From 5ee17216535f772b9fcc7308e818c9b9f5096e4e Mon Sep 17 00:00:00 2001 From: Pathways-on-Cloud Team Date: Wed, 2 Apr 2025 17:57:12 -0700 Subject: [PATCH] Group bulk write requests by device list. It avoids setting devices repeatedly for each request. PiperOrigin-RevId: 743346435 --- pathwaysutils/persistence/helper.py | 91 ++++++++++++++++++++++------- 1 file changed, 70 insertions(+), 21 deletions(-) diff --git a/pathwaysutils/persistence/helper.py b/pathwaysutils/persistence/helper.py index 48d3021..5ab2245 100644 --- a/pathwaysutils/persistence/helper.py +++ b/pathwaysutils/persistence/helper.py @@ -168,35 +168,52 @@ def get_read_request( dtype: np.dtype, shape: Sequence[int], sharding: jax.sharding.Sharding, - devices: Sequence[jax.Device], + devices: Sequence[jax.Device] | None, timeout: datetime.timedelta, return_dict: bool = False, ) -> Union[str, dict[str, Any]]: """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding.""" - if not isinstance(devices, np.ndarray): - devices = np.array(devices) - timeout_seconds, timeout_fractional_seconds = divmod( timeout.total_seconds(), 1 ) timeout_nanoseconds = timeout_fractional_seconds * 1e9 - d = { - "persistenceReadRequest": { - "b64_location": string_to_base64(location_path), - "shape": get_shape_info(dtype, shape), - "b64_name": string_to_base64(name), - "b64_hlo_sharding_string": get_hlo_sharding_string( - sharding, len(shape) - ), - "devices": { - "device_ids": [device.id for device in devices.flatten()] - }, - "timeout": { - "seconds": int(timeout_seconds), - "nanos": int(timeout_nanoseconds), - }, - } - } + + if devices is None: + d = { + "persistenceReadRequest": { + "b64_location": string_to_base64(location_path), + "shape": get_shape_info(dtype, shape), + "b64_name": string_to_base64(name), + "b64_hlo_sharding_string": get_hlo_sharding_string( + sharding, len(shape) + ), + "timeout": { + "seconds": int(timeout_seconds), + "nanos": int(timeout_nanoseconds), + }, + } + } + else: + if not isinstance(devices, np.ndarray): + devices = np.array(devices) + + d = { + "persistenceReadRequest": { + "b64_location": string_to_base64(location_path), + "shape": get_shape_info(dtype, shape), + "b64_name": string_to_base64(name), + "b64_hlo_sharding_string": get_hlo_sharding_string( + sharding, len(shape) + ), + "devices": { + "device_ids": [device.id for device in devices.flatten()] + }, + "timeout": { + "seconds": int(timeout_seconds), + "nanos": int(timeout_nanoseconds), + }, + } + } if return_dict: return d @@ -224,6 +241,38 @@ def get_bulk_read_request( ) +def get_bulk_read_request_per_device_list( + location_path: str, + names: Sequence[str], + dtypes: Sequence[np.dtype], + shapes: Sequence[Sequence[int]], + shardings: Sequence[jax.sharding.Sharding], + devices: Sequence[jax.Device], + timeout: datetime.timedelta, +) -> str: + """Returns a string representation of a bulk read request, reads multiple arrays with one call.""" + read_requests = [ + get_read_request( + location_path, name, dtype, shape, sharding, None, timeout, True + )["persistenceReadRequest"] + for name, dtype, shape, sharding in zip(names, dtypes, shapes, shardings) + ] + + if not isinstance(devices, np.ndarray): + devices = np.array(devices) + + return json.dumps({ + "bulk_persistence_read_request": { + "read_requests_per_device_list": { + "device_list": { + "device_ids": [device.id for device in devices.flatten()] + }, + "read_requests": read_requests, + } + } + }) + + def write_one_array( location: str, name: str,