diff --git a/multinode/mock_rpc_client_test.go b/multinode/mock_rpc_client_test.go index a90063e..6e129e3 100644 --- a/multinode/mock_rpc_client_test.go +++ b/multinode/mock_rpc_client_test.go @@ -324,6 +324,52 @@ func (_c *mockRPCClient_IsSyncing_Call[CHAIN_ID, HEAD]) RunAndReturn(run func(co return _c } +// PollHealthCheck provides a mock function with given fields: ctx +func (_m *mockRPCClient[CHAIN_ID, HEAD]) PollHealthCheck(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for PollHealthCheck") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// mockRPCClient_PollHealthCheck_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PollHealthCheck' +type mockRPCClient_PollHealthCheck_Call[CHAIN_ID ID, HEAD Head] struct { + *mock.Call +} + +// PollHealthCheck is a helper method to define mock.On call +// - ctx context.Context +func (_e *mockRPCClient_Expecter[CHAIN_ID, HEAD]) PollHealthCheck(ctx interface{}) *mockRPCClient_PollHealthCheck_Call[CHAIN_ID, HEAD] { + return &mockRPCClient_PollHealthCheck_Call[CHAIN_ID, HEAD]{Call: _e.mock.On("PollHealthCheck", ctx)} +} + +func (_c *mockRPCClient_PollHealthCheck_Call[CHAIN_ID, HEAD]) Run(run func(ctx context.Context)) *mockRPCClient_PollHealthCheck_Call[CHAIN_ID, HEAD] { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *mockRPCClient_PollHealthCheck_Call[CHAIN_ID, HEAD]) Return(_a0 error) *mockRPCClient_PollHealthCheck_Call[CHAIN_ID, HEAD] { + _c.Call.Return(_a0) + return _c +} + +func (_c *mockRPCClient_PollHealthCheck_Call[CHAIN_ID, HEAD]) RunAndReturn(run func(context.Context) error) *mockRPCClient_PollHealthCheck_Call[CHAIN_ID, HEAD] { + _c.Call.Return(run) + return _c +} + // SubscribeToFinalizedHeads provides a mock function with given fields: ctx func (_m *mockRPCClient[CHAIN_ID, HEAD]) SubscribeToFinalizedHeads(ctx context.Context) (<-chan HEAD, Subscription, error) { ret := _m.Called(ctx) diff --git a/multinode/node_lifecycle.go b/multinode/node_lifecycle.go index e2974c0..dffdcf3 100644 --- a/multinode/node_lifecycle.go +++ b/multinode/node_lifecycle.go @@ -111,6 +111,11 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { lggr.Tracew("Pinging RPC", "nodeState", n.State(), "pollFailures", pollFailures) pollCtx, cancel := context.WithTimeout(ctx, pollInterval) version, pingErr := n.RPC().ClientVersion(pollCtx) + if pingErr == nil { + if healthErr := n.RPC().PollHealthCheck(pollCtx); healthErr != nil { + pingErr = fmt.Errorf("poll health check failed: %w", healthErr) + } + } cancel() if pingErr != nil { // prevent overflow diff --git a/multinode/node_lifecycle_test.go b/multinode/node_lifecycle_test.go index 684d0c7..dca2531 100644 --- a/multinode/node_lifecycle_test.go +++ b/multinode/node_lifecycle_test.go @@ -147,6 +147,8 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { }).Once() // redundant call to stay in alive state rpc.On("ClientVersion", mock.Anything).Return("", nil) + // PollHealthCheck is called after successful ClientVersion - return nil to pass + rpc.On("PollHealthCheck", mock.Anything).Return(nil).Maybe() node.declareAlive() tests.AssertLogCountEventually(t, observedLogs, fmt.Sprintf("Poll failure, RPC endpoint %s failed to respond properly", node.String()), pollFailureThreshold) tests.AssertLogCountEventually(t, observedLogs, "Ping successful", 2) @@ -176,6 +178,31 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { return nodeStateUnreachable == node.State() }) }) + t.Run("optional poll health check failure counts as poll failure and transitions to unreachable", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}) + lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{ + pollFailureThreshold: 1, + pollInterval: tests.TestInterval, + }, + rpc: rpc, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("ClientVersion", mock.Anything).Return("mock-version", nil) + rpc.On("PollHealthCheck", mock.Anything).Return(errors.New("health check failed")) + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + + node.declareAlive() + tests.AssertLogCountEventually(t, observedLogs, fmt.Sprintf("Poll failure, RPC endpoint %s failed to respond properly", node.String()), 1) + tests.AssertEventually(t, func() bool { + return nodeStateUnreachable == node.State() + }) + }) t.Run("with threshold poll failures, but we are the last node alive, forcibly keeps it alive", func(t *testing.T) { t.Parallel() rpc := newMockRPCClient[ID, Head](t) @@ -247,6 +274,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() rpc.On("ClientVersion", mock.Anything).Return("", nil) + rpc.On("PollHealthCheck", mock.Anything).Return(nil).Maybe() const mostRecentBlock = 20 rpc.On("GetInterceptedChainInfo").Return(ChainInfo{BlockNumber: mostRecentBlock}, ChainInfo{BlockNumber: 30}) poolInfo := newMockPoolChainInfoProvider(t) @@ -282,6 +310,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() rpc.On("ClientVersion", mock.Anything).Return("", nil) + rpc.On("PollHealthCheck", mock.Anything).Return(nil).Maybe() const mostRecentBlock = 20 rpc.On("GetInterceptedChainInfo").Return(ChainInfo{BlockNumber: mostRecentBlock}, ChainInfo{BlockNumber: 30}) poolInfo := newMockPoolChainInfoProvider(t) @@ -310,6 +339,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() rpc.On("ClientVersion", mock.Anything).Return("", nil) + rpc.On("PollHealthCheck", mock.Anything).Return(nil).Maybe() const mostRecentBlock = 20 rpc.On("GetInterceptedChainInfo").Return(ChainInfo{BlockNumber: mostRecentBlock}, ChainInfo{BlockNumber: 30}).Twice() poolInfo := newMockPoolChainInfoProvider(t) @@ -344,6 +374,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() rpc.On("ClientVersion", mock.Anything).Return("", nil) + rpc.On("PollHealthCheck", mock.Anything).Return(nil).Maybe() const mostRecentBlock = 20 rpc.On("GetInterceptedChainInfo").Return(ChainInfo{BlockNumber: mostRecentBlock}, ChainInfo{BlockNumber: 30}) node.declareAlive() diff --git a/multinode/rpc_client_base.go b/multinode/rpc_client_base.go index b4a886c..a06200d 100644 --- a/multinode/rpc_client_base.go +++ b/multinode/rpc_client_base.go @@ -296,3 +296,10 @@ func (m *RPCClientBase[HEAD]) GetInterceptedChainInfo() (latest, highestUserObse defer m.chainInfoLock.RUnlock() return m.latestChainInfo, m.highestUserObservations } + +// PollHealthCheck provides a default no-op implementation for the RPCClient interface. +// Chain-specific RPC clients can override this method to perform additional health checks +// during polling (e.g., verifying historical state availability). +func (m *RPCClientBase[HEAD]) PollHealthCheck(ctx context.Context) error { + return nil +} diff --git a/multinode/types.go b/multinode/types.go index b31c6ca..e9aa954 100644 --- a/multinode/types.go +++ b/multinode/types.go @@ -77,6 +77,10 @@ type RPCClient[ // Ensure implementation does not have a race condition when values are reset before request completion and as // a result latest ChainInfo contains information from the previous cycle. GetInterceptedChainInfo() (latest, highestUserObservations ChainInfo) + // PollHealthCheck - performs an optional additional health check during polling. + // Implementations can use this for chain-specific health verification (e.g., historical state availability). + // Return nil if the check passes or is not applicable, or an error if the check fails. + PollHealthCheck(ctx context.Context) error } // Head is the interface required by the NodeClient