From c660e2f127199ee9db0996a411f52836aa4f1190 Mon Sep 17 00:00:00 2001 From: Nikhil Bansal Date: Fri, 15 May 2026 04:49:24 -0700 Subject: [PATCH] Allow Snapshotter to handle PyTrees with non-Jax Array types. PiperOrigin-RevId: 915939920 --- .../v1/_src/training/pathways/snapshotter.py | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py index 02bba0a51..08c5e7553 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py @@ -47,25 +47,24 @@ def _worker(self): finally: self._queue.task_done() - def save_pytree( - self, step: int, state: tree_types.PyTreeOf[jax.Array] - ) -> None: + def save_pytree(self, step: int, state: tree_types.PyTree) -> None: """Move arrays onto CPU worker devices.""" if self._queue.full(): logging.warning("Snapshotter busy. Skipping snapshot for step %d", step) return - pinned_shardings = jax.tree.map( - lambda x: x.sharding.with_memory_kind("pinned_host"), state + pinned_state = jax.tree.map( + lambda x: jax.device_put(x, x.sharding.with_memory_kind("pinned_host")) + if hasattr(x, "sharding") + else x, + state, ) - pinned_state = jax.device_put(state, pinned_shardings) - self._queue.put((pinned_state, step)) def load_pytree( self, - abstract_state: tree_types.PyTreeOf[jax.Array], + abstract_state: tree_types.PyTree, *, reset_snapshot_state: bool = True, ) -> tree_types.PyTree: @@ -119,20 +118,32 @@ def get_active_pytree(x): ) return reconstructed_state - pinned_state = jax.tree.map(get_active_pytree, pinned_state) - - # Re-shard on host to the target device mesh - host_target_shardings = jax.tree.map( - lambda x: x.sharding.with_memory_kind("pinned_host"), abstract_state + pinned_state = jax.tree.map( + lambda x: get_active_pytree(x) if hasattr(x, "sharding") else x, + pinned_state, ) - host_target_state = jax.device_put( - pinned_state, host_target_shardings + def _device_put_pinned(x, abs_x): + if hasattr(abs_x, "sharding"): + return jax.device_put( + x, abs_x.sharding.with_memory_kind("pinned_host") + ) + return x + + # Re-shard on host to the target device mesh + host_target_state = jax.tree.map( + _device_put_pinned, + pinned_state, + abstract_state, ) # Move from host back to device (TPU) memory. - restored_state = jax.device_put( - host_target_state, jax.tree.map(lambda x: x.sharding, abstract_state) + restored_state = jax.tree.map( + lambda x, abs_x: jax.device_put(x, abs_x.sharding) + if hasattr(abs_x, "sharding") + else x, + host_target_state, + abstract_state, ) jax.block_until_ready(restored_state)