Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 70 additions & 21 deletions pathwaysutils/persistence/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,35 +168,52 @@ def get_read_request(
dtype: np.dtype,
shape: Sequence[int],
sharding: jax.sharding.Sharding,
devices: Sequence[jax.Device],
devices: Sequence[jax.Device] | None,
timeout: datetime.timedelta,
return_dict: bool = False,
) -> Union[str, dict[str, Any]]:
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
devices = np.array(devices)

timeout_seconds, timeout_fractional_seconds = divmod(
timeout.total_seconds(), 1
)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
d = {
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"shape": get_shape_info(dtype, shape),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
sharding, len(shape)
),
"devices": {
"device_ids": [device.id for device in devices.flatten()]
},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
}

if devices is None:
d = {
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"shape": get_shape_info(dtype, shape),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
sharding, len(shape)
),
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
}
else:
if not isinstance(devices, np.ndarray):
devices = np.array(devices)

d = {
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"shape": get_shape_info(dtype, shape),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
sharding, len(shape)
),
"devices": {
"device_ids": [device.id for device in devices.flatten()]
},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
}

if return_dict:
return d
Expand Down Expand Up @@ -224,6 +241,38 @@ def get_bulk_read_request(
)


def get_bulk_read_request_per_device_list(
location_path: str,
names: Sequence[str],
dtypes: Sequence[np.dtype],
shapes: Sequence[Sequence[int]],
shardings: Sequence[jax.sharding.Sharding],
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
) -> str:
"""Returns a string representation of a bulk read request, reads multiple arrays with one call."""
read_requests = [
get_read_request(
location_path, name, dtype, shape, sharding, None, timeout, True
)["persistenceReadRequest"]
for name, dtype, shape, sharding in zip(names, dtypes, shapes, shardings)
]

if not isinstance(devices, np.ndarray):
devices = np.array(devices)

return json.dumps({
"bulk_persistence_read_request": {
"read_requests_per_device_list": {
"device_list": {
"device_ids": [device.id for device in devices.flatten()]
},
"read_requests": read_requests,
}
}
})


def write_one_array(
location: str,
name: str,
Expand Down