From 61d4e105d8e64406bab7a814f36e2823b3a44b19 Mon Sep 17 00:00:00 2001 From: Aviv Atedgi Date: Mon, 15 Mar 2021 17:11:53 +0200 Subject: [PATCH 1/3] [WIP] Added deserialization for Message (Done). Also working to make `from_stream` work on any field. Also added some tests and updated the messages documentation --- docs/messages.md | 37 ++++++++++++- hydration/base.py | 9 ++-- hydration/fields.py | 9 +++- hydration/message.py | 101 +++++++++++++++++++++++++++++++++- hydration/scalars.py | 7 +++ hydration/vectors.py | 12 ++++- requirements.txt | 3 +- tests/__init__.py | 0 tests/test_message.py | 104 ++++++++++++++++++++++++++++++++++-- tests/test_nested_struct.py | 3 ++ tests/test_placeholder.py | 17 +++++- tests/test_structs.py | 26 ++++----- tests/test_validation.py | 7 +++ tests/test_vectors.py | 21 ++++++-- tests/utils.py | 11 ++++ 15 files changed, 336 insertions(+), 31 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/utils.py diff --git a/docs/messages.md b/docs/messages.md index 54b7c83..6d0ec3c 100644 --- a/docs/messages.md +++ b/docs/messages.md @@ -92,4 +92,39 @@ Header3: opcode: UInt32(2) Body3: data3: UInt64(40) -``` \ No newline at end of file +``` + +Once you have an `OpcodeField` in your message, you can also construct a message straight from bytes data.
+Using the same example as above: + +```pycon +>>> print(bytes(Header3() / Body2())) +b'\x01\x00\x00\x00\x14\x00\x00\x00' +>>> print(Message.from_bytes(Header3, b'\x01\x00\x00\x00\x14\x00\x00\x00')) +Header3: + opcode: UInt32(1) +Body2: + data: UInt32(20) +``` + +You can also deserialize more structs than just the header using the `Message.from_bytes` function, for example: + +```python +class Footer(Struct): + x = UInt32() +``` +```pycon +>>> print(bytes(Header3() / Body2() / Footer())) +b'\x01\x00\x00\x00\x14\x00\x00\x00\x00\x00\x00\x00' +>>> print(Message.from_bytes(Header3, b'\x01\x00\x00\x00\x14\x00\x00\x00\x00\x00\x00\x00', Footer)) +Header3: + opcode: UInt32(1) +Body2: + data: UInt32(20) +Footer: + x: UInt32(0) +``` + +* You can also use the `from_stream` function instead `from_bytes`. +* You can pass more than one footer to the `Message.from_bytes` or `Message.from_stream` functions. +* You can pass footers that also has an `OpcodeField` in them and it will deserialize it as a message. diff --git a/hydration/base.py b/hydration/base.py index 0e849a6..dea1aac 100644 --- a/hydration/base.py +++ b/hydration/base.py @@ -258,21 +258,18 @@ def from_stream(cls, read_func: Callable[[int], bytes], *args): :return: The deserialized object. """ + print(f'Using `from_stream` in {cls.__name__}') obj = cls(*args) for field in obj._fields: - obj.invoke_from_bytes_hooks(field) if isinstance(field, VLA): field.length = int(getattr(obj, field.length_field_name)) - data = read_func(field.length) - field.from_bytes(data) + field.from_stream(read_func) else: - read_size = field.size + field.from_stream(read_func) - data = read_func(read_size) - field.value = field.from_bytes(data).value with suppress(AttributeError): field.validator.validate(field.value) diff --git a/hydration/fields.py b/hydration/fields.py index 3f22eb2..2a92865 100644 --- a/hydration/fields.py +++ b/hydration/fields.py @@ -1,6 +1,6 @@ import abc from abc import ABC -from typing import Union +from typing import Union, Callable from .validators import ValidatorABC @@ -51,6 +51,10 @@ def __bytes__(self) -> bytes: def from_bytes(self, data: bytes): raise NotImplementedError + @abc.abstractmethod + def from_stream(self, read_func: Callable[[int], bytes]): + raise NotImplementedError + def __eq__(self, other): if isinstance(other, Field): return self.value == other.value and len(self) == len(other) @@ -118,3 +122,6 @@ def __bytes__(self) -> bytes: def from_bytes(self, data: bytes): raise AttributeError('Placeholders cannot be deserialized') + + def from_stream(self, read_func: Callable[[int], bytes]): + raise AttributeError('Placeholders cannot be deserialized') diff --git a/hydration/message.py b/hydration/message.py index d54cc83..7ed555f 100644 --- a/hydration/message.py +++ b/hydration/message.py @@ -1,7 +1,8 @@ import inspect from abc import ABC, abstractmethod from contextlib import suppress -from typing import List, Union, Type, Mapping +from typing import List, Union, Type, Mapping, Callable +from bidict import bidict, ValueDuplicationError from hydration.helpers import as_obj from .base import Struct @@ -137,6 +138,95 @@ def __contains__(self, item): def __len__(self): return len(self.layers) + @classmethod + def from_bytes(cls, header_class: Type[Struct], data: bytes, *additional_classes: List[Type[Struct]]): + """ + Create a message from bytes data, using a header with an OpcodeField. + + :param header_class: A struct class which is the header of the message + :param data: Data containing the message (in bytes) + :param additional_classes: Additional classes to deserialize after the header and the body + :return: A message created from `data`,based on `header_class` and `additional_classes` + """ + + # Find the opcode field in the header + for opcode_name, opcode_field in as_obj(header_class): + if isinstance(opcode_field, OpcodeField): + break + else: + raise ValueError(f'Header {header_class.__name__} must have an opcode field in order to deserialize a message') + + # Create the header object + header = header_class.from_bytes(data) + data = data[len(header):] + + # Extract body class from header's opcode field + header_opcode_value = getattr(header, opcode_name).value + body_class: Type[Struct] = bidict(opcode_field.opcode_dictionary).inverse[header_opcode_value] + + # Create the body + body = body_class.from_bytes(data) + data = data[len(body):] + additional_layers = [] + + # Add the additional classes + for additional_struct in additional_classes: + try: + # Try to deserialize the struct as an header, if it doesn't have an OpcodeField + # it will raise a ValueError and we will treat it as a normal struct + msg = Message.from_bytes(additional_struct, data) + additional_layers.extend(msg.layers) + data = data[msg.size:] + except ValueError: + obj = additional_struct.from_bytes(data) + additional_layers.append(obj) + data = data[len(obj):] + + return cls(header, body, *additional_layers, update_metadata=False) + + @classmethod + def from_stream(cls, header_class: Type[Struct], read_func: Callable[[int], bytes], *additional_classes: List[Type[Struct]]): + """ + Create a message from bytes data, using a header with an OpcodeField. + + :param header_class: A struct class which is the header of the message + :param read_func: The stream's reader function + The function needs to receive an int as a positional parameter and return a bytes object. + :param additional_classes: Additional classes to deserialize after the header and the body + :return: A message created from `read_func`,based on `header_class` and `additional_classes` + """ + + # Find the opcode field in the header + for opcode_name, opcode_field in as_obj(header_class): + if isinstance(opcode_field, OpcodeField): + break + else: + raise ValueError(f'Header {header_class.__name__} must have an opcode field in order to deserialize a message') + + # Create the header object + header = header_class.from_stream(read_func) + + # Extract body class from header's opcode field + header_opcode_value = getattr(header, opcode_name).value + body_class: Type[Struct] = bidict(opcode_field.opcode_dictionary).inverse[header_opcode_value] + + # Create the body + body = body_class.from_stream(read_func) + additional_layers = [] + + # Add the additional classes + for additional_struct in additional_classes: + try: + # Try to deserialize the struct as an header, if it doesn't have an OpcodeField + # it will raise a ValueError and we will treat it as a normal struct + msg = Message.from_stream(additional_struct, read_func) + additional_layers.extend(msg.layers) + except ValueError: + obj = additional_struct.from_stream(read_func) + additional_layers.append(obj) + + return cls(header, body, *additional_layers, update_metadata=False) + @property def size(self): # layers are structs or bytes, so use len instead of size @@ -186,6 +276,9 @@ def __bytes__(self) -> bytes: def from_bytes(self, data: bytes): return self.data_field.from_bytes(data) + def from_stream(self, read_func: Callable[[int], bytes]): + return self.data_field.from_stream(read_func) + @abstractmethod def update(self, message: Message, struct: Struct, struct_index: int): raise NotImplementedError @@ -206,6 +299,12 @@ class OpcodeField(MetaField): def __init__(self, data_field: FieldType, opcode_dictionary: Mapping): super().__init__(data_field) self.opcode_dictionary = opcode_dictionary + + try: + # Validate that there are no duplicate opcodes + bidict(opcode_dictionary) + except ValueDuplicationError: + raise ValueError("Opcode values must be unique") def update(self, message: Message, struct: Struct, struct_index: int): with suppress(IndexError): diff --git a/hydration/scalars.py b/hydration/scalars.py index 6b0af12..82ef9ab 100644 --- a/hydration/scalars.py +++ b/hydration/scalars.py @@ -128,6 +128,10 @@ def from_bytes(self, data: bytes): self.value = struct.unpack(format_string, data)[0] return self + def from_stream(self, read_func: Callable[[int], bytes]): + print(f'Using `from_stream` in {self.__class__.__name__}') + return self.from_bytes(read_func(len(self))) + def __trunc__(self): return trunc(self.value) @@ -317,6 +321,9 @@ def from_bytes(self, data: bytes): self.value = self.type.value return self + def from_stream(self, read_func: Callable[[int], bytes]): + return self.from_bytes(read_func(len(self))) + @property def name(self): return self.enum_class(self.value).name diff --git a/hydration/vectors.py b/hydration/vectors.py index 343d953..696ee81 100644 --- a/hydration/vectors.py +++ b/hydration/vectors.py @@ -1,7 +1,7 @@ import copy from abc import ABC from collections import UserList -from typing import Sequence, Optional, Any, Union, Iterable +from typing import Sequence, Optional, Any, Union, Iterable, Callable from itertools import islice from .base import Struct @@ -55,6 +55,11 @@ def from_bytes(self, data: bytes): self.value = tuple(field_type.from_bytes(chunk).value for chunk in byte_chunks(data, len(field_type))) return self + def from_stream(self, read_func: Callable[[int], bytes]): + field_type = copy.deepcopy(self.type) + self.value = tuple(field_type.from_stream(read_func) for _ in range(len(self))) + return self + def __str__(self): return '{}{}'.format(self.__class__.__qualname__, self.value) @@ -184,6 +189,11 @@ def from_bytes(self, data: bytes): self.value = val return self + def from_stream(self, read_func: Callable[[int], bytes]): + print(f'Using `from_stream` in {self.__class__.__name__}') + self.value = [self.type.from_stream(read_func) for _ in range(len(self))] + return self + def __len__(self) -> int: return VLA.__len__(self) diff --git a/requirements.txt b/requirements.txt index 1053340..0eb0416 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ pytest -pyhooks>=1.0.3 \ No newline at end of file +pyhooks>=1.0.3 +bidict \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_message.py b/tests/test_message.py index 07dc361..24e26e5 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,9 +1,10 @@ -from zlib import crc32 - import pytest - import hydration as h +from zlib import crc32 + +from .utils import as_reader + class Tomer(h.Struct): b = h.UInt8(5) @@ -153,3 +154,100 @@ class Footer(h.Struct, endianness=h.BigEndian): def test_crc(): msg = Header(magic=0x01052000) / Footer() assert crc32(bytes(msg)[:-4]) == msg[Footer].crc.value + + +class MessageBody1(h.Struct): + x = h.UInt16(0x1D14) + + +class MessageBody2(h.Struct): + y = h.UInt32(0x06072001) + + +class MessageHeader(h.Struct): + magic = h.UInt32(0xDEADBEEF) + opcode = h.OpcodeField(h.UInt16, {MessageBody1: 0x00, MessageBody2: 0x01}) + + +class MessageFooter(h.Struct): + magic = h.UInt32(0x000C0FEE) + + +class AnotherFooter(h.Struct): + another_magic = h.UInt32(0x01052000) + + +class FooterBody1(h.Struct): + x = h.UInt8(0xFF) + + +class FooterBody2(h.Struct): + y = h.UInt64(0x0123456789ABCDEF) + + +class FooterThatIsAlsoHeader(h.Struct): + magic = h.UInt32(0x09012001) + my_opcode = h.OpcodeField(h.UInt16, {FooterBody1: 0x02, FooterBody2: 0x00}) + + +def test_message_deserialization(): + msg = MessageHeader() / MessageBody1() + assert msg[MessageHeader].opcode == 0 + assert h.Message.from_bytes(MessageHeader, bytes(msg)) == msg + assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg))) == msg + + msg = MessageHeader() / MessageBody2() + assert msg[MessageHeader].opcode == 1 + assert h.Message.from_bytes(MessageHeader, bytes(msg)) == msg + assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg))) == msg + + +def test_message_deserialization_with_footer(): + msg = MessageHeader() / MessageBody1() / MessageFooter() + assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter) == msg + assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter) == msg + + msg = MessageHeader() / MessageBody2() / MessageFooter() + assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter) == msg + assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter) == msg + + +def test_message_deserialization_with_multiple_footers(): + msg = MessageHeader() / MessageBody1() / MessageFooter() / AnotherFooter() + assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, AnotherFooter) == msg + assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, AnotherFooter) == msg + + msg = MessageHeader() / MessageBody2() / MessageFooter() / AnotherFooter() + assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, AnotherFooter) == msg + assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, AnotherFooter) == msg + + +def test_message_deserialization_with_multiple_footers_and_multiple_opcodes(): + msg = MessageHeader() / MessageBody1() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody1() / AnotherFooter() + assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + + msg = MessageHeader() / MessageBody2() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody1() / AnotherFooter() + assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + + msg = MessageHeader() / MessageBody1() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody2() / AnotherFooter() + assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + + msg = MessageHeader() / MessageBody2() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody2() / AnotherFooter() + assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + + +def test_message_deserialization_with_missing_opcode_field(): + class MissingOpcode(h.Struct): + x = h.UInt16() + + with pytest.raises(ValueError): + h.Message.from_bytes(MissingOpcode, bytes(MissingOpcode())) + + with pytest.raises(ValueError): + h.Message.from_stream(MissingOpcode, as_reader(bytes(MissingOpcode()))) + + diff --git a/tests/test_nested_struct.py b/tests/test_nested_struct.py index cd6c7f9..c560a64 100644 --- a/tests/test_nested_struct.py +++ b/tests/test_nested_struct.py @@ -1,6 +1,8 @@ import pytest import hydration as h +from .utils import as_reader + class Time(h.Struct): time = h.UInt64(3) @@ -20,6 +22,7 @@ class Ron(h.Struct): def test_nested(): a = Ron() assert a.from_bytes(bytes(a)) == a + assert a.from_stream(as_reader(bytes(a))) == a class Card(h.Struct): diff --git a/tests/test_placeholder.py b/tests/test_placeholder.py index 933c6a3..57d635a 100644 --- a/tests/test_placeholder.py +++ b/tests/test_placeholder.py @@ -3,6 +3,8 @@ import hydration as h from hydration.fields import Field +from .utils import as_reader + class IDVector(h.Struct): vec_len = h.UInt16() @@ -20,7 +22,7 @@ def set_vec_field(self): self.vec.type = d[int(self.id_len)]() -def test_id_vector(): +def test_id_vector_using_from_bytes(): idv = IDVector(h.UInt8, vec=[0, 1, 1, 0]) assert idv.vec_len == 4 @@ -31,3 +33,16 @@ def test_id_vector(): assert idv.vec_len == 2 assert idv.id_len == 2 assert IDVector.from_bytes(bytes(idv)) == idv + + +def test_id_vector_using_from_stream(): + + idv = IDVector(h.UInt8, vec=[0, 1, 1, 0]) + assert idv.vec_len == 4 + assert idv.id_len == 1 + assert IDVector.from_stream(as_reader(bytes(idv))) == idv + + idv = IDVector(h.UInt16, vec=[256, 1]) + assert idv.vec_len == 2 + assert idv.id_len == 2 + assert IDVector.from_stream(as_reader(bytes(idv))) == idv diff --git a/tests/test_structs.py b/tests/test_structs.py index 63dfa48..2bb169a 100644 --- a/tests/test_structs.py +++ b/tests/test_structs.py @@ -1,6 +1,8 @@ import pytest import hydration as h +from .utils import as_reader + class Omri(h.Struct): a = h.UInt16(256) @@ -75,6 +77,8 @@ def test_footer(): def test_from_stream(): + from .utils import MockReader + class Shustin(h.Struct): length = h.UInt16() data = h.Vector(length=length, field_type=h.UInt8()) @@ -85,27 +89,21 @@ class NadavLoYazam(h.Struct): yazam = h.UInt32() bihlal = h.UInt64() - class MockReader: - def __init__(self, data: bytes): - self._data = data - - def read(self, size=0): - user_data, self._data = self._data[:size], self._data[size:] - return user_data - shustin = Shustin() shustin.length = 32 shustin.data = [x for x in range(0, 32, 1)] - my_shustin = Shustin.from_stream(MockReader(bytes(shustin)).read) - assert bytes(shustin) == bytes(my_shustin) + my_shustin = Shustin.from_stream(as_reader(bytes(shustin))) + # assert bytes(shustin) == bytes(my_shustin) + assert shustin == my_shustin nadav = NadavLoYazam() nadav.nadav = 3 nadav.lo = 854 nadav.yazam = 1512 nadav.bihlal = 38272 - nadav_lo_yazam = NadavLoYazam.from_stream(MockReader(bytes(nadav)).read) - assert bytes(nadav) == bytes(nadav_lo_yazam) + nadav_lo_yazam = NadavLoYazam.from_stream(as_reader(bytes(nadav))) + # assert bytes(nadav) == bytes(nadav_lo_yazam) + assert nadav == nadav_lo_yazam def test_new_attributes(): @@ -135,6 +133,7 @@ def __init__(self, x, *args, **kwargs): a = Amadeus(3) assert Amadeus(3).from_bytes(bytes(a)) == a + assert Amadeus(3).from_stream(as_reader(bytes(a))) == a class Mozart(Amadeus): y = h.UInt16 @@ -145,6 +144,7 @@ def __init__(self, x, y, **kwargs): m = Mozart(4, 5) assert Mozart(0, 0).from_bytes(bytes(m)) == m + assert Mozart(0, 0).from_stream(as_reader(bytes(m))) == m def test_from_bytes_hooks(): @@ -160,5 +160,7 @@ def foo(self): r.arr.type = h.UInt16() r2 = Ronen.from_bytes(bytes(r)) + r3 = Ronen.from_stream(as_reader(bytes(r))) assert r2.arr == list(range(10)) + assert r3.arr == list(range(10)) diff --git a/tests/test_validation.py b/tests/test_validation.py index bf14e56..2e2da51 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -3,6 +3,8 @@ import pytest import hydration as h +from .utils import as_reader + class Tst(h.Struct): a = h.UInt8(validator=0) @@ -16,6 +18,11 @@ def test_init(): def test_from_bytes(): with pytest.raises(ValueError): Tst.from_bytes(b'\x02') + + +def test_from_stream(): + with pytest.raises(ValueError): + Tst.from_stream(as_reader(b'\x02')) def test_bad_struct(): diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 546cbbd..12904ff 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -1,7 +1,8 @@ import pytest - import hydration as h +from .utils import as_reader + class Garzon(h.Struct): arr = h.Array(field_type=h.UInt8(3), length=3, value=(1, 2), fill=True) @@ -27,6 +28,7 @@ def test_vector(): x = Shine() new_x = Shine.from_bytes(bytes(x)) assert x == new_x + assert x == Shine.from_stream(as_reader(bytes(x))) def test_vector_len_update(): @@ -52,6 +54,7 @@ def test_bad_val(): def test_array(): x = Isaac() assert x == Isaac.from_bytes(bytes(x)) + assert x == Isaac.from_stream(as_reader(bytes(x))) assert bytes(x) == b'\x01\x02\x03\x05' @@ -74,10 +77,14 @@ class Venice(h.Struct): assert str(Venice().ip) == '0.0.0.0' assert Venice.from_bytes(bytes(Venice(ip='127.0.0.1'))) == Venice(ip='127.0.0.1') + assert Venice.from_stream(as_reader(bytes(Venice(ip='127.0.0.1')))) == Venice(ip='127.0.0.1') with pytest.raises(ValueError): Venice.from_bytes(bytes(Venice(ip='127.0.0.1'))[:-1]) + with pytest.raises(ValueError): + Venice.from_stream(as_reader(bytes(Venice(ip='127.0.0.1'))[:-1])) + def test_type_field(): class Lior(h.Struct): @@ -136,9 +143,10 @@ class Data(h.Struct): d = Data(data=[3] * 10) b = bytes(d) assert Data.from_bytes(b).data == d.data + assert Data.from_stream(as_reader(b)).data == d.data -def test_dynamic_vec_size(): +def test_vector_with_dynamic_item_size(): class Atedgi(h.Struct): this = h.UInt8(1) @@ -153,7 +161,7 @@ def __init__(self, this=1, *args, **kwargs): def set_vec_field(self): d = {len(x()): x for x in (h.UInt8, h.UInt16, h.UInt32, h.UInt64)} self.aviv = d[self.this.value]() - + class Maor(h.Struct): vec_len = h.UInt16() vec = h.Vector(vec_len, Atedgi) @@ -165,9 +173,14 @@ class Maor(h.Struct): assert len(obj) == size + 1 real_deal = Maor(vec=x) - identical = Maor.from_bytes(bytes(real_deal)) + assert real_deal.vec.type == identical.vec.type + for a1, a2 in zip(real_deal.vec.value, identical.vec.value): + assert a1.aviv == a2.aviv + + real_deal = Maor(vec=x) + identical = Maor.from_stream(as_reader(bytes(real_deal))) assert real_deal.vec.type == identical.vec.type for a1, a2 in zip(real_deal.vec.value, identical.vec.value): diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..7f1f883 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,11 @@ +class MockReader: + def __init__(self, data: bytes): + self._data = data + + def read(self, size=0): + user_data, self._data = self._data[:size], self._data[size:] + return user_data + + +def as_reader(data: bytes): + return MockReader(data).read From 887288b26f5c269777fccc71833ffb40b40052de Mon Sep 17 00:00:00 2001 From: Aviv Atedgi Date: Tue, 16 Mar 2021 12:30:45 +0200 Subject: [PATCH 2/3] Fixed all of the tests, message deserialization fully works on `from_bytes` and `from_stream`, also added support for `from_stream` in all of the fields and it is now a requirement to create a new field (one of the abstract methods in the abstract class `Field`) --- hydration/base.py | 12 ++++-- hydration/scalars.py | 1 - hydration/vectors.py | 6 ++- tests/test_vectors.py | 88 ++++++++++++++++++++++++++++--------------- 4 files changed, 71 insertions(+), 36 deletions(-) diff --git a/hydration/base.py b/hydration/base.py index dea1aac..33abdf0 100644 --- a/hydration/base.py +++ b/hydration/base.py @@ -8,7 +8,7 @@ from .helpers import as_obj, assert_no_property_override, as_type from .scalars import Scalar, Enum -from .fields import Field, VLA, FieldPlaceholder +from .fields import Field, VLA from .endianness import Endianness illegal_field_names = ['value', 'validate', '_fields'] @@ -258,12 +258,18 @@ def from_stream(cls, read_func: Callable[[int], bytes], *args): :return: The deserialized object. """ - print(f'Using `from_stream` in {cls.__name__}') obj = cls(*args) - for field in obj._fields: + for field_name in obj._field_names: + + # Get field for current field name + field = getattr(obj, field_name) + obj.invoke_from_bytes_hooks(field) + # Bytes hooks can change the field object, so get it again by name + field = getattr(obj, field_name) + if isinstance(field, VLA): field.length = int(getattr(obj, field.length_field_name)) field.from_stream(read_func) diff --git a/hydration/scalars.py b/hydration/scalars.py index 82ef9ab..043bb47 100644 --- a/hydration/scalars.py +++ b/hydration/scalars.py @@ -129,7 +129,6 @@ def from_bytes(self, data: bytes): return self def from_stream(self, read_func: Callable[[int], bytes]): - print(f'Using `from_stream` in {self.__class__.__name__}') return self.from_bytes(read_func(len(self))) def __trunc__(self): diff --git a/hydration/vectors.py b/hydration/vectors.py index 696ee81..814f9de 100644 --- a/hydration/vectors.py +++ b/hydration/vectors.py @@ -57,7 +57,7 @@ def from_bytes(self, data: bytes): def from_stream(self, read_func: Callable[[int], bytes]): field_type = copy.deepcopy(self.type) - self.value = tuple(field_type.from_stream(read_func) for _ in range(len(self))) + self.value = tuple(field_type.from_stream(read_func).value for _ in range(len(self))) return self def __str__(self): @@ -190,7 +190,9 @@ def from_bytes(self, data: bytes): return self def from_stream(self, read_func: Callable[[int], bytes]): - print(f'Using `from_stream` in {self.__class__.__name__}') + if isinstance(self.type, Field): + return super().from_bytes(read_func(len(self) * len(self.type))) + self.value = [self.type.from_stream(read_func) for _ in range(len(self))] return self diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 12904ff..f471335 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -24,11 +24,16 @@ class Shine(h.Struct): x = h.UInt16(104) -def test_vector(): +def test_vector_from_bytes(): x = Shine() new_x = Shine.from_bytes(bytes(x)) assert x == new_x - assert x == Shine.from_stream(as_reader(bytes(x))) + + +def test_vector_from_stream(): + x = Shine() + new_x = Shine.from_stream(as_reader(bytes(x))) + assert x == new_x def test_vector_len_update(): @@ -51,10 +56,16 @@ def test_bad_val(): assert bytes(x) == b'\x04\x03\x03\x05' -def test_array(): +def test_array_from_bytes(): x = Isaac() assert x == Isaac.from_bytes(bytes(x)) - assert x == Isaac.from_stream(as_reader(bytes(x))) + assert bytes(x) == b'\x01\x02\x03\x05' + + +def test_array_from_stream(): + x = Isaac() + new_x = Isaac.from_stream(as_reader(bytes(x))) + assert x == new_x assert bytes(x) == b'\x01\x02\x03\x05' @@ -71,19 +82,31 @@ class Shustin2(h.Struct): Shustin2() -def test_ipv4(): - class Venice(h.Struct): - ip = h.IPv4() +class Venice(h.Struct): + ip = h.IPv4() + +def test_ipv4(): assert str(Venice().ip) == '0.0.0.0' - assert Venice.from_bytes(bytes(Venice(ip='127.0.0.1'))) == Venice(ip='127.0.0.1') - assert Venice.from_stream(as_reader(bytes(Venice(ip='127.0.0.1')))) == Venice(ip='127.0.0.1') + + x = Venice(ip='127.0.0.1') + assert Venice.from_bytes(bytes(x)) == x with pytest.raises(ValueError): Venice.from_bytes(bytes(Venice(ip='127.0.0.1'))[:-1]) - with pytest.raises(ValueError): - Venice.from_stream(as_reader(bytes(Venice(ip='127.0.0.1'))[:-1])) + +def test_ipv4_from_stream(): + x = Venice(ip='127.0.0.1') + assert Venice.from_stream(as_reader(bytes(x))) == x + + """ + No need to test this test-case because it won't work + If there is no bytes left in the reader, the UInt8 + will raise an struct.error: unpack requires a buffer of 1 bytes + """ + # with pytest.raises(ValueError): + # Venice.from_stream(as_reader(bytes(x)[:-1])) def test_type_field(): @@ -146,40 +169,45 @@ class Data(h.Struct): assert Data.from_stream(as_reader(b)).data == d.data -def test_vector_with_dynamic_item_size(): +class Atedgi(h.Struct): + this = h.UInt8(1) + aviv = h.FieldPlaceholder() + + def __init__(self, this=1, *args, **kwargs): + self.this = this + self.set_vec_field() + super().__init__(*args, **kwargs) - class Atedgi(h.Struct): - this = h.UInt8(1) - aviv = h.FieldPlaceholder() + @h.from_bytes_hook(aviv) + def set_vec_field(self): + d = {len(x()): x for x in (h.UInt8, h.UInt16, h.UInt32, h.UInt64)} + self.aviv = d[self.this.value]() - def __init__(self, this=1, *args, **kwargs): - self.this = this - self.set_vec_field() - super().__init__(*args, **kwargs) - @h.from_bytes_hook(aviv) - def set_vec_field(self): - d = {len(x()): x for x in (h.UInt8, h.UInt16, h.UInt32, h.UInt64)} - self.aviv = d[self.this.value]() - - class Maor(h.Struct): - vec_len = h.UInt16() - vec = h.Vector(vec_len, Atedgi) +class Maor(h.Struct): + vec_len = h.UInt16() + vec = h.Vector(vec_len, Atedgi) + + +def test_vector_with_dynamic_item_size(): sizes = (1, 2, 4, 8) - from random import randint x = [Atedgi(this=size, aviv=5) for size in sizes] for obj, size in zip(x, sizes): assert len(obj) == size + 1 - real_deal = Maor(vec=x) + +def test_vector_with_dynamic_item_size_from_bytes(): + real_deal = Maor(vec=[Atedgi(this=size, aviv=5) for size in (1, 2, 4, 8)]) identical = Maor.from_bytes(bytes(real_deal)) assert real_deal.vec.type == identical.vec.type for a1, a2 in zip(real_deal.vec.value, identical.vec.value): assert a1.aviv == a2.aviv - real_deal = Maor(vec=x) + +def test_vector_with_dynamic_item_size_from_stream(): + real_deal = Maor(vec=[Atedgi(this=size, aviv=5) for size in (1, 2, 4, 8)]) identical = Maor.from_stream(as_reader(bytes(real_deal))) assert real_deal.vec.type == identical.vec.type From 91dae0ab7cdc2fa3ba934c6668f44dbb49773353 Mon Sep 17 00:00:00 2001 From: Aviv Atedgi Date: Fri, 19 Mar 2021 21:23:51 +0200 Subject: [PATCH 3/3] Updated the Message.from_stream functionallity to work with recursive OpcodeField logic, also removed all of the `from_bytes` logic in all of the Field and now its calling to `from_stream` with a mock reader --- hydration/base.py | 28 +------------ hydration/fields.py | 4 +- hydration/helpers.py | 13 ++++++ hydration/message.py | 85 +++++++++++++------------------------- hydration/scalars.py | 14 ++++--- hydration/vectors.py | 31 ++++++-------- tests/test_message.py | 95 ++++++++++++++++++++++++++++++------------- 7 files changed, 134 insertions(+), 136 deletions(-) diff --git a/hydration/base.py b/hydration/base.py index 33abdf0..fdd2b6a 100644 --- a/hydration/base.py +++ b/hydration/base.py @@ -6,7 +6,7 @@ from pyhooks import Hook, precall_register, postcall_register from typing import Callable, List, Iterable, Optional -from .helpers import as_obj, assert_no_property_override, as_type +from .helpers import as_obj, assert_no_property_override, as_type, as_stream from .scalars import Scalar, Enum from .fields import Field, VLA from .endianness import Endianness @@ -220,31 +220,7 @@ def from_bytes(cls, data: bytes, *args): :return The deserialized struct """ - obj = cls(*args) - - for field_name in obj._field_names: - - # Get field for current field name - field = getattr(obj, field_name) - - obj.invoke_from_bytes_hooks(field) - - # Bytes hooks can change the field object, so get it again by name - field = getattr(obj, field_name) - - if isinstance(field, VLA): - field.length = int(getattr(obj, field.length_field_name)) - field.from_bytes(data) - data = data[len(bytes(field)):] - else: - split_index = field.size - - field_data, data = data[:split_index], data[split_index:] - field.value = field.from_bytes(field_data).value - with suppress(AttributeError): - field.validator.validate(field.value) - - return obj + return cls.from_stream(as_stream(data), *args) @classmethod def from_stream(cls, read_func: Callable[[int], bytes], *args): diff --git a/hydration/fields.py b/hydration/fields.py index 2a92865..7c784d6 100644 --- a/hydration/fields.py +++ b/hydration/fields.py @@ -3,6 +3,7 @@ from typing import Union, Callable from .validators import ValidatorABC +from .helpers import as_stream class Field(ABC): @@ -47,9 +48,8 @@ def size(self): def __bytes__(self) -> bytes: raise NotImplementedError - @abc.abstractmethod def from_bytes(self, data: bytes): - raise NotImplementedError + return self.from_stream(as_stream(data)) @abc.abstractmethod def from_stream(self, read_func: Callable[[int], bytes]): diff --git a/hydration/helpers.py b/hydration/helpers.py index 928d6b3..3d25b1d 100644 --- a/hydration/helpers.py +++ b/hydration/helpers.py @@ -1,4 +1,5 @@ import inspect +from typing import Callable def as_type(obj): @@ -9,6 +10,18 @@ def as_obj(obj): return obj if not inspect.isclass(obj) else obj() +def as_stream(data: bytes) -> Callable[[int], bytes]: + class _StreamReader: + def __init__(self, _data: bytes): + self._data = _data + + def read(self, size: int) -> bytes: + user_data, self._data = self._data[:size], self._data[size:] + return user_data + + return _StreamReader(data).read + + def assert_no_property_override(obj, base_class): """ Use this to ensure that a Struct doesn't override properties of Field when using it as one. diff --git a/hydration/message.py b/hydration/message.py index 7ed555f..4924d15 100644 --- a/hydration/message.py +++ b/hydration/message.py @@ -4,7 +4,7 @@ from typing import List, Union, Type, Mapping, Callable from bidict import bidict, ValueDuplicationError -from hydration.helpers import as_obj +from .helpers import as_obj, as_stream from .base import Struct from .fields import Field from .validators import ValidatorABC, as_validator @@ -139,61 +139,28 @@ def __len__(self): return len(self.layers) @classmethod - def from_bytes(cls, header_class: Type[Struct], data: bytes, *additional_classes: List[Type[Struct]]): + def from_bytes(cls, data: bytes, header_class: Type[Struct], *layers: Type[Struct]): """ Create a message from bytes data, using a header with an OpcodeField. - :param header_class: A struct class which is the header of the message :param data: Data containing the message (in bytes) - :param additional_classes: Additional classes to deserialize after the header and the body - :return: A message created from `data`,based on `header_class` and `additional_classes` + :param header_class: The header class of the message + :param layers: The struct classes that represent the layers of the message + :return: A message created from `data`, based on `header_class` and `layers` """ - # Find the opcode field in the header - for opcode_name, opcode_field in as_obj(header_class): - if isinstance(opcode_field, OpcodeField): - break - else: - raise ValueError(f'Header {header_class.__name__} must have an opcode field in order to deserialize a message') - - # Create the header object - header = header_class.from_bytes(data) - data = data[len(header):] - - # Extract body class from header's opcode field - header_opcode_value = getattr(header, opcode_name).value - body_class: Type[Struct] = bidict(opcode_field.opcode_dictionary).inverse[header_opcode_value] - - # Create the body - body = body_class.from_bytes(data) - data = data[len(body):] - additional_layers = [] - - # Add the additional classes - for additional_struct in additional_classes: - try: - # Try to deserialize the struct as an header, if it doesn't have an OpcodeField - # it will raise a ValueError and we will treat it as a normal struct - msg = Message.from_bytes(additional_struct, data) - additional_layers.extend(msg.layers) - data = data[msg.size:] - except ValueError: - obj = additional_struct.from_bytes(data) - additional_layers.append(obj) - data = data[len(obj):] - - return cls(header, body, *additional_layers, update_metadata=False) + return cls.from_stream(as_stream(data), header_class, *layers) @classmethod - def from_stream(cls, header_class: Type[Struct], read_func: Callable[[int], bytes], *additional_classes: List[Type[Struct]]): + def from_stream(cls, read_func: Callable[[int], bytes], header_class: Type[Struct], *layers: Type[Struct]): """ - Create a message from bytes data, using a header with an OpcodeField. - - :param header_class: A struct class which is the header of the message + Create a message from bytes data, using a header with an OpcodeField and the layers that represent the message. + :param read_func: The stream's reader function The function needs to receive an int as a positional parameter and return a bytes object. - :param additional_classes: Additional classes to deserialize after the header and the body - :return: A message created from `read_func`,based on `header_class` and `additional_classes` + :param header_class: The header class of the message + :param layers: The struct classes that represent the layers of the message + :return: A message created from `read_func`, based on `header_class` and `layers` """ # Find the opcode field in the header @@ -201,7 +168,8 @@ def from_stream(cls, header_class: Type[Struct], read_func: Callable[[int], byte if isinstance(opcode_field, OpcodeField): break else: - raise ValueError(f'Header {header_class.__name__} must have an opcode field in order to deserialize a message') + raise AttributeError(f'Header {header_class.__name__} ' + f'must have an opcode field in order to deserialize a message') # Create the header object header = header_class.from_stream(read_func) @@ -211,20 +179,25 @@ def from_stream(cls, header_class: Type[Struct], read_func: Callable[[int], byte body_class: Type[Struct] = bidict(opcode_field.opcode_dictionary).inverse[header_opcode_value] # Create the body - body = body_class.from_stream(read_func) - additional_layers = [] + try: + # Try to treat the body as a message in case it's also contains an OpcodeField + body = Message.from_stream(read_func, body_class) + except AttributeError: + # If it doesn't contain an OpcodeField treat it like a normal struct + body = body_class.from_stream(read_func) - # Add the additional classes - for additional_struct in additional_classes: + additional_layers = [] + for layer in layers: try: - # Try to deserialize the struct as an header, if it doesn't have an OpcodeField - # it will raise a ValueError and we will treat it as a normal struct - msg = Message.from_stream(additional_struct, read_func) + # Try to treat the body as a message in case it's also contains an OpcodeField + msg = cls.from_stream(read_func, layer) additional_layers.extend(msg.layers) - except ValueError: - obj = additional_struct.from_stream(read_func) + except AttributeError: + # If it doesn't contain an OpcodeField treat it like a normal struct + obj = layer.from_stream(read_func) additional_layers.append(obj) + print(*additional_layers) return cls(header, body, *additional_layers, update_metadata=False) @property @@ -299,7 +272,7 @@ class OpcodeField(MetaField): def __init__(self, data_field: FieldType, opcode_dictionary: Mapping): super().__init__(data_field) self.opcode_dictionary = opcode_dictionary - + try: # Validate that there are no duplicate opcodes bidict(opcode_dictionary) diff --git a/hydration/scalars.py b/hydration/scalars.py index 043bb47..fca70a4 100644 --- a/hydration/scalars.py +++ b/hydration/scalars.py @@ -122,14 +122,16 @@ def __int__(self) -> int: def __float__(self) -> float: return float(self.value) - def from_bytes(self, data: bytes): + def from_stream(self, read_func: Callable[[int], bytes]): format_string = '{}{}'.format(self.endianness_format, self.scalar_format) - # noinspection PyAttributeOutsideInit - self.value = struct.unpack(format_string, data)[0] - return self - def from_stream(self, read_func: Callable[[int], bytes]): - return self.from_bytes(read_func(len(self))) + try: + # noinspection PyAttributeOutsideInit + self.value = struct.unpack(format_string, read_func(len(self)))[0] + except struct.error: + raise ValueError(f'Not enough bytes to unpack {self.__class__.__name__}') + + return self def __trunc__(self): return trunc(self.value) diff --git a/hydration/vectors.py b/hydration/vectors.py index 814f9de..2211344 100644 --- a/hydration/vectors.py +++ b/hydration/vectors.py @@ -50,11 +50,6 @@ def __bytes__(self) -> bytes: return bytes(result) - def from_bytes(self, data: bytes): - field_type = copy.deepcopy(self.type) - self.value = tuple(field_type.from_bytes(chunk).value for chunk in byte_chunks(data, len(field_type))) - return self - def from_stream(self, read_func: Callable[[int], bytes]): field_type = copy.deepcopy(self.type) self.value = tuple(field_type.from_stream(read_func).value for _ in range(len(self))) @@ -116,7 +111,7 @@ def extend(self, other: Iterable) -> None: self.value = self.data -class Array(_Sequence): +class Array(_Sequence, ABC): def __init__(self, length: int, field_type: FieldType = UInt8, value: Optional[Sequence[Any]] = (), @@ -177,21 +172,21 @@ def value(self, value): # This assumes that the Struct will update the length field's value self.length = len(value) - def from_bytes(self, data: bytes): - if isinstance(self.type, Field): - return super().from_bytes(data[:len(self) * len(self.type)]) - else: - val = [] - for _ in range(len(self)): - next_obj = self.type.from_bytes(data) - val.append(next_obj) - data = data[len(bytes(next_obj)):] - self.value = val - return self + # def from_bytes(self, data: bytes): + # if isinstance(self.type, Field): + # return super().from_bytes(data[:len(self) * len(self.type)]) + # else: + # val = [] + # for _ in range(len(self)): + # next_obj = self.type.from_bytes(data) + # val.append(next_obj) + # data = data[len(bytes(next_obj)):] + # self.value = val + # return self def from_stream(self, read_func: Callable[[int], bytes]): if isinstance(self.type, Field): - return super().from_bytes(read_func(len(self) * len(self.type))) + return super().from_stream(read_func) self.value = [self.type.from_stream(read_func) for _ in range(len(self))] return self diff --git a/tests/test_message.py b/tests/test_message.py index 24e26e5..88b2ce6 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -90,7 +90,6 @@ class D(h.Struct): def test_opcode_field(): - class C1(h.Struct): x = h.UInt32 @@ -190,64 +189,104 @@ class FooterThatIsAlsoHeader(h.Struct): my_opcode = h.OpcodeField(h.UInt16, {FooterBody1: 0x02, FooterBody2: 0x00}) -def test_message_deserialization(): +def test_message_deserialization(): msg = MessageHeader() / MessageBody1() assert msg[MessageHeader].opcode == 0 - assert h.Message.from_bytes(MessageHeader, bytes(msg)) == msg - assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg))) == msg + assert h.Message.from_bytes(bytes(msg), MessageHeader) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader) == msg msg = MessageHeader() / MessageBody2() assert msg[MessageHeader].opcode == 1 - assert h.Message.from_bytes(MessageHeader, bytes(msg)) == msg - assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg))) == msg + assert h.Message.from_bytes(bytes(msg), MessageHeader) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader) == msg def test_message_deserialization_with_footer(): - msg = MessageHeader() / MessageBody1() / MessageFooter() - assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter) == msg - assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter) == msg + msg = MessageHeader() / MessageBody1() / MessageFooter() + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter) == msg - msg = MessageHeader() / MessageBody2() / MessageFooter() - assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter) == msg - assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter) == msg + msg = MessageHeader() / MessageBody2() / MessageFooter() + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter) == msg def test_message_deserialization_with_multiple_footers(): msg = MessageHeader() / MessageBody1() / MessageFooter() / AnotherFooter() - assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, AnotherFooter) == msg - assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, AnotherFooter) == msg + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, AnotherFooter) == msg msg = MessageHeader() / MessageBody2() / MessageFooter() / AnotherFooter() - assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, AnotherFooter) == msg - assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, AnotherFooter) == msg + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, AnotherFooter) == msg def test_message_deserialization_with_multiple_footers_and_multiple_opcodes(): msg = MessageHeader() / MessageBody1() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody1() / AnotherFooter() - assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg - assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, + AnotherFooter) == msg msg = MessageHeader() / MessageBody2() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody1() / AnotherFooter() - assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg - assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, + AnotherFooter) == msg msg = MessageHeader() / MessageBody1() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody2() / AnotherFooter() - assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg - assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, + AnotherFooter) == msg msg = MessageHeader() / MessageBody2() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody2() / AnotherFooter() - assert h.Message.from_bytes(MessageHeader, bytes(msg), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg - assert h.Message.from_stream(MessageHeader, as_reader(bytes(msg)), MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, + AnotherFooter) == msg def test_message_deserialization_with_missing_opcode_field(): class MissingOpcode(h.Struct): x = h.UInt16() - with pytest.raises(ValueError): - h.Message.from_bytes(MissingOpcode, bytes(MissingOpcode())) + with pytest.raises(AttributeError): + h.Message.from_bytes(bytes(MissingOpcode()), MissingOpcode) + + with pytest.raises(AttributeError): + h.Message.from_stream(as_reader(bytes(MissingOpcode())), MissingOpcode) + + +def test_message_deserialization_with_recursive_opcode_field(): + class LiorN(h.Struct): + y = h.UInt32(0x16051999) + + class Fuad(h.Struct): + x = h.UInt32(0x17082001) + + class Ralon(h.Struct): + z = h.UInt32(0x29062000) + + class Sherman(h.Struct): + a = h.UInt32(0x15101996) + + class H3(h.Struct): + opcode = h.OpcodeField(h.UInt16, {Fuad: 0x02, LiorN: 0x03}) + + class H2(h.Struct): + opcode = h.OpcodeField(h.UInt16, {H3: 0x00, Ralon: 0x01}) + + class H1(h.Struct): + opcode = h.OpcodeField(h.UInt16, {H2: 0x00, Sherman: 0x01}) + + class F1(h.Struct): + damn = h.UInt32(0xFACEB00C) - with pytest.raises(ValueError): - h.Message.from_stream(MissingOpcode, as_reader(bytes(MissingOpcode()))) + msg = H1() / Sherman() / F1() + assert h.Message.from_bytes(bytes(msg), H1, F1) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), H1, F1) == msg + msg = H1() / H2() / Ralon() / Footer() + assert h.Message.from_bytes(bytes(msg), H1, Footer) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), H1, Footer) == msg + msg = H1() / H2() / H3() / Fuad() / F1() + assert h.Message.from_bytes(bytes(msg), H1, F1) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), H1, F1) == msg