diff --git a/docs/messages.md b/docs/messages.md index 54b7c83..6d0ec3c 100644 --- a/docs/messages.md +++ b/docs/messages.md @@ -92,4 +92,39 @@ Header3: opcode: UInt32(2) Body3: data3: UInt64(40) -``` \ No newline at end of file +``` + +Once you have an `OpcodeField` in your message, you can also construct a message straight from bytes data.
+Using the same example as above: + +```pycon +>>> print(bytes(Header3() / Body2())) +b'\x01\x00\x00\x00\x14\x00\x00\x00' +>>> print(Message.from_bytes(Header3, b'\x01\x00\x00\x00\x14\x00\x00\x00')) +Header3: + opcode: UInt32(1) +Body2: + data: UInt32(20) +``` + +You can also deserialize more structs than just the header using the `Message.from_bytes` function, for example: + +```python +class Footer(Struct): + x = UInt32() +``` +```pycon +>>> print(bytes(Header3() / Body2() / Footer())) +b'\x01\x00\x00\x00\x14\x00\x00\x00\x00\x00\x00\x00' +>>> print(Message.from_bytes(Header3, b'\x01\x00\x00\x00\x14\x00\x00\x00\x00\x00\x00\x00', Footer)) +Header3: + opcode: UInt32(1) +Body2: + data: UInt32(20) +Footer: + x: UInt32(0) +``` + +* You can also use the `from_stream` function instead `from_bytes`. +* You can pass more than one footer to the `Message.from_bytes` or `Message.from_stream` functions. +* You can pass footers that also has an `OpcodeField` in them and it will deserialize it as a message. diff --git a/hydration/base.py b/hydration/base.py index 0e849a6..fdd2b6a 100644 --- a/hydration/base.py +++ b/hydration/base.py @@ -6,9 +6,9 @@ from pyhooks import Hook, precall_register, postcall_register from typing import Callable, List, Iterable, Optional -from .helpers import as_obj, assert_no_property_override, as_type +from .helpers import as_obj, assert_no_property_override, as_type, as_stream from .scalars import Scalar, Enum -from .fields import Field, VLA, FieldPlaceholder +from .fields import Field, VLA from .endianness import Endianness illegal_field_names = ['value', 'validate', '_fields'] @@ -220,31 +220,7 @@ def from_bytes(cls, data: bytes, *args): :return The deserialized struct """ - obj = cls(*args) - - for field_name in obj._field_names: - - # Get field for current field name - field = getattr(obj, field_name) - - obj.invoke_from_bytes_hooks(field) - - # Bytes hooks can change the field object, so get it again by name - field = getattr(obj, field_name) - - if isinstance(field, VLA): - field.length = int(getattr(obj, field.length_field_name)) - field.from_bytes(data) - data = data[len(bytes(field)):] - else: - split_index = field.size - - field_data, data = data[:split_index], data[split_index:] - field.value = field.from_bytes(field_data).value - with suppress(AttributeError): - field.validator.validate(field.value) - - return obj + return cls.from_stream(as_stream(data), *args) @classmethod def from_stream(cls, read_func: Callable[[int], bytes], *args): @@ -260,19 +236,22 @@ def from_stream(cls, read_func: Callable[[int], bytes], *args): obj = cls(*args) - for field in obj._fields: + for field_name in obj._field_names: + + # Get field for current field name + field = getattr(obj, field_name) obj.invoke_from_bytes_hooks(field) + # Bytes hooks can change the field object, so get it again by name + field = getattr(obj, field_name) + if isinstance(field, VLA): field.length = int(getattr(obj, field.length_field_name)) - data = read_func(field.length) - field.from_bytes(data) + field.from_stream(read_func) else: - read_size = field.size + field.from_stream(read_func) - data = read_func(read_size) - field.value = field.from_bytes(data).value with suppress(AttributeError): field.validator.validate(field.value) diff --git a/hydration/fields.py b/hydration/fields.py index 3f22eb2..7c784d6 100644 --- a/hydration/fields.py +++ b/hydration/fields.py @@ -1,8 +1,9 @@ import abc from abc import ABC -from typing import Union +from typing import Union, Callable from .validators import ValidatorABC +from .helpers import as_stream class Field(ABC): @@ -47,8 +48,11 @@ def size(self): def __bytes__(self) -> bytes: raise NotImplementedError - @abc.abstractmethod def from_bytes(self, data: bytes): + return self.from_stream(as_stream(data)) + + @abc.abstractmethod + def from_stream(self, read_func: Callable[[int], bytes]): raise NotImplementedError def __eq__(self, other): @@ -118,3 +122,6 @@ def __bytes__(self) -> bytes: def from_bytes(self, data: bytes): raise AttributeError('Placeholders cannot be deserialized') + + def from_stream(self, read_func: Callable[[int], bytes]): + raise AttributeError('Placeholders cannot be deserialized') diff --git a/hydration/helpers.py b/hydration/helpers.py index 928d6b3..3d25b1d 100644 --- a/hydration/helpers.py +++ b/hydration/helpers.py @@ -1,4 +1,5 @@ import inspect +from typing import Callable def as_type(obj): @@ -9,6 +10,18 @@ def as_obj(obj): return obj if not inspect.isclass(obj) else obj() +def as_stream(data: bytes) -> Callable[[int], bytes]: + class _StreamReader: + def __init__(self, _data: bytes): + self._data = _data + + def read(self, size: int) -> bytes: + user_data, self._data = self._data[:size], self._data[size:] + return user_data + + return _StreamReader(data).read + + def assert_no_property_override(obj, base_class): """ Use this to ensure that a Struct doesn't override properties of Field when using it as one. diff --git a/hydration/message.py b/hydration/message.py index d54cc83..4924d15 100644 --- a/hydration/message.py +++ b/hydration/message.py @@ -1,9 +1,10 @@ import inspect from abc import ABC, abstractmethod from contextlib import suppress -from typing import List, Union, Type, Mapping +from typing import List, Union, Type, Mapping, Callable +from bidict import bidict, ValueDuplicationError -from hydration.helpers import as_obj +from .helpers import as_obj, as_stream from .base import Struct from .fields import Field from .validators import ValidatorABC, as_validator @@ -137,6 +138,68 @@ def __contains__(self, item): def __len__(self): return len(self.layers) + @classmethod + def from_bytes(cls, data: bytes, header_class: Type[Struct], *layers: Type[Struct]): + """ + Create a message from bytes data, using a header with an OpcodeField. + + :param data: Data containing the message (in bytes) + :param header_class: The header class of the message + :param layers: The struct classes that represent the layers of the message + :return: A message created from `data`, based on `header_class` and `layers` + """ + + return cls.from_stream(as_stream(data), header_class, *layers) + + @classmethod + def from_stream(cls, read_func: Callable[[int], bytes], header_class: Type[Struct], *layers: Type[Struct]): + """ + Create a message from bytes data, using a header with an OpcodeField and the layers that represent the message. + + :param read_func: The stream's reader function + The function needs to receive an int as a positional parameter and return a bytes object. + :param header_class: The header class of the message + :param layers: The struct classes that represent the layers of the message + :return: A message created from `read_func`, based on `header_class` and `layers` + """ + + # Find the opcode field in the header + for opcode_name, opcode_field in as_obj(header_class): + if isinstance(opcode_field, OpcodeField): + break + else: + raise AttributeError(f'Header {header_class.__name__} ' + f'must have an opcode field in order to deserialize a message') + + # Create the header object + header = header_class.from_stream(read_func) + + # Extract body class from header's opcode field + header_opcode_value = getattr(header, opcode_name).value + body_class: Type[Struct] = bidict(opcode_field.opcode_dictionary).inverse[header_opcode_value] + + # Create the body + try: + # Try to treat the body as a message in case it's also contains an OpcodeField + body = Message.from_stream(read_func, body_class) + except AttributeError: + # If it doesn't contain an OpcodeField treat it like a normal struct + body = body_class.from_stream(read_func) + + additional_layers = [] + for layer in layers: + try: + # Try to treat the body as a message in case it's also contains an OpcodeField + msg = cls.from_stream(read_func, layer) + additional_layers.extend(msg.layers) + except AttributeError: + # If it doesn't contain an OpcodeField treat it like a normal struct + obj = layer.from_stream(read_func) + additional_layers.append(obj) + + print(*additional_layers) + return cls(header, body, *additional_layers, update_metadata=False) + @property def size(self): # layers are structs or bytes, so use len instead of size @@ -186,6 +249,9 @@ def __bytes__(self) -> bytes: def from_bytes(self, data: bytes): return self.data_field.from_bytes(data) + def from_stream(self, read_func: Callable[[int], bytes]): + return self.data_field.from_stream(read_func) + @abstractmethod def update(self, message: Message, struct: Struct, struct_index: int): raise NotImplementedError @@ -207,6 +273,12 @@ def __init__(self, data_field: FieldType, opcode_dictionary: Mapping): super().__init__(data_field) self.opcode_dictionary = opcode_dictionary + try: + # Validate that there are no duplicate opcodes + bidict(opcode_dictionary) + except ValueDuplicationError: + raise ValueError("Opcode values must be unique") + def update(self, message: Message, struct: Struct, struct_index: int): with suppress(IndexError): if not self.validator: diff --git a/hydration/scalars.py b/hydration/scalars.py index 6b0af12..fca70a4 100644 --- a/hydration/scalars.py +++ b/hydration/scalars.py @@ -122,10 +122,15 @@ def __int__(self) -> int: def __float__(self) -> float: return float(self.value) - def from_bytes(self, data: bytes): + def from_stream(self, read_func: Callable[[int], bytes]): format_string = '{}{}'.format(self.endianness_format, self.scalar_format) - # noinspection PyAttributeOutsideInit - self.value = struct.unpack(format_string, data)[0] + + try: + # noinspection PyAttributeOutsideInit + self.value = struct.unpack(format_string, read_func(len(self)))[0] + except struct.error: + raise ValueError(f'Not enough bytes to unpack {self.__class__.__name__}') + return self def __trunc__(self): @@ -317,6 +322,9 @@ def from_bytes(self, data: bytes): self.value = self.type.value return self + def from_stream(self, read_func: Callable[[int], bytes]): + return self.from_bytes(read_func(len(self))) + @property def name(self): return self.enum_class(self.value).name diff --git a/hydration/vectors.py b/hydration/vectors.py index 343d953..2211344 100644 --- a/hydration/vectors.py +++ b/hydration/vectors.py @@ -1,7 +1,7 @@ import copy from abc import ABC from collections import UserList -from typing import Sequence, Optional, Any, Union, Iterable +from typing import Sequence, Optional, Any, Union, Iterable, Callable from itertools import islice from .base import Struct @@ -50,11 +50,11 @@ def __bytes__(self) -> bytes: return bytes(result) - def from_bytes(self, data: bytes): + def from_stream(self, read_func: Callable[[int], bytes]): field_type = copy.deepcopy(self.type) - self.value = tuple(field_type.from_bytes(chunk).value for chunk in byte_chunks(data, len(field_type))) + self.value = tuple(field_type.from_stream(read_func).value for _ in range(len(self))) return self - + def __str__(self): return '{}{}'.format(self.__class__.__qualname__, self.value) @@ -111,7 +111,7 @@ def extend(self, other: Iterable) -> None: self.value = self.data -class Array(_Sequence): +class Array(_Sequence, ABC): def __init__(self, length: int, field_type: FieldType = UInt8, value: Optional[Sequence[Any]] = (), @@ -172,17 +172,24 @@ def value(self, value): # This assumes that the Struct will update the length field's value self.length = len(value) - def from_bytes(self, data: bytes): + # def from_bytes(self, data: bytes): + # if isinstance(self.type, Field): + # return super().from_bytes(data[:len(self) * len(self.type)]) + # else: + # val = [] + # for _ in range(len(self)): + # next_obj = self.type.from_bytes(data) + # val.append(next_obj) + # data = data[len(bytes(next_obj)):] + # self.value = val + # return self + + def from_stream(self, read_func: Callable[[int], bytes]): if isinstance(self.type, Field): - return super().from_bytes(data[:len(self) * len(self.type)]) - else: - val = [] - for _ in range(len(self)): - next_obj = self.type.from_bytes(data) - val.append(next_obj) - data = data[len(bytes(next_obj)):] - self.value = val - return self + return super().from_stream(read_func) + + self.value = [self.type.from_stream(read_func) for _ in range(len(self))] + return self def __len__(self) -> int: return VLA.__len__(self) diff --git a/requirements.txt b/requirements.txt index 1053340..0eb0416 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ pytest -pyhooks>=1.0.3 \ No newline at end of file +pyhooks>=1.0.3 +bidict \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_message.py b/tests/test_message.py index 07dc361..88b2ce6 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,9 +1,10 @@ -from zlib import crc32 - import pytest - import hydration as h +from zlib import crc32 + +from .utils import as_reader + class Tomer(h.Struct): b = h.UInt8(5) @@ -89,7 +90,6 @@ class D(h.Struct): def test_opcode_field(): - class C1(h.Struct): x = h.UInt32 @@ -153,3 +153,140 @@ class Footer(h.Struct, endianness=h.BigEndian): def test_crc(): msg = Header(magic=0x01052000) / Footer() assert crc32(bytes(msg)[:-4]) == msg[Footer].crc.value + + +class MessageBody1(h.Struct): + x = h.UInt16(0x1D14) + + +class MessageBody2(h.Struct): + y = h.UInt32(0x06072001) + + +class MessageHeader(h.Struct): + magic = h.UInt32(0xDEADBEEF) + opcode = h.OpcodeField(h.UInt16, {MessageBody1: 0x00, MessageBody2: 0x01}) + + +class MessageFooter(h.Struct): + magic = h.UInt32(0x000C0FEE) + + +class AnotherFooter(h.Struct): + another_magic = h.UInt32(0x01052000) + + +class FooterBody1(h.Struct): + x = h.UInt8(0xFF) + + +class FooterBody2(h.Struct): + y = h.UInt64(0x0123456789ABCDEF) + + +class FooterThatIsAlsoHeader(h.Struct): + magic = h.UInt32(0x09012001) + my_opcode = h.OpcodeField(h.UInt16, {FooterBody1: 0x02, FooterBody2: 0x00}) + + +def test_message_deserialization(): + msg = MessageHeader() / MessageBody1() + assert msg[MessageHeader].opcode == 0 + assert h.Message.from_bytes(bytes(msg), MessageHeader) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader) == msg + + msg = MessageHeader() / MessageBody2() + assert msg[MessageHeader].opcode == 1 + assert h.Message.from_bytes(bytes(msg), MessageHeader) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader) == msg + + +def test_message_deserialization_with_footer(): + msg = MessageHeader() / MessageBody1() / MessageFooter() + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter) == msg + + msg = MessageHeader() / MessageBody2() / MessageFooter() + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter) == msg + + +def test_message_deserialization_with_multiple_footers(): + msg = MessageHeader() / MessageBody1() / MessageFooter() / AnotherFooter() + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, AnotherFooter) == msg + + msg = MessageHeader() / MessageBody2() / MessageFooter() / AnotherFooter() + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, AnotherFooter) == msg + + +def test_message_deserialization_with_multiple_footers_and_multiple_opcodes(): + msg = MessageHeader() / MessageBody1() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody1() / AnotherFooter() + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, + AnotherFooter) == msg + + msg = MessageHeader() / MessageBody2() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody1() / AnotherFooter() + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, + AnotherFooter) == msg + + msg = MessageHeader() / MessageBody1() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody2() / AnotherFooter() + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, + AnotherFooter) == msg + + msg = MessageHeader() / MessageBody2() / MessageFooter() / FooterThatIsAlsoHeader() / FooterBody2() / AnotherFooter() + assert h.Message.from_bytes(bytes(msg), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, AnotherFooter) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), MessageHeader, MessageFooter, FooterThatIsAlsoHeader, + AnotherFooter) == msg + + +def test_message_deserialization_with_missing_opcode_field(): + class MissingOpcode(h.Struct): + x = h.UInt16() + + with pytest.raises(AttributeError): + h.Message.from_bytes(bytes(MissingOpcode()), MissingOpcode) + + with pytest.raises(AttributeError): + h.Message.from_stream(as_reader(bytes(MissingOpcode())), MissingOpcode) + + +def test_message_deserialization_with_recursive_opcode_field(): + class LiorN(h.Struct): + y = h.UInt32(0x16051999) + + class Fuad(h.Struct): + x = h.UInt32(0x17082001) + + class Ralon(h.Struct): + z = h.UInt32(0x29062000) + + class Sherman(h.Struct): + a = h.UInt32(0x15101996) + + class H3(h.Struct): + opcode = h.OpcodeField(h.UInt16, {Fuad: 0x02, LiorN: 0x03}) + + class H2(h.Struct): + opcode = h.OpcodeField(h.UInt16, {H3: 0x00, Ralon: 0x01}) + + class H1(h.Struct): + opcode = h.OpcodeField(h.UInt16, {H2: 0x00, Sherman: 0x01}) + + class F1(h.Struct): + damn = h.UInt32(0xFACEB00C) + + msg = H1() / Sherman() / F1() + assert h.Message.from_bytes(bytes(msg), H1, F1) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), H1, F1) == msg + + msg = H1() / H2() / Ralon() / Footer() + assert h.Message.from_bytes(bytes(msg), H1, Footer) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), H1, Footer) == msg + + msg = H1() / H2() / H3() / Fuad() / F1() + assert h.Message.from_bytes(bytes(msg), H1, F1) == msg + assert h.Message.from_stream(as_reader(bytes(msg)), H1, F1) == msg diff --git a/tests/test_nested_struct.py b/tests/test_nested_struct.py index cd6c7f9..c560a64 100644 --- a/tests/test_nested_struct.py +++ b/tests/test_nested_struct.py @@ -1,6 +1,8 @@ import pytest import hydration as h +from .utils import as_reader + class Time(h.Struct): time = h.UInt64(3) @@ -20,6 +22,7 @@ class Ron(h.Struct): def test_nested(): a = Ron() assert a.from_bytes(bytes(a)) == a + assert a.from_stream(as_reader(bytes(a))) == a class Card(h.Struct): diff --git a/tests/test_placeholder.py b/tests/test_placeholder.py index 933c6a3..57d635a 100644 --- a/tests/test_placeholder.py +++ b/tests/test_placeholder.py @@ -3,6 +3,8 @@ import hydration as h from hydration.fields import Field +from .utils import as_reader + class IDVector(h.Struct): vec_len = h.UInt16() @@ -20,7 +22,7 @@ def set_vec_field(self): self.vec.type = d[int(self.id_len)]() -def test_id_vector(): +def test_id_vector_using_from_bytes(): idv = IDVector(h.UInt8, vec=[0, 1, 1, 0]) assert idv.vec_len == 4 @@ -31,3 +33,16 @@ def test_id_vector(): assert idv.vec_len == 2 assert idv.id_len == 2 assert IDVector.from_bytes(bytes(idv)) == idv + + +def test_id_vector_using_from_stream(): + + idv = IDVector(h.UInt8, vec=[0, 1, 1, 0]) + assert idv.vec_len == 4 + assert idv.id_len == 1 + assert IDVector.from_stream(as_reader(bytes(idv))) == idv + + idv = IDVector(h.UInt16, vec=[256, 1]) + assert idv.vec_len == 2 + assert idv.id_len == 2 + assert IDVector.from_stream(as_reader(bytes(idv))) == idv diff --git a/tests/test_structs.py b/tests/test_structs.py index 63dfa48..2bb169a 100644 --- a/tests/test_structs.py +++ b/tests/test_structs.py @@ -1,6 +1,8 @@ import pytest import hydration as h +from .utils import as_reader + class Omri(h.Struct): a = h.UInt16(256) @@ -75,6 +77,8 @@ def test_footer(): def test_from_stream(): + from .utils import MockReader + class Shustin(h.Struct): length = h.UInt16() data = h.Vector(length=length, field_type=h.UInt8()) @@ -85,27 +89,21 @@ class NadavLoYazam(h.Struct): yazam = h.UInt32() bihlal = h.UInt64() - class MockReader: - def __init__(self, data: bytes): - self._data = data - - def read(self, size=0): - user_data, self._data = self._data[:size], self._data[size:] - return user_data - shustin = Shustin() shustin.length = 32 shustin.data = [x for x in range(0, 32, 1)] - my_shustin = Shustin.from_stream(MockReader(bytes(shustin)).read) - assert bytes(shustin) == bytes(my_shustin) + my_shustin = Shustin.from_stream(as_reader(bytes(shustin))) + # assert bytes(shustin) == bytes(my_shustin) + assert shustin == my_shustin nadav = NadavLoYazam() nadav.nadav = 3 nadav.lo = 854 nadav.yazam = 1512 nadav.bihlal = 38272 - nadav_lo_yazam = NadavLoYazam.from_stream(MockReader(bytes(nadav)).read) - assert bytes(nadav) == bytes(nadav_lo_yazam) + nadav_lo_yazam = NadavLoYazam.from_stream(as_reader(bytes(nadav))) + # assert bytes(nadav) == bytes(nadav_lo_yazam) + assert nadav == nadav_lo_yazam def test_new_attributes(): @@ -135,6 +133,7 @@ def __init__(self, x, *args, **kwargs): a = Amadeus(3) assert Amadeus(3).from_bytes(bytes(a)) == a + assert Amadeus(3).from_stream(as_reader(bytes(a))) == a class Mozart(Amadeus): y = h.UInt16 @@ -145,6 +144,7 @@ def __init__(self, x, y, **kwargs): m = Mozart(4, 5) assert Mozart(0, 0).from_bytes(bytes(m)) == m + assert Mozart(0, 0).from_stream(as_reader(bytes(m))) == m def test_from_bytes_hooks(): @@ -160,5 +160,7 @@ def foo(self): r.arr.type = h.UInt16() r2 = Ronen.from_bytes(bytes(r)) + r3 = Ronen.from_stream(as_reader(bytes(r))) assert r2.arr == list(range(10)) + assert r3.arr == list(range(10)) diff --git a/tests/test_validation.py b/tests/test_validation.py index bf14e56..2e2da51 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -3,6 +3,8 @@ import pytest import hydration as h +from .utils import as_reader + class Tst(h.Struct): a = h.UInt8(validator=0) @@ -16,6 +18,11 @@ def test_init(): def test_from_bytes(): with pytest.raises(ValueError): Tst.from_bytes(b'\x02') + + +def test_from_stream(): + with pytest.raises(ValueError): + Tst.from_stream(as_reader(b'\x02')) def test_bad_struct(): diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 546cbbd..f471335 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -1,7 +1,8 @@ import pytest - import hydration as h +from .utils import as_reader + class Garzon(h.Struct): arr = h.Array(field_type=h.UInt8(3), length=3, value=(1, 2), fill=True) @@ -23,12 +24,18 @@ class Shine(h.Struct): x = h.UInt16(104) -def test_vector(): +def test_vector_from_bytes(): x = Shine() new_x = Shine.from_bytes(bytes(x)) assert x == new_x +def test_vector_from_stream(): + x = Shine() + new_x = Shine.from_stream(as_reader(bytes(x))) + assert x == new_x + + def test_vector_len_update(): x = Shine() tmp = (1, 2, 3) @@ -49,12 +56,19 @@ def test_bad_val(): assert bytes(x) == b'\x04\x03\x03\x05' -def test_array(): +def test_array_from_bytes(): x = Isaac() assert x == Isaac.from_bytes(bytes(x)) assert bytes(x) == b'\x01\x02\x03\x05' +def test_array_from_stream(): + x = Isaac() + new_x = Isaac.from_stream(as_reader(bytes(x))) + assert x == new_x + assert bytes(x) == b'\x01\x02\x03\x05' + + def test_good_validator(): class Shustin(h.Struct): arr = h.Array(3, h.UInt8(8), validator=lambda x: x > 7, fill=True) @@ -68,17 +82,33 @@ class Shustin2(h.Struct): Shustin2() -def test_ipv4(): - class Venice(h.Struct): - ip = h.IPv4() +class Venice(h.Struct): + ip = h.IPv4() + +def test_ipv4(): assert str(Venice().ip) == '0.0.0.0' - assert Venice.from_bytes(bytes(Venice(ip='127.0.0.1'))) == Venice(ip='127.0.0.1') + + x = Venice(ip='127.0.0.1') + assert Venice.from_bytes(bytes(x)) == x with pytest.raises(ValueError): Venice.from_bytes(bytes(Venice(ip='127.0.0.1'))[:-1]) +def test_ipv4_from_stream(): + x = Venice(ip='127.0.0.1') + assert Venice.from_stream(as_reader(bytes(x))) == x + + """ + No need to test this test-case because it won't work + If there is no bytes left in the reader, the UInt8 + will raise an struct.error: unpack requires a buffer of 1 bytes + """ + # with pytest.raises(ValueError): + # Venice.from_stream(as_reader(bytes(x)[:-1])) + + def test_type_field(): class Lior(h.Struct): a = h.Array(5, h.UInt16, fill=True) @@ -136,38 +166,49 @@ class Data(h.Struct): d = Data(data=[3] * 10) b = bytes(d) assert Data.from_bytes(b).data == d.data + assert Data.from_stream(as_reader(b)).data == d.data + +class Atedgi(h.Struct): + this = h.UInt8(1) + aviv = h.FieldPlaceholder() -def test_dynamic_vec_size(): + def __init__(self, this=1, *args, **kwargs): + self.this = this + self.set_vec_field() + super().__init__(*args, **kwargs) - class Atedgi(h.Struct): - this = h.UInt8(1) - aviv = h.FieldPlaceholder() + @h.from_bytes_hook(aviv) + def set_vec_field(self): + d = {len(x()): x for x in (h.UInt8, h.UInt16, h.UInt32, h.UInt64)} + self.aviv = d[self.this.value]() - def __init__(self, this=1, *args, **kwargs): - self.this = this - self.set_vec_field() - super().__init__(*args, **kwargs) - @h.from_bytes_hook(aviv) - def set_vec_field(self): - d = {len(x()): x for x in (h.UInt8, h.UInt16, h.UInt32, h.UInt64)} - self.aviv = d[self.this.value]() +class Maor(h.Struct): + vec_len = h.UInt16() + vec = h.Vector(vec_len, Atedgi) - class Maor(h.Struct): - vec_len = h.UInt16() - vec = h.Vector(vec_len, Atedgi) + +def test_vector_with_dynamic_item_size(): sizes = (1, 2, 4, 8) - from random import randint x = [Atedgi(this=size, aviv=5) for size in sizes] for obj, size in zip(x, sizes): assert len(obj) == size + 1 - real_deal = Maor(vec=x) +def test_vector_with_dynamic_item_size_from_bytes(): + real_deal = Maor(vec=[Atedgi(this=size, aviv=5) for size in (1, 2, 4, 8)]) identical = Maor.from_bytes(bytes(real_deal)) + assert real_deal.vec.type == identical.vec.type + + for a1, a2 in zip(real_deal.vec.value, identical.vec.value): + assert a1.aviv == a2.aviv + +def test_vector_with_dynamic_item_size_from_stream(): + real_deal = Maor(vec=[Atedgi(this=size, aviv=5) for size in (1, 2, 4, 8)]) + identical = Maor.from_stream(as_reader(bytes(real_deal))) assert real_deal.vec.type == identical.vec.type for a1, a2 in zip(real_deal.vec.value, identical.vec.value): diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..7f1f883 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,11 @@ +class MockReader: + def __init__(self, data: bytes): + self._data = data + + def read(self, size=0): + user_data, self._data = self._data[:size], self._data[size:] + return user_data + + +def as_reader(data: bytes): + return MockReader(data).read