From 89c4651e3d611a8e53eb11170d12f4f6135e699d Mon Sep 17 00:00:00 2001 From: Mridul Sahu Date: Thu, 24 Apr 2025 12:23:10 -0700 Subject: [PATCH] Make CloudPathwaysArrayHandler compatible with async directory creation feature in orbax. PiperOrigin-RevId: 751089745 --- pathwaysutils/persistence/orbax_handler.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/pathwaysutils/persistence/orbax_handler.py b/pathwaysutils/persistence/orbax_handler.py index 7228254..13b47cf 100644 --- a/pathwaysutils/persistence/orbax_handler.py +++ b/pathwaysutils/persistence/orbax_handler.py @@ -63,6 +63,18 @@ def __init__( raise ValueError("OCDBT not supported for Pathways.") super().__init__() + async def _background_serialize( + self, + values: Sequence[jax.Array], + locations: Sequence[str], + names: Sequence[str], + ) -> None: + """Uses Pathways Persistence API to serialize a jax array.""" + f = functools.partial(helper.write_one_array, timeout=self._read_timeout) + futures_results = list(map(f, locations, names, values)) + for future_result in futures_results: + future_result.result() + async def serialize( self, values: Sequence[jax.Array], @@ -76,8 +88,12 @@ async def serialize( raise ValueError("Casting during save not supported for Pathways.") locations, names = extract_parent_dir_and_name(infos) - f = functools.partial(helper.write_one_array, timeout=self._read_timeout) - return list(map(f, locations, names, values)) + return [ + future.CommitFutureAwaitingContractedSignals( + self._background_serialize(values, locations, names), + name="cloud_pathways_array_handler", + ) + ] async def deserialize( self,