Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
--ignore=orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager_test.py \
--ignore=orbax/checkpoint/_src/testing/multiprocess_test.py \
--ignore=orbax/checkpoint/_src/testing/oss/multiprocess_test.py \
$(python3 -c "import yaml; d=yaml.safe_load(open('orbax/checkpoint/_src/testing/oss/tagged_tests.yaml')); print(' '.join(['--ignore=' + t.replace(':', '/') + '.py' for k,v in d.items() if k.startswith('processes') and v for t in v]))")
$(python3 -c "import yaml; d=yaml.safe_load(open('orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml')); print(' '.join(['--ignore=' + t.replace(':', '/') + '.py' for k,v in d.items() if k.startswith('processes') and v for t in v]))")
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Report success or failure as github status
Expand Down Expand Up @@ -350,13 +350,13 @@ jobs:
env:
TEST_TMPDIR: /tmp/orbax_test
run: |
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=2 --tpu_chips_per_process=4 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=2
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=2 --tpu_chips_per_process=4 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml --processes=2
- name: Run 4 multiprocess tests
run: |
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=4 --tpu_chips_per_process=2 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=4
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=4 --tpu_chips_per_process=2 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml --processes=4
- name: Run single process tests
run: |
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=1
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml --processes=1
- name: Report success or failure as github status
if: always()
shell: bash
Expand Down
59 changes: 59 additions & 0 deletions .github/workflows/multiprocess_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,62 @@ jobs:
"description": "'$status'",
"context": "github-actions/build"
}'

multiprocess-unit-tests:
name: "multiprocess-unit-tests (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
runs-on: linux-x86-ct5lp-224-8tpu
container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
defaults:
run:
working-directory: checkpoint
strategy:
matrix:
python-version: ["3.11"]
jax-version: ["newest"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install -e .
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip uninstall -y orbax
pip install gcsfs
pip install portpicker pytest chex pyyaml tensorboard
if [ "${{ matrix.jax-version }}" = "newest" ]; then
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jax-version }}" = "nightly" ]; then
pip install -U --pre "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
else
pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
fi
- name: Run 2 multiprocess tests
env:
TEST_TMPDIR: /tmp/orbax_test
run: |
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=2 --tpu_chips_per_process=4 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml --processes=2
- name: Run 4 multiprocess tests
run: |
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=4 --tpu_chips_per_process=2 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml --processes=4
- name: Run single process tests
run: |
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml --processes=1
- name: Report success or failure as github status
if: always()
shell: bash
run: |
status="${{ job.status }}"
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
curl -sS --request POST \
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
--header 'content-type: application/json' \
--data '{
"state": "'$lowercase_status'",
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
"description": "'$status'",
"context": "github-actions/build"
}'
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@
]
EXCLUDED_PATHS = [
'orbax/checkpoint/experimental/model_surgery',
'orbax/checkpoint/experimental/v1',
'orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py',
'orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager_test.py',
'orbax/checkpoint/google',
]
EXCLUDED_PATHS_FOR_PRESUBMIT = [
'orbax/checkpoint/experimental/v1',
'orbax/checkpoint/_src/checkpointers',
]


def get_kwargs(call_node):
Expand Down Expand Up @@ -111,21 +114,22 @@ def get_build_targets(build_file_path):
yield name, tags, srcs, args


def run(root_dir, output_file):
def run(root_dir, output_file, extra_excluded_paths=None):
"""Runs the script to generate tagged tests file."""
tests_by_tag = collections.defaultdict(list)
all_excluded_paths = EXCLUDED_PATHS + (extra_excluded_paths or [])

count = 0
for dirpath, dirnames, filenames in os.walk(root_dir):
if any(dirpath.startswith(p) for p in EXCLUDED_PATHS):
if any(dirpath.startswith(p) for p in all_excluded_paths):
dirnames[:] = []
continue

original_dirs = list(dirnames)
dirnames[:] = []
for d in original_dirs:
if not any(
os.path.join(dirpath, d).startswith(p) for p in EXCLUDED_PATHS
os.path.join(dirpath, d).startswith(p) for p in all_excluded_paths
):
dirnames.append(d)

Expand All @@ -137,7 +141,8 @@ def run(root_dir, output_file):
if not any(tag in TAG_MAPPING for tag in tags):
continue
if srcs and any(
os.path.join(dirpath, srcs[0]).startswith(p) for p in EXCLUDED_PATHS
os.path.join(dirpath, srcs[0]).startswith(p)
for p in all_excluded_paths
):
continue
target_path = f'{package_path}:{name}'
Expand Down Expand Up @@ -167,5 +172,15 @@ def run(root_dir, output_file):
if 'BUILD_WORKING_DIRECTORY' in os.environ:
os.chdir(os.environ['BUILD_WORKING_DIRECTORY'])
orbax_dir = 'orbax/checkpoint'
output = 'orbax/checkpoint/_src/testing/oss/tagged_tests.yaml'
run(orbax_dir, output)

# Generate whole suite file
output_whole = 'orbax/checkpoint/_src/testing/oss/tagged_tests_whole_suite.yaml'
run(orbax_dir, output_whole)

# Generate presubmit file
output_presubmit = 'orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml'
run(
orbax_dir,
output_presubmit,
extra_excluded_paths=EXCLUDED_PATHS_FOR_PRESUBMIT,
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ processes:1:
- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:pathways_replicator_checkpoint_manager_test
- orbax/checkpoint:single_host_test
processes:2:
- orbax/checkpoint/_src/checkpointers:async_checkpointer_test
- orbax/checkpoint/_src/checkpointers:checkpointer_test
- orbax/checkpoint/_src/handlers:array_checkpoint_handler_test
- orbax/checkpoint/_src/handlers:pytree_checkpoint_handler_test
- orbax/checkpoint/_src/handlers:standard_checkpoint_handler_test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# DO NOT EDIT!
processes:1:
- orbax/checkpoint/_src/futures:future_test
- orbax/checkpoint/_src/metadata:sharding_tpu_test
- orbax/checkpoint/_src/multihost:multislice_test
- orbax/checkpoint/_src/serialization:colocated_pathways_local_type_handlers_test
- orbax/checkpoint/_src/serialization:colocated_pathways_memory_usage_test
- orbax/checkpoint/_src/serialization:pathways_local_type_handlers_test
- orbax/checkpoint/_src/serialization:pathways_memory_usage_test
- orbax/checkpoint/_src/serialization:replica_slices_test
- orbax/checkpoint/_src/serialization:serialization_test
- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:pathways_process_metadata_checkpoint_handler_test
- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:pathways_replicator_checkpoint_manager_test
- orbax/checkpoint/experimental/v1/_src/emergency:deleter_test
- orbax/checkpoint/experimental/v1/_src/emergency:path_utils_test
- orbax/checkpoint/experimental/v1/_src/testing/pathways:array_leaf_handler_test_multi_worker
- orbax/checkpoint/experimental/v1/_src/testing/pathways:array_leaf_handler_test_single_worker
- orbax/checkpoint/experimental/v1/_src/testing/pathways:pytree_handler_test_multi_worker
- orbax/checkpoint/experimental/v1/_src/testing/pathways:pytree_handler_test_single_worker
- orbax/checkpoint/experimental/v1/_src/testing/pathways:save_load_test_multi_worker
- orbax/checkpoint/experimental/v1/_src/testing/pathways:save_load_test_single_worker
- orbax/checkpoint/experimental/v1/_src/testing/pathways:v1_compatibility_load_test_multi_worker
- orbax/checkpoint/experimental/v1/_src/testing/pathways:v1_compatibility_load_test_single_worker
- orbax/checkpoint/experimental/v1/_src/training/pathways:checkpointer_test_multi_worker
- orbax/checkpoint/experimental/v1/_src/training/pathways:checkpointer_test_single_worker
- orbax/checkpoint/experimental/v1/_src/training/pathways:snapshotter_test
- orbax/checkpoint/experimental/v1/_src/training/pathways:v0v1_compatibility_checkpointer_test_single_worker
- orbax/checkpoint:single_host_test
processes:2:
- orbax/checkpoint/_src/checkpointers:async_checkpointer_test
- orbax/checkpoint/_src/checkpointers:checkpointer_test
- orbax/checkpoint/_src/handlers:array_checkpoint_handler_test
- orbax/checkpoint/_src/handlers:pytree_checkpoint_handler_test
- orbax/checkpoint/_src/handlers:standard_checkpoint_handler_test
- orbax/checkpoint/_src/serialization:local_type_handlers_test
- orbax/checkpoint/_src/serialization:type_handlers_test
- orbax/checkpoint/experimental/emergency/p2p:checkpoint_manager_multiprocess_test
- orbax/checkpoint/experimental/emergency/p2p:local_multiprocess_test
- orbax/checkpoint/experimental/emergency/p2p:persistent_multiprocess_test
- orbax/checkpoint/experimental/v1/_src/handlers:pytree_handler_partial_save_test
- orbax/checkpoint/experimental/v1/_src/handlers:pytree_handler_test
- orbax/checkpoint/experimental/v1/_src/layout:orbax_layout_multiprocess_test
- orbax/checkpoint/experimental/v1/_src/partial:saving_multihost_test
- orbax/checkpoint/experimental/v1/_src/serialization:array_leaf_handler_test
- orbax/checkpoint/experimental/v1/_src/serialization:numpy_leaf_handler_test
- orbax/checkpoint/experimental/v1/_src/serialization:scalar_leaf_handler_test
- orbax/checkpoint/experimental/v1/_src/serialization:string_leaf_handler_test
- orbax/checkpoint/experimental/v1/_src/testing:save_load_test
- orbax/checkpoint/experimental/v1/_src/testing:v1_compatibility_load_test_multiprocess
- orbax/checkpoint/experimental/v1/_src/training:checkpointer_test
- orbax/checkpoint/experimental/v1/_src/training:v0v1_compatibility_checkpointer_test
processes:4:
- orbax/checkpoint/_src/multihost:multihost_test
- orbax/checkpoint/_src/testing/tree_verity:checkpoint_manager_test
- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:process_metadata_checkpoint_handler_test
- orbax/checkpoint/experimental/emergency:local_checkpoint_data_debugging_test
- orbax/checkpoint/experimental/emergency:local_checkpoint_manager_test
- orbax/checkpoint/experimental/emergency:single_slice_checkpoint_manager_test
- orbax/checkpoint/experimental/v1/_src/layout:safetensors_layout_multiprocess_test
- orbax/checkpoint/experimental/v1/_src/synchronization:multihost_test
- orbax/checkpoint/testing:local_path_test
- orbax/checkpoint:checkpoint_manager_slice_test
- orbax/checkpoint:checkpoint_manager_test
Loading
Loading