Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import format_utils
from orbax.checkpoint._src.path import types as path_types
from orbax.checkpoint._src.path.snapshot import snapshot
from orbax.checkpoint._src.serialization import limits
from orbax.checkpoint._src.serialization import memory_regulator
from orbax.checkpoint._src.serialization import ocdbt_utils
Expand Down Expand Up @@ -81,6 +83,8 @@
PYTREE_METADATA_FILE = format_utils.PYTREE_METADATA_FILE
PLACEHOLDER = type_handlers.PLACEHOLDER
PLACEHOLDER_TYPESTR = type_handlers.PLACEHOLDER_TYPESTR
TMP_DIR_SUFFIX = atomicity_types.TMP_DIR_SUFFIX
PENDING_DIR_SUFFIX = snapshot.PENDING_DIR_SUFFIX

DEFAULT_CONCURRENT_GB = 96

Expand Down Expand Up @@ -563,48 +567,82 @@ def _param_info(keypath, name, value):
_param_info, names, item, is_leaf=utils.is_empty_or_leaf
)

async def _async_partial_save(
async def _get_partial_save_additions(
self,
directory: epath.Path,
item: PyTree,
batch_requests: list[_BatchRequest],
param_infos: PyTree,
save_args: BasePyTreeSaveArgs,
):
value_metadata_tree = (
await self._read_metadata_file(directory)
).as_nested_tree()
flat_item: dict[Any, Any],
) -> set[Any]:
# Reconstruct the base partial path from the temporary directory.
# The temporary directory should be named
# `{checkpoint_name}.partial_save.orbax-checkpoint-tmp...`
# and we want to find all other pending saves for this checkpoint.
tmp_dir = directory.parent
if not tmp_dir.name.endswith(TMP_DIR_SUFFIX):
raise ValueError(
f'Expected temporary directory name to end with {TMP_DIR_SUFFIX}, '
f'but got {tmp_dir.name}. Partial saving requires a TemporaryPath '
'class that supports snapshots.'
)
base_name = tmp_dir.name[: -len(TMP_DIR_SUFFIX)]
partial_path = tmp_dir.parent / base_name

tree_diff = tree_structure_utils.tree_difference(item, value_metadata_tree)
# Glob for metadata files written by previous partial saves in this session.
pending_metadata_files = await async_path.glob(
partial_path,
f'*{PENDING_DIR_SUFFIX}*/{directory.name}/{PYTREE_METADATA_FILE}',
)

additions = set()
# Merge tree_metadata from all pending metadata files.
merged_tree_meta = {}
for meta_file in pending_metadata_files:
meta = json.loads(await async_path.read_text(meta_file))
merged_tree_meta.update(meta.get('tree_metadata', {}))

def _handle_diffs(keypath, diff):
keypath = tree_utils.tuple_path_from_keypath(keypath)
if diff.lhs is not None: # Leaf is present in the current item
if diff.rhs is None: # Leaf was not in the on-disk metadata
additions.add(keypath)
else: # Leaf was also in the on-disk metadata
raise PartialSaveReplacementError(
f'Key "{keypath}" was found in the on-disk PyTree metadata and'
' supplied item. Partial saving currently does not support'
' REPLACEMENT. Please reach out to the Orbax team if you need'
' this feature.'
)
def _is_prefix(t1, t2):
return len(t1) < len(t2) and t2[: len(t1)] == t1

jax.tree.map_with_path(
_handle_diffs,
tree_diff,
is_leaf=lambda x: isinstance(x, tree_structure_utils.Diff),
)
# Extract tuple keys from the inner metadata payload to avoid parsing string
# keys with `ast.literal_eval`, which is brittle.
merged_tuples = []
for v in merged_tree_meta.values():
if 'key_metadata' in v:
merged_tuples.append(tuple(entry['key'] for entry in v['key_metadata']))

# Check for replacements vs. additions by comparing keys from the current
# save request against the merged metadata of previous pending saves.
additions = set()
for key in flat_item:
key_str = str(key)
is_replacement = False
if key_str in merged_tree_meta:
is_replacement = True
else:
for mt in merged_tuples:
if isinstance(mt, tuple) and isinstance(key, tuple):
if _is_prefix(key, mt) or _is_prefix(mt, key):
is_replacement = True
break

if is_replacement:
raise PartialSaveReplacementError(
f'Key "{key}" was found in a previous partial save in this session.'
' Partial saving currently does not support REPLACEMENT.'
)
else:
additions.add(key)

logging.info(
'[process=%d] Found the following additions during partial save: %s',
multihost.process_index(),
additions,
)
return additions

# Filter out requests that don't have any additions.
def _filter_batch_requests(
self,
batch_requests: list[_BatchRequest],
additions: set[Any],
) -> list[_BatchRequest]:
filtered_requests = []
for request in batch_requests:
filtered_items = []
Expand All @@ -626,6 +664,19 @@ def _handle_diffs(keypath, diff):
args=list(args),
)
)
return filtered_requests

async def _async_partial_save(
self,
directory: epath.Path,
item: PyTree,
batch_requests: list[_BatchRequest],
param_infos: PyTree,
save_args: BasePyTreeSaveArgs,
):
flat_item = tree_utils.to_flat_dict(item)
additions = await self._get_partial_save_additions(directory, flat_item)
filtered_requests = self._filter_batch_requests(batch_requests, additions)

serialize_ops = []
tree_memory_size = 0
Expand Down Expand Up @@ -733,12 +784,9 @@ async def async_save(
self._type_handler_registry,
)

is_partial_save = args.partial_save_mode and await async_path.exists(
directory / PYTREE_METADATA_FILE
)
batch_requests_ready_time = time.time()
with _memory_profiler_context():
if is_partial_save:
if args.partial_save_mode:
serialize_ops, tree_memory_size, param_infos, save_args = (
await self._async_partial_save(
directory, item, batch_requests, param_infos, save_args
Expand Down Expand Up @@ -784,7 +832,6 @@ async def async_save(
custom_metadata=custom_metadata,
use_ocdbt=self._use_ocdbt,
use_zarr3=self._use_zarr3,
partial_save=is_partial_save,
),
name='write_metadata_after_commits',
)
Expand Down Expand Up @@ -1244,7 +1291,6 @@ async def _write_metadata_file(
custom_metadata: tree_types.JsonType | None,
use_ocdbt: bool,
use_zarr3: bool,
partial_save: bool,
) -> None:
if utils.is_primary_host(self._primary_host):
metadata_write_start_time = time.time()
Expand All @@ -1258,12 +1304,6 @@ async def _write_metadata_file(
pytree_metadata_options=self._pytree_metadata_options,
)

if partial_save:
old_metadata = await self._read_metadata_file(directory)
metadata_content = tree_metadata.InternalTreeMetadata.merge(
old_metadata, metadata_content, overwrite=True
)

logging.vlog(
1,
'Writing pytree metadata file: %s with pytree_metadata_options: %s',
Expand All @@ -1288,7 +1328,6 @@ async def _write_metadata_after_commits(
custom_metadata: tree_types.JsonType | None,
use_ocdbt: bool,
use_zarr3: bool,
partial_save: bool,
) -> None:
start_time = time.time()
if not utils.is_primary_host(self._primary_host):
Expand Down Expand Up @@ -1322,7 +1361,6 @@ async def _write_metadata_after_commits(
custom_metadata=custom_metadata,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
partial_save=partial_save,
)
end_time = time.time()
logging.info(
Expand Down
11 changes: 6 additions & 5 deletions checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
from orbax.checkpoint._src.path import utils as ocp_path_utils


SNAPSHOTTING_TIME = "snapshotting_time"
PENDING_DIR_SUFFIX = ".pending_"

PENDING_DIR_SUFFIX = ".pending"


def get_pending_dir_name(source_name: str) -> str:
return f"{source_name}{PENDING_DIR_SUFFIX}{uuid.uuid4().hex}"
return (
f"{source_name}{PENDING_DIR_SUFFIX}_{time.time_ns()}_{uuid.uuid4().hex}"
)


def get_uuid_from_pending_dir_name(pending_dir_name: str) -> str:
Expand Down Expand Up @@ -169,8 +171,7 @@ async def replace_source(self) -> None:

if not await async_path.exists(self._snapshot):
raise FileNotFoundError(f"Snapshot does not exist: {self._snapshot}")
if not await async_path.exists(self._source):
await async_path.mkdir(self._source, parents=True, exist_ok=True)
await async_path.mkdir(self._source, parents=True, exist_ok=True)
# Move files from inside the tmp snapshot into the original source
# directory under a pending suffix. This is to avoid potentially wiping
# out previous files.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,11 @@ def __init__(
)

async def _finalize(self, directory: path_types.Path):
# Keep non-finalized checkpoint state during partial saves to be merged
# later during partial save finalization.
if self._partial_save_mode:
return

if multihost.is_primary_host(self._multiprocessing_options.primary_host):
await self._handler_impl._finalize_async(directory) # pylint: disable=protected-access

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from etils import epath
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path.snapshot import snapshot
from orbax.checkpoint._src.testing import multiprocess_test
from orbax.checkpoint._src.tree import structure_utils as tree_structure_utils
from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler
Expand Down Expand Up @@ -88,11 +90,17 @@ def save(
*,
partial_save: bool = False,
):
if partial_save:
final_dir = directory
directory = (
directory.parent / f'{directory.name}{atomicity_types.TMP_DIR_SUFFIX}'
)

test_utils.sync_global_processes('CompositeHandlerTest:save:start')
if multihost.is_primary_host(0):
directory.mkdir(parents=False, exist_ok=partial_save)
directory.mkdir(parents=True, exist_ok=True)
for k in checkpointables:
(directory / k).mkdir(parents=False, exist_ok=partial_save)
(directory / k).mkdir(parents=True, exist_ok=True)
test_utils.sync_global_processes('CompositeHandlerTest:save:mkdir')

async def _save():
Expand All @@ -107,20 +115,8 @@ async def _save():
for name in checkpointables.keys()
}

checkpoint_metadata_path = (
metadata_serialization.checkpoint_metadata_file_path(directory)
)
if partial_save and checkpoint_metadata_path.exists():
checkpoint_metadata = await orbax_layout.read_checkpoint_metadata(
directory
)
old_handler_typestrs = checkpoint_metadata.item_handlers
handler_typestrs = old_handler_typestrs | handler_typestrs
await multihost.sync_global_processes(
'CompositeHandlerTest:save:checkpoint_metadata_read',
operation_id='op',
processes=None,
)
# For partial save in this test, we skip reading existing global metadata
# here since it will be merged during finalize, just like real execution.

# Metadata expected to be created outside the handler.
if multihost.is_primary_host(0):
Expand All @@ -145,6 +141,11 @@ async def _save():
)
await awaitable

if partial_save and multihost.is_primary_host(0):
final_dir.mkdir(parents=True, exist_ok=True)
pending_dir = final_dir / snapshot.get_pending_dir_name(final_dir.name)
directory.rename(pending_dir)

asyncio.run(_save())
test_utils.sync_global_processes('CompositeHandlerTest:save:complete')

Expand Down Expand Up @@ -354,7 +355,6 @@ def test_partial_save_and_finalize(self, finalize_with_partial_path: bool):
layout, partial_path, first_save_checkpointables, partial_save=True
)
self.assertTrue(partial_path.exists())
self.assertTrue((partial_path / ORBAX_CHECKPOINT_INDICATOR_FILE).exists())

self.save(
layout,
Expand All @@ -363,14 +363,6 @@ def test_partial_save_and_finalize(self, finalize_with_partial_path: bool):
partial_save=True,
)
self.assertTrue(partial_path.exists())
self.assertTrue((partial_path / ORBAX_CHECKPOINT_INDICATOR_FILE).exists())

restored_checkpointables = self.load(
layout, partial_path, merged_checkpointables
)
test_utils.assert_tree_equal(
self, restored_checkpointables, merged_checkpointables
)

partial_saving.finalize(
partial_path if finalize_with_partial_path else final_path
Expand Down
Loading
Loading