Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions NetSerializer/Primitives.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ public static class Primitives
private const int StringByteBufferLength = 256;
private const int StringCharBufferLength = 128;

public static uint MaxByteArrayLength = 16 * 1024 * 1024;
public static uint MaxStringLength = 16 * 1024 * 1024;

public static MethodInfo GetWritePrimitive(Type type)
{
return typeof(Primitives).GetMethod("WritePrimitive",
Expand Down Expand Up @@ -547,6 +550,8 @@ public static void ReadPrimitive(Stream stream, out string value)
totalBytes -= 1;

ReadPrimitive(stream, out uint totalChars);
if (totalChars > MaxStringLength)
throw new InvalidDataException($"Serialized string length {totalChars} exceeds maximum {MaxStringLength}.");

value = string.Create((int) totalChars, ((int) totalBytes, stream), _stringSpanRead);
}
Expand Down Expand Up @@ -645,6 +650,8 @@ public static void ReadPrimitive(Stream stream, out string value)

uint totalChars;
ReadPrimitive(stream, out totalChars);
if (totalChars > MaxStringLength)
throw new InvalidDataException($"Serialized string length {totalChars} exceeds maximum {MaxStringLength}.");

len -= 1;

Expand Down Expand Up @@ -759,6 +766,8 @@ public static void ReadPrimitive(Stream stream, out string value)

uint totalChars;
ReadPrimitive(stream, out totalChars);
if (totalChars > MaxStringLength)
throw new InvalidDataException($"Serialized string length {totalChars} exceeds maximum {MaxStringLength}.");

var helper = s_stringHelper;
if (helper == null)
Expand Down Expand Up @@ -840,6 +849,8 @@ public static void ReadPrimitive(Stream stream, out byte[] value)
}

len -= 1;
if (len > MaxByteArrayLength)
throw new InvalidDataException($"Serialized byte array length {len} exceeds maximum {MaxByteArrayLength}.");

value = new byte[len];
int l = 0;
Expand Down
70 changes: 69 additions & 1 deletion NetSerializer/Serializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ public Serializer(IEnumerable<Type> rootTypes)
public Serializer(IEnumerable<Type> rootTypes, Settings settings)
{
this.Settings = settings;
Primitives.MaxByteArrayLength = settings.MaxByteArrayLength;
Primitives.MaxStringLength = settings.MaxStringLength;

if (this.Settings.CustomTypeSerializers.All(s => s is IDynamicTypeSerializer || s is IStaticTypeSerializer) == false)
throw new ArgumentException("TypeSerializers have to implement IDynamicTypeSerializer or IStaticTypeSerializer");
Expand Down Expand Up @@ -93,6 +95,8 @@ public Serializer(Dictionary<Type, uint> typeMap)
public Serializer(Dictionary<Type, uint> typeMap, Settings settings)
{
this.Settings = settings;
Primitives.MaxByteArrayLength = settings.MaxByteArrayLength;
Primitives.MaxStringLength = settings.MaxStringLength;

if (this.Settings.CustomTypeSerializers.All(s => s is IDynamicTypeSerializer || s is IStaticTypeSerializer) == false)
throw new ArgumentException("TypeSerializers have to implement IDynamicTypeSerializer or IStaticTypeSerializer");
Expand Down Expand Up @@ -337,6 +341,35 @@ public void Deserialize(Stream stream, out object ob)
ObjectSerializer.Deserialize(this, stream, out ob);
}

public bool TryDeserialize(Stream stream, out object ob)
{
return ObjectSerializer.TryDeserialize(this, stream, out ob);
}

public bool TryGetTypeFromSerializedObject(Stream stream, out Type type)
{
Primitives.ReadPrimitive(stream, out uint id);
return TryGetTypeFromId(id, out type);
}

public bool TryGetTypeFromId(uint id, out Type type)
{
if (id == 0)
{
type = null;
return true;
}

if (m_runtimeTypeIDList.TryGetValue(id, out var data))
{
type = data.Type;
return true;
}

type = null;
return false;
}

/// <summary>
/// Serialize object graph without writing the type-id of the root type. This can be useful e.g. when
/// serializing a known value type, as this will avoid boxing.
Expand Down Expand Up @@ -371,6 +404,19 @@ public void DeserializeDirect<T>(Stream stream, out T value)
del(this, stream, out value);
}

internal void CheckCollectionLength(uint encodedLength)
{
if (encodedLength == 0)
return;

var length = encodedLength - 1;
if (length > Settings.MaxCollectionLength)
{
throw new InvalidDataException(
$"Serialized collection length {length} exceeds maximum {Settings.MaxCollectionLength}.");
}
}

public int RegisterContext(object context)
{
AssertLocked();
Expand Down Expand Up @@ -412,7 +458,8 @@ internal uint GetTypeIdAndSerializer(Type type, out SerializeDelegate<object> de

internal DeserializeDelegate<object> GetDeserializeTrampolineFromId(uint id)
{
var data = m_runtimeTypeIDList[id];
if (!m_runtimeTypeIDList.TryGetValue(id, out var data))
throw new InvalidDataException($"Unknown serialized type ID {id}.");

if (data.ReaderTrampolineDelegate != null)
return data.ReaderTrampolineDelegate;
Expand All @@ -423,6 +470,27 @@ internal DeserializeDelegate<object> GetDeserializeTrampolineFromId(uint id)
}
}

internal bool TryGetDeserializeTrampolineFromId(uint id, out DeserializeDelegate<object> del)
{
if (!m_runtimeTypeIDList.TryGetValue(id, out var data))
{
del = null;
return false;
}

if (data.ReaderTrampolineDelegate != null)
{
del = data.ReaderTrampolineDelegate;
return true;
}

lock (m_modifyLock)
{
del = GenerateReaderTrampoline(data.Type);
return true;
}
}

ITypeSerializer GetTypeSerializer(Type type)
{
var serializer = this.Settings.CustomTypeSerializers.FirstOrDefault(h => h.Handles(type));
Expand Down
19 changes: 19 additions & 0 deletions NetSerializer/Settings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,25 @@ namespace NetSerializer
{
public class Settings
{
/// <summary>
/// Maximum number of elements allowed in a deserialized collection.
/// </summary>
/// <remarks>
/// This defends network deserialization from allocating attacker-controlled
/// arrays, lists, dictionaries, and similar collection types.
/// </remarks>
public uint MaxCollectionLength = 1024 * 1024;

/// <summary>
/// Maximum number of bytes allowed in a deserialized byte array.
/// </summary>
public uint MaxByteArrayLength = 16 * 1024 * 1024;

/// <summary>
/// Maximum number of UTF-16 code units allowed in a deserialized string.
/// </summary>
public uint MaxStringLength = 16 * 1024 * 1024;

/// <summary>
/// Array of custom TypeSerializers
/// </summary>
Expand Down
12 changes: 12 additions & 0 deletions NetSerializer/TypeIDList.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ public bool ContainsTypeID(uint typeID)
return typeID < m_array.Length && m_array[typeID] != null;
}

public bool TryGetValue(uint typeID, out TypeData data)
{
if (typeID < m_array.Length && m_array[typeID] != null)
{
data = m_array[typeID];
return true;
}

data = null;
return false;
}

public TypeData this[uint idx]
{
get
Expand Down
4 changes: 4 additions & 0 deletions NetSerializer/TypeSerializers/ArraySerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ public void GenerateReaderMethod(Serializer serializer, Type type, ILGenerator i

il.MarkLabel(notNullLabel);

il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldloc_S, lenLocal);
il.Emit(OpCodes.Call, typeof(Serializer).GetMethod("CheckCollectionLength", BindingFlags.Instance | BindingFlags.NonPublic)!);

var arrLocal = il.DeclareLocal(type);

// create new array with len - 1
Expand Down
4 changes: 4 additions & 0 deletions NetSerializer/TypeSerializers/ListSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ public void GenerateReaderMethod(Serializer serializer, Type type, ILGenerator i

il.MarkLabel(notNullLabel);

il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldloc_S, lenLocal);
il.Emit(OpCodes.Call, typeof(Serializer).GetMethod("CheckCollectionLength", BindingFlags.Instance | BindingFlags.NonPublic)!);

var listLocal = il.DeclareLocal(type);

// -- length
Expand Down
28 changes: 28 additions & 0 deletions NetSerializer/TypeSerializers/ObjectSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,33 @@ public static void Deserialize(Serializer serializer, Stream stream, out object
var del = serializer.GetDeserializeTrampolineFromId(id);
del(serializer, stream, out ob);
}

public static bool TryDeserialize(Serializer serializer, Stream stream, out object ob)
{
uint id;

Primitives.ReadPrimitive(stream, out id);

if (id == 0)
{
ob = null;
return true;
}

if (id == Serializer.ObjectTypeId)
{
ob = new object();
return true;
}

if (!serializer.TryGetDeserializeTrampolineFromId(id, out var del))
{
ob = null;
return false;
}

del(serializer, stream, out ob);
return true;
}
}
}