diff --git a/checkpoint/orbax/checkpoint/experimental/v1/__init__.py b/checkpoint/orbax/checkpoint/experimental/v1/__init__.py index b29ca1668..f2d092cf9 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/__init__.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/__init__.py @@ -34,6 +34,8 @@ ) from orbax.checkpoint.experimental.v1 import multihost from orbax.checkpoint.experimental.v1.handlers import ( + Checkpointable, + AbstractCheckpointable, CheckpointableHandler, StatefulCheckpointable, ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/arrays/abstract_arrays.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/arrays/abstract_arrays.py index 899b2ec84..f1329c85f 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/arrays/abstract_arrays.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/arrays/abstract_arrays.py @@ -16,6 +16,4 @@ from orbax.checkpoint._src.arrays import abstract_arrays -ArrayLike = abstract_arrays.ArrayLike - to_shape_dtype_struct = abstract_arrays.to_shape_dtype_struct diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/arrays/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/arrays/types.py new file mode 100644 index 000000000..2f1981800 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/arrays/types.py @@ -0,0 +1,80 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Array type definitions.""" + +from typing import Protocol, TypeAlias + +import jax +from jax import numpy as jnp +import jax.experimental.layout as jax_layout +import numpy as np + +Shape = tuple[int, ...] +DType = jnp.dtype | np.dtype + +if jax.__version_info__ >= (0, 6, 2): + Format = jax_layout.Format +else: + Format = jax_layout.Layout + + +Scalar: TypeAlias = int | float | np.number | bytes | bool +AbstractScalar = Scalar + + +class AbstractArray(Protocol): + """Abstract representation of an array. + + This is a protocol for an abstract array that can be used to represent + the metadata belonging to an array. + + shape: + Tuple of integers describing the array shape. + dtype: + Dtype of array elements. + """ + + shape: Shape | None + dtype: DType | None + + +class AbstractShardedArray(Protocol): + """Abstract representation of an array. + + This is a protocol for an abstract array that can be used to represent various + metadata types such as :py:class:`jax.ShapeDtypeStruct` and + :py:class:`~orbax.checkpoint.metadata.ArrayMetadata`. + + #TODO(dnlng): All attributes are made optional to support the case where + # the ArrayMetadata is passed into the metadata() call to pass only the + # `write_shape`. Optional attributes are not needed once write_shape is + # refactored. + + + shape: + Tuple of integers describing the array shape. + dtype: + Dtype of array elements. + Sharding: + Sharding to indicate how the array is sharded. This can be jax's Sharding or + Layout or None. + """ + + shape: Shape | None + dtype: DType | None + sharding: jax.sharding.Sharding | Format | None = None # pytype: disable=invalid-annotation + + +ArrayLike: TypeAlias = AbstractArray | AbstractShardedArray diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/types.py index 4c81e7f3b..49301179d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/types.py @@ -14,12 +14,59 @@ """Defines types for `CheckpointableHandler`.""" +import typing from typing import Any, Awaitable, Protocol, Type, TypeVar, runtime_checkable from orbax.checkpoint.experimental.v1._src.path import types as path_types -T = TypeVar('T') -AbstractT = TypeVar('AbstractT') +if typing.TYPE_CHECKING: + Checkpointable = Any + AbstractCheckpointable = Any +else: + + class Checkpointable: + """A logical piece of a checkpoint that is separable and distinct from other pieces. + + For example, model state ('params', 'opt_state') vs. dataset_iterator vs. + embeddings vs. other custom states. Each is handled differently in training + code and is often represented by a distinct container or object type. Each + is also used (or not used) in different contexts (training, + evaluations, inference). + + See also :py:type:`.AbstractCheckpointable`. + + Checkpointables are typically: + + (1) separable; they may or may not be loaded + concurrently and some may be omitted from the checkpoint entirely (e.g. the + dataset + iterator checkpointable is not needed for evals); + + (2) different types, with correspondingly different on-disk representations + (e.g. the dataset + iterator checkpointable is a lightweight index, while model weights are + large distributed arrays). + """ + + pass + + class AbstractCheckpointable: + """An "abstract checkpointable" is the abstract form of a :py:type:`.Checkpointable`. + + The abstract form is used to: + + (1) Customize properties of a loaded checkpointable (e.g. specify a tree of + `jax.ShapeDtypeStruct` to + load a tree of `jax.Array` with desired shardings). + + (2) Represent metadata returned by metadata accessor functions. + """ + + pass + + +_Checkpointable = TypeVar('_Checkpointable') +_AbstractCheckpointable = TypeVar('_AbstractCheckpointable') @runtime_checkable @@ -37,21 +84,13 @@ async def load(self, directory: path_types.Path) -> Awaitable[None]: ... -class CheckpointableHandler(Protocol[T, AbstractT]): +class CheckpointableHandler(Protocol[_Checkpointable, _AbstractCheckpointable]): """An interface that defines save/load logic for a `checkpointable` object. NOTE: Prefer to use :py:class:`.StatefulCheckpointable` interface when possible. - A "checkpointable" is a fundamental concept in Orbax. A “checkpointable” - refers to a logical piece of the checkpoint that is distinct in some way from - other pieces. Checkpointables are separable; they may or may not be loaded - concurrently and some may be omitted from the checkpoint entirely. - Checkpointables are often represented by different types, and have different - representations on disk. The quintessential example is model params vs. - dataset. - A PyTree of arrays, representing model parameters, is the most basic "checkpointable". A singular array is also a checkpointable. @@ -92,8 +131,9 @@ class CheckpointableHandler(Protocol[T, AbstractT]): To create a custom handler, you must define a class that implements the methods defined in this Protocol. The class should be generic over the - concrete type `T` (the object being saved/loaded) and the abstract type - `AbstractT` (the lightweight metadata representation). + concrete type `Checkpointable` (the object being saved/loaded) and the + abstract type + `AbstractCheckpointable` (the lightweight metadata representation). Crucially, once implemented, the handler must be registered with the global registry or a context-local registry so that `save_checkpointables` @@ -143,7 +183,7 @@ def is_abstract_handleable( ) In many cases, no information is needed for loading. In this case, - `AbstractT` may be defined as `None`. For example:: + `AbstractCheckpointable` may be defined as `None`. For example:: class FooHandler(CheckpointableHandler[Foo, None]): @@ -155,7 +195,9 @@ def is_abstract_handleable(self, abstract_checkpointable: None) -> bool: """ async def save( - self, directory: path_types.PathAwaitingCreation, checkpointable: T + self, + directory: path_types.PathAwaitingCreation, + checkpointable: _Checkpointable, ) -> Awaitable[None]: """Saves the given `checkpointable` to the given `directory`. @@ -204,8 +246,8 @@ async def save( async def load( self, directory: path_types.Path, - abstract_checkpointable: AbstractT | None = None, - ) -> Awaitable[T]: + abstract_checkpointable: _AbstractCheckpointable | None = None, + ) -> Awaitable[_Checkpointable]: """Loads the checkpointable from the given `directory`. Args: @@ -214,9 +256,9 @@ async def load( checkpointable to load. If provided, this is used to provide properties to guide the restoration logic of the checkpoint. In the case of arrays, for example, this conveys properties like shape and dtype, for casting - and reshaping. In some cases, no information is needed, and `AbstractT` - may always be None. In other cases, the abstract representation may be a - hard requirement for loading. + and reshaping. In some cases, no information is needed, and + `AbstractCheckpointable` may always be None. In other cases, the + abstract representation may be a hard requirement for loading. Returns: An `Awaitable` that continues to load the checkpointable in the background @@ -224,21 +266,25 @@ async def load( """ ... - async def metadata(self, directory: path_types.Path) -> AbstractT: + async def metadata( + self, directory: path_types.Path + ) -> _AbstractCheckpointable: """Returns the metadata for the given `directory`. The logic in this method must be executed fully in the main thread; metadata access is expected to be cheap and fast. In many cases it is desirable to return additional metadata properties - beyond the limited set in `AbstractT`. In this case, `AbstractT` should + beyond the limited set in `AbstractCheckpointable`. In this case, + `AbstractCheckpointable` should be subclasses, and this subclass can be returned from `metadata`. Args: directory: The directory where the checkpoint is located. Returns: - AbstractT: The metadata is an `AbstractT`, which is the abstract + AbstractT: The metadata is an `AbstractCheckpointable`, which is the + abstract representation of the checkpointable. """ ... @@ -262,10 +308,12 @@ def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None: """Returns whether the handler can handle the abstract checkpointable. The method should return `True` if it is possible to use the given - `abstract_checkpointable` for loading a concrete `T`. Note that `None` is + `abstract_checkpointable` for loading a concrete `Checkpointable`. Note that + `None` is always considered handleable for loading, so this method does not need to check for it. If an implementation defines - `AbstractT` as `None`, then this method should only return True for values + `AbstractCheckpointable` as `None`, then this method should only return True + for values of `None`. See class docstring for more details. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py index 44fee0893..75421d413 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py @@ -376,7 +376,7 @@ async def load_pytree( path: Path, checkpointable_name: str | None = None, abstract_pytree: ( - tree_types.PyTreeOf[tree_types.AbstractLeafType] | None + tree_types.PyTreeOf[tree_types.AbstractLeaf] | None ) = None, ) -> Awaitable[Any]: """Loads pytree specified by `checkpointable_name`. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout.py index f207bcf75..25d92ef8a 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout.py @@ -309,7 +309,7 @@ async def load_pytree( path: Path, checkpointable_name: str | None = None, abstract_pytree: ( - tree_types.PyTreeOf[tree_types.AbstractLeafType] | None + tree_types.PyTreeOf[tree_types.AbstractLeaf] | None ) = None, ) -> Awaitable[Any]: """Loads a V0 PyTree checkpoint. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py index c7fc3f7e1..f22ff71fd 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py @@ -26,6 +26,7 @@ from orbax.checkpoint._src.logging import event_tracking from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.context import options as options_lib +from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.layout import registry as layout_registry @@ -40,10 +41,13 @@ PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY AUTO_CHECKPOINTABLE_KEY = checkpoint_layout.AUTO_CHECKPOINTABLE_KEY -AbstractPyTree = tree_types.PyTreeOf[tree_types.AbstractLeafType] +AbstractPyTree = tree_types.PyTreeOf[tree_types.AbstractLeaf] CheckpointMetadata = metadata_types.CheckpointMetadata PLACEHOLDER = ... +Checkpointable = handler_types.Checkpointable +AbstractCheckpointable = handler_types.AbstractCheckpointable + AsyncResponse = async_types.AsyncResponse @@ -105,7 +109,7 @@ def load_pytree( ) = None, *, checkpointable_name: str | None = AUTO_CHECKPOINTABLE_KEY, -) -> tree_types.PyTreeOf[tree_types.LeafType]: +) -> tree_types.PyTreeOf[tree_types.Leaf]: """Loads a PyTree. Loads from a `PyTree` checkpoint. A `PyTree` checkpoint must be a path @@ -135,8 +139,9 @@ def load_pytree( 2. The leaves of the restored tree will be restored with the properties indicated by the abstract leaves. For example, if a leaf in `abstract_pytree` is a `jax.ShapeDtypeStruct`, the restored leaf will be a `jax.Array` with the - same shape and `dtype`. Each `AbstractLeafType` has a corresponding `LeafType` - that is restored. + same shape and `dtype`. Each `AbstractLeaf` has a corresponding `Leaf` + that is restored. See `orbax.checkpoint.v1.tree` for a table + of standard supported leaf types. Example Usage: @@ -175,8 +180,8 @@ def load_pytree( dynamically discovers and resolves a pytree checkpointable. It prioritizes the standard 'pytree' checkpointable name if present, then sorts any other valid pytree checkpointable names alphabetically and returns the first - valid one, and ultimately falls back to interpreting the path as a flat - V0 root layout if no standard pytree exists. + valid one, and ultimately falls back to interpreting the path as a flat V0 + root layout if no standard pytree exists. Returns: The restored `PyTree`. @@ -216,9 +221,11 @@ def load_pytree( def load_checkpointables( path: path_types.PathLike, abstract_checkpointables: ( - dict[str, Any] | CheckpointMetadata[dict[str, Any]] | None + dict[str, AbstractCheckpointable] + | CheckpointMetadata[dict[str, AbstractCheckpointable]] + | None ) = None, -) -> dict[str, Any]: +) -> dict[str, Checkpointable]: """Loads checkpointables. See documentation for :py:func:`.save_checkpointables` for more context on @@ -352,7 +359,7 @@ def _load_impl( path: path_types.Path, load_fn: LoadFn, start_time: float, -) -> dict[str, Any] | tree_types.PyTreeOf[tree_types.LeafType]: +) -> dict[str, Checkpointable] | tree_types.PyTreeOf[tree_types.Leaf]: """Implementation of loading logic for both :py:func:`.load_checkpointables` and :py:func:`.load_pytree`. Args: @@ -370,7 +377,7 @@ def _load_impl( ctx = context_lib.get_context() - async def _load() -> Any: + async def _load() -> Checkpointable: load_awaitable = await load_fn() event_tracking.OperationRecorder( path, @@ -399,16 +406,14 @@ async def _load() -> Any: return result -class _LoadPyTreeResponse( - AsyncResponse[tree_types.PyTreeOf[tree_types.LeafType]] -): +class _LoadPyTreeResponse(AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]): """An :py:class:`.AsyncResponse` for :py:func:`.load_pytree_async`.""" def __init__( self, operation_id: str, path: path_types.Path, - background_awaitable: Awaitable[tree_types.PyTreeOf[tree_types.LeafType]], + background_awaitable: Awaitable[tree_types.PyTreeOf[tree_types.Leaf]], *, start_time: float, context: context_lib.Context, @@ -419,13 +424,13 @@ def __init__( self._start_time = start_time self._context = context self._thread_runner = thread_utils.BackgroundThreadRunner[ - tree_types.PyTreeOf[tree_types.LeafType] + tree_types.PyTreeOf[tree_types.Leaf] ](self._finalize_load()) @classmethod def create( cls, - background_awaitable: Awaitable[tree_types.PyTreeOf[tree_types.LeafType]], + background_awaitable: Awaitable[tree_types.PyTreeOf[tree_types.Leaf]], path: path_types.Path, start_time: float, *, @@ -446,7 +451,7 @@ def create( context=context, ) - async def _finalize_load(self) -> tree_types.PyTreeOf[tree_types.LeafType]: + async def _finalize_load(self) -> tree_types.PyTreeOf[tree_types.Leaf]: logging.info( '[process=%s] Waiting for background load operations', multihost.process_index(), @@ -478,7 +483,7 @@ async def _finalize_load(self) -> tree_types.PyTreeOf[tree_types.LeafType]: def result( self, timeout: float | None = None - ) -> tree_types.PyTreeOf[tree_types.LeafType]: + ) -> tree_types.PyTreeOf[tree_types.Leaf]: return self._thread_runner.result(timeout=timeout) @@ -489,7 +494,7 @@ def load_pytree_async( ) = None, *, checkpointable_name: str | None = PYTREE_CHECKPOINTABLE_KEY, -) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.LeafType]]: +) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]: """Loads a PyTree asynchronously. Currently has limited support.""" start_time = time.time() event_tracking.OperationRecorder( @@ -532,9 +537,11 @@ async def _blocking_load() -> Any: def load_checkpointables_async( path: path_types.PathLike, abstract_checkpointables: ( - dict[str, Any] | CheckpointMetadata[dict[str, Any]] | None + dict[str, AbstractCheckpointable] + | CheckpointMetadata[dict[str, AbstractCheckpointable]] + | None ) = None, -) -> async_types.AsyncResponse[dict[str, Any]]: +) -> async_types.AsyncResponse[dict[str, Checkpointable]]: """Loads checkpointables asynchronously. Not yet implemented.""" del path, abstract_checkpointables raise NotImplementedError('Asynchronous loading is not yet supported.') diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py index 112bb0ae2..b7641fd6e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py @@ -14,11 +14,10 @@ """Functions for loading metadata from a checkpoint.""" -from typing import Any - from orbax.checkpoint._src import asyncio_utils from orbax.checkpoint.experimental.v1 import errors from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.layout import registry as layout_registry @@ -33,6 +32,8 @@ PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY EMPTY_CHECKPOINTABLE_KEY = checkpoint_layout.EMPTY_CHECKPOINTABLE_KEY +AbstractCheckpointable = handler_types.AbstractCheckpointable + def pytree_metadata( path: path_types.PathLike, @@ -124,7 +125,7 @@ def _get_abstract_array(arr): def checkpointables_metadata( path: path_types.PathLike, -) -> CheckpointMetadata[dict[str, Any]]: +) -> CheckpointMetadata[dict[str, AbstractCheckpointable]]: """Loads all checkpointables metadata from a checkpoint. This function is a more general version of `pytree_metadata`. The same @@ -160,11 +161,11 @@ def checkpointables_metadata( def _checkpointables_metadata_impl( layout: checkpoint_layout.CheckpointLayout, path: path_types.Path, -) -> CheckpointMetadata[dict[str, Any]]: +) -> CheckpointMetadata[dict[str, AbstractCheckpointable]]: """Shared implementation for checkpointables_metadata.""" async def _load_metadata() -> ( - metadata_types.CheckpointMetadata[dict[str, Any]] + metadata_types.CheckpointMetadata[dict[str, AbstractCheckpointable]] ): return await layout.metadata(path) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/types.py index 501f060cc..8ef7e6207 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/types.py @@ -19,7 +19,7 @@ import datetime import pprint import typing -from typing import Any, Generic, TypeAlias, TypeVar +from typing import Any, Generic, TypeVar from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.tree import types as tree_types @@ -29,8 +29,8 @@ SerializedMetadata = TypeVar('SerializedMetadata', bound=dict[str, Any]) -PyTreeMetadata: TypeAlias = tree_types.PyTreeOf[tree_types.AbstractLeafType] -PyTreeMetadata.__doc__ = """ +PyTreeMetadata = tree_types.PyTreeOf[tree_types.AbstractLeaf] +""" Metadata describing a `PyTree` checkpoint. A serialized `PyTree` structure with the same structure as the checkpointed @@ -53,7 +53,7 @@ class CheckpointMetadata(Generic[CheckpointableMetadataT]): Note that this class has a generic type `CheckpointableMetadataT`. This will typically be either :py:data:`.PyTreeMetadata` (see above), or - `dict[str, Any]`. + `dict[str, AbstractCheckpointable]`. `CheckpointMetadata` can be accessed via one of two metadata methods. Please see :py:func:`.pytree_metadata` and :py:func:`.checkpointables_metadata` for diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py index 9b7c88887..35947c684 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py @@ -62,7 +62,7 @@ async def load(self, directory: path_types.Path) -> Awaitable[None]: def save_pytree( path: path_types.PathLike, - pytree: tree_types.PyTreeOf[tree_types.LeafType], + pytree: tree_types.PyTreeOf[tree_types.Leaf], *, custom_metadata: tree_types.JsonType | None = None, ): @@ -133,7 +133,7 @@ def save_pytree( def save_pytree_async( path: path_types.PathLike, - pytree: tree_types.PyTreeOf[tree_types.LeafType], + pytree: tree_types.PyTreeOf[tree_types.Leaf], *, custom_metadata: tree_types.JsonType | None = None, ) -> async_types.AsyncResponse[None]: @@ -201,7 +201,7 @@ def save_pytree_async( objects registered as PyTrees) consisting of supported leaf types. Default supported leaf types include `jax.Array`, `np.ndarray`, simple types like `int`, `float`, `str`, and empty nodes. Support for custom leaves is also - possible by implementing a :py:class:`.LeafTypeHandler`. + possible by implementing a :py:class:`.LeafHandler`. custom_metadata: User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py index 883c7019a..7775a345a 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py @@ -14,14 +14,13 @@ """Defines free-function interface for saving.""" -from typing import Any - from orbax.checkpoint._src.checkpointers import async_checkpointer from orbax.checkpoint._src.handlers import composite_checkpoint_handler from orbax.checkpoint._src.handlers import handler_registration as legacy_handler_registration from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.handlers import compatibility as handler_compatibility from orbax.checkpoint.experimental.v1._src.handlers import registration as handler_registration +from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.path import types as path_types @@ -31,11 +30,12 @@ from orbax.checkpoint.experimental.v1._src.tree import types as tree_types PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY +Checkpointable = handler_types.Checkpointable def save_pytree( path: path_types.PathLike, - pytree: tree_types.PyTreeOf[tree_types.LeafType], + pytree: tree_types.PyTreeOf[tree_types.Leaf], *, checkpointable_name: str = PYTREE_CHECKPOINTABLE_KEY, overwrite: bool = False, @@ -62,10 +62,9 @@ def save_pytree( Args: path: The path to save the checkpoint to. pytree: The `PyTree` to save. This may be any JAX `PyTree` (including custom - objects registered as `PyTrees`) consisting of supported leaf types. - Default supported leaf types include `jax.Array`, `np.ndarray`, simple - types like `int`, `float`, `str`, and empty nodes. Support for custom - leaves is also possible by implementing a `LeafTypeHandler`. + objects registered as `PyTrees`) consisting of supported leaf types. See + `orbax.checkpoint.experimental.v1.tree` for a table of standard supported + leaf types. checkpointable_name: The name of the checkpointable to save a pytree under. Defaults to 'pytree'. overwrite: If True, fully overwrites an existing checkpoint in `path`. @@ -85,7 +84,7 @@ def save_pytree( def save_checkpointables( path: path_types.PathLike, - checkpointables: dict[str, Any], + checkpointables: dict[str, Checkpointable], *, overwrite: bool = False, custom_metadata: tree_types.JsonType | None = None, @@ -150,7 +149,7 @@ def save_checkpointables( # save operation is scheduled. def save_pytree_async( path: path_types.PathLike, - pytree: tree_types.PyTreeOf[tree_types.LeafType], + pytree: tree_types.PyTreeOf[tree_types.Leaf], *, checkpointable_name: str = PYTREE_CHECKPOINTABLE_KEY, overwrite: bool = False, @@ -193,10 +192,8 @@ def save_pytree_async( Args: path: The path to save the checkpoint to. pytree: The `PyTree` to save. This may be any JAX `PyTree` (including custom - objects registered as `PyTrees`) consisting of supported leaf types. - Default supported leaf types include `jax.Array`, `np.ndarray`, simple - types like `int`, `float`, `str`, and empty nodes. Support for custom - leaves is also possible by implementing a `LeafTypeHandler`. + objects registered as `PyTrees`) consisting of supported leaf types. See + `orbax.checkpoint.v1.tree` for a table of standard supported leaf types. checkpointable_name: The name of the checkpointable to save a pytree under. Defaults to 'pytree'. overwrite: If True, fully overwrites an existing checkpoint in `path`. @@ -220,7 +217,7 @@ def save_pytree_async( def save_checkpointables_async( path: path_types.PathLike, - checkpointables: dict[str, Any], + checkpointables: dict[str, Checkpointable], *, overwrite: bool = False, custom_metadata: tree_types.JsonType | None = None, @@ -296,7 +293,7 @@ def save_checkpointables_async( def get_v0_checkpointer_and_args( - checkpointables: dict[str, Any], + checkpointables: dict[str, Checkpointable], *, metrics: tree_types.JsonType | None = None, ) -> tuple[ diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py index 5b426e938..a7d090102 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py @@ -20,7 +20,7 @@ def resolve_storage_options( keypath: tree_types.PyTreeKeyPath, - value: tree_types.LeafType, + value: tree_types.Leaf, array_saving_options: options_lib.ArrayOptions.Saving, ) -> options_lib.ArrayOptions.Saving.StorageOptions: """Resolves storage options using a global default and a per-leaf creator. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py index 81f079942..3a21929fd 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py @@ -282,7 +282,7 @@ def should_save(self, step: int) -> bool: def save_pytree( self, step: int, - pytree: tree_types.PyTreeOf[tree_types.LeafType], + pytree: tree_types.PyTreeOf[tree_types.Leaf], *, checkpointable_name: str = PYTREE_CHECKPOINTABLE_KEY, force: bool = False, @@ -465,7 +465,7 @@ def save_checkpointables( def save_pytree_async( self, step: int, - pytree: tree_types.PyTreeOf[tree_types.LeafType], + pytree: tree_types.PyTreeOf[tree_types.Leaf], *, checkpointable_name: str = PYTREE_CHECKPOINTABLE_KEY, force: bool = False, @@ -586,11 +586,11 @@ def load_pytree( self, step: int | CheckpointMetadata | None = None, abstract_pytree: ( - tree_types.PyTreeOf[tree_types.AbstractLeafType] | None + tree_types.PyTreeOf[tree_types.AbstractLeaf] | None ) = None, *, checkpointable_name: str = PYTREE_CHECKPOINTABLE_KEY, - ) -> tree_types.PyTreeOf[tree_types.LeafType]: + ) -> tree_types.PyTreeOf[tree_types.Leaf]: """Loads a PyTree checkpoint at the given step. This method behaves similarly to the standalone free function @@ -772,9 +772,9 @@ def load_pytree_async( self, step: int | CheckpointMetadata | None = None, abstract_pytree: ( - tree_types.PyTreeOf[tree_types.AbstractLeafType] | None + tree_types.PyTreeOf[tree_types.AbstractLeaf] | None ) = None, - ) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.LeafType]]: + ) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]: """Not yet supported.""" raise NotImplementedError() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py index fe4870dbe..68f3e8474 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py @@ -86,7 +86,7 @@ def save_pytree( self, checkpointer: Checkpointer, step: int, - pytree: tree_types.PyTreeOf[tree_types.LeafType], + pytree: tree_types.PyTreeOf[tree_types.Leaf], metrics: tree_types.JsonType | None = None, custom_metadata: tree_types.JsonType | None = None, ) -> bool: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/tree/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/tree/types.py index b856a5c0b..0c96d99b1 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/tree/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/tree/types.py @@ -18,9 +18,10 @@ from typing import Any, Generic, TypeVar import jax import numpy as np -from orbax.checkpoint._src.tree import types +from orbax.checkpoint._src.tree import types as tree_types +from orbax.checkpoint.experimental.v1._src.arrays import types as array_types -JsonType = types.JsonType +JsonType = tree_types.JsonType T = TypeVar("T") @@ -59,13 +60,17 @@ class PyTree: pass -PyTreeKey = types.PyTreeKey -PyTreeKeyPath = types.PyTreePath +PyTreeKey = tree_types.PyTreeKey +PyTreeKeyPath = tree_types.PyTreePath -ScalarType = int | float | bool -LeafType = jax.Array | np.ndarray | str | ScalarType | Any -AbstractLeafType = Any # TODO(cpgaffney): Add a type for abstract leaves. +Leaf = jax.Array | np.ndarray | array_types.Scalar | str +AbstractLeaf = ( + array_types.AbstractArray + | array_types.AbstractShardedArray + | array_types.AbstractScalar + | str +) -JsonType = types.JsonType +JsonType = tree_types.JsonType PLACEHOLDER = ... diff --git a/checkpoint/orbax/checkpoint/experimental/v1/arrays.py b/checkpoint/orbax/checkpoint/experimental/v1/arrays.py index 95097d37b..b201f5800 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/arrays.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/arrays.py @@ -18,5 +18,13 @@ from orbax.checkpoint.experimental.v1._src.arrays.abstract_arrays import ( to_shape_dtype_struct, +) +from orbax.checkpoint.experimental.v1._src.arrays.types import ( ArrayLike, + AbstractArray, + AbstractShardedArray, + Scalar, + AbstractScalar, + Shape, + DType, ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/handlers.py b/checkpoint/orbax/checkpoint/experimental/v1/handlers.py index e44597712..327173572 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/handlers.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/handlers.py @@ -18,6 +18,8 @@ import orbax.checkpoint.experimental.v1._src.handlers.global_registration from orbax.checkpoint.experimental.v1._src.handlers.types import ( + Checkpointable, + AbstractCheckpointable, CheckpointableHandler, StatefulCheckpointable, ) @@ -32,7 +34,6 @@ JsonHandler, ) - from orbax.checkpoint.experimental.v1._src.handlers.registration import ( CheckpointableHandlerRegistry, global_registry, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/tree.py b/checkpoint/orbax/checkpoint/experimental/v1/tree.py index 45e519efd..c376e12ba 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/tree.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/tree.py @@ -12,7 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Public symbols for tree module.""" +"""Public symbols for tree module. + +Standard supported leaf types are described by the table below. +See +https://orbax.readthedocs.io/en/latest/guides/checkpoint/v1/checkpointing_pytrees.html#standard-leaf-types +for more information. + +| `Leaf` Type | `AbstractLeaf` Type | Properties | +:------- | :-------- | :-------- | +|`jax.Array`|`ocp.arrays.AbstractShardedArray` (`jax.ShapeDtypeStruct`) +|`shape`, `dtype`, +`sharding`| +|`np.ndarray`|`ocp.arrays.AbstractArray` (`np.ndarray`) |`shape`, `dtype`| +|`int`|`int`| | +|`float`|`float`| | +|`bytes`|`bytes`| | +|`str`|`str`| | +""" # pylint: disable=g-importing-member, g-multiple-import, g-bad-import-order, unused-import @@ -26,3 +43,7 @@ PyTreeKeyPath, JsonType, ) +from orbax.checkpoint.experimental.v1._src.tree.types import ( + Leaf, + AbstractLeaf, +) diff --git a/docs/api_reference/checkpoint.v1.arrays.rst b/docs/api_reference/checkpoint.v1.arrays.rst index 3f76b4352..522509bb3 100644 --- a/docs/api_reference/checkpoint.v1.arrays.rst +++ b/docs/api_reference/checkpoint.v1.arrays.rst @@ -7,3 +7,7 @@ .. autofunction:: to_shape_dtype_struct .. autodata:: ArrayLike +.. autoclass:: AbstractArray +.. autoclass:: AbstractShardedArray +.. autotype:: Scalar +.. autotype:: AbstractScalar diff --git a/docs/api_reference/checkpoint.v1.handlers.rst b/docs/api_reference/checkpoint.v1.handlers.rst index f3ab2bb57..a7bcd5210 100644 --- a/docs/api_reference/checkpoint.v1.handlers.rst +++ b/docs/api_reference/checkpoint.v1.handlers.rst @@ -8,6 +8,8 @@ Types ------------------------------------------------------------ +.. autotype:: Checkpointable +.. autotype:: AbstractCheckpointable .. autoclass:: CheckpointableHandler .. autoclass:: StatefulCheckpointable diff --git a/docs/api_reference/checkpoint.v1.rst b/docs/api_reference/checkpoint.v1.rst index ee3f8ae3e..88eae462f 100644 --- a/docs/api_reference/checkpoint.v1.rst +++ b/docs/api_reference/checkpoint.v1.rst @@ -25,6 +25,12 @@ Submodules Top-level Symbols ----------------- +Types +~~~~~~~ +.. autotype:: Checkpointable +.. autotype:: AbstractCheckpointable +.. autodata:: PLACEHOLDER + Loading ~~~~~~~ .. autofunction:: load_pytree @@ -50,10 +56,6 @@ Path Utilities ~~~~~~~~~~~~~~ .. autofunction:: is_orbax_checkpoint -Constants -~~~~~~~~~ -.. autodata:: PLACEHOLDER - Synchronization ~~~~~~~~~~~~~~~ .. autoclass:: AsyncResponse diff --git a/docs/api_reference/checkpoint.v1.tree.rst b/docs/api_reference/checkpoint.v1.tree.rst index 15712a08d..a1ad8f197 100644 --- a/docs/api_reference/checkpoint.v1.tree.rst +++ b/docs/api_reference/checkpoint.v1.tree.rst @@ -10,6 +10,8 @@ Types ------------------------------------------------------------ .. autodata:: PyTree .. autodata:: PyTreeOf +.. autodata:: Leaf +.. autodata:: AbstractLeaf .. autodata:: PyTreeKey .. autodata:: PyTreeKeyPath .. autodata:: JsonType