Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils


Expand All @@ -36,7 +37,8 @@
_BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')


class CheckpointablesMetadataCompatibilityTest(parameterized.TestCase):
class CheckpointablesMetadataCompatibilityTestBase(parameterized.TestCase):
"""Tests for V1 checkpointables_metadata API against generated Checkpoints."""

def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -94,6 +96,14 @@ def test_checkpointables_metadata_compatibility(
is_direct_checkpoint: bool,
is_pytree: bool,
) -> None:
"""Tests checkpointables_metadata against various checkpoint formats.

Args:
version: v0 or v1.
metadata_present: Whether the checkpoint has metadata files.
is_direct_checkpoint: Whether the checkpoint is a direct checkpoint.
is_pytree: Whether the checkpoint is a pytree checkpoint.
"""
path = compatibility_test_utils.get_checkpoint_path(
version, metadata_present, is_direct_checkpoint, is_pytree
)
Expand Down Expand Up @@ -133,7 +143,11 @@ def test_checkpointables_metadata_compatibility(
else:
expected = self.expected_checkpointables_metadata

test_utils.assert_tree_equal(self, expected, loaded.metadata)
actual = loaded.metadata
if multihost.is_pathways_backend() or jax.process_count() > 1:
expected = compatibility_test_utils.strip_sharding_metadata(expected)
actual = compatibility_test_utils.strip_sharding_metadata(actual)
test_utils.assert_tree_equal(self, expected, actual)
else:
with self.assertRaisesRegex(error_type, expected_error_msg):
ocp.checkpointables_metadata(path)
Expand All @@ -153,6 +167,12 @@ def test_checkpointables_metadata_compatibility(
def test_checkpointables_metadata_non_critical_corruptions(
self, version: str, alteration: str
) -> None:
"""Tests checkpointables_metadata against non-critical corruptions.

Args:
version: The checkpoint version to test against.
alteration: The alteration to apply to the checkpoint.
"""
path = self.base_dir.joinpath(
f'{version}_checkpoints',
'composite_checkpoint',
Expand All @@ -162,16 +182,24 @@ def test_checkpointables_metadata_non_critical_corruptions(
# Missing sharding metadata results in a pytree identical to expected
# values except sharding metadata is None.
loaded = ocp.checkpointables_metadata(path)
test_utils.assert_tree_equal(
self, self.expected_checkpointables_metadata, loaded.metadata
)
expected = self.expected_checkpointables_metadata
actual = loaded.metadata
if multihost.is_pathways_backend() or jax.process_count() > 1:
expected = compatibility_test_utils.strip_sharding_metadata(expected)
actual = compatibility_test_utils.strip_sharding_metadata(actual)
test_utils.assert_tree_equal(self, expected, actual)

@parameterized.product(
version=['v0', 'v1'],
)
def test_checkpointables_metadata_missing_sharding_corruption(
self, version: str
) -> None:
"""Tests checkpointables_metadata against missing sharding corruption.

Args:
version: The checkpoint version to test against.
"""
path = self.base_dir.joinpath(
f'{version}_checkpoints',
'composite_checkpoint',
Expand All @@ -193,6 +221,12 @@ def test_checkpointables_metadata_missing_sharding_corruption(
def test_checkpointables_metadata_critical_corruptions(
self, version: str, alteration: str
) -> None:
"""Tests checkpointables_metadata against critical corruptions.

Args:
version: The checkpoint version to test against.
alteration: The alteration to apply to the checkpoint.
"""
path = self.base_dir.joinpath(
f'{version}_checkpoints',
'composite_checkpoint',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Generates V0 CheckpointManager checkpoints for compatibility testing.

The checkpoints generated by this script are checked into the repository
statically to ensure long-term backward compatibility against runtime changes.
"""
import itertools
import os
from typing import Any

from absl import app
from absl import flags
from etils import epath
import jax
import jax.numpy as jnp
from orbax.checkpoint import args
from orbax.checkpoint import checkpoint_manager
from orbax.checkpoint import handlers
from orbax.checkpoint import test_utils
from orbax.checkpoint.experimental.v1._src.testing import handler_utils

FLAGS = flags.FLAGS


def _get_base_dir():
if 'BUILD_WORKING_DIRECTORY' in os.environ:
return os.path.join(
os.environ['BUILD_WORKING_DIRECTORY'],
'orbax/checkpoint/experimental/v1/_src/testing/compatibility/managed_checkpoints',
)
return os.path.join(
os.path.dirname(__file__),
'managed_checkpoints',
)


_BASE_DIR = flags.DEFINE_string(
'base_dir',
_get_base_dir(),
'Base directory to save checkpoints.',
)
_OVERWRITE = flags.DEFINE_bool(
'overwrite',
False,
'Overwrite existing checkpoints.',
)


def create_pytree() -> dict[str, Any]:
return {
'a': jax.device_put(jnp.array([0, 1, 2, 3, 4, 5, 6, 7], dtype=jnp.int32)),
'b': {'c': jax.device_put(jnp.array([1, 2, 3], dtype=jnp.int32))},
}


def generate_v0_checkpoint_manager_checkpoint(
path: epath.Path,
has_metrics: bool = True,
has_custom_metadata: bool = True,
) -> None:
"""Saves a composite manager checkpoint using V0 CheckpointManager."""
if _OVERWRITE.value:
path.rmtree(missing_ok=True)
path.mkdir(parents=True, exist_ok=True)

pytree = create_pytree()
json_object = {'metadata': 'json_data'}
baz = handler_utils.Baz(123, 'hi')
registry = handlers.create_default_handler_registry(
state=handlers.PyTreeCheckpointHandler(),
metadata=handlers.JsonCheckpointHandler(),
baz=handler_utils.DataclassCheckpointHandler(),
)
options = checkpoint_manager.CheckpointManagerOptions(
create=True,
enable_async_checkpointing=False,
prevent_write_metrics=not has_metrics,
best_fn=lambda metrics: metrics['loss'] if has_metrics else None,
)

manager = checkpoint_manager.CheckpointManager(
path,
options=options,
handler_registry=registry,
metadata={'foo': 'bar'} if has_custom_metadata else None,
)
with manager:
manager.save(
0,
args=args.Composite(
state=args.PyTreeSave(pytree),
metadata=args.JsonSave(json_object),
baz=handler_utils.DataclassSaveArgs(baz),
),
metrics={'loss': 0.5} if has_metrics else None,
custom_metadata={'custom': 'meta'} if has_custom_metadata else None,
)
manager.wait_until_finished()


def main(argv):
del argv
epath.Path(_BASE_DIR.value).mkdir(parents=True, exist_ok=True)

test_utils.set_tensorstore_driver_for_test()

print('Generating V0 CheckpointManager Checkpoints...')

for has_metrics, has_custom_metadata in itertools.product(
[True, False], repeat=2
):
metrics_str = 'has_metrics' if has_metrics else 'no_metrics'
meta_str = (
'has_custom_metadata' if has_custom_metadata else 'no_custom_metadata'
)
path = (
epath.Path(_BASE_DIR.value)
/ 'v0_managed'
/ metrics_str
/ meta_str
)
generate_v0_checkpoint_manager_checkpoint(
path,
has_metrics=has_metrics,
has_custom_metadata=has_custom_metadata,
)

print(f'V0 CheckpointManager Checkpoints generated at {_BASE_DIR.value}')


if __name__ == '__main__':
app.run(main)
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Generates V1 Checkpointer checkpoints for compatibility testing.

The checkpoints generated by this script are checked into the repository
statically to ensure long-term backward compatibility against runtime changes.
"""
import itertools
import os
from typing import Any

from absl import app
from absl import flags
from etils import epath
import jax
import jax.numpy as jnp
from orbax.checkpoint import test_utils
import orbax.checkpoint.experimental.v1 as ocp
from orbax.checkpoint.experimental.v1._src.testing import handler_utils

FLAGS = flags.FLAGS


def _get_base_dir():
if 'BUILD_WORKING_DIRECTORY' in os.environ:
return os.path.join(
os.environ['BUILD_WORKING_DIRECTORY'],
'orbax/checkpoint/experimental/v1/_src/testing/compatibility/managed_checkpoints',
)
return os.path.join(
os.path.dirname(__file__),
'managed_checkpoints',
)


_BASE_DIR = flags.DEFINE_string(
'base_dir',
_get_base_dir(),
'Base directory to save checkpoints.',
)
_OVERWRITE = flags.DEFINE_bool(
'overwrite',
False,
'Overwrite existing checkpoints.',
)


def create_pytree() -> dict[str, Any]:
return {
'a': jax.device_put(jnp.array([0, 1, 2, 3, 4, 5, 6, 7], dtype=jnp.int32)),
'b': {'c': jax.device_put(jnp.array([1, 2, 3], dtype=jnp.int32))},
}


def generate_v1_checkpointer_checkpoint(
path: epath.Path,
has_metrics: bool = True,
has_custom_metadata: bool = True,
) -> None:
"""Saves a composite checkpoint using V1 Checkpointer."""
if _OVERWRITE.value:
path.rmtree(missing_ok=True)
path.mkdir(parents=True, exist_ok=True)

pytree = create_pytree()
json_object = {'metadata': 'json_data'}
baz = handler_utils.Baz(123, 'hi')

registry = ocp.handlers.local_registry()
registry.add(ocp.handlers.PyTreeHandler, checkpointable_name='state')
registry.add(ocp.handlers.JsonHandler, checkpointable_name='metadata')
registry.add(handler_utils.BazHandler, checkpointable_name='baz')

checkpointables = {
'state': pytree,
'metadata': json_object,
'baz': baz,
}

with ocp.Context(
checkpointables_options=ocp.options.CheckpointablesOptions(
registry=registry
)
):
checkpointer = ocp.training.Checkpointer(
path,
custom_metadata={'foo': 'bar'} if has_custom_metadata else None,
)
checkpointer.save_checkpointables(
0,
checkpointables,
metrics={'loss': 0.5} if has_metrics else None,
custom_metadata={'custom': 'meta'} if has_custom_metadata else None,
)


def main(argv):
del argv
epath.Path(_BASE_DIR.value).mkdir(parents=True, exist_ok=True)

test_utils.set_tensorstore_driver_for_test()

print('Generating V1 Checkpointer Checkpoints...')

for has_metrics, has_custom_metadata in itertools.product(
[True, False], repeat=2
):
metrics_str = 'has_metrics' if has_metrics else 'no_metrics'
meta_str = (
'has_custom_metadata' if has_custom_metadata else 'no_custom_metadata'
)
path = (
epath.Path(_BASE_DIR.value)
/ 'v1_managed'
/ metrics_str
/ meta_str
)
generate_v1_checkpointer_checkpoint(
path,
has_metrics=has_metrics,
has_custom_metadata=has_custom_metadata,
)

print(f'V1 Checkpointer Checkpoints generated at {_BASE_DIR.value}')


if __name__ == '__main__':
app.run(main)
Loading
Loading