Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ public static bool ValidateBuildTarget(this IDurableTaskWorkerBuilder builder)
}

#pragma warning disable CS9113 // Parameter is unread. Suppressed to let a breaking change get fixed before we remove this parameter.
private class Worker(string name, IDurableTaskFactory factory, IExceptionPropertiesProvider? provider = null) : DurableTaskWorker(name, factory)
private class Worker(string name, IDurableTaskFactory factory, IExceptionPropertiesProvider? provider = null)
: DurableTaskWorker(name, TypeHintingDurableTaskFactory.WrapIfNeeded(factory))
{
public new IDurableTaskFactory Factory => base.Factory;

Expand All @@ -57,4 +58,4 @@ protected override Task ExecuteAsync(CancellationToken stoppingToken)
}
}
#pragma warning restore CS9113 // Parameter is unread.
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
using System.Diagnostics.CodeAnalysis;
using System.Threading.Tasks;
using Microsoft.DurableTask;
using Microsoft.DurableTask.Entities;
using Microsoft.DurableTask.Worker;

namespace Microsoft.Azure.Functions.Worker.Extensions.DurableTask.Execution;

internal sealed class TypeHintingDurableTaskFactory : IDurableTaskFactory2
{
private readonly IDurableTaskFactory inner;
private readonly IDurableTaskFactory2? inner2;

public TypeHintingDurableTaskFactory(IDurableTaskFactory inner)
{
this.inner = inner ?? throw new ArgumentNullException(nameof(inner));
this.inner2 = inner as IDurableTaskFactory2;
}

public bool TryCreateActivity(TaskName name, IServiceProvider services, [NotNullWhen(true)] out ITaskActivity? activity)
{
if (!this.inner.TryCreateActivity(name, services, out activity))
{
return false;
}

activity = SerializationHintTaskActivity.Wrap(activity);
return true;
}

public bool TryCreateOrchestrator(
TaskName name,
IServiceProvider services,
[NotNullWhen(true)] out ITaskOrchestrator? orchestrator)
{
return this.inner.TryCreateOrchestrator(name, services, out orchestrator);
}

public bool TryCreateEntity(TaskName name, IServiceProvider services, [NotNullWhen(true)] out ITaskEntity? entity)
{
if (this.inner2 is null)
{
entity = null;
return false;
}

return this.inner2.TryCreateEntity(name, services, out entity);
}

internal static IDurableTaskFactory WrapIfNeeded(IDurableTaskFactory factory)
{
if (factory is TypeHintingDurableTaskFactory)
{
return factory;
}

return new TypeHintingDurableTaskFactory(factory);
}

internal sealed class SerializationHintTaskActivity : ITaskActivity
{
private readonly ITaskActivity inner;

private SerializationHintTaskActivity(ITaskActivity inner)
{
this.inner = inner ?? throw new ArgumentNullException(nameof(inner));
}

public Type InputType => this.inner.InputType;

public Type OutputType => this.inner.OutputType;

public Task<object?> RunAsync(TaskActivityContext context, object? input)
{
return this.RunWithHintAsync(context, input);
}

internal static ITaskActivity Wrap(ITaskActivity activity)
{
if (activity is SerializationHintTaskActivity)
{
return activity;
}

return new SerializationHintTaskActivity(activity);
}

private async Task<object?> RunWithHintAsync(TaskActivityContext context, object? input)
{
object? result = await this.inner.RunAsync(context, input).ConfigureAwait(false);
return ObjectConverterShim.WithDeclaredType(result, this.OutputType);
}
}
}
35 changes: 34 additions & 1 deletion src/Worker.Extensions.DurableTask/ObjectConverterShim.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,40 @@ public ObjectConverterShim(ObjectSerializer serializer)
return null;
}

BinaryData data = this.serializer.Serialize(value, value.GetType(), default);
if (value is SerializationHint hint)
{
return this.Serialize(hint.Value, hint.DeclaredType);
}

return this.Serialize(value, value.GetType());
}

internal static SerializationHint WithDeclaredType(object? value, Type? declaredType)
{
return new SerializationHint(value, declaredType);
}

private string? Serialize(object? value, Type? declaredType)
{
if (value is null)
{
return null;
}

BinaryData data = this.serializer.Serialize(value, declaredType ?? value.GetType(), default);
return data.ToString();
}

internal sealed class SerializationHint
{
public SerializationHint(object? value, Type? declaredType)
{
this.Value = value;
this.DeclaredType = declaredType;
}

public object? Value { get; }

public Type? DeclaredType { get; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Azure.Core.Serialization;
using Microsoft.Azure.Functions.Worker.Extensions.DurableTask;
using Microsoft.Azure.Functions.Worker.Extensions.DurableTask.Execution;

namespace Microsoft.Azure.Functions.Worker.Tests;

public class ObjectConverterShimTests
{
[Fact]
public void Serialize_UsesDeclaredTypeHintWhenProvided()
{
JsonSerializerOptions options = CreateOptions();
JsonObjectSerializer serializer = new(options);
ObjectConverterShim converter = new(serializer);

DerivedResponse value = new() { Field1 = 42, Field2 = 99 };

string runtimeJson = JsonSerializer.Serialize(value, options);
string declaredJson = JsonSerializer.Serialize<BaseResponse>(value, options);

Assert.NotEqual(declaredJson, runtimeJson);
Assert.Equal(runtimeJson, converter.Serialize(value));

string hintedJson = converter.Serialize(
ObjectConverterShim.WithDeclaredType(value, typeof(BaseResponse)))!;

Assert.Equal(declaredJson, hintedJson);
Assert.Equal(runtimeJson, converter.Serialize(value));
}

private static JsonSerializerOptions CreateOptions()
{
DefaultJsonTypeInfoResolver resolver = new();
resolver.Modifiers.Add(info =>
{
if (info.Type == typeof(BaseResponse))
{
info.PolymorphismOptions = new JsonPolymorphismOptions
{
TypeDiscriminatorPropertyName = "type",
UnknownDerivedTypeHandling = JsonUnknownDerivedTypeHandling.FailSerialization,
};
info.PolymorphismOptions.DerivedTypes.Add(
new JsonDerivedType(typeof(DerivedResponse), "derived"));
}
});

return new JsonSerializerOptions
{
TypeInfoResolver = resolver,
};
}

private class BaseResponse
{
public int Field1 { get; set; }
}

private class DerivedResponse : BaseResponse
{
public int Field2 { get; set; }
}
}
Loading