Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/test_yuanrong_client_zero_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,22 @@ def storage_client(self, mock_kv_client):

def test_mset_mget_p2p(self, storage_client, mocker):
# Mock serialization/deserialization
def mock_serialization(obj):
def mock_encode(obj):
if isinstance(obj, torch.Tensor):
return [obj.numpy().tobytes()]
return [str(obj).encode("utf-8")]

def mock_deserialization(items):
data = items[0]
def mock_decode(frames):
data = frames[0]
if len(data) == 12:
return torch.from_numpy(np.frombuffer(data, dtype=np.float32).copy())
try:
return data.tobytes().decode("utf-8")
except UnicodeDecodeError:
return data

mocker.patch("transfer_queue.storage.clients.yuanrong_client._encoder.encode", side_effect=mock_serialization)
mocker.patch("transfer_queue.storage.clients.yuanrong_client._decoder.decode", side_effect=mock_deserialization)
mocker.patch("transfer_queue.utils.serial_utils.encode", side_effect=mock_encode)
mocker.patch("transfer_queue.utils.serial_utils.decode", side_effect=mock_decode)

stored_raw_buffers = []

Expand Down
102 changes: 19 additions & 83 deletions transfer_queue/storage/clients/yuanrong_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import struct
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Optional
Expand All @@ -23,7 +22,7 @@

from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.serial_utils import _decoder, _encoder
from transfer_queue.utils.serial_utils import batch_decode_from, batch_encode_into
from transfer_queue.utils.yuanrong_utils import find_reachable_host

logger = get_logger(__name__)
Expand Down Expand Up @@ -193,19 +192,11 @@ def _create_empty_npu_tensorlist(self, shapes: list[Any], dtypes: list[Any]) ->
class GeneralKVClientAdapter(StorageStrategy):
"""Adapter for general-purpose KV storage with serialization.
Using yr.datasystem.KVClient to connect datasystem backends.
The serialization method uses '_decoder' and '_encoder' from 'transfer_queue.utils.serial_utils'.
The serialization method uses 'batch_encode_into' and 'batch_decode_from' from 'transfer_queue.utils.serial_utils'.
"""

PUT_KEYS_LIMIT: int = 10_000
GET_CLEAR_KEYS_LIMIT: int = 10_000

# Header: number of entries (uint32, little-endian)
HEADER_FMT = "<I"
HEADER_SIZE = struct.calcsize(HEADER_FMT)
# Entry: (payload_offset: uint32, payload_size: uint32)
ENTRY_FMT = "<II"
ENTRY_SIZE = struct.calcsize(ENTRY_FMT)

DS_MAX_WORKERS: int = 16

def __init__(self, config: dict):
Expand Down Expand Up @@ -270,84 +261,23 @@ def clear(self, keys: list[str]) -> None:
batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
self._ds_client.delete(batch_keys)

@classmethod
def calc_packed_size(cls, items: list[memoryview]) -> int:
"""
Calculate the total size (in bytes) required to pack a list of memoryview items
into the structured binary format used by pack_into.

Args:
items: List of memoryview objects to be packed.

Returns:
Total buffer size in bytes.
"""
return cls.HEADER_SIZE + len(items) * cls.ENTRY_SIZE + sum(item.nbytes for item in items)

@classmethod
def pack_into(cls, target: memoryview, items: list[memoryview]):
"""
Pack multiple contiguous buffers into a single buffer.
┌───────────────┐
│ item_count │ uint32
├───────────────┤
│ entries │ N * item entries
├───────────────┤
│ payload blob │ N * concatenated buffers
└───────────────┘

Args:
target (memoryview): A writable memoryview returned by StateValueBuffer.MutableData().
It must be large enough to accommodate the total number of bytes of HEADER + ENTRY_TABLE + all items.
This buffer is usually mapped to shared memory or Zero-Copy memory area.
items (List[memoryview]): List of read-only memory views (e.g., from serialized objects).
Each item must support the buffer protocol and be readable as raw bytes.

"""
struct.pack_into(cls.HEADER_FMT, target, 0, len(items))

entry_offset = cls.HEADER_SIZE
payload_offset = cls.HEADER_SIZE + len(items) * cls.ENTRY_SIZE

target_tensor = torch.frombuffer(target, dtype=torch.uint8)

for item in items:
struct.pack_into(cls.ENTRY_FMT, target, entry_offset, payload_offset, item.nbytes)
src_tensor = torch.frombuffer(item, dtype=torch.uint8)
target_tensor[payload_offset : payload_offset + item.nbytes].copy_(src_tensor)
entry_offset += cls.ENTRY_SIZE
payload_offset += item.nbytes

@classmethod
def unpack_from(cls, source: memoryview) -> list[memoryview]:
"""
Unpack multiple contiguous buffers from a single packed buffer.
Args:
source (memoryview): The packed source buffer.
Returns:
list[memoryview]: List of unpacked contiguous buffers.
"""
mv = memoryview(source)
item_count = struct.unpack_from(cls.HEADER_FMT, mv, 0)[0]
offsets = []
for i in range(item_count):
offset, length = struct.unpack_from(cls.ENTRY_FMT, mv, cls.HEADER_SIZE + i * cls.ENTRY_SIZE)
offsets.append((offset, length))
return [mv[offset : offset + length] for offset, length in offsets]

def mset_zero_copy(self, keys: list[str], objs: list[Any]):
"""Store multiple objects in zero-copy mode using parallel serialization and buffer packing.

Args:
keys (list[str]): List of string keys under which the objects will be stored.
objs (list[Any]): List of Python objects to store (e.g., tensors, strings).
"""
items_list = [[memoryview(b) for b in _encoder.encode(obj)] for obj in objs]
packed_sizes = [self.calc_packed_size(items) for items in items_list]
buffers = self._ds_client.mcreate(keys, packed_sizes)
tasks = [(target.MutableData(), item) for target, item in zip(buffers, items_list, strict=True)]
with ThreadPoolExecutor(max_workers=self.DS_MAX_WORKERS) as executor:
list(executor.map(lambda p: self.pack_into(*p), tasks))
buffers: list = []

def alloc(sizes):
# DataSystem buffers must be converted via MutableData() to obtain
# a memoryview-compatible data structure for zero-copy packing.
mcreate_bufs = self._ds_client.mcreate(keys, sizes)
buffers.extend(mcreate_bufs)
return [buf.MutableData() for buf in mcreate_bufs]

batch_encode_into(objs, alloc, num_workers=self.DS_MAX_WORKERS)
self._ds_client.mset_buffer(buffers)

def mget_zero_copy(self, keys: list[str]) -> list[Any]:
Expand All @@ -360,7 +290,13 @@ def mget_zero_copy(self, keys: list[str]) -> list[Any]:
list[Any]: List of deserialized objects corresponding to the input keys.
"""
buffers = self._ds_client.get_buffers(keys)
return [_decoder.decode(self.unpack_from(buffer)) if buffer is not None else None for buffer in buffers]
valid_indexes = [i for i, buf in enumerate(buffers) if buf is not None]
valid_bufs = [buffers[i] for i in valid_indexes]
decoded_objs = batch_decode_from(valid_bufs)
results = [None] * len(keys)
for idx, obj in zip(valid_indexes, decoded_objs, strict=True):
results[idx] = obj
return results


@StorageClientFactory.register("YuanrongStorageClient")
Expand Down
Loading