diff --git a/spring-grpc-core/src/main/java/org/springframework/grpc/server/exception/GrpcExceptionHandlerInterceptor.java b/spring-grpc-core/src/main/java/org/springframework/grpc/server/exception/GrpcExceptionHandlerInterceptor.java index 76946335..7be47993 100644 --- a/spring-grpc-core/src/main/java/org/springframework/grpc/server/exception/GrpcExceptionHandlerInterceptor.java +++ b/spring-grpc-core/src/main/java/org/springframework/grpc/server/exception/GrpcExceptionHandlerInterceptor.java @@ -44,6 +44,7 @@ * returns a null. * * @author Dave Syer + * @author Andrey Litvitski * @see ServerInterceptor * @see GrpcExceptionHandler */ @@ -81,7 +82,7 @@ public Listener interceptCall(ServerCall call, this.logger.trace("Failed to start exception handler call", t); StatusException statusEx = fallbackHandler.handleException(t); exceptionHandledServerCall.close(statusEx != null ? statusEx.getStatus() : Status.fromThrowable(t), - headers(t)); + headers(statusEx != null ? statusEx : t)); return new Listener<>() { }; } diff --git a/spring-grpc-core/src/test/java/org/springframework/grpc/server/exception/GrpcExceptionHandlerInterceptorTests.java b/spring-grpc-core/src/test/java/org/springframework/grpc/server/exception/GrpcExceptionHandlerInterceptorTests.java index 9d91c885..3b0dafe5 100644 --- a/spring-grpc-core/src/test/java/org/springframework/grpc/server/exception/GrpcExceptionHandlerInterceptorTests.java +++ b/spring-grpc-core/src/test/java/org/springframework/grpc/server/exception/GrpcExceptionHandlerInterceptorTests.java @@ -17,11 +17,35 @@ package org.springframework.grpc.server.exception; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.grpc.server.exception.GrpcExceptionHandlerInterceptor.FallbackHandler; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import com.google.rpc.Code; +import com.google.rpc.Status; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.StatusException; +import io.grpc.protobuf.ProtoUtils; +import io.grpc.protobuf.StatusProto; + +/** + * Tests for {@link GrpcExceptionHandlerInterceptor}. + * + * @author Dave Syer + * @author Andrey Litvitski + */ public class GrpcExceptionHandlerInterceptorTests { @Test @@ -30,4 +54,38 @@ void testNullStatusHandled() { .isNotNull(); } + @Test + void propagatesTrailersFromStatusExceptionWhenStartCallThrows() { + Status statusWithDetails = Status.newBuilder() + .setCode(Code.PERMISSION_DENIED_VALUE) + .setMessage("access denied") + .addDetails(Any.pack(Empty.getDefaultInstance())) + .build(); + StatusException statusEx = StatusProto.toStatusException(statusWithDetails); + GrpcExceptionHandler handler = ex -> statusEx; + ServerInterceptor interceptor = new GrpcExceptionHandlerInterceptor(handler); + @SuppressWarnings("unchecked") + ServerCall call = mock(ServerCall.class); + MethodDescriptor method = MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("test/Test") + .setRequestMarshaller(ProtoUtils.marshaller(Empty.getDefaultInstance())) + .setResponseMarshaller(ProtoUtils.marshaller(Empty.getDefaultInstance())) + .build(); + when(call.getMethodDescriptor()).thenReturn(method); + ServerCallHandler next = (c, headers) -> { + throw new RuntimeException("boom"); + }; + interceptor.interceptCall(call, new Metadata(), next); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(io.grpc.Status.class); + ArgumentCaptor trailersCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(call, times(1)).close(statusCaptor.capture(), trailersCaptor.capture()); + io.grpc.Status closedStatus = statusCaptor.getValue(); + Metadata closedTrailers = trailersCaptor.getValue(); + assertThat(closedStatus.getCode()).isEqualTo(io.grpc.Status.Code.PERMISSION_DENIED); + Status extracted = StatusProto.fromThrowable(new StatusException(closedStatus, closedTrailers)); + assertThat(extracted).isNotNull(); + assertThat(extracted).isEqualTo(statusWithDetails); + } + }