Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
)
from orbax.checkpoint.experimental.v1 import multihost
from orbax.checkpoint.experimental.v1.handlers import (
Checkpointable,
AbstractCheckpointable,
CheckpointableHandler,
StatefulCheckpointable,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,4 @@

from orbax.checkpoint._src.arrays import abstract_arrays

ArrayLike = abstract_arrays.ArrayLike

to_shape_dtype_struct = abstract_arrays.to_shape_dtype_struct
80 changes: 80 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/arrays/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Array type definitions."""

from typing import Protocol, TypeAlias

import jax
from jax import numpy as jnp
import jax.experimental.layout as jax_layout
import numpy as np

Shape = tuple[int, ...]
DType = jnp.dtype | np.dtype

if jax.__version_info__ >= (0, 6, 2):
Format = jax_layout.Format
else:
Format = jax_layout.Layout


Scalar: TypeAlias = int | float | np.number | bytes | bool
AbstractScalar = Scalar


class AbstractArray(Protocol):
"""Abstract representation of an array.

This is a protocol for an abstract array that can be used to represent
the metadata belonging to an array.

shape:
Tuple of integers describing the array shape.
dtype:
Dtype of array elements.
"""

shape: Shape | None
dtype: DType | None


class AbstractShardedArray(Protocol):
"""Abstract representation of an array.

This is a protocol for an abstract array that can be used to represent various
metadata types such as :py:class:`jax.ShapeDtypeStruct` and
:py:class:`~orbax.checkpoint.metadata.ArrayMetadata`.

#TODO(dnlng): All attributes are made optional to support the case where
# the ArrayMetadata is passed into the metadata() call to pass only the
# `write_shape`. Optional attributes are not needed once write_shape is
# refactored.


shape:
Tuple of integers describing the array shape.
dtype:
Dtype of array elements.
Sharding:
Sharding to indicate how the array is sharded. This can be jax's Sharding or
Layout or None.
"""

shape: Shape | None
dtype: DType | None
sharding: jax.sharding.Sharding | Format | None = None # pytype: disable=invalid-annotation


ArrayLike: TypeAlias = AbstractArray | AbstractShardedArray
98 changes: 73 additions & 25 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,59 @@

"""Defines types for `CheckpointableHandler`."""

import typing
from typing import Any, Awaitable, Protocol, Type, TypeVar, runtime_checkable
from orbax.checkpoint.experimental.v1._src.path import types as path_types


T = TypeVar('T')
AbstractT = TypeVar('AbstractT')
if typing.TYPE_CHECKING:
Checkpointable = Any
AbstractCheckpointable = Any
else:

class Checkpointable:
"""A logical piece of a checkpoint that is separable and distinct from other pieces.

For example, model state ('params', 'opt_state') vs. dataset_iterator vs.
embeddings vs. other custom states. Each is handled differently in training
code and is often represented by a distinct container or object type. Each
is also used (or not used) in different contexts (training,
evaluations, inference).

See also :py:type:`.AbstractCheckpointable`.

Checkpointables are typically:

(1) separable; they may or may not be loaded
concurrently and some may be omitted from the checkpoint entirely (e.g. the
dataset
iterator checkpointable is not needed for evals);

(2) different types, with correspondingly different on-disk representations
(e.g. the dataset
iterator checkpointable is a lightweight index, while model weights are
large distributed arrays).
"""

pass

class AbstractCheckpointable:
"""An "abstract checkpointable" is the abstract form of a :py:type:`.Checkpointable`.

The abstract form is used to:

(1) Customize properties of a loaded checkpointable (e.g. specify a tree of
`jax.ShapeDtypeStruct` to
load a tree of `jax.Array` with desired shardings).

(2) Represent metadata returned by metadata accessor functions.
"""

pass


_Checkpointable = TypeVar('_Checkpointable')
_AbstractCheckpointable = TypeVar('_AbstractCheckpointable')


@runtime_checkable
Expand All @@ -37,21 +84,13 @@ async def load(self, directory: path_types.Path) -> Awaitable[None]:
...


class CheckpointableHandler(Protocol[T, AbstractT]):
class CheckpointableHandler(Protocol[_Checkpointable, _AbstractCheckpointable]):
"""An interface that defines save/load logic for a `checkpointable` object.

NOTE: Prefer to use
:py:class:`.StatefulCheckpointable`
interface when possible.

A "checkpointable" is a fundamental concept in Orbax. A “checkpointable”
refers to a logical piece of the checkpoint that is distinct in some way from
other pieces. Checkpointables are separable; they may or may not be loaded
concurrently and some may be omitted from the checkpoint entirely.
Checkpointables are often represented by different types, and have different
representations on disk. The quintessential example is model params vs.
dataset.

A PyTree of arrays, representing model parameters, is the most basic
"checkpointable". A singular array is also a checkpointable.

Expand Down Expand Up @@ -92,8 +131,9 @@ class CheckpointableHandler(Protocol[T, AbstractT]):

To create a custom handler, you must define a class that implements the
methods defined in this Protocol. The class should be generic over the
concrete type `T` (the object being saved/loaded) and the abstract type
`AbstractT` (the lightweight metadata representation).
concrete type `Checkpointable` (the object being saved/loaded) and the
abstract type
`AbstractCheckpointable` (the lightweight metadata representation).

Crucially, once implemented, the handler must be registered with the
global registry or a context-local registry so that `save_checkpointables`
Expand Down Expand Up @@ -143,7 +183,7 @@ def is_abstract_handleable(
)

In many cases, no information is needed for loading. In this case,
`AbstractT` may be defined as `None`. For example::
`AbstractCheckpointable` may be defined as `None`. For example::

class FooHandler(CheckpointableHandler[Foo, None]):

Expand All @@ -155,7 +195,9 @@ def is_abstract_handleable(self, abstract_checkpointable: None) -> bool:
"""

async def save(
self, directory: path_types.PathAwaitingCreation, checkpointable: T
self,
directory: path_types.PathAwaitingCreation,
checkpointable: _Checkpointable,
) -> Awaitable[None]:
"""Saves the given `checkpointable` to the given `directory`.

Expand Down Expand Up @@ -204,8 +246,8 @@ async def save(
async def load(
self,
directory: path_types.Path,
abstract_checkpointable: AbstractT | None = None,
) -> Awaitable[T]:
abstract_checkpointable: _AbstractCheckpointable | None = None,
) -> Awaitable[_Checkpointable]:
"""Loads the checkpointable from the given `directory`.

Args:
Expand All @@ -214,31 +256,35 @@ async def load(
checkpointable to load. If provided, this is used to provide properties
to guide the restoration logic of the checkpoint. In the case of arrays,
for example, this conveys properties like shape and dtype, for casting
and reshaping. In some cases, no information is needed, and `AbstractT`
may always be None. In other cases, the abstract representation may be a
hard requirement for loading.
and reshaping. In some cases, no information is needed, and
`AbstractCheckpointable` may always be None. In other cases, the
abstract representation may be a hard requirement for loading.

Returns:
An `Awaitable` that continues to load the checkpointable in the background
and returns the loaded checkpointable when complete.
"""
...

async def metadata(self, directory: path_types.Path) -> AbstractT:
async def metadata(
self, directory: path_types.Path
) -> _AbstractCheckpointable:
"""Returns the metadata for the given `directory`.

The logic in this method must be executed fully in the main thread; metadata
access is expected to be cheap and fast.

In many cases it is desirable to return additional metadata properties
beyond the limited set in `AbstractT`. In this case, `AbstractT` should
beyond the limited set in `AbstractCheckpointable`. In this case,
`AbstractCheckpointable` should
be subclasses, and this subclass can be returned from `metadata`.

Args:
directory: The directory where the checkpoint is located.

Returns:
AbstractT: The metadata is an `AbstractT`, which is the abstract
AbstractT: The metadata is an `AbstractCheckpointable`, which is the
abstract
representation of the checkpointable.
"""
...
Expand All @@ -262,10 +308,12 @@ def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None:
"""Returns whether the handler can handle the abstract checkpointable.

The method should return `True` if it is possible to use the given
`abstract_checkpointable` for loading a concrete `T`. Note that `None` is
`abstract_checkpointable` for loading a concrete `Checkpointable`. Note that
`None` is
always considered handleable for loading, so this method does not need to
check for it. If an implementation defines
`AbstractT` as `None`, then this method should only return True for values
`AbstractCheckpointable` as `None`, then this method should only return True
for values
of `None`.

See class docstring for more details.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ async def load_pytree(
path: Path,
checkpointable_name: str | None = None,
abstract_pytree: (
tree_types.PyTreeOf[tree_types.AbstractLeafType] | None
tree_types.PyTreeOf[tree_types.AbstractLeaf] | None
) = None,
) -> Awaitable[Any]:
"""Loads pytree specified by `checkpointable_name`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ async def load_pytree(
path: Path,
checkpointable_name: str | None = None,
abstract_pytree: (
tree_types.PyTreeOf[tree_types.AbstractLeafType] | None
tree_types.PyTreeOf[tree_types.AbstractLeaf] | None
) = None,
) -> Awaitable[Any]:
"""Loads a V0 PyTree checkpoint.
Expand Down
Loading
Loading