diff --git a/internal/cluster/gateway.go b/internal/cluster/gateway.go index ebafcc8..4b20b15 100644 --- a/internal/cluster/gateway.go +++ b/internal/cluster/gateway.go @@ -5,9 +5,9 @@ import ( "fmt" "net" "net/http" - "net/url" "path/filepath" "strconv" + "sync" "time" dqlite "github.com/CanonicalLtd/go-dqlite" @@ -15,6 +15,7 @@ import ( "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" hashiraft "github.com/hashicorp/raft" + "github.com/lxc/lxd/shared/logger" "github.com/pkg/errors" "github.com/spoke-d/thermionic/internal/cert" "github.com/spoke-d/thermionic/internal/clock" @@ -84,6 +85,9 @@ type RaftInstance interface { // MembershipChanger returns the underlying rafthttp.Layer, which can be used // to change the membership of this node in the cluster. MembershipChanger() raftmembership.Changer + + // Info returns the information around the current server. + Info() ServerInfo } // ServerProvider creates a new Server instance. @@ -93,8 +97,17 @@ type ServerProvider interface { New(RaftInstance, net.Listener, *raft.AddressProvider, log.Logger) (Server, error) } +// ServerInfo describes the id and address of a server +type ServerInfo struct { + ID uint64 + Address string +} + // Server implements the dqlite network protocol. type Server interface { + // Leader returns information about the current leader, if any. + Leader() *ServerInfo + // Dump the files of a database to disk. Dump(string, string) error @@ -130,6 +143,7 @@ type StoreProvider interface { // possibly runs a dqlite replica on this node (if we're configured to do // so). type Gateway struct { + mutex sync.RWMutex database Node fileSystem fsys.FileSystem cert *cert.Info @@ -297,8 +311,11 @@ func (g *Gateway) Init(cert *cert.Info) error { // These handlers might return 404, either because this node is a // non-clustered node not available over the network or because it is not a // database node part of the dqlite cluster. -func (g *Gateway) HandlerFuncs() map[string]http.HandlerFunc { +func (g *Gateway) HandlerFuncs(nodeRefreshTask func(*heartbeat.APIHeartbeat)) map[string]http.HandlerFunc { databaseHandler := func(w http.ResponseWriter, r *http.Request) { + g.mutex.RLock() + defer g.mutex.RUnlock() + ok, err := cert.TLSCheckCert(r, g.cert) if err != nil { http.Error(w, "500 server error", http.StatusInternalServerError) @@ -309,9 +326,30 @@ func (g *Gateway) HandlerFuncs() map[string]http.HandlerFunc { return } + // Compare the dqlite version of the connecting client + // with our own one. + versionHeader := r.Header.Get("X-Dqlite-Version") + if versionHeader == "" { + // No version header means an old pre dqlite 1.0 client + versionHeader = "0" + } + version, err := strconv.Atoi(versionHeader) + if err != nil { + http.Error(w, "400 invalid dqlite version", http.StatusBadRequest) + return + } + if version != dqliteVersion { + if version > dqliteVersion { + http.Error(w, "503 unsupported dqlite version", http.StatusServiceUnavailable) + } else { + http.Error(w, "426 dqlite version too old ", http.StatusUpgradeRequired) + } + return + } + // Handle hearbeats if r.Method == "PUT" { - g.handleHeartbeats(w, r) + g.handleHeartbeats(w, r, nodeRefreshTask) return } @@ -372,47 +410,8 @@ func (g *Gateway) HandlerFuncs() map[string]http.HandlerFunc { g.acceptCh <- conn } - raftHandler := func(w http.ResponseWriter, r *http.Request) { - // If we are not part of the raft cluster, reply with a - // redirect to one of the raft nodes that we know about. - if g.raft == nil { - var address string - if err := g.database.Transaction(func(tx *db.NodeTx) error { - nodes, err := tx.RaftNodes() - if err != nil { - return errors.WithStack(err) - } - if len(nodes) == 0 { - return errors.Errorf("no raft nodes found") - } - address = nodes[0].Address - return nil - }); err != nil { - http.Error(w, "500 failed to fetch raft nodes", http.StatusInternalServerError) - return - } - url := &url.URL{ - Scheme: "http", - Path: r.URL.Path, - RawQuery: r.URL.RawQuery, - Host: address, - } - http.Redirect(w, r, url.String(), http.StatusPermanentRedirect) - return - } - - // If this node is not clustered return a 404. - handlerFunc := g.raft.HandlerFunc() - if handlerFunc == nil { - http.NotFound(w, r) - return - } - - handlerFunc(w, r) - } return map[string]http.HandlerFunc{ heartbeat.DatabaseEndpoint: databaseHandler, - raft.Endpoint: raftHandler, } } @@ -463,16 +462,18 @@ func (g *Gateway) Raft() RaftInstance { // RaftNodes returns the nodes currently part of the raft cluster. func (g *Gateway) RaftNodes() ([]db.RaftNode, error) { - if g.raft == nil { + g.mutex.RLock() + defer g.mutex.RUnlock() + + if g.raft == nil || !g.isLeader() { return nil, hashiraft.ErrNotLeader } - servers, err := g.raft.Servers() + servers, err := g.server.Cluster() if err != nil { return nil, errors.WithStack(err) } nodes := make([]db.RaftNode, len(servers)) - addressProvider := g.addressProvider for i, server := range servers { address, err := addressProvider.ServerAddr(server.ID) @@ -500,6 +501,9 @@ func (g *Gateway) RaftNodes() ([]db.RaftNode, error) { // LeaderAddress returns the address of the current raft leader. func (g *Gateway) LeaderAddress() (string, error) { + g.mutex.RLock() + defer g.mutex.RUnlock() + // If we aren't clustered, return an error. if g.memoryDial != nil { return "", errors.New("node is not clustered") @@ -512,8 +516,8 @@ func (g *Gateway) LeaderAddress() (string, error) { // wait a bit until one is elected. if g.raft != nil { for ctx.Err() == nil { - if address := string(g.raft.Raft().Leader()); address != "" { - return address, nil + if info := g.server.Leader(); info != nil { + return info.Address, nil } time.Sleep(time.Second) } @@ -556,6 +560,7 @@ func (g *Gateway) LeaderAddress() (string, error) { return "", errors.WithStack(err) } + setDqliteVersionHeader(request) request = request.WithContext(ctx) client := &http.Client{ Transport: &http.Transport{ @@ -685,48 +690,60 @@ func (g *Gateway) Reset(cert *cert.Info) error { return g.Init(cert) } -func (g *Gateway) handleHeartbeats(w http.ResponseWriter, r *http.Request) { - var nodes []db.RaftNode - if err := json.Read(r.Body, &nodes); err != nil { - http.Error(w, "400 invalid raft nodes payload", http.StatusBadRequest) +func (g *Gateway) isLeader() bool { + if g.server == nil { + return false + } + info := g.server.Leader() + return info != nil && info.ID == g.raft.Info().ID +} + +func (g *Gateway) handleHeartbeats(w http.ResponseWriter, r *http.Request, nodeRefreshTask func(*heartbeat.APIHeartbeat)) { + var heartbeatData heartbeat.APIHeartbeat + err := json.Read(r.Body, &heartbeatData) + if err != nil { + logger.Errorf("Error decoding heartbeat body: %v", err) + http.Error(w, "400 invalid heartbeat payload", http.StatusBadRequest) return } - if err := g.database.Transaction(func(tx *db.NodeTx) error { - // validate the raft nodes - current, err := tx.RaftNodes() - if err != nil { - return errors.WithStack(err) + + raftNodes := make([]db.RaftNode, 0) + for _, node := range heartbeatData.Members { + if node.RaftID > 0 { + raftNodes = append(raftNodes, db.RaftNode{ + ID: node.RaftID, + Address: node.Address, + }) } - // If nodes match then we can be assured that we do want to check if - // nodes match. - if len(current) == len(nodes) { - new := make(map[string]db.RaftNode, len(current)) - for _, v := range current { - new[v.Address] = v - } - identical := true - for _, v := range nodes { - if n, ok := new[v.Address]; !ok || n.ID == v.ID { - identical = false - break - } - } - // Nothing to do, we're already at the same quorum - if identical { - return nil - } + } + + if len(raftNodes) > 0 { + if err := g.database.Transaction(func(tx *db.NodeTx) error { + level.Debug(g.logger).Log("msg", fmt.Sprintf("replace current raft nodes with notes %+v", raftNodes)) + err = tx.RaftNodesReplace(raftNodes) + return errors.WithStack(err) + }); err != nil { + http.Error(w, "500 failed to update raft nodes", http.StatusInternalServerError) + return } - level.Debug(g.logger).Log("msg", fmt.Sprintf("replace current raft nodes with notes %+v", nodes)) - err = tx.RaftNodesReplace(nodes) - return errors.WithStack(err) - }); err != nil { - http.Error(w, "500 failed to update raft nodes", http.StatusInternalServerError) + } else { + level.Error(g.logger).Log("msg", "Empty raft node set received") + } + + // Only perform node refresh task if we have received a full state list from leader. + if !heartbeatData.FullStateList { + level.Debug(g.logger).Log("msg", "Partial node list heartbeat received, skipping full update") return } + + // If node refresh task is specified, run it async. + if nodeRefreshTask != nil { + go nodeRefreshTask(&heartbeatData) + } } func (g *Gateway) handleLeadershipState(w http.ResponseWriter, r *http.Request) { - if g.raft.Raft().State() != hashiraft.Leader { + if info := g.server.Leader(); info == nil || info.ID != g.raft.Info().ID { http.Error(w, "503 not leader", http.StatusServiceUnavailable) return } diff --git a/internal/cluster/heartbeat/heartbeat.go b/internal/cluster/heartbeat/heartbeat.go index cb9404f..aeff453 100644 --- a/internal/cluster/heartbeat/heartbeat.go +++ b/internal/cluster/heartbeat/heartbeat.go @@ -16,10 +16,14 @@ import ( "github.com/pkg/errors" "github.com/spoke-d/task" "github.com/spoke-d/thermionic/internal/cert" + "github.com/spoke-d/thermionic/internal/clock" "github.com/spoke-d/thermionic/internal/db" "github.com/spoke-d/thermionic/internal/net" ) +// Current dqlite protocol version. +const dqliteVersion = 1 + // DatabaseEndpoint specifies the API endpoint path that gets routed to a dqlite // server handler for performing SQL queries against the dqlite server running // on this node. @@ -71,6 +75,206 @@ type CertConfig interface { Read(*cert.Info) (*tls.Config, error) } +// APIHeartbeatMember contains specific cluster node info. +type APIHeartbeatMember struct { + ID int64 // ID field value in nodes table. + Address string // Host and Port of node. + RaftID int64 // ID field value in raft_nodes table, zero if non-raft node. + Raft bool // Deprecated, use non-zero RaftID instead to indicate raft node. + LastHeartbeat time.Time // Last time we received a successful response from node. + Online bool // Calculated from offline threshold and LastHeatbeat time. + updated bool // Has node been updated during this heartbeat run. Not sent to nodes. +} + +// APIHeartbeatVersion contains max versions for all nodes in cluster. +type APIHeartbeatVersion struct { + Schema int + APIExtensions int +} + +// APIHeartbeat contains data sent to nodes in heartbeat. +type APIHeartbeat struct { + sync.Mutex + + Members map[int64]APIHeartbeatMember + Version APIHeartbeatVersion + Time time.Time + + // Indicates if heartbeat contains a fresh set of node states. + // This can be used to indicate to the receiving node that the state is fresh enough to + // trigger node refresh activies (such as forkdns). + FullStateList bool + + databaseEndpoint string + certConfig CertConfig + clock clock.Clock + logger log.Logger +} + +// Update updates an existing APIHeartbeat struct with the raft and all node states supplied. +// If allNodes provided is an empty set then this is considered a non-full state list. +func (h *APIHeartbeat) Update(fullStateList bool, raftNodes []db.RaftNode, allNodes []db.NodeInfo, offlineThreshold time.Duration) { + var maxSchemaVersion, maxAPIExtensionsVersion int + h.Time = h.clock.Now() + + if h.Members == nil { + h.Members = make(map[int64]APIHeartbeatMember) + } + + // If we've been supplied a fresh set of node states, this is a full state list. + h.FullStateList = fullStateList + + raftNodeMap := make(map[string]db.RaftNode) + + // Convert raftNodes to a map keyed on address for lookups later. + for _, raftNode := range raftNodes { + raftNodeMap[raftNode.Address] = raftNode + } + + // Add nodes (overwrites any nodes with same ID in map with fresh data). + for _, node := range allNodes { + member := APIHeartbeatMember{ + ID: node.ID, + Address: node.Address, + LastHeartbeat: node.Heartbeat, + Online: !node.Heartbeat.Before(h.clock.Now().Add(-offlineThreshold)), + } + + if raftNode, exists := raftNodeMap[member.Address]; exists { + member.Raft = true // Deprecated + member.RaftID = raftNode.ID + delete(raftNodeMap, member.Address) // Used to check any remaining later. + } + + // Add to the members map using the node ID (not the Raft Node ID). + h.Members[node.ID] = member + + // Keep a record of highest APIExtensions and Schema version seen in all nodes. + if node.APIExtensions > maxAPIExtensionsVersion { + maxAPIExtensionsVersion = node.APIExtensions + } + + if node.Schema > maxSchemaVersion { + maxSchemaVersion = node.Schema + } + } + + h.Version = APIHeartbeatVersion{ + Schema: maxSchemaVersion, + APIExtensions: maxAPIExtensionsVersion, + } + + if len(raftNodeMap) > 0 { + level.Error(h.logger).Log("msg", "Unaccounted raft node(s) not found in 'nodes' table for heartbeat", "nodes", fmt.Sprintf("%+v", raftNodeMap)) + } + + return +} + +// Send sends heartbeat requests to the nodes supplied and updates heartbeat state. +func (h *APIHeartbeat) Send(ctx context.Context, cert *cert.Info, localAddress string, nodes []db.NodeInfo) { + heartbeatsWg := sync.WaitGroup{} + sendHeartbeat := func(nodeID int64, address string, heartbeatData *APIHeartbeat) { + defer heartbeatsWg.Done() + + level.Debug(h.logger).Log("msg", "Sending heartbeat", "address", address) + + err := h.heartbeatNode(ctx, address, cert, heartbeatData) + + if err == nil { + h.Lock() + // Ensure only update nodes that exist in Members already. + hbNode, existing := h.Members[nodeID] + if !existing { + return + } + + hbNode.LastHeartbeat = h.clock.Now() + hbNode.Online = true + hbNode.updated = true + h.Members[nodeID] = hbNode + h.Unlock() + level.Debug(h.logger).Log("msg", "Successful heartbeat", "address", address) + } else { + level.Error(h.logger).Log("msg", "Failed heartbeat for", "address", address, "err", err) + } + } + + for _, node := range nodes { + // Special case for the local node - just record the time now. + if node.Address == localAddress { + h.Lock() + hbNode := h.Members[node.ID] + hbNode.LastHeartbeat = h.clock.Now() + hbNode.Online = true + hbNode.updated = true + h.Members[node.ID] = hbNode + h.Unlock() + continue + } + + // Parallelize the rest. + heartbeatsWg.Add(1) + go sendHeartbeat(node.ID, node.Address, h) + } + heartbeatsWg.Wait() +} + +func (h *APIHeartbeat) heartbeatNode(taskCtx context.Context, address string, cert *cert.Info, heartbeatData *APIHeartbeat) error { + level.Debug(h.logger).Log("msg", "Sending heartbeat request", "address", address) + + config, err := h.certConfig.Read(cert) + if err != nil { + return errors.WithStack(err) + } + client := http.Client{ + Transport: &http.Transport{ + TLSClientConfig: config, + }, + } + + var buffer bytes.Buffer + if err := json.NewEncoder(&buffer).Encode(heartbeatData); err != nil { + return errors.WithStack(err) + } + + url := net.EnsureHTTPS(fmt.Sprintf("%s%s", address, h.databaseEndpoint)) + request, err := http.NewRequest("PUT", url, bytes.NewReader(buffer.Bytes())) + if err != nil { + return errors.WithStack(err) + } + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + request = request.WithContext(ctx) + request.Close = true // Immediately close the connection after the request is done + + SetDqliteVersionHeader(request) + + // Perform the request asynchronously, so we can abort it if the task context is done. + errCh := make(chan error) + go func() { + response, err := client.Do(request) + if err != nil { + errCh <- errors.Wrap(err, "failed to send HTTP request") + return + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + errCh <- errors.Errorf("HTTP request failed: %s", response.Status) + return + } + errCh <- nil + }() + + select { + case err := <-errCh: + return err + case <-taskCtx.Done(): + return taskCtx.Err() + } +} + // Interval represents the number of seconds to wait between to heartbeat // rounds. const Interval = 4 @@ -83,6 +287,7 @@ type Heartbeat struct { task Task certConfig CertConfig databaseEndpoint string + clock clock.Clock logger log.Logger } @@ -114,7 +319,7 @@ func (h *Heartbeat) Run() (task.Func, task.Schedule) { heartbeatWrapper := func(ctx context.Context) { ch := make(chan struct{}) go func() { - h.run(ctx) + h.run(ctx, false) ch <- struct{}{} }() select { @@ -127,7 +332,7 @@ func (h *Heartbeat) Run() (task.Func, task.Schedule) { return heartbeatWrapper, schedule } -func (h *Heartbeat) run(ctx context.Context) { +func (h *Heartbeat) run(ctx context.Context, initialHeartbeat bool) { if !h.gateway.Clustered() { // We're not a raft node or we're not clustered return @@ -155,6 +360,7 @@ func (h *Heartbeat) run(ctx context.Context) { } var nodes []db.NodeInfo var nodeAddress string + var offlineThreshold time.Duration if err := h.cluster.Transaction(func(tx *db.ClusterTx) error { var err error if nodes, err = tx.Nodes(); err != nil { @@ -163,44 +369,74 @@ func (h *Heartbeat) run(ctx context.Context) { if nodeAddress, err = tx.NodeAddress(); err != nil { return errors.WithStack(err) } + if offlineThreshold, err = tx.NodeOfflineThreshold(); err != nil { + return errors.WithStack(err) + } return nil }); err != nil { level.Error(h.logger).Log("msg", "Failed to get current cluster nodes", "err", err) return } - heartbeats := make([]time.Time, len(nodes)) - var mutex sync.Mutex - var wg sync.WaitGroup - - for i, node := range nodes { + // Cumulative set of node states (will be written back to database once done). + heartbeats := &APIHeartbeat{ + certConfig: h.certConfig, + databaseEndpoint: h.databaseEndpoint, + } + cert := h.gateway.Cert() + + // If this leader node hasn't sent a heartbeat recently, then its node state records + // are likely out of date, this can happen when a node becomes a leader. + // Send stale set to all nodes in database to get a fresh set of active nodes. + if initialHeartbeat { + heartbeats.Update(false, raftNodes, nodes, offlineThreshold) + heartbeats.Send(ctx, cert, nodeAddress, nodes) + + // We have the latest set of node states now, lets send that state set to all nodes. + heartbeats.Update(true, raftNodes, nodes, offlineThreshold) + heartbeats.Send(ctx, cert, nodeAddress, nodes) + } else { + heartbeats.Update(true, raftNodes, nodes, offlineThreshold) + heartbeats.Send(ctx, cert, nodeAddress, nodes) + } - // Special case the local node - if node.Address == nodeAddress { - mutex.Lock() - heartbeats[i] = time.Now() - mutex.Unlock() - continue + // Look for any new node which appeared since sending last heartbeat. + var currentNodes []db.NodeInfo + err = h.cluster.Transaction(func(tx *db.ClusterTx) error { + var err error + currentNodes, err = tx.Nodes() + if err != nil { + return errors.WithStack(err) } + return nil + }) + if err != nil { + level.Warn(h.logger).Log("msg", "Failed to get current cluster nodes", "err", err) + return + } - // Parallelize the rest - wg.Add(1) - go func(i int, address string) { - defer wg.Done() + newNodes := []db.NodeInfo{} + for _, currentNode := range currentNodes { + existing := false + for _, node := range nodes { + if node.Address == currentNode.Address && node.ID == currentNode.ID { + existing = true + break + } + } - level.Debug(h.logger).Log("msg", "Sending heartbeat", "address", address) - if err := h.heartbeatNode(ctx, address, h.gateway.Cert(), raftNodes); err == nil { - level.Debug(h.logger).Log("msg", "Successful heartbeat", "address", address) + if !existing { + // We found a new node + nodes = append(nodes, currentNode) + newNodes = append(newNodes, currentNode) + } + } - mutex.Lock() - heartbeats[i] = time.Now() - mutex.Unlock() - } else { - level.Error(h.logger).Log("msg", "Failed heartbeat", "address", address, "err", err) - } - }(i, node.Address) + // If any new nodes found, send heartbeat to just them (with full node state). + if len(newNodes) > 0 { + heartbeats.Update(true, raftNodes, nodes, offlineThreshold) + heartbeats.Send(ctx, cert, nodeAddress, newNodes) } - wg.Wait() // If the context has been cancelled, return immediately. if ctx.Err() != nil { @@ -209,11 +445,11 @@ func (h *Heartbeat) run(ctx context.Context) { } if err := h.cluster.Transaction(func(tx *db.ClusterTx) error { - for i, node := range nodes { - if heartbeats[i].Equal(time.Time{}) { + for _, node := range heartbeats.Members { + if !node.updated { continue } - if err := tx.NodeHeartbeat(node.Address, heartbeats[i]); err != nil { + if err := tx.NodeHeartbeat(node.Address, h.clock.Now()); err != nil { return errors.WithStack(err) } } @@ -221,60 +457,8 @@ func (h *Heartbeat) run(ctx context.Context) { }); err != nil { level.Error(h.logger).Log("msg", "Failed to update heartbeat", "err", err) } - level.Info(h.logger).Log("msg", "Completed heartbeat round") -} - -func (h *Heartbeat) heartbeatNode(taskCtx context.Context, address string, cert *cert.Info, raftNodes []db.RaftNode) error { - level.Debug(h.logger).Log("msg", "Sending heartbeat request", "address", address) - config, err := h.certConfig.Read(cert) - if err != nil { - return errors.WithStack(err) - } - client := http.Client{ - Transport: &http.Transport{ - TLSClientConfig: config, - }, - } - - var buffer bytes.Buffer - if err := json.NewEncoder(&buffer).Encode(raftNodes); err != nil { - return errors.WithStack(err) - } - - url := net.EnsureHTTPS(fmt.Sprintf("%s%s", address, h.databaseEndpoint)) - request, err := http.NewRequest("PUT", url, bytes.NewReader(buffer.Bytes())) - if err != nil { - return errors.WithStack(err) - } - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - - request = request.WithContext(ctx) - request.Close = true // Immediately close the connection after the request is done - - // Perform the request asynchronously, so we can abort it if the task context is done. - errCh := make(chan error) - go func() { - response, err := client.Do(request) - if err != nil { - errCh <- errors.Wrap(err, "failed to send HTTP request") - return - } - defer response.Body.Close() - if response.StatusCode != http.StatusOK { - errCh <- errors.Errorf("HTTP request failed: %s", response.Status) - return - } - errCh <- nil - }() - - select { - case err := <-errCh: - return err - case <-taskCtx.Done(): - return taskCtx.Err() - } + level.Info(h.logger).Log("msg", "Completed heartbeat round") } type taskShim struct{} @@ -288,3 +472,8 @@ type certConfigShim struct{} func (certConfigShim) Read(info *cert.Info) (*tls.Config, error) { return cert.TLSClientConfig(info) } + +// SetDqliteVersionHeader the dqlite version header. +func SetDqliteVersionHeader(request *http.Request) { + request.Header.Set("X-Dqlite-Version", fmt.Sprintf("%d", dqliteVersion)) +}