diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index ddf59b4..297b7cf 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -35,7 +35,8 @@ MOONCAKE_STORE_IMPORTED = False BATCH_SIZE_LIMIT: int = 400 -MAX_WORKER_THREADS = 4 +MAX_BATCH_WORKER_THREADS = 4 +MAX_SERIAL_WORKER_THREADS = 4 MAX_RETRIES = 3 RETRY_DELAY_SECONDS = 1.0 @@ -131,7 +132,7 @@ def put(self, keys: list[str], values: list[Any]) -> list[dict | None]: tensor_futures: list[Future[None]] = [] bytes_futures: list[Future[list[int]]] = [] - with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor: + with ThreadPoolExecutor(max_workers=MAX_BATCH_WORKER_THREADS) as executor: for i in range(0, len(tensor_keys), BATCH_SIZE_LIMIT): batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT] batch_tensors = tensor_values[i : i + BATCH_SIZE_LIMIT] @@ -142,12 +143,13 @@ def put(self, keys: list[str], values: list[Any]) -> list[dict | None]: batch_values = non_tensor_values[i : i + BATCH_SIZE_LIMIT] bytes_futures.append(executor.submit(self._put_bytes_thread_worker, batch_keys, batch_values)) - for tf in tensor_futures: - tf.result() packed_sizes: list[int] = [] for bf in bytes_futures: packed_sizes.extend(bf.result()) + for tf in tensor_futures: + tf.result() + # bytes results arrive in non-tensor submit order, which matches the order of # non-tensor values; walk values once to scatter packed_size back to its key slot. sizes_iter = iter(packed_sizes) @@ -187,7 +189,9 @@ def alloc(sizes: list[int]) -> list[Tensor]: buffers, _, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) return buffers - buffers, batch_sizes = serial_utils.batch_encode_into(batch_values, alloc) + buffers, batch_sizes = serial_utils.batch_encode_into( + batch_values, alloc, num_workers=MAX_SERIAL_WORKER_THREADS + ) batch_ptrs = [cast(Tensor, b).data_ptr() for b in buffers] self._register_all_buffers(region_ptrs, region_sizes) @@ -223,27 +227,43 @@ def get( if not (len(keys) == len(shapes) == len(dtypes)): raise ValueError("Lengths of keys, shapes, dtypes must match") - tensor_indices = [] - non_tensor_indices = [] + tensor_indices: list[int] = [] + tensor_keys: list[str] = [] + tensor_shapes: list[Any] = [] + tensor_dtypes: list[Any] = [] + non_tensor_indices: list[int] = [] + non_tensor_keys: list[str] = [] + non_tensor_packed_sizes: list[int] = [] for i, dtype in enumerate(dtypes): if dtype is not None: tensor_indices.append(i) + tensor_keys.append(keys[i]) + tensor_shapes.append(shapes[i]) + tensor_dtypes.append(dtype) else: non_tensor_indices.append(i) + non_tensor_keys.append(keys[i]) if non_tensor_indices and (custom_backend_meta is None or len(custom_backend_meta) != len(keys)): raise ValueError("custom_backend_meta with per-key packed_size is required when any dtype is None.") + if non_tensor_indices: + assert custom_backend_meta is not None + for j in non_tensor_indices: + meta = custom_backend_meta[j] + assert meta is not None + non_tensor_packed_sizes.append(meta["packed_size"]) + results = [None] * len(keys) futures = [] - with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor: + with ThreadPoolExecutor(max_workers=MAX_BATCH_WORKER_THREADS) as executor: for i in range(0, len(tensor_indices), BATCH_SIZE_LIMIT): + batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT] + batch_shapes = tensor_shapes[i : i + BATCH_SIZE_LIMIT] + batch_dtypes = tensor_dtypes[i : i + BATCH_SIZE_LIMIT] batch_indexes = tensor_indices[i : i + BATCH_SIZE_LIMIT] - batch_keys = [keys[i] for i in batch_indexes] - batch_shapes = [shapes[i] for i in batch_indexes] - batch_dtypes = [dtypes[i] for i in batch_indexes] futures.append( executor.submit( self._get_tensors_thread_worker, batch_keys, batch_shapes, batch_dtypes, batch_indexes @@ -251,10 +271,9 @@ def get( ) for i in range(0, len(non_tensor_indices), BATCH_SIZE_LIMIT): + batch_keys = non_tensor_keys[i : i + BATCH_SIZE_LIMIT] + batch_packed_sizes = non_tensor_packed_sizes[i : i + BATCH_SIZE_LIMIT] batch_indexes = non_tensor_indices[i : i + BATCH_SIZE_LIMIT] - batch_keys = [keys[i] for i in batch_indexes] - assert custom_backend_meta is not None # guaranteed by the check above - batch_packed_sizes = [cast(dict, custom_backend_meta[j])["packed_size"] for j in batch_indexes] futures.append( executor.submit(self._get_bytes_thread_worker, batch_keys, batch_packed_sizes, batch_indexes) ) diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index 2379839..170e2e8 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -459,8 +459,7 @@ def batch_encode_into( the corresponding buffer list. ``buffers[i]`` must be an ``np.ndarray`` or ``memoryview`` holding at least ``sizes[i]`` bytes. - num_workers: Thread count for parallel packing. Default 1 (serial); - set ``>1`` only when the upper layer is single-threaded. + num_workers: Thread count for parallel packing. Default 1 (serial). Returns: tuple[list[np.ndarray | memoryview], list[int]]: The buffers returned by