diff --git a/tests/test_yuanrong_client_zero_copy.py b/tests/test_yuanrong_client_zero_copy.py index c100d39..ea6a698 100644 --- a/tests/test_yuanrong_client_zero_copy.py +++ b/tests/test_yuanrong_client_zero_copy.py @@ -50,13 +50,13 @@ def storage_client(self, mock_kv_client): def test_mset_mget_p2p(self, storage_client, mocker): # Mock serialization/deserialization - def mock_serialization(obj): + def mock_encode(obj): if isinstance(obj, torch.Tensor): return [obj.numpy().tobytes()] return [str(obj).encode("utf-8")] - def mock_deserialization(items): - data = items[0] + def mock_decode(frames): + data = frames[0] if len(data) == 12: return torch.from_numpy(np.frombuffer(data, dtype=np.float32).copy()) try: @@ -64,8 +64,8 @@ def mock_deserialization(items): except UnicodeDecodeError: return data - mocker.patch("transfer_queue.storage.clients.yuanrong_client._encoder.encode", side_effect=mock_serialization) - mocker.patch("transfer_queue.storage.clients.yuanrong_client._decoder.decode", side_effect=mock_deserialization) + mocker.patch("transfer_queue.utils.serial_utils.encode", side_effect=mock_encode) + mocker.patch("transfer_queue.utils.serial_utils.decode", side_effect=mock_decode) stored_raw_buffers = [] diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 399cb58..0aebf22 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import struct from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Optional @@ -23,7 +22,7 @@ from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient from transfer_queue.utils.logging_utils import get_logger -from transfer_queue.utils.serial_utils import _decoder, _encoder +from transfer_queue.utils.serial_utils import batch_decode_from, batch_encode_into from transfer_queue.utils.yuanrong_utils import find_reachable_host logger = get_logger(__name__) @@ -193,19 +192,11 @@ def _create_empty_npu_tensorlist(self, shapes: list[Any], dtypes: list[Any]) -> class GeneralKVClientAdapter(StorageStrategy): """Adapter for general-purpose KV storage with serialization. Using yr.datasystem.KVClient to connect datasystem backends. - The serialization method uses '_decoder' and '_encoder' from 'transfer_queue.utils.serial_utils'. + The serialization method uses 'batch_encode_into' and 'batch_decode_from' from 'transfer_queue.utils.serial_utils'. """ PUT_KEYS_LIMIT: int = 10_000 GET_CLEAR_KEYS_LIMIT: int = 10_000 - - # Header: number of entries (uint32, little-endian) - HEADER_FMT = " None: batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT] self._ds_client.delete(batch_keys) - @classmethod - def calc_packed_size(cls, items: list[memoryview]) -> int: - """ - Calculate the total size (in bytes) required to pack a list of memoryview items - into the structured binary format used by pack_into. - - Args: - items: List of memoryview objects to be packed. - - Returns: - Total buffer size in bytes. - """ - return cls.HEADER_SIZE + len(items) * cls.ENTRY_SIZE + sum(item.nbytes for item in items) - - @classmethod - def pack_into(cls, target: memoryview, items: list[memoryview]): - """ - Pack multiple contiguous buffers into a single buffer. - ┌───────────────┐ - │ item_count │ uint32 - ├───────────────┤ - │ entries │ N * item entries - ├───────────────┤ - │ payload blob │ N * concatenated buffers - └───────────────┘ - - Args: - target (memoryview): A writable memoryview returned by StateValueBuffer.MutableData(). - It must be large enough to accommodate the total number of bytes of HEADER + ENTRY_TABLE + all items. - This buffer is usually mapped to shared memory or Zero-Copy memory area. - items (List[memoryview]): List of read-only memory views (e.g., from serialized objects). - Each item must support the buffer protocol and be readable as raw bytes. - - """ - struct.pack_into(cls.HEADER_FMT, target, 0, len(items)) - - entry_offset = cls.HEADER_SIZE - payload_offset = cls.HEADER_SIZE + len(items) * cls.ENTRY_SIZE - - target_tensor = torch.frombuffer(target, dtype=torch.uint8) - - for item in items: - struct.pack_into(cls.ENTRY_FMT, target, entry_offset, payload_offset, item.nbytes) - src_tensor = torch.frombuffer(item, dtype=torch.uint8) - target_tensor[payload_offset : payload_offset + item.nbytes].copy_(src_tensor) - entry_offset += cls.ENTRY_SIZE - payload_offset += item.nbytes - - @classmethod - def unpack_from(cls, source: memoryview) -> list[memoryview]: - """ - Unpack multiple contiguous buffers from a single packed buffer. - Args: - source (memoryview): The packed source buffer. - Returns: - list[memoryview]: List of unpacked contiguous buffers. - """ - mv = memoryview(source) - item_count = struct.unpack_from(cls.HEADER_FMT, mv, 0)[0] - offsets = [] - for i in range(item_count): - offset, length = struct.unpack_from(cls.ENTRY_FMT, mv, cls.HEADER_SIZE + i * cls.ENTRY_SIZE) - offsets.append((offset, length)) - return [mv[offset : offset + length] for offset, length in offsets] - def mset_zero_copy(self, keys: list[str], objs: list[Any]): """Store multiple objects in zero-copy mode using parallel serialization and buffer packing. @@ -342,12 +268,16 @@ def mset_zero_copy(self, keys: list[str], objs: list[Any]): keys (list[str]): List of string keys under which the objects will be stored. objs (list[Any]): List of Python objects to store (e.g., tensors, strings). """ - items_list = [[memoryview(b) for b in _encoder.encode(obj)] for obj in objs] - packed_sizes = [self.calc_packed_size(items) for items in items_list] - buffers = self._ds_client.mcreate(keys, packed_sizes) - tasks = [(target.MutableData(), item) for target, item in zip(buffers, items_list, strict=True)] - with ThreadPoolExecutor(max_workers=self.DS_MAX_WORKERS) as executor: - list(executor.map(lambda p: self.pack_into(*p), tasks)) + buffers: list = [] + + def alloc(sizes): + # DataSystem buffers must be converted via MutableData() to obtain + # a memoryview-compatible data structure for zero-copy packing. + mcreate_bufs = self._ds_client.mcreate(keys, sizes) + buffers.extend(mcreate_bufs) + return [buf.MutableData() for buf in mcreate_bufs] + + batch_encode_into(objs, alloc, num_workers=self.DS_MAX_WORKERS) self._ds_client.mset_buffer(buffers) def mget_zero_copy(self, keys: list[str]) -> list[Any]: @@ -360,7 +290,13 @@ def mget_zero_copy(self, keys: list[str]) -> list[Any]: list[Any]: List of deserialized objects corresponding to the input keys. """ buffers = self._ds_client.get_buffers(keys) - return [_decoder.decode(self.unpack_from(buffer)) if buffer is not None else None for buffer in buffers] + valid_indexes = [i for i, buf in enumerate(buffers) if buf is not None] + valid_bufs = [buffers[i] for i in valid_indexes] + decoded_objs = batch_decode_from(valid_bufs) + results = [None] * len(keys) + for idx, obj in zip(valid_indexes, decoded_objs, strict=True): + results[idx] = obj + return results @StorageClientFactory.register("YuanrongStorageClient")