diff --git a/src/Worker.Extensions.DurableTask/Execution/DurableWorkerBuilderExtensions.cs b/src/Worker.Extensions.DurableTask/Execution/DurableWorkerBuilderExtensions.cs index 746a6ac56..2e54323dd 100644 --- a/src/Worker.Extensions.DurableTask/Execution/DurableWorkerBuilderExtensions.cs +++ b/src/Worker.Extensions.DurableTask/Execution/DurableWorkerBuilderExtensions.cs @@ -47,7 +47,8 @@ public static bool ValidateBuildTarget(this IDurableTaskWorkerBuilder builder) } #pragma warning disable CS9113 // Parameter is unread. Suppressed to let a breaking change get fixed before we remove this parameter. - private class Worker(string name, IDurableTaskFactory factory, IExceptionPropertiesProvider? provider = null) : DurableTaskWorker(name, factory) + private class Worker(string name, IDurableTaskFactory factory, IExceptionPropertiesProvider? provider = null) + : DurableTaskWorker(name, TypeHintingDurableTaskFactory.WrapIfNeeded(factory)) { public new IDurableTaskFactory Factory => base.Factory; @@ -57,4 +58,4 @@ protected override Task ExecuteAsync(CancellationToken stoppingToken) } } #pragma warning restore CS9113 // Parameter is unread. -} \ No newline at end of file +} diff --git a/src/Worker.Extensions.DurableTask/Execution/TypeHintingDurableTaskFactory.cs b/src/Worker.Extensions.DurableTask/Execution/TypeHintingDurableTaskFactory.cs new file mode 100644 index 000000000..0fe3b16fa --- /dev/null +++ b/src/Worker.Extensions.DurableTask/Execution/TypeHintingDurableTaskFactory.cs @@ -0,0 +1,98 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Entities; +using Microsoft.DurableTask.Worker; + +namespace Microsoft.Azure.Functions.Worker.Extensions.DurableTask.Execution; + +internal sealed class TypeHintingDurableTaskFactory : IDurableTaskFactory2 +{ + private readonly IDurableTaskFactory inner; + private readonly IDurableTaskFactory2? inner2; + + public TypeHintingDurableTaskFactory(IDurableTaskFactory inner) + { + this.inner = inner ?? throw new ArgumentNullException(nameof(inner)); + this.inner2 = inner as IDurableTaskFactory2; + } + + public bool TryCreateActivity(TaskName name, IServiceProvider services, [NotNullWhen(true)] out ITaskActivity? activity) + { + if (!this.inner.TryCreateActivity(name, services, out activity)) + { + return false; + } + + activity = SerializationHintTaskActivity.Wrap(activity); + return true; + } + + public bool TryCreateOrchestrator( + TaskName name, + IServiceProvider services, + [NotNullWhen(true)] out ITaskOrchestrator? orchestrator) + { + return this.inner.TryCreateOrchestrator(name, services, out orchestrator); + } + + public bool TryCreateEntity(TaskName name, IServiceProvider services, [NotNullWhen(true)] out ITaskEntity? entity) + { + if (this.inner2 is null) + { + entity = null; + return false; + } + + return this.inner2.TryCreateEntity(name, services, out entity); + } + + internal static IDurableTaskFactory WrapIfNeeded(IDurableTaskFactory factory) + { + if (factory is TypeHintingDurableTaskFactory) + { + return factory; + } + + return new TypeHintingDurableTaskFactory(factory); + } + + internal sealed class SerializationHintTaskActivity : ITaskActivity + { + private readonly ITaskActivity inner; + + private SerializationHintTaskActivity(ITaskActivity inner) + { + this.inner = inner ?? throw new ArgumentNullException(nameof(inner)); + } + + public Type InputType => this.inner.InputType; + + public Type OutputType => this.inner.OutputType; + + public Task RunAsync(TaskActivityContext context, object? input) + { + return this.RunWithHintAsync(context, input); + } + + internal static ITaskActivity Wrap(ITaskActivity activity) + { + if (activity is SerializationHintTaskActivity) + { + return activity; + } + + return new SerializationHintTaskActivity(activity); + } + + private async Task RunWithHintAsync(TaskActivityContext context, object? input) + { + object? result = await this.inner.RunAsync(context, input).ConfigureAwait(false); + return ObjectConverterShim.WithDeclaredType(result, this.OutputType); + } + } +} diff --git a/src/Worker.Extensions.DurableTask/ObjectConverterShim.cs b/src/Worker.Extensions.DurableTask/ObjectConverterShim.cs index 9f8abd884..98f10d4c8 100644 --- a/src/Worker.Extensions.DurableTask/ObjectConverterShim.cs +++ b/src/Worker.Extensions.DurableTask/ObjectConverterShim.cs @@ -39,7 +39,40 @@ public ObjectConverterShim(ObjectSerializer serializer) return null; } - BinaryData data = this.serializer.Serialize(value, value.GetType(), default); + if (value is SerializationHint hint) + { + return this.Serialize(hint.Value, hint.DeclaredType); + } + + return this.Serialize(value, value.GetType()); + } + + internal static SerializationHint WithDeclaredType(object? value, Type? declaredType) + { + return new SerializationHint(value, declaredType); + } + + private string? Serialize(object? value, Type? declaredType) + { + if (value is null) + { + return null; + } + + BinaryData data = this.serializer.Serialize(value, declaredType ?? value.GetType(), default); return data.ToString(); } + + internal sealed class SerializationHint + { + public SerializationHint(object? value, Type? declaredType) + { + this.Value = value; + this.DeclaredType = declaredType; + } + + public object? Value { get; } + + public Type? DeclaredType { get; } + } } diff --git a/test/Worker.Extensions.DurableTask.Tests/ObjectConverterShimTests.cs b/test/Worker.Extensions.DurableTask.Tests/ObjectConverterShimTests.cs new file mode 100644 index 000000000..1f0e19758 --- /dev/null +++ b/test/Worker.Extensions.DurableTask.Tests/ObjectConverterShimTests.cs @@ -0,0 +1,69 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +using Azure.Core.Serialization; +using Microsoft.Azure.Functions.Worker.Extensions.DurableTask; +using Microsoft.Azure.Functions.Worker.Extensions.DurableTask.Execution; + +namespace Microsoft.Azure.Functions.Worker.Tests; + +public class ObjectConverterShimTests +{ + [Fact] + public void Serialize_UsesDeclaredTypeHintWhenProvided() + { + JsonSerializerOptions options = CreateOptions(); + JsonObjectSerializer serializer = new(options); + ObjectConverterShim converter = new(serializer); + + DerivedResponse value = new() { Field1 = 42, Field2 = 99 }; + + string runtimeJson = JsonSerializer.Serialize(value, options); + string declaredJson = JsonSerializer.Serialize(value, options); + + Assert.NotEqual(declaredJson, runtimeJson); + Assert.Equal(runtimeJson, converter.Serialize(value)); + + string hintedJson = converter.Serialize( + ObjectConverterShim.WithDeclaredType(value, typeof(BaseResponse)))!; + + Assert.Equal(declaredJson, hintedJson); + Assert.Equal(runtimeJson, converter.Serialize(value)); + } + + private static JsonSerializerOptions CreateOptions() + { + DefaultJsonTypeInfoResolver resolver = new(); + resolver.Modifiers.Add(info => + { + if (info.Type == typeof(BaseResponse)) + { + info.PolymorphismOptions = new JsonPolymorphismOptions + { + TypeDiscriminatorPropertyName = "type", + UnknownDerivedTypeHandling = JsonUnknownDerivedTypeHandling.FailSerialization, + }; + info.PolymorphismOptions.DerivedTypes.Add( + new JsonDerivedType(typeof(DerivedResponse), "derived")); + } + }); + + return new JsonSerializerOptions + { + TypeInfoResolver = resolver, + }; + } + + private class BaseResponse + { + public int Field1 { get; set; } + } + + private class DerivedResponse : BaseResponse + { + public int Field2 { get; set; } + } +}