diff --git a/hydration/__init__.py b/hydration/__init__.py index 97e3c8b..67733b7 100644 --- a/hydration/__init__.py +++ b/hydration/__init__.py @@ -8,9 +8,13 @@ from .message import Message, InclusiveLengthField, ExclusiveLengthField, OpcodeField from .fields import FieldPlaceholder -pre_bytes_hook = Struct.pre_bytes_hook -post_bytes_hook = Struct.post_bytes_hook -from_bytes_hook = Struct.from_bytes_hook +pre_serialization_hook = Struct.pre_serialization_hook +post_serialization_hook = Struct.post_serialization_hook +deserialization_hook = Struct.deserialization_hook + +pre_bytes_hook = Struct.pre_serialization_hook +post_bytes_hook = Struct.post_serialization_hook +from_bytes_hook = Struct.deserialization_hook LittleEndian = Endianness.LittleEndian BigEndian = Endianness.BigEndian @@ -25,5 +29,6 @@ 'Array', 'Vector', 'IPv4', 'FieldPlaceholder', 'ExactValueValidator', 'RangeValidator', 'FunctionValidator', 'SetValidator', 'Message', 'InclusiveLengthField', 'ExclusiveLengthField', 'OpcodeField', - 'pre_bytes_hook', 'post_bytes_hook', 'from_bytes_hook', + 'pre_serialization_hook', 'post_serialization_hook', 'deserialization_hook', + 'from_bytes_hook', 'pre_bytes_hook', 'post_bytes_hook', 'LittleEndian', 'BigEndian', 'NativeEndian', 'NetworkEndian'] diff --git a/hydration/base.py b/hydration/base.py index 0e849a6..a78ea59 100644 --- a/hydration/base.py +++ b/hydration/base.py @@ -6,7 +6,7 @@ from pyhooks import Hook, precall_register, postcall_register from typing import Callable, List, Iterable, Optional -from .helpers import as_obj, assert_no_property_override, as_type +from .helpers import as_obj, assert_no_property_override, as_type, as_stream from .scalars import Scalar, Enum from .fields import Field, VLA, FieldPlaceholder from .endianness import Endianness @@ -101,7 +101,7 @@ def __prepare__(mcs, name, bases, *args, **kwargs): class Struct(metaclass=StructMeta): __frozen = False _field_names: List[str] - _from_bytes_hooks = {} + _deserialization_hooks = {} @property def value(self): @@ -144,7 +144,7 @@ def __init__(self, *args, **kwargs): self.from_bytes = lambda data: self._from_bytes(data, *args) self.from_stream = lambda data: self._from_stream(data, *args) - self._from_bytes_hooks = {} + self._deserialization_hooks = {} # Deepcopy the fields so different instances of Struct have unique fields for name, field in self: @@ -207,8 +207,8 @@ def serialize(self) -> bytes: except struct.error as e: raise ValueError(str(e)) from e - pre_bytes_hook = precall_register('__bytes__') - post_bytes_hook = postcall_register('__bytes__') + pre_serialization_hook = precall_register('__bytes__') + post_serialization_hook = postcall_register('__bytes__') @classmethod def from_bytes(cls, data: bytes, *args): @@ -219,32 +219,7 @@ def from_bytes(cls, data: bytes, *args): :param args: Arguments for the __init__ of the Struct, if there's any :return The deserialized struct """ - - obj = cls(*args) - - for field_name in obj._field_names: - - # Get field for current field name - field = getattr(obj, field_name) - - obj.invoke_from_bytes_hooks(field) - - # Bytes hooks can change the field object, so get it again by name - field = getattr(obj, field_name) - - if isinstance(field, VLA): - field.length = int(getattr(obj, field.length_field_name)) - field.from_bytes(data) - data = data[len(bytes(field)):] - else: - split_index = field.size - - field_data, data = data[:split_index], data[split_index:] - field.value = field.from_bytes(field_data).value - with suppress(AttributeError): - field.validator.validate(field.value) - - return obj + return cls.from_stream(as_stream(data), *args) @classmethod def from_stream(cls, read_func: Callable[[int], bytes], *args): @@ -260,19 +235,21 @@ def from_stream(cls, read_func: Callable[[int], bytes], *args): obj = cls(*args) - for field in obj._fields: + for field_name in obj._field_names: + + # Get field for current field name + field = getattr(obj, field_name) - obj.invoke_from_bytes_hooks(field) + obj.invoke_deserialization_hooks(field) + + # Bytes hooks can change the field object, so get it again by name + field = getattr(obj, field_name) if isinstance(field, VLA): field.length = int(getattr(obj, field.length_field_name)) - data = read_func(field.length) - field.from_bytes(data) + field.from_stream(read_func) else: - read_size = field.size - - data = read_func(read_size) - field.value = field.from_bytes(data).value + field.value = field.from_stream(read_func).value with suppress(AttributeError): field.validator.validate(field.value) @@ -304,29 +281,33 @@ def __setattr__(self, key, value): # Overriding fields but saving the hooks elif key in self._field_names: # Save the hooks from the field - hooks = getattr(getattr(self, key), '_from_bytes_hooks', []) + hooks = getattr(getattr(self, key), '_deserialization_hooks', []) # Set the field to the new value super().__setattr__(key, value) # Inject the old hooks to the new field - setattr(getattr(self, key), '_from_bytes_hooks', hooks) + setattr(getattr(self, key), '_deserialization_hooks', hooks) elif hasattr(self, key) or not self.__frozen: super().__setattr__(key, value) else: raise AttributeError("Struct doesn't allow defining new attributes") - def invoke_from_bytes_hooks(self, field: Field): - for f in getattr(field, '_from_bytes_hooks', ()): + def invoke_deserialization_hooks(self, field: Field): + for f in getattr(field, '_deserialization_hooks', ()): f(self) - + @classmethod def from_bytes_hook(cls, field): + return cls.deserialization_hook(field) + + @classmethod + def deserialization_hook(cls, field): # noinspection PyProtectedMember def register_field_hook(func: callable): - if hasattr(field, '_from_bytes_hooks'): - field._from_bytes_hooks.append(func) + if hasattr(field, '_deserialization_hooks'): + field._deserialization_hooks.append(func) else: - field._from_bytes_hooks = [func] + field._deserialization_hooks = [func] return func return register_field_hook diff --git a/hydration/fields.py b/hydration/fields.py index 3f22eb2..408d2e4 100644 --- a/hydration/fields.py +++ b/hydration/fields.py @@ -1,6 +1,7 @@ import abc from abc import ABC -from typing import Union +from hydration.helpers import as_stream +from typing import Union, Callable from .validators import ValidatorABC @@ -47,8 +48,11 @@ def size(self): def __bytes__(self) -> bytes: raise NotImplementedError - @abc.abstractmethod def from_bytes(self, data: bytes): + return self.from_stream(as_stream(data)) + + @abc.abstractmethod + def from_stream(self, read_func: Callable[[int], bytes]): raise NotImplementedError def __eq__(self, other): @@ -116,5 +120,5 @@ def __len__(self) -> int: def __bytes__(self) -> bytes: raise AttributeError('Placeholders cannot be serialized') - def from_bytes(self, data: bytes): + def from_stream(self, read_func: Callable[[int], bytes]): raise AttributeError('Placeholders cannot be deserialized') diff --git a/hydration/helpers.py b/hydration/helpers.py index 928d6b3..20cea6b 100644 --- a/hydration/helpers.py +++ b/hydration/helpers.py @@ -20,3 +20,14 @@ def assert_no_property_override(obj, base_class): if (isinstance(getattr(base_class, attr_name), property) and not isinstance(getattr(type(obj), attr_name), property)): raise NameError(f"'{attr_name}' is an invalid name for an attribute in a sequenced or nested struct") + +def as_stream(data: bytes): + class Reader: + def __init__(self, content: bytes): + self._data = content + + def read(self, size=0): + user_data, self._data = self._data[:size], self._data[size:] + return user_data + + return Reader(data).read \ No newline at end of file diff --git a/hydration/message.py b/hydration/message.py index d54cc83..76eb35f 100644 --- a/hydration/message.py +++ b/hydration/message.py @@ -1,7 +1,7 @@ import inspect from abc import ABC, abstractmethod from contextlib import suppress -from typing import List, Union, Type, Mapping +from typing import List, Union, Type, Mapping, Callable from hydration.helpers import as_obj from .base import Struct @@ -182,9 +182,9 @@ def size(self): def __bytes__(self) -> bytes: return bytes(self.data_field) - - def from_bytes(self, data: bytes): - return self.data_field.from_bytes(data) + + def from_stream(self, read_func: Callable[[int], bytes]): + return self.data_field.from_stream(read_func) @abstractmethod def update(self, message: Message, struct: Struct, struct_index: int): diff --git a/hydration/scalars.py b/hydration/scalars.py index 6b0af12..2de934d 100644 --- a/hydration/scalars.py +++ b/hydration/scalars.py @@ -122,8 +122,9 @@ def __int__(self) -> int: def __float__(self) -> float: return float(self.value) - def from_bytes(self, data: bytes): + def from_stream(self, read_func: Callable[[int], bytes]): format_string = '{}{}'.format(self.endianness_format, self.scalar_format) + data = read_func(struct.calcsize(format_string)) # noinspection PyAttributeOutsideInit self.value = struct.unpack(format_string, data)[0] return self @@ -311,9 +312,9 @@ def __bytes__(self) -> bytes: return bytes(self.type) except ValueError as e: raise ValueError(f'Error serializing {repr(self)}:\n{str(e)}') - - def from_bytes(self, data: bytes): - self.type.from_bytes(data) + + def from_stream(self, read_func: Callable[[int], bytes]): + self.type.from_stream(read_func) self.value = self.type.value return self diff --git a/hydration/vectors.py b/hydration/vectors.py index 343d953..3bd0f9c 100644 --- a/hydration/vectors.py +++ b/hydration/vectors.py @@ -1,7 +1,7 @@ import copy from abc import ABC from collections import UserList -from typing import Sequence, Optional, Any, Union, Iterable +from typing import Sequence, Optional, Any, Union, Iterable, Callable from itertools import islice from .base import Struct @@ -153,6 +153,9 @@ def __delitem__(self, key) -> None: def size(self): return len(self) * len(self.type) + def from_stream(self, read_func: Callable[[int], bytes]): + return self.from_bytes(read_func(self.size)) + class Vector(_Sequence, VLA): @@ -172,15 +175,14 @@ def value(self, value): # This assumes that the Struct will update the length field's value self.length = len(value) - def from_bytes(self, data: bytes): + def from_stream(self, read_func: Callable[[int], bytes]): if isinstance(self.type, Field): - return super().from_bytes(data[:len(self) * len(self.type)]) + return super().from_bytes(read_func(len(self) * len(self.type))) else: val = [] for _ in range(len(self)): - next_obj = self.type.from_bytes(data) + next_obj = self.type.from_stream(read_func) val.append(next_obj) - data = data[len(bytes(next_obj)):] self.value = val return self