diff --git a/ocp/rpc/messaging/server.go b/ocp/rpc/messaging/server.go index f625779..a7fc008 100644 --- a/ocp/rpc/messaging/server.go +++ b/ocp/rpc/messaging/server.go @@ -21,6 +21,7 @@ import ( messagingpb "github.com/code-payments/ocp-protobuf-api/generated/go/messaging/v1" "github.com/code-payments/ocp-server/grpc/client" + "github.com/code-payments/ocp-server/protoutil" "github.com/code-payments/ocp-server/retry" "github.com/code-payments/ocp-server/retry/backoff" @@ -224,7 +225,9 @@ func (s *server) OpenMessageStreamWithKeepAlive(streamer messagingpb.Messaging_O go s.flush(ctx, req.GetRequest().RendezvousKey, ms) sendPingCh := time.After(0) - streamHealthCh := s.monitorOpenMessageStreamHealth(ctx, log, ssRef, streamer) + streamHealthCh := protoutil.MonitorStreamHealth(ctx, log, streamer, func(t *messagingpb.OpenMessageStreamWithKeepAliveRequest) bool { + return t.GetPong() != nil + }) updateRendezvousRecordCh := time.After(rendezvousRecordRefreshInterval) for { @@ -315,35 +318,6 @@ func (s *server) boundedRecv( } } -// Very naive implementation to start -func (s *server) monitorOpenMessageStreamHealth( - ctx context.Context, - log *zap.Logger, - ssRef string, - streamer messagingpb.Messaging_OpenMessageStreamWithKeepAliveServer, -) <-chan struct{} { - streamHealthChan := make(chan struct{}) - go func() { - defer close(streamHealthChan) - - for { - req, err := s.boundedRecv(ctx, streamer, messageStreamKeepAliveRecvTimeout) - if err != nil { - return - } - - switch req.RequestOrPong.(type) { - case *messagingpb.OpenMessageStreamWithKeepAliveRequest_Pong: - log.Debug(fmt.Sprintf("received pong from client (stream=%s)", ssRef)) - default: - // Client sent something unexpected. Terminate the stream - return - } - } - }() - return streamHealthChan -} - // OpenMessageStream implements messagingpb.MessagingServer.OpenMessageStream. // // Note: This variant is more suitable for short-lived streams, and is coded as diff --git a/protoutil/compare.go b/protoutil/compare.go new file mode 100644 index 0000000..23f8b87 --- /dev/null +++ b/protoutil/compare.go @@ -0,0 +1,58 @@ +package protoutil + +import ( + "fmt" + + "google.golang.org/protobuf/proto" +) + +func SliceEqualError[T proto.Message](a, b []T) error { + if len(a) != len(b) { + return fmt.Errorf("len(%d) != len(%d)", len(a), len(b)) + } + + for i := 0; i < len(a); i++ { + if err := ProtoEqualError(a[i], b[i]); err != nil { + return fmt.Errorf("mismatch[%d]: %w", i, err) + } + } + + return nil +} + +func SetEqualError[T proto.Message](a, b []T) error { + if len(a) != len(b) { + return fmt.Errorf("len(%d) != len(%d)", len(a), len(b)) + } + + for i := 0; i < len(a); i++ { + found := false + for j := 0; j < len(b); j++ { + if proto.Equal(a[i], b[j]) { + found = true + break + } + } + if !found { + return fmt.Errorf("missing[%d]: %v", i, a[i]) + } + } + + return nil +} + +func ProtoEqualError(a, b proto.Message) error { + if !proto.Equal(a, b) { + return fmt.Errorf("expected: %v\nactual: %v\n", a, b) + } + + return nil +} + +func SliceClone[T proto.Message](src []T) []T { + cloned := make([]T, len(src)) + for i := range src { + cloned[i] = proto.Clone(src[i]).(T) + } + return cloned +} diff --git a/protoutil/stream.go b/protoutil/stream.go index c984bc2..0058bff 100644 --- a/protoutil/stream.go +++ b/protoutil/stream.go @@ -4,12 +4,15 @@ import ( "context" "time" + "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" ) +const streamKeepAliveRecvTimeout = 10 * time.Second + type Ptr[T any] interface { proto.Message *T @@ -38,3 +41,29 @@ func BoundedReceive[Req any]( return req, status.Error(codes.DeadlineExceeded, "timeout receiving message") } } + +func MonitorStreamHealth[Req any]( + ctx context.Context, + log *zap.Logger, + streamer grpc.ServerStream, + validFn func(*Req) bool, +) <-chan struct{} { + healthCh := make(chan struct{}) + go func() { + defer close(healthCh) + + for { + req, err := BoundedReceive[Req](ctx, streamer, streamKeepAliveRecvTimeout) + if err != nil { + return + } + + if !validFn(req) { + return + } + + log.Debug("receiving pong from client") + } + }() + return healthCh +}