From 78ef58c3e2e0c34d01d9e5175feb955713c40978 Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Wed, 29 Apr 2026 11:28:39 -0700 Subject: [PATCH] Redesign partial saving to use pending directories. PiperOrigin-RevId: 907692436 --- .../base_pytree_checkpoint_handler.py | 124 +++++--- .../checkpoint/_src/path/snapshot/snapshot.py | 11 +- .../v1/_src/handlers/pytree_handler.py | 5 + .../layout/orbax_layout_multiprocess_test.py | 42 ++- .../experimental/v1/_src/partial/saving.py | 274 +++++++++++++++++- .../experimental/v1/_src/saving/execution.py | 8 +- 6 files changed, 376 insertions(+), 88 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index 014a88527..95d91c65e 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -48,8 +48,10 @@ from orbax.checkpoint._src.metadata import tree as tree_metadata from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import atomicity_types from orbax.checkpoint._src.path import format_utils from orbax.checkpoint._src.path import types as path_types +from orbax.checkpoint._src.path.snapshot import snapshot from orbax.checkpoint._src.serialization import limits from orbax.checkpoint._src.serialization import memory_regulator from orbax.checkpoint._src.serialization import ocdbt_utils @@ -81,6 +83,8 @@ PYTREE_METADATA_FILE = format_utils.PYTREE_METADATA_FILE PLACEHOLDER = type_handlers.PLACEHOLDER PLACEHOLDER_TYPESTR = type_handlers.PLACEHOLDER_TYPESTR +TMP_DIR_SUFFIX = atomicity_types.TMP_DIR_SUFFIX +PENDING_DIR_SUFFIX = snapshot.PENDING_DIR_SUFFIX DEFAULT_CONCURRENT_GB = 96 @@ -563,48 +567,82 @@ def _param_info(keypath, name, value): _param_info, names, item, is_leaf=utils.is_empty_or_leaf ) - async def _async_partial_save( + async def _get_partial_save_additions( self, directory: epath.Path, - item: PyTree, - batch_requests: list[_BatchRequest], - param_infos: PyTree, - save_args: BasePyTreeSaveArgs, - ): - value_metadata_tree = ( - await self._read_metadata_file(directory) - ).as_nested_tree() + flat_item: dict[Any, Any], + ) -> set[Any]: + # Reconstruct the base partial path from the temporary directory. + # The temporary directory should be named + # `{checkpoint_name}.partial_save.orbax-checkpoint-tmp...` + # and we want to find all other pending saves for this checkpoint. + tmp_dir = directory.parent + if not tmp_dir.name.endswith(TMP_DIR_SUFFIX): + raise ValueError( + f'Expected temporary directory name to end with {TMP_DIR_SUFFIX}, ' + f'but got {tmp_dir.name}. Partial saving requires a TemporaryPath ' + 'class that supports snapshots.' + ) + base_name = tmp_dir.name[: -len(TMP_DIR_SUFFIX)] + partial_path = tmp_dir.parent / base_name - tree_diff = tree_structure_utils.tree_difference(item, value_metadata_tree) + # Glob for metadata files written by previous partial saves in this session. + pending_metadata_files = await async_path.glob( + partial_path, + f'*{PENDING_DIR_SUFFIX}*/{directory.name}/{PYTREE_METADATA_FILE}', + ) - additions = set() + # Merge tree_metadata from all pending metadata files. + merged_tree_meta = {} + for meta_file in pending_metadata_files: + meta = json.loads(await async_path.read_text(meta_file)) + merged_tree_meta.update(meta.get('tree_metadata', {})) - def _handle_diffs(keypath, diff): - keypath = tree_utils.tuple_path_from_keypath(keypath) - if diff.lhs is not None: # Leaf is present in the current item - if diff.rhs is None: # Leaf was not in the on-disk metadata - additions.add(keypath) - else: # Leaf was also in the on-disk metadata - raise PartialSaveReplacementError( - f'Key "{keypath}" was found in the on-disk PyTree metadata and' - ' supplied item. Partial saving currently does not support' - ' REPLACEMENT. Please reach out to the Orbax team if you need' - ' this feature.' - ) + def _is_prefix(t1, t2): + return len(t1) < len(t2) and t2[: len(t1)] == t1 - jax.tree.map_with_path( - _handle_diffs, - tree_diff, - is_leaf=lambda x: isinstance(x, tree_structure_utils.Diff), - ) + # Extract tuple keys from the inner metadata payload to avoid parsing string + # keys with `ast.literal_eval`, which is brittle. + merged_tuples = [] + for v in merged_tree_meta.values(): + if 'key_metadata' in v: + merged_tuples.append(tuple(entry['key'] for entry in v['key_metadata'])) + + # Check for replacements vs. additions by comparing keys from the current + # save request against the merged metadata of previous pending saves. + additions = set() + for key in flat_item: + key_str = str(key) + is_replacement = False + if key_str in merged_tree_meta: + is_replacement = True + else: + for mt in merged_tuples: + if isinstance(mt, tuple) and isinstance(key, tuple): + if _is_prefix(key, mt) or _is_prefix(mt, key): + is_replacement = True + break + + if is_replacement: + raise PartialSaveReplacementError( + f'Key "{key}" was found in a previous partial save in this session.' + ' Partial saving currently does not support REPLACEMENT.' + ) + else: + additions.add(key) logging.info( '[process=%d] Found the following additions during partial save: %s', multihost.process_index(), additions, ) + return additions - # Filter out requests that don't have any additions. + def _filter_batch_requests( + self, + batch_requests: list[_BatchRequest], + additions: set[Any], + ) -> list[_BatchRequest]: filtered_requests = [] for request in batch_requests: filtered_items = [] @@ -626,6 +664,19 @@ def _handle_diffs(keypath, diff): args=list(args), ) ) + return filtered_requests + + async def _async_partial_save( + self, + directory: epath.Path, + item: PyTree, + batch_requests: list[_BatchRequest], + param_infos: PyTree, + save_args: BasePyTreeSaveArgs, + ): + flat_item = tree_utils.to_flat_dict(item) + additions = await self._get_partial_save_additions(directory, flat_item) + filtered_requests = self._filter_batch_requests(batch_requests, additions) serialize_ops = [] tree_memory_size = 0 @@ -733,12 +784,9 @@ async def async_save( self._type_handler_registry, ) - is_partial_save = args.partial_save_mode and await async_path.exists( - directory / PYTREE_METADATA_FILE - ) batch_requests_ready_time = time.time() with _memory_profiler_context(): - if is_partial_save: + if args.partial_save_mode: serialize_ops, tree_memory_size, param_infos, save_args = ( await self._async_partial_save( directory, item, batch_requests, param_infos, save_args @@ -784,7 +832,6 @@ async def async_save( custom_metadata=custom_metadata, use_ocdbt=self._use_ocdbt, use_zarr3=self._use_zarr3, - partial_save=is_partial_save, ), name='write_metadata_after_commits', ) @@ -1244,7 +1291,6 @@ async def _write_metadata_file( custom_metadata: tree_types.JsonType | None, use_ocdbt: bool, use_zarr3: bool, - partial_save: bool, ) -> None: if utils.is_primary_host(self._primary_host): metadata_write_start_time = time.time() @@ -1258,12 +1304,6 @@ async def _write_metadata_file( pytree_metadata_options=self._pytree_metadata_options, ) - if partial_save: - old_metadata = await self._read_metadata_file(directory) - metadata_content = tree_metadata.InternalTreeMetadata.merge( - old_metadata, metadata_content, overwrite=True - ) - logging.vlog( 1, 'Writing pytree metadata file: %s with pytree_metadata_options: %s', @@ -1288,7 +1328,6 @@ async def _write_metadata_after_commits( custom_metadata: tree_types.JsonType | None, use_ocdbt: bool, use_zarr3: bool, - partial_save: bool, ) -> None: start_time = time.time() if not utils.is_primary_host(self._primary_host): @@ -1322,7 +1361,6 @@ async def _write_metadata_after_commits( custom_metadata=custom_metadata, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, - partial_save=partial_save, ) end_time = time.time() logging.info( diff --git a/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py b/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py index d22b9ba9b..9cd558170 100644 --- a/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py +++ b/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py @@ -26,12 +26,14 @@ from orbax.checkpoint._src.path import utils as ocp_path_utils -SNAPSHOTTING_TIME = "snapshotting_time" -PENDING_DIR_SUFFIX = ".pending_" + +PENDING_DIR_SUFFIX = ".pending" def get_pending_dir_name(source_name: str) -> str: - return f"{source_name}{PENDING_DIR_SUFFIX}{uuid.uuid4().hex}" + return ( + f"{source_name}{PENDING_DIR_SUFFIX}_{time.time_ns()}_{uuid.uuid4().hex}" + ) def get_uuid_from_pending_dir_name(pending_dir_name: str) -> str: @@ -169,8 +171,7 @@ async def replace_source(self) -> None: if not await async_path.exists(self._snapshot): raise FileNotFoundError(f"Snapshot does not exist: {self._snapshot}") - if not await async_path.exists(self._source): - await async_path.mkdir(self._source, parents=True, exist_ok=True) + await async_path.mkdir(self._source, parents=True, exist_ok=True) # Move files from inside the tmp snapshot into the original source # directory under a pending suffix. This is to avoid potentially wiping # out previous files. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py index b1ec35e73..1654e89f4 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py @@ -318,6 +318,11 @@ def __init__( ) async def _finalize(self, directory: path_types.Path): + # Keep non-finalized checkpoint state during partial saves to be merged + # later during partial save finalization. + if self._partial_save_mode: + return + if multihost.is_primary_host(self._multiprocessing_options.primary_host): await self._handler_impl._finalize_async(directory) # pylint: disable=protected-access diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py index b880ad7f6..30a15436f 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py @@ -23,6 +23,8 @@ from etils import epath from orbax.checkpoint import test_utils from orbax.checkpoint._src.metadata import step_metadata_serialization +from orbax.checkpoint._src.path import atomicity_types +from orbax.checkpoint._src.path.snapshot import snapshot from orbax.checkpoint._src.testing import multiprocess_test from orbax.checkpoint._src.tree import structure_utils as tree_structure_utils from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler @@ -88,11 +90,17 @@ def save( *, partial_save: bool = False, ): + if partial_save: + final_dir = directory + directory = ( + directory.parent / f'{directory.name}{atomicity_types.TMP_DIR_SUFFIX}' + ) + test_utils.sync_global_processes('CompositeHandlerTest:save:start') if multihost.is_primary_host(0): - directory.mkdir(parents=False, exist_ok=partial_save) + directory.mkdir(parents=True, exist_ok=True) for k in checkpointables: - (directory / k).mkdir(parents=False, exist_ok=partial_save) + (directory / k).mkdir(parents=True, exist_ok=True) test_utils.sync_global_processes('CompositeHandlerTest:save:mkdir') async def _save(): @@ -107,20 +115,8 @@ async def _save(): for name in checkpointables.keys() } - checkpoint_metadata_path = ( - metadata_serialization.checkpoint_metadata_file_path(directory) - ) - if partial_save and checkpoint_metadata_path.exists(): - checkpoint_metadata = await orbax_layout.read_checkpoint_metadata( - directory - ) - old_handler_typestrs = checkpoint_metadata.item_handlers - handler_typestrs = old_handler_typestrs | handler_typestrs - await multihost.sync_global_processes( - 'CompositeHandlerTest:save:checkpoint_metadata_read', - operation_id='op', - processes=None, - ) + # For partial save in this test, we skip reading existing global metadata + # here since it will be merged during finalize, just like real execution. # Metadata expected to be created outside the handler. if multihost.is_primary_host(0): @@ -145,6 +141,11 @@ async def _save(): ) await awaitable + if partial_save and multihost.is_primary_host(0): + final_dir.mkdir(parents=True, exist_ok=True) + pending_dir = final_dir / snapshot.get_pending_dir_name(final_dir.name) + directory.rename(pending_dir) + asyncio.run(_save()) test_utils.sync_global_processes('CompositeHandlerTest:save:complete') @@ -354,7 +355,6 @@ def test_partial_save_and_finalize(self, finalize_with_partial_path: bool): layout, partial_path, first_save_checkpointables, partial_save=True ) self.assertTrue(partial_path.exists()) - self.assertTrue((partial_path / ORBAX_CHECKPOINT_INDICATOR_FILE).exists()) self.save( layout, @@ -363,14 +363,6 @@ def test_partial_save_and_finalize(self, finalize_with_partial_path: bool): partial_save=True, ) self.assertTrue(partial_path.exists()) - self.assertTrue((partial_path / ORBAX_CHECKPOINT_INDICATOR_FILE).exists()) - - restored_checkpointables = self.load( - layout, partial_path, merged_checkpointables - ) - test_utils.assert_tree_equal( - self, restored_checkpointables, merged_checkpointables - ) partial_saving.finalize( partial_path if finalize_with_partial_path else final_path diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py index f36a01801..c02812480 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py @@ -14,19 +14,27 @@ """Defines free-function interface for partial saving and finalizing.""" +import ast import asyncio import dataclasses +import json +import logging from typing import Awaitable +from etils import epath from orbax.checkpoint._src import asyncio_utils +from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import format_utils from orbax.checkpoint._src.path import utils as ocp_path_utils +from orbax.checkpoint._src.path.snapshot import snapshot from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.handlers import global_registration # pylint: disable=unused-import from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler from orbax.checkpoint.experimental.v1._src.handlers import stateful_checkpointable_handler from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout +from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization from orbax.checkpoint.experimental.v1._src.partial import path as partial_path_lib from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.saving import execution @@ -36,6 +44,7 @@ PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY +CHECKPOINT_METADATA_FILENAME = metadata_serialization._CHECKPOINT_METADATA_FILENAME # pylint: disable=protected-access StatefulCheckpointableHandler = ( stateful_checkpointable_handler.StatefulCheckpointableHandler @@ -230,6 +239,245 @@ def save_pytree_async( ) +async def _extract_use_zarr3(pending_dirs: list[epath.Path]) -> bool: + """Extracts the use_zarr3 flag from the first pending directory's metadata.""" + if not pending_dirs: + return False + + for p_dir in pending_dirs: + for child in await async_path.iterdir(p_dir): + if await async_path.is_dir(child): + meta_file = child / format_utils.PYTREE_METADATA_FILE + if await async_path.exists(meta_file): + try: + meta = json.loads(await async_path.read_text(meta_file)) + return meta.get('use_zarr3', False) + except json.JSONDecodeError as e: + raise ValueError( + f'Failed to read use_zarr3 from metadata file {meta_file}: {e}' + ) from e + return False + + +def _recursive_dict_merge(dst, src): + for k, v in src.items(): + if k in dst and isinstance(dst[k], dict) and isinstance(v, dict): + _recursive_dict_merge(dst[k], v) + else: + dst[k] = v + + +def _is_prefix(t1, t2): + return len(t1) < len(t2) and t2[: len(t1)] == t1 + + +def _filter_conflicting_keys(d): + """Filters out conflicting keys from a dictionary of PyTree metadata.""" + keys = sorted(d.keys()) + to_remove = set() + + parsed_keys = {} + for k in keys: + try: + parsed_keys[k] = ast.literal_eval(k) + except (ValueError, SyntaxError): + parsed_keys[k] = k + + # Since keys are lexicographically sorted, any extensions of a prefix + # will immediately follow the prefix in the list. Thus, we only need + # to compare adjacent elements. + for i in range(len(keys) - 1): + k1, k2 = keys[i], keys[i + 1] + t1, t2 = parsed_keys[k1], parsed_keys[k2] + if isinstance(t1, tuple) and isinstance(t2, tuple): + if _is_prefix(t1, t2): + to_remove.add(k1) + elif isinstance(k1, str) and isinstance(k2, str): + if k2.startswith(k1 + '.') or k2.startswith(k1 + '/'): + to_remove.add(k1) + + for k in to_remove: + del d[k] + return d + + +async def _merge_pytree_metadata(src_item: epath.Path, dst_item: epath.Path): + """Merges PyTree metadata files (_METADATA or _sharding).""" + if not await async_path.exists(dst_item): + await async_path.rename(src_item, dst_item) + return + + src_meta = json.loads(await async_path.read_text(src_item)) + dst_meta = json.loads(await async_path.read_text(dst_item)) + + _recursive_dict_merge(dst_meta, src_meta) + + if 'tree_metadata' in dst_meta: + dst_meta['tree_metadata'] = _filter_conflicting_keys( + dst_meta['tree_metadata'] + ) + + await async_path.write_text(dst_item, json.dumps(dst_meta)) + await async_path.unlink(src_item) + + +async def _rename_ocdbt_process_dir( + item: epath.Path, pytree_dst: epath.Path, uuid_str: str +): + """Renames an ocdbt.process_ directory to avoid collisions across partial saves.""" + # To avoid collisions across different partial save pending directories, + # we append the pending dir's UUID to the original process ID. + # We must avoid using '_' in the new ID because `ocdbt_utils.py` splits + # the directory name by '_' to extract the process ID. + new_name = f'{item.name}{uuid_str.replace("-", "")}' + await async_path.rename(item, pytree_dst / new_name) + + +async def _merge_array_metadatas(src_dir: epath.Path, dst_dir: epath.Path): + """Merges array_metadatas JSON files (process_0, process_1, etc.).""" + await async_path.mkdir(dst_dir, parents=True, exist_ok=True) + + async def _process_child(src_child: epath.Path): + dst_child = dst_dir / src_child.name + + if not await async_path.exists(dst_child): + await async_path.rename(src_child, dst_child) + return + + src_meta = json.loads(await async_path.read_text(src_child)) + dst_meta = json.loads(await async_path.read_text(dst_child)) + + src_arr_meta = src_meta.get('array_metadatas', []) + dst_arr_meta = dst_meta.get('array_metadatas', []) + dst_arr_meta.extend(src_arr_meta) + dst_meta['array_metadatas'] = dst_arr_meta + + await async_path.write_text(dst_child, json.dumps(dst_meta)) + await async_path.unlink(src_child) + + await asyncio.gather(*[ + _process_child(src_child) + for src_child in await async_path.iterdir(src_dir) + ]) + + +async def _recursive_merge(src: epath.Path, dst: epath.Path): + """Recursively merge src into dst.""" + if not await async_path.exists(src): + return + + if not await async_path.exists(dst): + await async_path.rename(src, dst) + return + + if await async_path.is_dir(src): + items = await async_path.iterdir(src) + await asyncio.gather( + *[_recursive_merge(item, dst / item.name) for item in items] + ) + await async_path.rmtree(src) + return + + logging.warning( + 'File collision on %s during finalize. Overwriting destination file.', + src.name, + ) + if await async_path.is_dir(dst): + await async_path.rmtree(dst) + else: + await async_path.unlink(dst) + await async_path.rename(src, dst) + + +async def _merge_pytree_directory( + pytree_src: epath.Path, + pytree_dst: epath.Path, + uuid_str: str, +): + """Merges a single pending pytree directory into the destination.""" + if not await async_path.exists(pytree_src): + return + + async def _merge_item(item: epath.Path): + if item.name in [format_utils.PYTREE_METADATA_FILE, '_sharding']: + await _merge_pytree_metadata(item, pytree_dst / item.name) + elif item.name.startswith('ocdbt.process_'): + await _rename_ocdbt_process_dir(item, pytree_dst, uuid_str) + elif item.name == 'array_metadatas': + await _merge_array_metadatas(item, pytree_dst / item.name) + else: + await _recursive_merge(item, pytree_dst / item.name) + + await asyncio.gather( + *[_merge_item(item) for item in await async_path.iterdir(pytree_src)] + ) + + await async_path.rmtree(pytree_src) + + +async def _merge_checkpoint_metadata(src: epath.Path, dst: epath.Path): + """Merges checkpoint metadata.""" + if not await async_path.exists(dst): + await async_path.rename(src, dst) + return + + src_meta = json.loads(await async_path.read_text(src)) + dst_meta = json.loads(await async_path.read_text(dst)) + + _recursive_dict_merge(dst_meta, src_meta) + + await async_path.write_text(dst, json.dumps(dst_meta)) + await async_path.unlink(src) + + +async def _merge_all(partial_path: epath.Path): + """Merge all pending directories into the partial path.""" + + # Each partial save call results in a new pending directory containing unique + # PyTree keypaths and corresponding data. During finalization, all pending + # directories are merged to form the final checkpoint state. + pending_dirs = await snapshot.list_pending_dirs(partial_path) + + # Ensure deterministic merge order (alphabetical glob + timestamp) + pending_dirs.sort() + + use_zarr3 = await _extract_use_zarr3(pending_dirs) + + pytree_directories = set() + + for p_dir in pending_dirs: + uuid_str = snapshot.get_uuid_from_pending_dir_name(p_dir.name) + + async def _process_item(item: epath.Path, uuid_str: str): + if item.name == CHECKPOINT_METADATA_FILENAME: + await _merge_checkpoint_metadata(item, partial_path / item.name) + elif await async_path.is_dir(item) and await async_path.exists( + item / format_utils.PYTREE_METADATA_FILE + ): + pytree_directories.add(item.name) + pytree_dst = partial_path / item.name + await async_path.mkdir(pytree_dst, parents=True, exist_ok=True) + await _merge_pytree_directory(item, pytree_dst, uuid_str) + else: + await _recursive_merge(item, partial_path / item.name) + + await asyncio.gather(*[ + _process_item(item, uuid_str) + for item in await async_path.iterdir(p_dir) + ]) + + await async_path.rmtree(p_dir) + + # 3. Call PyTreeHandler.finalize to perform OCDBT merge. + # This merges the individual ocdbt.process_xxx directories into a single + # valid manifest for the final partial state. + handler = base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler( + use_zarr3=use_zarr3 + ) + for pytree_dir_name in pytree_directories: + await asyncio.to_thread(handler.finalize, partial_path / pytree_dir_name) + + def finalize(path: path_types.PathLike) -> None: """Finalizes a partially-saved checkpoint, making it permanent and readable. @@ -303,17 +551,23 @@ async def _finalize_impl(): processes=context.multiprocessing_options.active_processes, ) - rename_failed = False - rename_error = None + finalize_failed = False + finalize_error = None if multihost.is_primary_host(context.multiprocessing_options.primary_host): + try: + await _merge_all(partial_path) + except ValueError as e: + finalize_failed = True + finalize_error = e + try: await async_path.rename(partial_path, final_path) except OSError as e: - rename_failed = True - rename_error = e + finalize_failed = True + finalize_error = e - rename_failed = multihost.broadcast_one_to_all( - rename_failed, + finalize_failed = multihost.broadcast_one_to_all( + finalize_failed, is_source=multihost.is_primary_host( context.multiprocessing_options.primary_host ), @@ -328,9 +582,9 @@ async def _finalize_impl(): processes=context.multiprocessing_options.active_processes, ) - if rename_failed: - if rename_error is not None: - raise rename_error - raise OSError('Partial checkpoint finalization failed during rename.') + if finalize_failed: + if finalize_error is not None: + raise finalize_error + raise OSError('Partial checkpoint finalization failed.') asyncio_utils.run_sync(_finalize_impl()) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py index f5f20373b..f99229fa2 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py @@ -85,7 +85,7 @@ def add_internal_checkpointables( class _SaveResponse(AsyncResponse[None]): - """An :py:class:`.AsyncResponse` representing the result of:py:func:`.save_pytree_async`.""" + """:py:class:`.AsyncResponse`, result of :py:func:`.save_pytree_async`.""" def __init__( self, @@ -389,17 +389,15 @@ def save_checkpointables_impl( path = context.file_options.path_class(path) _check_directory_consistency(path) - path_exists = path.exists() if partial_save else False # Prevent internal mutation from affecting the caller. checkpointables = dict(checkpointables) checkpointables = add_internal_checkpointables( checkpointables, context=context ) - subdirectories = [] if path_exists else checkpointables.keys() - snapshot_type = snapshot_lib.SnapshotType.IN_PLACE if path_exists else None + snapshot_type = snapshot_lib.SnapshotType.EMPTY if partial_save else None temporary_path = _TemporaryPathAwaitingCreation( path, - subdirectories=subdirectories, + subdirectories=checkpointables.keys(), snapshot_type=snapshot_type, ) background_awaitable = asyncio_utils.run_sync(