diff --git a/sentry_sdk/integrations/grpc/__init__.py b/sentry_sdk/integrations/grpc/__init__.py index b6641163a9..a41631c37e 100644 --- a/sentry_sdk/integrations/grpc/__init__.py +++ b/sentry_sdk/integrations/grpc/__init__.py @@ -50,14 +50,29 @@ def __getitem__(self, _): GRPC_VERSION = parse_version(grpc.__version__) +def _is_channel_intercepted(channel: "Channel") -> bool: + interceptor = getattr(channel, "_interceptor", None) + while interceptor is not None: + if isinstance(interceptor, ClientInterceptor): + return True + + inner_channel = getattr(channel, "_channel", None) + if inner_channel is None: + return False + + channel = inner_channel + interceptor = getattr(channel, "_interceptor", None) + + return False + + def _wrap_channel_sync(func: "Callable[P, Channel]") -> "Callable[P, Channel]": "Wrapper for synchronous secure and insecure channel." @wraps(func) def patched_channel(*args: "Any", **kwargs: "Any") -> "Channel": channel = func(*args, **kwargs) - if not ClientInterceptor._is_intercepted: - ClientInterceptor._is_intercepted = True + if not _is_channel_intercepted(channel): return intercept_channel(channel, ClientInterceptor()) else: return channel @@ -70,7 +85,7 @@ def _wrap_intercept_channel(func: "Callable[P, Channel]") -> "Callable[P, Channe def patched_intercept_channel( channel: "Channel", *interceptors: "grpc.ServerInterceptor" ) -> "Channel": - if ClientInterceptor._is_intercepted: + if _is_channel_intercepted(channel): interceptors = tuple( [ interceptor diff --git a/sentry_sdk/integrations/grpc/client.py b/sentry_sdk/integrations/grpc/client.py index 69b3f3d318..b6cbc54f10 100644 --- a/sentry_sdk/integrations/grpc/client.py +++ b/sentry_sdk/integrations/grpc/client.py @@ -22,8 +22,6 @@ class ClientInterceptor( grpc.UnaryUnaryClientInterceptor, # type: ignore grpc.UnaryStreamClientInterceptor, # type: ignore ): - _is_intercepted = False - def intercept_unary_unary( self: "ClientInterceptor", continuation: "Callable[[ClientCallDetails, Message], _UnaryOutcome]", diff --git a/tests/integrations/grpc/test_grpc.py b/tests/integrations/grpc/test_grpc.py index 8d2698f411..25436d9feb 100644 --- a/tests/integrations/grpc/test_grpc.py +++ b/tests/integrations/grpc/test_grpc.py @@ -8,6 +8,7 @@ from sentry_sdk import start_span, start_transaction from sentry_sdk.consts import OP from sentry_sdk.integrations.grpc import GRPCIntegration +from sentry_sdk.integrations.grpc.client import ClientInterceptor from tests.conftest import ApproxDict from tests.integrations.grpc.grpc_test_service_pb2 import gRPCTestMessage from tests.integrations.grpc.grpc_test_service_pb2_grpc import ( @@ -269,6 +270,42 @@ def test_grpc_client_other_interceptor(sentry_init, capture_events_forksafe): ) +@pytest.mark.forked +def test_prevent_dual_client_interceptor(sentry_init, capture_events_forksafe): + sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()]) + events = capture_events_forksafe() + + server, channel = _set_up() + + # Intercept the channel + channel = grpc.intercept_channel(channel, ClientInterceptor()) + stub = gRPCTestServiceStub(channel) + + with start_transaction(): + stub.TestServe(gRPCTestMessage(text="test")) + + _tear_down(server=server) + + events.write_file.close() + events.read_event() + local_transaction = events.read_event() + span = local_transaction["spans"][0] + + assert len(local_transaction["spans"]) == 1 + assert span["op"] == OP.GRPC_CLIENT + assert ( + span["description"] + == "unary unary call to /grpc_test_server.gRPCTestService/TestServe" + ) + assert span["data"] == ApproxDict( + { + "type": "unary unary", + "method": "/grpc_test_server.gRPCTestService/TestServe", + "code": "OK", + } + ) + + @pytest.mark.forked def test_grpc_client_and_servers_interceptors_integration( sentry_init, capture_events_forksafe