diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5aff526be..9b6e3371c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -76,7 +76,7 @@ jobs: --ignore=orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager_test.py \ --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py \ --ignore=orbax/checkpoint/_src/testing/oss/multiprocess_test.py \ - $(python3 -c "import yaml; d=yaml.safe_load(open('orbax/checkpoint/_src/testing/oss/tagged_tests.yaml')); print(' '.join(['--ignore=' + t.replace(':', '/') + '.py' for k,v in d.items() if k.startswith('processes') and v for t in v]))") + $(python3 -c "import yaml; d=yaml.safe_load(open('orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml')); print(' '.join(['--ignore=' + t.replace(':', '/') + '.py' for k,v in d.items() if k.startswith('processes') and v for t in v]))") # The below step just reports the success or failure of tests as a "commit status". # This is needed for copybara integration. - name: Report success or failure as github status @@ -350,13 +350,13 @@ jobs: env: TEST_TMPDIR: /tmp/orbax_test run: | - python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=2 --tpu_chips_per_process=4 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=2 + python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=2 --tpu_chips_per_process=4 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml --processes=2 - name: Run 4 multiprocess tests run: | - python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=4 --tpu_chips_per_process=2 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=4 + python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=4 --tpu_chips_per_process=2 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml --processes=4 - name: Run single process tests run: | - python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=1 + python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml --processes=1 - name: Report success or failure as github status if: always() shell: bash diff --git a/.github/workflows/multiprocess_tests.yml b/.github/workflows/multiprocess_tests.yml index 14f1ab69e..ed6cb0fc6 100644 --- a/.github/workflows/multiprocess_tests.yml +++ b/.github/workflows/multiprocess_tests.yml @@ -93,3 +93,62 @@ jobs: "description": "'$status'", "context": "github-actions/build" }' + + multiprocess-unit-tests: + name: "multiprocess-unit-tests (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + runs-on: linux-x86-ct5lp-224-8tpu + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e + defaults: + run: + working-directory: checkpoint + strategy: + matrix: + python-version: ["3.11"] + jax-version: ["newest"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install -e . + pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip uninstall -y orbax + pip install gcsfs + pip install portpicker pytest chex pyyaml tensorboard + if [ "${{ matrix.jax-version }}" = "newest" ]; then + pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + elif [ "${{ matrix.jax-version }}" = "nightly" ]; then + pip install -U --pre "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ + else + pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + fi + - name: Run 2 multiprocess tests + env: + TEST_TMPDIR: /tmp/orbax_test + run: | + python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=2 --tpu_chips_per_process=4 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml --processes=2 + - name: Run 4 multiprocess tests + run: | + python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=4 --tpu_chips_per_process=2 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml --processes=4 + - name: Run single process tests + run: | + python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml --processes=1 + - name: Report success or failure as github status + if: always() + shell: bash + run: | + status="${{ job.status }}" + lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') + curl -sS --request POST \ + --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ + --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ + --header 'content-type: application/json' \ + --data '{ + "state": "'$lowercase_status'", + "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", + "description": "'$status'", + "context": "github-actions/build" + }' diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py b/checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py index a2f62d79b..c0f4ac22b 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py @@ -32,11 +32,14 @@ ] EXCLUDED_PATHS = [ 'orbax/checkpoint/experimental/model_surgery', - 'orbax/checkpoint/experimental/v1', 'orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py', 'orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager_test.py', 'orbax/checkpoint/google', ] +EXCLUDED_PATHS_FOR_PRESUBMIT = [ + 'orbax/checkpoint/experimental/v1', + 'orbax/checkpoint/_src/checkpointers', +] def get_kwargs(call_node): @@ -111,13 +114,14 @@ def get_build_targets(build_file_path): yield name, tags, srcs, args -def run(root_dir, output_file): +def run(root_dir, output_file, extra_excluded_paths=None): """Runs the script to generate tagged tests file.""" tests_by_tag = collections.defaultdict(list) + all_excluded_paths = EXCLUDED_PATHS + (extra_excluded_paths or []) count = 0 for dirpath, dirnames, filenames in os.walk(root_dir): - if any(dirpath.startswith(p) for p in EXCLUDED_PATHS): + if any(dirpath.startswith(p) for p in all_excluded_paths): dirnames[:] = [] continue @@ -125,7 +129,7 @@ def run(root_dir, output_file): dirnames[:] = [] for d in original_dirs: if not any( - os.path.join(dirpath, d).startswith(p) for p in EXCLUDED_PATHS + os.path.join(dirpath, d).startswith(p) for p in all_excluded_paths ): dirnames.append(d) @@ -137,7 +141,8 @@ def run(root_dir, output_file): if not any(tag in TAG_MAPPING for tag in tags): continue if srcs and any( - os.path.join(dirpath, srcs[0]).startswith(p) for p in EXCLUDED_PATHS + os.path.join(dirpath, srcs[0]).startswith(p) + for p in all_excluded_paths ): continue target_path = f'{package_path}:{name}' @@ -167,5 +172,15 @@ def run(root_dir, output_file): if 'BUILD_WORKING_DIRECTORY' in os.environ: os.chdir(os.environ['BUILD_WORKING_DIRECTORY']) orbax_dir = 'orbax/checkpoint' - output = 'orbax/checkpoint/_src/testing/oss/tagged_tests.yaml' - run(orbax_dir, output) + + # Generate whole suite file + output_whole = 'orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml' + run(orbax_dir, output_whole) + + # Generate presubmit file + output_presubmit = 'orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml' + run( + orbax_dir, + output_presubmit, + extra_excluded_paths=EXCLUDED_PATHS_FOR_PRESUBMIT, + ) diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests.yaml b/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml old mode 100755 new mode 100644 similarity index 94% rename from checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests.yaml rename to checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml index f2a2b536f..179181c9d --- a/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests.yaml +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml @@ -13,8 +13,6 @@ processes:1: - orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:pathways_replicator_checkpoint_manager_test - orbax/checkpoint:single_host_test processes:2: -- orbax/checkpoint/_src/checkpointers:async_checkpointer_test -- orbax/checkpoint/_src/checkpointers:checkpointer_test - orbax/checkpoint/_src/handlers:array_checkpoint_handler_test - orbax/checkpoint/_src/handlers:pytree_checkpoint_handler_test - orbax/checkpoint/_src/handlers:standard_checkpoint_handler_test diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml b/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml new file mode 100644 index 000000000..afa72e373 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml @@ -0,0 +1,63 @@ +# DO NOT EDIT! +processes:1: +- orbax/checkpoint/_src/futures:future_test +- orbax/checkpoint/_src/metadata:sharding_tpu_test +- orbax/checkpoint/_src/multihost:multislice_test +- orbax/checkpoint/_src/serialization:colocated_pathways_local_type_handlers_test +- orbax/checkpoint/_src/serialization:colocated_pathways_memory_usage_test +- orbax/checkpoint/_src/serialization:pathways_local_type_handlers_test +- orbax/checkpoint/_src/serialization:pathways_memory_usage_test +- orbax/checkpoint/_src/serialization:replica_slices_test +- orbax/checkpoint/_src/serialization:serialization_test +- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:pathways_process_metadata_checkpoint_handler_test +- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:pathways_replicator_checkpoint_manager_test +- orbax/checkpoint/experimental/v1/_src/emergency:deleter_test +- orbax/checkpoint/experimental/v1/_src/emergency:path_utils_test +- orbax/checkpoint/experimental/v1/_src/testing/pathways:array_leaf_handler_test_multi_worker +- orbax/checkpoint/experimental/v1/_src/testing/pathways:array_leaf_handler_test_single_worker +- orbax/checkpoint/experimental/v1/_src/testing/pathways:pytree_handler_test_multi_worker +- orbax/checkpoint/experimental/v1/_src/testing/pathways:pytree_handler_test_single_worker +- orbax/checkpoint/experimental/v1/_src/testing/pathways:save_load_test_multi_worker +- orbax/checkpoint/experimental/v1/_src/testing/pathways:save_load_test_single_worker +- orbax/checkpoint/experimental/v1/_src/testing/pathways:v1_compatibility_load_test_multi_worker +- orbax/checkpoint/experimental/v1/_src/testing/pathways:v1_compatibility_load_test_single_worker +- orbax/checkpoint/experimental/v1/_src/training/pathways:checkpointer_test_multi_worker +- orbax/checkpoint/experimental/v1/_src/training/pathways:checkpointer_test_single_worker +- orbax/checkpoint/experimental/v1/_src/training/pathways:snapshotter_test +- orbax/checkpoint/experimental/v1/_src/training/pathways:v0v1_compatibility_checkpointer_test_single_worker +- orbax/checkpoint:single_host_test +processes:2: +- orbax/checkpoint/_src/checkpointers:async_checkpointer_test +- orbax/checkpoint/_src/checkpointers:checkpointer_test +- orbax/checkpoint/_src/handlers:array_checkpoint_handler_test +- orbax/checkpoint/_src/handlers:pytree_checkpoint_handler_test +- orbax/checkpoint/_src/handlers:standard_checkpoint_handler_test +- orbax/checkpoint/_src/serialization:local_type_handlers_test +- orbax/checkpoint/_src/serialization:type_handlers_test +- orbax/checkpoint/experimental/emergency/p2p:checkpoint_manager_multiprocess_test +- orbax/checkpoint/experimental/emergency/p2p:local_multiprocess_test +- orbax/checkpoint/experimental/emergency/p2p:persistent_multiprocess_test +- orbax/checkpoint/experimental/v1/_src/handlers:pytree_handler_partial_save_test +- orbax/checkpoint/experimental/v1/_src/handlers:pytree_handler_test +- orbax/checkpoint/experimental/v1/_src/layout:orbax_layout_multiprocess_test +- orbax/checkpoint/experimental/v1/_src/partial:saving_multihost_test +- orbax/checkpoint/experimental/v1/_src/serialization:array_leaf_handler_test +- orbax/checkpoint/experimental/v1/_src/serialization:numpy_leaf_handler_test +- orbax/checkpoint/experimental/v1/_src/serialization:scalar_leaf_handler_test +- orbax/checkpoint/experimental/v1/_src/serialization:string_leaf_handler_test +- orbax/checkpoint/experimental/v1/_src/testing:save_load_test +- orbax/checkpoint/experimental/v1/_src/testing:v1_compatibility_load_test_multiprocess +- orbax/checkpoint/experimental/v1/_src/training:checkpointer_test +- orbax/checkpoint/experimental/v1/_src/training:v0v1_compatibility_checkpointer_test +processes:4: +- orbax/checkpoint/_src/multihost:multihost_test +- orbax/checkpoint/_src/testing/tree_verity:checkpoint_manager_test +- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:process_metadata_checkpoint_handler_test +- orbax/checkpoint/experimental/emergency:local_checkpoint_data_debugging_test +- orbax/checkpoint/experimental/emergency:local_checkpoint_manager_test +- orbax/checkpoint/experimental/emergency:single_slice_checkpoint_manager_test +- orbax/checkpoint/experimental/v1/_src/layout:safetensors_layout_multiprocess_test +- orbax/checkpoint/experimental/v1/_src/synchronization:multihost_test +- orbax/checkpoint/testing:local_path_test +- orbax/checkpoint:checkpoint_manager_slice_test +- orbax/checkpoint:checkpoint_manager_test diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py new file mode 100644 index 000000000..ce87dff86 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py @@ -0,0 +1,2365 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common test cases for PyTreeHandler.""" + +# pylint: disable=protected-access, missing-function-docstring + +from __future__ import annotations + +import asyncio +import contextlib +import dataclasses +import datetime +import functools +import json +import threading +from typing import Any, Awaitable, Iterator, List, Sequence, Type +from unittest import mock + +from absl import flags +from absl.testing import parameterized +import aiofiles +from etils import epath +import flax +import flax.training.train_state +import jax +from jax import numpy as jnp +from jax.experimental import mesh_utils +import numpy as np +import optax +from orbax.checkpoint import test_utils +from orbax.checkpoint import utils +from orbax.checkpoint._src.arrays import abstract_arrays +from orbax.checkpoint._src.handlers import pytree_checkpoint_handler +from orbax.checkpoint._src.metadata import array_metadata +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib +from orbax.checkpoint._src.metadata import empty_values +from orbax.checkpoint._src.metadata import sharding as sharding_metadata +from orbax.checkpoint._src.metadata import tree as tree_metadata +from orbax.checkpoint._src.metadata import value as value_metadata +from orbax.checkpoint._src.serialization import limits +from orbax.checkpoint._src.serialization import replica_slices +from orbax.checkpoint._src.serialization import serialization +from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils +from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint._src.testing import multiprocess_test +from orbax.checkpoint._src.tree import utils as tree_utils +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.context import options as options_lib +from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler +from orbax.checkpoint.experimental.v1._src.path import types as path_types +from orbax.checkpoint.experimental.v1._src.serialization import array_leaf_handler +from orbax.checkpoint.experimental.v1._src.serialization import numpy_leaf_handler +from orbax.checkpoint.experimental.v1._src.serialization import registry +from orbax.checkpoint.experimental.v1._src.serialization import scalar_leaf_handler +from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types +from orbax.checkpoint.experimental.v1._src.synchronization import multihost +from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils +from orbax.checkpoint.experimental.v1._src.testing import handler_utils as handler_test_utils +from orbax.checkpoint.experimental.v1._src.tree import types as tree_types + + +PyTree = tree_types.PyTree +ParamInfo = pytree_checkpoint_handler.ParamInfo + +_SHARDING = '_sharding' +PYTREE_METADATA_FILE = pytree_checkpoint_handler.PYTREE_METADATA_FILE +ARRAY_METADATA_STORE = array_metadata_store_lib.Store() +PLACEHOLDER = type_handlers.PLACEHOLDER + +create_sharded_array = array_test_utils.create_sharded_array +create_numpy_pytree = array_test_utils.create_numpy_pytree +create_sharded_pytree = array_test_utils.create_sharded_pytree +as_abstract_type = array_test_utils.as_abstract_type + + +PathAwaitingCreation = path_types.PathAwaitingCreation +PathLike = path_types.PathLike +Path = path_types.Path + + +FLAGS = flags.FLAGS + +jax.config.update('jax_enable_x64', True) + + +async def _run_awaitable(awaitable: Awaitable[Any]) -> Any: + return await awaitable + + +# Custom dataclasses for testing custom leaf handlers. PyType check requires +# these defines outside of the test. +@dataclasses.dataclass +class Point: + x: int + y: float + + +@dataclasses.dataclass +class AbstractPoint: + x: Type[int] = int + y: Type[float] = float + + +class PointLeafHandler(serialization_types.LeafHandler[Point, AbstractPoint]): + """A custom leaf handler for testing.""" + + def __init__(self, context: context_lib.Context | None = None): + del context + + async def serialize( + self, + params: Sequence[serialization_types.SerializationParam[Point]], + serialization_context: serialization_types.SerializationContext, + ) -> Awaitable[None]: + + async def _background_serialize(): + if multihost.is_primary_host(0): + # make sure the parent directory is created + await serialization_context.parent_dir.await_creation() + + for param in params: + async with aiofiles.open( + serialization_context.parent_dir.path / f'{param.name}.txt', + 'w', + ) as f: + await f.write(json.dumps(dataclasses.asdict(param.value))) + + return _background_serialize() + + async def deserialize( + self, + params: Sequence[ + serialization_types.DeserializationParam[ + AbstractPoint | Type[AbstractPoint] + ] + ], + deserialization_context: serialization_types.DeserializationContext, + ) -> Awaitable[Sequence[Point]]: + + async def _background_deserialize(): + ret = [] + for param in params: + async with aiofiles.open( + deserialization_context.parent_dir / f'{param.name}.txt', + 'r', + ) as f: + ret.append(Point(**json.loads(await f.read()))) + + return ret + + return _background_deserialize() + + async def metadata( + self, + params: Sequence[serialization_types.DeserializationParam[None]], + deserialization_context: serialization_types.DeserializationContext, + ) -> Sequence[AbstractPoint]: + return [AbstractPoint()] * len(params) + + +def create_mixed_format_pytree( + *, + add: int = 0, + strings: bool = False, + parent_key: str | None = None, + include_scalars: bool = True, +) -> PyTree: + """Creates a PyTree with different leaf types for testing. + + Args: + add: Adds the specified value to numeric leafs. + strings: If true, adds string leaves to the tree. + parent_key: If provided, keys will be contained within a dictionary under + this key. + include_scalars: If true, adds scalar leaves to the tree. + + Returns: + PyTree + """ + numpy_pytree, abstract_numpy_pytree = create_numpy_pytree( + add=add, include_scalars=include_scalars + ) + sharded_pytree, abstract_sharded_pytree = create_sharded_pytree( + add=add, include_scalars=include_scalars + ) + if parent_key: + numpy_pytree = {parent_key: numpy_pytree} + sharded_pytree = {parent_key: sharded_pytree} + abstract_numpy_pytree = {parent_key: abstract_numpy_pytree} + abstract_sharded_pytree = {parent_key: abstract_sharded_pytree} + mixed_pytree = { + 'numpy': numpy_pytree, + 'sharded': sharded_pytree, + } + abstract_mixed_pytree = { + 'numpy': abstract_numpy_pytree, + 'sharded': abstract_sharded_pytree, + } + if strings: + mixed_pytree['foo'] = 'foo_val' + mixed_pytree['bar'] = 'bar_val' + abstract_mixed_pytree['foo'] = '' + abstract_mixed_pytree['bar'] = '' + return mixed_pytree, abstract_mixed_pytree + + +def _raise_file_not_found_error(*args, **kwargs): + del args, kwargs + raise FileNotFoundError() + + +# Not in common util because we need to eliminate OSS dependency on flax. +def init_flax_model(model): + params = model.init(jax.random.PRNGKey(0), jnp.ones([8, 8])) + tx = optax.adamw(learning_rate=0.001) + state = flax.training.train_state.TrainState.create( + apply_fn=model.apply, params=params, tx=tx + ) + return jax.tree.map(np.asarray, state) + + +def get_d_files(path: Path) -> list[Path]: + files = [] + for idx in range(multihost.process_count()): + d_path = path / f'ocdbt.process_{idx}' / 'd' + if not d_path.exists(): + continue + files.extend(list(d_path.iterdir())) + return files + + +@contextlib.contextmanager +def handler_with_options( + *, + scoped_storage_options_creator: ( + options_lib.ArrayOptions.Saving.ScopedStorageOptionsCreator | None + ) = None, + array_storage_options: ( + options_lib.ArrayOptions.Saving.StorageOptions + ) = options_lib.ArrayOptions.Saving.StorageOptions(), + save_concurrent_bytes: int | None = None, + restore_concurrent_bytes: int | None = None, + use_ocdbt: bool = True, + use_zarr3: bool = False, + use_compression: bool = True, + enable_padding_and_truncation: bool = True, + ocdbt_target_data_file_size: int | None = None, + enable_pinned_host_transfer: bool | None = None, + pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = ( + tree_metadata.PYTREE_METADATA_OPTIONS + ), + array_metadata_store: array_metadata_store_lib.Store | None = ( + ARRAY_METADATA_STORE + ), + enable_write_sharding_file: bool = True, + partial_load: bool = False, + leaf_handler_registry: ( + serialization_types.LeafHandlerRegistry | None + ) = None, +): + """Registers handlers with OCDBT support and resets when done.""" + context = context_lib.Context( + array_options=options_lib.ArrayOptions( + saving=options_lib.ArrayOptions.Saving( + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + use_compression=use_compression, + ocdbt_target_data_file_size=ocdbt_target_data_file_size, + enable_pinned_host_transfer=enable_pinned_host_transfer, + array_metadata_store=array_metadata_store, + enable_write_sharding_file=enable_write_sharding_file, + use_replica_parallel=not utils.is_pathways_backend(), + storage_options=array_storage_options, + scoped_storage_options_creator=scoped_storage_options_creator, + ), + loading=options_lib.ArrayOptions.Loading( + enable_padding_and_truncation=enable_padding_and_truncation, + ), + ), + memory_options=options_lib.MemoryOptions( + write_concurrent_bytes=save_concurrent_bytes, + read_concurrent_bytes=restore_concurrent_bytes, + ), + pytree_options=options_lib.PyTreeOptions( + saving=options_lib.PyTreeOptions.Saving( + pytree_metadata_options=pytree_metadata_options, + ), + loading=options_lib.PyTreeOptions.Loading( + partial_load=partial_load, + ), + leaf_handler_registry=leaf_handler_registry, + ), + ) + + handler = handler_test_utils.create_test_handler( + pytree_handler.PyTreeHandler, context=context + ) + + try: + yield handler + finally: + pass + + +class PyTreeHandlerTest( + parameterized.TestCase, + multiprocess_test.MultiProcessTest, +): + + def setUp(self): + super().setUp() + + self.pytree, self.abstract_pytree = create_sharded_pytree() + self.numpy_pytree, self.abstract_numpy_pytree = create_numpy_pytree() + + self.directory = epath.Path( + self.multiprocess_create_tempdir(name='checkpointing_test') + ) + # TODO: b/365169723 - Add tests for support_rich_types=True. + self.pytree_metadata_options = tree_metadata.PyTreeMetadataOptions( + support_rich_types=False + ) + + # default to use_ocdbt=False, so we can test non-ocdbt handler first + self.handler = self.enter_context( + handler_with_options( + use_ocdbt=False, array_metadata_store=ARRAY_METADATA_STORE + ) + ) + test_utils.set_tensorstore_driver_for_test() + + test_utils.sync_global_processes( + 'PyTreeCheckpointHandlerTest:setup_complete' + ) + + def tearDown(self): + test_utils.sync_global_processes( + 'PyTreeCheckpointHandlerTest:tests_complete' + ) + super().tearDown() + + def validate_save( + self, + path: epath.Path, + abstract_pytree: PyTree | None, + expected: PyTree, + checkpoint_handler, + ): + """Validate save was performed correctly.""" + actual = checkpoint_handler.load(path, abstract_pytree) + test_utils.assert_tree_equal(self, expected, actual) + + def validate_metadata( + self, + *, + expected_reference_metadata_tree: PyTree, + actual_metadata: PyTree, + pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, + array_metadata_store: array_metadata_store_lib.Store | None, + ): + """Validate metadata, provided the original tree that was saved.""" + expected_reference_metadata_tree = tree_metadata.serialize_tree( + expected_reference_metadata_tree, pytree_metadata_options + ) + + def _metadata(value): + if empty_values.is_supported_empty_value(value, pytree_metadata_options): + return value + if isinstance(value, np.ndarray): + return numpy_leaf_handler.NumpyMetadata( + shape=value.shape, + dtype=value.dtype, + storage_metadata=value_metadata.StorageMetadata( + chunk_shape=value.shape, + ), + ) + if isinstance(value, jax.Array): + expected_sharding = sharding_metadata.from_jax_sharding(value.sharding) + expected_chunk_shape = test_utils.get_expected_chunk_shape(value) + return array_leaf_handler.ArrayMetadata( + shape=value.shape, + sharding_metadata=expected_sharding, + dtype=value.dtype, + storage_metadata=value_metadata.StorageMetadata( + chunk_shape=expected_chunk_shape, + write_shape=( + expected_chunk_shape + if array_metadata_store is not None + else None + ), + ), + ) + if isinstance(value, float): + return 0.0 + elif isinstance(value, int): + return 0 + if isinstance(value, str): + return 'string' + if isinstance(value, optax.EmptyState): + return None + if isinstance(value, Point): + return AbstractPoint() + raise ValueError(f'Unrecognized type: {type(value)}.') + + expected_metadata = jax.tree.map( + _metadata, + expected_reference_metadata_tree, + is_leaf=tree_utils.is_empty_or_leaf, + ) + test_utils.assert_tree_equal(self, expected_metadata, actual_metadata) + + def test_get_param_names(self): + param_names = pytree_checkpoint_handler.get_param_names(self.pytree) + expected = { + 'a': 'a', + 'b': 'b', + 'c': { + 'a': 'c.a', + 'e': 'c.e', + }, + 'x': 'x', + 'y': 'y', + } + test_utils.assert_tree_equal(self, expected, param_names) + + def test_save_format(self): + pytree = {'a': 0, 'c': {'d': np.arange(3), 'e': {'f': 5}}, 'g': 10} + self.handler.save(self.directory, pytree) + fnames = ['a', 'c.d', 'c.e.f', 'g'] + paths = [self.directory / name for name in fnames] + for p in paths: + self.assertTrue(p.exists()) + self.assertTrue((p / '.zarray').exists()) + + @parameterized.product(use_ocdbt=(True, False)) + def test_save_sharding(self, use_ocdbt: bool): + if multihost.is_pathways_backend(): + self.skipTest('Sharding metadata not present on Pathways.') + with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: + pytree = { + 'mlp/~/linear_0': { + 'a': self.pytree['a'], + 'b': self.pytree['b'], + 'c': {'a': self.pytree['c']['a'], 'e': self.pytree['c']['e']}, + } + } + abstract_pytree = jax.tree.map(array_test_utils.as_abstract_type, pytree) + checkpoint_handler.save(self.directory, pytree) + + self.validate_save( + self.directory, + abstract_pytree, + pytree, + checkpoint_handler, + ) + + self.assertTrue((self.directory / _SHARDING).exists()) + with open(self.directory / _SHARDING, 'r') as file: + data = json.load(file) + self.assertCountEqual( + data.keys(), + { + 'bWxwL34vbGluZWFyXzAuYQ==', # mlp/~/linear_0.a + 'bWxwL34vbGluZWFyXzAuYg==', # mlp/~/linear_0.b + 'bWxwL34vbGluZWFyXzAuYy5h', # mlp/~/linear_0.c.a + 'bWxwL34vbGluZWFyXzAuYy5l', # mlp/~/linear_0.c.e + }, + ) + # mlp/~/linear_0.a + self.assertEqual( + sharding_metadata.NamedShardingMetadata.from_deserialized_dict( + json.loads(data['bWxwL34vbGluZWFyXzAuYQ==']) + ), + sharding_metadata.NamedShardingMetadata.from_jax_sharding( + pytree['mlp/~/linear_0']['a'].sharding + ), + ) + # mlp/~/linear_0.b + self.assertEqual( + sharding_metadata.NamedShardingMetadata.from_deserialized_dict( + json.loads(data['bWxwL34vbGluZWFyXzAuYg==']) + ), + sharding_metadata.NamedShardingMetadata.from_jax_sharding( + pytree['mlp/~/linear_0']['b'].sharding + ), + ) + # mlp/~/linear_0.c.a + self.assertEqual( + sharding_metadata.NamedShardingMetadata.from_deserialized_dict( + json.loads(data['bWxwL34vbGluZWFyXzAuYy5h']) + ), + sharding_metadata.NamedShardingMetadata.from_jax_sharding( + pytree['mlp/~/linear_0']['c']['a'].sharding + ), + ) + # mlp/~/linear_0.c.e + self.assertEqual( + sharding_metadata.NamedShardingMetadata.from_deserialized_dict( + json.loads(data['bWxwL34vbGluZWFyXzAuYy5l']) + ), + sharding_metadata.NamedShardingMetadata.from_jax_sharding( + pytree['mlp/~/linear_0']['c']['e'].sharding + ), + ) + + @parameterized.product( + use_ocdbt=(True, False), + array_metadata_store=(None, ARRAY_METADATA_STORE), + ) + def test_disable_write_sharding_file( + self, + use_ocdbt: bool, + array_metadata_store: array_metadata_store_lib.Store | None, + ): + pytree, abstract_pytree = create_mixed_format_pytree() + with handler_with_options( + use_ocdbt=use_ocdbt, + array_metadata_store=array_metadata_store, + enable_write_sharding_file=False, + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, pytree) + self.validate_save( + self.directory, + abstract_pytree, + pytree, + checkpoint_handler, + ) + self.assertFalse((self.directory / _SHARDING).exists()) + + def test_sharding_variable_devices(self): + if multihost.is_pathways_backend(): + self.skipTest('Sharding metadata not present on Pathways.') + mesh_axes = jax.sharding.PartitionSpec( + 'x', + ) + devices_subset = [] + for idx in range(jax.process_count()): + for d in jax.devices(): + if d.process_index == idx: + devices_subset.append(d) + break + pytree = { + 'a': test_utils.create_sharded_array( + np.arange(16), + jax.sharding.Mesh(devices_subset, ('x',)), + mesh_axes, + ), + 'b': test_utils.create_sharded_array( + np.arange(16), jax.sharding.Mesh(jax.devices(), ('x',)), mesh_axes + ), + } + + self.handler.save(self.directory, pytree) + self.assertTrue((self.directory / _SHARDING).exists()) + a_sharding_metadata = sharding_metadata.NamedShardingMetadata( + shape=np.array([2]), + axis_names=['x'], + partition_spec=('x',), + axis_types=(jax.sharding.AxisType.Auto,), + device_mesh=sharding_metadata.DeviceMetadataMesh.from_jax_mesh( + jax.sharding.Mesh(devices_subset, ('x',)) + ), + ) + b_sharding_metadata = sharding_metadata.NamedShardingMetadata( + shape=np.array([8]), + axis_names=['x'], + partition_spec=('x',), + axis_types=(jax.sharding.AxisType.Auto,), + device_mesh=sharding_metadata.DeviceMetadataMesh.from_jax_mesh( + jax.sharding.Mesh(jax.devices(), ('x',)) + ), + ) + + restored_metadata = self.handler.metadata(self.directory) + self.assertEqual( + a_sharding_metadata, + restored_metadata['a'].sharding_metadata, + ) + self.assertEqual( + b_sharding_metadata, + restored_metadata['b'].sharding_metadata, + ) + self.assertEqual( + pytree['a'].sharding, + restored_metadata['a'].sharding, + ) + self.assertEqual( + pytree['b'].sharding, + restored_metadata['b'].sharding, + ) + + @parameterized.product(use_ocdbt=(True, False)) + def test_save_main(self, use_ocdbt: bool): + with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: + checkpoint_handler.save(self.directory, self.pytree) + self.validate_save( + self.directory, + self.abstract_pytree, + self.pytree, + checkpoint_handler, + ) + self.assertEqual( + type_handlers.is_ocdbt_checkpoint(self.directory), use_ocdbt + ) + + @parameterized.product(use_ocdbt=(True, False)) + def test_save_keys_with_slashes(self, use_ocdbt: bool): + with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: + pytree = { + 'a': np.arange(2), + 'b/c': np.arange(4), + } + checkpoint_handler.save(self.directory, pytree) + self.validate_save( + self.directory, + None, + pytree, + checkpoint_handler, + ) + + def test_save_non_sharded(self): + self.handler.save(self.directory, self.numpy_pytree) + self.validate_save( + self.directory, + None, + self.numpy_pytree, + self.handler, + ) + + @parameterized.product( + use_ocdbt=(True, False), + array_metadata_store=(None, ARRAY_METADATA_STORE), + ) + def test_save_mixed( + self, + use_ocdbt: bool, + array_metadata_store: array_metadata_store_lib.Store | None, + ): + with handler_with_options( + use_ocdbt=use_ocdbt, array_metadata_store=array_metadata_store + ) as checkpoint_handler: + pytree, abstract_pytree = create_mixed_format_pytree(strings=True) + checkpoint_handler.save(self.directory, pytree) + self.validate_save( + self.directory, + abstract_pytree, + pytree, + checkpoint_handler, + ) + if use_ocdbt: + expected_files_and_directories = [ + '_strings.json', + 'manifest.ocdbt', + 'ocdbt.process_0', + ] + else: + expected_files_and_directories = [ + '_strings.json', + 'numpy.a', + 'numpy.b', + 'numpy.c.a', + 'numpy.c.e', + ] + self.assertContainsSubset( + expected_files_and_directories, + [f.name for f in self.directory.iterdir()], + ) + self.validate_metadata( + expected_reference_metadata_tree=pytree, + actual_metadata=checkpoint_handler.metadata(self.directory), + pytree_metadata_options=self.pytree_metadata_options, + array_metadata_store=array_metadata_store, + ) + + @parameterized.product( + use_ocdbt=(True, False), + array_metadata_store=(None,), + ) + def test_save_strings( + self, + use_ocdbt: bool, + array_metadata_store: array_metadata_store_lib.Store | None, + ): + if use_ocdbt and multihost.is_pathways_backend(): + self.skipTest('Pathways + OCDBT not supported.') + + with handler_with_options( + use_ocdbt=use_ocdbt, array_metadata_store=array_metadata_store + ) as checkpoint_handler: + pytree, abstract_pytree = create_mixed_format_pytree(strings=True) + + checkpoint_handler.save(self.directory, pytree) + self.validate_save( + self.directory, + abstract_pytree, + pytree, + checkpoint_handler, + ) + self.validate_metadata( + expected_reference_metadata_tree=pytree, + actual_metadata=checkpoint_handler.metadata(self.directory), + pytree_metadata_options=self.pytree_metadata_options, + array_metadata_store=array_metadata_store, + ) + self.assertTrue((self.directory / '_strings.json').exists()) + with open(self.directory / '_strings.json') as file: + data = json.load(file) + self.assertCountEqual( + data.keys(), + {'foo', 'bar'}, + None, + ) + self.assertEqual(data['foo'], 'foo_val') + self.assertEqual(data['bar'], 'bar_val') + + def test_cast(self): + pytree, abstract_pytree = create_mixed_format_pytree(include_scalars=False) + origin_dtype = np.int64 + save_dtype = np.uint32 + restore_dtype = np.float64 + + def check_dtype(x, dtype): + if not utils.is_scalar(x): + self.assertEqual(x.dtype, dtype) + + def set_dtype(v, dtype): + if hasattr(v, 'dtype'): + if isinstance(v, jax.ShapeDtypeStruct): + v = v.update(dtype=dtype) + else: + setattr(v, 'dtype', dtype) + return v + + with self.subTest('check_origin_dtype'): + jax.tree.map(functools.partial(check_dtype, dtype=origin_dtype), pytree) + jax.tree.map( + functools.partial(check_dtype, dtype=origin_dtype), abstract_pytree + ) + + with handler_with_options( + use_ocdbt=False, + scoped_storage_options_creator=lambda k, v: options_lib.ArrayOptions.Saving.StorageOptions( + dtype=save_dtype + ), + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, pytree) + + with self.subTest('check_restore_dtype'): + abstract_pytree = jax.tree.map( + functools.partial(set_dtype, dtype=restore_dtype), abstract_pytree + ) + restored = self.handler.load(self.directory, abstract_pytree) + jax.tree.map( + functools.partial(check_dtype, dtype=restore_dtype), restored + ) + + with self.subTest('check_save_dtype'): + if multihost.is_pathways_backend(): + self.skipTest( + 'Pathways does not allow restoring without abstract tree.' + ) + restored = self.handler.load(self.directory) + jax.tree.map(functools.partial(check_dtype, dtype=save_dtype), restored) + + def test_save_with_callback_falling_back_to_global_options(self): + # Setup global options + global_opts = options_lib.ArrayOptions.Saving.StorageOptions( + dtype=np.int16 + ) + + # Callback returns empty options for some fields + def my_callback(k, v): + # For 'sharded.a' and 'sharded.c.e', return specific dtype. + # For others, return None, which should fall back to global_opts (int16). + del v + key_path_tuple = tuple(getattr(p, 'key', None) for p in k) + if key_path_tuple == ('sharded', 'a'): + return options_lib.ArrayOptions.Saving.StorageOptions(dtype=np.int32) + elif key_path_tuple == ('sharded', 'c', 'e'): + return options_lib.ArrayOptions.Saving.StorageOptions(dtype=np.float32) + return None + + with handler_with_options( + use_ocdbt=False, + array_storage_options=global_opts, + scoped_storage_options_creator=my_callback, + ) as checkpoint_handler: + pytree, _ = create_mixed_format_pytree( + include_scalars=False + ) + checkpoint_handler.save(self.directory, pytree) + + # Load and verify it restored as int16 (falling back to global) + restored = self.handler.load(self.directory) + + # Check only sharded leaves + def check_dtype(keypath, x): + if hasattr(x, 'dtype'): + key_path_tuple = tuple(getattr(p, 'key', None) for p in keypath) + if key_path_tuple == ('a',): + self.assertEqual(x.dtype, np.dtype(np.int32)) + elif key_path_tuple == ('c', 'e'): + self.assertEqual(x.dtype, np.dtype(np.float32)) + else: + self.assertEqual(x.dtype, np.dtype(np.int16)) + + jax.tree_util.tree_map_with_path(check_dtype, restored['sharded']) + + @parameterized.product(cast_to=(int, float, 0, 0.0)) + def test_cast_scalar_types(self, cast_to): + pytree = {'a': 5, 'b': 6.1} + abstract_pytree = { + 'a': cast_to, + 'b': cast_to, + } + + self.handler.save(self.directory, pytree) + restored = self.handler.load(self.directory, abstract_pytree) + expected_type = cast_to if isinstance(cast_to, type) else type(cast_to) + self.assertIsInstance(restored['a'], expected_type) + self.assertIsInstance(restored['b'], expected_type) + + @parameterized.product( + use_ocdbt=(True, False), + use_zarr3=(True, False), + array_metadata_store=(None, ARRAY_METADATA_STORE), + ) + def test_save_restore( + self, + use_ocdbt: bool, + use_zarr3: bool, + array_metadata_store: array_metadata_store_lib.Store | None, + ): + with handler_with_options( + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + array_metadata_store=array_metadata_store, + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, self.pytree) + restored = checkpoint_handler.load( + self.directory, + self.abstract_pytree, + ) + test_utils.assert_tree_equal(self, self.pytree, restored) + self.validate_metadata( + expected_reference_metadata_tree=self.pytree, + actual_metadata=checkpoint_handler.metadata(self.directory), + pytree_metadata_options=self.pytree_metadata_options, + array_metadata_store=array_metadata_store, + ) + + def test_save_async(self): + # The pytree must be larger so that saving doesn't complete too quickly. + mesh = jax.sharding.Mesh(jax.devices(), 'x') + np.random.seed(42) + pytree = { + 'a': array_test_utils.create_sharded_array( + np.arange(2**20), + sharding=jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('x') + ), + ), + 'b': array_test_utils.create_sharded_array( + np.random.uniform(size=2**15), + sharding=jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(None) + ), + ), + } + abstract_pytree = jax.tree.map(array_test_utils.as_abstract_type, pytree) + + start_serialize = threading.Event() + original_serialize = serialization.async_serialize_from_host + + def mock_serialize(*args, **kwargs): + start_serialize.wait() # Wait for explicit signal before proceeding. + return original_serialize(*args, **kwargs) + + def is_save_complete(directory): + return (directory / 'manifest.ocdbt').exists() + + # Serialization to disk does not start until receiving an explicit signal. + self.enter_context( + mock.patch.object( + serialization, 'async_serialize_from_host', new=mock_serialize + ) + ) + + with handler_with_options() as checkpoint_handler: + awaitable = checkpoint_handler.save_async(self.directory, pytree) + initial_d_files = get_d_files(self.directory) + self.assertFalse(is_save_complete(self.directory)) + start_serialize.set() + + asyncio.run(_run_awaitable(awaitable)) + final_d_files = get_d_files(self.directory) + self.assertNotEmpty(final_d_files) + self.assertNotEqual(len(initial_d_files), len(final_d_files)) + self.assertTrue(is_save_complete(self.directory)) + + restored = checkpoint_handler.load( + self.directory, + abstract_pytree, + ) + test_utils.assert_tree_equal(self, pytree, restored) + + def test_load_async(self): + with handler_with_options() as checkpoint_handler: + checkpoint_handler.save(self.directory, self.pytree) + load_awaitable = checkpoint_handler.load_async( + self.directory, + self.abstract_pytree, + ) + restored = asyncio.run(_run_awaitable(load_awaitable)) + test_utils.assert_tree_equal(self, self.pytree, restored) + + @parameterized.product(use_ocdbt=(True, False)) + def test_load_reverse_mesh(self, use_ocdbt: bool): + if use_ocdbt and multihost.is_pathways_backend(): + self.skipTest('Pathways + OCDBT not supported.') + with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: + pytree, abstract_pytree = array_test_utils.create_sharded_pytree( + reverse_devices=True + ) + checkpoint_handler.save(self.directory, pytree) + restored = checkpoint_handler.load(self.directory, abstract_pytree) + test_utils.assert_tree_equal(self, pytree, restored) + + def test_load_multiple_steps(self): + for step in [0, 1]: + directory = self.directory / str(step) + if multihost.process_index() == 0: + directory.mkdir() + test_utils.sync_global_processes( + 'PyTreeCheckpointHandlerTest:test_load_different_mkdir' + ) + + pytree, abstract_pytree = create_mixed_format_pytree(add=step) + self.handler.save(directory, pytree) + + restored = self.handler.load(directory, abstract_pytree) + test_utils.assert_tree_equal(self, pytree, restored) + + def test_load_missing_checkpoint(self): + directory = self.directory / 'nothing' + with self.assertRaises(FileNotFoundError): + self.handler.load(directory) + + @parameterized.product( + use_ocdbt=(True, False), + array_metadata_store=(None, ARRAY_METADATA_STORE), + ) + def test_flax_model( + self, + use_ocdbt: bool, + array_metadata_store: array_metadata_store_lib.Store | None, + ): + + @flax.struct.dataclass + class Params(flax.struct.PyTreeNode): + params: PyTree + opt_state: PyTree + + def make_state_with_optax(): + return Params( + params=self.numpy_pytree, + opt_state=(optax.EmptyState(), optax.EmptyState()), + ) + + def make_state_with_nones(): + return Params( + params=self.numpy_pytree, + opt_state=(None, None), + ) + + state = make_state_with_optax() + + with handler_with_options( + use_ocdbt=use_ocdbt, array_metadata_store=array_metadata_store + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, state) + + with self.subTest('with_abstract_state'): + abstract_state = jax.tree.map(array_test_utils.as_abstract_type, state) + restored = checkpoint_handler.load(self.directory, abstract_state) + expected_state = state + test_utils.assert_tree_equal(self, expected_state, restored) + self.validate_metadata( + expected_reference_metadata_tree=expected_state, + actual_metadata=checkpoint_handler.metadata(self.directory), + pytree_metadata_options=self.pytree_metadata_options, + array_metadata_store=array_metadata_store, + ) + + with self.subTest('without_abstract_state'): + if multihost.is_pathways_backend(): + self.skipTest('Must provide abstract_pytree for Pathways.') + restored = checkpoint_handler.load(self.directory) + expected_state = tree_utils.serialize_tree( + make_state_with_nones(), + keep_empty_nodes=True, + ) + test_utils.assert_tree_equal(self, expected_state, restored) + self.validate_metadata( + expected_reference_metadata_tree=expected_state, + actual_metadata=checkpoint_handler.metadata(self.directory), + pytree_metadata_options=self.pytree_metadata_options, + array_metadata_store=array_metadata_store, + ) + + @parameterized.product( + use_ocdbt=( + True, + False, + ), + data=( + {}, + {'a': {}, 'b': 3}, + [1, {}, 2], + None, + {'a': None, 'b': 3}, + [1, None, 2], + [], + [1, [], 2], + {'a': [], 'b': 3}, + ), + array_metadata_store=(None, ARRAY_METADATA_STORE), + ) + def test_empty_data( + self, + use_ocdbt: bool, + data: Any, + array_metadata_store: array_metadata_store_lib.Store | None, + ): + with handler_with_options( + use_ocdbt=use_ocdbt, array_metadata_store=array_metadata_store + ) as checkpoint_handler: + if not data: + with self.assertRaisesRegex(ValueError, 'Found empty item'): + checkpoint_handler.save( + self.directory, + data, + ) + return + + checkpoint_handler.save(self.directory, data) + restored = checkpoint_handler.load(self.directory) + self.assertEqual(restored, data) + + self.validate_metadata( + expected_reference_metadata_tree=data, + actual_metadata=checkpoint_handler.metadata(self.directory), + pytree_metadata_options=self.pytree_metadata_options, + array_metadata_store=array_metadata_store, + ) + + @parameterized.product( + use_ocdbt=(True, False), + array_metadata_store=(None, ARRAY_METADATA_STORE), + ) + def test_list( + self, + use_ocdbt: bool, + array_metadata_store: array_metadata_store_lib.Store | None, + ): + item = [1, 2, 5, 6] + with handler_with_options( + use_ocdbt=use_ocdbt, array_metadata_store=array_metadata_store + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, item) + abstract_item = [0, 0, 0, 0] + restored = checkpoint_handler.load(self.directory, abstract_item) + self.assertListEqual(restored, item) + self.validate_metadata( + expected_reference_metadata_tree=[0, 0, 0, 0], + actual_metadata=checkpoint_handler.metadata(self.directory), + pytree_metadata_options=self.pytree_metadata_options, + array_metadata_store=array_metadata_store, + ) + + restored = checkpoint_handler.load(self.directory) + self.assertListEqual( + restored, + [ + np.asarray([1]), + np.asarray([2]), + np.asarray([5]), + np.asarray([6]), + ], + ) + + def test_no_metadata_file(self): + self.handler.save(self.directory, self.pytree) + metadata_file = self.directory / PYTREE_METADATA_FILE + if multihost.process_index() == 0: + self.assertTrue(metadata_file.exists()) + metadata_file.unlink() + test_utils.sync_global_processes('delete_metadata_file') + self.assertFalse(metadata_file.exists()) + with self.assertRaises(FileNotFoundError): + self.handler.metadata(self.directory) + + @parameterized.parameters((True,), (False,)) + def test_reshape_padding(self, enable_padding_and_truncation: bool): + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',)) + axes = jax.sharding.PartitionSpec( + 'x', + ) + dtype = np.float32 + pytree = { + 'x': test_utils.create_sharded_array( + np.arange(8, dtype=dtype), mesh, axes + ) + } + abstract_pytree = { + 'x': jax.ShapeDtypeStruct( + shape=(16,), dtype=dtype, sharding=pytree['x'].sharding + ) + } + with handler_with_options( + enable_padding_and_truncation=enable_padding_and_truncation + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, pytree) + if enable_padding_and_truncation: + restored = checkpoint_handler.load(self.directory, abstract_pytree) + expected = { + 'x': test_utils.create_sharded_array( + np.concatenate( + (np.arange(8, dtype=dtype), np.zeros(8, dtype=dtype)) + ), + mesh, + axes, + ) + } + test_utils.assert_tree_equal(self, expected, restored) + else: + with self.assertRaises(BaseException): + checkpoint_handler.load(self.directory, abstract_pytree) + + @parameterized.parameters((True,), (False,)) + def test_reshape_truncate(self, enable_padding_and_truncation: bool): + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',)) + axes = jax.sharding.PartitionSpec( + 'x', + ) + dtype = np.float32 + pytree = { + 'x': test_utils.create_sharded_array( + np.arange(16, dtype=dtype), mesh, axes + ) + } + abstract_pytree = { + 'x': jax.ShapeDtypeStruct( + shape=(8,), dtype=dtype, sharding=pytree['x'].sharding + ) + } + + with handler_with_options( + enable_padding_and_truncation=enable_padding_and_truncation + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, pytree) + if enable_padding_and_truncation: + restored = checkpoint_handler.load(self.directory, abstract_pytree) + expected = { + 'x': test_utils.create_sharded_array( + np.arange(8, dtype=dtype), mesh, axes + ) + } + test_utils.assert_tree_equal(self, expected, restored) + else: + with self.assertRaises(BaseException): + checkpoint_handler.load(self.directory, abstract_pytree) + + @parameterized.parameters( + (jax.sharding.PartitionSpec(), jax.sharding.PartitionSpec(('x', 'y'))), + (jax.sharding.PartitionSpec(), jax.sharding.PartitionSpec(('y', 'x'))), + (jax.sharding.PartitionSpec(), jax.sharding.PartitionSpec(('x',))), + (jax.sharding.PartitionSpec(), jax.sharding.PartitionSpec(('y',))), + (jax.sharding.PartitionSpec(('x', 'y')), jax.sharding.PartitionSpec()), + ( + jax.sharding.PartitionSpec(('x', 'y')), + jax.sharding.PartitionSpec(('x',)), + ), + ( + jax.sharding.PartitionSpec(('x', 'y')), + jax.sharding.PartitionSpec(('y',)), + ), + ( + jax.sharding.PartitionSpec(('x', 'y')), + jax.sharding.PartitionSpec(('y', 'x')), + ), + ( + jax.sharding.PartitionSpec(('x',)), + jax.sharding.PartitionSpec(('y',)), + ), + ) + def test_reshard(self, save_spec, restore_spec): + devices = jax.devices() + len_devices = len(devices) + self.assertGreaterEqual(len_devices, 4) + + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh((4, len_devices // 4)), ('x', 'y') + ) + dtype = np.int32 + pytree = { + 'x': test_utils.create_sharded_array( + np.arange(len_devices, dtype=dtype), mesh, save_spec + ) + } + abstract_pytree = { + 'x': jax.ShapeDtypeStruct( + shape=(len_devices,), + dtype=dtype, + sharding=jax.sharding.NamedSharding(mesh, restore_spec), + ) + } + + self.handler.save(self.directory, pytree) + restored = self.handler.load(self.directory, abstract_pytree) + expected = { + 'x': test_utils.create_sharded_array( + np.arange(len_devices, dtype=dtype), mesh, restore_spec + ) + } + test_utils.assert_tree_equal(self, expected, restored) + + def test_load_non_ocdbt(self): + with handler_with_options(use_ocdbt=False) as checkpoint_handler: + checkpoint_handler.save(self.directory, self.pytree) + self.assertFalse(type_handlers.is_ocdbt_checkpoint(self.directory)) + with handler_with_options(use_ocdbt=True) as checkpoint_handler: + restored = checkpoint_handler.load( + self.directory, + self.abstract_pytree, + ) + test_utils.assert_tree_equal(self, self.pytree, restored) + + def test_load_non_ocdbt_mixed(self): + pytree, abstract_pytree = create_mixed_format_pytree(strings=True) + with handler_with_options(use_ocdbt=False) as checkpoint_handler: + checkpoint_handler.save(self.directory, pytree) + self.assertFalse(type_handlers.is_ocdbt_checkpoint(self.directory)) + with handler_with_options(use_ocdbt=True) as checkpoint_handler: + restored = checkpoint_handler.load(self.directory, abstract_pytree) + test_utils.assert_tree_equal(self, pytree, restored) + + def test_check_zarray(self): + self.handler.save(self.directory, self.pytree) + zarr_path = self.directory / 'a' / '.zarray' + zarr_path.unlink(missing_ok=True) + test_utils.sync_global_processes( + 'PyTreeCheckpointHandlerTest:delete_zarray' + ) + self.assertFalse(zarr_path.exists()) + with self.assertRaises(FileNotFoundError): + self.handler.load( + self.directory, + self.abstract_pytree, + ) + + def test_without_abstract_pytree(self): + if multihost.is_pathways_backend(): + self.skipTest('Must provide abstract_pytree when using Pathways.') + arr = test_utils.create_sharded_array( + np.arange(8), + jax.sharding.Mesh(jax.devices(), ('x',)), + jax.sharding.PartitionSpec('x'), + ) + pytree = [arr] + self.handler.save(self.directory, pytree) + restored = self.handler.load(self.directory) + test_utils.assert_tree_equal(self, pytree, restored) + + @parameterized.product(use_ocdbt=(True, False)) + def test_masked_shape_dtype_struct(self, use_ocdbt: bool): + + def _should_mask(keypath): + return keypath[0].key == 'a' or ( + keypath[0].key == 'c' and keypath[1].key == 'e' + ) + + def _mask(keypath, x): + return optax.MaskedNode() if _should_mask(keypath) else x + + def _none(keypath, x): + return None if _should_mask(keypath) else x + + masked_tree = jax.tree_util.tree_map_with_path(_mask, self.pytree) + expected = jax.tree_util.tree_map_with_path(_none, self.pytree) + + with handler_with_options(use_ocdbt=use_ocdbt) as handler: + handler.save(self.directory, masked_tree) + if use_ocdbt: + self.assertTrue(type_handlers.is_ocdbt_checkpoint(self.directory)) + + # Restore it with state which was given before applying masking. + restored = handler.load( + self.directory, + jax.tree.map(abstract_arrays.to_shape_dtype_struct, self.pytree), + ) + test_utils.assert_tree_equal(self, expected, restored) + + # Restore it with state after applying masking to it. + restored = handler.load( + self.directory, + jax.tree.map(abstract_arrays.to_shape_dtype_struct, masked_tree), + ) + test_utils.assert_tree_equal(self, expected, restored) + + # Restore it without any state. + restored = handler.load( + self.directory, + self.abstract_pytree, + ) + test_utils.assert_tree_equal(self, expected, restored) + + def test_finalize(self): + with handler_with_options(use_ocdbt=True) as checkpoint_handler: + checkpoint_handler.save(self.directory, self.pytree) + process_index = multihost.process_index() + process_dir = ( + self.directory / f'{ts_utils.PROCESS_SUBDIR_PREFIX}{process_index}' + ) + self.assertTrue(process_dir.exists()) + self.assertTrue(process_dir.is_dir()) + self.assertTrue(type_handlers.is_ocdbt_checkpoint(self.directory)) + + @parameterized.product(use_ocdbt=(True, False)) + def test_unregistered_types(self, use_ocdbt: bool): + data = {'uncheckpointable_field': datetime.timedelta(seconds=5)} + with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: + with self.assertRaisesRegex( + registry.UnregisteredTypeError, + 'The following leaf types are not registered', + ): + checkpoint_handler.save( + self.directory, + data, + ) + + @parameterized.product( + target_data_file_size=[ + 50 * 1024, # 50KB + 10 * 1024, # 10KB + 0, + None, + ], + chunk_byte_size=[ + None, # unspecified + 5 * 1024, # 5KB + 100 * 1024, # greater than target_data_file_size + ], + use_per_key_options=[True, False], + patch_default_ocdbt_data_file_size=[True, False], + ) + def test_ocdbt_target_data_file_size( + self, + target_data_file_size, + chunk_byte_size, + use_per_key_options, + patch_default_ocdbt_data_file_size, + ): + """Test ocdbt_target_data_file_size.""" + array_len = 16 * 1024 # ~ 64KB of float data + custom_pytree = { + 'a': np.arange(array_len, dtype=np.int32), + 'b': np.arange(array_len * 2, dtype=np.float32), + 'c': { + 'a': ( + np.arange(array_len, dtype=np.int32).reshape(2, array_len // 2) + ), + 'e': ( + np.arange(array_len * 2, dtype=np.float32).reshape(2, array_len) + ), + }, + } + shardings = { + 'a': self.abstract_pytree['a'].sharding, + 'b': self.abstract_pytree['b'].sharding, + 'c': { + 'a': self.abstract_pytree['c']['a'].sharding, + 'e': self.abstract_pytree['c']['e'].sharding, + }, + } + pytree = jax.tree.map(create_sharded_array, custom_pytree, shardings) + abstract_pytree = jax.tree.map(as_abstract_type, pytree) + + if use_per_key_options: + scoped_storage_options_creator = ( + lambda key, value: options_lib.ArrayOptions.Saving.StorageOptions( + chunk_byte_size=chunk_byte_size + ) + ) + array_storage_options = None + else: + scoped_storage_options_creator = None + array_storage_options = options_lib.ArrayOptions.Saving.StorageOptions( + chunk_byte_size=chunk_byte_size + ) + new_ocdbt_target_data_file_size = ( + 1024 + if patch_default_ocdbt_data_file_size + else ts_utils._DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE + ) + with mock.patch.object( + ts_utils, + '_DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE', + new_ocdbt_target_data_file_size, + ): + if patch_default_ocdbt_data_file_size: + assert ts_utils._DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE == 1024 + with handler_with_options( + use_ocdbt=True, + ocdbt_target_data_file_size=target_data_file_size, + array_storage_options=array_storage_options, + scoped_storage_options_creator=scoped_storage_options_creator, + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, pytree) + + data_dir = self.directory / 'd' + self.assertTrue(data_dir.exists()) + self.assertTrue(data_dir.is_dir()) + + for f in data_dir.iterdir(): + if f.is_file(): + if target_data_file_size not in (0, None): + # it's expected the resulting file sizes can be larger than + # the target_data_file_size, so we give some buffer here + self.assertLessEqual( + f.stat().length, + target_data_file_size * 2.0, + ) # some room + if patch_default_ocdbt_data_file_size: + self.assertLessEqual( + f.stat().length, + ( + new_ocdbt_target_data_file_size * 4.0 + ), # TODO(niketkb): revisit culprit cl/786790774. + ) + + restored = checkpoint_handler.load(self.directory, abstract_pytree) + + test_utils.assert_tree_equal(self, pytree, restored) + + def test_local_registry(self): + + if multihost.is_pathways_backend(): + # This does not test anything on the pathways backend + # TODO(b/333114195): add proper pathways testing. + return + + class PlusOneHandler(scalar_leaf_handler.ScalarLeafHandler): + """A custom handler that adds one to all scalar values.""" + + def __init__(self, context: context_lib.Context | None = None): + super().__init__(context=context) + + async def serialize( + self, + params: Sequence[scalar_leaf_handler.ScalarSerializationParam], + serialization_context: serialization_types.SerializationContext, + ) -> Awaitable[None]: + updated_params = [ + scalar_leaf_handler.ScalarSerializationParam( + keypath=param.keypath, value=param.value + 1 + ) + for param in params + ] + + return await super().serialize(updated_params, serialization_context) + + leaf_registry = registry.BaseLeafHandlerRegistry() + leaf_registry.add(int, int, PlusOneHandler) + + with handler_with_options( + leaf_handler_registry=leaf_registry, + array_metadata_store=None, + use_zarr3=True, + ) as handler: + with self.assertRaisesRegex( + registry.UnregisteredTypeError, + 'The following leaf types are not registered', + ): + handler.save(self.directory, {'a': 3, 'b': 1.0}) + + handler.save(self.directory, {'a': 3}) + + with self.assertRaisesRegex( + registry.UnregisteredTypeError, + 'The following abstract leaf types are not registered', + ): + handler.load(self.directory, {'a': 3.0}) + + restored = handler.load(self.directory) + expected = {'a': 4} + self.assertEqual(restored, expected) + + def test_empty_custom_node(self): + + class PyTreeDict(dict): + pass + + jax.tree_util.register_pytree_node( + PyTreeDict, + lambda d: (tuple(d.values()), tuple(d.keys())), + lambda keys, values: PyTreeDict(dict(zip(keys, values))), + ) + + with self.assertRaisesRegex(ValueError, 'Found empty item'): + self.handler.save(self.directory, PyTreeDict()) + + self.handler.save(self.directory, {'a': PyTreeDict()}) + restored = self.handler.load(self.directory) + self.assertDictEqual({'a': {}}, restored) + + restored = self.handler.load(self.directory, {'a': PyTreeDict()}) + test_utils.assert_tree_equal(self, {'a': PyTreeDict()}, restored) + + @parameterized.parameters((5,), (9,)) + def test_concurrent_gb_save(self, limit_bytes): + # TODO(b/346811105): Enable for Pathways. + if multihost.is_pathways_backend(): + self.skipTest( + 'Disabled on Pathways because completion_times cannot updated by' + ' reference outside remote Python.' + ) + sleep_time = 1.0 + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh( + jax.devices(), + ('x',), + ), + jax.sharding.PartitionSpec( + None, + ), + ) + # 4 arrays, each has a single chunk, with 4 bytes each. + tree = jax.tree.map( + functools.partial( + array_test_utils.create_sharded_array, sharding=sharding + ), + { + 'a': np.arange(1, dtype=np.int32), + 'b': np.arange(1, dtype=np.int32), + 'c': np.arange(1, dtype=np.int32), + 'd': np.arange(1, dtype=np.int32), + }, + ) + byte_limiter = test_utils.get_byte_limiter(limit_bytes, sleep_time) + with mock.patch.object( + limits, + 'get_byte_limiter', + new=lambda _: byte_limiter, + ), handler_with_options( + save_concurrent_bytes=limit_bytes, + ) as handler: + handler.save(self.directory, tree) + # Replicated shards are handled within the _write_array_shard function. + # Since shards are only saved once per replica, we only have to check + # the primary process. + completion_times = byte_limiter.completion_times + if multihost.process_index() == 0: + self.assertLen(completion_times, len(jax.tree.leaves(tree))) + test_utils.assert_every_n_is_x_apart( + self, + completion_times, + limit_bytes // np.int32().itemsize, + sleep_time, + ) + + @parameterized.parameters((5,), (9,)) + def test_concurrent_gb_restore(self, limit_bytes): + # TODO(b/346811105): Enable for Pathways. + if multihost.is_pathways_backend(): + self.skipTest( + 'Disabled on Pathways because completion_times cannot updated by' + ' reference outside remote Python.' + ) + sleep_time = 1.0 + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh( + jax.devices(), + ('x',), + ), + jax.sharding.PartitionSpec( + None, + ), + ) + # 4 arrays, each has a single chunk, with 4 bytes each. + tree = jax.tree.map( + functools.partial( + array_test_utils.create_sharded_array, sharding=sharding + ), + { + 'a': np.arange(1, dtype=np.int32), + 'b': np.arange(1, dtype=np.int32), + 'c': np.arange(1, dtype=np.int32), + 'd': np.arange(1, dtype=np.int32), + }, + ) + self.handler.save(self.directory, tree) + + byte_limiter = test_utils.get_byte_limiter(limit_bytes, sleep_time) + with mock.patch.object( + limits, + 'get_byte_limiter', + new=lambda _,: byte_limiter, + ), handler_with_options(restore_concurrent_bytes=limit_bytes) as handler: + restored = handler.load(self.directory) + test_utils.assert_tree_equal(self, tree, restored) + completion_times = byte_limiter.completion_times + self.assertLen( + completion_times, + len(jax.tree.leaves(tree)), + ) + test_utils.assert_every_n_is_x_apart( + self, + completion_times, + limit_bytes // np.int32().itemsize, + sleep_time, + ) + + @parameterized.product(enable_pinned_host_transfer=(True, False)) + def test_enable_pinned_host_transfer(self, enable_pinned_host_transfer): + if multihost.is_pathways_backend(): + self.skipTest( + 'Disabled on Pathways because local variables cannot updated by' + ' reference outside remote Python.' + ) + true_count = 0 + false_count = 0 + + original_transfer_arrays_to_host = replica_slices.transfer_arrays_to_host + + def _transfer_arrays_to_host( + arrays, + replica_id, + use_replica_parallel, + min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel, + enable_pinned_host_transfer, + ): + nonlocal true_count, false_count + if enable_pinned_host_transfer: + true_count += 1 + else: + false_count += 1 + return original_transfer_arrays_to_host( + arrays, + replica_id, + use_replica_parallel=use_replica_parallel, + min_slice_bytes_for_replica_parallel=min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel=max_replicas_for_replica_parallel, + enable_pinned_host_transfer=enable_pinned_host_transfer, + ) + + with mock.patch.object( + replica_slices, + 'transfer_arrays_to_host', + new=_transfer_arrays_to_host, + ), handler_with_options( + enable_pinned_host_transfer=enable_pinned_host_transfer, + ) as handler: + handler.save(self.directory, self.pytree) + + if enable_pinned_host_transfer: + self.assertGreater(true_count, 0) + self.assertEqual(false_count, 0) + else: + self.assertEqual(true_count, 0) + self.assertGreater(false_count, 0) + + @parameterized.product( + use_ocdbt=(True, False), + pytree_metadata_options=( + tree_metadata.PyTreeMetadataOptions(support_rich_types=False), + tree_metadata.PyTreeMetadataOptions(support_rich_types=True), + ), + ) + def test_write_shape_metadata_missing_for_all_types_other_than_jax_array( + self, + use_ocdbt: bool, + pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, + ): + checkpoint = { + 'a': 1, + 'b': np.array([2]), + 'c': 'hello', + } + expected_metadata = { + 'a': 0, + 'b': numpy_leaf_handler.NumpyMetadata( + shape=(1,), + dtype=checkpoint['b'].dtype, + storage_metadata=value_metadata.StorageMetadata( + chunk_shape=(1,), write_shape=None + ), + ), + 'c': 'string', + } + with handler_with_options( + use_ocdbt=use_ocdbt, + pytree_metadata_options=pytree_metadata_options, + array_metadata_store=ARRAY_METADATA_STORE, + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, checkpoint) + + self.assertFalse((self.directory / 'array_metadatas').exists()) + restored_metadata = checkpoint_handler.metadata(self.directory) + self.assertEqual( + expected_metadata, + restored_metadata, + ) + + @parameterized.product( + use_ocdbt=(True, False), + pytree_metadata_options=( + tree_metadata.PyTreeMetadataOptions(support_rich_types=False), + tree_metadata.PyTreeMetadataOptions(support_rich_types=True), + ), + ) + def test_write_shape_in_metadata_disabled( + self, + use_ocdbt: bool, + pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, + ): + with handler_with_options( + use_ocdbt=use_ocdbt, + pytree_metadata_options=pytree_metadata_options, + array_metadata_store=None, + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, self.pytree) + expected_tree_with_write_shapes = { + 'a': {'write_shape': None}, + 'b': {'write_shape': None}, + 'c': { + 'a': {'write_shape': None}, + 'e': {'write_shape': None}, + }, + 'x': {'write_shape': None}, + 'y': {'write_shape': None}, + } + metadata = checkpoint_handler.metadata(self.directory) + tree_with_write_shapes = jax.tree.map( + lambda m: {'write_shape': m.storage_metadata.write_shape}, metadata + ) + self.assertDictEqual( + expected_tree_with_write_shapes, tree_with_write_shapes + ) + + # TODO(b/382230550): Add test for chunk_shape != write_shape. + @parameterized.product( + use_ocdbt=(True, False), + pytree_metadata_options=( + tree_metadata.PyTreeMetadataOptions(support_rich_types=False), + tree_metadata.PyTreeMetadataOptions(support_rich_types=True), + ), + ) + def test_write_shape_in_metadata( + self, + use_ocdbt: bool, + pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, + ): + with handler_with_options( + use_ocdbt=use_ocdbt, pytree_metadata_options=pytree_metadata_options + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, self.pytree) + + expected_tree_with_write_shapes = { + 'a': { + 'write_shape': test_utils.get_expected_chunk_shape( + self.pytree['a'] + ) + }, + 'b': {'write_shape': (2,)}, + 'c': { + 'a': {'write_shape': (1, 1)}, + 'e': {'write_shape': (2, 1)}, + }, + 'x': {'write_shape': ()}, + 'y': {'write_shape': ()}, + } + metadata = checkpoint_handler.metadata(self.directory) + tree_with_write_shapes = jax.tree.map( + lambda m: {'write_shape': m.storage_metadata.write_shape}, metadata + ) + self.assertDictEqual( + expected_tree_with_write_shapes, tree_with_write_shapes + ) + + @parameterized.product(use_ocdbt=(True, False)) + def test_array_metadata_disabled(self, use_ocdbt: bool): + with handler_with_options( + use_ocdbt=use_ocdbt, array_metadata_store=None + ) as checkpoint_handler: + pytree, abstract_pytree = create_mixed_format_pytree() + + checkpoint_handler.save(self.directory, pytree) + + self.validate_save( + self.directory, + abstract_pytree, + pytree, + checkpoint_handler, + ) + + self.assertFalse((self.directory / 'array_metadatas').exists()) + + @parameterized.product(use_ocdbt=(True, False)) + def test_array_metadata(self, use_ocdbt: bool): + with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: + + checkpoint_handler.save(self.directory, self.pytree) + + self.validate_save( + self.directory, + self.abstract_pytree, + self.pytree, + checkpoint_handler, + ) + + self.assertTrue((self.directory / 'array_metadatas').exists()) + if multihost.is_primary_host(0): + array_metadatas = asyncio.run(ARRAY_METADATA_STORE.read(self.directory)) + self.assertIsInstance(array_metadatas, dict) + per_process_metadatas = [ + array_metadata.SerializedArrayMetadata( + param_name='a', + write_shape=test_utils.get_expected_chunk_shape(self.pytree['a']), + chunk_shape=test_utils.get_expected_chunk_shape(self.pytree['a']), + ), + array_metadata.SerializedArrayMetadata( + param_name='b', + write_shape=(2,), + chunk_shape=(2,), + ), + array_metadata.SerializedArrayMetadata( + param_name='c.a', + write_shape=(1, 1), + chunk_shape=(1, 1), + ), + array_metadata.SerializedArrayMetadata( + param_name='c.e', + write_shape=(2, 1), + chunk_shape=(2, 1), + ), + array_metadata.SerializedArrayMetadata( + param_name='x', + write_shape=(), + chunk_shape=(), + ), + array_metadata.SerializedArrayMetadata( + param_name='y', + write_shape=(), + chunk_shape=(), + ), + ] + processes = range(multihost.process_count()) + if multihost.is_pathways_backend(): + process_ids = set( + [f'{d.slice_index}.{d.process_index}' for d in jax.devices()] + ) + processes = range(len(process_ids)) + expected_array_metadatas = { + idx: per_process_metadatas for idx in processes + } + self.assertSameElements( + expected_array_metadatas.keys(), array_metadatas.keys() + ) + for process_index in expected_array_metadatas: + self.assertEqual( # pylint: disable=g-generic-assert + sorted( + expected_array_metadatas[process_index], + key=lambda x: x.param_name, + ), + sorted(array_metadatas[process_index], key=lambda x: x.param_name), + ) + + @parameterized.product(use_ocdbt=(True, False)) + def test_save_with_missing_array_metadata_file(self, use_ocdbt: bool): + if multihost.process_index() != 0: # only test on primary host + self.skipTest('Test only for primary host to avoid barrier timeout.') + + class PathResolverReturningNoMetadataFiles( + array_metadata_store_lib.PathResolver + ): + + async def get_read_file_paths( + self, checkpoint_dir: epath.Path, process_index: int | None = None + ) -> Iterator[epath.Path] | epath.Path | None: + return None + + with handler_with_options( + use_ocdbt=use_ocdbt, + array_metadata_store=array_metadata_store_lib.Store( + path_resolver=PathResolverReturningNoMetadataFiles() + ), + ) as checkpoint_handler: + with self.assertRaisesRegex( + ValueError, 'No ArrayMetadata found for process_index' + ): + checkpoint_handler.save(self.directory, self.pytree) + + @parameterized.product(use_ocdbt=(True, False)) + def test_save_with_missing_array_metadata_for_params(self, use_ocdbt: bool): + if multihost.process_index() != 0: # only test on primary host + self.skipTest('Test only for primary host to avoid barrier timeout.') + + class MissingArrayMetadataSerializer(array_metadata_store_lib.Serializer): + + def deserialize( + self, serialized: str + ) -> List[array_metadata.SerializedArrayMetadata]: + true_data = super().deserialize(serialized) + return [true_data.pop(0)] # Delete the rest and return partial data. + + with handler_with_options( + use_ocdbt=use_ocdbt, + array_metadata_store=array_metadata_store_lib.Store( + serializer=MissingArrayMetadataSerializer() + ), + ) as checkpoint_handler: + with self.assertRaisesRegex( + ValueError, 'No ArrayMetadata found for param_info' + ): + checkpoint_handler.save(self.directory, self.pytree) + + @parameterized.parameters((True,), (False,)) + def test_zero_size_array(self, use_jax_array: bool): + arr = np.ones(shape=(0,)) + mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=('x',)) + pspec = jax.sharding.PartitionSpec() + if use_jax_array: + arr = test_utils.create_sharded_array(arr, mesh, pspec) + pytree = [arr] + with self.assertRaisesRegex(ValueError, 'zero size'): + self.handler.save(self.directory, pytree) + + @parameterized.product(use_ocdbt=(True, False)) + def test_save_restore_random_keys(self, use_ocdbt: bool): + """Test saving and restoring random keys within a pytree.""" + + # TODO(b/393160483) investigate Pathways remote Python support for + # random.keys. + if multihost.is_pathways_backend(): + self.skipTest( + 'Disabled on Pathways because random keys are not supported by' + ' remote Python.' + ) + + mesh = jax.sharding.Mesh(jax.devices(), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + pytree = { + 'keys': { + 'kone': jax.random.key(jnp.array(0, device=sharding)), + 'impl_key': { + 'rbg': jax.random.key( + jnp.array(1, device=sharding), impl='rbg' + ), + 'unsafe_rbg': jax.random.key( + jnp.array(2, device=sharding), impl='unsafe_rbg' + ), + }, + 'split_keys': jax.random.split( + jax.random.key(jnp.array(123, device=sharding)), num=10 + ), + }, + 'arrays': self.pytree, + } + + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as save_handler: + save_handler.save(self.directory, pytree) + + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as load_handler: + restored = load_handler.load(self.directory) + test_utils.assert_tree_equal(self, pytree, restored) + + def test_pinned_host_loading(self): + if multihost.is_pathways_backend(): + # TODO(b/404915487): Reenable when possible. + self.skipTest('Disabled due to b/404915487.') + + mesh = jax.sharding.Mesh( + np.asarray(jax.devices()).reshape((1, len(jax.devices()))), ('x', 'y') + ) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('x', 'y') + ).with_memory_kind('pinned_host') + + pytree = dict(arr=jnp.ones((1024, 512), device=sharding)) + self.handler.save(self.directory, pytree) + + abstract_pytree = dict( + arr=jax.ShapeDtypeStruct( + pytree['arr'].shape, pytree['arr'].dtype, sharding=sharding + ) + ) + restored = self.handler.load(self.directory, abstract_pytree) + expected = dict(arr=jax.device_put(np.ones((1024, 512)), sharding)) + test_utils.assert_tree_equal(self, expected, restored) + + @parameterized.product( + use_ocdbt=(True, False), + reference_item=( + { + 'a': 0, + 'b': 0, + 'c': { + 'e': 0, + }, + }, + { + 'a': 0, + 'c': { + 'a': 0, + 'e': 0, + }, + }, + { + 'a': 0, + 'b': 0, + }, + ), + ) + def test_restore_item_has_missing_leaves( + self, use_ocdbt: bool, reference_item: dict[str, Any] + ): + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as handler: + handler.save(self.directory, self.pytree) + + with self.assertRaisesRegex( + ValueError, 'User-provided restore item and on-disk value' + ): + handler.load(self.directory, reference_item) + + def test_partial_restore_with_placeholder_simple(self): + original_item = { + 'a': np.arange(8), + 'b': np.arange(8), + 'c': { + 'a': np.arange(8), + 'e': np.arange(8), + }, + } + reference_item = jax.tree.map(as_abstract_type, original_item) + reference_item['b'] = PLACEHOLDER + reference_item['c']['e'] = PLACEHOLDER + expected = { + 'a': original_item['a'], + 'b': PLACEHOLDER, + 'c': { + 'a': original_item['c']['a'], + 'e': PLACEHOLDER, + }, + } + + simple_dir = epath.Path( + self.multiprocess_create_tempdir(name='simple_placeholder_dir') + ) + + with handler_with_options() as handler: + handler.save(simple_dir, original_item) + restored = handler.load(simple_dir, reference_item) + test_utils.assert_tree_equal(self, expected, restored) + + @parameterized.product(use_ocdbt=(True, False)) + def test_partial_restore_with_placeholder(self, use_ocdbt: bool): + """Test saving and restoring placeholder.""" + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as save_handler: + save_handler.save(self.directory, self.pytree) + + with self.subTest('success'): + reference_item = self.abstract_pytree.copy() + reference_item['b'] = PLACEHOLDER + reference_item['c']['e'] = PLACEHOLDER + + expected = self.pytree.copy() + expected['b'] = PLACEHOLDER + expected['c']['e'] = PLACEHOLDER + + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as restore_handler: + restored = restore_handler.load(self.directory, reference_item) + test_utils.assert_tree_equal(self, expected, restored) + + with self.subTest('missing_leaf'): + reference_item = self.abstract_pytree.copy() + reference_item['b'] = PLACEHOLDER + reference_item['c']['e'] = PLACEHOLDER + del reference_item['c']['a'] + + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as restore_handler: + with self.assertRaisesRegex( + ValueError, 'User-provided restore item and on-disk value' + ): + restore_handler.load(self.directory, reference_item) + + with self.subTest('non_leaf_placeholder'): + reference_item = self.abstract_pytree.copy() + reference_item['c'] = PLACEHOLDER + + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as restore_handler: + with self.assertRaisesRegex( + ValueError, 'User-provided restore item and on-disk value' + ): + restore_handler.load(self.directory, reference_item) + + @parameterized.product(use_ocdbt=(True, False)) + def test_partial_restore_with_omission(self, use_ocdbt: bool): + """Basic save and restore test.""" + directory = self.directory / 'partial_restore' + + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as save_handler: + save_handler.save(directory, self.pytree) + + with self.subTest('success'): + with handler_with_options( + use_ocdbt=use_ocdbt, + partial_load=True, + ) as restore_handler: + # Create a new pytree structure with the same leaves. + # Leaves (ShapeDtypeStruct) are immutable and can be shared. + reference_item = jax.tree.map(lambda x: x, self.abstract_pytree) + # Omit 'b', 'c.e', and 'x' from the reference item. + del reference_item['b'] + del reference_item['c']['e'] + del reference_item['x'] + expected = { + 'a': self.pytree['a'], + 'c': { + 'a': self.pytree['c']['a'], + }, + 'y': self.pytree['y'], + } + restored = restore_handler.load(directory, reference_item) + test_utils.assert_tree_equal(self, expected, restored) + + @parameterized.product(use_ocdbt=(True, False)) + def test_partial_restore_with_placeholder_unexpected_keys( + self, use_ocdbt: bool + ): + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as save_handler: + save_handler.save(self.directory, self.pytree) + + reference_item = self.abstract_pytree.copy() + reference_item['b'] = PLACEHOLDER + reference_item['c']['e'] = PLACEHOLDER + reference_item['c']['f'] = PLACEHOLDER # Unexpected key. + reference_item['z'] = PLACEHOLDER # Unexpected key. + + expected = self.pytree.copy() + expected['b'] = PLACEHOLDER + expected['c']['e'] = PLACEHOLDER + expected['c']['f'] = PLACEHOLDER + expected['z'] = PLACEHOLDER + + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as restore_handler: + restored = restore_handler.load(self.directory, reference_item) + test_utils.assert_tree_equal(self, expected, restored) + + @parameterized.product(use_ocdbt=(True, False)) + def test_partial_restore_with_placeholder_unexpected_keys_no_placeholder( + self, use_ocdbt: bool + ): + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as save_handler: + save_handler.save(self.directory, self.pytree) + + reference_item = self.abstract_pytree.copy() + reference_item['b'] = PLACEHOLDER + reference_item['c']['e'] = PLACEHOLDER + reference_item['z'] = 0 # Unexpected key, but not a placeholder. + + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as restore_handler: + with self.assertRaisesRegex( + ValueError, 'User-provided restore item and on-disk value' + ): + restore_handler.load(self.directory, reference_item) + + @parameterized.product( + use_ocdbt=(True, False), + use_placeholder=(True, False), + ) + def test_partial_restore_with_omission_unexpected_keys( + self, use_ocdbt: bool, use_placeholder: bool + ): + with handler_with_options( + use_ocdbt=use_ocdbt, + ) as save_handler: + save_handler.save(self.directory, self.pytree) + + reference_item = self.abstract_pytree.copy() + reference_item['c']['f'] = ( + PLACEHOLDER if use_placeholder else 0 + ) # Unexpected key. + reference_item['z'] = ( + PLACEHOLDER if not use_placeholder else 0 + ) # Unexpected key. + + expected = self.pytree.copy() + expected['c']['f'] = PLACEHOLDER if use_placeholder else 0 + expected['z'] = PLACEHOLDER if not use_placeholder else 0 + + with handler_with_options( + use_ocdbt=use_ocdbt, + partial_load=True, + ) as restore_handler: + restored = restore_handler.load(self.directory, reference_item) + test_utils.assert_tree_equal(self, expected, restored) + + @parameterized.product(use_zarr3=(True, False), use_ocdbt=(True, False)) + def test_custom_leaf_handler(self, use_zarr3: bool, use_ocdbt: bool): + + pytree = { + 'point1': Point(1, 2), + 'point2': Point(3, 4), + 'nested': { + 'point3': Point(5, 6), + 'point4': Point(7, 8), + }, + 'string_leaf': 'string_leaf', + 'number': 123, + 'pytree': self.pytree, + } + + array_metadata_store = ARRAY_METADATA_STORE + + leaf_handler_registry = registry.StandardLeafHandlerRegistry() + leaf_handler_registry.add(Point, AbstractPoint, PointLeafHandler) + + def _as_abstract_type(x): + if isinstance(x, Point): + return AbstractPoint + return as_abstract_type(x) + + with handler_with_options( + use_ocdbt=use_ocdbt, + leaf_handler_registry=leaf_handler_registry, + array_metadata_store=array_metadata_store, + use_zarr3=use_zarr3, + ) as checkpoint_handler: + checkpoint_handler.save(self.directory, pytree) + abstract_pytree = jax.tree.map(_as_abstract_type, pytree) + restored = checkpoint_handler.load(self.directory, abstract_pytree) + + test_utils.assert_tree_equal(self, pytree, restored) + + self.validate_metadata( + expected_reference_metadata_tree=pytree, + actual_metadata=checkpoint_handler.metadata(self.directory), + pytree_metadata_options=self.pytree_metadata_options, + array_metadata_store=array_metadata_store, + ) + + def test_abstract_array_loading(self): + replicated_sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), ('x',)), + jax.sharding.PartitionSpec(), + ) + value = array_test_utils.create_sharded_array( + np.arange(8), replicated_sharding + ) + abstract_value = jax.ShapeDtypeStruct( + value.shape, value.dtype, sharding=replicated_sharding + ) + with handler_with_options() as handler: + handler.save(self.directory, [value]) + restored = handler.load(self.directory, [abstract_value]) + test_utils.assert_tree_equal(self, [value], restored) + if not multihost.is_pathways_backend(): + restored = handler.load(self.directory, [jax.ShapeDtypeStruct]) + test_utils.assert_tree_equal(self, [value], restored) + + @parameterized.parameters( + (np.arange(8, dtype=np.int32), np.empty(8, dtype=np.int32)), + (np.arange(8), np.ndarray), + (1, 0), + (1, int), + (1.1, 0.0), + (1.1, float), + ('hi', '_'), + ('hi', str), + ) + def test_abstract_loading(self, value, abstract_value): + with handler_with_options() as handler: + handler.save(self.directory, [value]) + restored = handler.load(self.directory, [abstract_value]) + test_utils.assert_tree_equal(self, [value], restored) + + @parameterized.product( + use_ocdbt=(True, False), + use_zarr3=(True, False), + use_compression=(True, False), + ) + def test_compression( + self, use_ocdbt: bool, use_zarr3: bool, use_compression: bool + ): + + mesh = jax.sharding.Mesh(jax.devices(), 'x') + mesh_axes = jax.sharding.PartitionSpec( + 'x', + ) + pytree = { + 'a': test_utils.create_sharded_array( + np.arange(16), + mesh, + mesh_axes, + ), + } + with handler_with_options( + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + use_compression=use_compression, + ) as handler: + handler.save(self.directory, pytree) + + self.assertEqual( + test_utils.is_compression_used( + checkpoint_directory=self.directory, + param_name='a', + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + ), + use_compression, + ) + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py new file mode 100644 index 000000000..b880ad7f6 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py @@ -0,0 +1,510 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import time +from typing import Any +from unittest import mock + +from absl.testing import parameterized +from etils import epath +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.metadata import step_metadata_serialization +from orbax.checkpoint._src.testing import multiprocess_test +from orbax.checkpoint._src.tree import structure_utils as tree_structure_utils +from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler +from orbax.checkpoint.experimental.v1._src.handlers import registration +from orbax.checkpoint.experimental.v1._src.handlers import stateful_checkpointable_handler +from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types +import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import +from orbax.checkpoint.experimental.v1._src.layout import orbax_layout +from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization +from orbax.checkpoint.experimental.v1._src.partial import path as partial_path_lib +from orbax.checkpoint.experimental.v1._src.partial import saving as partial_saving +from orbax.checkpoint.experimental.v1._src.synchronization import multihost +from orbax.checkpoint.experimental.v1._src.testing import handler_utils +from orbax.checkpoint.experimental.v1._src.testing import path_utils + + +CHECKPOINT_METADATA = orbax_layout.CHECKPOINT_METADATA +ORBAX_CHECKPOINT_INDICATOR_FILE = orbax_layout.ORBAX_CHECKPOINT_INDICATOR_FILE +InternalCheckpointMetadata = ( + step_metadata_serialization.InternalCheckpointMetadata +) +PyTreeHandler = pytree_handler.PyTreeHandler +FooHandler = handler_utils.FooHandler +BarHandler = handler_utils.BarHandler +BazHandler = handler_utils.BazHandler +Foo = handler_utils.Foo +Bar = handler_utils.Bar +Baz = handler_utils.Baz +StatefulCheckpointableHandler = ( + stateful_checkpointable_handler.StatefulCheckpointableHandler +) +PartialSavePyTree = partial_saving._PartialSavePyTree + + +OrbaxLayout = orbax_layout.OrbaxLayout +InvalidLayoutError = orbax_layout.InvalidLayoutError + + +class OrbaxLayoutCompositeTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.directory = ( + epath.Path( + self.create_tempdir(name='orbax_layout_multiprocess_test_dir') + ) + / 'ckpt' + ) + self._mock_global_registry = registration.local_registry() + self.enter_context( + mock.patch.object( + registration, '_GLOBAL_REGISTRY', new=self._mock_global_registry + ) + ) + # Baz is registered globally, while the others are not. + registration.register_handler(BazHandler) + + def save( + self, + layout: OrbaxLayout, + directory: epath.Path, + checkpointables: dict[str, Any], + *, + partial_save: bool = False, + ): + test_utils.sync_global_processes('CompositeHandlerTest:save:start') + if multihost.is_primary_host(0): + directory.mkdir(parents=False, exist_ok=partial_save) + for k in checkpointables: + (directory / k).mkdir(parents=False, exist_ok=partial_save) + test_utils.sync_global_processes('CompositeHandlerTest:save:mkdir') + + async def _save(): + handler_typestrs = { + name: handler_types.typestr( + type( + registration.resolve_handler_for_save( + layout._handler_registry, checkpointables[name], name=name + ) + ) + ) + for name in checkpointables.keys() + } + + checkpoint_metadata_path = ( + metadata_serialization.checkpoint_metadata_file_path(directory) + ) + if partial_save and checkpoint_metadata_path.exists(): + checkpoint_metadata = await orbax_layout.read_checkpoint_metadata( + directory + ) + old_handler_typestrs = checkpoint_metadata.item_handlers + handler_typestrs = old_handler_typestrs | handler_typestrs + await multihost.sync_global_processes( + 'CompositeHandlerTest:save:checkpoint_metadata_read', + operation_id='op', + processes=None, + ) + + # Metadata expected to be created outside the handler. + if multihost.is_primary_host(0): + internal_metadata = InternalCheckpointMetadata.create( + handler_typestrs=handler_typestrs, + init_timestamp_nsecs=time.time_ns(), + commit_timestamp_nsecs=time.time_ns(), + custom_metadata={}, + ) + await metadata_serialization.write( + metadata_serialization.checkpoint_metadata_file_path(directory), + internal_metadata.serialize(), + ) + await multihost.sync_global_processes( + 'CompositeHandlerTest:save:checkpoint_metadata_write', + operation_id='op', + processes=None, + ) + awaitable = await layout.save( + path_utils.PathAwaitingCreationWrapper(directory), + checkpointables=checkpointables, + ) + await awaitable + + asyncio.run(_save()) + test_utils.sync_global_processes('CompositeHandlerTest:save:complete') + + def load(self, layout, directory, checkpointable): + test_utils.sync_global_processes('CompositeHandlerTest:load:start') + + async def _load(): + awaitable = await layout.load_checkpointables(directory, checkpointable) + return await awaitable + + result = asyncio.run(_load()) + test_utils.sync_global_processes('CompositeHandlerTest:load:complete') + return result + + def create_registry( + self, include_global_registry: bool = True + ) -> registration.CheckpointableHandlerRegistry: + return registration.local_registry( + include_global_registry=include_global_registry + ) + + def test_init(self): + patch_registry = self.create_registry().add( + PyTreeHandler, checkpointable_name='pytree_foo' + ) + layout = OrbaxLayout() + layout._handler_registry = patch_registry + self.assertTrue(layout._handler_registry.has('pytree_foo')) + self.assertEqual(layout._handler_registry.get('pytree_foo'), PyTreeHandler) + + self.assertTrue(layout._handler_registry.has('pytree')) + self.assertEqual(layout._handler_registry.get('pytree'), PyTreeHandler) + + @parameterized.product( + save_checkpointables=({'foo': {'a': 1}, 'bar': {'x': 5}},), + abstract_checkpointables=( + None, + {}, + {'foo': None, 'bar': None}, + {'foo': {'a': 0}, 'bar': {'x': 0}}, + {'foo': {'a': 0}}, # Skip loading 'bar'. + ), + ) + def test_save_load( + self, + save_checkpointables, + abstract_checkpointables, + ): + patch_registry = self.create_registry() + for k in save_checkpointables: + patch_registry.add(PyTreeHandler, checkpointable_name=k) + layout = OrbaxLayout() + layout._handler_registry = patch_registry + + self.save( + layout, + self.directory, + save_checkpointables, + ) + for k in save_checkpointables: + self.assertTrue((self.directory / k).exists()) + + result = self.load( + layout, + self.directory, + abstract_checkpointables, + ) + if abstract_checkpointables: + expected_result = { + k: v + for k, v in save_checkpointables.items() + if k in abstract_checkpointables + } + else: + expected_result = save_checkpointables + self.assertDictEqual(expected_result, result) + + @parameterized.product( + with_name=(True, False), + ) + def test_save_load_checkpointables( + self, + with_name: bool, + ): + if with_name: + pairs_to_register = [ + (PyTreeHandler, 'pytree'), + (FooHandler, 'foo'), + ] + else: + pairs_to_register = [ + (PyTreeHandler, None), + (FooHandler, None), + ] + registry = self.create_registry(include_global_registry=False) + for handler_type, checkpointable in pairs_to_register: + registry.add(handler_type, checkpointable_name=checkpointable) + layout = OrbaxLayout() + layout._handler_registry = registry + + checkpointables = {'pytree': {'a': 1}, 'foo': Foo(x=1, y='foo')} + self.save( + layout, + self.directory, + checkpointables, + ) + for k in checkpointables: + self.assertTrue((self.directory / k).exists()) + + result = self.load( + layout, + self.directory, + None, + ) + self.assertDictEqual(checkpointables, result) + + def test_save_unregistered_checkpointable(self): + checkpointables = {'foo': Foo(x=1, y='foo')} + registry = self.create_registry() + layout = OrbaxLayout() + layout._handler_registry = registry + with self.assertRaises(registration.NoEntryError): + self.save( + layout, + self.directory, + checkpointables, + ) + + def test_save_custom_object_with_global_registry(self): + checkpointables = {'baz': Baz(int_val=2, str_val='baz')} + registry = self.create_registry() + layout = OrbaxLayout() + layout._handler_registry = registry + registry.add(BazHandler, checkpointable_name='baz') + + self.save( + layout, + self.directory, + checkpointables, + ) + result = self.load(layout, self.directory, None) + self.assertDictEqual(checkpointables, result) + + def test_save_and_load_with_different_handlers(self): + checkpointables = {'foo': Foo(x=1, y='foo'), 'bar': Bar(a=5, b='bar')} + + registry = ( + self.create_registry() + .add(FooHandler, checkpointable_name='foo') + .add(BarHandler, checkpointable_name='bar') + ) + layout = OrbaxLayout() + layout._handler_registry = registry + self.save(layout, self.directory, checkpointables) + for k in checkpointables: + self.assertTrue((self.directory / k).exists()) + + registry = ( + self.create_registry() + .add(FooHandler, checkpointable_name='bar') + .add(BarHandler, checkpointable_name='foo') + ) + layout = OrbaxLayout() + layout._handler_registry = registry + + result = self.load(layout, self.directory, None) + expected_result = {'foo': Bar(a=1, b='foo'), 'bar': Foo(x=5, y='bar')} + self.assertDictEqual(expected_result, result) + + def test_orbax_identifier_file_exists(self): + checkpointables = {'foo': Foo(x=1, y='foo')} + registry = self.create_registry().add(FooHandler, checkpointable_name='foo') + layout = OrbaxLayout() + layout._handler_registry = registry + self.save(layout, self.directory, checkpointables) + self.assertTrue((self.directory / ORBAX_CHECKPOINT_INDICATOR_FILE).exists()) + test_utils.sync_global_processes( + 'CompositeHandlerTest:test_orbax_identifier_file_exists' + ) + + @parameterized.parameters(True, False) + def test_partial_save_and_finalize(self, finalize_with_partial_path: bool): + final_path = self.directory + partial_path = partial_path_lib.add_partial_save_suffix(final_path) + + first_save_checkpointables = { + 'foo': PartialSavePyTree({'a': 1}), + 'bar': PartialSavePyTree({'x': 5}), + 'foo_list': PartialSavePyTree([{'a1': 1, 'b1': 2}, {'a2': 3}]), + } + second_save_checkpointables = { + 'baz': PartialSavePyTree({'a': 2}), + 'foo': PartialSavePyTree({'b': 3}), + 'foo_list': PartialSavePyTree([{}, {'b2': 4}]), + } + merged_checkpointables = tree_structure_utils.merge_trees( + {k: v.pytree for k, v in first_save_checkpointables.items()}, + {k: v.pytree for k, v in second_save_checkpointables.items()}, + ) + registry = self.create_registry(include_global_registry=False) + registry.add(StatefulCheckpointableHandler) + registry.add(PyTreeHandler) + layout = OrbaxLayout() + layout._handler_registry = registry + + self.save( + layout, partial_path, first_save_checkpointables, partial_save=True + ) + self.assertTrue(partial_path.exists()) + self.assertTrue((partial_path / ORBAX_CHECKPOINT_INDICATOR_FILE).exists()) + + self.save( + layout, + partial_path, + second_save_checkpointables, + partial_save=True, + ) + self.assertTrue(partial_path.exists()) + self.assertTrue((partial_path / ORBAX_CHECKPOINT_INDICATOR_FILE).exists()) + + restored_checkpointables = self.load( + layout, partial_path, merged_checkpointables + ) + test_utils.assert_tree_equal( + self, restored_checkpointables, merged_checkpointables + ) + + partial_saving.finalize( + partial_path if finalize_with_partial_path else final_path + ) + self.assertTrue(final_path.exists()) + self.assertTrue((final_path / ORBAX_CHECKPOINT_INDICATOR_FILE).exists()) + + restored_checkpointables = self.load( + layout, final_path, merged_checkpointables + ) + test_utils.assert_tree_equal( + self, restored_checkpointables, merged_checkpointables + ) + + @parameterized.product( + second_save_checkpointables=({'foo': {'a': 2}}, {'bar': {'x': 6}}) + ) + def test_partial_save_replacement_raises_error( + self, second_save_checkpointables + ): + final_path = self.directory + partial_path = partial_path_lib.add_partial_save_suffix(final_path) + + first_save_checkpointables = { + 'foo': PartialSavePyTree({'a': 1}), + 'bar': PartialSavePyTree({'x': 5}), + } + + registry = self.create_registry(include_global_registry=False) + registry.add(StatefulCheckpointableHandler) + registry.add(PyTreeHandler) + layout = OrbaxLayout() + layout._handler_registry = registry + + self.save( + layout, partial_path, first_save_checkpointables, partial_save=True + ) + + wrapped_second_save = { + k: PartialSavePyTree(v) for k, v in second_save_checkpointables.items() + } + with self.assertRaisesRegex( + pytree_handler.PartialSaveReplacementError, + 'Partial saving currently does not support REPLACEMENT.', + ): + self.save( + layout, + partial_path, + wrapped_second_save, + partial_save=True, + ) + + @parameterized.product( + checkpointable_name=('foo', 'bar'), + first_save_leaf_is_subtree=(True, False), + ) + def test_partial_save_subtree_replacement_raises_error( + self, checkpointable_name: str, first_save_leaf_is_subtree: bool + ): + final_path = self.directory + partial_path = partial_path_lib.add_partial_save_suffix(final_path) + + if first_save_leaf_is_subtree: + tree1 = {'a': {'b': 1}} + tree2 = {'a': 2} + else: + tree1 = {'a': 2} + tree2 = {'a': {'b': 1}} + + first_save_checkpointables = { + checkpointable_name: PartialSavePyTree(tree1), + 'other': PartialSavePyTree({'c': 3}), + } + second_save_checkpointables = { + checkpointable_name: PartialSavePyTree(tree2) + } + + registry = self.create_registry(include_global_registry=False) + registry.add(StatefulCheckpointableHandler) + registry.add(PyTreeHandler) + layout = OrbaxLayout() + layout._handler_registry = registry + + self.save( + layout, partial_path, first_save_checkpointables, partial_save=True + ) + with self.assertRaisesRegex( + pytree_handler.PartialSaveReplacementError, + 'Partial saving currently does not support REPLACEMENT.', + ): + self.save( + layout, + partial_path, + second_save_checkpointables, + partial_save=True, + ) + + def test_partial_save_with_mixed_handlers(self): + final_path = self.directory + partial_path = partial_path_lib.add_partial_save_suffix(final_path) + + # PyTreeHandler supports partial save, FooHandler does not. + registry = self.create_registry(include_global_registry=False) + registry.add(StatefulCheckpointableHandler, checkpointable_name='pytree') + registry.add(FooHandler, checkpointable_name='foo') + layout = OrbaxLayout() + layout._handler_registry = registry + + first_save = { + 'pytree': PartialSavePyTree({'a': 1}), + 'foo': Foo(x=1, y='foo1'), + } + self.save(layout, partial_path, first_save, partial_save=True) + + second_save = { + 'pytree': PartialSavePyTree({'b': 2}), + 'foo': Foo(x=2, y='foo2'), + } + self.save(layout, partial_path, second_save, partial_save=True) + + partial_saving.finalize(final_path) + + # PyTreeHandler should have merged the results. + # FooHandler should have overwritten. + expected = {'pytree': {'a': 1, 'b': 2}, 'foo': Foo(x=2, y='foo2')} + load_registry = self.create_registry(include_global_registry=False) + load_registry.add(PyTreeHandler, checkpointable_name='pytree') + load_registry.add(FooHandler, checkpointable_name='foo') + load_layout = OrbaxLayout() + load_layout._handler_registry = load_registry + + restored = self.load(load_layout, final_path, None) + test_utils.assert_tree_equal(self, expected, restored) + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py new file mode 100644 index 000000000..8cb6456b6 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py @@ -0,0 +1,479 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sharded loading with SafetensorsLayout.""" + +import gc +import tracemalloc +import unittest +from absl.testing import parameterized +from etils import epath +import jax +import jax.experimental.multihost_utils +import jax.sharding +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.testing import multiprocess_test +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.context import options as options_lib +from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout +import safetensors.numpy + +SafetensorsLayout = safetensors_layout.SafetensorsLayout +np_save_file = safetensors.numpy.save_file +Mesh = jax.sharding.Mesh +NamedSharding = jax.sharding.NamedSharding +PartitionSpec = jax.sharding.PartitionSpec +jnp = jax.numpy + + +def _get_partition_spec( + mesh_config, array_shape, sharding_type +) -> PartitionSpec | None: + """Returns the partition spec for a given sharding type, or None if invalid.""" + mesh_shape, mesh_axes = mesh_config["shape"], mesh_config["axes"] + rank = len(array_shape) + is_scalar_like = rank == 0 or (rank == 1 and array_shape[0] == 1) + if is_scalar_like and sharding_type != "fully_replicated": + return None + + if sharding_type == "fully_replicated": + pspec = PartitionSpec() + elif sharding_type == "fully_sharded": + if rank > len(mesh_axes): + return None + pspec = PartitionSpec(*mesh_axes[:rank]) + else: + pspec_list = [mesh_axes[0]] + for _ in range(rank - 1): + pspec_list.append(None) + pspec = PartitionSpec(*pspec_list) + + # Need to verify that an array dimension is divisible by the size of the + # mesh axis. + for i, axis_name in enumerate(pspec): + if axis_name is not None: + array_dim_size = array_shape[i] + mesh_axis_index = mesh_axes.index(axis_name) + mesh_axis_size = mesh_shape[mesh_axis_index] + if array_dim_size % mesh_axis_size != 0: + return None + + # If all checks pass, the combination is possible. + return pspec + + +class ShardedSafetensorsLayoutTest( + unittest.IsolatedAsyncioTestCase, + parameterized.TestCase, + multiprocess_test.MultiProcessTest, +): + + def setUp(self): + super().setUp() + self.assertEqual(jax.device_count(), 8) + self.assertEqual(jax.process_count(), 4) + self.assertEqual(jax.local_device_count(), 2) + self.test_dir = epath.Path( + self.multiprocess_create_tempdir(name="test_dir") + ) + + devices = jax.devices() + mesh_shape = (len(devices) // 2, 2) + self.mesh = Mesh( + np.array(devices).reshape(mesh_shape), ("data", "model") + ) + test_utils.sync_global_processes("setUp") + + def tearDown(self): + super().tearDown() + test_utils.sync_global_processes("tearDown") + + @parameterized.product( + mesh_config=[ + {"shape": (4, 2), "axes": ("data", "model")}, + {"shape": (2, 4), "axes": ("data", "model")}, + {"shape": (8, 1), "axes": ("data", "model")}, + {"shape": (1, 8), "axes": ("data", "model")}, + {"shape": (1, 8, 1), "axes": ("d1", "d2", "d3")}, + {"shape": (1, 2, 4, 1), "axes": ("d1", "d2", "d3", "d4")}, + ], + array_shape=[ + (), + (1,), + (16,), + (8, 8), + (4, 4, 4), + ], + sharding_type=[ + "fully_replicated", + "fully_sharded", + "one_axis_sharded", + ], + ) + async def test_sharding_scenarios( + self, mesh_config, array_shape, sharding_type + ): + # We are skipping tests that attempt to construct an invalid sharding spec. + # In the next test, we validate that we get the expected error message. + mesh_shape, mesh_axes = mesh_config["shape"], mesh_config["axes"] + sharding_spec = _get_partition_spec(mesh_config, array_shape, sharding_type) + if sharding_spec is None: + self.skipTest("Invalid sharding spec.") + + # Create the tensor to save + if not array_shape: + tensor_to_save = np.float32(1.0) + else: + num_elements = np.prod(array_shape) + tensor_to_save = np.arange(num_elements, dtype=np.float32).reshape( + array_shape + ) + + tensor_data = {"params.tensor": tensor_to_save} + mesh = Mesh(np.array(jax.devices()).reshape(mesh_shape), mesh_axes) + st_path = self.test_dir / f"{self.id()}.safetensors" + if jax.process_index() == 0: + np_save_file(tensor_data, st_path) + test_utils.sync_global_processes(self.id()) + + abstract_sharding = NamedSharding(mesh, sharding_spec) + abstract_state = { + "params.tensor": jax.ShapeDtypeStruct( + shape=array_shape, dtype=np.float32, sharding=abstract_sharding + ), + } + expected_tensor = jax.device_put(tensor_to_save, abstract_sharding) + + layout = SafetensorsLayout() + restore_fn = await layout.load_pytree( + st_path, abstract_pytree=abstract_state + ) + restored_tensor = await restore_fn + restored_tensor = restored_tensor["params.tensor"] + + self.assertEqual(restored_tensor.sharding, expected_tensor.sharding) + test_utils.assert_array_equal(self, expected_tensor, restored_tensor) + + async def test_load_without_global_reshard_single_tensor(self): + """Tests loading with ignore_load_sharding=True with a single tensor.""" + array_shape = (4, 4) + tensor_to_save = np.arange(16, dtype=np.float32).reshape(array_shape) + tensor_data = {"params.tensor": tensor_to_save} + + st_path = self.test_dir / f"{self.id()}.safetensors" + if jax.process_index() == 0: + np_save_file(tensor_data, st_path) + test_utils.sync_global_processes(self.id()) + + abstract_sharding = NamedSharding(self.mesh, PartitionSpec("data", "model")) + abstract_state = { + "params.tensor": jax.ShapeDtypeStruct( + shape=array_shape, dtype=np.float32, sharding=abstract_sharding + ), + } + + layout = SafetensorsLayout() + with context_lib.Context( + safetensors_options=options_lib.SafetensorsOptions( + ignore_load_sharding=True + ) + ): + restore_fn = await layout.load_pytree( + st_path, abstract_pytree=abstract_state + ) + restored_pytree = await restore_fn + restored_tensor = restored_pytree["params.tensor"] + + self.assertEqual(restored_tensor.shape, array_shape) + + if len(restored_tensor.addressable_shards) == 1: + np.testing.assert_array_equal( + restored_tensor.addressable_shards[0].data, tensor_to_save + ) + else: + self.assertEmpty(restored_tensor.addressable_shards) + + async def test_load_without_global_reshard_multi_tensor(self): + """Tests loading with ignore_load_sharding=True with multiple tensors.""" + array_shape = (4, 4) + tensor_data = { + f"params.tensor_{i}": ( + np.arange(16, dtype=np.float32).reshape(array_shape) + i + ) + for i in range(4) + } + + st_path = self.test_dir / f"{self.id()}.safetensors" + if jax.process_index() == 0: + np_save_file(tensor_data, st_path) + test_utils.sync_global_processes(self.id()) + + abstract_sharding = NamedSharding(self.mesh, PartitionSpec("data", "model")) + abstract_state = { + f"params.tensor_{i}": jax.ShapeDtypeStruct( + shape=array_shape, dtype=np.float32, sharding=abstract_sharding + ) + for i in range(4) + } + + layout = SafetensorsLayout() + with context_lib.Context( + safetensors_options=options_lib.SafetensorsOptions( + ignore_load_sharding=True + ) + ): + restore_fn = await layout.load_pytree( + st_path, abstract_pytree=abstract_state + ) + restored_pytree = await restore_fn + + # Tensors are expected to be distributed among hosts. + # With 4 hosts and 4 equal sized tensors, each host should own one. + for i in range(4): + tensor_name = f"params.tensor_{i}" + restored_tensor = restored_pytree[tensor_name] + self.assertEqual(restored_tensor.shape, array_shape) + + if len(restored_tensor.addressable_shards) == 1: + np.testing.assert_array_equal( + restored_tensor.addressable_shards[0].data, tensor_data[tensor_name] + ) + else: + self.assertEmpty(restored_tensor.addressable_shards) + + async def test_load_multi_host_memory_efficiency(self): + """Verifies that non-owner hosts don't materialize full zero buffers.""" + num_tensors = 100 + tensor_shape = (1024, 1024) # 1M elements = 4MB for float32 + # Total logical size for 100 tensors = 400MB. + num_elements = np.prod(tensor_shape) + bytes_per_tensor = num_elements * np.dtype(np.float32).itemsize + + abstract_sharding = NamedSharding(self.mesh, PartitionSpec("data", "model")) + + abstract_pytree = { + f"tensor_{i}": jax.ShapeDtypeStruct( + shape=tensor_shape, dtype=np.float32, sharding=abstract_sharding + ) + for i in range(num_tensors) + } + + file_path = self.test_dir / "dummy.safetensors" + + if jax.process_index() == 0: + tensors = { + f"tensor_{i}": np.zeros(tensor_shape, dtype=np.float32) + for i in range(num_tensors) + } + safetensors.numpy.save_file(tensors, file_path) + del tensors + + gc.collect() + + test_utils.sync_global_processes(self.id()) + + layout = SafetensorsLayout() + + tracemalloc.start() + + restore_fn = await layout.load_pytree( + file_path, + abstract_pytree=abstract_pytree, + ) + pytree = await restore_fn + + jax.block_until_ready(pytree) + + unused_current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + # Peak memory should be dominated by owned tensors (approx 100MB). + # Will also contain a single zero buffer of 4MB to be used in place of all + # non-owned tensors. + # If non-owned tensors were materialized, it would be 400MB. + tensors_per_host = num_tensors // jax.process_count() + expected_peak = bytes_per_tensor * (tensors_per_host + 1) + fudge_factor = 1.2 # Account for overhead, Python objects, etc. + + self.assertLess(peak, fudge_factor * expected_peak) + + async def test_load_without_global_reshard_memory_efficiency(self): + """Verifies that non-owner hosts don't materialize full zero buffers when ignore_load_sharding=True.""" + num_tensors = 100 + tensor_shape = (1024, 1024) # 1M elements = 4MB for float32 + # Total logical size for 100 tensors = 400MB. + num_elements = np.prod(tensor_shape) + bytes_per_tensor = num_elements * np.dtype(np.float32).itemsize + + abstract_sharding = NamedSharding(self.mesh, PartitionSpec("data", "model")) + + abstract_pytree = { + f"tensor_{i}": jax.ShapeDtypeStruct( + shape=tensor_shape, dtype=np.float32, sharding=abstract_sharding + ) + for i in range(num_tensors) + } + + file_path = self.test_dir / "dummy_no_reshard.safetensors" + + if jax.process_index() == 0: + tensors = { + f"tensor_{i}": np.zeros(tensor_shape, dtype=np.float32) + for i in range(num_tensors) + } + safetensors.numpy.save_file(tensors, file_path) + del tensors + gc.collect() + + test_utils.sync_global_processes(self.id()) + + layout = SafetensorsLayout() + + tracemalloc.start() + + with context_lib.Context( + safetensors_options=options_lib.SafetensorsOptions( + ignore_load_sharding=True + ) + ): + restore_fn = await layout.load_pytree( + file_path, + abstract_pytree=abstract_pytree, + ) + pytree = await restore_fn + + jax.block_until_ready(pytree) + + unused_current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + # Peak memory should be dominated by owned tensors (approx 100MB). + # If non-owned tensors were materialized, it would be 400MB. + tensors_per_host = num_tensors // jax.process_count() + expected_peak = bytes_per_tensor * (tensors_per_host + 1) + fudge_factor = 1.2 # Account for overhead, Python objects, etc. + + self.assertLess(peak, fudge_factor * expected_peak) + + def test_sharding_fails_when_divisibility_check_fails(self): + """Tests that JAX errors when an array dim is not divisible by a mesh dim.""" + mesh_config = {"shape": (8, 1), "axes": ("data", "model")} + array_shape = ( + 4, + 4, + 4, + ) # Dimension size 4 is not divisible by mesh axis data (size 8) + sharding_spec = PartitionSpec("data", None, None) + + mesh = Mesh( + np.array(jax.devices()).reshape(mesh_config["shape"]), + mesh_config["axes"], + ) + abstract_sharding = NamedSharding(mesh, sharding_spec) + tensor_to_save = np.zeros(array_shape, dtype=np.float32) + + with self.assertRaisesRegex( + ValueError, "partitioned 8 times, but the dimension size is 4" + ): + jax.device_put(tensor_to_save, abstract_sharding) + + def test_sharding_fails_with_scalar(self): + """Tests that JAX errors when attempting to shard a scalar.""" + mesh_config = {"shape": (4, 2), "axes": ("data", "model")} + sharding_spec = PartitionSpec("data") + + mesh = Mesh( + np.array(jax.devices()).reshape(mesh_config["shape"]), + mesh_config["axes"], + ) + abstract_sharding = NamedSharding(mesh, sharding_spec) + tensor_to_save = np.float32(1.0) + with self.assertRaisesRegex( + ValueError, "For scalars the PartitionSpec should be P()" + ): + jax.device_put(tensor_to_save, abstract_sharding) + + def test_sharding_fails_with_non_existent_axes(self): + """Tests that JAX errors when the PartitionSpec references non-existent mesh axes.""" + mesh_config = {"shape": (4, 2), "axes": ("data", "model")} + # d3 does not exist in the mesh axes + sharding_spec = PartitionSpec("data", "model", "d3") + + mesh = Mesh( + np.array(jax.devices()).reshape(mesh_config["shape"]), + mesh_config["axes"], + ) + + with self.assertRaisesRegex(ValueError, "is not found in mesh"): + _ = NamedSharding(mesh, sharding_spec) + + async def test_load_sharded_fails_with_nested_abstract_pytree(self): + """Tests that loading fails if the abstract pytree is nested.""" + st_path = self.test_dir / "nested_fail.safetensors" + if jax.process_index() == 0: + np_save_file({"a": np.arange(8)}, st_path) + test_utils.sync_global_processes( + "test_load_sharded_fails_with_nested_abstract_pytree" + ) + layout = SafetensorsLayout() + + nested_abstract_pytree = { + "params": { + "a": jax.ShapeDtypeStruct( + shape=(8,), + dtype=np.int32, + sharding=NamedSharding(self.mesh, PartitionSpec()), + ), + } + } + with self.assertRaisesRegex( + ValueError, "The PyTree is not a flat dictionary." + ): + test_awaitable = await layout.load_pytree( + st_path, abstract_pytree=nested_abstract_pytree + ) + await test_awaitable + + async def test_load_sharded_fails_with_wrong_key_abstract_pytree(self): + """Tests that loading fails if a key in the abstract pytree is not in the file.""" + st_path = self.test_dir / "wrong_key_fail.safetensors" + if jax.process_index() == 0: + np_save_file({"a": np.arange(8)}, st_path) + test_utils.sync_global_processes( + "test_load_sharded_fails_with_wrong_key_abstract_pytree" + ) + + layout = SafetensorsLayout() + + wrong_key_abstract_pytree = { + "a": jax.ShapeDtypeStruct( + shape=(8,), + dtype=np.int32, + sharding=NamedSharding(self.mesh, PartitionSpec()), + ), + "c": jax.ShapeDtypeStruct(shape=(3,), dtype=np.float32), # Wrong key + } + + with self.assertRaisesRegex( + KeyError, "not found in Safetensors checkpoint" + ): + test_awaitable = await layout.load_pytree( + st_path, abstract_pytree=wrong_key_abstract_pytree + ) + await test_awaitable + + +if __name__ == "__main__": + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/serialization_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/serialization_test.py new file mode 100644 index 000000000..3dc0f703e --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/serialization_test.py @@ -0,0 +1,72 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +from orbax.checkpoint.experimental.v1._src.metadata import serialization + + +class SerializationTest( + parameterized.TestCase, unittest.IsolatedAsyncioTestCase +): + + def setUp(self): + super().setUp() + self.directory = epath.Path( + self.create_tempdir(name='serialization_test').full_path + ) + + @parameterized.parameters( + ({'property': 'value'},), + ({'property': 123},), + ({'property': True},), + ({'property': None},), + ({'property': []},), + ({'property': {}},), + ({'property': [1, 2, 3]},), + ({'property': {'a': 1, 'b': 2, 'c': 3}},), + ({},), + ) + async def test_write_and_read(self, d): + await serialization.write(self.directory / 'metadata.json', d) + result = await serialization.read(self.directory / 'metadata.json') + self.assertDictEqual(result, d) + + @parameterized.parameters( + ({'property': b'123'},), + ([],), + ) + async def test_not_writeable(self, d): + with self.assertRaises(TypeError): + await serialization.write( + self.directory / 'metadata.json', + d + ) + + async def test_write_no_parent(self): + with self.assertRaises(FileNotFoundError): + await serialization.write( + epath.Path('/foo/bar/metadata.json'), {'property': 'value'} + ) + + async def test_read_no_file(self): + self.assertIsNone( + await serialization.read(self.directory / 'metadata.json') + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler_test.py new file mode 100644 index 000000000..9e6369f82 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler_test.py @@ -0,0 +1,324 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from absl import flags +from absl.testing import parameterized +from etils import epath +import jax +import jax.numpy as jnp +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint import utils +from orbax.checkpoint._src.serialization import ocdbt_utils +from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils +from orbax.checkpoint._src.tree import utils as tree_utils +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.context import options as options_lib +from orbax.checkpoint.experimental.v1._src.serialization import array_leaf_handler +from orbax.checkpoint.experimental.v1._src.serialization import types +from orbax.checkpoint.experimental.v1._src.synchronization import multihost +from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils +from orbax.checkpoint.experimental.v1._src.testing import path_utils as path_test_utils + +from orbax.checkpoint._src.testing.oss import multiprocess_test + +FLAGS = flags.FLAGS +jax.config.update('jax_enable_x64', True) + + +def _get_serialization_params(pytree): + return [ + array_leaf_handler.ArraySerializationParam( + keypath=keypath, + value=array, + ) + for keypath, array in jax.tree.flatten_with_path(pytree)[0] + ] + + +def _get_deserialization_params(pytree): + ret = [] + for keypath, array in jax.tree.flatten_with_path(pytree)[0]: + shapedtype = tree_utils.to_shape_dtype_struct(array) + assert isinstance(shapedtype, jax.ShapeDtypeStruct) + ret.append( + array_leaf_handler.ArrayDeserializationParam( + keypath=keypath, + value=shapedtype, + ) + ) + return ret + + +def _get_metadata_params(pytree): + return [ + types.DeserializationParam[None]( + keypath=keypath, + ) + for keypath, _ in jax.tree.flatten_with_path(pytree)[0] + ] + + +class ArrayLeafHandlerTest( + unittest.IsolatedAsyncioTestCase, parameterized.TestCase +): + + def setUp(self): + super().setUp() + mesh = jax.sharding.Mesh( + jax.devices(), + ('x',), + axis_types=(jax.sharding.AxisType.Auto,) * len(('x',)), + ) + sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) + replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + self.pytree = { + 'replicated': array_test_utils.create_sharded_array( + np.arange(16), replicated + ), + 'sharded': array_test_utils.create_sharded_array( + np.arange(32), sharded + ), + } + + if not utils.is_pathways_backend(): + self.pytree.update({ + 'rand0': jax.random.key( + jnp.array(1, device=replicated), impl='threefry2x32' + ), + 'rand1': jax.random.key(jnp.array(2, device=replicated), impl='rbg'), + }) + + test_utils.sync_global_processes('setUp') + + def tearDown(self): + test_utils.sync_global_processes('tearDown') + super().tearDown() + + async def _test_simple_checkpoint_impl( + self, + use_ocdbt: bool = True, + use_zarr3: bool = True, + use_replica_parallel: bool = True, + enable_replica_parallel_separate_folder: bool = False, + enable_pinned_host_transfer: bool = False, + save_concurrent_bytes: int | None = None, + load_concurrent_bytes: int | None = None, + use_compression: bool = True, + min_slice_bytes_for_replica_parallel: int | None = None, + max_replicas_for_replica_parallel: int | None = None, + ): + # make unit with different tests + parent_dir = epath.Path( + self.create_tempdir(f'tmp_{self._testMethodName}').full_path + ) + + init_context = context_lib.Context( + array_options=options_lib.ArrayOptions( + saving=options_lib.ArrayOptions.Saving( + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + use_replica_parallel=use_replica_parallel, + enable_replica_parallel_separate_folder=enable_replica_parallel_separate_folder, + enable_pinned_host_transfer=enable_pinned_host_transfer, + use_compression=use_compression, + min_slice_bytes_for_replica_parallel=min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel=max_replicas_for_replica_parallel, + ), + loading=options_lib.ArrayOptions.Loading(), + ), + memory_options=options_lib.MemoryOptions( + write_concurrent_bytes=save_concurrent_bytes, + read_concurrent_bytes=load_concurrent_bytes, + ), + ) + + with context_lib.get_context(init_context) as context: + + handler = array_leaf_handler.ArrayLeafHandler() + + # serialization + use_ocdbt = context.array_options.saving.use_ocdbt + serialization_context = types.SerializationContext( + parent_dir=path_test_utils.PathAwaitingCreationWrapper(parent_dir), + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ) + serialization_params = _get_serialization_params(self.pytree) + task = await handler.serialize( + params=serialization_params, + serialization_context=serialization_context, + ) + await task + test_utils.sync_global_processes( + f'{self._testMethodName}_serialize_complete' + ) + + # try finalize + if use_ocdbt and multihost.is_primary_host( + context.multiprocessing_options.primary_host + ): + await ocdbt_utils.merge_ocdbt_per_process_files( + parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + use_zarr3=context.array_options.saving.use_zarr3, + ) + + test_utils.sync_global_processes( + f'{self._testMethodName}_finalize_complete' + ) + + # deserialization + deserialization_context = types.DeserializationContext( + parent_dir=parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ocdbt_checkpoint=use_ocdbt, + zarr3_checkpoint=use_zarr3, + ) + deserialization_task = await handler.deserialize( + _get_deserialization_params(self.pytree), + deserialization_context=deserialization_context, + ) + + restored = await deserialization_task + + for p, restored_array in zip(serialization_params, restored): + test_utils.assert_array_equal(self, p.value, restored_array) + + # validate whether compression used + self.assertEqual( + test_utils.is_compression_used( + parent_dir, + serialization_params[0].name, + use_zarr3, + use_ocdbt, + ), + use_compression, + ) + + @parameterized.product( + use_ocdbt=(True, False), + use_zarr3=(True, False), + enable_pinned_host_transfer=(True, False), + save_concurrent_bytes=(None, 64), + load_concurrent_bytes=(None, 64), + use_compression=(True, False), + ) + async def test_simple_checkpoint( + self, + use_ocdbt: bool, + use_zarr3: bool, + enable_pinned_host_transfer: bool, + save_concurrent_bytes: int | None, + load_concurrent_bytes: int | None, + use_compression: bool, + ): + await self._test_simple_checkpoint_impl( + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + enable_pinned_host_transfer=enable_pinned_host_transfer, + save_concurrent_bytes=save_concurrent_bytes, + load_concurrent_bytes=load_concurrent_bytes, + use_compression=use_compression, + ) + + @parameterized.product( + use_replica_parallel=(True, False), + enable_replica_parallel_separate_folder=(True, False), + min_slice_bytes_for_replica_parallel=(None, 1024), + max_replicas_for_replica_parallel=(None, 2), + ) + async def test_simple_checkpoint_for_replica_parallel( + self, + use_replica_parallel: bool, + enable_replica_parallel_separate_folder: bool, + min_slice_bytes_for_replica_parallel: int | None, + max_replicas_for_replica_parallel: int | None, + ): + await self._test_simple_checkpoint_impl( + use_replica_parallel=use_replica_parallel, + enable_replica_parallel_separate_folder=enable_replica_parallel_separate_folder, + min_slice_bytes_for_replica_parallel=min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel=max_replicas_for_replica_parallel, + ) + + async def test_metadata(self): + # make unit with different tests + parent_dir = epath.Path(self.create_tempdir('test_metadata').full_path) + + with context_lib.get_context() as context: + + handler = array_leaf_handler.ArrayLeafHandler() + + # serialization + use_ocdbt = context.array_options.saving.use_ocdbt + serialization_context = types.SerializationContext( + parent_dir=path_test_utils.PathAwaitingCreationWrapper(parent_dir), + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ) + serialization_params = _get_serialization_params(self.pytree) + task = await handler.serialize( + params=serialization_params, + serialization_context=serialization_context, + ) + await task + test_utils.sync_global_processes( + f'{self._testMethodName}_serialize_complete' + ) + + # try finalize + if use_ocdbt and multihost.is_primary_host( + context.multiprocessing_options.primary_host + ): + await ocdbt_utils.merge_ocdbt_per_process_files( + parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + use_zarr3=context.array_options.saving.use_zarr3, + ) + + test_utils.sync_global_processes( + f'{self._testMethodName}_finalize_complete' + ) + test_utils.print_directory(parent_dir) + + # load the metadata + use_ocdbt = context.array_options.saving.use_ocdbt + use_zarr3 = context.array_options.saving.use_zarr3 + deserialization_context = types.DeserializationContext( + parent_dir=parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ocdbt_checkpoint=use_ocdbt, + zarr3_checkpoint=use_zarr3, + ) + + metadata = await handler.metadata( + _get_metadata_params(self.pytree), + deserialization_context=deserialization_context, + ) + + # validate metadata + for p, m in zip(serialization_params, metadata): + expected_v = p.value + if jax.dtypes.issubdtype(p.value.dtype, jax.dtypes.prng_key): + expected_v = jax.random.key_data(p.value) + + self.assertEqual(expected_v.shape, m.shape) + self.assertEqual(expected_v.dtype, m.dtype) + expected_sharding = expected_v.sharding + self.assertEqual(expected_sharding, m.sharding) + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler_test.py new file mode 100644 index 000000000..6b8d579c6 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler_test.py @@ -0,0 +1,250 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unittests for NumpyLeafHandler.""" + +import unittest + +from absl import flags +from absl.testing import parameterized +from etils import epath +import jax +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.serialization import ocdbt_utils +from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.context import options as options_lib +from orbax.checkpoint.experimental.v1._src.serialization import numpy_leaf_handler +from orbax.checkpoint.experimental.v1._src.serialization import types +from orbax.checkpoint.experimental.v1._src.synchronization import multihost +from orbax.checkpoint.experimental.v1._src.testing import path_utils as path_test_utils + +from orbax.checkpoint._src.testing.oss import multiprocess_test + +FLAGS = flags.FLAGS + + +def _get_serialization_params(pytree): + return [ + numpy_leaf_handler.NumpySerializationParam( + keypath=keypath, + value=nparray, + ) + for keypath, nparray in jax.tree.flatten_with_path(pytree)[0] + ] + + +def _get_deserialization_params(pytree): + ret = [] + for keypath, nparray in jax.tree.flatten_with_path(pytree)[0]: + + ret.append( + numpy_leaf_handler.NumpyDeserializationParam( + keypath=keypath, + value=numpy_leaf_handler.NumpyShapeDtype( + shape=nparray.shape, dtype=nparray.dtype + ), + ) + ) + return ret + + +def _get_metadata_params(pytree): + return [ + types.DeserializationParam[None]( + keypath=keypath, + ) + for keypath, _ in jax.tree.flatten_with_path(pytree)[0] + ] + + +class NumpyLeafHandlerTest( + unittest.IsolatedAsyncioTestCase, parameterized.TestCase +): + + def setUp(self): + super().setUp() + self.pytree = { + 'np1': np.arange(16, dtype=np.float32), + 'np2': np.arange(32, dtype=np.int32), + 'np3': np.arange(1, dtype=np.float64), + } + + test_utils.sync_global_processes('setUp') + + def tearDown(self): + test_utils.sync_global_processes('tearDown') + super().tearDown() + + @parameterized.product( + use_ocdbt=(True, False), + use_zarr3=(True, False), + save_concurrent_bytes=(None, 64), + load_concurrent_bytes=(None, 64), + use_compression=(True, False), + ) + async def test_simple_checkpoint( + self, + use_ocdbt: bool, + use_zarr3: bool, + save_concurrent_bytes: int | None, + load_concurrent_bytes: int | None, + use_compression: bool, + ): + # make unit with different tests + parent_dir = epath.Path( + self.create_tempdir(f'tmp_{self._testMethodName}').full_path + ) + + init_context = context_lib.Context( + array_options=options_lib.ArrayOptions( + saving=options_lib.ArrayOptions.Saving( + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + use_compression=use_compression, + ), + loading=options_lib.ArrayOptions.Loading(), + ), + memory_options=options_lib.MemoryOptions( + write_concurrent_bytes=save_concurrent_bytes, + read_concurrent_bytes=load_concurrent_bytes, + ), + ) + + with context_lib.get_context(init_context) as context: + + handler = numpy_leaf_handler.NumpyLeafHandler() + + # serialization + use_ocdbt = context.array_options.saving.use_ocdbt + serialization_context = types.SerializationContext( + parent_dir=path_test_utils.PathAwaitingCreationWrapper(parent_dir), + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ) + serialization_params = _get_serialization_params(self.pytree) + task = await handler.serialize( + params=serialization_params, + serialization_context=serialization_context, + ) + await task + test_utils.sync_global_processes( + f'{self._testMethodName}_serialize_complete' + ) + + # try finalize + if use_ocdbt and multihost.is_primary_host( + context.multiprocessing_options.primary_host + ): + await ocdbt_utils.merge_ocdbt_per_process_files( + parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + use_zarr3=context.array_options.saving.use_zarr3, + ) + + test_utils.sync_global_processes( + f'{self._testMethodName}_finalize_complete' + ) + + # deserialization + deserialization_context = types.DeserializationContext( + parent_dir=parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ocdbt_checkpoint=use_ocdbt, + zarr3_checkpoint=use_zarr3, + ) + deserialization_task = await handler.deserialize( + _get_deserialization_params(self.pytree), + deserialization_context=deserialization_context, + ) + + restored = await deserialization_task + + for p, restored_array in zip(serialization_params, restored): + test_utils.assert_array_equal(self, p.value, restored_array) + + # validate whether compression used + self.assertEqual( + test_utils.is_compression_used( + parent_dir, + serialization_params[0].name, + use_zarr3, + use_ocdbt, + ), + use_compression, + ) + + async def test_metadata(self): + # make unit with different tests + parent_dir = epath.Path(self.create_tempdir('test_metadata').full_path) + + with context_lib.get_context() as context: + + handler = numpy_leaf_handler.NumpyLeafHandler() + + # serialization + use_ocdbt = context.array_options.saving.use_ocdbt + serialization_context = types.SerializationContext( + parent_dir=path_test_utils.PathAwaitingCreationWrapper(parent_dir), + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ) + serialization_params = _get_serialization_params(self.pytree) + task = await handler.serialize( + params=serialization_params, + serialization_context=serialization_context, + ) + await task + test_utils.sync_global_processes( + f'{self._testMethodName}_serialize_complete' + ) + + # try finalize + if use_ocdbt and multihost.is_primary_host( + context.multiprocessing_options.primary_host + ): + await ocdbt_utils.merge_ocdbt_per_process_files( + parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + use_zarr3=context.array_options.saving.use_zarr3, + ) + + test_utils.sync_global_processes( + f'{self._testMethodName}_finalize_complete' + ) + test_utils.print_directory(parent_dir) + + # load the metadata + use_ocdbt = context.array_options.saving.use_ocdbt + use_zarr3 = context.array_options.saving.use_zarr3 + deserialization_context = types.DeserializationContext( + parent_dir=parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ocdbt_checkpoint=use_ocdbt, + zarr3_checkpoint=use_zarr3, + ) + + metadata = await handler.metadata( + _get_metadata_params(self.pytree), + deserialization_context=deserialization_context, + ) + + # validate metadata + for p, m in zip(serialization_params, metadata): + expected_v = p.value + self.assertEqual(expected_v.shape, m.shape) + self.assertEqual(expected_v.dtype, m.dtype) + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler_test.py new file mode 100644 index 000000000..72e3f0dff --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler_test.py @@ -0,0 +1,269 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unittests for ScalarLeafHandler.""" + +import unittest + +from absl import flags +from absl.testing import parameterized +from etils import epath +import jax +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.serialization import ocdbt_utils +from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.context import options as options_lib +from orbax.checkpoint.experimental.v1._src.serialization import scalar_leaf_handler +from orbax.checkpoint.experimental.v1._src.serialization import types +from orbax.checkpoint.experimental.v1._src.synchronization import multihost +from orbax.checkpoint.experimental.v1._src.testing import path_utils as path_test_utils + +from orbax.checkpoint._src.testing.oss import multiprocess_test + +FLAGS = flags.FLAGS + + +def _get_serialization_params(pytree): + return [ + scalar_leaf_handler.ScalarSerializationParam( + keypath=keypath, + value=scalar, + ) + for keypath, scalar in jax.tree.flatten_with_path(pytree)[0] + ] + + +def _get_deserialization_params( + pytree, cast_to: type[int | float] | None = None, pass_scalar=False +): + ret = [] + for keypath, scalar in jax.tree.flatten_with_path(pytree)[0]: + + ret.append( + scalar_leaf_handler.ScalarDeserializationParam( + keypath=keypath, + value=scalar if pass_scalar else cast_to, + ) + ) + return ret + + +def _get_metadata_params(pytree): + return [ + types.DeserializationParam[None]( + keypath=keypath, + ) + for keypath, _ in jax.tree.flatten_with_path(pytree)[0] + ] + + +class ScalarLeafHandlerTest( + unittest.IsolatedAsyncioTestCase, parameterized.TestCase +): + + def setUp(self): + super().setUp() + self.pytree = { + 'int_value': 0, + 'float_value': 1.1, + 'np_int32_value': np.int32(2), + 'np_int64_value': np.int64(3), + 'np_float32_value': np.float32(4.4), + 'np_float64_value': np.float64(5.5), + } + + test_utils.sync_global_processes('setUp') + + def tearDown(self): + test_utils.sync_global_processes('tearDown') + super().tearDown() + + @parameterized.product( + use_ocdbt=(True, False), + use_zarr3=(True, False), + save_concurrent_bytes=(None, 64), + load_concurrent_bytes=(None, 64), + cast_to=(None, int, float, 'scalar'), + use_compression=(True, False), + ) + async def test_simple_checkpoint( + self, + use_ocdbt: bool, + use_zarr3: bool, + save_concurrent_bytes: int | None, + load_concurrent_bytes: int | None, + cast_to: type[int | float] | str | None, + use_compression: bool, + ): + # make unit with different tests + parent_dir = epath.Path( + self.create_tempdir(f'tmp_{self._testMethodName}').full_path + ) + + init_context = context_lib.Context( + array_options=options_lib.ArrayOptions( + saving=options_lib.ArrayOptions.Saving( + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + use_compression=use_compression, + ), + loading=options_lib.ArrayOptions.Loading(), + ), + memory_options=options_lib.MemoryOptions( + write_concurrent_bytes=save_concurrent_bytes, + read_concurrent_bytes=load_concurrent_bytes, + ), + ) + + with context_lib.get_context(init_context) as context: + + handler = scalar_leaf_handler.ScalarLeafHandler() + + # serialization + use_ocdbt = context.array_options.saving.use_ocdbt + serialization_context = types.SerializationContext( + parent_dir=path_test_utils.PathAwaitingCreationWrapper(parent_dir), + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ) + serialization_params = _get_serialization_params(self.pytree) + task = await handler.serialize( + params=serialization_params, + serialization_context=serialization_context, + ) + await task + test_utils.sync_global_processes( + f'{self._testMethodName}_serialize_complete' + ) + + # try finalize + if use_ocdbt and multihost.is_primary_host( + context.multiprocessing_options.primary_host + ): + await ocdbt_utils.merge_ocdbt_per_process_files( + parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + use_zarr3=context.array_options.saving.use_zarr3, + ) + + test_utils.sync_global_processes( + f'{self._testMethodName}_finalize_complete' + ) + + # deserialization + deserialization_context = types.DeserializationContext( + parent_dir=parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ocdbt_checkpoint=use_ocdbt, + zarr3_checkpoint=use_zarr3, + ) + + deserialization_task = await handler.deserialize( + _get_deserialization_params( + self.pytree, cast_to=cast_to, pass_scalar=(cast_to == 'scalar') + ), + deserialization_context=deserialization_context, + ) + + restored = await deserialization_task + + self._validate(serialization_params, restored, cast_to=cast_to) + + # validate whether compression used + self.assertEqual( + test_utils.is_compression_used( + parent_dir, + serialization_params[0].name, + use_zarr3, + use_ocdbt, + ), + use_compression, + ) + + def _validate(self, serialization_params, restored, cast_to=None): + for p, restored_scalar in zip(serialization_params, restored): + expected_value = p.value + if cast_to in (int, float): + expected_value = cast_to(p.value) + self.assertEqual(expected_value, restored_scalar) + + async def test_metadata(self): + # make unit with different tests + parent_dir = epath.Path(self.create_tempdir('test_metadata').full_path) + + with context_lib.get_context() as context: + + handler = scalar_leaf_handler.ScalarLeafHandler() + + # serialization + use_ocdbt = context.array_options.saving.use_ocdbt + serialization_context = types.SerializationContext( + parent_dir=path_test_utils.PathAwaitingCreationWrapper(parent_dir), + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ) + serialization_params = _get_serialization_params(self.pytree) + task = await handler.serialize( + params=serialization_params, + serialization_context=serialization_context, + ) + await task + test_utils.sync_global_processes( + f'{self._testMethodName}_serialize_complete' + ) + + # try finalize + if use_ocdbt and multihost.is_primary_host( + context.multiprocessing_options.primary_host + ): + await ocdbt_utils.merge_ocdbt_per_process_files( + parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + use_zarr3=context.array_options.saving.use_zarr3, + ) + + test_utils.sync_global_processes( + f'{self._testMethodName}_finalize_complete' + ) + test_utils.print_directory(parent_dir) + + # load the metadata + use_ocdbt = context.array_options.saving.use_ocdbt + use_zarr3 = context.array_options.saving.use_zarr3 + deserialization_context = types.DeserializationContext( + parent_dir=parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ocdbt_checkpoint=use_ocdbt, + zarr3_checkpoint=use_zarr3, + ) + + metadata = await handler.metadata( + _get_metadata_params(self.pytree), + deserialization_context=deserialization_context, + ) + + # validate metadata + for p, m in zip(serialization_params, metadata): + expected_v = p.value + if isinstance(expected_v, (int, np.integer)): + expected_type = int + elif isinstance(expected_v, (float, np.floating)): + expected_type = float + else: + raise ValueError(f'Unsupported type: {type(expected_v)}') + self.assertIsInstance(m, expected_type) + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler_test.py new file mode 100644 index 000000000..78085bcda --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler_test.py @@ -0,0 +1,210 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unittests for StringLeafHandler.""" + +from typing import Type +import unittest + +from absl import flags +from absl.testing import parameterized +from etils import epath +import jax +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.serialization import ocdbt_utils +from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.serialization import string_leaf_handler +from orbax.checkpoint.experimental.v1._src.serialization import types +from orbax.checkpoint.experimental.v1._src.synchronization import multihost +from orbax.checkpoint.experimental.v1._src.testing import path_utils as path_test_utils + +from orbax.checkpoint._src.testing.oss import multiprocess_test + +FLAGS = flags.FLAGS + + +def _get_serialization_params(pytree): + return [ + string_leaf_handler.StringSerializationParam( + keypath=keypath, + value=string, + ) + for keypath, string in jax.tree.flatten_with_path(pytree)[0] + ] + + +def _get_deserialization_params(pytree, abstract_leaf=None): + return [ + string_leaf_handler.StringDeserializationParam( + keypath=keypath, + value=abstract_leaf, + ) + for keypath, _ in jax.tree.flatten_with_path(pytree)[0] + ] + + +def _get_metadata_params(pytree): + return [ + types.DeserializationParam[None]( + keypath=keypath, + ) + for keypath, _ in jax.tree.flatten_with_path(pytree)[0] + ] + + +class StringLeafHandlerTest( + unittest.IsolatedAsyncioTestCase, parameterized.TestCase +): + + def setUp(self): + super().setUp() + self.pytree = { + 'a': 'some_string1', + 'b': 'some_string2', + 'c': '123', + } + + test_utils.sync_global_processes('setUp') + + def tearDown(self): + test_utils.sync_global_processes('tearDown') + super().tearDown() + + @parameterized.product( + abstract_leaf=(None, str), # should have no effects. + ) + async def test_simple_checkpoint( + self, + abstract_leaf: Type[str] | None, + ): + # Use different tests path for each test case. + parent_dir = epath.Path( + self.create_tempdir(f'tmp_{self._testMethodName}').full_path + ) + + init_context = context_lib.Context() + + with context_lib.get_context(init_context) as context: + + handler = string_leaf_handler.StringLeafHandler() + + # Serialize the self.pytree. + use_ocdbt = context.array_options.saving.use_ocdbt + serialization_context = types.SerializationContext( + parent_dir=path_test_utils.PathAwaitingCreationWrapper(parent_dir), + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ) + serialization_params = _get_serialization_params(self.pytree) + task = await handler.serialize( + params=serialization_params, + serialization_context=serialization_context, + ) + await task + test_utils.sync_global_processes( + f'{self._testMethodName}_serialize_complete' + ) + + # Try finalize the checkpoint. + if use_ocdbt and multihost.is_primary_host( + context.multiprocessing_options.primary_host + ): + await ocdbt_utils.merge_ocdbt_per_process_files( + parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + use_zarr3=context.array_options.saving.use_zarr3, + ) + + test_utils.sync_global_processes( + f'{self._testMethodName}_finalize_complete' + ) + + # Deserialize the self.pytree from the stored checkpoint. + deserialization_context = types.DeserializationContext( + parent_dir=parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ocdbt_checkpoint=use_ocdbt, + zarr3_checkpoint=False, + ) + + deserialization_task = await handler.deserialize( + _get_deserialization_params(self.pytree, abstract_leaf=abstract_leaf), + deserialization_context=deserialization_context, + ) + + restored = await deserialization_task + + for p, restored_string in zip(serialization_params, restored): + expected_value = p.value + self.assertEqual(expected_value, restored_string) + + async def test_metadata(self): + # make unit with different tests + parent_dir = epath.Path(self.create_tempdir('test_metadata').full_path) + + with context_lib.get_context() as context: + + handler = string_leaf_handler.StringLeafHandler() + + # serialization + use_ocdbt = context.array_options.saving.use_ocdbt + serialization_context = types.SerializationContext( + parent_dir=path_test_utils.PathAwaitingCreationWrapper(parent_dir), + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ) + serialization_params = _get_serialization_params(self.pytree) + task = await handler.serialize( + params=serialization_params, + serialization_context=serialization_context, + ) + await task + test_utils.sync_global_processes( + f'{self._testMethodName}_serialize_complete' + ) + + # try finalize + if use_ocdbt and multihost.is_primary_host( + context.multiprocessing_options.primary_host + ): + await ocdbt_utils.merge_ocdbt_per_process_files( + parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + use_zarr3=context.array_options.saving.use_zarr3, + ) + + test_utils.sync_global_processes( + f'{self._testMethodName}_finalize_complete' + ) + test_utils.print_directory(parent_dir) + + # load the metadata + use_ocdbt = context.array_options.saving.use_ocdbt + use_zarr3 = context.array_options.saving.use_zarr3 + deserialization_context = types.DeserializationContext( + parent_dir=parent_dir, + ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), + ocdbt_checkpoint=use_ocdbt, + zarr3_checkpoint=use_zarr3, + ) + + metadata = await handler.metadata( + _get_metadata_params(self.pytree), + deserialization_context=deserialization_context, + ) + + self.assertEqual(metadata, ['string'] * len(self.pytree)) + + +if __name__ == '__main__': + multiprocess_test.main()