Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ private MethodBodyStatement CreateXmlWriteAttributeStatement(XmlPropertyInfo pro
var xmlWireInfo = prop.XmlWireInfo;
if (xmlWireInfo.Namespace != null && namespaces?.TryGetValue(xmlWireInfo.Namespace.Namespace, out var nsInfo) == true)
{
var stringValue = CreateXmlSerializeValueExpression(prop.SerializationExp, prop.PropertyType, prop.SerializationFormat);
var writeStatement = _xmlWriterSnippet.WriteAttributeString(
nsInfo.Prefix,
xmlWireInfo.Name,
Expand Down Expand Up @@ -268,9 +267,7 @@ private MethodBodyStatement CreateXmlWriteValueStatement(ValueExpression value,

if (!underlyingType.IsFrameworkType)
{
return underlyingType.IsEnum
? _xmlWriterSnippet.WriteValue(CreateXmlSerializeValueExpression(value, valueType, serializationFormat))
: _xmlWriterSnippet.WriteObjectValue(value.As(valueType), _serializationOptionsParameter);
return ScmCodeModelGenerator.Instance.TypeFactory.SerializeXmlValue(valueType, value, _xmlWriterSnippet, _mrwOptionsParameterSnippet, serializationFormat);
}

return underlyingType.FrameworkType switch
Expand All @@ -282,7 +279,7 @@ Type t when (t == typeof(byte[]) || t == typeof(BinaryData)) && serializationFor
? value.As<BinaryData>().ToArray()
: value.NullableStructValue(valueType),
serializationFormat.ToFormatSpecifier()),
_ => _xmlWriterSnippet.WriteValue(CreateXmlSerializeValueExpression(value, valueType, serializationFormat))
_ => ScmCodeModelGenerator.Instance.TypeFactory.SerializeXmlValue(valueType, value, _xmlWriterSnippet, _mrwOptionsParameterSnippet, serializationFormat)
};
}

Expand Down Expand Up @@ -430,27 +427,8 @@ private MethodBodyStatement CreateXmlWriteDictionaryEntryStatement(

private MethodBodyStatement CreateXmlWriteTextContentStatement(XmlPropertyInfo prop)
{
var serializedValue = CreateXmlSerializeValueExpression(prop.SerializationExp, prop.PropertyType, prop.SerializationFormat);
return WrapInIsDefinedCheck(prop, _xmlWriterSnippet.WriteValue(serializedValue));
}

private ValueExpression CreateXmlSerializeValueExpression(ValueExpression value, CSharpType valueType, SerializationFormat serializationFormat)
{
var underlyingType = valueType.IsNullable && valueType.Arguments.Count > 0
? valueType.Arguments[0]
: valueType;

if (underlyingType.IsEnum)
{
return underlyingType.ToSerial(value.NullableStructValue(valueType));
}

if (!underlyingType.IsFrameworkType)
{
return value;
}

return CreateXmlSerializePrimitiveExpression(value.NullableStructValue(valueType), underlyingType, serializationFormat);
var writeStatement = ScmCodeModelGenerator.Instance.TypeFactory.SerializeXmlValue(prop.PropertyType, prop.SerializationExp, _xmlWriterSnippet, _mrwOptionsParameterSnippet, prop.SerializationFormat);
return WrapInIsDefinedCheck(prop, writeStatement);
}

private static ValueExpression CreateXmlSerializePrimitiveExpression(ValueExpression value, CSharpType valueType, SerializationFormat serializationFormat)
Expand Down Expand Up @@ -886,7 +864,7 @@ private MethodBodyStatement CreateXmlDeserializePropertyAssignment(
return CreateXmlDeserializeDictionaryAssignment(childElement, propertyType, propertyExpression, xmlWireInfo, serializationFormat);
}

var deserializedValue = CreateXmlDeserializeValueExpression(childElement, propertyType, serializationFormat);
var deserializedValue = ScmCodeModelGenerator.Instance.TypeFactory.DeserializeXmlValue(propertyType, childElement, _mrwOptionsParameterSnippet, serializationFormat);
return propertyExpression.Assign(deserializedValue).Terminate();
}

Expand Down Expand Up @@ -1028,30 +1006,10 @@ private MethodBodyStatement DeserializeXmlValue(
return new MethodBodyStatement[] { dictDeclaration, foreachStatement };
}

value = CreateXmlDeserializeValueExpression(element, valueType, serializationFormat);
value = ScmCodeModelGenerator.Instance.TypeFactory.DeserializeXmlValue(valueType, element, _mrwOptionsParameterSnippet, serializationFormat);
return MethodBodyStatement.Empty;
}

private ValueExpression CreateXmlDeserializeValueExpression(ScopedApi<XElement> element, CSharpType valueType, SerializationFormat serializationFormat)
{
var underlyingType = valueType.IsNullable && valueType.Arguments.Count > 0
? valueType.Arguments[0]
: valueType;

if (underlyingType.IsEnum && underlyingType.UnderlyingEnumType != null)
{
var underlyingExpression = CreateXmlDeserializePrimitiveExpression(element, underlyingType.UnderlyingEnumType, serializationFormat);
return underlyingType.ToEnum(underlyingExpression);
}

if (!underlyingType.IsFrameworkType)
{
return GetDeserializationMethodInvocationForType(underlyingType, element, null, _serializationOptionsParameter);
}

return CreateXmlDeserializePrimitiveExpression(element, valueType, serializationFormat);
}

private static ValueExpression CreateXmlDeserializePrimitiveExpression(
ScopedApi<XElement> element,
CSharpType valueType,
Expand All @@ -1076,6 +1034,54 @@ private static ValueExpression CreateXmlDeserializePrimitiveExpression(
};
}

internal static ValueExpression DeserializeXmlValueCore(
Comment thread
jorgerangel-msft marked this conversation as resolved.
CSharpType valueType,
ScopedApi<XElement> element,
ScopedApi<ModelReaderWriterOptions> mrwOptions,
SerializationFormat format)
{
var underlyingType = valueType.IsNullable && valueType.Arguments.Count > 0
? valueType.Arguments[0]
: valueType;

if (underlyingType.IsEnum && underlyingType.UnderlyingEnumType != null)
{
var underlyingExpression = CreateXmlDeserializePrimitiveExpression(element, underlyingType.UnderlyingEnumType, format);
return underlyingType.ToEnum(underlyingExpression);
}

if (!underlyingType.IsFrameworkType)
{
return GetDeserializationMethodInvocationForType(underlyingType, element, null, mrwOptions);
}

return CreateXmlDeserializePrimitiveExpression(element, valueType, format);
}

internal static MethodBodyStatement SerializeXmlValueCore(
CSharpType valueType,
ValueExpression value,
ScopedApi<XmlWriter> xmlWriter,
ScopedApi<ModelReaderWriterOptions> mrwOptionsParameter,
SerializationFormat serializationFormat)
{
var underlyingType = valueType.IsNullable && valueType.Arguments.Count > 0
? valueType.Arguments[0]
: valueType;

if (underlyingType.IsEnum)
{
return xmlWriter.WriteValue(underlyingType.ToSerial(value.NullableStructValue(valueType)));
}

if (!underlyingType.IsFrameworkType)
{
return xmlWriter.WriteObjectValue(value.As(valueType), mrwOptionsParameter);
}

return xmlWriter.WriteValue(CreateXmlSerializePrimitiveExpression(value.NullableStructValue(valueType), underlyingType, serializationFormat));
}

private MethodBodyStatement CreateXmlDeserializeAttributeStatements(
ScopedApi<XAttribute> attrVariable,
List<XmlPropertyInfo> attributeProperties,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Xml;
using System.Xml.Linq;
using Microsoft.TypeSpec.Generator.ClientModel.Primitives;
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
using Microsoft.TypeSpec.Generator.Expressions;
Expand Down Expand Up @@ -246,6 +248,21 @@ public virtual MethodBodyStatement SerializeJsonValue(
SerializationFormat serializationFormat)
=> MrwSerializationTypeDefinition.SerializeJsonValueCore(valueType, value, utf8JsonWriter, mrwOptionsParameter, serializationFormat);

public virtual ValueExpression DeserializeXmlValue(
Comment thread
jorgerangel-msft marked this conversation as resolved.
CSharpType valueType,
ScopedApi<XElement> element,
ScopedApi<ModelReaderWriterOptions> mrwOptionsParameter,
SerializationFormat format)
=> MrwSerializationTypeDefinition.DeserializeXmlValueCore(valueType, element, mrwOptionsParameter, format);

public virtual MethodBodyStatement SerializeXmlValue(
CSharpType valueType,
ValueExpression value,
ScopedApi<XmlWriter> xmlWriter,
ScopedApi<ModelReaderWriterOptions> mrwOptionsParameter,
SerializationFormat format)
=> MrwSerializationTypeDefinition.SerializeXmlValueCore(valueType, value, xmlWriter, mrwOptionsParameter, format);

protected override ModelProvider? CreateModelCore(InputModelType model) => new ScmModelProvider(model);

protected override ScmSerializationOptions? CreateSerializationOptionsCore(InputSerializationOptions inputSerializationOptions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1464,5 +1464,6 @@ public void TestDeserializationOfNonBase64ByteArrayPropertyUsesGetRawText()
Assert.IsFalse(methodBody.Contains("EnumerateArray"),
$"byte[] property should not use array enumeration. Actual:\n{methodBody}");
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Xml.Linq;
using Sample.Models;

namespace Sample
{
public partial class TestXmlModel
{
internal static global::Sample.Models.TestXmlModel DeserializeTestXmlModel(global::System.Xml.Linq.XElement element, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
if ((element == null))
{
return null;
}

string name = default;
global::System.Collections.Generic.IDictionary<string, global::System.BinaryData> additionalBinaryDataProperties = new global::Sample.ChangeTrackingDictionary<string, global::System.BinaryData>();

foreach (var child in element.Elements())
{
string localName = child.Name.LocalName;
if ((localName == "Name"))
{
name = child.ToString();
continue;
}
}
return new global::Sample.Models.TestXmlModel(name, additionalBinaryDataProperties);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel.Primitives;
using System.Xml;
using Sample.Models;

namespace Sample
{
public partial class TestXmlModel
{
internal virtual void XmlModelWriteCore(global::System.Xml.XmlWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
Comment thread
jorgerangel-msft marked this conversation as resolved.
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.TestXmlModel>)this).GetFormatFromOptions(options) : options.Format;
if ((format != "X"))
{
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.TestXmlModel)} does not support writing '{format}' format.");
}

if (global::Sample.Optional.IsDefined(Name))
{
writer.WriteStartElement("Name");
writer.WriteValue(Name.ToString());
writer.WriteEndElement();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.Xml.Linq;
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
using Microsoft.TypeSpec.Generator.Expressions;
using Microsoft.TypeSpec.Generator.Input;
using Microsoft.TypeSpec.Generator.Primitives;
using Microsoft.TypeSpec.Generator.Providers;
using Microsoft.TypeSpec.Generator.Snippets;
using Microsoft.TypeSpec.Generator.Tests.Common;
using Moq;
using NUnit.Framework;

namespace Microsoft.TypeSpec.Generator.ClientModel.Tests.Providers.MrwSerializationTypeDefinitions
Expand Down Expand Up @@ -601,5 +605,82 @@ protected override MethodProvider[] BuildMethods()

protected override FieldProvider[] BuildFields() => [];
}

[TestCase(typeof(int), SerializationFormat.Default, ExpectedResult = "((int)foo)")]
[TestCase(typeof(string), SerializationFormat.Default, ExpectedResult = "((string)foo)")]
[TestCase(typeof(bool), SerializationFormat.Default, ExpectedResult = "((bool)foo)")]
[TestCase(typeof(long), SerializationFormat.Default, ExpectedResult = "((long)foo)")]
[TestCase(typeof(float), SerializationFormat.Default, ExpectedResult = "((float)foo)")]
[TestCase(typeof(double), SerializationFormat.Default, ExpectedResult = "((double)foo)")]
[TestCase(typeof(byte), SerializationFormat.Default, ExpectedResult = "((byte)((int)foo))")]
[TestCase(typeof(sbyte), SerializationFormat.Default, ExpectedResult = "((sbyte)((int)foo))")]
[TestCase(typeof(short), SerializationFormat.Default, ExpectedResult = "((short)((int)foo))")]
public string DeserializeXmlValueCore_PrimitiveTypes(Type type, SerializationFormat format)
{
var expr = MrwSerializationTypeDefinition.DeserializeXmlValueCore(
type,
new ScopedApi<XElement>(new VariableExpression(typeof(XElement), "foo")),
new ScopedApi<ModelReaderWriterOptions>(new VariableExpression(typeof(ModelReaderWriterOptions), "options")),
format);
return expr.ToDisplayString();
}

[TestCase(SerializationFormat.DateTime_ISO8601, ExpectedResult = "foo.GetDateTimeOffset(\"O\")")]
[TestCase(SerializationFormat.DateTime_RFC1123, ExpectedResult = "foo.GetDateTimeOffset(\"R\")")]
[TestCase(SerializationFormat.DateTime_RFC3339, ExpectedResult = "foo.GetDateTimeOffset(\"O\")")]
public string DeserializeXmlValueCore_DateTimeOffset(SerializationFormat format)
{
var expr = MrwSerializationTypeDefinition.DeserializeXmlValueCore(
typeof(DateTimeOffset),
new ScopedApi<XElement>(new VariableExpression(typeof(XElement), "foo")),
new ScopedApi<ModelReaderWriterOptions>(new VariableExpression(typeof(ModelReaderWriterOptions), "options")),
format);
return expr.ToDisplayString();
}

[TestCase(SerializationFormat.Duration_ISO8601, ExpectedResult = "foo.GetTimeSpan(\"P\")")]
[TestCase(SerializationFormat.Duration_Constant, ExpectedResult = "foo.GetTimeSpan(\"c\")")]
public string DeserializeXmlValueCore_TimeSpan(SerializationFormat format)
{
var expr = MrwSerializationTypeDefinition.DeserializeXmlValueCore(
typeof(TimeSpan),
new ScopedApi<XElement>(new VariableExpression(typeof(XElement), "foo")),
new ScopedApi<ModelReaderWriterOptions>(new VariableExpression(typeof(ModelReaderWriterOptions), "options")),
format);
return expr.ToDisplayString();
}

[Test]
public void DeserializeXmlValueOverride_CustomTypeDeserialization()
{
var inputModel = InputFactory.Model(
"TestXmlModel",
usage: InputModelTypeUsage.Input | InputModelTypeUsage.Xml,
properties: [InputFactory.Property("Name", InputPrimitiveType.String,
serializationOptions: InputFactory.Serialization.Options(xml: InputFactory.Serialization.Xml("Name")))]);

var mockGenerator = MockHelpers.LoadMockGenerator(
inputModels: () => [inputModel]);

// override DeserializeXmlValue to return a custom expression for string types
var mockTypeFactory = Mock.Get((ScmTypeFactory)mockGenerator.Object.TypeFactory);
mockTypeFactory.Setup(p => p.DeserializeXmlValue(
It.Is<CSharpType>(t => t.FrameworkType == typeof(string)),
It.IsAny<ScopedApi<XElement>>(),
It.IsAny<ScopedApi<ModelReaderWriterOptions>>(),
It.IsAny<SerializationFormat>()))
.Returns((CSharpType type, ScopedApi<XElement> element, ScopedApi<ModelReaderWriterOptions> mrwOptions, SerializationFormat format) =>
element.InvokeToString());

var modelProvider = mockGenerator.Object.OutputLibrary.TypeProviders.Single(t => t is ModelProvider && t.Name == "TestXmlModel");
var serializationProvider = modelProvider.SerializationProviders.Single(t => t is MrwSerializationTypeDefinition);
Assert.IsNotNull(serializationProvider);

var writer = new TypeProviderWriter(new FilteredMethodsTypeProvider(
serializationProvider,
name => name == "DeserializeTestXmlModel"));
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}
}
}
Loading
Loading