From 73a55145768c2cdae58bd427c0bb617e59b6105d Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Thu, 14 May 2026 14:46:04 -0700 Subject: [PATCH] Modularize BasePyTreeCheckpointHandler metadata persistence logic into MetadataManager. PiperOrigin-RevId: 915626554 --- .../checkpoint/_src/engine/async_io_engine.py | 202 ++++++++ .../_src/engine/async_io_engine_test.py | 158 ++++++ .../base_pytree_checkpoint_handler.py | 463 +++--------------- .../handlers/pytree_checkpoint_handler.py | 5 +- .../_src/metadata/metadata_manager.py | 231 +++++++++ .../_src/metadata/metadata_manager_test.py | 216 ++++++++ .../v1/_src/handlers/pytree_handler.py | 10 +- .../orbax/checkpoint/metadata/__init__.py | 1 + 8 files changed, 882 insertions(+), 404 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/_src/engine/async_io_engine.py create mode 100644 checkpoint/orbax/checkpoint/_src/engine/async_io_engine_test.py create mode 100644 checkpoint/orbax/checkpoint/_src/metadata/metadata_manager.py create mode 100644 checkpoint/orbax/checkpoint/_src/metadata/metadata_manager_test.py diff --git a/checkpoint/orbax/checkpoint/_src/engine/async_io_engine.py b/checkpoint/orbax/checkpoint/_src/engine/async_io_engine.py new file mode 100644 index 000000000..259822e0e --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/engine/async_io_engine.py @@ -0,0 +1,202 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AsyncIoEngine module. + +Provides the `AsyncIoEngine` class and supporting helper functions responsible +for managing concurrent I/O execution, thread-pooling, and performance telemetry +collection during PyTree saving and restoration workflows. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import dataclasses +import sys +import threading +import time +from typing import Any, List, Optional, Sequence, Tuple, Union + +from absl import logging +import humanize +import jax +from orbax.checkpoint._src.futures import future +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.serialization import memory_regulator +from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint._src.serialization import types + +TypeHandler = types.TypeHandler +ParamInfo = types.ParamInfo +SaveArgs = type_handlers.SaveArgs +RestoreArgs = type_handlers.RestoreArgs + + +def _default_sizeof_values(values: Sequence[Any]) -> Sequence[int]: + return [sys.getsizeof(v) for v in values] + + +def get_batch_memory_size( + handler: TypeHandler, values: Sequence[Any] +) -> Tuple[int, int]: + """Gets memory size for a batch of leaf values.""" + try: + write_sizes, read_sizes = zip(*handler.memory_size(values)) + except NotImplementedError: + logging.warning( + '`memory_size` is not implemented for `TypeHandler` of type: %s. Using' + ' the a default implementation to measure value memory consumption that' + ' may result in inaccurate estimation.', + type(handler), + ) + write_sizes = read_sizes = _default_sizeof_values(values) + assert len(write_sizes) == len(values) + assert len(read_sizes) == len(values) + return sum(write_sizes), sum(read_sizes) + + +def log_io_metrics( + size: int, + start_time: float, + gbytes_per_sec_metric: str, + gbytes_metric: Optional[str] = None, +): + """Logs the bytes per second metric.""" + time_elapsed = time.time() - start_time + bytes_per_sec = ( + float('nan') if time_elapsed == 0 else float(size) / time_elapsed + ) + note = 'per-host' + logging.info( + '[process=%d] %s: %s/s (total gbytes: %s) (time elapsed: %s s) (%s)', + multihost.process_index(), + gbytes_per_sec_metric, + humanize.naturalsize(bytes_per_sec, binary=True, format='%.3f'), + humanize.naturalsize(size, binary=True), + time_elapsed, + note, + ) + jax.monitoring.record_scalar( + gbytes_per_sec_metric, value=bytes_per_sec / (1024**3) + ) + if gbytes_metric is not None: + jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3)) + + +async def logging_serialize( + handler: TypeHandler, + serialize: asyncio.Coroutine[Any, Any, Sequence[future.Future]], +) -> Sequence[future.Future]: + """Logs the time taken to serialize.""" + start = time.time() + commit_futures = await serialize + handler_name = f'{type(handler).__module__}.{type(handler).__qualname__}' + logging.info( + '[process=%s][thread=%s] Initiated %s.serialize. Time taken: %fs', + multihost.process_index(), + threading.current_thread().name, + f'"{handler_name}"', + time.time() - start, + ) + return commit_futures + + +@dataclasses.dataclass +class BatchRequest: + """Represents a a request for batched serialization or deserialization. + + Attributes: + handler: Used to serialize or deserialize the parameters. + keys: Used to identify the original tree keys so that the PyTree can be + reconstructed. + values: Values to serialize. + infos: ParamInfos. + args: List of SaveArgs or RestoreArgs. + """ + + handler: TypeHandler + keys: List[str] + values: List[Any] + infos: List[ParamInfo] + args: List[Union[SaveArgs, RestoreArgs]] + + def __post_init__(self): + length = len(self.values) + if not all(( + length == len(self.infos), + length == len(self.args), + length == len(self.keys), + )): + raise AssertionError('Found `_BatchRequest` with mismatched parameters.') + + +@contextlib.contextmanager +def memory_profiler_context(): + """Context manager for memory_regulator profiler.""" + memory_regulator.profiler_start() + try: + yield + finally: + # Explicitly stop the bg thread if an exception occurs + memory_regulator.profiler_end() + + +class AsyncIoEngine: + """Encapsulates concurrency, thread-pooling, and I/O telemetry logic.""" + + async def execute_save( + self, batch_requests: Sequence[BatchRequest] + ) -> Tuple[List[Any], int]: + """Executes save requests asynchronously with I/O telemetry.""" + serialize_ops = [] + tree_memory_size = 0 + with memory_profiler_context(): + for request in batch_requests: + serialize_ops.append( + logging_serialize( + request.handler, + request.handler.serialize( + request.values, request.infos, request.args + ), + ) + ) + write_size, _ = get_batch_memory_size(request.handler, request.values) + tree_memory_size += write_size + + commit_futures = await asyncio.gather(*serialize_ops) + + logging.info( + 'MemoryRegulated: Peak usage: %f GiB', + memory_regulator.profiler_peak_usage_gib(), + ) + return commit_futures, tree_memory_size + + async def execute_restore( + self, batch_requests: Sequence[BatchRequest] + ) -> Tuple[List[Any], int]: + """Executes restore requests asynchronously with I/O telemetry.""" + deserialized_batches_ops = [] + for request in batch_requests: + deserialized_batches_ops.append( + request.handler.deserialize(request.infos, request.args) + ) + deserialized_batches = await asyncio.gather(*deserialized_batches_ops) + + tree_memory_size = 0 + for request, deserialized in zip(batch_requests, deserialized_batches): + _, read_size = get_batch_memory_size(request.handler, deserialized) + tree_memory_size += read_size + + return deserialized_batches, tree_memory_size diff --git a/checkpoint/orbax/checkpoint/_src/engine/async_io_engine_test.py b/checkpoint/orbax/checkpoint/_src/engine/async_io_engine_test.py new file mode 100644 index 000000000..7a41d7f36 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/engine/async_io_engine_test.py @@ -0,0 +1,158 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from unittest import mock + +from absl.testing import absltest +from orbax.checkpoint._src.engine import async_io_engine +from orbax.checkpoint._src.serialization import types + +AsyncIoEngine = async_io_engine.AsyncIoEngine +BatchRequest = async_io_engine.BatchRequest + + +class AsyncIoEngineTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): + + def test_get_batch_memory_size_success(self): + handler = mock.create_autospec(types.TypeHandler, instance=True) + handler.memory_size.return_value = [(10, 20), (30, 40)] + + write_size, read_size = async_io_engine.get_batch_memory_size( + handler, ['a', 'b'] + ) + self.assertEqual(write_size, 40) + self.assertEqual(read_size, 60) + + def test_get_batch_memory_size_not_implemented(self): + handler = mock.create_autospec(types.TypeHandler, instance=True) + handler.memory_size.side_effect = NotImplementedError() + + values = ['dummy1', 'dummy2'] + expected_size = sum(sys.getsizeof(v) for v in values) + + write_size, read_size = async_io_engine.get_batch_memory_size( + handler, values + ) + self.assertEqual(write_size, expected_size) + self.assertEqual(read_size, expected_size) + + def test_batch_request_validation_success(self): + handler = mock.create_autospec(types.TypeHandler, instance=True) + req = BatchRequest( + handler=handler, + keys=['k1', 'k2'], + values=['v1', 'v2'], + infos=[mock.Mock(), mock.Mock()], + args=[mock.Mock(), mock.Mock()], + ) + self.assertLen(req.values, 2) + + def test_batch_request_validation_mismatch(self): + handler = mock.create_autospec(types.TypeHandler, instance=True) + with self.assertRaises(AssertionError): + BatchRequest( + handler=handler, + keys=['k1'], + values=['v1', 'v2'], + infos=[mock.Mock(), mock.Mock()], + args=[mock.Mock(), mock.Mock()], + ) + + async def test_execute_save(self): + engine = AsyncIoEngine() + + handler1 = mock.create_autospec(types.TypeHandler, instance=True) + handler2 = mock.create_autospec(types.TypeHandler, instance=True) + + async def dummy_serialize1(*args, **kwargs): + del args, kwargs + return ['fut1', 'fut2'] + + async def dummy_serialize2(*args, **kwargs): + del args, kwargs + return ['fut3'] + + handler1.serialize.side_effect = dummy_serialize1 + handler2.serialize.side_effect = dummy_serialize2 + + handler1.memory_size.return_value = [(100, 0)] + handler2.memory_size.return_value = [(200, 0)] + + req1 = BatchRequest( + handler=handler1, + keys=['k1'], + values=['v1'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + req2 = BatchRequest( + handler=handler2, + keys=['k2'], + values=['v2'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + + commit_futures, tree_memory_size = await engine.execute_save([req1, req2]) + + self.assertEqual(commit_futures, [['fut1', 'fut2'], ['fut3']]) + self.assertEqual(tree_memory_size, 300) + + async def test_execute_restore(self): + engine = AsyncIoEngine() + + handler1 = mock.create_autospec(types.TypeHandler, instance=True) + handler2 = mock.create_autospec(types.TypeHandler, instance=True) + + async def dummy_deserialize1(*args, **kwargs): + del args, kwargs + return ['restored1'] + + async def dummy_deserialize2(*args, **kwargs): + del args, kwargs + return ['restored2'] + + handler1.deserialize.side_effect = dummy_deserialize1 + handler2.deserialize.side_effect = dummy_deserialize2 + + handler1.memory_size.return_value = [(0, 50)] + handler2.memory_size.return_value = [(0, 150)] + + req1 = BatchRequest( + handler=handler1, + keys=['k1'], + values=['v1'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + req2 = BatchRequest( + handler=handler2, + keys=['k2'], + values=['v2'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + + deserialized_batches, tree_memory_size = await engine.execute_restore( + [req1, req2] + ) + + self.assertEqual(deserialized_batches, [['restored1'], ['restored2']]) + self.assertEqual(tree_memory_size, 200) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index 014a88527..b978e301e 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -22,15 +22,12 @@ from __future__ import annotations import asyncio -import contextlib import dataclasses import functools import json -import sys import threading import time -from typing import Any, List, Optional, Sequence, Tuple, Union -import uuid +from typing import Any, List, Optional, Tuple, Union from absl import logging from etils import epath @@ -40,11 +37,12 @@ from orbax.checkpoint import options as options_lib from orbax.checkpoint import utils from orbax.checkpoint._src import asyncio_utils +from orbax.checkpoint._src.engine import async_io_engine from orbax.checkpoint._src.futures import future from orbax.checkpoint._src.handlers import async_checkpoint_handler -from orbax.checkpoint._src.logging import event_tracking from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib from orbax.checkpoint._src.metadata import empty_values +from orbax.checkpoint._src.metadata import metadata_manager as metadata_manager_lib from orbax.checkpoint._src.metadata import tree as tree_metadata from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import async_path @@ -52,7 +50,6 @@ from orbax.checkpoint._src.path import types as path_types from orbax.checkpoint._src.serialization import limits from orbax.checkpoint._src.serialization import memory_regulator -from orbax.checkpoint._src.serialization import ocdbt_utils from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import type_handler_registry as type_handler_registry_lib from orbax.checkpoint._src.serialization import type_handlers @@ -72,6 +69,7 @@ ParamInfo = types.ParamInfo TypeHandler = types.TypeHandler TypeHandlerRegistry = types.TypeHandlerRegistry +BatchRequest = async_io_engine.BatchRequest # TODO(b/298487158) Clean up protected access. LimitInFlightBytes = limits.LimitInFlightBytes @@ -94,110 +92,12 @@ class PartialSaveReplacementError(PartialSaveError): """Raised when a replacement is attempted during partial saving.""" -def _default_sizeof_values(values: Sequence[Any]) -> Sequence[int]: - return [sys.getsizeof(v) for v in values] - - -def _get_batch_memory_size( - handler: TypeHandler, values: Sequence[Any] -) -> Tuple[int, int]: - """Gets memory size for a batch of leaf values.""" - try: - write_sizes, read_sizes = zip(*handler.memory_size(values)) - except NotImplementedError: - logging.warning( - '`memory_size` is not implemented for `TypeHandler` of type: %s. Using' - ' the a default implementation to measure value memory consumption that' - ' may result in inaccurate estimation.', - type(handler), - ) - write_sizes = read_sizes = _default_sizeof_values(values) - assert len(write_sizes) == len(values) - assert len(read_sizes) == len(values) - return sum(write_sizes), sum(read_sizes) - - -def _log_io_metrics( - size: int, - start_time: float, - gbytes_per_sec_metric: str, - gbytes_metric: Optional[str] = None, -): - """Logs the bytes per second metric.""" - time_elapsed = time.time() - start_time - bytes_per_sec = ( - float('nan') if time_elapsed == 0 else float(size) / time_elapsed - ) - note = 'per-host' - logging.info( - '[process=%d] %s: %s/s (total gbytes: %s) (time elapsed: %s s) (%s)', - multihost.process_index(), - gbytes_per_sec_metric, - humanize.naturalsize(bytes_per_sec, binary=True, format='%.3f'), - humanize.naturalsize(size, binary=True), - time_elapsed, - note, - ) - jax.monitoring.record_scalar( - gbytes_per_sec_metric, value=bytes_per_sec / (1024**3) - ) - if gbytes_metric is not None: - jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3)) - - -async def _logging_serialize( - handler: TypeHandler, - serialize: asyncio.Coroutine[Any, Any, Sequence[future.Future]], -) -> Sequence[future.Future]: - """Logs the time taken to serialize.""" - start = time.time() - commit_futures = await serialize - handler_name = f'{type(handler).__module__}.{type(handler).__qualname__}' - logging.info( - '[process=%s][thread=%s] Initiated %s.serialize. Time taken: %fs', - multihost.process_index(), - threading.current_thread().name, - f'"{handler_name}"', - time.time() - start, - ) - return commit_futures - - -@dataclasses.dataclass -class _BatchRequest: - """Represents a a request for batched serialization or deserialization. - - Attributes: - handler: Used to serialize or deserialize the parameters. - keys: Used to identify the original tree keys so that the PyTree can be - reconstructed. - values: Values to serialize. - infos: ParamInfos. - args: List of SaveArgs or RestoreArgs. - """ - - handler: TypeHandler - keys: List[str] - values: List[Any] - infos: List[ParamInfo] - args: List[Union[SaveArgs, RestoreArgs]] - - def __post_init__(self): - length = len(self.values) - if not all(( - length == len(self.infos), - length == len(self.args), - length == len(self.keys), - )): - raise AssertionError('Found `_BatchRequest` with mismatched parameters.') - - def batched_serialization_requests( tree: PyTree, param_infos: PyTree, args: PyTree, registry: TypeHandlerRegistry, -) -> List[_BatchRequest]: +) -> List[BatchRequest]: """Gets a list of batched serialization or deserialization requests.""" grouped = {} @@ -246,7 +146,7 @@ def _group_value( ) from e if handler not in grouped: - grouped[handler] = _BatchRequest(handler, [], [], [], []) + grouped[handler] = BatchRequest(handler, [], [], [], []) request = grouped[handler] grouped[handler] = dataclasses.replace( request, @@ -325,17 +225,6 @@ def _maybe_set_default_save_restore_args(v, leaf_args): ) -@contextlib.contextmanager -def _memory_profiler_context(): - """Context manager for memory_regulator profiler.""" - memory_regulator.profiler_start() - try: - yield - finally: - # Explicitly stop the bg thread if an exception occurs - memory_regulator.profiler_end() - - def _format_bytes(bytes_value: Optional[int]) -> str: @@ -376,6 +265,7 @@ def __init__( ), enable_pinned_host_transfer: Optional[bool] = None, is_prioritized_key_fn: Optional[types.IsPrioritizedKeyFn] = None, + metadata_manager: Optional[metadata_manager_lib.MetadataManager] = None, ): """Creates BasePyTreeCheckpointHandler. @@ -420,6 +310,8 @@ def __init__( not prioritized. Note that any "prioritized" keys are assumed to be lightweight, and `save_device_host_concurrent_gb` will be ignored for them. + metadata_manager: Optional `MetadataManager` instance to manage + persistence. """ self._save_concurrent_bytes = save_concurrent_bytes self._restore_concurrent_bytes = restore_concurrent_bytes @@ -463,6 +355,11 @@ def __init__( if self._array_metadata_store: self._array_metadata_store.set_primary_host(self._primary_host) self._array_metadata_validator = array_metadata_validator + self._metadata_manager = ( + metadata_manager + if metadata_manager is not None + else metadata_manager_lib.MetadataManager() + ) if enable_pinned_host_transfer is None: enable_pinned_host_transfer = jax.default_backend() == 'gpu' @@ -489,6 +386,7 @@ def __init__( _format_bytes(self._save_concurrent_bytes), _format_bytes(self._restore_concurrent_bytes), ) + self._async_io_engine = async_io_engine.AsyncIoEngine() def get_param_names(self, item: PyTree) -> PyTree: """Gets parameter names for PyTree elements.""" @@ -567,12 +465,12 @@ async def _async_partial_save( self, directory: epath.Path, item: PyTree, - batch_requests: list[_BatchRequest], - param_infos: PyTree, - save_args: BasePyTreeSaveArgs, + batch_requests: list[BatchRequest], ): value_metadata_tree = ( - await self._read_metadata_file(directory) + await self._metadata_manager.read_metadata_file( + directory, pytree_metadata_options=self._pytree_metadata_options + ) ).as_nested_tree() tree_diff = tree_structure_utils.tree_difference(item, value_metadata_tree) @@ -627,21 +525,7 @@ def _handle_diffs(keypath, diff): ) ) - serialize_ops = [] - tree_memory_size = 0 - for request in filtered_requests: - serialize_ops += [ - _logging_serialize( - request.handler, - request.handler.serialize( - request.values, request.infos, request.args - ), - ) - ] - write_size, _ = _get_batch_memory_size(request.handler, request.values) - tree_memory_size += write_size - - return serialize_ops, tree_memory_size, param_infos, save_args + return filtered_requests async def async_save( self, @@ -725,7 +609,6 @@ async def async_save( leaf.parent_dir is directory for leaf in jax.tree.leaves(param_infos) ) - serialize_ops = [] # List of (coros -> List of futures) batch_requests = batched_serialization_requests( item, param_infos, @@ -737,33 +620,15 @@ async def async_save( directory / PYTREE_METADATA_FILE ) batch_requests_ready_time = time.time() - with _memory_profiler_context(): - if is_partial_save: - serialize_ops, tree_memory_size, param_infos, save_args = ( - await self._async_partial_save( - directory, item, batch_requests, param_infos, save_args - ) - ) - else: - tree_memory_size = 0 - for request in batch_requests: - serialize_ops += [ - _logging_serialize( - request.handler, - request.handler.serialize( - request.values, request.infos, request.args - ), - ) - ] - write_size, _ = _get_batch_memory_size( - request.handler, request.values - ) - tree_memory_size += write_size - # Await copy futures. Returns List[List[future.Future]]. - commit_futures = await asyncio.gather(*serialize_ops) - logging.info( - 'MemoryRegulated: Peak usage: %f GiB', - memory_regulator.profiler_peak_usage_gib(), + if is_partial_save: + requests_to_save = await self._async_partial_save( + directory, item, batch_requests + ) + else: + requests_to_save = batch_requests + + commit_futures, tree_memory_size = await self._async_io_engine.execute_save( + requests_to_save ) # Flatten to List[future.Future]. commit_futures, _ = jax.tree.flatten(commit_futures) @@ -793,7 +658,7 @@ async def async_save( save_futures += commit_futures - _log_io_metrics( + async_io_engine.log_io_metrics( tree_memory_size, start_time, '/jax/orbax/write/blocking_gbytes_per_sec', @@ -802,7 +667,7 @@ async def async_save( future.ChainedFuture( save_futures, functools.partial( - _log_io_metrics, + async_io_engine.log_io_metrics, tree_memory_size, start_time, '/jax/orbax/write/gbytes_per_sec', @@ -867,19 +732,11 @@ async def _maybe_deserialize( restore_args, self._type_handler_registry, ) - deserialized_batches = [] - deserialized_batches_ops = [] - for request in batch_requests: - deserialized_batches_ops.append( - request.handler.deserialize(request.infos, request.args) - ) - deserialized_batches += await asyncio.gather(*deserialized_batches_ops) - - tree_memory_size = 0 + deserialized_batches, tree_memory_size = ( + await self._async_io_engine.execute_restore(batch_requests) + ) flat_restored = {} for request, deserialized in zip(batch_requests, deserialized_batches): - _, read_size = _get_batch_memory_size(request.handler, deserialized) - tree_memory_size += read_size for key, value in zip(request.keys, deserialized): flat_restored[key] = value # Add in empty nodes from the metadata tree. @@ -1065,7 +922,9 @@ class TrainState: ) # Get value metadata tree and use_zarr3 from serialized pytree metadata. internal_tree_metadata = asyncio_utils.run_sync( - self._read_metadata_file(directory) + self._metadata_manager.read_metadata_file( + directory, pytree_metadata_options=self._pytree_metadata_options + ) ) value_metadata_tree = internal_tree_metadata.as_nested_tree() if not value_metadata_tree: @@ -1174,7 +1033,7 @@ class TrainState: ) - _log_io_metrics( + async_io_engine.log_io_metrics( tree_memory_size, start_time, '/jax/checkpoint/read/gbytes_per_sec', @@ -1182,103 +1041,6 @@ class TrainState: ) return restored_item - async def _get_param_infos_with_write_shape( - self, - param_infos: PyTree, - checkpoint_dir: epath.Path, - array_metadata_store: array_metadata_store_lib.Store, - ) -> PyTree: - """Returns `param_infos` updated with `write_shape`. - - Args: - param_infos: A PyTree of ParamInfo to be updated. - checkpoint_dir: The checkpoint directory where write_shape metadata is - saved in ArrayMetadata store. - array_metadata_store: The ArrayMetadata store to read write_shape metadata - from. - """ - if not utils.is_primary_host(self._primary_host): - return param_infos - # Extract write_shape from ArrayMetadata for current process_index. - process_index = multihost.process_index() - array_metadatas = await array_metadata_store.read( - checkpoint_dir, process_index=process_index - ) - if array_metadatas is None: - jax_array_param_info = type_handlers.any_jax_array_param_info(param_infos) - if jax_array_param_info is not None: - raise ValueError( - f'No ArrayMetadata found for process_index={process_index} in the' - f' checkpoint directory: {checkpoint_dir}. But input PyTree' - ' contains at least one jax.Array param_info:' - f' {jax_array_param_info}.' - ) - return param_infos - - assert isinstance(array_metadatas, list) - array_metadatas_cache = { - array_metadata.param_name: array_metadata - for array_metadata in array_metadatas - } - - def update_param_info(param_info: types.ParamInfo) -> types.ParamInfo: - if not type_handlers.represents_jax_array(param_info): - return param_info - if param_info.name not in array_metadatas_cache: - raise ValueError( - f'No ArrayMetadata found for param_info: {param_info}, checkpoint' - f' directory: {checkpoint_dir}, process_index={process_index}.' - ) - return param_info.replace( - write_shape=array_metadatas_cache[param_info.name].write_shape - ) - - return jax.tree.map(update_param_info, param_infos) - - async def _write_metadata_file( - self, - directory: epath.Path, - *, - param_infos: PyTree, - save_args: PyTree, - custom_metadata: tree_types.JsonType | None, - use_ocdbt: bool, - use_zarr3: bool, - partial_save: bool, - ) -> None: - if utils.is_primary_host(self._primary_host): - metadata_write_start_time = time.time() - path = directory / PYTREE_METADATA_FILE - metadata_content = tree_metadata.InternalTreeMetadata.build( - param_infos, - save_args=save_args, - use_ocdbt=use_ocdbt, - use_zarr3=use_zarr3, - custom_metadata=custom_metadata, - pytree_metadata_options=self._pytree_metadata_options, - ) - - if partial_save: - old_metadata = await self._read_metadata_file(directory) - metadata_content = tree_metadata.InternalTreeMetadata.merge( - old_metadata, metadata_content, overwrite=True - ) - - logging.vlog( - 1, - 'Writing pytree metadata file: %s with pytree_metadata_options: %s', - path, - self._pytree_metadata_options, - ) - await async_path.write_text( - path, - json.dumps(metadata_content.to_json()), - ) - jax.monitoring.record_event_duration_secs( - '/jax/checkpoint/write/async/metadata_write_duration_secs', - time.time() - metadata_write_start_time, - ) - async def _write_metadata_after_commits( self, commit_futures: List[future.Future], @@ -1302,20 +1064,17 @@ async def _write_metadata_after_commits( checkpoint_dir = jax.tree.leaves(param_infos)[0].parent_dir commit_time = time.time() - # `write_shape` is extracted from ArrayMetadata store saved during - # materialization of commit_futures. Then it is written to the pytree - # metadata. - # TODO(b/390465017): Simplify all metadata related code in this module after - # removing overriding of self._write_metadata_file() in subclasses. All - # metadata related code can be moved to a separate class and - # BasePyTreeCheckpointHandler should delegate all metadata related code to - # that class. if self._array_metadata_store is not None: - param_infos = await self._get_param_infos_with_write_shape( - param_infos, checkpoint_dir, self._array_metadata_store + param_infos = ( + await self._metadata_manager.get_param_infos_with_write_shape( + param_infos, + checkpoint_dir, + array_metadata_store=self._array_metadata_store, + primary_host=self._primary_host, + ) ) - await self._write_metadata_file( + await self._metadata_manager.write_metadata_file( checkpoint_dir, param_infos=param_infos, save_args=save_args, @@ -1323,6 +1082,8 @@ async def _write_metadata_after_commits( use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, partial_save=partial_save, + primary_host=self._primary_host, + pytree_metadata_options=self._pytree_metadata_options, ) end_time = time.time() logging.info( @@ -1335,71 +1096,11 @@ async def _write_metadata_after_commits( end_time - commit_time, ) - async def _read_metadata_file( - self, directory: epath.Path - ) -> tree_metadata.InternalTreeMetadata: - """Reads metadata file and returns a tree of restore types. - - Args: - directory: directory - - Returns: - orbax.checkpoint.metadata.InternalTreeMetadata - - Raises: - FileNotFoundError: if the metadata file is not found. - """ - path = directory / PYTREE_METADATA_FILE - if not await async_path.exists(path): - raise FileNotFoundError( - f'Metadata file (named {PYTREE_METADATA_FILE}) does not exist at' - f' {directory}.' - ) - logging.vlog( - 1, - 'Reading pytree metadata file: %s with pytree_metadata_options: %s', - path, - self._pytree_metadata_options, - ) - metadata = tree_metadata.InternalTreeMetadata.from_json( - json.loads(await async_path.read_text(path)), - pytree_metadata_options=self._pytree_metadata_options, - ) - - # Log the read event for the checkpoint to the DM log. - event_tracking.record_read_metadata_event(directory) - - return metadata - - def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata: - """Returns tree metadata. - - The result will be a PyTree matching the structure of the saved checkpoint. - Note that if the item saved was a custom class, the restored metadata will - be returned as a nested dictionary representation. - - Example:: - - { - 'layer0': { - 'w': ArrayMetadata(dtype=jnp.float32, shape=(8, 8), shards=(1, 2)), - 'b': ArrayMetadata(dtype=jnp.float32, shape=(8,), shards=(1,)), - }, - 'step': ScalarMetadata(dtype=jnp.int64), - } - - If the required metadata file is not present, this method will raise an - error. - - Args: - directory: checkpoint location. - - Returns: - tree containing metadata. - """ internal_tree_metadata = asyncio_utils.run_sync( - self._read_metadata_file(directory) + self._metadata_manager.read_metadata_file( + directory, pytree_metadata_options=self._pytree_metadata_options + ) ) return tree_metadata.build_default_tree_metadata( internal_tree_metadata.as_custom_metadata( @@ -1410,57 +1111,6 @@ def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata: use_zarr3=internal_tree_metadata.use_zarr3, ) - async def _finalize_async(self, directory: epath.Path) -> None: - start_time = time.time() - finalize_coros = [] - if self._array_metadata_store is not None: - if self._primary_host is None: - logging.log_first_n( - logging.INFO, - '[process=%s] Skipped cross-host ArrayMetadata validation' - ' because all hosts are primary (e.g. local storage).', - 1, # log only once - multihost.process_index(), - ) - elif utils.is_primary_host(self._primary_host): - finalize_coros.append( - array_metadata_store_lib.validate_all_array_metadatas( - self._array_metadata_validator, - self._array_metadata_store, - directory, - ) - ) - - async def merge_ocdbt_per_process_files(): - merge_start_time = time.time() - ts_context = ts_utils.get_ts_context(use_ocdbt=True) - await ocdbt_utils.merge_ocdbt_per_process_files( - directory, - ts_context=ts_context, - use_zarr3=self._use_zarr3, - enable_validation=self._enable_post_merge_validation, - ) - jax.monitoring.record_event_duration_secs( - '/jax/checkpoint/write/async/ocdbt_merge_duration_secs', - time.time() - merge_start_time, - ) - - finalize_coros.append(merge_ocdbt_per_process_files()) - - await asyncio.gather(*finalize_coros) - end_time = time.time() - logging.info( - '[process=%s][thread=%s] Pytree save finalize (merge_ocdbt +' - ' ArrayMetadata validation) completed. Time taken: %fs. use_zarr3=%s,' - ' enable_post_merge_validation=%s, directory=%s', - multihost.process_index(), - threading.current_thread().name, - end_time - start_time, - self._use_zarr3, - self._enable_post_merge_validation, - directory, - ) - def finalize(self, directory: epath.Path) -> None: """Finalization step. @@ -1471,7 +1121,16 @@ def finalize(self, directory: epath.Path) -> None: Args: directory: Path where the checkpoint is located. """ - asyncio_utils.run_sync(self._finalize_async(directory)) + asyncio_utils.run_sync( + self._metadata_manager.finalize_async( + directory, + array_metadata_store=self._array_metadata_store, + primary_host=self._primary_host, + array_metadata_validator=self._array_metadata_validator, + use_zarr3=self._use_zarr3, + enable_post_merge_validation=self._enable_post_merge_validation, + ) + ) @register_with_handler(BasePyTreeCheckpointHandler, for_save=True) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py index 0c5de7e4f..fc296470c 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py @@ -1005,8 +1005,11 @@ def _get_internal_metadata( """ # Try reading metadata file. try: + impl = self._handler_impl internal_tree_metadata = asyncio_utils.run_sync( - self._handler_impl._read_metadata_file(directory) # pylint: disable=protected-access + impl._metadata_manager.read_metadata_file( # pylint: disable=protected-access + directory, pytree_metadata_options=impl._pytree_metadata_options # pylint: disable=protected-access + ) ) use_zarr3 = internal_tree_metadata.use_zarr3 value_metadata_tree = internal_tree_metadata.as_nested_tree() diff --git a/checkpoint/orbax/checkpoint/_src/metadata/metadata_manager.py b/checkpoint/orbax/checkpoint/_src/metadata/metadata_manager.py new file mode 100644 index 000000000..c0f3f8597 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/metadata/metadata_manager.py @@ -0,0 +1,231 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MetadataManager class for Orbax PyTree checkpointing.""" + +from __future__ import annotations + +import asyncio +import json +import threading +import time +from typing import Any, Optional +import uuid + +from absl import logging +from etils import epath +import jax +from orbax.checkpoint import utils +from orbax.checkpoint._src.logging import event_tracking +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib +from orbax.checkpoint._src.metadata import tree as tree_metadata +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import format_utils +from orbax.checkpoint._src.serialization import ocdbt_utils +from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils +from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint._src.serialization import types +from orbax.checkpoint._src.tree import types as tree_types + + + +class MetadataManager: + """Manages file-system and metadata persistence for PyTree checkpoints.""" + + async def get_param_infos_with_write_shape( + self, + param_infos: Any, + checkpoint_dir: epath.Path, + *, + array_metadata_store: Any | None, + primary_host: int | None, + ) -> Any: + """Returns `param_infos` updated with `write_shape`.""" + if array_metadata_store is None: + return param_infos + if not utils.is_primary_host(primary_host): + return param_infos + # Extract write_shape from ArrayMetadata for current process_index. + process_index = multihost.process_index() + array_metadatas = await array_metadata_store.read( + checkpoint_dir, process_index=process_index + ) + if array_metadatas is None: + jax_array_param_info = type_handlers.any_jax_array_param_info(param_infos) + if jax_array_param_info is not None: + raise ValueError( + f'No ArrayMetadata found for process_index={process_index} in the' + f' checkpoint directory: {checkpoint_dir}. But input PyTree' + ' contains at least one jax.Array param_info:' + f' {jax_array_param_info}.' + ) + return param_infos + + assert isinstance(array_metadatas, list) + array_metadatas_cache = { + array_metadata.param_name: array_metadata + for array_metadata in array_metadatas + } + + def update_param_info(param_info: types.ParamInfo) -> types.ParamInfo: + if not type_handlers.represents_jax_array(param_info): + return param_info + if param_info.name not in array_metadatas_cache: + raise ValueError( + f'No ArrayMetadata found for param_info: {param_info}, checkpoint' + f' directory: {checkpoint_dir}, process_index={process_index}.' + ) + return param_info.replace( + write_shape=array_metadatas_cache[param_info.name].write_shape + ) + + return jax.tree.map(update_param_info, param_infos) + + async def write_metadata_file( + self, + directory: epath.Path, + *, + param_infos: Any, + save_args: Any, + custom_metadata: tree_types.JsonType | None, + use_ocdbt: bool, + use_zarr3: bool, + partial_save: bool, + primary_host: int | None, + pytree_metadata_options: Any, + ) -> None: + """Writes the pytree metadata file (`_METADATA`).""" + if utils.is_primary_host(primary_host): + metadata_write_start_time = time.time() + path = directory / format_utils.PYTREE_METADATA_FILE + metadata_content = tree_metadata.InternalTreeMetadata.build( + param_infos, + save_args=save_args, + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + custom_metadata=custom_metadata, + pytree_metadata_options=pytree_metadata_options, + ) + + if partial_save: + old_metadata = await self.read_metadata_file( + directory, pytree_metadata_options=pytree_metadata_options + ) + metadata_content = tree_metadata.InternalTreeMetadata.merge( + old_metadata, metadata_content, overwrite=True + ) + + logging.vlog( + 1, + 'Writing pytree metadata file: %s with pytree_metadata_options: %s', + path, + pytree_metadata_options, + ) + await async_path.write_text( + path, + json.dumps(metadata_content.to_json()), + ) + jax.monitoring.record_event_duration_secs( + '/jax/checkpoint/write/async/metadata_write_duration_secs', + time.time() - metadata_write_start_time, + ) + + async def read_metadata_file( + self, directory: epath.Path, *, pytree_metadata_options: Any + ) -> tree_metadata.InternalTreeMetadata: + """Reads metadata file and returns internal tree metadata.""" + path = directory / format_utils.PYTREE_METADATA_FILE + if not await async_path.exists(path): + raise FileNotFoundError( + f'Metadata file (named {format_utils.PYTREE_METADATA_FILE}) does not' + f' exist at {directory}.' + ) + logging.vlog( + 1, + 'Reading pytree metadata file: %s with pytree_metadata_options: %s', + path, + pytree_metadata_options, + ) + metadata = tree_metadata.InternalTreeMetadata.from_json( + json.loads(await async_path.read_text(path)), + pytree_metadata_options=pytree_metadata_options, + ) + + # Log the read event for the checkpoint to the DM log. + event_tracking.record_read_metadata_event(directory) + + return metadata + + + async def finalize_async( + self, + directory: epath.Path, + *, + array_metadata_store: Any | None, + primary_host: int | None, + array_metadata_validator: Any, + use_zarr3: bool, + enable_post_merge_validation: bool, + ) -> None: + """Finalizes checkpoint save (merging OCDBT and validating ArrayMetadata).""" + start_time = time.time() + finalize_coros = [] + if array_metadata_store is not None: + if primary_host is None: + logging.log_first_n( + logging.INFO, + '[process=%s] Skipped cross-host ArrayMetadata validation' + ' because all hosts are primary (e.g. local storage).', + 1, # log only once + multihost.process_index(), + ) + elif utils.is_primary_host(primary_host): + finalize_coros.append( + array_metadata_store_lib.validate_all_array_metadatas( + array_metadata_validator, + array_metadata_store, + directory, + ) + ) + + async def merge_ocdbt_per_process_files(): + merge_start_time = time.time() + ts_context = ts_utils.get_ts_context(use_ocdbt=True) + await ocdbt_utils.merge_ocdbt_per_process_files( + directory, + ts_context=ts_context, + use_zarr3=use_zarr3, + enable_validation=enable_post_merge_validation, + ) + jax.monitoring.record_event_duration_secs( + '/jax/checkpoint/write/async/ocdbt_merge_duration_secs', + time.time() - merge_start_time, + ) + + finalize_coros.append(merge_ocdbt_per_process_files()) + + await asyncio.gather(*finalize_coros) + end_time = time.time() + logging.info( + '[process=%s][thread=%s] Pytree save finalize (merge_ocdbt +' + ' ArrayMetadata validation) completed. Time taken: %fs. use_zarr3=%s,' + ' enable_post_merge_validation=%s, directory=%s', + multihost.process_index(), + threading.current_thread().name, + end_time - start_time, + use_zarr3, + enable_post_merge_validation, + directory, + ) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/metadata_manager_test.py b/checkpoint/orbax/checkpoint/_src/metadata/metadata_manager_test.py new file mode 100644 index 000000000..53ab395e2 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/metadata/metadata_manager_test.py @@ -0,0 +1,216 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for MetadataManager.""" + +import unittest +from unittest import mock + +from absl.testing import absltest +from etils import epath +import jax.numpy as jnp +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib +from orbax.checkpoint._src.metadata import metadata_manager +from orbax.checkpoint._src.metadata import tree as tree_metadata +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.path import format_utils +from orbax.checkpoint._src.serialization import ocdbt_utils +from orbax.checkpoint._src.serialization import type_handler_registry as type_handler_registry_lib +from orbax.checkpoint._src.serialization import type_handlers + + +class MetadataManagerTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): + + def setUp(self): + super().setUp() + self.directory = epath.Path(self.create_tempdir().full_path) + self.options = tree_metadata.PYTREE_METADATA_OPTIONS + self.manager = metadata_manager.MetadataManager() + self.registry = type_handler_registry_lib.GLOBAL_TYPE_HANDLER_REGISTRY + + async def test_write_and_read_metadata_file(self): + typestr = type_handler_registry_lib.get_param_typestr( + 0, self.registry, self.options + ) + param_infos = { + 'x': type_handlers.ParamInfo( + name='x', + parent_dir=self.directory, + skip_deserialize=False, + value_typestr=typestr, + ), + 'y': type_handlers.ParamInfo( + name='y', + parent_dir=self.directory, + skip_deserialize=False, + value_typestr=typestr, + ), + } + save_args = { + 'x': type_handlers.SaveArgs(), + 'y': type_handlers.SaveArgs(), + } + custom_metadata = {'step': 10} + + await self.manager.write_metadata_file( + self.directory, + param_infos=param_infos, + save_args=save_args, + custom_metadata=custom_metadata, + use_ocdbt=False, + use_zarr3=False, + partial_save=False, + primary_host=0, + pytree_metadata_options=self.options, + ) + + metadata_path = self.directory / format_utils.PYTREE_METADATA_FILE + self.assertTrue(metadata_path.exists()) + + actual = await self.manager.read_metadata_file( + self.directory, pytree_metadata_options=self.options + ) + self.assertEqual(actual.custom_metadata, custom_metadata) + tree = actual.as_nested_tree() + self.assertIn('x', tree) + self.assertIn('y', tree) + + async def test_write_metadata_file_partial_save(self): + typestr = type_handler_registry_lib.get_param_typestr( + 0, self.registry, self.options + ) + param_infos1 = { + 'x': type_handlers.ParamInfo( + name='x', + parent_dir=self.directory, + skip_deserialize=False, + value_typestr=typestr, + ), + } + save_args1 = { + 'x': type_handlers.SaveArgs(), + } + + await self.manager.write_metadata_file( + self.directory, + param_infos=param_infos1, + save_args=save_args1, + custom_metadata={'version': 1}, + use_ocdbt=False, + use_zarr3=False, + partial_save=False, + primary_host=0, + pytree_metadata_options=self.options, + ) + + param_infos2 = { + 'y': type_handlers.ParamInfo( + name='y', + parent_dir=self.directory, + skip_deserialize=False, + value_typestr=typestr, + ), + } + save_args2 = { + 'y': type_handlers.SaveArgs(), + } + + await self.manager.write_metadata_file( + self.directory, + param_infos=param_infos2, + save_args=save_args2, + custom_metadata={'version': 2}, + use_ocdbt=False, + use_zarr3=False, + partial_save=True, + primary_host=0, + pytree_metadata_options=self.options, + ) + + actual = await self.manager.read_metadata_file( + self.directory, pytree_metadata_options=self.options + ) + tree = actual.as_nested_tree() + self.assertIn('x', tree) + self.assertIn('y', tree) + self.assertEqual(actual.custom_metadata, {'version': 2}) + + async def test_read_metadata_file_not_found(self): + with self.assertRaises(FileNotFoundError): + await self.manager.read_metadata_file( + self.directory, pytree_metadata_options=self.options + ) + + async def test_finalize_async(self): + mock_store = mock.MagicMock(spec=array_metadata_store_lib.Store) + mock_validator = mock.MagicMock(spec=array_metadata_store_lib.Validator) + + with mock.patch.object( + array_metadata_store_lib, + 'validate_all_array_metadatas', + new_callable=mock.AsyncMock, + ) as mock_validate, mock.patch.object( + ocdbt_utils, + 'merge_ocdbt_per_process_files', + new_callable=mock.AsyncMock, + ) as mock_merge: + await self.manager.finalize_async( + self.directory, + array_metadata_store=mock_store, + primary_host=0, + array_metadata_validator=mock_validator, + use_zarr3=False, + enable_post_merge_validation=True, + ) + + mock_validate.assert_awaited_once_with( + mock_validator, mock_store, self.directory + ) + mock_merge.assert_awaited_once() + + async def test_get_param_infos_with_write_shape(self): + mock_store = mock.MagicMock(spec=array_metadata_store_lib.Store) + + mock_array_metadata = mock.MagicMock() + mock_array_metadata.param_name = 'x' + mock_array_metadata.write_shape = (10, 20) + mock_store.read = mock.AsyncMock(return_value=[mock_array_metadata]) + + typestr = type_handler_registry_lib.get_param_typestr( + jnp.zeros((10, 20), dtype=jnp.float32), + self.registry, + self.options, + ) + param_infos = { + 'x': type_handlers.ParamInfo( + name='x', + parent_dir=self.directory, + skip_deserialize=False, + value_typestr=typestr, + ), + } + + with mock.patch.object(multihost, 'process_index', return_value=0): + updated_param_infos = await self.manager.get_param_infos_with_write_shape( + param_infos, + self.directory, + array_metadata_store=mock_store, + primary_host=0, + ) + + self.assertEqual(updated_param_infos['x'].write_shape, (10, 20)) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py index b1ec35e73..88c2c06cb 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py @@ -319,7 +319,15 @@ def __init__( async def _finalize(self, directory: path_types.Path): if multihost.is_primary_host(self._multiprocessing_options.primary_host): - await self._handler_impl._finalize_async(directory) # pylint: disable=protected-access + impl = self._handler_impl + await impl._metadata_manager.finalize_async( # pylint: disable=protected-access + directory, + array_metadata_store=impl._array_metadata_store, # pylint: disable=protected-access + primary_host=impl._primary_host, # pylint: disable=protected-access + array_metadata_validator=impl._array_metadata_validator, # pylint: disable=protected-access + use_zarr3=impl._use_zarr3, # pylint: disable=protected-access + enable_post_merge_validation=impl._enable_post_merge_validation, # pylint: disable=protected-access + ) async def _background_save( self, diff --git a/checkpoint/orbax/checkpoint/metadata/__init__.py b/checkpoint/orbax/checkpoint/metadata/__init__.py index f23544de3..61496f96a 100644 --- a/checkpoint/orbax/checkpoint/metadata/__init__.py +++ b/checkpoint/orbax/checkpoint/metadata/__init__.py @@ -21,6 +21,7 @@ from orbax.checkpoint._src.metadata.checkpoint import MetadataStore from orbax.checkpoint._src.metadata.checkpoint import metadata_store from orbax.checkpoint._src.metadata.step_metadata_serialization import get_step_metadata +from orbax.checkpoint._src.metadata.metadata_manager import MetadataManager from orbax.checkpoint._src.metadata.sharding import ShardingMetadata from orbax.checkpoint._src.metadata.sharding import NamedShardingMetadata