diff --git a/service/authorization/authorization.go b/service/authorization/authorization.go index 17c6b24070..edaed4530e 100644 --- a/service/authorization/authorization.go +++ b/service/authorization/authorization.go @@ -33,8 +33,6 @@ import ( "github.com/opentdf/platform/service/pkg/config" "github.com/opentdf/platform/service/pkg/db" "github.com/opentdf/platform/service/pkg/serviceregistry" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -154,10 +152,6 @@ func (as AuthorizationService) IsReady(ctx context.Context) error { } func (as *AuthorizationService) GetDecisionsByToken(ctx context.Context, req *connect.Request[authorization.GetDecisionsByTokenRequest]) (*connect.Response[authorization.GetDecisionsByTokenResponse], error) { - // Extract trace context from the incoming request - propagator := otel.GetTextMapPropagator() - ctx = propagator.Extract(ctx, propagation.HeaderCarrier(req.Header())) - ctx, span := as.Start(ctx, "GetDecisionsByToken") defer span.End() diff --git a/service/authorization/v2/authorization.go b/service/authorization/v2/authorization.go index 4fb8e18b55..761e3ff682 100644 --- a/service/authorization/v2/authorization.go +++ b/service/authorization/v2/authorization.go @@ -19,8 +19,6 @@ import ( ctxAuth "github.com/opentdf/platform/service/pkg/auth" "github.com/opentdf/platform/service/pkg/cache" "github.com/opentdf/platform/service/pkg/serviceregistry" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -141,10 +139,6 @@ func (as *Service) GetEntitlements(ctx context.Context, req *connect.Request[aut ctx, span := as.Start(ctx, "GetEntitlements") defer span.End() - // Extract trace context from the incoming request - propagator := otel.GetTextMapPropagator() - ctx = propagator.Extract(ctx, propagation.HeaderCarrier(req.Header())) - entityIdentifier := req.Msg.GetEntityIdentifier() withComprehensiveHierarchy := req.Msg.GetWithComprehensiveHierarchy() @@ -172,10 +166,6 @@ func (as *Service) GetDecision(ctx context.Context, req *connect.Request[authzV2 ctx, span := as.Start(ctx, "GetDecision") defer span.End() - // Extract trace context from the incoming request - propagator := otel.GetTextMapPropagator() - ctx = propagator.Extract(ctx, propagation.HeaderCarrier(req.Header())) - pdp, err := access.NewJustInTimePDP(ctx, as.logger, as.sdk, as.cache, as.config.AllowDirectEntitlements) if err != nil { return nil, statusifyError(ctx, as.logger, errors.Join(ErrFailedToInitPDP, err)) @@ -222,10 +212,6 @@ func (as *Service) GetDecisionMultiResource(ctx context.Context, req *connect.Re ctx, span := as.Start(ctx, "GetDecisionMultiResource") defer span.End() - // Extract trace context from the incoming request - propagator := otel.GetTextMapPropagator() - ctx = propagator.Extract(ctx, propagation.HeaderCarrier(req.Header())) - pdp, err := access.NewJustInTimePDP(ctx, as.logger, as.sdk, as.cache, as.config.AllowDirectEntitlements) if err != nil { return nil, statusifyError(ctx, as.logger, errors.Join(ErrFailedToInitPDP, err)) @@ -275,10 +261,6 @@ func (as *Service) GetDecisionBulk(ctx context.Context, req *connect.Request[aut ctx, span := as.Start(ctx, "GetDecisionBulk") defer span.End() - // Extract trace context from the incoming request - propagator := otel.GetTextMapPropagator() - ctx = propagator.Extract(ctx, propagation.HeaderCarrier(req.Header())) - pdp, err := access.NewJustInTimePDP(ctx, as.logger, as.sdk, as.cache, as.config.AllowDirectEntitlements) if err != nil { return nil, statusifyError(ctx, as.logger, errors.Join(ErrFailedToInitPDP, err)) diff --git a/service/go.mod b/service/go.mod index 33cea5edb4..ecb9177807 100644 --- a/service/go.mod +++ b/service/go.mod @@ -9,6 +9,7 @@ require ( connectrpc.com/connect v1.19.1 connectrpc.com/grpchealth v1.4.0 connectrpc.com/grpcreflect v1.3.0 + connectrpc.com/otelconnect v0.9.0 connectrpc.com/validate v0.6.0 github.com/Masterminds/squirrel v1.5.4 github.com/Nerzal/gocloak/v13 v13.9.0 diff --git a/service/go.sum b/service/go.sum index 8892b27b38..25e0741d3b 100644 --- a/service/go.sum +++ b/service/go.sum @@ -10,6 +10,8 @@ connectrpc.com/grpchealth v1.4.0 h1:MJC96JLelARPgZTiRF9KRfY/2N9OcoQvF2EWX07v2IE= connectrpc.com/grpchealth v1.4.0/go.mod h1:WhW6m1EzTmq3Ky1FE8EfkIpSDc6TfUx2M2KqZO3ts/Q= connectrpc.com/grpcreflect v1.3.0 h1:Y4V+ACf8/vOb1XOc251Qun7jMB75gCUNw6llvB9csXc= connectrpc.com/grpcreflect v1.3.0/go.mod h1:nfloOtCS8VUQOQ1+GTdFzVg2CJo4ZGaat8JIovCtDYs= +connectrpc.com/otelconnect v0.9.0 h1:NggB3pzRC3pukQWaYbRHJulxuXvmCKCKkQ9hbrHAWoA= +connectrpc.com/otelconnect v0.9.0/go.mod h1:AEkVLjCPXra+ObGFCOClcJkNjS7zPaQSqvO0lCyjfZc= connectrpc.com/validate v0.6.0 h1:DcrgDKt2ZScrUs/d/mh9itD2yeEa0UbBBa+i0mwzx+4= connectrpc.com/validate v0.6.0/go.mod h1:ihrpI+8gVbLH1fvVWJL1I3j0CfWnF8P/90LsmluRiZs= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 0332e91e5d..8d16b784f7 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -523,6 +523,13 @@ func pprofHandler(h http.Handler) http.Handler { func newConnectRPC(c Config, authInt connect.Interceptor, ints []connect.Interceptor, logger *logger.Logger) (*ConnectRPC, error) { interceptors := make([]connect.HandlerOption, 0) + // OTel tracing and metrics for incoming Connect requests, before all other interceptors + serverTraceInt, err := tracing.ConnectServerTraceInterceptor() + if err != nil { + return nil, fmt.Errorf("failed to create server trace interceptor: %w", err) + } + interceptors = append(interceptors, connect.WithInterceptors(serverTraceInt)) + if c.Auth.Enabled { if authInt == nil { return nil, errors.New("authentication enabled but no interceptor provided") @@ -597,6 +604,13 @@ func (s OpenTDFServer) Stop() { func (s inProcessServer) Conn() *sdk.ConnectRPCConnection { var clientInterceptors []connect.Interceptor + // OTel tracing and metrics for outbound IPC Connect RPCs + if clientTraceInt, err := tracing.ConnectClientTraceInterceptor(); err != nil { + s.logger.Error("failed to create IPC client trace interceptor", slog.String("error", err.Error())) + } else { + clientInterceptors = append(clientInterceptors, clientTraceInt) + } + // Add audit interceptor clientInterceptors = append(clientInterceptors, sdkAudit.MetadataAddingConnectInterceptor()) diff --git a/service/internal/server/server_test.go b/service/internal/server/server_test.go index e34bd29519..12884d1302 100644 --- a/service/internal/server/server_test.go +++ b/service/internal/server/server_test.go @@ -554,32 +554,32 @@ func TestNewConnectRPC(t *testing.T) { authEnabled: true, authInt: noopInterceptor(), extraInts: []connect.Interceptor{noopInterceptor(), noopInterceptor()}, - wantIntLen: 3, - wantDescription: "1 auth + 1 extras + 1 validation/audit", + wantIntLen: 4, + wantDescription: "1 trace + 1 auth + 1 extras + 1 validation/audit", }, { name: "auth enabled no extras", authEnabled: true, authInt: noopInterceptor(), extraInts: nil, - wantIntLen: 2, - wantDescription: "1 auth + 1 validation/audit", + wantIntLen: 3, + wantDescription: "1 trace + 1 auth + 1 validation/audit", }, { name: "auth disabled no extras", authEnabled: false, authInt: nil, extraInts: nil, - wantIntLen: 1, - wantDescription: "1 validation/audit only", + wantIntLen: 2, + wantDescription: "1 trace + 1 validation/audit only", }, { name: "auth disabled with extras", authEnabled: false, authInt: nil, extraInts: []connect.Interceptor{noopInterceptor()}, - wantIntLen: 2, - wantDescription: "1 extras + 1 validation/audit", + wantIntLen: 3, + wantDescription: "1 trace + 1 extras + 1 validation/audit", }, { name: "auth enabled but nil authInt returns error", diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index f8bfacd092..25e617fddb 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -372,6 +372,13 @@ func setupERSConnection(cfg *config.Config, oidcconfig *auth.OIDCConfiguration, ersConnectRPCConn := &sdk.ConnectRPCConnection{} + // OTel tracing and metrics for outbound ERS Connect RPCs (outermost interceptor) + if ersTraceInt, err := tracing.ConnectClientTraceInterceptor(); err != nil { + logger.Error("failed to create ERS trace interceptor", slog.String("error", err.Error())) + } else { + ersConnectRPCConn.Options = append(ersConnectRPCConn.Options, connect.WithInterceptors(ersTraceInt)) + } + // Configure TLS tlsConfig := configureTLSForERS(cfg, ersConnectRPCConn) diff --git a/service/tracing/connect_interceptor.go b/service/tracing/connect_interceptor.go new file mode 100644 index 0000000000..6168965990 --- /dev/null +++ b/service/tracing/connect_interceptor.go @@ -0,0 +1,29 @@ +package tracing + +import ( + "connectrpc.com/connect" + "connectrpc.com/otelconnect" +) + +// ConnectClientTraceInterceptor returns a Connect interceptor backed by +// otelconnect that injects OpenTelemetry trace context into outbound requests +// and creates per-RPC spans and metrics. +func ConnectClientTraceInterceptor() (connect.Interceptor, error) { + return otelconnect.NewInterceptor( + otelconnect.WithoutTraceEvents(), + ) +} + +// ConnectServerTraceInterceptor returns a Connect interceptor backed by +// otelconnect that extracts OpenTelemetry trace context from incoming requests +// and creates per-RPC spans and metrics. +// +// WithTrustRemote makes server spans children of the incoming trace rather +// than linked root spans. WithoutServerPeerAttributes reduces cardinality. +func ConnectServerTraceInterceptor() (connect.Interceptor, error) { + return otelconnect.NewInterceptor( + otelconnect.WithTrustRemote(), + otelconnect.WithoutServerPeerAttributes(), + otelconnect.WithoutTraceEvents(), + ) +} diff --git a/service/tracing/connect_interceptor_test.go b/service/tracing/connect_interceptor_test.go new file mode 100644 index 0000000000..7f95ee7442 --- /dev/null +++ b/service/tracing/connect_interceptor_test.go @@ -0,0 +1,235 @@ +package tracing_test + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "connectrpc.com/connect" + "github.com/opentdf/platform/service/tracing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/types/known/emptypb" +) + +// setupOTel configures an in-memory tracer provider and W3C trace propagator, +// returning the provider and a cleanup function that restores prior globals. +func setupOTel(t *testing.T) *sdktrace.TracerProvider { + t.Helper() + + exporter := tracetest.NewInMemoryExporter() + tp := sdktrace.NewTracerProvider( + sdktrace.WithSyncer(exporter), + sdktrace.WithSampler(sdktrace.AlwaysSample()), + ) + + prevTP := otel.GetTracerProvider() + prevProp := otel.GetTextMapPropagator() + t.Cleanup(func() { + _ = tp.Shutdown(context.Background()) + otel.SetTracerProvider(prevTP) + otel.SetTextMapPropagator(prevProp) + }) + + otel.SetTracerProvider(tp) + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + )) + + return tp +} + +// TestTraceContextPropagation_Unary verifies that the client interceptor +// injects traceparent/tracestate headers and the server interceptor extracts them, +// resulting in both sides sharing the same trace ID for unary RPCs. +func TestTraceContextPropagation_Unary(t *testing.T) { + tp := setupOTel(t) + + serverInt, err := tracing.ConnectServerTraceInterceptor() + require.NoError(t, err) + clientInt, err := tracing.ConnectClientTraceInterceptor() + require.NoError(t, err) + + var ( + mu sync.Mutex + serverTraceID trace.TraceID + ) + + mux := http.NewServeMux() + handler := connect.NewUnaryHandler( + "/test.v1.TestService/Ping", + func(ctx context.Context, _ *connect.Request[emptypb.Empty]) (*connect.Response[emptypb.Empty], error) { + sc := trace.SpanContextFromContext(ctx) + mu.Lock() + serverTraceID = sc.TraceID() + mu.Unlock() + return connect.NewResponse(&emptypb.Empty{}), nil + }, + connect.WithInterceptors(serverInt), + ) + mux.Handle("/test.v1.TestService/", handler) + + srv := httptest.NewServer(mux) + defer srv.Close() + + client := connect.NewClient[emptypb.Empty, emptypb.Empty]( + srv.Client(), + srv.URL+"/test.v1.TestService/Ping", + connect.WithInterceptors(clientInt), + ) + + ctx, span := tp.Tracer("test").Start(context.Background(), "client-call") + clientTraceID := span.SpanContext().TraceID() + + _, err = client.CallUnary(ctx, connect.NewRequest(&emptypb.Empty{})) + span.End() + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + + assert.True(t, clientTraceID.IsValid(), "client trace ID should be valid") + assert.True(t, serverTraceID.IsValid(), "server trace ID should be valid") + assert.Equal(t, clientTraceID, serverTraceID, + "server must see the same trace ID as the client") + + t.Logf("client trace: %s", clientTraceID) + t.Logf("server trace: %s", serverTraceID) +} + +// TestTraceContextPropagation_ServerStream verifies trace context propagation +// for server-streaming RPCs, exercising WrapStreamingClient on the client side +// and WrapStreamingHandler on the server side. +func TestTraceContextPropagation_ServerStream(t *testing.T) { + tp := setupOTel(t) + + serverInt, err := tracing.ConnectServerTraceInterceptor() + require.NoError(t, err) + clientInt, err := tracing.ConnectClientTraceInterceptor() + require.NoError(t, err) + + var ( + mu sync.Mutex + serverTraceID trace.TraceID + ) + + mux := http.NewServeMux() + handler := connect.NewServerStreamHandler( + "/test.v1.TestService/StreamPing", + func(ctx context.Context, _ *connect.Request[emptypb.Empty], stream *connect.ServerStream[emptypb.Empty]) error { + sc := trace.SpanContextFromContext(ctx) + mu.Lock() + serverTraceID = sc.TraceID() + mu.Unlock() + return stream.Send(&emptypb.Empty{}) + }, + connect.WithInterceptors(serverInt), + ) + mux.Handle("/test.v1.TestService/", handler) + + srv := httptest.NewServer(mux) + defer srv.Close() + + client := connect.NewClient[emptypb.Empty, emptypb.Empty]( + srv.Client(), + srv.URL+"/test.v1.TestService/StreamPing", + connect.WithInterceptors(clientInt), + ) + + ctx, span := tp.Tracer("test").Start(context.Background(), "client-stream-call") + clientTraceID := span.SpanContext().TraceID() + + stream, err := client.CallServerStream(ctx, connect.NewRequest(&emptypb.Empty{})) + require.NoError(t, err) + for stream.Receive() { + } + require.NoError(t, stream.Err()) + require.NoError(t, stream.Close()) + span.End() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, clientTraceID.IsValid(), "client trace ID should be valid") + assert.True(t, serverTraceID.IsValid(), "server trace ID should be valid") + assert.Equal(t, clientTraceID, serverTraceID, + "server must see the same trace ID as the client (streaming)") + + t.Logf("client trace: %s", clientTraceID) + t.Logf("server trace: %s", serverTraceID) +} + +// TestTraceContextPropagation_NoTraceContext verifies that a no-op propagator +// prevents trace context from reaching the server, even when the client has +// an active span. This proves the interceptor respects the propagator config. +func TestTraceContextPropagation_NoTraceContext(t *testing.T) { + tp := sdktrace.NewTracerProvider(sdktrace.WithSampler(sdktrace.AlwaysSample())) + defer func() { _ = tp.Shutdown(context.Background()) }() + + prevTP := otel.GetTracerProvider() + prevProp := otel.GetTextMapPropagator() + defer func() { + otel.SetTracerProvider(prevTP) + otel.SetTextMapPropagator(prevProp) + }() + otel.SetTracerProvider(tp) + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator()) + + serverInt, err := tracing.ConnectServerTraceInterceptor() + require.NoError(t, err) + clientInt, err := tracing.ConnectClientTraceInterceptor() + require.NoError(t, err) + + var ( + mu sync.Mutex + serverTraceID trace.TraceID + ) + + mux := http.NewServeMux() + handler := connect.NewUnaryHandler( + "/test.v1.TestService/Ping", + func(ctx context.Context, _ *connect.Request[emptypb.Empty]) (*connect.Response[emptypb.Empty], error) { + mu.Lock() + serverTraceID = trace.SpanContextFromContext(ctx).TraceID() + mu.Unlock() + return connect.NewResponse(&emptypb.Empty{}), nil + }, + connect.WithInterceptors(serverInt), + ) + mux.Handle("/test.v1.TestService/", handler) + + srv := httptest.NewServer(mux) + defer srv.Close() + + client := connect.NewClient[emptypb.Empty, emptypb.Empty]( + srv.Client(), + srv.URL+"/test.v1.TestService/Ping", + connect.WithInterceptors(clientInt), + ) + + ctx, span := tp.Tracer("test").Start(context.Background(), "client-call") + clientTraceID := span.SpanContext().TraceID() + require.True(t, clientTraceID.IsValid(), "client must have a valid trace ID for this test") + + _, err = client.CallUnary(ctx, connect.NewRequest(&emptypb.Empty{})) + span.End() + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + + // otelconnect still creates a server span, so serverTraceID must be valid. + // But with a no-op propagator, the client's trace context is not injected + // into headers — the server starts a new independent trace. + require.True(t, serverTraceID.IsValid(), "server span should still be created") + assert.NotEqual(t, clientTraceID, serverTraceID, + "server should have a different trace ID when no propagator is configured") +}