June 08, 2023
Adding JOIN support for parallel replicas on ClickHouse®️
June 07, 2023
Yes, PostgreSQL Has Problems. But We’re Sticking With It!
Andy explores ways to optimize PostgreSQL for each of the problems caused by the implementation of multi-version concurrency control in PostgreSQL.
Using redundant conditions to unlock indexes in MySQL
June 05, 2023
Vitess Security Audit Results
June 01, 2023
Optimizing query planning in Vitess: a step-by-step approach
Why we just released a huge upgrade to our VS Code Extension
May 31, 2023
Pulling back the curtain: the new database overview page
May 29, 2023
Flutter Hackathon Winners
Supabase Vecs: a vector client for Postgres
May 25, 2023
Increase developer productivity with Database DevOps
ChatGPT plugins now support Postgres & Supabase
Implementing the Raft distributed consensus protocol in Go
As part of bringing myself up-to-speed after joining TigerBeetle, I wanted some background on how distributed consensus and replicated state machines protocols work. TigerBeetle uses Viewstamped Replication. But I wanted to understand all popular protocols and I decided to start with Raft.
We'll implement two key components of Raft in this post (leader election and log replication). Around 1k lines of Go. It took me around 7 months of sporadic studying to come to (what I hope is) an understanding of the basics.
Disclaimer: I'm not an expert. My implementation isn't yet hooked up to Jepsen. I've run it through a mix of manual and automated tests and it seems generally correct. This is not intended to be used in production. It's just for my education.
All code for this project is available on GitHub.
Let's dig in!
The algorithm
The Raft paper itself is quite readable. Give it a read and you'll get the basic idea.
The gist is that nodes in a cluster conduct elections to pick a leader. Users of the Raft cluster send messages to the leader. The leader passes the message to followers and waits for a majority to store the message. Once the message is committed (majority consensus has been reached), the message is applied to a state machine the user supplies. Followers learn about the latest committed message from the leader and apply each new committed message to their local user-supplied state machine.
There's more to it including reconfiguration and snapshotting, which I won't get into in this post. But you can get the gist of Raft by thinking about 1) leader election and 2) replicated logs powering replicated state machines.
Modeling with state machines and key-value stores
I've written before about how you can build a key-value store on top of Raft. How you can build a SQL database on top of a key-value store. And how you can build a distributed SQL database on top of Raft.
This post will start quite similarly to that first post except for that we won't stop at the Raft layer.
A distributed key-value store
To build on top of the Raft library we'll build, we need to create a state machine and commands that are sent to the state machine.
Our state machine will have two operations: get a value from a key, and set a key to a value.
This will go in cmd/kvapi/main.go
.
package main
import (
"bytes"
crypto "crypto/rand"
"encoding/binary"
"fmt"
"log"
"math/rand"
"net/http"
"os"
"strconv"
"strings"
"sync"
"github.com/eatonphil/goraft"
)
type statemachine struct {
db *sync.Map
server int
}
type commandKind uint8
const (
setCommand commandKind = iota
getCommand
)
type command struct {
kind commandKind
key string
value string
}
func (s *statemachine) Apply(cmd []byte) ([]byte, error) {
c := decodeCommand(cmd)
switch c.kind {
case setCommand:
s.db.Store(c.key, c.value)
case getCommand:
value, ok := s.db.Load(c.key)
if !ok {
return nil, fmt.Errorf("Key not found")
}
return []byte(value.(string)), nil
default:
return nil, fmt.Errorf("Unknown command: %x", cmd)
}
return nil, nil
}
But the Raft library we'll build needs to deal with various state machines. So commands passed from the user into the Raft cluster must be serialized to bytes.
func encodeCommand(c command) []byte {
msg := bytes.NewBuffer(nil)
err := msg.WriteByte(uint8(c.kind))
if err != nil {
panic(err)
}
err = binary.Write(msg, binary.LittleEndian, uint64(len(c.key)))
if err != nil {
panic(err)
}
msg.WriteString(c.key)
err = binary.Write(msg, binary.LittleEndian, uint64(len(c.value)))
if err != nil {
panic(err)
}
msg.WriteString(c.value)
return msg.Bytes()
}
And the Apply()
function from above needs to be able to decode the
bytes:
func decodeCommand(msg []byte) command {
var c command
c.kind = commandKind(msg[0])
keyLen := binary.LittleEndian.Uint64(msg[1:9])
c.key = string(msg[9 : 9+keyLen])
if c.kind == setCommand {
valLen := binary.LittleEndian.Uint64(msg[9+keyLen : 9+keyLen+8])
c.value = string(msg[9+keyLen+8 : 9+keyLen+8+valLen])
}
return c
}
HTTP API
Now that we've modeled the key-value store as a state machine. Let's build the HTTP endpoints that allow the user to operate the state machine through the Raft cluster.
First, let's implement the set
operation. We need to grab the key
and value the user passes in and call Apply()
on the Raft
cluster. Calling Apply()
on the Raft cluster will eventually call
the Apply()
function we just wrote, but not until the message sent
to the Raft cluster is actually replicated.
type httpServer struct {
raft *goraft.Server
db *sync.Map
}
// Example:
//
// curl http://localhost:2020/set?key=x&value=1
func (hs httpServer) setHandler(w http.ResponseWriter, r *http.Request) {
var c command
c.kind = setCommand
c.key = r.URL.Query().Get("key")
c.value = r.URL.Query().Get("value")
_, err := hs.raft.Apply([][]byte{encodeCommand(c)})
if err != nil {
log.Printf("Could not write key-value: %s", err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
}
To reiterate, we tell the Raft cluster we want this message
replicated. The message contains the operation type (set
) and the
operation details (key
and value
). These messages are custom to
the state machine we wrote. And they will be interpreted by the state
machine we wrote, on each node in the cluster.
Next we handle get
-ing values from the cluster. There are two ways
to do this. We already embed a local copy of the distributed key-value
map. We could just read from that map in the current process. But it
might not be up-to-date or correct. It would be fast to read
though. And convenient for debugging.
But the only correct way to read from a Raft cluster is to pass the read through the log replication too.
So we'll support both.
// Example:
//
// curl http://localhost:2020/get?key=x
// 1
// curl http://localhost:2020/get?key=x&relaxed=true # Skips consensus for the read.
// 1
func (hs httpServer) getHandler(w http.ResponseWriter, r *http.Request) {
var c command
c.kind = getCommand
c.key = r.URL.Query().Get("key")
var value []byte
var err error
if r.URL.Query().Get("relaxed") == "true" {
v, ok := hs.db.Load(c.key)
if !ok {
err = fmt.Errorf("Key not found")
} else {
value = []byte(v.(string))
}
} else {
var results []goraft.ApplyResult
results, err = hs.raft.Apply([][]byte{encodeCommand(c)})
if err == nil {
if len(results) != 1 {
err = fmt.Errorf("Expected single response from Raft, got: %d.", len(results))
} else if results[0].Error != nil {
err = results[0].Error
} else {
value = results[0].Result
}
}
}
if err != nil {
log.Printf("Could not encode key-value in http response: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
written := 0
for written < len(value) {
n, err := w.Write(value[written:])
if err != nil {
log.Printf("Could not encode key-value in http response: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
written += n
}
}
Main
Now that we've set up our custom state machine and our HTTP API for interacting with the Raft cluster, we'll tie it together with reading configuration from the command-line and actually starting the Raft node and the HTTP API.
type config struct {
cluster []goraft.ClusterMember
index int
id string
address string
http string
}
func getConfig() config {
cfg := config{}
var node string
for i, arg := range os.Args[1:] {
if arg == "--node" {
var err error
node = os.Args[i+2]
cfg.index, err = strconv.Atoi(node)
if err != nil {
log.Fatal("Expected $value to be a valid integer in `--node $value`, got: %s", node)
}
i++
continue
}
if arg == "--http" {
cfg.http = os.Args[i+2]
i++
continue
}
if arg == "--cluster" {
cluster := os.Args[i+2]
var clusterEntry goraft.ClusterMember
for _, part := range strings.Split(cluster, ";") {
idAddress := strings.Split(part, ",")
var err error
clusterEntry.Id, err = strconv.ParseUint(idAddress[0], 10, 64)
if err != nil {
log.Fatal("Expected $id to be a valid integer in `--cluster $id,$ip`, got: %s", idAddress[0])
}
clusterEntry.Address = idAddress[1]
cfg.cluster = append(cfg.cluster, clusterEntry)
}
i++
continue
}
}
if node == "" {
log.Fatal("Missing required parameter: --node $index")
}
if cfg.http == "" {
log.Fatal("Missing required parameter: --http $address")
}
if len(cfg.cluster) == 0 {
log.Fatal("Missing required parameter: --cluster $node1Id,$node1Address;...;$nodeNId,$nodeNAddress")
}
return cfg
}
func main() {
var b [8]byte
_, err := crypto.Read(b[:])
if err != nil {
panic("cannot seed math/rand package with cryptographically secure random number generator")
}
rand.Seed(int64(binary.LittleEndian.Uint64(b[:])))
cfg := getConfig()
var db sync.Map
var sm statemachine
sm.db = &db
sm.server = cfg.index
s := goraft.NewServer(cfg.cluster, &sm, ".", cfg.index)
go s.Start()
hs := httpServer{s, &db}
http.HandleFunc("/set", hs.setHandler)
http.HandleFunc("/get", hs.getHandler)
err = http.ListenAndServe(cfg.http, nil)
if err != nil {
panic(err)
}
}
And that's it for the easy part: a distributed key-value store on top of a Raft cluster.
Next we need to implement Raft.
A Raft server
If we take a look at Figure 2 in the Raft paper, we get an idea for all the state we need to model.
We'll dig into the details as we go. But for now let's turn that model
into a few Go types. This goes in raft.go
in the base directory,
not cmd/kvapi
.
package goraft
import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/rpc"
"os"
"path"
"sync"
"time"
)
type StateMachine interface {
Apply(cmd []byte) ([]byte, error)
}
type ApplyResult struct {
Result []byte
Error error
}
type Entry struct {
Command []byte
Term uint64
// Set by the primary so it can learn about the result of
// applying this command to the state machine
result chan ApplyResult
}
type ClusterMember struct {
Id uint64
Address string
// Index of the next log entry to send
nextIndex uint64
// Highest log entry known to be replicated
matchIndex uint64
// Who was voted for in the most recent term
votedFor uint64
// TCP connection
rpcClient *rpc.Client
}
type ServerState string
const (
leaderState ServerState = "leader"
followerState = "follower"
candidateState = "candidate"
)
type Server struct {
// These variables for shutting down.
done bool
server *http.Server
Debug bool
mu sync.Mutex
// ----------- PERSISTENT STATE -----------
// The current term
currentTerm uint64
log []Entry
// votedFor is stored in `cluster []ClusterMember` below,
// mapped by `clusterIndex` below
// ----------- READONLY STATE -----------
// Unique identifier for this Server
id uint64
// The TCP address for RPC
address string
// When to start elections after no append entry messages
electionTimeout time.Time
// How often to send empty messages
heartbeatMs int
// When to next send empty message
heartbeatTimeout time.Time
// User-provided state machine
statemachine StateMachine
// Metadata directory
metadataDir string
// Metadata store
fd *os.File
// ----------- VOLATILE STATE -----------
// Index of highest log entry known to be committed
commitIndex uint64
// Index of highest log entry applied to state machine
lastApplied uint64
// Candidate, follower, or leader
state ServerState
// Servers in the cluster, including this one
cluster []ClusterMember
// Index of this server
clusterIndex int
}
And let's build a constructor to initialize the state for all servers in the cluster, as well as local server state.
func NewServer(
clusterConfig []ClusterMember,
statemachine StateMachine,
metadataDir string,
clusterIndex int,
) *Server {
// Explicitly make a copy of the cluster because we'll be
// modifying it in this server.
var cluster []ClusterMember
for _, c := range clusterConfig {
if c.Id == 0 {
panic("Id must not be 0.")
}
cluster = append(cluster, c)
}
return &Server{
id: cluster[clusterIndex].Id,
address: cluster[clusterIndex].Address,
cluster: cluster,
statemachine: statemachine,
metadataDir: metadataDir,
clusterIndex: clusterIndex,
heartbeatMs: 300,
mu: sync.Mutex{},
}
}
And add a few debugging and assertion helpers.
func (s *Server) debugmsg(msg string) string {
return fmt.Sprintf("%s [Id: %d, Term: %d] %s", time.Now().Format(time.RFC3339Nano), s.id, s.currentTerm, msg)
}
func (s *Server) debug(msg string) {
if !s.Debug {
return
}
fmt.Println(s.debugmsg(msg))
}
func (s *Server) debugf(msg string, args ...any) {
if !s.Debug {
return
}
s.debug(fmt.Sprintf(msg, args...))
}
func (s *Server) warn(msg string) {
fmt.Println("[WARN] " + s.debugmsg(msg))
}
func (s *Server) warnf(msg string, args ...any) {
fmt.Println(fmt.Sprintf(msg, args...))
}
func Assert[T comparable](msg string, a, b T) {
if a != b {
panic(fmt.Sprintf("%s. Got a = %#v, b = %#v", msg, a, b))
}
}
func Server_assert[T comparable](s *Server, msg string, a, b T) {
Assert(s.debugmsg(msg), a, b)
}
Persistent state
As Figure 2 says, currentTerm
, log
, and votedFor
must be
persisted to disk as they're edited.
I like to initially doing the stupidest thing possible. So in the
first version of this project I used encoding/gob
to write these
three fields to disk every time s.persist()
was called.
Here is what this first version looked like:
func (s *Server) persist() {
s.mu.Lock()
defer s.mu.Unlock()
s.fd.Truncate(0)
s.fd.Seek(0, 0)
enc := gob.NewEncoder(s.fd)
err := enc.Encode(PersistentState{
CurrentTerm: s.currentTerm,
Log: s.log,
VotedFor: s.votedFor,
})
if err != nil {
panic(err)
}
if err = s.fd.Sync(); err != nil {
panic(err)
}
s.debug(fmt.Sprintf("Persisted. Term: %d. Log Len: %d. Voted For: %s.", s.currentTerm, len(s.log), s.votedFor))
}
But doing so means this implementation is a function of the size of the log. And that was horrible for throughput.
I also noticed that encoding/gob
is pretty inefficient.
For a simple struct like:
type X struct {
A uint64
B []uint64
C bool
}
encoding/gob
uses 68 bytes to store that data for when B has two
entries. If we wrote the
encoder/decoder ourselves we could store that struct in 33 bytes (8
(sizeof(A)) + 8 (sizeof(len(B))) + 16 (len(B) * sizeof(B)) + 1
(sizeof(C))
).
It's not that encoding/gob
is bad. It just likely has different
constraints than we are party to.
So I decided to swap out encoding/gob
for simply binary encoding the
fields and also, importantly, keeping track of exactly how many
entries in the log must be written and only writing that many.
s.persist()
Here's what that looks like.
const PAGE_SIZE = 4096
const ENTRY_HEADER = 16
const ENTRY_SIZE = 128
// Must be called within s.mu.Lock()
func (s *Server) persist(writeLog bool, nNewEntries int) {
t := time.Now()
if nNewEntries == 0 && writeLog {
nNewEntries = len(s.log)
}
s.fd.Seek(0, 0)
var page [PAGE_SIZE]byte
// Bytes 0 - 8: Current term
// Bytes 8 - 16: Voted for
// Bytes 16 - 24: Log length
// Bytes 4096 - N: Log
binary.LittleEndian.PutUint64(page[:8], s.currentTerm)
binary.LittleEndian.PutUint64(page[8:16], s.getVotedFor())
binary.LittleEndian.PutUint64(page[16:24], uint64(len(s.log)))
n, err := s.fd.Write(page[:])
if err != nil {
panic(err)
}
Server_assert(s, "Wrote full page", n, PAGE_SIZE)
if writeLog && nNewEntries > 0 {
newLogOffset := max(len(s.log)-nNewEntries, 0)
s.fd.Seek(int64(PAGE_SIZE+ENTRY_SIZE*newLogOffset), 0)
bw := bufio.NewWriter(s.fd)
var entryBytes [ENTRY_SIZE]byte
for i := newLogOffset; i < len(s.log); i++ {
// Bytes 0 - 8: Entry term
// Bytes 8 - 16: Entry command length
// Bytes 16 - ENTRY_SIZE: Entry command
if len(s.log[i].Command) > ENTRY_SIZE-ENTRY_HEADER {
panic(fmt.Sprintf("Command is too large (%d). Must be at most %d bytes.", len(s.log[i].Command), ENTRY_SIZE-ENTRY_HEADER))
}
binary.LittleEndian.PutUint64(entryBytes[:8], s.log[i].Term)
binary.LittleEndian.PutUint64(entryBytes[8:16], uint64(len(s.log[i].Command)))
copy(entryBytes[16:], []byte(s.log[i].Command))
n, err := bw.Write(entryBytes[:])
if err != nil {
panic(err)
}
Server_assert(s, "Wrote full page", n, ENTRY_SIZE)
}
err = bw.Flush()
if err != nil {
panic(err)
}
}
if err = s.fd.Sync(); err != nil {
panic(err)
}
s.debugf("Persisted in %s. Term: %d. Log Len: %d (%d new). Voted For: %d.", time.Now().Sub(t), s.currentTerm, len(s.log), nNewEntries, s.getVotedFor())
}
Again the important thing is that only the entries that need to be
written are written. We do that by seek
-ing to the offset of the
first entry that needs to be written.
And we collect writes of entries in a bufio.Writer
so we don't waste
write syscalls. Don't forget to flush the buffered writer!
And don't forget to flush all writes to disk with fd.Sync()
.
ENTRY_SIZE
is something that I could see being configurable based
on the workload. Some workloads truly need only 128 bytes. But a
key-value store probably wants much more than that. This
implementation doesn't try to handle the case of completely
arbitrary sized keys and values.
Lastly, a few helpers used in there:
func min[T ~int | ~uint64](a, b T) T {
if a < b {
return a
}