diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index c1f640fc3..567bd8047 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -21,7 +21,7 @@ import functools import os import time -from typing import Any, Dict, Sequence, Set, Tuple, TypeAlias, Union, cast +from typing import Any, cast, Dict, Sequence, Set, Tuple, TypeAlias, Union import warnings from absl import logging @@ -41,6 +41,7 @@ from orbax.checkpoint._src.path import async_path from orbax.checkpoint._src.path import utils as path_utils from orbax.checkpoint._src.serialization import jax_array_restore_args +from orbax.checkpoint._src.serialization import jax_array_transfer_tracker from orbax.checkpoint._src.serialization import limits from orbax.checkpoint._src.serialization import ocdbt_utils from orbax.checkpoint._src.serialization import replica_slices @@ -51,6 +52,7 @@ from orbax.checkpoint._src.tree import utils as tree_utils import tensorstore as ts + Pytree: TypeAlias = Any ArrayRestoreArgs = jax_array_restore_args.ArrayRestoreArgs SingleReplicaArrayRestoreArgs = ( @@ -274,23 +276,23 @@ def _worker_serialize_arrays( ) -def _is_prioritized_for_saving(info: types.ParamInfo) -> bool: - """Identifies prioritized keys. +def _is_prioritized_for_saving(info: types.ParamInfo) -> types.Prioritization: + """Identifies key priority for saving. A prioritized key is one that is scheduled for D2H transfer synchronously, - otherwise it may be scheduled from a background thread. Defaults to True, - since async D2H is likely to result in errors if the arrays are donated by - the training step. + otherwise it may be scheduled from a background thread. Defaults to + PRIORITIZED since async D2H is likely to result in errors if the arrays are + donated by the training step. Args: info: The ParamInfo to check. Returns: - True if the key is prioritized for saving. + The prioritization of the key for saving. """ is_prioritized_key_fn = info.is_prioritized_key_fn if is_prioritized_key_fn is None: - return True + return types.Prioritization.PRIORITIZED return is_prioritized_key_fn(info.keypath) @@ -350,6 +352,7 @@ def _serialize_arrays_batches_without_dispatcher( enable_replica_parallel_separate_folder: bool, ext_metadata: Dict[str, Any], enable_pinned_host_transfer: bool, + transfer_tracker: jax_array_transfer_tracker.TransferTracker, ) -> future.Future: """Serializes arrays batches without dispatcher.""" # Complete D2H transfer in parallel for each array for prioritized values. @@ -382,6 +385,8 @@ def _serialize_arrays_batches_without_dispatcher( prioritized_values_on_host = replica_slices_transfer_arrays_to_host( prioritized_arrays ) + for info in prioritized_infos: + transfer_tracker.finish_transfer(info.keypath) else: logging.warning( 'No prioritized params found for saving. D2H for all values will be' @@ -411,6 +416,8 @@ async def _serialize_without_dispatcher(): dispatcher=None, ): b_arrays_on_host = replica_slices_transfer_arrays_to_host(b_arrays) + for info in b_infos: + transfer_tracker.finish_transfer(info.keypath) await async_serialize_replica_slices_batch( b_arrays_on_host, b_infos, @@ -437,6 +444,7 @@ def _serialize_arrays( array_metadata_store: array_metadata_store_lib.Store | None, enable_replica_parallel_separate_folder: bool, ext_metadata: Dict[str, Any], + transfer_tracker: jax_array_transfer_tracker.TransferTracker, ) -> future.Future: """D2H transfer and serialize arrays using dispatcher if provided.""" @@ -446,22 +454,32 @@ def _serialize_arrays( device_host_max_bytes = byte_limiter.max_bytes prioritized: list[tuple[jax.Array, types.ParamInfo, types.SaveArgs]] = [] + prioritized_async: list[tuple[jax.Array, types.ParamInfo, types.SaveArgs]] = ( + [] + ) deprioritized: list[tuple[jax.Array, types.ParamInfo, types.SaveArgs]] = [] - for info, arg, value in zip(infos, args, arrays): - if device_host_max_bytes is None: - prioritized.append((value, info, arg)) - elif _is_prioritized_for_saving(info): - logging.info( - 'Key prioritized for saving: %s', - tree_utils.str_keypath(info.keypath), - ) + transfer_tracker.register_batch( + [info.keypath for info in infos] + ) + if device_host_max_bytes is None: + for info, arg, value in zip(infos, args, arrays): prioritized.append((value, info, arg)) - else: - logging.info( - 'Key not prioritized for saving: %s', - tree_utils.str_keypath(info.keypath), - ) - deprioritized.append((value, info, arg)) + else: + for info, arg, value in zip(infos, args, arrays): + prioritization = _is_prioritized_for_saving(info) + if prioritization == types.Prioritization.PRIORITIZED: + prioritized.append((value, info, arg)) + elif prioritization == types.Prioritization.PRIORITIZED_ASYNC: + prioritized_async.append((value, info, arg)) + elif prioritization == types.Prioritization.DEPRIORITIZED: + deprioritized.append((value, info, arg)) + elif prioritization == types.Prioritization.UNKNOWN: + raise ValueError( + f'Prioritization is unknown for key {info.keypath}.' + ) + + # Combine async lists, placing prioritized async first. + deprioritized = prioritized_async + deprioritized if dispatcher is None: return _serialize_arrays_batches_without_dispatcher( @@ -478,6 +496,7 @@ def _serialize_arrays( enable_replica_parallel_separate_folder, ext_metadata, infos[0].enable_pinned_host_transfer, + transfer_tracker, ) else: @@ -541,6 +560,8 @@ async def _serialize(): if prioritized: arrays, infos, args = zip(*prioritized) _serialize_batch(infos, args, arrays) + for info in infos: + transfer_tracker.finish_transfer(info.keypath) if deprioritized: assert device_host_max_bytes is not None for ( @@ -554,6 +575,8 @@ async def _serialize(): dispatcher=dispatcher, ): _serialize_batch(b_infos, b_args, b_arrays) + for info in b_infos: + transfer_tracker.finish_transfer(info.keypath) return future.CommitFutureAwaitingContractedSignals( _serialize(), @@ -1090,6 +1113,7 @@ def __init__( ) self._ext_metadata = dict() self._dispatcher = dispatcher + self._transfer_tracker = jax_array_transfer_tracker.TransferTracker() logging.vlog( 1, @@ -1289,11 +1313,16 @@ async def serialize( metadata_key=self._metadata_key, ext_metadata=self._ext_metadata, array_metadata_store=self._array_metadata_store, + transfer_tracker=self._transfer_tracker, ) ) return future_list + def wait_for_transfer(self, prefix_tuple: tuple[Any, ...]): + """Waits for outstanding D2H transfers with given prefix to complete.""" + self._transfer_tracker.wait_for_transfer(prefix_tuple) + async def _maybe_read_metadata_and_update_restore_args( self, infos: Sequence[types.ParamInfo], diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_transfer_tracker.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_transfer_tracker.py new file mode 100644 index 000000000..6040367a7 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_transfer_tracker.py @@ -0,0 +1,127 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tracks in-flight D2H transfers using prefix reference counting.""" + +import collections +import threading +from typing import Any + +from absl import logging + + +def _to_string_tuple(key_tuple: tuple[Any, ...] | None) -> tuple[str, ...]: + """Converts a tuple of JAX keys or strings to a tuple of raw strings.""" + if key_tuple is None: + return () + + def _to_str(key: Any) -> str: + if isinstance(key, str): + return key + if hasattr(key, 'key'): + return str(key.key) + if hasattr(key, 'idx'): + return str(key.idx) + return str(key) + + return tuple(_to_str(k) for k in key_tuple) + + +class TransferTracker: + """Tracks in-flight D2H transfers using prefix reference counting. + + O(1) lookup, O(tuple_length) update. + """ + + def __init__(self): + self.in_flight_counts = {} + self.lock = threading.Lock() + self.condition = threading.Condition(self.lock) + + def __deepcopy__(self, memo): + """Support deepcopying the TransferTracker (constructs new locks).""" + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + result.in_flight_counts = dict(self.in_flight_counts) + result.lock = threading.Lock() + result.condition = threading.Condition(result.lock) + return result + + def __getstate__(self): + """Support pickling TransferTracker (serializes counts).""" + return {'in_flight_counts': dict(self.in_flight_counts)} + + def __setstate__(self, state): + """Support unpickling TransferTracker (reconstructs locks).""" + self.in_flight_counts = state['in_flight_counts'] + self.lock = threading.Lock() + self.condition = threading.Condition(self.lock) + + def register_batch(self, keys: list[tuple[Any, ...] | None]): + """Registers a batch of keys for transfer tracking. + + Args: + keys: A list of keys to register for transfer tracking. + """ + # Compute prefix counts + local_counts = collections.defaultdict(int) + for key in keys: + str_key = _to_string_tuple(key) + # Generate all prefixes: e.g., (A,), (A, B), (A, B, C) + for i in range(1, len(str_key) + 1): + local_counts[str_key[:i]] += 1 + + with self.lock: + for prefix, count in local_counts.items(): + current = self.in_flight_counts.get(prefix, 0) + self.in_flight_counts[prefix] = current + count + + logging.info('Registered batch of %d keys', len(keys)) + + def finish_transfer(self, key_tuple: tuple[Any, ...] | None): + """Finishes a transfer for a single key. + + Args: + key_tuple: The key tuple of the finished transfer. + """ + str_key = _to_string_tuple(key_tuple) + + with self.lock: + for i in range(1, len(str_key) + 1): + prefix = str_key[:i] + if prefix not in self.in_flight_counts: + continue + new_count = self.in_flight_counts[prefix] - 1 + + if new_count <= 0: + logging.info('Transfer for prefix %s is done', prefix) + if prefix in self.in_flight_counts: + del self.in_flight_counts[prefix] + else: + self.in_flight_counts[prefix] = new_count + + self.condition.notify_all() + + def wait_for_transfer(self, prefix_tuple: tuple[Any, ...] | None): + """Waits until all transfers under a specific prefix are complete. + + Args: + prefix_tuple: The prefix tuple of the transfer to wait for. + """ + str_prefix = _to_string_tuple(prefix_tuple) + + with self.condition: + while str_prefix in self.in_flight_counts: + self.condition.wait() diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_transfer_tracker_test.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_transfer_tracker_test.py new file mode 100644 index 000000000..f38d3d188 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_transfer_tracker_test.py @@ -0,0 +1,158 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from orbax.checkpoint._src.serialization import jax_array_transfer_tracker + + +class StopLoopError(Exception): + pass + + +class DictKey: + + def __init__(self, key: str): + self.key = key + + def __eq__(self, other): + return isinstance(other, DictKey) and self.key == other.key + + def __hash__(self): + return hash(self.key) + + def __repr__(self): + return f"DictKey(key={self.key!r})" + + +class JaxArrayTransferTrackerTest(parameterized.TestCase): + + def test_finish_transfer_twice_no_error(self): + tracker = jax_array_transfer_tracker.TransferTracker() + + tracker.register_batch([("key1",)]) + # Finish the transfer the first time + tracker.finish_transfer(("key1",)) + # Finish the transfer the second time, should not raise any error + tracker.finish_transfer(("key1",)) + + def test_prefix_matching(self): + tracker = jax_array_transfer_tracker.TransferTracker() + + tracker.register_batch([("key1", "key2")]) + # Finish sub-key + tracker.finish_transfer(("key1", "key2")) + # This should complete the prefix as well + self.assertNotIn(("key1",), tracker.in_flight_counts) + + def test_wait_for_transfer(self): + tracker = jax_array_transfer_tracker.TransferTracker() + tracker.register_batch([("key1", "key2"), ("key1", "key3")]) + tracker.finish_transfer(("key1", "key2")) + + with self.subTest("test_prefix_completed"): + with mock.patch.object(tracker.condition, "wait") as mock_wait: + tracker.wait_for_transfer(("key1", "key2")) + mock_wait.assert_not_called() + self.assertNotIn(("key1", "key2"), tracker.in_flight_counts) + + with self.subTest("test_prefix_not_completed_child"): + with mock.patch.object(tracker.condition, "wait") as mock_wait: + mock_wait.side_effect = StopLoopError + with self.assertRaises(StopLoopError): + tracker.wait_for_transfer(("key1", "key3")) + mock_wait.assert_called_once() + + with self.subTest("test_prefix_not_completed_parent"): + with mock.patch.object(tracker.condition, "wait") as mock_wait: + mock_wait.side_effect = StopLoopError + with self.assertRaises(StopLoopError): + tracker.wait_for_transfer(("key1",)) + mock_wait.assert_called_once() + self.assertIn(("key1",), tracker.in_flight_counts) + + with self.subTest("test_not_tracked_prefix"): + tracker.finish_transfer(("key1", "key4")) + with mock.patch.object(tracker.condition, "wait") as mock_wait: + tracker.wait_for_transfer(("key1", "key4")) + mock_wait.assert_not_called() + self.assertNotIn(("key1", "key4"), tracker.in_flight_counts) + + def test_nested_dict_keys(self): + dk = DictKey + tracker = jax_array_transfer_tracker.TransferTracker() + + key1 = (dk("a"), dk("b"), dk("c"), dk("f")) + key2 = (dk("a"), dk("b"), dk("c"), dk("g")) + key3 = (dk("a"), dk("b"), dk("cd"), dk("h")) + key4 = (dk("a"), dk("b"), dk("ce"), dk("i")) + + tracker.register_batch([key1, key2, key3, key4]) + + # Check prefix counts + prefix_primary = ("a",) + prefix_secondary = ("a", "b") + prefix_tertiary = ("a", "b", "c") + + self.assertEqual(tracker.in_flight_counts[prefix_primary], 4) + self.assertEqual(tracker.in_flight_counts[prefix_secondary], 4) + self.assertEqual(tracker.in_flight_counts[prefix_tertiary], 2) + + # Finish key1 + tracker.finish_transfer(key1) + self.assertEqual(tracker.in_flight_counts[prefix_primary], 3) + self.assertEqual(tracker.in_flight_counts[prefix_secondary], 3) + self.assertEqual(tracker.in_flight_counts[prefix_tertiary], 1) + + # Finish key2 + tracker.finish_transfer(key2) + self.assertEqual(tracker.in_flight_counts[prefix_primary], 2) + self.assertEqual(tracker.in_flight_counts[prefix_secondary], 2) + self.assertNotIn(prefix_tertiary, tracker.in_flight_counts) + + # Finish key3 and key4 + tracker.finish_transfer(key3) + tracker.finish_transfer(key4) + self.assertNotIn(prefix_primary, tracker.in_flight_counts) + + def test_none_keypath(self): + tracker = jax_array_transfer_tracker.TransferTracker() + # Check registering a None keypath, should not raise TypeError or any error + tracker.register_batch([None]) + self.assertEmpty(tracker.in_flight_counts) + + # Check finishing None keypath, should not raise TypeError or any error + tracker.finish_transfer(None) + self.assertEmpty(tracker.in_flight_counts) + + +if __name__ == "__main__": + absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers_test.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers_test.py index d5fe3a5fe..c2a58e492 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers_test.py @@ -36,6 +36,7 @@ from orbax.checkpoint._src.path import async_path from orbax.checkpoint._src.path import atomicity from orbax.checkpoint._src.serialization import jax_array_handlers +from orbax.checkpoint._src.serialization import jax_array_transfer_tracker from orbax.checkpoint._src.serialization import ocdbt_utils from orbax.checkpoint._src.serialization import serialization from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils diff --git a/checkpoint/orbax/checkpoint/_src/serialization/types.py b/checkpoint/orbax/checkpoint/_src/serialization/types.py index bd6146006..f11457070 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/types.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/types.py @@ -19,6 +19,7 @@ import abc import copy import dataclasses +import enum from typing import Any, Callable, Optional, Protocol, Sequence, Tuple from absl import logging @@ -416,17 +417,28 @@ def has(self, ty: Any) -> bool: ... +class Prioritization(enum.Enum): + """Prioritization of arrays for saving.""" + PRIORITIZED = 0 + PRIORITIZED_ASYNC = 1 + DEPRIORITIZED = 2 + UNKNOWN = 3 + + class IsPrioritizedKeyFn(Protocol): - """Protocol for checking if a key is prioritized. + """Protocol for checking the prioritization of a key. The function accepts a PyTree keypath (obtained - using jax.tree.map_with_path) and returns True if the D2H transfer should be - scheduled during the blocking part of the save (defaults to True in all places - unless False is returned by this function). + using jax.tree.map_with_path) and returns a Prioritization enum value + indicating how the D2H transfer should be scheduled. - The D2H transfer is scheduled before returning + For PRIORITIZED keys, the D2H transfer is scheduled before returning to the caller, so the values will never be corrupted by a concurrent update - or donation. Keys that are not prioritized will not + or donation. + + For the rest of the keys, the D2H transfer is scheduled asynchronously and + the function returns immediately. PRIORITIZED_ASYNC keys will be scheduled + before DEPRIORITIZED keys. Keys that are not prioritized will not be scheduled for transfer until all prioritized keys have been fully written to the checkpoint. This means that these values may be altered if the values are updated concurrently. @@ -438,5 +450,5 @@ class IsPrioritizedKeyFn(Protocol): `save_device_host_concurrent_gb` will be ignored for them. """ - def __call__(self, keypath: Tuple[Any, ...]) -> bool: - """Returns true if the key is prioritized.""" + def __call__(self, keypath: Tuple[Any, ...]) -> Prioritization: + """Returns the prioritization of the key."""