diff --git a/NetSerializer/Primitives.cs b/NetSerializer/Primitives.cs index 47a90b9..3011ae3 100644 --- a/NetSerializer/Primitives.cs +++ b/NetSerializer/Primitives.cs @@ -25,6 +25,9 @@ public static class Primitives private const int StringByteBufferLength = 256; private const int StringCharBufferLength = 128; + public static uint MaxByteArrayLength = 16 * 1024 * 1024; + public static uint MaxStringLength = 16 * 1024 * 1024; + public static MethodInfo GetWritePrimitive(Type type) { return typeof(Primitives).GetMethod("WritePrimitive", @@ -547,6 +550,8 @@ public static void ReadPrimitive(Stream stream, out string value) totalBytes -= 1; ReadPrimitive(stream, out uint totalChars); + if (totalChars > MaxStringLength) + throw new InvalidDataException($"Serialized string length {totalChars} exceeds maximum {MaxStringLength}."); value = string.Create((int) totalChars, ((int) totalBytes, stream), _stringSpanRead); } @@ -645,6 +650,8 @@ public static void ReadPrimitive(Stream stream, out string value) uint totalChars; ReadPrimitive(stream, out totalChars); + if (totalChars > MaxStringLength) + throw new InvalidDataException($"Serialized string length {totalChars} exceeds maximum {MaxStringLength}."); len -= 1; @@ -759,6 +766,8 @@ public static void ReadPrimitive(Stream stream, out string value) uint totalChars; ReadPrimitive(stream, out totalChars); + if (totalChars > MaxStringLength) + throw new InvalidDataException($"Serialized string length {totalChars} exceeds maximum {MaxStringLength}."); var helper = s_stringHelper; if (helper == null) @@ -840,6 +849,8 @@ public static void ReadPrimitive(Stream stream, out byte[] value) } len -= 1; + if (len > MaxByteArrayLength) + throw new InvalidDataException($"Serialized byte array length {len} exceeds maximum {MaxByteArrayLength}."); value = new byte[len]; int l = 0; diff --git a/NetSerializer/Serializer.cs b/NetSerializer/Serializer.cs index 92df080..caf72ee 100644 --- a/NetSerializer/Serializer.cs +++ b/NetSerializer/Serializer.cs @@ -55,6 +55,8 @@ public Serializer(IEnumerable rootTypes) public Serializer(IEnumerable rootTypes, Settings settings) { this.Settings = settings; + Primitives.MaxByteArrayLength = settings.MaxByteArrayLength; + Primitives.MaxStringLength = settings.MaxStringLength; if (this.Settings.CustomTypeSerializers.All(s => s is IDynamicTypeSerializer || s is IStaticTypeSerializer) == false) throw new ArgumentException("TypeSerializers have to implement IDynamicTypeSerializer or IStaticTypeSerializer"); @@ -93,6 +95,8 @@ public Serializer(Dictionary typeMap) public Serializer(Dictionary typeMap, Settings settings) { this.Settings = settings; + Primitives.MaxByteArrayLength = settings.MaxByteArrayLength; + Primitives.MaxStringLength = settings.MaxStringLength; if (this.Settings.CustomTypeSerializers.All(s => s is IDynamicTypeSerializer || s is IStaticTypeSerializer) == false) throw new ArgumentException("TypeSerializers have to implement IDynamicTypeSerializer or IStaticTypeSerializer"); @@ -337,6 +341,35 @@ public void Deserialize(Stream stream, out object ob) ObjectSerializer.Deserialize(this, stream, out ob); } + public bool TryDeserialize(Stream stream, out object ob) + { + return ObjectSerializer.TryDeserialize(this, stream, out ob); + } + + public bool TryGetTypeFromSerializedObject(Stream stream, out Type type) + { + Primitives.ReadPrimitive(stream, out uint id); + return TryGetTypeFromId(id, out type); + } + + public bool TryGetTypeFromId(uint id, out Type type) + { + if (id == 0) + { + type = null; + return true; + } + + if (m_runtimeTypeIDList.TryGetValue(id, out var data)) + { + type = data.Type; + return true; + } + + type = null; + return false; + } + /// /// Serialize object graph without writing the type-id of the root type. This can be useful e.g. when /// serializing a known value type, as this will avoid boxing. @@ -371,6 +404,19 @@ public void DeserializeDirect(Stream stream, out T value) del(this, stream, out value); } + internal void CheckCollectionLength(uint encodedLength) + { + if (encodedLength == 0) + return; + + var length = encodedLength - 1; + if (length > Settings.MaxCollectionLength) + { + throw new InvalidDataException( + $"Serialized collection length {length} exceeds maximum {Settings.MaxCollectionLength}."); + } + } + public int RegisterContext(object context) { AssertLocked(); @@ -412,7 +458,8 @@ internal uint GetTypeIdAndSerializer(Type type, out SerializeDelegate de internal DeserializeDelegate GetDeserializeTrampolineFromId(uint id) { - var data = m_runtimeTypeIDList[id]; + if (!m_runtimeTypeIDList.TryGetValue(id, out var data)) + throw new InvalidDataException($"Unknown serialized type ID {id}."); if (data.ReaderTrampolineDelegate != null) return data.ReaderTrampolineDelegate; @@ -423,6 +470,27 @@ internal DeserializeDelegate GetDeserializeTrampolineFromId(uint id) } } + internal bool TryGetDeserializeTrampolineFromId(uint id, out DeserializeDelegate del) + { + if (!m_runtimeTypeIDList.TryGetValue(id, out var data)) + { + del = null; + return false; + } + + if (data.ReaderTrampolineDelegate != null) + { + del = data.ReaderTrampolineDelegate; + return true; + } + + lock (m_modifyLock) + { + del = GenerateReaderTrampoline(data.Type); + return true; + } + } + ITypeSerializer GetTypeSerializer(Type type) { var serializer = this.Settings.CustomTypeSerializers.FirstOrDefault(h => h.Handles(type)); diff --git a/NetSerializer/Settings.cs b/NetSerializer/Settings.cs index 94880da..f3ef8f6 100644 --- a/NetSerializer/Settings.cs +++ b/NetSerializer/Settings.cs @@ -13,6 +13,25 @@ namespace NetSerializer { public class Settings { + /// + /// Maximum number of elements allowed in a deserialized collection. + /// + /// + /// This defends network deserialization from allocating attacker-controlled + /// arrays, lists, dictionaries, and similar collection types. + /// + public uint MaxCollectionLength = 1024 * 1024; + + /// + /// Maximum number of bytes allowed in a deserialized byte array. + /// + public uint MaxByteArrayLength = 16 * 1024 * 1024; + + /// + /// Maximum number of UTF-16 code units allowed in a deserialized string. + /// + public uint MaxStringLength = 16 * 1024 * 1024; + /// /// Array of custom TypeSerializers /// diff --git a/NetSerializer/TypeIDList.cs b/NetSerializer/TypeIDList.cs index 569d159..b4fecbe 100644 --- a/NetSerializer/TypeIDList.cs +++ b/NetSerializer/TypeIDList.cs @@ -34,6 +34,18 @@ public bool ContainsTypeID(uint typeID) return typeID < m_array.Length && m_array[typeID] != null; } + public bool TryGetValue(uint typeID, out TypeData data) + { + if (typeID < m_array.Length && m_array[typeID] != null) + { + data = m_array[typeID]; + return true; + } + + data = null; + return false; + } + public TypeData this[uint idx] { get diff --git a/NetSerializer/TypeSerializers/ArraySerializer.cs b/NetSerializer/TypeSerializers/ArraySerializer.cs index 23231f1..3ed17c5 100644 --- a/NetSerializer/TypeSerializers/ArraySerializer.cs +++ b/NetSerializer/TypeSerializers/ArraySerializer.cs @@ -129,6 +129,10 @@ public void GenerateReaderMethod(Serializer serializer, Type type, ILGenerator i il.MarkLabel(notNullLabel); + il.Emit(OpCodes.Ldarg_0); + il.Emit(OpCodes.Ldloc_S, lenLocal); + il.Emit(OpCodes.Call, typeof(Serializer).GetMethod("CheckCollectionLength", BindingFlags.Instance | BindingFlags.NonPublic)!); + var arrLocal = il.DeclareLocal(type); // create new array with len - 1 diff --git a/NetSerializer/TypeSerializers/ListSerializer.cs b/NetSerializer/TypeSerializers/ListSerializer.cs index e0c38c3..60cc638 100644 --- a/NetSerializer/TypeSerializers/ListSerializer.cs +++ b/NetSerializer/TypeSerializers/ListSerializer.cs @@ -123,6 +123,10 @@ public void GenerateReaderMethod(Serializer serializer, Type type, ILGenerator i il.MarkLabel(notNullLabel); + il.Emit(OpCodes.Ldarg_0); + il.Emit(OpCodes.Ldloc_S, lenLocal); + il.Emit(OpCodes.Call, typeof(Serializer).GetMethod("CheckCollectionLength", BindingFlags.Instance | BindingFlags.NonPublic)!); + var listLocal = il.DeclareLocal(type); // -- length diff --git a/NetSerializer/TypeSerializers/ObjectSerializer.cs b/NetSerializer/TypeSerializers/ObjectSerializer.cs index 9b8bbd9..99d958b 100644 --- a/NetSerializer/TypeSerializers/ObjectSerializer.cs +++ b/NetSerializer/TypeSerializers/ObjectSerializer.cs @@ -79,5 +79,33 @@ public static void Deserialize(Serializer serializer, Stream stream, out object var del = serializer.GetDeserializeTrampolineFromId(id); del(serializer, stream, out ob); } + + public static bool TryDeserialize(Serializer serializer, Stream stream, out object ob) + { + uint id; + + Primitives.ReadPrimitive(stream, out id); + + if (id == 0) + { + ob = null; + return true; + } + + if (id == Serializer.ObjectTypeId) + { + ob = new object(); + return true; + } + + if (!serializer.TryGetDeserializeTrampolineFromId(id, out var del)) + { + ob = null; + return false; + } + + del(serializer, stream, out ob); + return true; + } } }