diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py index 2f964a0c0..08249e65a 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py @@ -89,31 +89,27 @@ def is_valid(self) -> bool: @property def context(self) -> ocp.Context: - return ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving( - storage_options=ocp.options.ArrayOptions.Saving.StorageOptions( - chunk_byte_size=self.chunk_byte_size, - ), - use_ocdbt=self.use_ocdbt, - use_zarr3=self.use_zarr3, - use_replica_parallel=self.use_replica_parallel, - use_compression=self.use_compression, - enable_replica_parallel_separate_folder=self.enable_replica_parallel_separate_folder, - ), - loading=ocp.options.ArrayOptions.Loading( - use_load_and_broadcast=self.use_load_and_broadcast, - ), - ), - memory_options=ocp.options.MemoryOptions( - write_concurrent_bytes=self.save_concurrent_gb * 1024**3 - if self.save_concurrent_gb is not None - else None, - read_concurrent_bytes=self.restore_concurrent_gb * 1024**3 - if self.restore_concurrent_gb is not None - else None, - ), + ctx = ocp.Context() + ctx.array.saving.storage_options.chunk_byte_size = self.chunk_byte_size + ctx.array.saving.use_ocdbt = self.use_ocdbt + ctx.array.saving.use_zarr3 = self.use_zarr3 + ctx.array.saving.use_replica_parallel = self.use_replica_parallel + ctx.array.saving.use_compression = self.use_compression + ctx.array.saving.enable_replica_parallel_separate_folder = ( + self.enable_replica_parallel_separate_folder ) + ctx.array.loading.use_load_and_broadcast = self.use_load_and_broadcast + ctx.memory.write_concurrent_bytes = ( + self.save_concurrent_gb * 1024**3 + if self.save_concurrent_gb is not None + else None + ) + ctx.memory.read_concurrent_bytes = ( + self.restore_concurrent_gb * 1024**3 + if self.restore_concurrent_gb is not None + else None + ) + return ctx def clear_pytree(pytree: Any) -> Any: @@ -159,7 +155,7 @@ def test_fn( logging.info("Benchmark options: %s", pprint.pformat(options)) metrics_to_measure = get_metrics_to_measure(options) - with ocp.Context(context=options.context): + with ocp.Context(options.context): if options.enable_trace: jax.profiler.start_trace(context.path / "trace_save") if options.async_enabled: diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark.py index 9501b7444..a8e3e8ce3 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark.py @@ -115,7 +115,7 @@ def test_fn( reference_sharding_path=reference_sharding_path, ) - with ocp.Context(context=options.context): + with ocp.Context(options.context): loaded_pytree = ocp.load_pytree( reference_checkpoint_path, abstract_pytree ) @@ -123,7 +123,7 @@ def test_fn( for step in range(options.num_savings): logging.info("ReplicaParallelMultislice: Starting Step: %s", step) save_path = context.path / "ckpt" / str(step) - with ocp.Context(context=options.context): + with ocp.Context(options.context): if options.enable_trace: jax.profiler.start_trace(context.path / "trace_save") if options.async_enabled: diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py index 936df3e76..ce15b6e4a 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py @@ -95,7 +95,7 @@ def test_fn( options.reference_sharding_path ) - with ocp.Context(context=options.context): + with ocp.Context(options.context): metadata = ocp.pytree_metadata(reference_checkpoint_path) abstract_pytree = ( checkpoint_generation.get_abstract_state_from_sharding_config( diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py index 66b16d542..6968f3249 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py @@ -118,7 +118,7 @@ def test_fn( reference_sharding_path=reference_sharding_path, ) - with ocp.Context(context=options.context): + with ocp.Context(options.context): if options.enable_trace: jax.profiler.start_trace(context.path / "trace_load") with metrics.measure("load", metrics_to_measure): diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py index eb27f71df..59eb77ee6 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py @@ -18,6 +18,7 @@ from collections.abc import Iterable import contextvars +import copy import typing from absl import logging @@ -45,14 +46,13 @@ class Context(epy.ContextManager): """Context for customized checkpointing. This class manages the configuration options (e.g., async, multiprocessing, - array handling) used during Orbax checkpoint operations. + array handling) used during Orbax checkpoint operations using a mutable + namespace pattern. - Creating a new :py:class:`.Context` within an existing :py:class:`.Context` - sets all parameters from scratch by default. To inherit properties from a - parent :py:class:`.Context`, you must explicitly pass the parent context as - the first argument. The new context will inherit the parent's properties, - except for any options explicitly provided as keyword arguments to the child - context. + Creating a new :py:class:`.Context` sets all parameters from absolute defaults + by default. To inherit properties from a parent :py:class:`.Context`, pass the + parent context as the first positional argument. The new context will inherit + the parent's properties but can be mutated independently. WARNING: The context is thread-local and is not shared across threads. The entire context block must be executed within the same thread. If you dispatch @@ -69,116 +69,131 @@ class Context(epy.ContextManager): Example: Basic usage and explicit inheritance:: - import orbax.checkpoint as ocp + from orbax.checkpoint import v1 as ocp # Basic usage - with ocp.Context(pytree_options=ocp.options.PyTreeOptions()): - ocp.save_pytree(directory, tree) + ctx = ocp.Context() + ctx.pytree.loading.partial_load = True + with ctx: + ocp.load_pytree(directory, tree) # Inheriting properties from an existing context - with ocp.Context(pytree_options=ocp.options.PyTreeOptions()) as outer_ctx: - # inner_ctx inherits pytree_options, but overrides/adds array_options - with ocp.Context(outer_ctx, - array_options=ocp.options.ArrayOptions() - ) as inner_ctx: - ocp.save_pytree(directory, tree) + ctx1 = ocp.Context() + ctx1.pytree.loading.partial_load = True + with ctx1 as outer_ctx: + # inner_ctx inherits partial_load, but mutates array saving + ctx2 = ocp.Context(outer_ctx) + ctx2.array.saving.use_zarr3 = False + with ctx2 as inner_ctx: + ocp.load_pytree(directory, tree) Context is not shared across threads:: from concurrent.futures import ThreadPoolExecutor - import orbax.checkpoint as ocp + from orbax.checkpoint import v1 as ocp executor = ThreadPoolExecutor(max_workers=1) - with ocp.Context( - pytree_options=ocp.options.PyTreeOptions() - ): # Thread #1 creates Context. - # The following save_pytree call is executed in Thread #2, which sees + ctx = ocp.Context() + ctx.pytree.loading.partial_load = True + with ctx: # Thread #1 creates Context. + # The following load_pytree call is executed in Thread #2, which sees # a "default" Context, NOT the one created above. - executor.submit(ocp.save_pytree, directory, tree) - + executor.submit(ocp.load_pytree, directory, tree) Attributes: - pytree_options: Options for PyTree checkpointing. See + pytree: Options for PyTree checkpointing. See :class:`~orbax.checkpoint.experimental.v1.options.PyTreeOptions`. - array_options: Options for saving and loading array (and array-like - objects). See + array: Options for saving and loading array (and array-like objects). See :class:`~orbax.checkpoint.experimental.v1.options.ArrayOptions`. - async_options: Options for controlling asynchronous behavior. See + asynchronous: Options for controlling asynchronous behavior. See :class:`~orbax.checkpoint.experimental.v1.options.AsyncOptions`. - multiprocessing_options: Options for multiprocessing behavior. See + multiprocessing: Options for multiprocessing behavior. See :class:`~orbax.checkpoint.experimental.v1.options.MultiprocessingOptions`. - file_options: Options for working with the file system. See + file: Options for working with the file system. See :class:`~orbax.checkpoint.experimental.v1.options.FileOptions`. - checkpointables_options: Options for controlling checkpointables behavior. - See + checkpointables: Options for controlling checkpointables behavior. See :class:`~orbax.checkpoint.experimental.v1.options.CheckpointablesOptions`. - pathways_options: Options for Pathways checkpointing. See + pathways: Options for Pathways checkpointing. See :class:`~orbax.checkpoint.experimental.v1.options.PathwaysOptions`. checkpoint_layout: The layout of the checkpoint. Defaults to ORBAX. See :class:`~orbax.checkpoint.experimental.v1.options.CheckpointLayout`. - deletion_options: Options for controlling deletion behavior. See + deletion: Options for controlling deletion behavior. See :class:`~orbax.checkpoint.experimental.v1.options.DeletionOptions`. - memory_options: Options for controlling memory limits during save / load. - See :class:`~orbax.checkpoint.experimental.v1.options.MemoryOptions`. + memory: Options for controlling memory limits during save / load. See + :class:`~orbax.checkpoint.experimental.v1.options.MemoryOptions`. """ - def __init__( - self, - context: Context | None = None, - *, - pytree_options: options_lib.PyTreeOptions | None = None, - array_options: options_lib.ArrayOptions | None = None, - async_options: options_lib.AsyncOptions | None = None, - multiprocessing_options: options_lib.MultiprocessingOptions | None = None, - file_options: options_lib.FileOptions | None = None, - checkpointables_options: options_lib.CheckpointablesOptions | None = None, - pathways_options: options_lib.PathwaysOptions | None = None, - checkpoint_layout: options_lib.CheckpointLayout | None = None, - deletion_options: options_lib.DeletionOptions | None = None, - memory_options: options_lib.MemoryOptions | None = None, - safetensors_options: options_lib.SafetensorsOptions | None = None, - ): - self._pytree_options = pytree_options or ( - context.pytree_options if context else options_lib.PyTreeOptions() - ) - self._array_options = array_options or ( - context.array_options if context else options_lib.ArrayOptions() - ) - self._async_options = async_options or ( - context.async_options if context else options_lib.AsyncOptions() - ) - self._multiprocessing_options = multiprocessing_options or ( - context.multiprocessing_options - if context - else options_lib.MultiprocessingOptions() - ) - self._file_options = file_options or ( - context.file_options if context else options_lib.FileOptions() - ) - self._checkpointables_options = checkpointables_options or ( - context.checkpointables_options - if context - else options_lib.CheckpointablesOptions() - ) - self._pathways_options = pathways_options or ( - context.pathways_options if context else options_lib.PathwaysOptions() - ) - self._checkpoint_layout = checkpoint_layout or ( - context.checkpoint_layout - if context - else options_lib.CheckpointLayout.ORBAX - ) - self._deletion_options = deletion_options or ( - context.deletion_options if context else options_lib.DeletionOptions() - ) - self._memory_options = memory_options or ( - context.memory_options if context else options_lib.MemoryOptions() - ) - self._safetensors_options = safetensors_options or ( - context.safetensors_options - if context - else options_lib.SafetensorsOptions() - ) + def __init__(self, context: Context | None = None): + if context is not None: + for k, v in context.__dict__.items(): + if k.endswith('_options') or k.endswith('_layout'): + setattr(self, k, copy.deepcopy(v)) + else: + self._pytree_options = options_lib.PyTreeOptions() + self._array_options = options_lib.ArrayOptions() + self._async_options = options_lib.AsyncOptions() + self._multiprocessing_options = options_lib.MultiprocessingOptions() + self._file_options = options_lib.FileOptions() + self._checkpointables_options = options_lib.CheckpointablesOptions() + self._pathways_options = options_lib.PathwaysOptions() + self._checkpoint_layout = options_lib.CheckpointLayout.ORBAX + self._deletion_options = options_lib.DeletionOptions() + self._memory_options = options_lib.MemoryOptions() + self._safetensors_options = options_lib.SafetensorsOptions() + + # --- Short-name properties for mutable user configuration --- + + @property + def array(self) -> options_lib.ArrayOptions: + return self._array_options + + @property + def asynchronous(self) -> options_lib.AsyncOptions: + return self._async_options + + @property + def pytree(self) -> options_lib.PyTreeOptions: + return self._pytree_options + + @property + def file(self) -> options_lib.FileOptions: + return self._file_options + + @property + def multiprocessing(self) -> options_lib.MultiprocessingOptions: + return self._multiprocessing_options + + @property + def checkpointables(self) -> options_lib.CheckpointablesOptions: + return self._checkpointables_options + + @property + def pathways(self) -> options_lib.PathwaysOptions: + return self._pathways_options + + @property + def deletion(self) -> options_lib.DeletionOptions: + return self._deletion_options + + @property + def memory(self) -> options_lib.MemoryOptions: + return self._memory_options + + @property + def safetensors(self) -> options_lib.SafetensorsOptions: + return self._safetensors_options + + @property + def checkpoint_layout(self) -> options_lib.CheckpointLayout: + return self._checkpoint_layout + + @checkpoint_layout.setter + def checkpoint_layout(self, value: options_lib.CheckpointLayout) -> None: + self._checkpoint_layout = value + + # TODO: b/513156122 - Migrate internal read sites to short-hand properties and + # remove legacy aliases in the next refactor. + # --- Legacy aliases for internal read access compatibility --- @property def pytree_options(self) -> options_lib.PyTreeOptions: @@ -208,10 +223,6 @@ def checkpointables_options(self) -> options_lib.CheckpointablesOptions: def pathways_options(self) -> options_lib.PathwaysOptions: return self._pathways_options - @property - def checkpoint_layout(self) -> options_lib.CheckpointLayout: - return self._checkpoint_layout - @property def deletion_options(self) -> options_lib.DeletionOptions: return self._deletion_options diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py index 29c3d3296..816f8bd24 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py @@ -44,12 +44,11 @@ def test_default_context(self): self.assertEqual(ctx.array_options, ArrayOptions()) def test_get_context_with_default(self): - default_ctx = ocp.Context( - array_options=ArrayOptions(saving=ArrayOptions.Saving(use_ocdbt=False)) - ) - custom_ctx = ocp.Context( - array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)) - ) + default_ctx = ocp.Context() + default_ctx.array.saving.use_ocdbt = False + + custom_ctx = ocp.Context() + custom_ctx.array.saving.use_zarr3 = False with self.subTest("no context set, no default provided"): ctx = context_lib.get_context() @@ -75,53 +74,30 @@ def test_get_context_with_default(self): self.assertEqual(ctx.array_options, ArrayOptions()) def test_custom_context(self): - with ocp.Context( - array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)) - ): - ctx = fake_checkpoint_operation() - self.assertEqual( - ctx.array_options, - ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)), - ) - - context = ocp.Context( - array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)) - ) - with context: - ctx = fake_checkpoint_operation() + ctx = ocp.Context() + ctx.array.saving.use_zarr3 = False + with ctx: + ctx_2 = fake_checkpoint_operation() self.assertEqual( - ctx.array_options, + ctx_2.array_options, ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)), ) def test_custom_context_in_separate_thread_becomes_default(self): with futures.ThreadPoolExecutor(max_workers=1) as executor: - with ocp.Context( - array_options=ArrayOptions( - saving=ArrayOptions.Saving(use_zarr3=False) - ) - ): - future = executor.submit(fake_checkpoint_operation) - ctx = future.result() - self.assertEqual(ctx.array_options, ArrayOptions()) - - with ocp.Context( - array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)) - ): - with futures.ThreadPoolExecutor(max_workers=1) as executor: + ctx = ocp.Context() + ctx.array.saving.use_zarr3 = False + with ctx: future = executor.submit(fake_checkpoint_operation) - ctx = future.result() - self.assertEqual(ctx.array_options, ArrayOptions()) + active = future.result() + self.assertEqual(active.array.saving.use_zarr3, True) def test_custom_context_in_same_thread_remains_custom(self): def test_fn(): - with ocp.Context( - array_options=ArrayOptions( - saving=ArrayOptions.Saving(use_zarr3=False) - ) - ): - ctx = fake_checkpoint_operation() - return ctx + context = ocp.Context() + context.array.saving.use_zarr3 = False + with context: + return fake_checkpoint_operation() with futures.ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(test_fn) @@ -132,92 +108,45 @@ def test_fn(): ) def test_nested_contexts(self): - with ocp.Context( - array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)) - ): - ctx = fake_checkpoint_operation() + ctx1 = ocp.Context() + ctx1.array.saving.use_zarr3 = False + with ctx1: self.assertEqual( - ctx.array_options, - ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)), + fake_checkpoint_operation().array.saving.use_zarr3, False ) - with ocp.Context( - array_options=ArrayOptions( - saving=ArrayOptions.Saving(use_ocdbt=False) - ) - ): - ctx = fake_checkpoint_operation() - self.assertEqual( - ctx.array_options, - ArrayOptions( - saving=ArrayOptions.Saving(use_zarr3=True, use_ocdbt=False) - ), - ) - - ctx = fake_checkpoint_operation() - self.assertEqual( - ctx.array_options, - ArrayOptions( - saving=ArrayOptions.Saving(use_zarr3=False, use_ocdbt=True) - ), - ) + ctx2 = ocp.Context() # absolute default slate by default + ctx2.array.saving.use_ocdbt = False + with ctx2: + active = fake_checkpoint_operation() + self.assertEqual(active.array.saving.use_zarr3, True) + self.assertEqual(active.array.saving.use_ocdbt, False) def test_nested_contexts_with_inheritance(self): - default_ctx = fake_checkpoint_operation() - self.assertEqual( - default_ctx.array_options, - ArrayOptions( - saving=ArrayOptions.Saving(use_ocdbt=True, use_zarr3=True) - ), - ) - self.assertEqual( - default_ctx.file_options, - FileOptions(path_permission_mode=None), - ) - - with ocp.Context( - array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)), - file_options=FileOptions(path_permission_mode=0o750), - ): - ctx = fake_checkpoint_operation() - self.assertEqual( - ctx.array_options, - ArrayOptions( - saving=ArrayOptions.Saving(use_ocdbt=True, use_zarr3=False) - ), - ) - self.assertEqual( - ctx.file_options, - FileOptions(path_permission_mode=0o750), - ) - with ocp.Context( - ctx, - array_options=ArrayOptions( - saving=ArrayOptions.Saving(use_ocdbt=False) - ), - ): - ctx = fake_checkpoint_operation() - self.assertEqual( - ctx.array_options, - ArrayOptions( - saving=ArrayOptions.Saving(use_ocdbt=False, use_zarr3=True) - ), - ) - self.assertEqual( - ctx.file_options, - FileOptions(path_permission_mode=0o750), - ) - ctx = fake_checkpoint_operation() - self.assertEqual( - ctx.array_options, - ArrayOptions( - saving=ArrayOptions.Saving(use_ocdbt=True, use_zarr3=False) - ), - ) - self.assertEqual( - ctx.file_options, - FileOptions(path_permission_mode=0o750), - ) + ctx1 = ocp.Context() + ctx1.array.saving.use_zarr3 = False + ctx1.file.path_permission_mode = 0o750 + with ctx1: + active1 = fake_checkpoint_operation() + self.assertEqual(active1.array.saving.use_zarr3, False) + self.assertEqual(active1.file.path_permission_mode, 0o750) + + ctx2 = ocp.Context(active1) + ctx2.array.saving.use_ocdbt = False + with ctx2: + active2 = fake_checkpoint_operation() + self.assertEqual(active2.array.saving.use_zarr3, False) # inherited + self.assertEqual(active2.array.saving.use_ocdbt, False) # mutated + self.assertEqual(active2.file.path_permission_mode, 0o750) # inherited + + active3 = fake_checkpoint_operation() + self.assertEqual(active3.array.saving.use_zarr3, False) + self.assertEqual(active3.array.saving.use_ocdbt, True) + self.assertEqual(active3.file.path_permission_mode, 0o750) + + def test_legacy_context_kwargs_fail(self): + with self.assertRaises(TypeError): + ocp.Context(array_options=None) # pytype: disable=wrong-keyword-args if __name__ == "__main__": diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py index c25fd9502..3e593eac8 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py @@ -16,9 +16,10 @@ from __future__ import annotations +import copy import dataclasses import enum -from typing import Any, Callable, Protocol, Type +from typing import Any, Callable, Protocol from etils import epath import numpy as np @@ -28,14 +29,13 @@ from orbax.checkpoint._src.path import atomicity_types from orbax.checkpoint._src.serialization import pathways_types from orbax.checkpoint.experimental.v1._src.handlers import registration -from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types from orbax.checkpoint.experimental.v1._src.tree import types as tree_types -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class AsyncOptions: """Options used to configure async behavior. @@ -75,8 +75,15 @@ def v0(self) -> v0_options_lib.AsyncOptions: create_directories_asynchronously=self.create_directories_asynchronously, ) + def __deepcopy__(self, memo): + return AsyncOptions( + timeout_secs=self.timeout_secs, + post_finalization_callback=self.post_finalization_callback, + create_directories_asynchronously=self.create_directories_asynchronously, + ) + -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class MultiprocessingOptions: """Options used to configure multiprocessing behavior. @@ -129,7 +136,7 @@ def v0(self) -> v0_options_lib.MultiprocessingOptions: # pyformat: disable -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class FileOptions: """Options used to configure checkpoint directories and files. @@ -171,11 +178,19 @@ def v0(self) -> v0_options_lib.FileOptions: path_permission_mode=self.path_permission_mode, ) + def __deepcopy__(self, memo): + res = copy.copy(self) + if getattr(res, 'cns_file_options', None) is not None: + res.cns_file_options = copy.deepcopy(res.cns_file_options, memo) + if getattr(res, 'tfhub_file_options', None) is not None: + res.tfhub_file_options = copy.deepcopy(res.tfhub_file_options, memo) + return res + # pyformat: enable -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class PyTreeOptions: """Options used to configure PyTree saving and loading. @@ -191,7 +206,7 @@ class PyTreeOptions: override the default Leaf Handler Registry. """ - @dataclasses.dataclass(frozen=True, kw_only=True) + @dataclasses.dataclass(kw_only=True) class Saving: """Options for saving PyTrees. @@ -202,7 +217,7 @@ class Saving: dataclasses.field(default_factory=tree_metadata.PyTreeMetadataOptions) ) - @dataclasses.dataclass(frozen=True, kw_only=True) + @dataclasses.dataclass(kw_only=True) class Loading: """Options for loading PyTrees. @@ -216,68 +231,60 @@ class Loading: loading: Loading = dataclasses.field(default_factory=Loading) leaf_handler_registry: serialization_types.LeafHandlerRegistry | None = None + def __deepcopy__(self, memo): + return PyTreeOptions( + saving=copy.deepcopy(self.saving, memo), + loading=copy.deepcopy(self.loading, memo), + leaf_handler_registry=self.leaf_handler_registry, + ) + -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class ArrayOptions: """Options used to configure array saving and loading. This dataclass defines the high-level configuration parameters for array checkpointing operations within the Orbax framework. Because it is defined - as a frozen, keyword-only dataclass, instances are strictly immutable once - created, and all parameters must be explicitly specified by their keyword - names during initialization. + as a keyword-only dataclass, instances map mutable option dimensions. Example: To configure array options with specific saving formats and loading behaviors we can do so like this:: - from orbax.checkpoint.v1.options import ArrayOptions + from orbax.checkpoint import v1 as ocp - options = ArrayOptions( - saving=ArrayOptions.Saving( - use_zarr3=True, - use_compression=False, - ), - loading=ArrayOptions.Loading( - enable_padding_and_truncation=True - ) + ctx = ocp.Context() + ctx.array.saving.use_zarr3 = True + ctx.array.saving.use_compression = False + ctx.array.loading.enable_padding_and_truncation = True To save certain leaves in float16, while others in float32, we can use `scoped_storage_options_creator` like so:: import jax import jax.numpy as jnp - from orbax.checkpoint.v1 import options as ocp_options + from orbax.checkpoint import v1 as ocp - def create_opts_fn(keypath, value): + def create_opts_fn(keypath, value, storage): if 'small' in jax.tree_util.keystr(keypath): - return ocp_options.ArrayOptions.Saving.StorageOptions( - dtype=jnp.float16 - ) - return None # Fall back to global `storage_options` - - array_options = ocp_options.ArrayOptions( - saving=ocp_options.ArrayOptions.Saving( - storage_options=ocp_options.ArrayOptions.Saving.StorageOptions( - dtype=jnp.float32 - ), - scoped_storage_options_creator=create_opts_fn - ) + storage.dtype = jnp.float16 - ) + ctx = ocp.Context() + ctx.array.saving.storage_options.dtype = jnp.float32 + ctx.array.saving.scoped_storage_options_creator = create_opts_fn Attributes: saving: Options for saving arrays. loading: Options for loading arrays. """ - @dataclasses.dataclass(frozen=True, kw_only=True) + @dataclasses.dataclass(kw_only=True) class Saving: """Options for saving arrays. Attributes: - storage_options: Options used to customize array storage behavior for - all leaves at a global level. See below. + storage_options: Options used to customize array storage behavior for all + leaves at a global level. See below. use_ocdbt: Enables OCDBT format. use_zarr3: If True, use Zarr3 format. use_compression: If True, use ZSTD compression. @@ -306,23 +313,10 @@ class Saving: array_metadata_store: Store to manage per host ArrayMetadata. To disable ArrayMetadata persistence, set it to None. scoped_storage_options_creator: A function that, when dealing with - PyTrees, is applied to every leaf. If it returns an - :py:class:`ArrayOptions.Saving.StorageOptions`, its fields take - precedence when merging if they are set to non-None or non-default - values with respect to `storage_options`. If it returns `None`, - `storage_options` is used as a default for all fields. It is called - similar to: `jax.tree.map_with_path(scoped_storage_options_creator, - pytree_to_save)`. + PyTrees, is applied to every leaf to mutate storage options in-place. """ - class ScopedStorageOptionsCreator(Protocol): - - def __call__( - self, key: tree_types.PyTreeKeyPath, value: Any - ) -> ArrayOptions.Saving.StorageOptions | None: - ... - - @dataclasses.dataclass(frozen=True, kw_only=True) + @dataclasses.dataclass(kw_only=True) class StorageOptions: """Options used to customize array storage behavior for individual leaves. @@ -347,6 +341,16 @@ class StorageOptions: chunk_byte_size: int | None = None shard_axes: tuple[int, ...] | None = None + class ScopedStorageOptionsCreator(Protocol): + + def __call__( + self, + key: tree_types.PyTreeKeyPath, + value: Any, + storage: ArrayOptions.Saving.StorageOptions, + ) -> ArrayOptions.Saving.StorageOptions | None: + ... + storage_options: StorageOptions = dataclasses.field( default_factory=StorageOptions ) @@ -366,7 +370,13 @@ class StorageOptions: ) scoped_storage_options_creator: ScopedStorageOptionsCreator | None = None - @dataclasses.dataclass(frozen=True, kw_only=True) + def __deepcopy__(self, memo): + res = copy.copy(self) + if hasattr(res, 'storage_options') and res.storage_options is not None: + res.storage_options = copy.deepcopy(res.storage_options, memo) + return res + + @dataclasses.dataclass(kw_only=True) class Loading: """Options for loading arrays. @@ -388,7 +398,7 @@ class Loading: replicas. """ - @dataclasses.dataclass(frozen=True, kw_only=True) + @dataclasses.dataclass(kw_only=True) class LoadAndBroadcastOptions: """Used to configure load-and-broadcast behavior in multi-replica loading. @@ -415,41 +425,33 @@ class LoadAndBroadcastOptions: default_factory=LoadAndBroadcastOptions ) + def __deepcopy__(self, memo): + res = copy.copy(self) + if ( + hasattr(res, 'load_and_broadcast_options') + and res.load_and_broadcast_options is not None + ): + res.load_and_broadcast_options = copy.deepcopy( + res.load_and_broadcast_options, memo + ) + return res + saving: Saving = dataclasses.field(default_factory=Saving) loading: Loading = dataclasses.field(default_factory=Loading) + def __deepcopy__(self, memo): + return ArrayOptions( + saving=copy.deepcopy(self.saving, memo), + loading=copy.deepcopy(self.loading, memo), + ) + -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class CheckpointablesOptions: """Options used to configure `checkpointables` save/load behavior. Primarily intended for registering custom :py:class:`.CheckpointableHandler` - classes. You can specify a registry directly, or use `create_with_handlers`. - For example:: - - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - FooHandler(), - bar=BarHandler(), - ) - ) - with ocp.Context(checkpointables_options=checkpointables_options)): - ocp.save_checkpointables(directory, dict(foo=Foo(...), bar=Bar(...))) - - In this example, `FooHandler` is registered generically, which means that any - checkpointable that is handleable by `FooHandler` can be saved/loaded (a - `Foo` object in this case). In contrast, `BarHandler` is explicitly tied to - the name `bar`, which means that only a checkpointable that is both handleable - by `BarHandler` and has the name `bar` can handled by this `BarHandler`. - - Recall that a global registry also exists, containing core handlers like - :py:class:`.PyTreeHandler` and :py:class:`.JsonHandler`. Use - `ocp.handlers.register_handler` to register a handler globally. - - Note that registration order matters. For example, if saving a dict containing - only strings, both :py:class:`.JsonHandler` and :py:class:`.PyTreeHandler` are - capable of handling this object, but :py:class:`.JsonHandler` will be selected - first because it is registered first. + classes via direct registry binding. Attributes: registry: A :py:class:`.CheckpointableHandlerRegistry` that is used to @@ -463,21 +465,11 @@ class CheckpointablesOptions: ) ) - @classmethod - def create_with_handlers( - cls, - *handlers: Type[handler_types.CheckpointableHandler], - **named_handlers: Type[handler_types.CheckpointableHandler], - ) -> CheckpointablesOptions: - registry = registration.local_registry(include_global_registry=True) - for handler in handlers: - registry.add(handler, checkpointable_name=None) - for name, handler in named_handlers.items(): - registry.add(handler, checkpointable_name=name) - return cls(registry=registry) - - -@dataclasses.dataclass(frozen=True, kw_only=True) + def __deepcopy__(self, memo): + return CheckpointablesOptions(registry=self.registry) + + +@dataclasses.dataclass(kw_only=True) class PathwaysOptions: """Options used to configure Pathways saving and loading. @@ -488,7 +480,7 @@ class PathwaysOptions: checkpointing_impl: pathways_types.CheckpointingImpl | None = None -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class DeletionOptions: """Options used to configure checkpoint deletion behavior. @@ -496,7 +488,7 @@ class DeletionOptions: gcs_deletion_options: Deletion options specific to GCS. """ - @dataclasses.dataclass(frozen=True, kw_only=True) + @dataclasses.dataclass(kw_only=True) class GcsDeletionOptions: """Deletion options specific to GCS. @@ -526,7 +518,7 @@ class GcsDeletionOptions: -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class MemoryOptions: """Options for configuring memory limits for save / load. @@ -573,8 +565,16 @@ class MemoryOptions: transfer_concurrent_bytes: int | None = None is_prioritized_key_fn: serialization_types.IsPrioritizedKeyFn | None = None + def __deepcopy__(self, memo): + return MemoryOptions( + write_concurrent_bytes=self.write_concurrent_bytes, + read_concurrent_bytes=self.read_concurrent_bytes, + transfer_concurrent_bytes=self.transfer_concurrent_bytes, + is_prioritized_key_fn=self.is_prioritized_key_fn, + ) + -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class SafetensorsOptions: """Options for configuring Safetensors loading. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options_test.py index fecb1a529..98913f940 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options_test.py @@ -51,14 +51,12 @@ def is_prioritized_key_fn(path): del path return True - memory_options = ocp_options.MemoryOptions( - write_concurrent_bytes=1024, - read_concurrent_bytes=2048, - transfer_concurrent_bytes=512, - is_prioritized_key_fn=is_prioritized_key_fn, - ) - - with context_lib.Context(memory_options=memory_options): + ctx = context_lib.Context() + ctx.memory.write_concurrent_bytes = 1024 + ctx.memory.read_concurrent_bytes = 2048 + ctx.memory.transfer_concurrent_bytes = 512 + ctx.memory.is_prioritized_key_fn = is_prioritized_key_fn + with ctx: with mock.patch( 'orbax.checkpoint._src.handlers.base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler', autospec=True, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py index 8e3e430a0..6b6e89dad 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py @@ -72,20 +72,20 @@ class JsonHandler(CheckpointableHandler[JsonType, None]): config = {'learning_rate': 0.01, 'batch_size': 32} - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - experiment_config=ocp.handlers.JsonHandler( - filename='experiment_config.json' - ) - ) + registry = ocp.handlers.local_registry() + registry.add( + ocp.handlers.JsonHandler(filename='experiment_config.json'), + checkpointable_name='experiment_config', ) - with ocp.Context(checkpointables_options=checkpointables_options): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: ocp.save_checkpointables(path, dict(experiment_config=config)) Attributes: filename: An optional specific filename to use for saving and loading the - JSON data. If not provided, the handler will fall back to a default set - of supported JSON filenames. + JSON data. If not provided, the handler will fall back to a default set of + supported JSON filenames. """ def __init__(self, filename: str | None = None): diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/proto_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/proto_handler.py index 3d8feb845..64cbe768a 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/proto_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/proto_handler.py @@ -63,20 +63,20 @@ class ProtoHandler( # Assuming MyProtoMessage is your compiled protobuf class my_proto_msg = MyProtoMessage(config_field="value") - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - proto_config=ocp.handlers.ProtoHandler( - filename="model_config.pbtxt" - ) - ) + registry = ocp.handlers.local_registry() + registry.add( + ocp.handlers.ProtoHandler(filename="model_config.pbtxt"), + checkpointable_name="proto_config", ) - with ocp.Context(checkpointables_options=checkpointables_options): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: ocp.save_checkpointables(path, dict(proto_config=my_proto_msg)) Attributes: filename (str): An optional filename used for saving and loading the - protobuf data. If not provided, it defaults to a standard internal - default filename. + protobuf data. If not provided, it defaults to a standard internal default + filename. """ def __init__( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py index 9eda5edda..222371452 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py @@ -271,12 +271,13 @@ class PyTreeHandler(CheckpointableHandler[PyTree, PyTree]): state_pytree = {'weights': [1.0, 2.0], 'bias': 0.0} - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - model_state=ocp.handlers.PyTreeHandler() - ) + registry = ocp.handlers.local_registry() + registry.add( + ocp.handlers.PyTreeHandler(), checkpointable_name='model_state' ) - with ocp.Context(checkpointables_options=checkpointables_options): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: ocp.save_checkpointables(path, dict(model_state=state_pytree)) Attributes: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py index 321304373..0200c7e5c 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py @@ -39,10 +39,9 @@ # to a new v1 handler class. registry.add(BazHandler, secondary_typestrs=['OldBazHandlerTypestr']) - checkpointables_options = ocp.options.CheckpointablesOptions( - registry=registry - ) - with ocp.Context(checkpointables_options=checkpointables_options): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: ocp.save_checkpointables(...) Handler resolution for saving/loading follows this logic: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py index 8c83cb258..6d4c4685b 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py @@ -105,7 +105,8 @@ async def get_checkpoint_layout( f"Could not recognize the checkpoint at {path} as a valid" f" {layout_enum.value} checkpoint. If you are trying to load a" " checkpoint that does not conform to the standard Orbax format, use" - " `ocp.Context(layout=...)` to specify the expected checkpoint layout." + " `ctx.checkpoint_layout = ...` to specify the expected checkpoint" + " layout." ) from e diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py index 36da779fe..94cb60cf8 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py @@ -75,9 +75,9 @@ def setUp(self): ) def test_load_safetensors_checkpoint(self): - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx_builder = context_lib.Context() + ctx_builder.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx_builder: pytree = loading.load_pytree(self.safetensors_path) self.assertIsInstance(pytree, dict) np.testing.assert_array_equal(pytree['a'], self.object_to_save['a']) @@ -97,7 +97,9 @@ def test_load_orbax_checkpointables_checkpoint(self): ) def test_load_bad_path_orbax_ckpt(self, layout_enum): # User provides a directory of Orbax checkpoints, not specific one. - with context_lib.Context(checkpoint_layout=layout_enum): + ctx_builder = context_lib.Context() + ctx_builder.checkpoint_layout = layout_enum + with ctx_builder: with self.assertRaises(InvalidLayoutError): loading.load_pytree( epath.Path(self.test_dir.full_path), @@ -108,7 +110,9 @@ def test_load_bad_path_orbax_ckpt(self, layout_enum): ) def test_load_bad_path_safetensors_ckpt(self, layout_enum): # User provides a empty directory of SafeTensors checkpoints, not a file. - with context_lib.Context(checkpoint_layout=layout_enum): + ctx_builder = context_lib.Context() + ctx_builder.checkpoint_layout = layout_enum + with ctx_builder: with self.assertRaises(InvalidLayoutError): loading.load_pytree( epath.Path(self.test_dir_safetensors.full_path), @@ -118,9 +122,9 @@ def test_load_safetensors_ckpt_from_dir(self): safetensors_dir = epath.Path(self.test_dir_safetensors.full_path) safetensors_path = safetensors_dir / 'model.safetensors' np_save_file(self.object_to_save, safetensors_path) - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx_builder = context_lib.Context() + ctx_builder.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx_builder: pytree = loading.load_pytree(safetensors_dir) self.assertIsInstance(pytree, dict) np.testing.assert_array_equal(pytree['a'], self.object_to_save['a']) @@ -180,7 +184,9 @@ async def sleep_and_load(*args, **kwargs): else: directory = self.orbax_pytree_path - with context_lib.Context(checkpoint_layout=layout): + ctx_builder = context_lib.Context() + ctx_builder.checkpoint_layout = layout + with ctx_builder: if layout != options_lib.CheckpointLayout.SAFETENSORS: with self.assertRaises(NotImplementedError): loading.load_pytree_async(directory) @@ -197,9 +203,9 @@ async def sleep_and_load(*args, **kwargs): # TODO(b/431045454): Add tests for abstract_checkpointables. def test_load_auto_resolution_mode_orbax(self): - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.ORBAX - ): + ctx_builder = context_lib.Context() + ctx_builder.checkpoint_layout = options_lib.CheckpointLayout.ORBAX + with ctx_builder: loaded_orbax = loading.load_pytree( self.orbax_pytree_path, checkpointable_name=checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, @@ -207,9 +213,9 @@ def test_load_auto_resolution_mode_orbax(self): test_utils.assert_tree_equal(self, self.object_to_save, loaded_orbax) def test_load_auto_resolution_mode_safetensors(self): - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx_builder = context_lib.Context() + ctx_builder.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx_builder: loaded_safe = loading.load_pytree( self.safetensors_path, checkpointable_name=checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, @@ -227,9 +233,9 @@ def test_load_auto_multiple_checkpointables_priority(self): saving.save_checkpointables(multiple_path, checkpointables) # Triggering AUTO loading mode should prioritize resolving 'pytree'. - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.ORBAX - ): + ctx_builder = context_lib.Context() + ctx_builder.checkpoint_layout = options_lib.CheckpointLayout.ORBAX + with ctx_builder: loaded = loading.load_pytree(multiple_path) test_utils.assert_tree_equal(self, checkpointables['pytree'], loaded) @@ -242,9 +248,9 @@ def test_load_auto_non_pytree_fallback(self): fallback_path = epath.Path(self.test_dir.full_path) / 'fallback_checkpoint' saving.save_checkpointables(fallback_path, custom_checkpointables) - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.ORBAX - ): + ctx_builder = context_lib.Context() + ctx_builder.checkpoint_layout = options_lib.CheckpointLayout.ORBAX + with ctx_builder: loaded = loading.load_pytree( fallback_path, checkpointable_name=checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py index 18e44b50f..8eabdef35 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py @@ -146,9 +146,9 @@ def test_pytree_metadata_safetensors(self): 'y': jax.ShapeDtypeStruct(shape=(3,), dtype=np.int64), } - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: ckpt_metadata = ocp.pytree_metadata(st_path) self.assertIsInstance(ckpt_metadata, metadata_types.CheckpointMetadata) @@ -165,9 +165,9 @@ def test_pytree_metadata_safetensors(self): # Test invalid path with self.assertRaises(ocp.errors.InvalidLayoutError): - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: ocp.pytree_metadata(self.directory) @@ -193,15 +193,12 @@ class CheckpointablesMetadataTest(absltest.TestCase): def setUp(self): super().setUp() self.directory = epath.Path(self.create_tempdir().full_path) / 'ckpt' - checkpointables_options = ( - options_lib.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - handler_utils.BarHandler, - ) - ) - self.enter_context( - context_lib.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + registry.add(handler_utils.BarHandler) + ctx = context_lib.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) checkpointables = { 'foo': Foo(1, 'foo'), 'bar': Bar(2, 'bar'), @@ -247,9 +244,9 @@ def test_checkpointables_metadata_safetensors(self): 'item2': jax.ShapeDtypeStruct(shape=(1,), dtype=np.int32), } - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: ckpt_metadata = ocp.checkpointables_metadata(st_path) self.assertIsInstance(ckpt_metadata, metadata_types.CheckpointMetadata) @@ -268,9 +265,9 @@ def test_checkpointables_metadata_safetensors(self): # Test invalid path with self.assertRaises(ocp.errors.InvalidLayoutError): - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: ocp.checkpointables_metadata(self.directory) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py index 23bb157a5..591f1e755 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py @@ -119,7 +119,9 @@ def _create_v0_savearg( return type_handlers_v0.SaveArgs( dtype=jnp.dtype(storage_options.dtype) if storage_options.dtype else None, chunk_byte_size=storage_options.chunk_byte_size, - shard_axes=storage_options.shard_axes, + shard_axes=storage_options.shard_axes + if storage_options.shard_axes is not None + else tuple(), ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py index 1474b28b9..44fe6fe75 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py @@ -106,7 +106,9 @@ def _create_v0_savearg( return type_handlers_v0.SaveArgs( dtype=np.dtype(storage_options.dtype) if storage_options.dtype else None, chunk_byte_size=storage_options.chunk_byte_size, - shard_axes=storage_options.shard_axes, + shard_axes=storage_options.shard_axes + if storage_options.shard_axes is not None + else tuple(), ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py index 5b426e938..5b9eaefa7 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py @@ -14,6 +14,8 @@ """Utility functions for serialization.""" +import copy + from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.tree import types as tree_types @@ -26,10 +28,8 @@ def resolve_storage_options( """Resolves storage options using a global default and a per-leaf creator. When dealing with PyTrees, `scoped_storage_options_creator` is applied to - every leaf. Its fields take precedence when merging if they are set to - non-None or non-default values with respect to the global `storage_options`. - If the creator returns `None`, the global `storage_options` is used for all - fields. + every leaf to mutate its fields in-place on an isolated copy of the global + `storage_options`. Args: keypath: The PyTree keypath of the array being saved. @@ -40,38 +40,17 @@ def resolve_storage_options( The resolved StorageOptions containing storage options. """ global_opts = array_saving_options.storage_options - if global_opts is None: - global_opts = options_lib.ArrayOptions.Saving.StorageOptions() - - fn = array_saving_options.scoped_storage_options_creator - individual_opts = None - if fn is not None: - individual_opts = fn(keypath, value) + resolved = ( + copy.copy(global_opts) + if global_opts is not None + else options_lib.ArrayOptions.Saving.StorageOptions() + ) - if individual_opts is not None: - resolved_dtype = ( - individual_opts.dtype - if individual_opts.dtype is not None - else global_opts.dtype - ) - resolved_chunk_byte_size = ( - individual_opts.chunk_byte_size - if individual_opts.chunk_byte_size is not None - else global_opts.chunk_byte_size + if array_saving_options.scoped_storage_options_creator is not None: + ret = array_saving_options.scoped_storage_options_creator( + keypath, value, resolved ) - resolved_shard_axes = ( - individual_opts.shard_axes - if individual_opts.shard_axes is not None - else global_opts.shard_axes - ) - else: - resolved_dtype = global_opts.dtype - resolved_chunk_byte_size = global_opts.chunk_byte_size - resolved_shard_axes = global_opts.shard_axes - - return options_lib.ArrayOptions.Saving.StorageOptions( - dtype=resolved_dtype, - chunk_byte_size=resolved_chunk_byte_size, - shard_axes=resolved_shard_axes if resolved_shard_axes is not None else (), - ) + if ret is not None: + resolved = ret + return resolved diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py index 897b959f9..a2eaa88d9 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py @@ -24,14 +24,30 @@ from orbax.checkpoint.experimental.v1._src.serialization import options_resolution +def cb_overriding_global(k, v, s): + s.dtype = np.int16 + + +def cb_overriding_all(k, v, s): + s.dtype = np.float32 + s.chunk_byte_size = 32_000_000 + s.shard_axes = (1,) + + +def cb_jnp_converter(k, v, s): + s.dtype = jnp.bfloat16 + + +def cb_empty_axes(k, v, s): + s.shard_axes = () + + class OptionsResolutionTest(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name='callback_overriding_global', - callback=lambda k, v: options_lib.ArrayOptions.Saving.StorageOptions( - dtype=np.int16 - ), + callback=cb_overriding_global, expected_storage_options=options_lib.ArrayOptions.Saving.StorageOptions( dtype=np.int16, chunk_byte_size=16_000_000, @@ -40,11 +56,7 @@ class OptionsResolutionTest(parameterized.TestCase): ), dict( testcase_name='callback_overriding_all', - callback=lambda k, v: options_lib.ArrayOptions.Saving.StorageOptions( - dtype=np.float32, - chunk_byte_size=32_000_000, - shard_axes=(1,), - ), + callback=cb_overriding_all, expected_storage_options=options_lib.ArrayOptions.Saving.StorageOptions( dtype=np.float32, chunk_byte_size=32_000_000, @@ -62,9 +74,7 @@ class OptionsResolutionTest(parameterized.TestCase): ), dict( testcase_name='jnp_dtype_converter', - callback=lambda k, v: options_lib.ArrayOptions.Saving.StorageOptions( - dtype=jnp.bfloat16, - ), + callback=cb_jnp_converter, expected_storage_options=options_lib.ArrayOptions.Saving.StorageOptions( dtype=jnp.bfloat16, chunk_byte_size=16_000_000, @@ -73,9 +83,7 @@ class OptionsResolutionTest(parameterized.TestCase): ), dict( testcase_name='empty_shard_axes_overrides_to_empty', - callback=lambda k, v: options_lib.ArrayOptions.Saving.StorageOptions( - shard_axes=(), - ), + callback=cb_empty_axes, expected_storage_options=options_lib.ArrayOptions.Saving.StorageOptions( dtype=np.int32, chunk_byte_size=16_000_000, @@ -88,30 +96,22 @@ def test_resolve_storage_options( callback, expected_storage_options, ): - # Global options global_storage = options_lib.ArrayOptions.Saving.StorageOptions( dtype=np.int32, chunk_byte_size=16_000_000, shard_axes=(0,), ) - context = context_lib.Context( - array_options=options_lib.ArrayOptions( - saving=options_lib.ArrayOptions.Saving( - storage_options=global_storage, - scoped_storage_options_creator=callback, - ) - ), - ) + ctx = context_lib.Context() + ctx.array.saving.storage_options = global_storage + ctx.array.saving.scoped_storage_options_creator = callback - # Dummy param keypath = (jax.tree_util.DictKey(key='foo'),) value = np.ones((2, 2)) resolved_options = options_resolution.resolve_storage_options( - keypath, value, context.array_options.saving + keypath, value, ctx.array.saving ) - self.assertEqual(resolved_options, expected_storage_options) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py index 139fd304a..f6703a233 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py @@ -77,7 +77,9 @@ def _create_v0_savearg( return type_handlers_v0.SaveArgs( dtype=np.dtype(storage_options.dtype) if storage_options.dtype else None, chunk_byte_size=storage_options.chunk_byte_size, - shard_axes=storage_options.shard_axes, + shard_axes=storage_options.shard_axes + if storage_options.shard_axes is not None + else tuple(), ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py index 54f601de9..2311ef4b3 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py @@ -121,11 +121,9 @@ def test_checkpointables_metadata_compatibility( ) ) - with ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: if error_type is None: loaded = ocp.checkpointables_metadata(path) # If the state checpointable is missing pytree metadata, then we expect diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpoints.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpoints.py index c31af89d3..fb07c4558 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpoints.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpoints.py @@ -85,11 +85,9 @@ def generate_v1_checkpoint(path: epath.Path) -> None: registry = registration.local_registry() registry.add(ocp.handlers.PyTreeHandler, checkpointable_name='state') registry.add(ocp.handlers.JsonHandler, checkpointable_name='metadata') - with ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: ocp.save_checkpointables(path, checkpointables) (path / 'descriptor').rmtree() # GOOGLE_INTERNAL diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py index ffe275ea0..73d2537b6 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py @@ -246,11 +246,9 @@ def test_load_checkpointables_compatibility( else: abstract_checkpointables = None - with ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: if error_type is None: loaded = ocp.load_checkpointables( path, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py index c9a63def2..223fe1b45 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py @@ -257,11 +257,9 @@ def test_load_pytree_compatibility( self.abstract_state if abstract_pytree_provided else None ) - with ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: if error_type is None: loaded = ocp.load_pytree( path, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py index 986c4633e..bfc22b1a1 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py @@ -186,11 +186,9 @@ def test_pytree_metadata_compatibility( is_pytree, ) - with ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: if error_type is None: loaded = ocp.pytree_metadata( path, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index edcc5e006..53c3d4342 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -204,9 +204,8 @@ async def mock_finalize(self_handler, directory): ) ) - context = ocp.Context( - async_options=ocp.options.AsyncOptions(timeout_secs=timeout_secs) - ) + context = ocp.Context() + context.asynchronous.timeout_secs = timeout_secs self.enter_context(context) start = time.time() @@ -417,11 +416,8 @@ def test_custom_array_type(self): handler_utils.LazyArrayHandler, ) - custom_context = ocp.Context( - pytree_options=ocp.options.PyTreeOptions( - leaf_handler_registry=custom_registry - ) - ) + custom_context = ocp.Context() + custom_context.pytree.leaf_handler_registry = custom_registry mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) @@ -456,11 +452,8 @@ def test_custom_array_type(self): secondary_typestrs=[types.typestr(handler_utils.LazyArrayHandler)], override=True, ) - custom_context2 = ocp.Context( - pytree_options=ocp.options.PyTreeOptions( - leaf_handler_registry=custom_registry2 - ) - ) + custom_context2 = ocp.Context() + custom_context2.pytree.leaf_handler_registry = custom_registry2 with custom_context2: loaded_as_jax_array = ocp.load_pytree(self.directory) self.assertIsInstance(loaded_as_jax_array['a'], jax.Array) @@ -603,18 +596,15 @@ def test_casting(self, original_dtype, save_dtype, load_dtype): 'numpy_array': np.arange(len(jax.devices()), dtype=load_dtype), } - scoped_storage_options_creator = ( - lambda k, v: ocp.options.ArrayOptions.Saving.StorageOptions( - dtype=save_dtype - ) + def scoped_storage_options_creator(k, v, s): + del k, v + s.dtype = save_dtype + + ctx = ocp.Context() + ctx.array.saving.scoped_storage_options_creator = ( + scoped_storage_options_creator ) - with ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving( - scoped_storage_options_creator=scoped_storage_options_creator - ) - ) - ): + with ctx: ocp.save_pytree(self.directory, tree) with self.subTest('with_abstract_tree'): @@ -743,15 +733,12 @@ def test_missing_keys(self): ) def test_custom_checkpointables(self): - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - handler_utils.BarHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + registry.add(handler_utils.BarHandler) + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) checkpointables = { 'pytree': self.numpy_pytree, 'foo': Foo(123, 'hi'), @@ -821,14 +808,11 @@ def test_save_checkpointables_ambiguous_resolution(self): 'two': {'c': 3, 'd': 4}, } directory = self.directory - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - one=handler_utils.DictHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.DictHandler, checkpointable_name='one') + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) ocp.save_checkpointables(directory, checkpointables) self.assertTrue((directory / 'one' / 'data.txt').exists()) self.assertFalse((directory / 'two' / 'data.txt').exists()) @@ -876,15 +860,12 @@ def test_abstract_pytree_types(self): test_utils.assert_tree_equal(self, self.pytree, loaded) def test_abstract_checkpointables_types(self): - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - handler_utils.BarHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + registry.add(handler_utils.BarHandler) + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) checkpointables = { 'foo': Foo(123, 'hi'), 'bar': Bar(456, 'bye'), @@ -911,14 +892,11 @@ def test_abstract_checkpointables_types(self): self.assertEqual(checkpointables, loaded) def test_async_directory_creation(self): - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) self.enter_context( mock.patch.object( async_utils, '_create_paths', _sleep_and_create_paths @@ -1098,13 +1076,9 @@ def test_partial_restore_omission(self): 'y': self.pytree['y'], } - with ocp.Context( - pytree_options=ocp.options.PyTreeOptions( - loading=ocp.options.PyTreeOptions.Loading( - partial_load=True, - ) - ) - ): + ctx = ocp.Context() + ctx.pytree.loading.partial_load = True + with ctx: loaded = ocp.load_pytree(self.directory, reference_pytree) test_utils.assert_tree_equal(self, expected, loaded) @@ -1185,13 +1159,9 @@ def test_load_and_broadcast(self): self.assertEqual( sharding.shard_shape((4, 32)), (4, 32 // partition_count) ) - with ocp.Context( - array_options=ocp.options.ArrayOptions( - loading=ocp.options.ArrayOptions.Loading( - use_load_and_broadcast=True, - ) - ) - ): + ctx = ocp.Context() + ctx.array.loading.use_load_and_broadcast = True + with ctx: ocp.save_pytree(self.directory, [arr]) with self.subTest('with_abstract_pytree'): loaded = ocp.load_pytree( @@ -1223,15 +1193,9 @@ def test_subchunking(self): } with self.subTest('global_setting'): - with ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving( - storage_options=ocp.options.ArrayOptions.Saving.StorageOptions( - chunk_byte_size=8, # force divide in two subchunks - ) - ) - ) - ): + ctx = ocp.Context() + ctx.array.saving.storage_options.chunk_byte_size = 8 + with ctx: ocp.save_pytree(self.directory / 'global_setting', pytree) metadata = ocp.pytree_metadata( self.directory / 'global_setting' @@ -1242,22 +1206,19 @@ def test_subchunking(self): self.assertEqual(metadata[k].storage_metadata.chunk_shape, (2,)) with self.subTest('per_key_setting'): - def scoped_storage_options_creator(key, value): + + def scoped_storage_options_creator(key, value, storage): del value if 'a' in tree_utils.str_keypath(key): - return ocp.options.ArrayOptions.Saving.StorageOptions( - chunk_byte_size=4, # force divide in 4 subchunks - ) - return ocp.options.ArrayOptions.Saving.StorageOptions( - chunk_byte_size=8, # force divide in 2 subchunks - ) - with ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving( - scoped_storage_options_creator=scoped_storage_options_creator - ) - ), - ): + storage.chunk_byte_size = 4 + else: + storage.chunk_byte_size = 8 + + ctx = ocp.Context() + ctx.array.saving.scoped_storage_options_creator = ( + scoped_storage_options_creator + ) + with ctx: ocp.save_pytree(self.directory / 'per_key_setting', pytree) metadata = ocp.pytree_metadata( self.directory / 'per_key_setting' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py index 3d952f84d..34c01bcda 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py @@ -449,13 +449,12 @@ def test_custom_checkpointables(self): with self.subTest('load_with_free_function'): if multihost.is_pathways_backend(): self.skipTest('Sharding metadata not present in Pathways.') - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - foo=handler_utils.FooHandler, - bar=handler_utils.BarHandler, - ) - ) - with ocp.Context(checkpointables_options=checkpointables_options): + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler, checkpointable_name='foo') + registry.add(handler_utils.BarHandler, checkpointable_name='bar') + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: loaded = ocp.load_checkpointables(self.directory / '0') self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar']) test_utils.assert_tree_equal( @@ -686,12 +685,9 @@ def test_preservation_metrics(self, policy, expected_steps): checkpointer.close() def test_gcs_deletion_options(self): - deletion_options = ocp.options.DeletionOptions( - gcs_deletion_options=ocp.options.DeletionOptions.GcsDeletionOptions( - todelete_full_path='trash' - ) - ) - with ocp.Context(deletion_options=deletion_options): + ctx = ocp.Context() + ctx.deletion.gcs_deletion_options.todelete_full_path = 'trash' + with ctx: checkpointer = Checkpointer(self.directory) self.assertEqual( checkpointer._manager._options.todelete_full_path, 'trash' @@ -699,14 +695,9 @@ def test_gcs_deletion_options(self): def test_context_constructor_override(self): - ctx1 = ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving(use_ocdbt=False) - ), - pytree_options=ocp.options.PyTreeOptions( - loading=ocp.options.PyTreeOptions.Loading(partial_load=True) - ), - ) + ctx1 = ocp.Context() + ctx1.array.saving.use_ocdbt = False + ctx1.pytree.loading.partial_load = True checkpointer = Checkpointer(self.directory, context=ctx1) self.enter_context(checkpointer) self.save_pytree(checkpointer, 0, self.pytree) @@ -732,11 +723,8 @@ def test_context_constructor_override(self): with self.subTest('local_context_override'): # Override with local context setting use_ocdbt=True - ctx2 = ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving(use_ocdbt=True) - ) - ) + ctx2 = ocp.Context() + ctx2.array.saving.use_ocdbt = True with ctx2: self.save_pytree(checkpointer, 1, self.pytree) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/v0v1_compatibility_checkpointer_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/v0v1_compatibility_checkpointer_test_base.py index 0cff9c4be..e3175e67d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/v0v1_compatibility_checkpointer_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/v0v1_compatibility_checkpointer_test_base.py @@ -171,41 +171,32 @@ def test_root_metadata(self, reinitialize_checkpointer): def test_custom_checkpointables(self): # Use named handler to override v0 checkpoint_handlers. - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - foo=handler_utils.FooHandler, - bar=handler_utils.BarHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler, checkpointable_name='foo') + registry.add(handler_utils.BarHandler, checkpointable_name='bar') + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) super().test_custom_checkpointables() def test_load_with_switched_abstract_checkpointables(self): # Use named handler to override v0 checkpoint_handlers. - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - bar=handler_utils.FooHandler, - foo=handler_utils.BarHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler, checkpointable_name='bar') + registry.add(handler_utils.BarHandler, checkpointable_name='foo') + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) super().test_load_with_switched_abstract_checkpointables() def test_different_custom_checkpointables(self): # Use named handler to override v0 checkpoint_handlers. - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - foo=handler_utils.FooHandler, - bar=handler_utils.BarHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler, checkpointable_name='foo') + registry.add(handler_utils.BarHandler, checkpointable_name='bar') + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) super().test_different_custom_checkpointables() def test_custom_save_decision_policy(self):