Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,31 +89,27 @@ def is_valid(self) -> bool:

@property
def context(self) -> ocp.Context:
return ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(
storage_options=ocp.options.ArrayOptions.Saving.StorageOptions(
chunk_byte_size=self.chunk_byte_size,
),
use_ocdbt=self.use_ocdbt,
use_zarr3=self.use_zarr3,
use_replica_parallel=self.use_replica_parallel,
use_compression=self.use_compression,
enable_replica_parallel_separate_folder=self.enable_replica_parallel_separate_folder,
),
loading=ocp.options.ArrayOptions.Loading(
use_load_and_broadcast=self.use_load_and_broadcast,
),
),
memory_options=ocp.options.MemoryOptions(
write_concurrent_bytes=self.save_concurrent_gb * 1024**3
if self.save_concurrent_gb is not None
else None,
read_concurrent_bytes=self.restore_concurrent_gb * 1024**3
if self.restore_concurrent_gb is not None
else None,
),
ctx = ocp.Context()
ctx.array.saving.storage_options.chunk_byte_size = self.chunk_byte_size
ctx.array.saving.use_ocdbt = self.use_ocdbt
ctx.array.saving.use_zarr3 = self.use_zarr3
ctx.array.saving.use_replica_parallel = self.use_replica_parallel
ctx.array.saving.use_compression = self.use_compression
ctx.array.saving.enable_replica_parallel_separate_folder = (
self.enable_replica_parallel_separate_folder
)
ctx.array.loading.use_load_and_broadcast = self.use_load_and_broadcast
ctx.memory.write_concurrent_bytes = (
self.save_concurrent_gb * 1024**3
if self.save_concurrent_gb is not None
else None
)
ctx.memory.read_concurrent_bytes = (
self.restore_concurrent_gb * 1024**3
if self.restore_concurrent_gb is not None
else None
)
return ctx


def clear_pytree(pytree: Any) -> Any:
Expand Down Expand Up @@ -159,7 +155,7 @@ def test_fn(
logging.info("Benchmark options: %s", pprint.pformat(options))
metrics_to_measure = get_metrics_to_measure(options)

with ocp.Context(context=options.context):
with ocp.Context(options.context):
if options.enable_trace:
jax.profiler.start_trace(context.path / "trace_save")
if options.async_enabled:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ def test_fn(
reference_sharding_path=reference_sharding_path,
)

with ocp.Context(context=options.context):
with ocp.Context(options.context):
loaded_pytree = ocp.load_pytree(
reference_checkpoint_path, abstract_pytree
)

for step in range(options.num_savings):
logging.info("ReplicaParallelMultislice: Starting Step: %s", step)
save_path = context.path / "ckpt" / str(step)
with ocp.Context(context=options.context):
with ocp.Context(options.context):
if options.enable_trace:
jax.profiler.start_trace(context.path / "trace_save")
if options.async_enabled:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_fn(
options.reference_sharding_path
)

with ocp.Context(context=options.context):
with ocp.Context(options.context):
metadata = ocp.pytree_metadata(reference_checkpoint_path)
abstract_pytree = (
checkpoint_generation.get_abstract_state_from_sharding_config(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_fn(
reference_sharding_path=reference_sharding_path,
)

with ocp.Context(context=options.context):
with ocp.Context(options.context):
if options.enable_trace:
jax.profiler.start_trace(context.path / "trace_load")
with metrics.measure("load", metrics_to_measure):
Expand Down
203 changes: 107 additions & 96 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from collections.abc import Iterable
import contextvars
import copy
import typing

from absl import logging
Expand Down Expand Up @@ -45,14 +46,13 @@ class Context(epy.ContextManager):
"""Context for customized checkpointing.

This class manages the configuration options (e.g., async, multiprocessing,
array handling) used during Orbax checkpoint operations.
array handling) used during Orbax checkpoint operations using a mutable
namespace pattern.

Creating a new :py:class:`.Context` within an existing :py:class:`.Context`
sets all parameters from scratch by default. To inherit properties from a
parent :py:class:`.Context`, you must explicitly pass the parent context as
the first argument. The new context will inherit the parent's properties,
except for any options explicitly provided as keyword arguments to the child
context.
Creating a new :py:class:`.Context` sets all parameters from absolute defaults
by default. To inherit properties from a parent :py:class:`.Context`, pass the
parent context as the first positional argument. The new context will inherit
the parent's properties but can be mutated independently.

WARNING: The context is thread-local and is not shared across threads. The
entire context block must be executed within the same thread. If you dispatch
Expand All @@ -69,116 +69,131 @@ class Context(epy.ContextManager):
Example:
Basic usage and explicit inheritance::

import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp

# Basic usage
with ocp.Context(pytree_options=ocp.options.PyTreeOptions()):
ocp.save_pytree(directory, tree)
ctx = ocp.Context()
ctx.pytree.loading.partial_load = True
with ctx:
ocp.load_pytree(directory, tree)

# Inheriting properties from an existing context
with ocp.Context(pytree_options=ocp.options.PyTreeOptions()) as outer_ctx:
# inner_ctx inherits pytree_options, but overrides/adds array_options
with ocp.Context(outer_ctx,
array_options=ocp.options.ArrayOptions()
) as inner_ctx:
ocp.save_pytree(directory, tree)
ctx1 = ocp.Context()
ctx1.pytree.loading.partial_load = True
with ctx1 as outer_ctx:
# inner_ctx inherits partial_load, but mutates array saving
ctx2 = ocp.Context(outer_ctx)
ctx2.array.saving.use_zarr3 = False
with ctx2 as inner_ctx:
ocp.load_pytree(directory, tree)

Context is not shared across threads::

from concurrent.futures import ThreadPoolExecutor
import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp

executor = ThreadPoolExecutor(max_workers=1)
with ocp.Context(
pytree_options=ocp.options.PyTreeOptions()
): # Thread #1 creates Context.
# The following save_pytree call is executed in Thread #2, which sees
ctx = ocp.Context()
ctx.pytree.loading.partial_load = True
with ctx: # Thread #1 creates Context.
# The following load_pytree call is executed in Thread #2, which sees
# a "default" Context, NOT the one created above.
executor.submit(ocp.save_pytree, directory, tree)

executor.submit(ocp.load_pytree, directory, tree)

Attributes:
pytree_options: Options for PyTree checkpointing. See
pytree: Options for PyTree checkpointing. See
:class:`~orbax.checkpoint.experimental.v1.options.PyTreeOptions`.
array_options: Options for saving and loading array (and array-like
objects). See
array: Options for saving and loading array (and array-like objects). See
:class:`~orbax.checkpoint.experimental.v1.options.ArrayOptions`.
async_options: Options for controlling asynchronous behavior. See
asynchronous: Options for controlling asynchronous behavior. See
:class:`~orbax.checkpoint.experimental.v1.options.AsyncOptions`.
multiprocessing_options: Options for multiprocessing behavior. See
multiprocessing: Options for multiprocessing behavior. See
:class:`~orbax.checkpoint.experimental.v1.options.MultiprocessingOptions`.
file_options: Options for working with the file system. See
file: Options for working with the file system. See
:class:`~orbax.checkpoint.experimental.v1.options.FileOptions`.
checkpointables_options: Options for controlling checkpointables behavior.
See
checkpointables: Options for controlling checkpointables behavior. See
:class:`~orbax.checkpoint.experimental.v1.options.CheckpointablesOptions`.
pathways_options: Options for Pathways checkpointing. See
pathways: Options for Pathways checkpointing. See
:class:`~orbax.checkpoint.experimental.v1.options.PathwaysOptions`.
checkpoint_layout: The layout of the checkpoint. Defaults to ORBAX. See
:class:`~orbax.checkpoint.experimental.v1.options.CheckpointLayout`.
deletion_options: Options for controlling deletion behavior. See
deletion: Options for controlling deletion behavior. See
:class:`~orbax.checkpoint.experimental.v1.options.DeletionOptions`.
memory_options: Options for controlling memory limits during save / load.
See :class:`~orbax.checkpoint.experimental.v1.options.MemoryOptions`.
memory: Options for controlling memory limits during save / load. See
:class:`~orbax.checkpoint.experimental.v1.options.MemoryOptions`.
"""

def __init__(
self,
context: Context | None = None,
*,
pytree_options: options_lib.PyTreeOptions | None = None,
array_options: options_lib.ArrayOptions | None = None,
async_options: options_lib.AsyncOptions | None = None,
multiprocessing_options: options_lib.MultiprocessingOptions | None = None,
file_options: options_lib.FileOptions | None = None,
checkpointables_options: options_lib.CheckpointablesOptions | None = None,
pathways_options: options_lib.PathwaysOptions | None = None,
checkpoint_layout: options_lib.CheckpointLayout | None = None,
deletion_options: options_lib.DeletionOptions | None = None,
memory_options: options_lib.MemoryOptions | None = None,
safetensors_options: options_lib.SafetensorsOptions | None = None,
):
self._pytree_options = pytree_options or (
context.pytree_options if context else options_lib.PyTreeOptions()
)
self._array_options = array_options or (
context.array_options if context else options_lib.ArrayOptions()
)
self._async_options = async_options or (
context.async_options if context else options_lib.AsyncOptions()
)
self._multiprocessing_options = multiprocessing_options or (
context.multiprocessing_options
if context
else options_lib.MultiprocessingOptions()
)
self._file_options = file_options or (
context.file_options if context else options_lib.FileOptions()
)
self._checkpointables_options = checkpointables_options or (
context.checkpointables_options
if context
else options_lib.CheckpointablesOptions()
)
self._pathways_options = pathways_options or (
context.pathways_options if context else options_lib.PathwaysOptions()
)
self._checkpoint_layout = checkpoint_layout or (
context.checkpoint_layout
if context
else options_lib.CheckpointLayout.ORBAX
)
self._deletion_options = deletion_options or (
context.deletion_options if context else options_lib.DeletionOptions()
)
self._memory_options = memory_options or (
context.memory_options if context else options_lib.MemoryOptions()
)
self._safetensors_options = safetensors_options or (
context.safetensors_options
if context
else options_lib.SafetensorsOptions()
)
def __init__(self, context: Context | None = None):
if context is not None:
for k, v in context.__dict__.items():
if k.endswith('_options') or k.endswith('_layout'):
setattr(self, k, copy.deepcopy(v))
else:
self._pytree_options = options_lib.PyTreeOptions()
self._array_options = options_lib.ArrayOptions()
self._async_options = options_lib.AsyncOptions()
self._multiprocessing_options = options_lib.MultiprocessingOptions()
self._file_options = options_lib.FileOptions()
self._checkpointables_options = options_lib.CheckpointablesOptions()
self._pathways_options = options_lib.PathwaysOptions()
self._checkpoint_layout = options_lib.CheckpointLayout.ORBAX
self._deletion_options = options_lib.DeletionOptions()
self._memory_options = options_lib.MemoryOptions()
self._safetensors_options = options_lib.SafetensorsOptions()

# --- Short-name properties for mutable user configuration ---

@property
def array(self) -> options_lib.ArrayOptions:
return self._array_options

@property
def asynchronous(self) -> options_lib.AsyncOptions:
return self._async_options

@property
def pytree(self) -> options_lib.PyTreeOptions:
return self._pytree_options

@property
def file(self) -> options_lib.FileOptions:
return self._file_options

@property
def multiprocessing(self) -> options_lib.MultiprocessingOptions:
return self._multiprocessing_options

@property
def checkpointables(self) -> options_lib.CheckpointablesOptions:
return self._checkpointables_options

@property
def pathways(self) -> options_lib.PathwaysOptions:
return self._pathways_options

@property
def deletion(self) -> options_lib.DeletionOptions:
return self._deletion_options

@property
def memory(self) -> options_lib.MemoryOptions:
return self._memory_options

@property
def safetensors(self) -> options_lib.SafetensorsOptions:
return self._safetensors_options

@property
def checkpoint_layout(self) -> options_lib.CheckpointLayout:
return self._checkpoint_layout

@checkpoint_layout.setter
def checkpoint_layout(self, value: options_lib.CheckpointLayout) -> None:
self._checkpoint_layout = value

# TODO: b/513156122 - Migrate internal read sites to short-hand properties and
# remove legacy aliases in the next refactor.
# --- Legacy aliases for internal read access compatibility ---

@property
def pytree_options(self) -> options_lib.PyTreeOptions:
Expand Down Expand Up @@ -208,10 +223,6 @@ def checkpointables_options(self) -> options_lib.CheckpointablesOptions:
def pathways_options(self) -> options_lib.PathwaysOptions:
return self._pathways_options

@property
def checkpoint_layout(self) -> options_lib.CheckpointLayout:
return self._checkpoint_layout

@property
def deletion_options(self) -> options_lib.DeletionOptions:
return self._deletion_options
Expand Down
Loading
Loading