Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 97 additions & 80 deletions internal/cluster/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"path/filepath"
"strconv"
"sync"
"time"

dqlite "github.com/CanonicalLtd/go-dqlite"
raftmembership "github.com/CanonicalLtd/raft-membership"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
hashiraft "github.com/hashicorp/raft"
"github.com/lxc/lxd/shared/logger"
"github.com/pkg/errors"
"github.com/spoke-d/thermionic/internal/cert"
"github.com/spoke-d/thermionic/internal/clock"
Expand Down Expand Up @@ -84,6 +85,9 @@ type RaftInstance interface {
// MembershipChanger returns the underlying rafthttp.Layer, which can be used
// to change the membership of this node in the cluster.
MembershipChanger() raftmembership.Changer

// Info returns the information around the current server.
Info() ServerInfo
}

// ServerProvider creates a new Server instance.
Expand All @@ -93,8 +97,17 @@ type ServerProvider interface {
New(RaftInstance, net.Listener, *raft.AddressProvider, log.Logger) (Server, error)
}

// ServerInfo describes the id and address of a server
type ServerInfo struct {
ID uint64
Address string
}

// Server implements the dqlite network protocol.
type Server interface {
// Leader returns information about the current leader, if any.
Leader() *ServerInfo

// Dump the files of a database to disk.
Dump(string, string) error

Expand Down Expand Up @@ -130,6 +143,7 @@ type StoreProvider interface {
// possibly runs a dqlite replica on this node (if we're configured to do
// so).
type Gateway struct {
mutex sync.RWMutex
database Node
fileSystem fsys.FileSystem
cert *cert.Info
Expand Down Expand Up @@ -297,8 +311,11 @@ func (g *Gateway) Init(cert *cert.Info) error {
// These handlers might return 404, either because this node is a
// non-clustered node not available over the network or because it is not a
// database node part of the dqlite cluster.
func (g *Gateway) HandlerFuncs() map[string]http.HandlerFunc {
func (g *Gateway) HandlerFuncs(nodeRefreshTask func(*heartbeat.APIHeartbeat)) map[string]http.HandlerFunc {
databaseHandler := func(w http.ResponseWriter, r *http.Request) {
g.mutex.RLock()
defer g.mutex.RUnlock()

ok, err := cert.TLSCheckCert(r, g.cert)
if err != nil {
http.Error(w, "500 server error", http.StatusInternalServerError)
Expand All @@ -309,9 +326,30 @@ func (g *Gateway) HandlerFuncs() map[string]http.HandlerFunc {
return
}

// Compare the dqlite version of the connecting client
// with our own one.
versionHeader := r.Header.Get("X-Dqlite-Version")
if versionHeader == "" {
// No version header means an old pre dqlite 1.0 client
versionHeader = "0"
}
version, err := strconv.Atoi(versionHeader)
if err != nil {
http.Error(w, "400 invalid dqlite version", http.StatusBadRequest)
return
}
if version != dqliteVersion {
if version > dqliteVersion {
http.Error(w, "503 unsupported dqlite version", http.StatusServiceUnavailable)
} else {
http.Error(w, "426 dqlite version too old ", http.StatusUpgradeRequired)
}
return
}

// Handle hearbeats
if r.Method == "PUT" {
g.handleHeartbeats(w, r)
g.handleHeartbeats(w, r, nodeRefreshTask)
return
}

Expand Down Expand Up @@ -372,47 +410,8 @@ func (g *Gateway) HandlerFuncs() map[string]http.HandlerFunc {

g.acceptCh <- conn
}
raftHandler := func(w http.ResponseWriter, r *http.Request) {
// If we are not part of the raft cluster, reply with a
// redirect to one of the raft nodes that we know about.
if g.raft == nil {
var address string
if err := g.database.Transaction(func(tx *db.NodeTx) error {
nodes, err := tx.RaftNodes()
if err != nil {
return errors.WithStack(err)
}
if len(nodes) == 0 {
return errors.Errorf("no raft nodes found")
}
address = nodes[0].Address
return nil
}); err != nil {
http.Error(w, "500 failed to fetch raft nodes", http.StatusInternalServerError)
return
}
url := &url.URL{
Scheme: "http",
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
Host: address,
}
http.Redirect(w, r, url.String(), http.StatusPermanentRedirect)
return
}

// If this node is not clustered return a 404.
handlerFunc := g.raft.HandlerFunc()
if handlerFunc == nil {
http.NotFound(w, r)
return
}

handlerFunc(w, r)
}
return map[string]http.HandlerFunc{
heartbeat.DatabaseEndpoint: databaseHandler,
raft.Endpoint: raftHandler,
}
}

Expand Down Expand Up @@ -463,16 +462,18 @@ func (g *Gateway) Raft() RaftInstance {

// RaftNodes returns the nodes currently part of the raft cluster.
func (g *Gateway) RaftNodes() ([]db.RaftNode, error) {
if g.raft == nil {
g.mutex.RLock()
defer g.mutex.RUnlock()

if g.raft == nil || !g.isLeader() {
return nil, hashiraft.ErrNotLeader
}
servers, err := g.raft.Servers()
servers, err := g.server.Cluster()
if err != nil {
return nil, errors.WithStack(err)
}

nodes := make([]db.RaftNode, len(servers))

addressProvider := g.addressProvider
for i, server := range servers {
address, err := addressProvider.ServerAddr(server.ID)
Expand Down Expand Up @@ -500,6 +501,9 @@ func (g *Gateway) RaftNodes() ([]db.RaftNode, error) {

// LeaderAddress returns the address of the current raft leader.
func (g *Gateway) LeaderAddress() (string, error) {
g.mutex.RLock()
defer g.mutex.RUnlock()

// If we aren't clustered, return an error.
if g.memoryDial != nil {
return "", errors.New("node is not clustered")
Expand All @@ -512,8 +516,8 @@ func (g *Gateway) LeaderAddress() (string, error) {
// wait a bit until one is elected.
if g.raft != nil {
for ctx.Err() == nil {
if address := string(g.raft.Raft().Leader()); address != "" {
return address, nil
if info := g.server.Leader(); info != nil {
return info.Address, nil
}
time.Sleep(time.Second)
}
Expand Down Expand Up @@ -556,6 +560,7 @@ func (g *Gateway) LeaderAddress() (string, error) {
return "", errors.WithStack(err)
}

setDqliteVersionHeader(request)
request = request.WithContext(ctx)
client := &http.Client{
Transport: &http.Transport{
Expand Down Expand Up @@ -685,48 +690,60 @@ func (g *Gateway) Reset(cert *cert.Info) error {
return g.Init(cert)
}

func (g *Gateway) handleHeartbeats(w http.ResponseWriter, r *http.Request) {
var nodes []db.RaftNode
if err := json.Read(r.Body, &nodes); err != nil {
http.Error(w, "400 invalid raft nodes payload", http.StatusBadRequest)
func (g *Gateway) isLeader() bool {
if g.server == nil {
return false
}
info := g.server.Leader()
return info != nil && info.ID == g.raft.Info().ID
}

func (g *Gateway) handleHeartbeats(w http.ResponseWriter, r *http.Request, nodeRefreshTask func(*heartbeat.APIHeartbeat)) {
var heartbeatData heartbeat.APIHeartbeat
err := json.Read(r.Body, &heartbeatData)
if err != nil {
logger.Errorf("Error decoding heartbeat body: %v", err)
http.Error(w, "400 invalid heartbeat payload", http.StatusBadRequest)
return
}
if err := g.database.Transaction(func(tx *db.NodeTx) error {
// validate the raft nodes
current, err := tx.RaftNodes()
if err != nil {
return errors.WithStack(err)

raftNodes := make([]db.RaftNode, 0)
for _, node := range heartbeatData.Members {
if node.RaftID > 0 {
raftNodes = append(raftNodes, db.RaftNode{
ID: node.RaftID,
Address: node.Address,
})
}
// If nodes match then we can be assured that we do want to check if
// nodes match.
if len(current) == len(nodes) {
new := make(map[string]db.RaftNode, len(current))
for _, v := range current {
new[v.Address] = v
}
identical := true
for _, v := range nodes {
if n, ok := new[v.Address]; !ok || n.ID == v.ID {
identical = false
break
}
}
// Nothing to do, we're already at the same quorum
if identical {
return nil
}
}

if len(raftNodes) > 0 {
if err := g.database.Transaction(func(tx *db.NodeTx) error {
level.Debug(g.logger).Log("msg", fmt.Sprintf("replace current raft nodes with notes %+v", raftNodes))
err = tx.RaftNodesReplace(raftNodes)
return errors.WithStack(err)
}); err != nil {
http.Error(w, "500 failed to update raft nodes", http.StatusInternalServerError)
return
}
level.Debug(g.logger).Log("msg", fmt.Sprintf("replace current raft nodes with notes %+v", nodes))
err = tx.RaftNodesReplace(nodes)
return errors.WithStack(err)
}); err != nil {
http.Error(w, "500 failed to update raft nodes", http.StatusInternalServerError)
} else {
level.Error(g.logger).Log("msg", "Empty raft node set received")
}

// Only perform node refresh task if we have received a full state list from leader.
if !heartbeatData.FullStateList {
level.Debug(g.logger).Log("msg", "Partial node list heartbeat received, skipping full update")
return
}

// If node refresh task is specified, run it async.
if nodeRefreshTask != nil {
go nodeRefreshTask(&heartbeatData)
}
}

func (g *Gateway) handleLeadershipState(w http.ResponseWriter, r *http.Request) {
if g.raft.Raft().State() != hashiraft.Leader {
if info := g.server.Leader(); info == nil || info.ID != g.raft.Info().ID {
http.Error(w, "503 not leader", http.StatusServiceUnavailable)
return
}
Expand Down
Loading