Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions checkpoint/orbax/checkpoint/_src/engine/async_io_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""AsyncIoEngine module.

Provides the `AsyncIoEngine` class and supporting helper functions responsible
for managing concurrent I/O execution, thread-pooling, and performance telemetry
collection during PyTree saving and restoration workflows.
"""

from __future__ import annotations

import asyncio
import contextlib
import dataclasses
import sys
import threading
import time
from typing import Any, List, Optional, Sequence, Tuple, Union

from absl import logging
import humanize
import jax
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.serialization import memory_regulator
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization import types

TypeHandler = types.TypeHandler
ParamInfo = types.ParamInfo
SaveArgs = type_handlers.SaveArgs
RestoreArgs = type_handlers.RestoreArgs


def _default_sizeof_values(values: Sequence[Any]) -> Sequence[int]:
return [sys.getsizeof(v) for v in values]


def get_batch_memory_size(
handler: TypeHandler, values: Sequence[Any]
) -> Tuple[int, int]:
"""Gets memory size for a batch of leaf values."""
try:
write_sizes, read_sizes = zip(*handler.memory_size(values))
except NotImplementedError:
logging.warning(
'`memory_size` is not implemented for `TypeHandler` of type: %s. Using'
' the a default implementation to measure value memory consumption that'
' may result in inaccurate estimation.',
type(handler),
)
write_sizes = read_sizes = _default_sizeof_values(values)
assert len(write_sizes) == len(values)
assert len(read_sizes) == len(values)
return sum(write_sizes), sum(read_sizes)


def log_io_metrics(
size: int,
start_time: float,
gbytes_per_sec_metric: str,
gbytes_metric: Optional[str] = None,
):
"""Logs the bytes per second metric."""
time_elapsed = time.time() - start_time
bytes_per_sec = (
float('nan') if time_elapsed == 0 else float(size) / time_elapsed
)
note = 'per-host'
logging.info(
'[process=%d] %s: %s/s (total gbytes: %s) (time elapsed: %s s) (%s)',
multihost.process_index(),
gbytes_per_sec_metric,
humanize.naturalsize(bytes_per_sec, binary=True, format='%.3f'),
humanize.naturalsize(size, binary=True),
time_elapsed,
note,
)
jax.monitoring.record_scalar(
gbytes_per_sec_metric, value=bytes_per_sec / (1024**3)
)
if gbytes_metric is not None:
jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3))


async def logging_serialize(
handler: TypeHandler,
serialize: asyncio.Coroutine[Any, Any, Sequence[future.Future]],
) -> Sequence[future.Future]:
"""Logs the time taken to serialize."""
start = time.time()
commit_futures = await serialize
handler_name = f'{type(handler).__module__}.{type(handler).__qualname__}'
logging.info(
'[process=%s][thread=%s] Initiated %s.serialize. Time taken: %fs',
multihost.process_index(),
threading.current_thread().name,
f'"{handler_name}"',
time.time() - start,
)
return commit_futures


@dataclasses.dataclass
class BatchRequest:
"""Represents a a request for batched serialization or deserialization.

Attributes:
handler: Used to serialize or deserialize the parameters.
keys: Used to identify the original tree keys so that the PyTree can be
reconstructed.
values: Values to serialize.
infos: ParamInfos.
args: List of SaveArgs or RestoreArgs.
"""

handler: TypeHandler
keys: List[str]
values: List[Any]
infos: List[ParamInfo]
args: List[Union[SaveArgs, RestoreArgs]]

def __post_init__(self):
length = len(self.values)
if not all((
length == len(self.infos),
length == len(self.args),
length == len(self.keys),
)):
raise AssertionError('Found `_BatchRequest` with mismatched parameters.')


@contextlib.contextmanager
def memory_profiler_context():
"""Context manager for memory_regulator profiler."""
memory_regulator.profiler_start()
try:
yield
finally:
# Explicitly stop the bg thread if an exception occurs
memory_regulator.profiler_end()


class AsyncIoEngine:
"""Encapsulates concurrency, thread-pooling, and I/O telemetry logic."""

async def execute_save(
self, batch_requests: Sequence[BatchRequest]
) -> Tuple[List[Any], int]:
"""Executes save requests asynchronously with I/O telemetry."""
serialize_ops = []
tree_memory_size = 0
with memory_profiler_context():
for request in batch_requests:
serialize_ops.append(
logging_serialize(
request.handler,
request.handler.serialize(
request.values, request.infos, request.args
),
)
)
write_size, _ = get_batch_memory_size(request.handler, request.values)
tree_memory_size += write_size

commit_futures = await asyncio.gather(*serialize_ops)

logging.info(
'MemoryRegulated: Peak usage: %f GiB',
memory_regulator.profiler_peak_usage_gib(),
)
return commit_futures, tree_memory_size

async def execute_restore(
self, batch_requests: Sequence[BatchRequest]
) -> Tuple[List[Any], int]:
"""Executes restore requests asynchronously with I/O telemetry."""
deserialized_batches_ops = []
for request in batch_requests:
deserialized_batches_ops.append(
request.handler.deserialize(request.infos, request.args)
)
deserialized_batches = await asyncio.gather(*deserialized_batches_ops)

tree_memory_size = 0
for request, deserialized in zip(batch_requests, deserialized_batches):
_, read_size = get_batch_memory_size(request.handler, deserialized)
tree_memory_size += read_size

return deserialized_batches, tree_memory_size
158 changes: 158 additions & 0 deletions checkpoint/orbax/checkpoint/_src/engine/async_io_engine_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest
from unittest import mock

from absl.testing import absltest
from orbax.checkpoint._src.engine import async_io_engine
from orbax.checkpoint._src.serialization import types

AsyncIoEngine = async_io_engine.AsyncIoEngine
BatchRequest = async_io_engine.BatchRequest


class AsyncIoEngineTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):

def test_get_batch_memory_size_success(self):
handler = mock.create_autospec(types.TypeHandler, instance=True)
handler.memory_size.return_value = [(10, 20), (30, 40)]

write_size, read_size = async_io_engine.get_batch_memory_size(
handler, ['a', 'b']
)
self.assertEqual(write_size, 40)
self.assertEqual(read_size, 60)

def test_get_batch_memory_size_not_implemented(self):
handler = mock.create_autospec(types.TypeHandler, instance=True)
handler.memory_size.side_effect = NotImplementedError()

values = ['dummy1', 'dummy2']
expected_size = sum(sys.getsizeof(v) for v in values)

write_size, read_size = async_io_engine.get_batch_memory_size(
handler, values
)
self.assertEqual(write_size, expected_size)
self.assertEqual(read_size, expected_size)

def test_batch_request_validation_success(self):
handler = mock.create_autospec(types.TypeHandler, instance=True)
req = BatchRequest(
handler=handler,
keys=['k1', 'k2'],
values=['v1', 'v2'],
infos=[mock.Mock(), mock.Mock()],
args=[mock.Mock(), mock.Mock()],
)
self.assertLen(req.values, 2)

def test_batch_request_validation_mismatch(self):
handler = mock.create_autospec(types.TypeHandler, instance=True)
with self.assertRaises(AssertionError):
BatchRequest(
handler=handler,
keys=['k1'],
values=['v1', 'v2'],
infos=[mock.Mock(), mock.Mock()],
args=[mock.Mock(), mock.Mock()],
)

async def test_execute_save(self):
engine = AsyncIoEngine()

handler1 = mock.create_autospec(types.TypeHandler, instance=True)
handler2 = mock.create_autospec(types.TypeHandler, instance=True)

async def dummy_serialize1(*args, **kwargs):
del args, kwargs
return ['fut1', 'fut2']

async def dummy_serialize2(*args, **kwargs):
del args, kwargs
return ['fut3']

handler1.serialize.side_effect = dummy_serialize1
handler2.serialize.side_effect = dummy_serialize2

handler1.memory_size.return_value = [(100, 0)]
handler2.memory_size.return_value = [(200, 0)]

req1 = BatchRequest(
handler=handler1,
keys=['k1'],
values=['v1'],
infos=[mock.Mock()],
args=[mock.Mock()],
)
req2 = BatchRequest(
handler=handler2,
keys=['k2'],
values=['v2'],
infos=[mock.Mock()],
args=[mock.Mock()],
)

commit_futures, tree_memory_size = await engine.execute_save([req1, req2])

self.assertEqual(commit_futures, [['fut1', 'fut2'], ['fut3']])
self.assertEqual(tree_memory_size, 300)

async def test_execute_restore(self):
engine = AsyncIoEngine()

handler1 = mock.create_autospec(types.TypeHandler, instance=True)
handler2 = mock.create_autospec(types.TypeHandler, instance=True)

async def dummy_deserialize1(*args, **kwargs):
del args, kwargs
return ['restored1']

async def dummy_deserialize2(*args, **kwargs):
del args, kwargs
return ['restored2']

handler1.deserialize.side_effect = dummy_deserialize1
handler2.deserialize.side_effect = dummy_deserialize2

handler1.memory_size.return_value = [(0, 50)]
handler2.memory_size.return_value = [(0, 150)]

req1 = BatchRequest(
handler=handler1,
keys=['k1'],
values=['v1'],
infos=[mock.Mock()],
args=[mock.Mock()],
)
req2 = BatchRequest(
handler=handler2,
keys=['k2'],
values=['v2'],
infos=[mock.Mock()],
args=[mock.Mock()],
)

deserialized_batches, tree_memory_size = await engine.execute_restore(
[req1, req2]
)

self.assertEqual(deserialized_batches, [['restored1'], ['restored2']])
self.assertEqual(tree_memory_size, 200)


if __name__ == '__main__':
absltest.main()
Loading
Loading