Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import functools
import os
import time
from typing import Any, Dict, Sequence, Set, Tuple, TypeAlias, Union, cast
from typing import Any, cast, Dict, Sequence, Set, Tuple, TypeAlias, Union
import warnings

from absl import logging
Expand All @@ -41,6 +41,7 @@
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import utils as path_utils
from orbax.checkpoint._src.serialization import jax_array_restore_args
from orbax.checkpoint._src.serialization import jax_array_transfer_tracker
from orbax.checkpoint._src.serialization import limits
from orbax.checkpoint._src.serialization import ocdbt_utils
from orbax.checkpoint._src.serialization import replica_slices
Expand All @@ -51,6 +52,7 @@
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts


Pytree: TypeAlias = Any
ArrayRestoreArgs = jax_array_restore_args.ArrayRestoreArgs
SingleReplicaArrayRestoreArgs = (
Expand Down Expand Up @@ -274,23 +276,23 @@ def _worker_serialize_arrays(
)


def _is_prioritized_for_saving(info: types.ParamInfo) -> bool:
"""Identifies prioritized keys.
def _is_prioritized_for_saving(info: types.ParamInfo) -> types.Prioritization:
"""Identifies key priority for saving.

A prioritized key is one that is scheduled for D2H transfer synchronously,
otherwise it may be scheduled from a background thread. Defaults to True,
since async D2H is likely to result in errors if the arrays are donated by
the training step.
otherwise it may be scheduled from a background thread. Defaults to
PRIORITIZED since async D2H is likely to result in errors if the arrays are
donated by the training step.

Args:
info: The ParamInfo to check.

Returns:
True if the key is prioritized for saving.
The prioritization of the key for saving.
"""
is_prioritized_key_fn = info.is_prioritized_key_fn
if is_prioritized_key_fn is None:
return True
return types.Prioritization.PRIORITIZED
return is_prioritized_key_fn(info.keypath)


Expand Down Expand Up @@ -350,6 +352,7 @@ def _serialize_arrays_batches_without_dispatcher(
enable_replica_parallel_separate_folder: bool,
ext_metadata: Dict[str, Any],
enable_pinned_host_transfer: bool,
transfer_tracker: jax_array_transfer_tracker.TransferTracker,
) -> future.Future:
"""Serializes arrays batches without dispatcher."""
# Complete D2H transfer in parallel for each array for prioritized values.
Expand Down Expand Up @@ -382,6 +385,8 @@ def _serialize_arrays_batches_without_dispatcher(
prioritized_values_on_host = replica_slices_transfer_arrays_to_host(
prioritized_arrays
)
for info in prioritized_infos:
transfer_tracker.finish_transfer(info.keypath)
else:
logging.warning(
'No prioritized params found for saving. D2H for all values will be'
Expand Down Expand Up @@ -411,6 +416,8 @@ async def _serialize_without_dispatcher():
dispatcher=None,
):
b_arrays_on_host = replica_slices_transfer_arrays_to_host(b_arrays)
for info in b_infos:
transfer_tracker.finish_transfer(info.keypath)
await async_serialize_replica_slices_batch(
b_arrays_on_host,
b_infos,
Expand All @@ -437,6 +444,7 @@ def _serialize_arrays(
array_metadata_store: array_metadata_store_lib.Store | None,
enable_replica_parallel_separate_folder: bool,
ext_metadata: Dict[str, Any],
transfer_tracker: jax_array_transfer_tracker.TransferTracker,
) -> future.Future:
"""D2H transfer and serialize arrays using dispatcher if provided."""

Expand All @@ -446,22 +454,32 @@ def _serialize_arrays(
device_host_max_bytes = byte_limiter.max_bytes

prioritized: list[tuple[jax.Array, types.ParamInfo, types.SaveArgs]] = []
prioritized_async: list[tuple[jax.Array, types.ParamInfo, types.SaveArgs]] = (
[]
)
deprioritized: list[tuple[jax.Array, types.ParamInfo, types.SaveArgs]] = []
for info, arg, value in zip(infos, args, arrays):
if device_host_max_bytes is None:
prioritized.append((value, info, arg))
elif _is_prioritized_for_saving(info):
logging.info(
'Key prioritized for saving: %s',
tree_utils.str_keypath(info.keypath),
)
transfer_tracker.register_batch(
[info.keypath for info in infos]
)
if device_host_max_bytes is None:
for info, arg, value in zip(infos, args, arrays):
prioritized.append((value, info, arg))
else:
logging.info(
'Key not prioritized for saving: %s',
tree_utils.str_keypath(info.keypath),
)
deprioritized.append((value, info, arg))
else:
for info, arg, value in zip(infos, args, arrays):
prioritization = _is_prioritized_for_saving(info)
if prioritization == types.Prioritization.PRIORITIZED:
prioritized.append((value, info, arg))
elif prioritization == types.Prioritization.PRIORITIZED_ASYNC:
prioritized_async.append((value, info, arg))
elif prioritization == types.Prioritization.DEPRIORITIZED:
deprioritized.append((value, info, arg))
elif prioritization == types.Prioritization.UNKNOWN:
raise ValueError(
f'Prioritization is unknown for key {info.keypath}.'
)

# Combine async lists, placing prioritized async first.
deprioritized = prioritized_async + deprioritized

if dispatcher is None:
return _serialize_arrays_batches_without_dispatcher(
Expand All @@ -478,6 +496,7 @@ def _serialize_arrays(
enable_replica_parallel_separate_folder,
ext_metadata,
infos[0].enable_pinned_host_transfer,
transfer_tracker,
)
else:

Expand Down Expand Up @@ -541,6 +560,8 @@ async def _serialize():
if prioritized:
arrays, infos, args = zip(*prioritized)
_serialize_batch(infos, args, arrays)
for info in infos:
transfer_tracker.finish_transfer(info.keypath)
if deprioritized:
assert device_host_max_bytes is not None
for (
Expand All @@ -554,6 +575,8 @@ async def _serialize():
dispatcher=dispatcher,
):
_serialize_batch(b_infos, b_args, b_arrays)
for info in b_infos:
transfer_tracker.finish_transfer(info.keypath)

return future.CommitFutureAwaitingContractedSignals(
_serialize(),
Expand Down Expand Up @@ -1090,6 +1113,7 @@ def __init__(
)
self._ext_metadata = dict()
self._dispatcher = dispatcher
self._transfer_tracker = jax_array_transfer_tracker.TransferTracker()

logging.vlog(
1,
Expand Down Expand Up @@ -1289,11 +1313,16 @@ async def serialize(
metadata_key=self._metadata_key,
ext_metadata=self._ext_metadata,
array_metadata_store=self._array_metadata_store,
transfer_tracker=self._transfer_tracker,
)
)

return future_list

def wait_for_transfer(self, prefix_tuple: tuple[Any, ...]):
"""Waits for outstanding D2H transfers with given prefix to complete."""
self._transfer_tracker.wait_for_transfer(prefix_tuple)

async def _maybe_read_metadata_and_update_restore_args(
self,
infos: Sequence[types.ParamInfo],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tracks in-flight D2H transfers using prefix reference counting."""

import collections
import threading
from typing import Any

from absl import logging


def _to_string_tuple(key_tuple: tuple[Any, ...] | None) -> tuple[str, ...]:
"""Converts a tuple of JAX keys or strings to a tuple of raw strings."""
if key_tuple is None:
return ()

def _to_str(key: Any) -> str:
if isinstance(key, str):
return key
if hasattr(key, 'key'):
return str(key.key)
if hasattr(key, 'idx'):
return str(key.idx)
return str(key)

return tuple(_to_str(k) for k in key_tuple)


class TransferTracker:
"""Tracks in-flight D2H transfers using prefix reference counting.

O(1) lookup, O(tuple_length) update.
"""

def __init__(self):
self.in_flight_counts = {}
self.lock = threading.Lock()
self.condition = threading.Condition(self.lock)

def __deepcopy__(self, memo):
"""Support deepcopying the TransferTracker (constructs new locks)."""
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
result.in_flight_counts = dict(self.in_flight_counts)
result.lock = threading.Lock()
result.condition = threading.Condition(result.lock)
return result

def __getstate__(self):
"""Support pickling TransferTracker (serializes counts)."""
return {'in_flight_counts': dict(self.in_flight_counts)}

def __setstate__(self, state):
"""Support unpickling TransferTracker (reconstructs locks)."""
self.in_flight_counts = state['in_flight_counts']
self.lock = threading.Lock()
self.condition = threading.Condition(self.lock)

def register_batch(self, keys: list[tuple[Any, ...] | None]):
"""Registers a batch of keys for transfer tracking.

Args:
keys: A list of keys to register for transfer tracking.
"""
# Compute prefix counts
local_counts = collections.defaultdict(int)
for key in keys:
str_key = _to_string_tuple(key)
# Generate all prefixes: e.g., (A,), (A, B), (A, B, C)
for i in range(1, len(str_key) + 1):
local_counts[str_key[:i]] += 1

with self.lock:
for prefix, count in local_counts.items():
current = self.in_flight_counts.get(prefix, 0)
self.in_flight_counts[prefix] = current + count

logging.info('Registered batch of %d keys', len(keys))

def finish_transfer(self, key_tuple: tuple[Any, ...] | None):
"""Finishes a transfer for a single key.

Args:
key_tuple: The key tuple of the finished transfer.
"""
str_key = _to_string_tuple(key_tuple)

with self.lock:
for i in range(1, len(str_key) + 1):
prefix = str_key[:i]
if prefix not in self.in_flight_counts:
continue
new_count = self.in_flight_counts[prefix] - 1

if new_count <= 0:
logging.info('Transfer for prefix %s is done', prefix)
if prefix in self.in_flight_counts:
del self.in_flight_counts[prefix]
else:
self.in_flight_counts[prefix] = new_count

self.condition.notify_all()

def wait_for_transfer(self, prefix_tuple: tuple[Any, ...] | None):
"""Waits until all transfers under a specific prefix are complete.

Args:
prefix_tuple: The prefix tuple of the transfer to wait for.
"""
str_prefix = _to_string_tuple(prefix_tuple)

with self.condition:
while str_prefix in self.in_flight_counts:
self.condition.wait()
Loading
Loading