Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions hydration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from .message import Message, InclusiveLengthField, ExclusiveLengthField, OpcodeField
from .fields import FieldPlaceholder

pre_bytes_hook = Struct.pre_bytes_hook
post_bytes_hook = Struct.post_bytes_hook
from_bytes_hook = Struct.from_bytes_hook
pre_serialization_hook = Struct.pre_serialization_hook
post_serialization_hook = Struct.post_serialization_hook
deserialization_hook = Struct.deserialization_hook

pre_bytes_hook = Struct.pre_serialization_hook
post_bytes_hook = Struct.post_serialization_hook
from_bytes_hook = Struct.deserialization_hook

LittleEndian = Endianness.LittleEndian
BigEndian = Endianness.BigEndian
Expand All @@ -25,5 +29,6 @@
'Array', 'Vector', 'IPv4', 'FieldPlaceholder',
'ExactValueValidator', 'RangeValidator', 'FunctionValidator', 'SetValidator',
'Message', 'InclusiveLengthField', 'ExclusiveLengthField', 'OpcodeField',
'pre_bytes_hook', 'post_bytes_hook', 'from_bytes_hook',
'pre_serialization_hook', 'post_serialization_hook', 'deserialization_hook',
'from_bytes_hook', 'pre_bytes_hook', 'post_bytes_hook',
'LittleEndian', 'BigEndian', 'NativeEndian', 'NetworkEndian']
75 changes: 28 additions & 47 deletions hydration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyhooks import Hook, precall_register, postcall_register
from typing import Callable, List, Iterable, Optional

from .helpers import as_obj, assert_no_property_override, as_type
from .helpers import as_obj, assert_no_property_override, as_type, as_stream
from .scalars import Scalar, Enum
from .fields import Field, VLA, FieldPlaceholder
from .endianness import Endianness
Expand Down Expand Up @@ -101,7 +101,7 @@ def __prepare__(mcs, name, bases, *args, **kwargs):
class Struct(metaclass=StructMeta):
__frozen = False
_field_names: List[str]
_from_bytes_hooks = {}
_deserialization_hooks = {}

@property
def value(self):
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(self, *args, **kwargs):
self.from_bytes = lambda data: self._from_bytes(data, *args)
self.from_stream = lambda data: self._from_stream(data, *args)

self._from_bytes_hooks = {}
self._deserialization_hooks = {}

# Deepcopy the fields so different instances of Struct have unique fields
for name, field in self:
Expand Down Expand Up @@ -207,8 +207,8 @@ def serialize(self) -> bytes:
except struct.error as e:
raise ValueError(str(e)) from e

pre_bytes_hook = precall_register('__bytes__')
post_bytes_hook = postcall_register('__bytes__')
pre_serialization_hook = precall_register('__bytes__')
post_serialization_hook = postcall_register('__bytes__')

@classmethod
def from_bytes(cls, data: bytes, *args):
Expand All @@ -219,32 +219,7 @@ def from_bytes(cls, data: bytes, *args):
:param args: Arguments for the __init__ of the Struct, if there's any
:return The deserialized struct
"""

obj = cls(*args)

for field_name in obj._field_names:

# Get field for current field name
field = getattr(obj, field_name)

obj.invoke_from_bytes_hooks(field)

# Bytes hooks can change the field object, so get it again by name
field = getattr(obj, field_name)

if isinstance(field, VLA):
field.length = int(getattr(obj, field.length_field_name))
field.from_bytes(data)
data = data[len(bytes(field)):]
else:
split_index = field.size

field_data, data = data[:split_index], data[split_index:]
field.value = field.from_bytes(field_data).value
with suppress(AttributeError):
field.validator.validate(field.value)

return obj
return cls.from_stream(as_stream(data), *args)

@classmethod
def from_stream(cls, read_func: Callable[[int], bytes], *args):
Expand All @@ -260,19 +235,21 @@ def from_stream(cls, read_func: Callable[[int], bytes], *args):

obj = cls(*args)

for field in obj._fields:
for field_name in obj._field_names:

# Get field for current field name
field = getattr(obj, field_name)

obj.invoke_from_bytes_hooks(field)
obj.invoke_deserialization_hooks(field)

# Bytes hooks can change the field object, so get it again by name
field = getattr(obj, field_name)

if isinstance(field, VLA):
field.length = int(getattr(obj, field.length_field_name))
data = read_func(field.length)
field.from_bytes(data)
field.from_stream(read_func)
else:
read_size = field.size

data = read_func(read_size)
field.value = field.from_bytes(data).value
field.value = field.from_stream(read_func).value
with suppress(AttributeError):
field.validator.validate(field.value)

Expand Down Expand Up @@ -304,29 +281,33 @@ def __setattr__(self, key, value):
# Overriding fields but saving the hooks
elif key in self._field_names:
# Save the hooks from the field
hooks = getattr(getattr(self, key), '_from_bytes_hooks', [])
hooks = getattr(getattr(self, key), '_deserialization_hooks', [])
# Set the field to the new value
super().__setattr__(key, value)
# Inject the old hooks to the new field
setattr(getattr(self, key), '_from_bytes_hooks', hooks)
setattr(getattr(self, key), '_deserialization_hooks', hooks)
elif hasattr(self, key) or not self.__frozen:
super().__setattr__(key, value)
else:
raise AttributeError("Struct doesn't allow defining new attributes")

def invoke_from_bytes_hooks(self, field: Field):
for f in getattr(field, '_from_bytes_hooks', ()):
def invoke_deserialization_hooks(self, field: Field):
for f in getattr(field, '_deserialization_hooks', ()):
f(self)

@classmethod
def from_bytes_hook(cls, field):
return cls.deserialization_hook(field)

@classmethod
def deserialization_hook(cls, field):

# noinspection PyProtectedMember
def register_field_hook(func: callable):
if hasattr(field, '_from_bytes_hooks'):
field._from_bytes_hooks.append(func)
if hasattr(field, '_deserialization_hooks'):
field._deserialization_hooks.append(func)
else:
field._from_bytes_hooks = [func]
field._deserialization_hooks = [func]
return func

return register_field_hook
10 changes: 7 additions & 3 deletions hydration/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
from abc import ABC
from typing import Union
from hydration.helpers import as_stream
from typing import Union, Callable

from .validators import ValidatorABC

Expand Down Expand Up @@ -47,8 +48,11 @@ def size(self):
def __bytes__(self) -> bytes:
raise NotImplementedError

@abc.abstractmethod
def from_bytes(self, data: bytes):
return self.from_stream(as_stream(data))

@abc.abstractmethod
def from_stream(self, read_func: Callable[[int], bytes]):
raise NotImplementedError

def __eq__(self, other):
Expand Down Expand Up @@ -116,5 +120,5 @@ def __len__(self) -> int:
def __bytes__(self) -> bytes:
raise AttributeError('Placeholders cannot be serialized')

def from_bytes(self, data: bytes):
def from_stream(self, read_func: Callable[[int], bytes]):
raise AttributeError('Placeholders cannot be deserialized')
11 changes: 11 additions & 0 deletions hydration/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,14 @@ def assert_no_property_override(obj, base_class):
if (isinstance(getattr(base_class, attr_name), property) and
not isinstance(getattr(type(obj), attr_name), property)):
raise NameError(f"'{attr_name}' is an invalid name for an attribute in a sequenced or nested struct")

def as_stream(data: bytes):
class Reader:
def __init__(self, content: bytes):
self._data = content

def read(self, size=0):
user_data, self._data = self._data[:size], self._data[size:]
return user_data

return Reader(data).read
8 changes: 4 additions & 4 deletions hydration/message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import List, Union, Type, Mapping
from typing import List, Union, Type, Mapping, Callable

from hydration.helpers import as_obj
from .base import Struct
Expand Down Expand Up @@ -182,9 +182,9 @@ def size(self):

def __bytes__(self) -> bytes:
return bytes(self.data_field)

def from_bytes(self, data: bytes):
return self.data_field.from_bytes(data)
def from_stream(self, read_func: Callable[[int], bytes]):
return self.data_field.from_stream(read_func)

@abstractmethod
def update(self, message: Message, struct: Struct, struct_index: int):
Expand Down
9 changes: 5 additions & 4 deletions hydration/scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ def __int__(self) -> int:
def __float__(self) -> float:
return float(self.value)

def from_bytes(self, data: bytes):
def from_stream(self, read_func: Callable[[int], bytes]):
format_string = '{}{}'.format(self.endianness_format, self.scalar_format)
data = read_func(struct.calcsize(format_string))
# noinspection PyAttributeOutsideInit
self.value = struct.unpack(format_string, data)[0]
return self
Expand Down Expand Up @@ -311,9 +312,9 @@ def __bytes__(self) -> bytes:
return bytes(self.type)
except ValueError as e:
raise ValueError(f'Error serializing {repr(self)}:\n{str(e)}')

def from_bytes(self, data: bytes):
self.type.from_bytes(data)
def from_stream(self, read_func: Callable[[int], bytes]):
self.type.from_stream(read_func)
self.value = self.type.value
return self

Expand Down
12 changes: 7 additions & 5 deletions hydration/vectors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
from abc import ABC
from collections import UserList
from typing import Sequence, Optional, Any, Union, Iterable
from typing import Sequence, Optional, Any, Union, Iterable, Callable
from itertools import islice

from .base import Struct
Expand Down Expand Up @@ -153,6 +153,9 @@ def __delitem__(self, key) -> None:
def size(self):
return len(self) * len(self.type)

def from_stream(self, read_func: Callable[[int], bytes]):
return self.from_bytes(read_func(self.size))


class Vector(_Sequence, VLA):

Expand All @@ -172,15 +175,14 @@ def value(self, value):
# This assumes that the Struct will update the length field's value
self.length = len(value)

def from_bytes(self, data: bytes):
def from_stream(self, read_func: Callable[[int], bytes]):
if isinstance(self.type, Field):
return super().from_bytes(data[:len(self) * len(self.type)])
return super().from_bytes(read_func(len(self) * len(self.type)))
else:
val = []
for _ in range(len(self)):
next_obj = self.type.from_bytes(data)
next_obj = self.type.from_stream(read_func)
val.append(next_obj)
data = data[len(bytes(next_obj)):]
self.value = val
return self

Expand Down