diff --git a/src/IceRpc/Internal/IceProtocolConnection.cs b/src/IceRpc/Internal/IceProtocolConnection.cs index cb3205323..4201a3fd1 100644 --- a/src/IceRpc/Internal/IceProtocolConnection.cs +++ b/src/IceRpc/Internal/IceProtocolConnection.cs @@ -1052,7 +1052,9 @@ private async Task DispatchRequestAsync(IncomingRequest request, int requestId, { if (exception is not DispatchException dispatchException) { - dispatchException = new DispatchException(StatusCode.InternalError, innerException: exception); + StatusCode statusCode = exception is InvalidDataException ? + StatusCode.InvalidData : StatusCode.InternalError; + dispatchException = new DispatchException(statusCode, innerException: exception); } response = dispatchException.ToOutgoingResponse(request); } diff --git a/tests/IceRpc.Tests/IceProtocolConnectionTests.cs b/tests/IceRpc.Tests/IceProtocolConnectionTests.cs index ad2ecb0bb..b70cefad6 100644 --- a/tests/IceRpc.Tests/IceProtocolConnectionTests.cs +++ b/tests/IceRpc.Tests/IceProtocolConnectionTests.cs @@ -145,6 +145,36 @@ public async Task Dispatcher_failure( Assert.That(readResult.Buffer.IsEmpty, Is.True); } + /// Verifies that exceptions thrown by the dispatcher are classified into the right status code: + /// InvalidDataException surfaces as StatusCode.InvalidData, anything else as StatusCode.InternalError. + [TestCase("InvalidData", StatusCode.InvalidData)] + [TestCase("Other", StatusCode.InternalError)] + public async Task Thrown_dispatcher_exception_is_classified(string exceptionKind, StatusCode expectedStatusCode) + { + // Arrange + var dispatcher = new InlineDispatcher((request, cancellationToken) => exceptionKind switch + { + "InvalidData" => throw new InvalidDataException("boom"), + _ => throw new InvalidOperationException("boom") + }); + + await using ServiceProvider provider = new ServiceCollection() + .AddProtocolTest(Protocol.Ice, dispatcher) + .BuildServiceProvider(validateScopes: true); + var sut = provider.GetRequiredService(); + await sut.ConnectAsync(); + using var request = new OutgoingRequest(new ServiceAddress(Protocol.Ice) { Path = "/foo" }) + { + Operation = "op" + }; + + // Act + IncomingResponse response = await sut.Client.InvokeAsync(request); + + // Assert + Assert.That(response.StatusCode, Is.EqualTo(expectedStatusCode)); + } + /// Verifies that a StatusCode dispatched by the server is encoded as a ReplyStatus and decoded back to /// the expected StatusCode by the client. [Test, TestCaseSource(nameof(StatusCodeRoundTripSource))] diff --git a/tests/IceRpc.Tests/ProtocolConnectionTests.cs b/tests/IceRpc.Tests/ProtocolConnectionTests.cs index 7ddcf9af1..2bf48bec7 100644 --- a/tests/IceRpc.Tests/ProtocolConnectionTests.cs +++ b/tests/IceRpc.Tests/ProtocolConnectionTests.cs @@ -33,7 +33,7 @@ private static IEnumerable DispatcherThrowsExceptionSource } yield return new(Protocol.IceRpc, new InvalidDataException("invalid data"), StatusCode.InvalidData); - yield return new(Protocol.Ice, new InvalidDataException("invalid data"), StatusCode.InternalError); + yield return new(Protocol.Ice, new InvalidDataException("invalid data"), StatusCode.InvalidData); yield return new( Protocol.IceRpc,