Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ effort in maintaining. Safetensors is instead the recommended conversion case.
- #v1 Allow a context to be default-configured for all `Checkpointer`
operations.

### Removed

- #v1 Remove `LeafHandler` as a user-exposed layer (it remains as an internal layer).

## [0.11.39] - 2026-05-06

### Added
Expand Down
1 change: 0 additions & 1 deletion checkpoint/orbax/checkpoint/experimental/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from orbax.checkpoint.experimental.v1 import handlers
from orbax.checkpoint.experimental.v1 import partial
from orbax.checkpoint.experimental.v1 import path
from orbax.checkpoint.experimental.v1 import serialization
from orbax.checkpoint.experimental.v1 import training
from orbax.checkpoint.experimental.v1 import tree
from orbax.checkpoint.experimental.v1._src.synchronization.types import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,9 @@ class PyTreeOptions:
This dataclass defines the configuration parameters for creating and managing
PyTree saving and loading on disk.

# TODO: Include an example of registering a custom LeafHandler.

Attributes:
saving: Options for saving PyTrees.
loading: Options for loading PyTrees.
leaf_handler_registry: Optional Leaf Handler Registry. If provided, it will
override the default Leaf Handler Registry.
"""

@dataclasses.dataclass(frozen=True, kw_only=True)
Expand All @@ -214,7 +210,6 @@ class Loading:

saving: Saving = dataclasses.field(default_factory=Saving)
loading: Loading = dataclasses.field(default_factory=Loading)
leaf_handler_registry: serialization_types.LeafHandlerRegistry | None = None


@dataclasses.dataclass(frozen=True, kw_only=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def __init__(
array_metadata_validator: array_metadata_store_lib.Validator = (
array_metadata_store_lib.Validator()
),
leaf_handler_registry: (
serialization_types.LeafHandlerRegistry | None
) = None,
partial_save_mode: bool = False,
):
context = context_lib.get_context(context)
Expand All @@ -301,9 +304,7 @@ def __init__(
self._partial_save_mode = partial_save_mode

self._leaf_handler_registry = (
self._context.pytree_options.leaf_handler_registry
if self._context.pytree_options.leaf_handler_registry is not None
else registry.StandardLeafHandlerRegistry()
leaf_handler_registry or registry.StandardLeafHandlerRegistry()
)

type_handler_registry = compatibility.get_v0_type_handler_registry(
Expand Down Expand Up @@ -437,13 +438,10 @@ async def load(
abstract_checkpointable: The abstract checkpointable to load into. If
None, the handler will attempt to load the entire checkpoint using the
recorded metadata. Otherwise, the `abstract_checkpointable` is expected
to be a PyTree of abstract leaves. See
:py:class:`~.v1.serialization.LeafHandler` for more details. The
abstract leaf may be a value of type `AbstractLeaf`,
`Type[AbstractLeaf]`, or `None`. E.g. if the `AbstractLeaf` is
`AbstractFoo`, it is always valid to pass `AbstractFoo()` or
`AbstractFoo` or `None`. Passing the latter two indicates that metadata
should be used to restore the leaf.
to be a PyTree of abstract leaves. The abstract leaf may be a value of
type :py:class:`~.v1.tree.AbstractLeaf`,
`Type[AbstractLeaf]`, or `None`. Passing the latter two indicates that
the metadata should be used to restore the leaf.

Returns:
A awaitable which can be awaited to complete the load operation and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,13 @@ def handler_with_options(
loading=options_lib.PyTreeOptions.Loading(
partial_load=partial_load,
),
leaf_handler_registry=leaf_handler_registry,
),
)

handler = handler_test_utils.create_test_handler(
pytree_handler.PyTreeHandler, context=context
pytree_handler.PyTreeHandler,
context=context,
leaf_handler_registry=leaf_handler_registry,
)

try:
Expand Down Expand Up @@ -2288,6 +2289,63 @@ def _as_abstract_type(x):
array_metadata_store=array_metadata_store,
)

def test_custom_array_type(self):
# Set up local context with custom registry.
custom_registry = registry.StandardLeafHandlerRegistry()
custom_registry.add(
handler_test_utils.LazyArray,
handler_test_utils.AbstractLazyArray,
handler_test_utils.LazyArrayHandler,
)

mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
lazy_arr = handler_test_utils.LazyArray(
create_sharded_array(np.arange(16), sharding)
)
pytree = {'a': lazy_arr}

with handler_with_options(
use_ocdbt=False, leaf_handler_registry=custom_registry
) as handler:
handler.save(self.directory, pytree)

# Attempt to load without context (using global default registry), which
# should fail
with handler_with_options(use_ocdbt=False) as handler:
with self.assertRaisesRegex(ValueError, 'TypeHandler lookup failed'):
handler.load(self.directory)

# Load with the custom registry context
with handler_with_options(
use_ocdbt=False, leaf_handler_registry=custom_registry
) as handler:
loaded = handler.load(self.directory)
self.assertEqual(loaded['a'].array.shape, lazy_arr.array.shape)
np.testing.assert_array_equal(loaded['a'].array, lazy_arr.array)

# Load custom array directly as jax.Array by mapping secondary_typestr
custom_registry2 = registry.StandardLeafHandlerRegistry()
# Override the default jax.Array handler with LazyArray typestr,
# ensuring that the serialized jax.array annotated with original LazyArray
# typestr is loaded as a jax.Array.
custom_registry2.add(
jax.Array,
serialization_types.AbstractShardedArray,
array_leaf_handler.ArrayLeafHandler,
secondary_typestrs=[
serialization_types.typestr(handler_test_utils.LazyArrayHandler)
],
override=True,
)

with handler_with_options(
use_ocdbt=False, leaf_handler_registry=custom_registry2
) as handler:
loaded_as_jax_array = handler.load(self.directory)
self.assertIsInstance(loaded_as_jax_array['a'], jax.Array)
np.testing.assert_array_equal(loaded_as_jax_array['a'], lazy_arr.array)

def test_abstract_array_loading(self):
replicated_sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('x',)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ def save_pytree_async(
Args:
path: The path to save the checkpoint to.
pytree: The PyTree to save. This may be any JAX PyTree (including custom
objects registered as PyTrees) consisting of supported leaf types. Default
supported leaf types include `jax.Array`, `np.ndarray`, simple types like
`int`, `float`, `str`, and empty nodes. Support for custom leaves is also
possible by implementing a :py:class:`.LeafHandler`.
objects registered as PyTrees) consisting of supported leaf types (see
:py:class:`~.v1.tree.Leaf`). Default supported leaf types include
`jax.Array`, `np.ndarray`, simple types like
`int`, `float`, `str`, and empty nodes.
custom_metadata: User-provided custom metadata. An arbitrary
JSON-serializable dictionary the user can use to store additional
information. The field is treated as opaque by Orbax.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,3 @@ def resolve_storage_options(
chunk_byte_size=resolved_chunk_byte_size,
shard_axes=resolved_shard_axes if resolved_shard_axes is not None else (),
)

Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@
import dataclasses
from typing import Any, Awaitable, Generic, Protocol, Sequence, Tuple, Type, TypeVar

import jax
import jax.experimental.layout as jax_layout
import numpy as np
from orbax.checkpoint._src.arrays import types as arrays_types
from orbax.checkpoint._src.serialization import limits
from orbax.checkpoint._src.serialization import types as serialization_types
from orbax.checkpoint._src.tree import utils as tree_utils
from orbax.checkpoint.experimental.v1._src.arrays import types as array_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
import tensorstore as ts
Expand All @@ -32,66 +29,34 @@
Leaf = TypeVar('Leaf')
AbstractLeaf = TypeVar('AbstractLeaf')

Shape = arrays_types.Shape
DType = arrays_types.DType
Shape = array_types.Shape
DType = array_types.DType

PLACEHOLDER = ...

IsPrioritizedKeyFn = serialization_types.IsPrioritizedKeyFn

Scalar = int | float | np.number
### STANDARD PYTREE LEAF TYPES

### SCALAR
Scalar = array_types.Scalar
# Optional type hint for a scalar leaf handler. If provided, the restored scalar
# will be cast to this type. Only casting to int or float is supported.
AbstractScalar = Scalar
AbstractString = str

if jax.__version_info__ >= (0, 6, 2):
Format = jax_layout.Format
else:
Format = jax_layout.Layout


class AbstractArray(Protocol):
"""Abstract representation of an array.

This is a protocol for an abstract array that can be used to represent
the metadata belonging to an array.

shape:
Tuple of integers describing the array shape.
dtype:
Dtype of array elements.
"""

shape: Shape | None
dtype: DType | None


class AbstractShardedArray(Protocol):
"""Abstract representation of an array.

This is a protocol for an abstract array that can be used to represent various
metadata types such as :py:class:`jax.ShapeDtypeStruct` and
:py:class:`~orbax.checkpoint.metadata.ArrayMetadata`.

#TODO(dnlng): All attributes are made optional to support the case where
# the ArrayMetadata is passed into the metadata() call to pass only the
# `write_shape`. Optional attributes are not needed once write_shape is
# refactored.
### STRING
# str
AbstractString = str

### ARRAY
# np.ndarray
AbstractArray = array_types.AbstractArray

shape:
Tuple of integers describing the array shape.
dtype:
Dtype of array elements.
Sharding:
Sharding to indicate how the array is sharded. This can be jax's Sharding or
Layout or None.
"""
### SHARDED ARRAY
# jax.Array
AbstractShardedArray = array_types.AbstractShardedArray

shape: Shape | None
dtype: DType | None
sharding: jax.sharding.Sharding | Format | None = None # pytype: disable=invalid-annotation
###


def is_placeholder(value: Any) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.path import async_utils
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import array_leaf_handler
from orbax.checkpoint.experimental.v1._src.serialization import registry as serialization_registry
from orbax.checkpoint.experimental.v1._src.serialization import types
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils
from orbax.checkpoint.experimental.v1._src.testing import handler_utils
from orbax.checkpoint.experimental.v1._src.testing import tree_utils as tree_test_utils
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types


PyTree = tree_types.PyTree
Path = path_types.Path
InvalidLayoutError = checkpoint_layout.InvalidLayoutError
Expand Down Expand Up @@ -408,66 +405,6 @@ def test_leaf_change_type(self):
),
)

def test_custom_array_type(self):
# Set up local context with custom registry.
custom_registry = serialization_registry.StandardLeafHandlerRegistry()
custom_registry.add(
handler_utils.LazyArray,
handler_utils.AbstractLazyArray,
handler_utils.LazyArrayHandler,
)

custom_context = ocp.Context(
pytree_options=ocp.options.PyTreeOptions(
leaf_handler_registry=custom_registry
)
)

mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
lazy_arr = handler_utils.LazyArray(
create_sharded_array(np.arange(16), sharding)
)
pytree = {'a': lazy_arr}

with custom_context:
ocp.save_pytree(self.directory, pytree)

# Attempt to load without context (using global default registry), which
# should fail
with self.assertRaisesRegex(ValueError, 'TypeHandler lookup failed'):
ocp.load_pytree(self.directory)

# Load with the custom registry context
with custom_context:
loaded = ocp.load_pytree(self.directory)
self.assertEqual(loaded['a'].array.shape, lazy_arr.array.shape)
np.testing.assert_array_equal(loaded['a'].array, lazy_arr.array)

# Load custom array directly as jax.Array by mapping secondary_typestr
custom_registry2 = serialization_registry.StandardLeafHandlerRegistry()
# Override the default jax.Array handler with LazyArray typestr,
# ensuring that the serialized jax.array annotated with original LazyArray
# typestr is loaded as a jax.Array.
custom_registry2.add(
jax.Array,
types.AbstractShardedArray,
array_leaf_handler.ArrayLeafHandler,
secondary_typestrs=[types.typestr(handler_utils.LazyArrayHandler)],
override=True,
)
custom_context2 = ocp.Context(
pytree_options=ocp.options.PyTreeOptions(
leaf_handler_registry=custom_registry2
)
)
with custom_context2:
loaded_as_jax_array = ocp.load_pytree(self.directory)
self.assertIsInstance(loaded_as_jax_array['a'], jax.Array)
np.testing.assert_array_equal(
loaded_as_jax_array['a'], lazy_arr.array
)

def test_empty_array(self):
value = np.ones(shape=(0,))
with self.assertRaisesRegex(ValueError, 'zero size'):
Expand Down Expand Up @@ -735,9 +672,7 @@ def test_missing_keys(self):
test_utils.assert_tree_equal(self, self.numpy_pytree, loaded)

with self.subTest('load_checkpointables'):
with self.assertRaisesRegex(
KeyError, 'Requested checkpointables:'
):
with self.assertRaisesRegex(KeyError, 'Requested checkpointables:'):
ocp.load_checkpointables(
self.directory, {'foo': handler_utils.AbstractFoo()}
)
Expand Down Expand Up @@ -847,9 +782,7 @@ def test_save_checkpointables_deleted(self):
loaded = ocp.load_checkpointables(self.directory)
self.assertSameElements(['two'], loaded.keys())

with self.assertRaisesRegex(
KeyError, 'Requested checkpointables:'
):
with self.assertRaisesRegex(KeyError, 'Requested checkpointables:'):
ocp.load_checkpointables(self.directory, {'one': None})


Expand Down Expand Up @@ -1242,6 +1175,7 @@ def test_subchunking(self):
self.assertEqual(metadata[k].storage_metadata.chunk_shape, (2,))

with self.subTest('per_key_setting'):

def scoped_storage_options_creator(key, value):
del value
if 'a' in tree_utils.str_keypath(key):
Expand All @@ -1251,6 +1185,7 @@ def scoped_storage_options_creator(key, value):
return ocp.options.ArrayOptions.Saving.StorageOptions(
chunk_byte_size=8, # force divide in 2 subchunks
)

with ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(
Expand Down
Loading
Loading