Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions transfer_queue/storage/clients/mooncake_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
MOONCAKE_STORE_IMPORTED = False

BATCH_SIZE_LIMIT: int = 400
MAX_WORKER_THREADS = 4
MAX_BATCH_WORKER_THREADS = 4
MAX_SERIAL_WORKER_THREADS = 4
Comment thread
0oshowero0 marked this conversation as resolved.
MAX_RETRIES = 3
RETRY_DELAY_SECONDS = 1.0

Expand Down Expand Up @@ -131,7 +132,7 @@ def put(self, keys: list[str], values: list[Any]) -> list[dict | None]:

tensor_futures: list[Future[None]] = []
bytes_futures: list[Future[list[int]]] = []
with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor:
with ThreadPoolExecutor(max_workers=MAX_BATCH_WORKER_THREADS) as executor:
for i in range(0, len(tensor_keys), BATCH_SIZE_LIMIT):
batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT]
batch_tensors = tensor_values[i : i + BATCH_SIZE_LIMIT]
Expand All @@ -142,12 +143,13 @@ def put(self, keys: list[str], values: list[Any]) -> list[dict | None]:
batch_values = non_tensor_values[i : i + BATCH_SIZE_LIMIT]
bytes_futures.append(executor.submit(self._put_bytes_thread_worker, batch_keys, batch_values))

for tf in tensor_futures:
tf.result()
packed_sizes: list[int] = []
for bf in bytes_futures:
packed_sizes.extend(bf.result())

for tf in tensor_futures:
tf.result()

# bytes results arrive in non-tensor submit order, which matches the order of
# non-tensor values; walk values once to scatter packed_size back to its key slot.
sizes_iter = iter(packed_sizes)
Expand Down Expand Up @@ -187,7 +189,9 @@ def alloc(sizes: list[int]) -> list[Tensor]:
buffers, _, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
return buffers

buffers, batch_sizes = serial_utils.batch_encode_into(batch_values, alloc)
buffers, batch_sizes = serial_utils.batch_encode_into(
batch_values, alloc, num_workers=MAX_SERIAL_WORKER_THREADS
)
Comment thread
0oshowero0 marked this conversation as resolved.
batch_ptrs = [cast(Tensor, b).data_ptr() for b in buffers]

self._register_all_buffers(region_ptrs, region_sizes)
Expand Down Expand Up @@ -223,38 +227,53 @@ def get(
if not (len(keys) == len(shapes) == len(dtypes)):
raise ValueError("Lengths of keys, shapes, dtypes must match")

tensor_indices = []
non_tensor_indices = []
tensor_indices: list[int] = []
tensor_keys: list[str] = []
tensor_shapes: list[Any] = []
tensor_dtypes: list[Any] = []
non_tensor_indices: list[int] = []
non_tensor_keys: list[str] = []
non_tensor_packed_sizes: list[int] = []

for i, dtype in enumerate(dtypes):
if dtype is not None:
tensor_indices.append(i)
tensor_keys.append(keys[i])
tensor_shapes.append(shapes[i])
tensor_dtypes.append(dtype)
else:
non_tensor_indices.append(i)
non_tensor_keys.append(keys[i])

if non_tensor_indices and (custom_backend_meta is None or len(custom_backend_meta) != len(keys)):
raise ValueError("custom_backend_meta with per-key packed_size is required when any dtype is None.")

if non_tensor_indices:
assert custom_backend_meta is not None
for j in non_tensor_indices:
meta = custom_backend_meta[j]
assert meta is not None
non_tensor_packed_sizes.append(meta["packed_size"])

results = [None] * len(keys)

futures = []
with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor:
with ThreadPoolExecutor(max_workers=MAX_BATCH_WORKER_THREADS) as executor:
for i in range(0, len(tensor_indices), BATCH_SIZE_LIMIT):
batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT]
batch_shapes = tensor_shapes[i : i + BATCH_SIZE_LIMIT]
batch_dtypes = tensor_dtypes[i : i + BATCH_SIZE_LIMIT]
batch_indexes = tensor_indices[i : i + BATCH_SIZE_LIMIT]
batch_keys = [keys[i] for i in batch_indexes]
batch_shapes = [shapes[i] for i in batch_indexes]
batch_dtypes = [dtypes[i] for i in batch_indexes]
futures.append(
executor.submit(
self._get_tensors_thread_worker, batch_keys, batch_shapes, batch_dtypes, batch_indexes
)
)

for i in range(0, len(non_tensor_indices), BATCH_SIZE_LIMIT):
batch_keys = non_tensor_keys[i : i + BATCH_SIZE_LIMIT]
batch_packed_sizes = non_tensor_packed_sizes[i : i + BATCH_SIZE_LIMIT]
batch_indexes = non_tensor_indices[i : i + BATCH_SIZE_LIMIT]
batch_keys = [keys[i] for i in batch_indexes]
assert custom_backend_meta is not None # guaranteed by the check above
batch_packed_sizes = [cast(dict, custom_backend_meta[j])["packed_size"] for j in batch_indexes]
futures.append(
executor.submit(self._get_bytes_thread_worker, batch_keys, batch_packed_sizes, batch_indexes)
)
Expand Down
3 changes: 1 addition & 2 deletions transfer_queue/utils/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,7 @@ def batch_encode_into(
the corresponding buffer list. ``buffers[i]`` must be an
``np.ndarray`` or ``memoryview`` holding at least ``sizes[i]``
bytes.
num_workers: Thread count for parallel packing. Default 1 (serial);
set ``>1`` only when the upper layer is single-threaded.
num_workers: Thread count for parallel packing. Default 1 (serial).

Returns:
tuple[list[np.ndarray | memoryview], list[int]]: The buffers returned by
Expand Down
Loading