diff --git a/pathwaysutils/persistence/orbax_handler.py b/pathwaysutils/persistence/orbax_handler.py index 7228254..13b47cf 100644 --- a/pathwaysutils/persistence/orbax_handler.py +++ b/pathwaysutils/persistence/orbax_handler.py @@ -63,6 +63,18 @@ def __init__( raise ValueError("OCDBT not supported for Pathways.") super().__init__() + async def _background_serialize( + self, + values: Sequence[jax.Array], + locations: Sequence[str], + names: Sequence[str], + ) -> None: + """Uses Pathways Persistence API to serialize a jax array.""" + f = functools.partial(helper.write_one_array, timeout=self._read_timeout) + futures_results = list(map(f, locations, names, values)) + for future_result in futures_results: + future_result.result() + async def serialize( self, values: Sequence[jax.Array], @@ -76,8 +88,12 @@ async def serialize( raise ValueError("Casting during save not supported for Pathways.") locations, names = extract_parent_dir_and_name(infos) - f = functools.partial(helper.write_one_array, timeout=self._read_timeout) - return list(map(f, locations, names, values)) + return [ + future.CommitFutureAwaitingContractedSignals( + self._background_serialize(values, locations, names), + name="cloud_pathways_array_handler", + ) + ] async def deserialize( self,