diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 4e834f2fd..8db916037 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -17,6 +17,10 @@ effort in maintaining. Safetensors is instead the recommended conversion case. - #v1 Allow a context to be default-configured for all `Checkpointer` operations. +### Removed + +- #v1 Remove `LeafHandler` as a user-exposed layer (it remains as an internal layer). + ## [0.11.39] - 2026-05-06 ### Added diff --git a/checkpoint/orbax/checkpoint/experimental/v1/__init__.py b/checkpoint/orbax/checkpoint/experimental/v1/__init__.py index f2d092cf9..2d4c1e9f0 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/__init__.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/__init__.py @@ -26,7 +26,6 @@ from orbax.checkpoint.experimental.v1 import handlers from orbax.checkpoint.experimental.v1 import partial from orbax.checkpoint.experimental.v1 import path -from orbax.checkpoint.experimental.v1 import serialization from orbax.checkpoint.experimental.v1 import training from orbax.checkpoint.experimental.v1 import tree from orbax.checkpoint.experimental.v1._src.synchronization.types import ( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py index c25fd9502..72d3c02dc 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py @@ -182,13 +182,9 @@ class PyTreeOptions: This dataclass defines the configuration parameters for creating and managing PyTree saving and loading on disk. - # TODO: Include an example of registering a custom LeafHandler. - Attributes: saving: Options for saving PyTrees. loading: Options for loading PyTrees. - leaf_handler_registry: Optional Leaf Handler Registry. If provided, it will - override the default Leaf Handler Registry. """ @dataclasses.dataclass(frozen=True, kw_only=True) @@ -214,7 +210,6 @@ class Loading: saving: Saving = dataclasses.field(default_factory=Saving) loading: Loading = dataclasses.field(default_factory=Loading) - leaf_handler_registry: serialization_types.LeafHandlerRegistry | None = None @dataclasses.dataclass(frozen=True, kw_only=True) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py index 9eda5edda..b1ec35e73 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py @@ -293,6 +293,9 @@ def __init__( array_metadata_validator: array_metadata_store_lib.Validator = ( array_metadata_store_lib.Validator() ), + leaf_handler_registry: ( + serialization_types.LeafHandlerRegistry | None + ) = None, partial_save_mode: bool = False, ): context = context_lib.get_context(context) @@ -301,9 +304,7 @@ def __init__( self._partial_save_mode = partial_save_mode self._leaf_handler_registry = ( - self._context.pytree_options.leaf_handler_registry - if self._context.pytree_options.leaf_handler_registry is not None - else registry.StandardLeafHandlerRegistry() + leaf_handler_registry or registry.StandardLeafHandlerRegistry() ) type_handler_registry = compatibility.get_v0_type_handler_registry( @@ -437,13 +438,10 @@ async def load( abstract_checkpointable: The abstract checkpointable to load into. If None, the handler will attempt to load the entire checkpoint using the recorded metadata. Otherwise, the `abstract_checkpointable` is expected - to be a PyTree of abstract leaves. See - :py:class:`~.v1.serialization.LeafHandler` for more details. The - abstract leaf may be a value of type `AbstractLeaf`, - `Type[AbstractLeaf]`, or `None`. E.g. if the `AbstractLeaf` is - `AbstractFoo`, it is always valid to pass `AbstractFoo()` or - `AbstractFoo` or `None`. Passing the latter two indicates that metadata - should be used to restore the leaf. + to be a PyTree of abstract leaves. The abstract leaf may be a value of + type :py:class:`~.v1.tree.AbstractLeaf`, + `Type[AbstractLeaf]`, or `None`. Passing the latter two indicates that + the metadata should be used to restore the leaf. Returns: A awaitable which can be awaited to complete the load operation and diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py index ce87dff86..792bd2b65 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py @@ -300,12 +300,13 @@ def handler_with_options( loading=options_lib.PyTreeOptions.Loading( partial_load=partial_load, ), - leaf_handler_registry=leaf_handler_registry, ), ) handler = handler_test_utils.create_test_handler( - pytree_handler.PyTreeHandler, context=context + pytree_handler.PyTreeHandler, + context=context, + leaf_handler_registry=leaf_handler_registry, ) try: @@ -2288,6 +2289,63 @@ def _as_abstract_type(x): array_metadata_store=array_metadata_store, ) + def test_custom_array_type(self): + # Set up local context with custom registry. + custom_registry = registry.StandardLeafHandlerRegistry() + custom_registry.add( + handler_test_utils.LazyArray, + handler_test_utils.AbstractLazyArray, + handler_test_utils.LazyArrayHandler, + ) + + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + lazy_arr = handler_test_utils.LazyArray( + create_sharded_array(np.arange(16), sharding) + ) + pytree = {'a': lazy_arr} + + with handler_with_options( + use_ocdbt=False, leaf_handler_registry=custom_registry + ) as handler: + handler.save(self.directory, pytree) + + # Attempt to load without context (using global default registry), which + # should fail + with handler_with_options(use_ocdbt=False) as handler: + with self.assertRaisesRegex(ValueError, 'TypeHandler lookup failed'): + handler.load(self.directory) + + # Load with the custom registry context + with handler_with_options( + use_ocdbt=False, leaf_handler_registry=custom_registry + ) as handler: + loaded = handler.load(self.directory) + self.assertEqual(loaded['a'].array.shape, lazy_arr.array.shape) + np.testing.assert_array_equal(loaded['a'].array, lazy_arr.array) + + # Load custom array directly as jax.Array by mapping secondary_typestr + custom_registry2 = registry.StandardLeafHandlerRegistry() + # Override the default jax.Array handler with LazyArray typestr, + # ensuring that the serialized jax.array annotated with original LazyArray + # typestr is loaded as a jax.Array. + custom_registry2.add( + jax.Array, + serialization_types.AbstractShardedArray, + array_leaf_handler.ArrayLeafHandler, + secondary_typestrs=[ + serialization_types.typestr(handler_test_utils.LazyArrayHandler) + ], + override=True, + ) + + with handler_with_options( + use_ocdbt=False, leaf_handler_registry=custom_registry2 + ) as handler: + loaded_as_jax_array = handler.load(self.directory) + self.assertIsInstance(loaded_as_jax_array['a'], jax.Array) + np.testing.assert_array_equal(loaded_as_jax_array['a'], lazy_arr.array) + def test_abstract_array_loading(self): replicated_sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ('x',)), diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py index 35947c684..f36a01801 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py @@ -198,10 +198,10 @@ def save_pytree_async( Args: path: The path to save the checkpoint to. pytree: The PyTree to save. This may be any JAX PyTree (including custom - objects registered as PyTrees) consisting of supported leaf types. Default - supported leaf types include `jax.Array`, `np.ndarray`, simple types like - `int`, `float`, `str`, and empty nodes. Support for custom leaves is also - possible by implementing a :py:class:`.LeafHandler`. + objects registered as PyTrees) consisting of supported leaf types (see + :py:class:`~.v1.tree.Leaf`). Default supported leaf types include + `jax.Array`, `np.ndarray`, simple types like + `int`, `float`, `str`, and empty nodes. custom_metadata: User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py index a7d090102..322b2d0b2 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py @@ -74,4 +74,3 @@ def resolve_storage_options( chunk_byte_size=resolved_chunk_byte_size, shard_axes=resolved_shard_axes if resolved_shard_axes is not None else (), ) - diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py index 252828beb..de7ca9c99 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py @@ -17,13 +17,10 @@ import dataclasses from typing import Any, Awaitable, Generic, Protocol, Sequence, Tuple, Type, TypeVar -import jax -import jax.experimental.layout as jax_layout -import numpy as np -from orbax.checkpoint._src.arrays import types as arrays_types from orbax.checkpoint._src.serialization import limits from orbax.checkpoint._src.serialization import types as serialization_types from orbax.checkpoint._src.tree import utils as tree_utils +from orbax.checkpoint.experimental.v1._src.arrays import types as array_types from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.tree import types as tree_types import tensorstore as ts @@ -32,66 +29,34 @@ Leaf = TypeVar('Leaf') AbstractLeaf = TypeVar('AbstractLeaf') -Shape = arrays_types.Shape -DType = arrays_types.DType +Shape = array_types.Shape +DType = array_types.DType PLACEHOLDER = ... IsPrioritizedKeyFn = serialization_types.IsPrioritizedKeyFn -Scalar = int | float | np.number +### STANDARD PYTREE LEAF TYPES + +### SCALAR +Scalar = array_types.Scalar # Optional type hint for a scalar leaf handler. If provided, the restored scalar # will be cast to this type. Only casting to int or float is supported. AbstractScalar = Scalar -AbstractString = str - -if jax.__version_info__ >= (0, 6, 2): - Format = jax_layout.Format -else: - Format = jax_layout.Layout - - -class AbstractArray(Protocol): - """Abstract representation of an array. - - This is a protocol for an abstract array that can be used to represent - the metadata belonging to an array. - - shape: - Tuple of integers describing the array shape. - dtype: - Dtype of array elements. - """ - - shape: Shape | None - dtype: DType | None - - -class AbstractShardedArray(Protocol): - """Abstract representation of an array. - This is a protocol for an abstract array that can be used to represent various - metadata types such as :py:class:`jax.ShapeDtypeStruct` and - :py:class:`~orbax.checkpoint.metadata.ArrayMetadata`. - - #TODO(dnlng): All attributes are made optional to support the case where - # the ArrayMetadata is passed into the metadata() call to pass only the - # `write_shape`. Optional attributes are not needed once write_shape is - # refactored. +### STRING +# str +AbstractString = str +### ARRAY +# np.ndarray +AbstractArray = array_types.AbstractArray - shape: - Tuple of integers describing the array shape. - dtype: - Dtype of array elements. - Sharding: - Sharding to indicate how the array is sharded. This can be jax's Sharding or - Layout or None. - """ +### SHARDED ARRAY +# jax.Array +AbstractShardedArray = array_types.AbstractShardedArray - shape: Shape | None - dtype: DType | None - sharding: jax.sharding.Sharding | Format | None = None # pytype: disable=invalid-annotation +### def is_placeholder(value: Any) -> bool: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index edcc5e006..8987b9385 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -42,16 +42,13 @@ from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.path import async_utils from orbax.checkpoint.experimental.v1._src.path import types as path_types -from orbax.checkpoint.experimental.v1._src.serialization import array_leaf_handler from orbax.checkpoint.experimental.v1._src.serialization import registry as serialization_registry -from orbax.checkpoint.experimental.v1._src.serialization import types from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils from orbax.checkpoint.experimental.v1._src.testing import handler_utils from orbax.checkpoint.experimental.v1._src.testing import tree_utils as tree_test_utils from orbax.checkpoint.experimental.v1._src.tree import types as tree_types - PyTree = tree_types.PyTree Path = path_types.Path InvalidLayoutError = checkpoint_layout.InvalidLayoutError @@ -408,66 +405,6 @@ def test_leaf_change_type(self): ), ) - def test_custom_array_type(self): - # Set up local context with custom registry. - custom_registry = serialization_registry.StandardLeafHandlerRegistry() - custom_registry.add( - handler_utils.LazyArray, - handler_utils.AbstractLazyArray, - handler_utils.LazyArrayHandler, - ) - - custom_context = ocp.Context( - pytree_options=ocp.options.PyTreeOptions( - leaf_handler_registry=custom_registry - ) - ) - - mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) - sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - lazy_arr = handler_utils.LazyArray( - create_sharded_array(np.arange(16), sharding) - ) - pytree = {'a': lazy_arr} - - with custom_context: - ocp.save_pytree(self.directory, pytree) - - # Attempt to load without context (using global default registry), which - # should fail - with self.assertRaisesRegex(ValueError, 'TypeHandler lookup failed'): - ocp.load_pytree(self.directory) - - # Load with the custom registry context - with custom_context: - loaded = ocp.load_pytree(self.directory) - self.assertEqual(loaded['a'].array.shape, lazy_arr.array.shape) - np.testing.assert_array_equal(loaded['a'].array, lazy_arr.array) - - # Load custom array directly as jax.Array by mapping secondary_typestr - custom_registry2 = serialization_registry.StandardLeafHandlerRegistry() - # Override the default jax.Array handler with LazyArray typestr, - # ensuring that the serialized jax.array annotated with original LazyArray - # typestr is loaded as a jax.Array. - custom_registry2.add( - jax.Array, - types.AbstractShardedArray, - array_leaf_handler.ArrayLeafHandler, - secondary_typestrs=[types.typestr(handler_utils.LazyArrayHandler)], - override=True, - ) - custom_context2 = ocp.Context( - pytree_options=ocp.options.PyTreeOptions( - leaf_handler_registry=custom_registry2 - ) - ) - with custom_context2: - loaded_as_jax_array = ocp.load_pytree(self.directory) - self.assertIsInstance(loaded_as_jax_array['a'], jax.Array) - np.testing.assert_array_equal( - loaded_as_jax_array['a'], lazy_arr.array - ) - def test_empty_array(self): value = np.ones(shape=(0,)) with self.assertRaisesRegex(ValueError, 'zero size'): @@ -735,9 +672,7 @@ def test_missing_keys(self): test_utils.assert_tree_equal(self, self.numpy_pytree, loaded) with self.subTest('load_checkpointables'): - with self.assertRaisesRegex( - KeyError, 'Requested checkpointables:' - ): + with self.assertRaisesRegex(KeyError, 'Requested checkpointables:'): ocp.load_checkpointables( self.directory, {'foo': handler_utils.AbstractFoo()} ) @@ -847,9 +782,7 @@ def test_save_checkpointables_deleted(self): loaded = ocp.load_checkpointables(self.directory) self.assertSameElements(['two'], loaded.keys()) - with self.assertRaisesRegex( - KeyError, 'Requested checkpointables:' - ): + with self.assertRaisesRegex(KeyError, 'Requested checkpointables:'): ocp.load_checkpointables(self.directory, {'one': None}) @@ -1242,6 +1175,7 @@ def test_subchunking(self): self.assertEqual(metadata[k].storage_metadata.chunk_shape, (2,)) with self.subTest('per_key_setting'): + def scoped_storage_options_creator(key, value): del value if 'a' in tree_utils.str_keypath(key): @@ -1251,6 +1185,7 @@ def scoped_storage_options_creator(key, value): return ocp.options.ArrayOptions.Saving.StorageOptions( chunk_byte_size=8, # force divide in 2 subchunks ) + with ocp.Context( array_options=ocp.options.ArrayOptions( saving=ocp.options.ArrayOptions.Saving( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/serialization.py b/checkpoint/orbax/checkpoint/experimental/v1/serialization.py deleted file mode 100644 index 50a3e956f..000000000 --- a/checkpoint/orbax/checkpoint/experimental/v1/serialization.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2026 The Orbax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Public API for Serialization.""" - -# pylint: disable=g-importing-member, g-multiple-import, unused-import, g-bad-import-order - -from orbax.checkpoint.experimental.v1._src.serialization.registry import ( - BaseLeafHandlerRegistry, - StandardLeafHandlerRegistry, -) - -from orbax.checkpoint.experimental.v1._src.serialization.types import ( - LeafHandler, - SerializationParam, - DeserializationParam, - SerializationContext, - DeserializationContext, -) diff --git a/docs/api_reference/checkpoint.v1.serialization.rst b/docs/api_reference/checkpoint.v1.serialization.rst deleted file mode 100644 index 52a50207b..000000000 --- a/docs/api_reference/checkpoint.v1.serialization.rst +++ /dev/null @@ -1,27 +0,0 @@ -``ocp.v1.serialization`` module -============================================================================ - -.. currentmodule:: orbax.checkpoint.v1.serialization - -.. automodule:: orbax.checkpoint.experimental.v1.serialization - :members: - -Registry ------------------------------------------------------------- -.. autoclass:: BaseLeafHandlerRegistry - :members: -.. autoclass:: StandardLeafHandlerRegistry - :members: - -Types ------------------------------------------------------------- -.. autoclass:: LeafHandler - :members: -.. autoclass:: SerializationParam - :members: -.. autoclass:: DeserializationParam - :members: -.. autoclass:: SerializationContext - :members: -.. autoclass:: DeserializationContext - :members: diff --git a/docs/guides/checkpoint/v1/customization.ipynb b/docs/guides/checkpoint/v1/customization.ipynb index 4dddc226e..363f2ecc8 100644 --- a/docs/guides/checkpoint/v1/customization.ipynb +++ b/docs/guides/checkpoint/v1/customization.ipynb @@ -15,9 +15,8 @@ }, "cell_type": "markdown", "source": [ - "Orbax allows users to specify their own logic for dealing with custom objects.\n", - "Customization can occur at two levels - the level of a \"checkpointable\", and the\n", - "level of a \"PyTree leaf\"." + "Orbax allows users to specify their own logic for dealing with custom\n", + "\"Checkpointables\"." ] }, { @@ -788,475 +787,6 @@ }, "cell_type": "markdown", "source": [] - }, - { - "metadata": { - "id": "rpeAEZuE7cqv" - }, - "cell_type": "markdown", - "source": [ - "## Custom Leaf Handler" - ] - }, - { - "metadata": { - "id": "Z2V5i7qB7fxZ" - }, - "cell_type": "markdown", - "source": [ - "This is an advanced topic. Make sure you are familar with [the guide on checkpointing PyTrees](checkpointing_pytrees.ipynb) before reading this notebook.\n", - "\n", - "PyTrees are a common tree structure used to represent training states. LeafHandlers are responsible for serializing and deserializing each leaf node. Different leaf object types require specific LeafHandlers. Orbax includes standard LeafHandlers for common types including jax.Array, np.ndarray, int, float, and str. Before creating a custom LeafHandler, always check the options available in {py:class}`ocp.options.PyTreeOptions ` and {py:class}`ocp.options.ArrayOptions ` to ensure no existing options can meet your needs." - ] - }, - { - "metadata": { - "id": "BlmvH1qR6pVf" - }, - "cell_type": "markdown", - "source": [ - "One of common reasons to have a custom LeafHandler is to support a custom type that is not supported by Orbax. I will use the `Point` class from above as the example. Let's say you need to checkpoint many Point objects in a nested tree structure. It might make sense to store it within a Pytree along with your train state. Then you would need to write a PointLeafHandler and register it with the LeafHandlerRegistry." - ] - }, - { - "metadata": { - "id": "fdx96Jt06pVf" - }, - "cell_type": "code", - "source": [ - "import dataclasses\n", - "import json\n", - "from typing import Awaitable, Type\n", - "from etils import epath\n", - "import numpy as np\n", - "from orbax.checkpoint import multihost\n", - "import orbax.checkpoint.experimental.v1 as ocp\n", - "from orbax.checkpoint.experimental.v1 import serialization\n", - "\n", - "\n", - "@dataclasses.dataclass\n", - "class Point:\n", - " x: int | float\n", - " y: int | float" - ], - "outputs": [], - "execution_count": 13 - }, - { - "metadata": { - "id": "NzNJmUiZ6pVf" - }, - "cell_type": "markdown", - "source": [ - "For LeafHandler, we need to define a AbtractPoint class as well. This is required for two reasons:\n", - "1. The AbstractPoint class is used during restoration to indicate what type of a leaf object will be restored as.\n", - "2. In addition, metadata of a leaf node will be returned as AbstractPoint, avoid the need to restore the actual leaf object.\n", - "\n", - "In following example of AbstractPoint, we just define it as the type of data members without actual values." - ] - }, - { - "metadata": { - "id": "gk9azkLt6pVf" - }, - "cell_type": "code", - "source": [ - "@dataclasses.dataclass\n", - "class AbstractPoint:\n", - " x: Type[int|float]\n", - " y: Type[int|float]\n", - "\n", - " @classmethod\n", - " def from_point(cls, point):\n", - " return cls(x=type(point.x), y=type(point.y))\n" - ], - "outputs": [], - "execution_count": 14 - }, - { - "metadata": { - "id": "5hNKzRSx6pVg" - }, - "cell_type": "markdown", - "source": [ - "Next we will define the actual PointLeafHandler. See the comments below which explain what functions are required." - ] - }, - { - "metadata": { - "id": "X-pq0VT06pVg" - }, - "cell_type": "code", - "source": [ - "from typing import Sequence\n", - "import asyncio\n", - "import aiofiles" - ], - "outputs": [], - "execution_count": 15 - }, - { - "metadata": { - "id": "A9r2nGmW6pVg" - }, - "cell_type": "code", - "source": [ - "class PointLeafHandler(serialization.LeafHandler[Point, AbstractPoint]):\n", - " \"\"\"A custom leaf handler for testing.\"\"\"\n", - "\n", - " def __init__(self, context: ocp.Context | None = None):\n", - " \"\"\"Required Initializer.\n", - "\n", - " This initializer is initialized lazily during checkpoint operations. If the\n", - " signature is not matched, an exception will be raised during initialization.\n", - "\n", - " Args:\n", - " context: The context for the leaf handler. The leaf handler can\n", - " initialize and operate according to the context. In this example, we do\n", - " not utilize it though. For more examples, see ArrayLeafHandler.\n", - " \"\"\"\n", - " del context\n", - "\n", - " async def serialize(\n", - " self,\n", - " params: Sequence[serialization.SerializationParam[Point]],\n", - " serialization_context: serialization.SerializationContext,\n", - " ) -> Awaitable[None]:\n", - " \"\"\"Required Serialize function.\n", - "\n", - " This function writes the specified leaves of a checkpointable to a storage\n", - " location. A couple of notes here:\n", - " 1. This function is called on all hosts, but in this example, only the\n", - " primary host will write.\n", - " 2. we use `await await_creation()` to ensure the parent directory is created\n", - " before writing.\n", - " \"\"\"\n", - "\n", - " async def _background_serialize(params, serialization_context):\n", - " # make sure the parent directory is created\n", - " await serialization_context.parent_dir.await_creation()\n", - "\n", - " # only the primary host writes\n", - " if multihost.is_primary_host(0):\n", - " for param in params:\n", - " # save the value\n", - " async with aiofiles.open(\n", - " serialization_context.parent_dir.path / f'{param.name}.txt',\n", - " 'w',\n", - " ) as f:\n", - " await f.write(json.dumps(dataclasses.asdict(param.value)))\n", - "\n", - " # save the metadata\n", - " async with aiofiles.open(\n", - " serialization_context.parent_dir.path\n", - " / f'{param.name}.metadata.txt',\n", - " 'w',\n", - " ) as abstract_f:\n", - " contents = json.dumps({\n", - " k: type(v).__name__\n", - " for k, v in dataclasses.asdict(param.value).items()\n", - " })\n", - " await abstract_f.write(contents)\n", - "\n", - " return _background_serialize(params, serialization_context)\n", - "\n", - " async def deserialize(\n", - " self,\n", - " params: Sequence[serialization.DeserializationParam[AbstractPoint]],\n", - " deserialization_context: serialization.DeserializationContext,\n", - " ) -> Awaitable[Sequence[Point]]:\n", - " \"\"\"Required Deserialize function.\n", - "\n", - " Returns sequence of leaves from a stored checkpointable location. Note that\n", - " we use asyncio.to_thread to ensure the deserialization is performed in a\n", - " background thread immediately before returning this call.\n", - " \"\"\"\n", - "\n", - " async def _deserialize_impl():\n", - " ret = []\n", - " for param in params:\n", - " async with aiofiles.open(\n", - " deserialization_context.parent_dir / f'{param.name}.txt',\n", - " 'r',\n", - " ) as f:\n", - " ret.append(Point(**json.loads(await f.read())))\n", - "\n", - " return ret\n", - "\n", - " return _deserialize_impl()\n", - "\n", - " async def metadata(\n", - " self,\n", - " params: Sequence[serialization.DeserializationParam[None]],\n", - " deserialization_context: serialization.DeserializationContext,\n", - " ) -> Sequence[AbstractPoint]:\n", - " \"\"\"Required Metadata function.\n", - "\n", - " Returns a sequence of metadata that helps to describe the available leaves\n", - " in this checkpoint location.\n", - " \"\"\"\n", - "\n", - " ret = []\n", - " for param in params:\n", - " async with aiofiles.open(\n", - " deserialization_context.parent_dir / f'{param.name}.metadata.txt', 'r'\n", - " ) as f:\n", - " contents = json.loads(await f.read())\n", - " ret.append(\n", - " AbstractPoint(\n", - " **{k: getattr(__builtins__, v) for k, v in contents.items()}\n", - " )\n", - " )\n", - " return ret" - ], - "outputs": [], - "execution_count": 16 - }, - { - "metadata": { - "id": "_Pd9zWTE6pVg" - }, - "cell_type": "markdown", - "source": [ - "Next, we will define a train_state for demonstration purpose. In this train_state, it has some common types as well as some Points that are nested inside the PyTree." - ] - }, - { - "metadata": { - "id": "E0NwLMXD6pVg" - }, - "cell_type": "code", - "source": [ - "# define a PyTree Train State\n", - "\n", - "train_state = {\n", - " 'a': np.arange(16),\n", - " 'b': np.ones(16),\n", - " 'scalar': 123.0,\n", - " 'mixed': {\n", - " 'a': np.arange(16),\n", - " 'b': np.ones(16),\n", - " 'scalar': 123.0,\n", - " 'Point': Point(0, 0.5),\n", - " },\n", - " 'Points': {\n", - " 'level1': {\n", - " 'point_int': Point(1, 2),\n", - " 'point_float': Point(3.0, 4.0),\n", - " 'level2': {\n", - " 'point_mixed1': Point(5, 6.0),\n", - " 'point_mixed2': Point(7.0, 8),\n", - " 'point_int': Point(9, 10),\n", - " 'point_float': Point(11.0, 12.0),\n", - " },\n", - " }\n", - " },\n", - "}" - ], - "outputs": [], - "execution_count": 17 - }, - { - "metadata": { - "id": "CSF-iSZI6pVg" - }, - "cell_type": "markdown", - "source": [ - "Next, we will prepare a LeafHandlerRegistry. In this registry, the type and its abstract type will map with a LeafHandler. In the following example, we create a `StandardLeafHandler` first. This is the same as the registry used by default. Then PointLeafHandler is added along its type Point and abstract type AbstractPoint. Note that only the `PointLeafHandler` type is registered, not the handler instance. The instance will be created lazily depending on checkpoint operations." - ] - }, - { - "metadata": { - "id": "nQk4Iqh_6pVg" - }, - "cell_type": "code", - "source": [ - "# Create LeafHandlerRegistry\n", - "registry = serialization.StandardLeafHandlerRegistry() # with standard handlers\n", - "registry.add(Point, AbstractPoint, PointLeafHandler) # add custom handler" - ], - "outputs": [], - "execution_count": 18 - }, - { - "metadata": { - "id": "qFimu0XB6pVg" - }, - "cell_type": "code", - "source": [ - "# prepare the checkpoint directory\n", - "path = epath.Path('/tmp/customization/with_points')\n", - "path.rmtree(missing_ok=True)" - ], - "outputs": [], - "execution_count": 19 - }, - { - "metadata": { - "id": "d1MEggBk6pVg" - }, - "cell_type": "markdown", - "source": [ - "Now, we are ready to save the `train_state`. To customize context and pass the custom registry, you can use the `ocp.Context` as below." - ] - }, - { - "metadata": { - "id": "yGbwDdpC6pVg" - }, - "cell_type": "code", - "source": [ - "with ocp.Context(\n", - " pytree_options=ocp.options.PyTreeOptions(\n", - " leaf_handler_registry=registry\n", - " )\n", - "):\n", - " ocp.save_pytree(path, train_state)" - ], - "outputs": [], - "execution_count": 20 - }, - { - "metadata": { - "id": "GHarzRaB6pVg" - }, - "cell_type": "markdown", - "source": [ - "After saving, let's load the checkpoint back to see if we can get back the expected Point objects. We will again create a ocp.Context with our custom registry." - ] - }, - { - "metadata": { - "id": "HMgdqu7U6pVg" - }, - "cell_type": "code", - "source": [ - "with ocp.Context(\n", - " pytree_options=ocp.options.PyTreeOptions(\n", - " leaf_handler_registry=registry\n", - " )\n", - "):\n", - " restored_train_state = ocp.load_pytree(path)" - ], - "outputs": [], - "execution_count": 21 - }, - { - "metadata": { - "id": "d4hTdQBu6pVg" - }, - "cell_type": "code", - "source": [ - "import pprint\n", - "pprint.pprint(restored_train_state)" - ], - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'Points': {'level1': {'level2': {'point_float': Point(x=11.0, y=12.0),\n", - " 'point_int': Point(x=9, y=10),\n", - " 'point_mixed1': Point(x=5, y=6.0),\n", - " 'point_mixed2': Point(x=7.0, y=8)},\n", - " 'point_float': Point(x=3.0, y=4.0),\n", - " 'point_int': Point(x=1, y=2)}},\n", - " 'a': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),\n", - " 'b': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n", - " 'mixed': {'Point': Point(x=0, y=0.5),\n", - " 'a': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),\n", - " 'b': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n", - " 'scalar': 123.0},\n", - " 'scalar': 123.0}\n" - ] - } - ], - "execution_count": 22 - }, - { - "metadata": { - "id": "nh_huXzr6pVg" - }, - "cell_type": "markdown", - "source": [ - "We can see the restored_train_state looks exactly the same as the original train_state.\n", - "\n", - "Finally, we also want to see if we can read the expected metadata. Similarly, we will use ocp.Context to use our registry with the custom PointLeafHandler." - ] - }, - { - "metadata": { - "id": "B8GnatTB6pVg" - }, - "cell_type": "code", - "source": [ - "with ocp.Context(\n", - " pytree_options=ocp.options.PyTreeOptions(\n", - " leaf_handler_registry=registry\n", - " )\n", - "):\n", - " restored_metadata = ocp.pytree_metadata(path)" - ], - "outputs": [], - "execution_count": 23 - }, - { - "metadata": { - "id": "BtB3Sqzc6pVg" - }, - "cell_type": "markdown", - "source": [ - "We can see the AbstractPoints are returned for Point leaves." - ] - }, - { - "metadata": { - "id": "VjlZgYk76pVg" - }, - "cell_type": "code", - "source": [ - "pprint.pprint(restored_metadata.metadata)" - ], - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'Points': {'level1': {'level2': {'point_float': AbstractPoint(x=,\n", - " y=),\n", - " 'point_int': AbstractPoint(x=,\n", - " y=),\n", - " 'point_mixed1': AbstractPoint(x=,\n", - " y=),\n", - " 'point_mixed2': AbstractPoint(x=,\n", - " y=)},\n", - " 'point_float': AbstractPoint(x=,\n", - " y=),\n", - " 'point_int': AbstractPoint(x=,\n", - " y=)}},\n", - " 'a': NumpyMetadata(shape=(16,),\n", - " dtype=dtype('int64'),\n", - " storage_metadata=StorageMetadata(chunk_shape=(16,),\n", - " write_shape=None)),\n", - " 'b': NumpyMetadata(shape=(16,),\n", - " dtype=dtype('float64'),\n", - " storage_metadata=StorageMetadata(chunk_shape=(16,),\n", - " write_shape=None)),\n", - " 'mixed': {'Point': AbstractPoint(x=, y=),\n", - " 'a': NumpyMetadata(shape=(16,),\n", - " dtype=dtype('int64'),\n", - " storage_metadata=StorageMetadata(chunk_shape=(16,),\n", - " write_shape=None)),\n", - " 'b': NumpyMetadata(shape=(16,),\n", - " dtype=dtype('float64'),\n", - " storage_metadata=StorageMetadata(chunk_shape=(16,),\n", - " write_shape=None)),\n", - " 'scalar': 0.0},\n", - " 'scalar': 0.0}\n" - ] - } - ], - "execution_count": 24 } ], "metadata": {