Skip to content

Commit c8be6df

Browse files
ChromeHeartsOrbax Authors
authored andcommitted
Add support loading ref checkpoint with colocated Python and PathwaysColocatedPythonGuide.md
PiperOrigin-RevId: 912169322
1 parent 146b6b1 commit c8be6df

5 files changed

Lines changed: 125 additions & 0 deletions

File tree

checkpoint/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
- Tensorstore non-atomic file I/O locking on local/POSIX-compatible filesystems
1515
to avoid unnecessary renames.
16+
- Add loading source checkpoint with Pathways colocated-Python in Benchmark
17+
suite.
1618

1719
## [0.11.38] - 2026-05-05
1820

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# The name for the entire test suite run.
2+
suite_name: "Llama 3.1 70B v5p-64"
3+
num_repeats: 5
4+
5+
mesh_configs:
6+
- mesh_axes: ["data", "fsdp", "tensor"]
7+
ici_parallelism: {"data": 1, "fsdp": 32, "tensor": 1}
8+
9+
# The checkpoint configuration, shared across all generated benchmarks.
10+
checkpoint_configs:
11+
- path: "gs://orbax-benchmarks/checkpoints/llama-3.1-70B-checkpoints/0/items"
12+
sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-64-data-1-fsdp-32-tensor-1/abstract_state.json"
13+
load_with_colocated_python: true
14+
15+
benchmarks:
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.pytree_checkpoint_benchmark.PyTreeCheckpointBenchmark"
17+
options:
18+
async_enabled: true
19+
use_ocdbt: true
20+
use_zarr3: true
21+
use_replica_parallel: false
22+
use_compression: true
23+
use_colocated_python: true

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/checkpoint_generation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
import jax.numpy as jnp
2424
import numpy as np
2525
from orbax.checkpoint import checkpoint_utils
26+
from orbax.checkpoint import pathways
2627
from orbax.checkpoint._src.arrays import abstract_arrays
2728
from orbax.checkpoint._src.arrays import sharding as sharding_utils
2829
from orbax.checkpoint._src.checkpointers import checkpointer
2930
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
31+
from orbax.checkpoint._src.multihost import multihost
3032
from orbax.checkpoint._src.serialization import type_handlers
3133
from orbax.checkpoint._src.testing.benchmarks.core import configs
3234
from orbax.checkpoint._src.tree import utils as tree_utils
@@ -214,6 +216,16 @@ def load_checkpoint(config: configs.CheckpointConfig) -> Any:
214216
abstract_state = _get_abstract_state(config, use_ocdbt=use_ocdbt)
215217
restore_args = checkpoint_utils.construct_restore_args(abstract_state)
216218

219+
if multihost.is_pathways_backend():
220+
checkpointing_impl = pathways.CheckpointingImpl.from_options(
221+
use_colocated_python=config.load_with_colocated_python,
222+
)
223+
pathways.register_type_handlers(
224+
checkpointing_impl=checkpointing_impl,
225+
use_replica_parallel=False,
226+
enable_replica_parallel_separate_folder=False,
227+
)
228+
217229
with checkpointer.Checkpointer(
218230
pytree_checkpoint_handler.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt)
219231
) as ckptr:

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/configs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,15 @@ class CheckpointConfig:
8181
1024], 'sharding': ['data', 'model'] # PartitionSpec }, 'step': 'int' }
8282
sharding_config_path: A path to a file containing sharding specifications,
8383
used alongside `path`. See above.
84+
load_with_colocated_python: If True, the checkpoint will be loaded with
85+
colocated Python. Only effective in Pathways and `path` is specified.
8486
"""
8587

8688
path: str | None = None
8789
random_seed: int = 0
8890
spec: dict[str, Any] | None = None
8991
sharding_config_path: str | None = None
92+
load_with_colocated_python: bool = False
9093

9194
def __post_init__(self):
9295
if self.path is not None and self.spec is not None:
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Orbax Benchmark on Pathways Colocated Python Guide
2+
3+
## Introduction
4+
5+
This guide provides step-by-step instructions for running Orbax benchmarks
6+
on **Pathways with Colocated Python**. Before you start, make sure you are
7+
familiarize with [README.md](README.md) on how to run `xpk`. This setup allows
8+
running benchmarks where the Python code executes directly on the TPU worker
9+
nodes (colocated with the Pathways server), which is useful for reducing data
10+
transfer between Pathways Client and Workers.
11+
12+
## Step-by-step Instructions
13+
14+
### Create a Pathways cluster
15+
16+
Use `launch_xpk.py --create_cluster` option or command below
17+
18+
```shell
19+
xpk cluster create-pathways \
20+
--cluster orbax-benchmark-v5p-64-pw \
21+
--tpu-type=v5p-64 \
22+
--num-slices=1 \
23+
--spot \
24+
--zone=us-east5-a
25+
```
26+
27+
### Install latest XPK
28+
29+
```shell
30+
pip install --upgrade xpk
31+
```
32+
33+
### Build both benchmark & sidecar images.
34+
35+
- Use `--jax-version 0.10.0` for Benchmark image build. This has to match
36+
with `Dockerfile.sidecar`'s source image.
37+
- Use `--base-image python:3.12-slim` for the Benchmark image build, the python
38+
version also has with `Dockerfile.sidecar`'s source image.
39+
- Use `--build-sidecar true` to build sidecar image as well.
40+
41+
This is a full example of `build_image.sh` command.
42+
```shell
43+
export DATE_STR=(`date +"%Y%m%d-%H%M%Z"`)
44+
export IMG_TAG=orbax-benchmark-local-tpu-$DATE_STR
45+
46+
bash build_image.sh \
47+
--project orbax-checkpoint \
48+
--local-repo ./ \
49+
--tag $IMG_TAG \
50+
--device tpu \
51+
--jax-version 0.10.0 \
52+
--base-image python:3.12-slim \
53+
--no-cache \
54+
--build-benchmark true \
55+
--build-sidecar true
56+
```
57+
58+
### Run Benchmark
59+
60+
Following is the full example to run a Orbax Benchmark on Pathways Colocated
61+
Python. Make sure to set the sidecar docker_image in `--pathways_sidecar_image`
62+
to the images you have just built above.
63+
64+
```shell
65+
MODEL=llama-70b-v5p-64-pw
66+
OUTPUT_DIR=gs://orbax-benchmarks/${USER}/${DATE_STR}/$MODEL # don't end with slash
67+
CONFIG_DIR=$OUPTUT_DIR
68+
69+
python3 launch_xpk.py \
70+
--verbose \
71+
--enable_pathways \
72+
--cluster_name dnlng-v5p-64-pw \
73+
--create_cluster=False \
74+
--delete_cluster_on_completion=False \
75+
--spot \
76+
--tpu_type v5p-64 \
77+
--num_slices 1 \
78+
--zone us-east5-a \
79+
--config_file=orbax/checkpoint/_src/testing/benchmarks/configs/llama-70b-v5p-64-pw-example.yaml \
80+
--docker_image=gcr.io/orbax-checkpoint/orbax-benchmarks:$IMG_TAG \
81+
--pathways_sidecar_image=gcr.io/orbax-checkpoint/orbax-benchmarks/sidecar:$IMG_TAG \
82+
--output_directory=$OUTPUT_DIR \
83+
--config_directory=$CONFIG_DIR
84+
85+
```

0 commit comments

Comments
 (0)