Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 54 additions & 39 deletions training/ironwood/llama3.1-405b/8k-bf16-tpu7x-4x8x8/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Pretrain llama3.1-405b workload on Ironwood GKE clusters with XPK
# Pretrain llama3-1-405b workload on Ironwood GKE clusters with XPK

This recipe outlines the steps for running a llama3.1-405b
This recipe outlines the steps for running a llama3-1-405b
[MaxText](https://github.com/AI-Hypercomputer/maxtext) pretraining workload on
[Ironwood GKE clusters](https://cloud.google.com/kubernetes-engine) by using
[XPK](https://github.com/AI-Hypercomputer/xpk).


## Workload Details

This workload is configured with the following details:

- Sequence Length: 8192
- Precision: bf16
- Precision: bfloat16
- Chips: 256 (4x8x8 topology)

## Prerequisites
Expand All @@ -35,13 +36,14 @@ To run this recipe, you need the following:
in the [Install XPK and dependencies](#install-xpk-and-dependencies) section
to install Docker.
- **Python 3.11 Virtual Environment:** A Python
3.11 virtual environment is required. Instructions for
setting this up are also in the
3.11 virtual environment is required. Instructions
for setting this up are also in the
[Install XPK and dependencies](#install-xpk-and-dependencies) section.
- **XPK and Dependencies:** Follow the steps in the
[Install XPK and dependencies](#install-xpk-and-dependencies) section to
install XPK, `kubectl`, `kubectl-kueue`, and `kubectl-kjob`.


## Install XPK and dependencies

### XPK and Dependency Installation
Expand All @@ -57,11 +59,11 @@ curl -LsSf https://astral.sh/uv/install.sh -o install-uv.sh
chmod +x install-uv.sh
./install-uv.sh
rm install-uv.sh
source ~/.local/bin/env
source ${HOME}/.local/bin/env

# Set up and Activate Python 3.11 virtual environment
uv venv --seed ~/.local/bin/venv --python 3.11 --clear
source ~/.local/bin/venv/bin/activate
uv venv --seed ${HOME}/.local/bin/venv --python 3.11 --clear
source ${HOME}/.local/bin/venv/bin/activate
pip install --upgrade pip
```

Expand All @@ -78,12 +80,12 @@ Install XPK and necessary tools:
# Ensure to log in to your gcloud

# Install latest xpk
pip install xpk==0.16.0
pip install xpk==1.2.0

# Install xpk pre-reqs kubectl-kueue and kjob (if you installed xpk via pip)
curl -LsSf https://raw.githubusercontent.com/AI-Hypercomputer/xpk/refs/tags/v0.16.0/tools/install-xpk.sh -o install-xpk.sh
curl -LsSf https://raw.githubusercontent.com/AI-Hypercomputer/xpk/refs/tags/v1.2.0/tools/install-xpk.sh -o install-xpk.sh
chmod +x install-xpk.sh
./install-xpk.sh
sudo ./install-xpk.sh
rm install-xpk.sh

# Follow https://cloud.google.com/kubernetes-engine/docs/how-to/cluster-access-for-kubectl#install_plugin to install gke-gcloud-auth-plugin
Expand All @@ -101,6 +103,7 @@ sudo usermod -aG docker $USER ## relaunch the terminal and make sure you have th
docker run hello-world # Test docker
```


## Orchestration and deployment tools

For this recipe, the following setup is used:
Expand All @@ -110,7 +113,8 @@ For this recipe, the following setup is used:
- **Pretraining job configuration and deployment** - XPK is used to configure
and deploy the
[Kubernetes Jobset](https://kubernetes.io/blog/2025/03/23/introducing-jobset)
resource, which manages the execution of the MaxText pretraining workload.
resource, which manages the execution of the llama3-1-405b workload.


## Test environment

Expand All @@ -132,24 +136,24 @@ across all commands and configurations.
- `PROJECT_ID`: Your GCP project name.
- `CLUSTER_NAME`: The target cluster name.
- `ZONE`: The zone for your cluster (e.g., `us-central1-c`).
- `CONTAINER_REGISTRY`: The container registry to use (e.g., `gcr.io`).
- `BASE_OUTPUT_DIR`: Output directory for model training (e.g.,
`"gs://<your_gcs_bucket>"`).
- `CONTAINER_REGISTRY`: The container registry to use (e.g., gcr.io).
- `WORKLOAD_IMAGE`: The Docker image for the workload. This is set in
`run_recipe.sh` to `${CONTAINER_REGISTRY}/${PROJECT_ID}/${USER}-maxtext-runner` by default,
matching the image built in the
`run_recipe.sh` to
`${CONTAINER_REGISTRY}/${PROJECT_ID}/${USER}-llama3-1-405b-runner` by
default, matching the image built in the
[Docker container image](#docker-container-image) section.
- `WORKLOAD_NAME`: A unique name for your workload. This is set in
`run_recipe.sh` using the following command:
`export WORKLOAD_NAME="$(printf "%.26s" "${USER//_/-}-llama3-1-405b-8192-4x4x4")-$(date +%Y%m%d-%H%M)"`
`run_recipe.sh` to `${USER}-llama3-1-405b-$(date +%H%M)` by default.
- `GKE_VERSION`: The GKE version, `1.34.0-gke.2201000` or later.
- `ACCELERATOR_TYPE`: The TPU type (e.g., `tpu7x-4x4x4`). See topologies
[here](https://cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus#configuration).
- `RESERVATION_NAME`: Your TPU reservation name. Use the reservation name if
within the same project. For a shared project, use
`"projects/<project_number>/reservations/<reservation_name>"`.

If you dont have a GCS bucket, create one with this command:
If you don't have a GCS bucket, create one with this command:

```bash
# Make sure BASE_OUTPUT_DIR is set in run_recipe.sh before running this.
Expand All @@ -168,6 +172,7 @@ xpk cluster create \
--reservation=${RESERVATION_NAME}
```


## Docker container image

To build your own image, follow the steps linked in this section. If you don't
Expand All @@ -176,13 +181,23 @@ XPK and its dependencies. Docker installation is part of this process.

### Steps for building workload image

**Warning:** If any of the software versions below show as "N/A", you *must*
fill in the correct versions. To find the missing versions (e.g., for MaxText
commit hash, Libtpu, and Jax/Jaxlib), you may need to:
1. Pull the Docker image from the workload that this recipe is based on.
2. Start the Docker container.
3. Run commands within the container to get the specific versions. For example,
to find the MaxText commit, you can use `git rev-parse HEAD` inside the cloned
MaxText repository within the container. For Python package versions, use
`pip show <package_name>`.

The following software versions are used:

- Libtpu version: 0.0.31.dev20251119+nightly
- Jax version: 0.8.1
- Maxtext version: maxtext-tutorial-v1.3.0
- Python 3.11
- XPK 0.14.3
- Libtpu version: 0.0.37.dev20260224+nightly
- Jax version: 0.9.1.dev20260225
- Maxtext version: bf174d6
- Python: 3.11
- XPK: 1.2.0

Docker Image Building Command:

Expand All @@ -191,26 +206,24 @@ export CONTAINER_REGISTRY="" # Initialize with your registry
export CLOUD_IMAGE_NAME="${USER}-maxtext-runner"
export WORKLOAD_IMAGE="${CONTAINER_REGISTRY}/${PROJECT_ID}/${CLOUD_IMAGE_NAME}"

# Let's temporarily switch to a Python 3.12 virtual environment for Docker build
uv venv --seed ~/.local/bin/venv-docker --python 3.12 --clear
source ~/.local/bin/venv-docker/bin/activate
# Set up and Activate Python 3.12 virtual environment for Docker build
uv venv --seed ${HOME}/.local/bin/venv-docker --python 3.12 --clear
source ${HOME}/.local/bin/venv-docker/bin/activate
pip install --upgrade pip

# Make sure you're running on a Virtual Environment with python 3.12
if [[ "$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null)" == "3.12" ]]; then { echo You have the correct Python version 3.12; } else { >&2 echo Error: Python version must be 3.12; } fi
if [[ "$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null)" == "3.12" ]]; then { echo "You have the correct Python version 3.12"; } else { >&2 echo "Error: Python version must be 3.12."; false; } fi

# Clone MaxText Repository and Checkout Recipe Branch
git clone https://github.com/AI-Hypercomputer/maxtext.git
cd maxtext
git checkout maxtext-tutorial-v1.3.0

# Custom Jax and LibTPU wheels
pip download libtpu==0.0.31.dev20251119+nightly -f"https://storage.googleapis.com/jax-releases/libtpu_releases.html"

pip download --pre jax==0.8.1 jaxlib==0.8.1 --index https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
git checkout bf174d6

# Build and upload the docker image
bash dependencies/scripts/docker_build_dependency_image.sh MODE=custom_wheels
bash dependencies/scripts/docker_build_dependency_image.sh \
MODE=nightly \
JAX_VERSION=0.9.1.dev20260225 \
LIBTPU_VERSION=0.0.37.dev20260224+nightly
bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}

# Deactivate the virtual environment
Expand Down Expand Up @@ -239,16 +252,15 @@ does this for you already):
gcloud container clusters get-credentials ${CLUSTER_NAME} --project ${PROJECT_ID} --zone ${ZONE}
```

### Run llama3.1-405b Pretraining Workload
### Run llama3-1-405b Pretraining Workload

The `run_recipe.sh` script contains all the necessary environment variables and
configurations to launch the llama3.1-405b pretraining workload.
configurations to launch the llama3-1-405b pretraining workload.

To run the benchmark, first make the script executable and then run it:

```bash
chmod +x run_recipe.sh

./run_recipe.sh
```

Expand Down Expand Up @@ -301,15 +313,16 @@ xpk workload list --cluster ${CLUSTER_NAME} --project ${PROJECT_ID} --zone ${ZON
For more in-depth debugging, use xpk inspector: (`xpk inspector`)

```bash
xpk inspector --cluster ${CLUSTER_NAME} --project ${PROJECT_ID} --zone ${ZONE} [--workload <workload_name>]
xpk inspector --cluster ${CLUSTER_NAME} --project ${PROJECT_ID} --zone ${ZONE} [--workload ${WORKLOAD_NAME}]
```


### Delete resources

#### Delete a specific workload

```bash
xpk workload delete --workload <workload_name> --cluster ${CLUSTER_NAME} --project ${PROJECT_ID} --zone ${ZONE}
xpk workload delete --workload ${WORKLOAD_NAME} --cluster ${CLUSTER_NAME} --project ${PROJECT_ID} --zone ${ZONE}
# Or filter and delete:
xpk workload delete --cluster ${CLUSTER_NAME} --project ${PROJECT_ID} --zone ${ZONE} --filter-by-job=${USER}
```
Expand All @@ -320,6 +333,7 @@ xpk workload delete --cluster ${CLUSTER_NAME} --project ${PROJECT_ID} --zone ${Z
xpk cluster delete --cluster ${CLUSTER_NAME} --zone ${ZONE} --project ${PROJECT_ID}
```


## Check results

After the job completes, you can check the results by:
Expand All @@ -329,6 +343,7 @@ After the job completes, you can check the results by:
`${BASE_OUTPUT_DIR}` variable in your `run_recipe.sh`.
- Reviewing metrics in Cloud Monitoring, if configured.


## Next steps: deeper exploration and customization

This recipe is designed to provide a simple, reproducible "0-to-1" experience
Expand Down
59 changes: 38 additions & 21 deletions training/ironwood/llama3.1-405b/8k-bf16-tpu7x-4x8x8/run_recipe.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ source "${UV_VENV_PATH}/bin/activate"
# Check if xpk is installed in the venv
if ! pip show xpk &> /dev/null; then
echo "xpk not found in the virtual environment. Please install it by running:"
echo "pip install xpk==0.16.0"
echo "pip install xpk==1.2.0"
exit 1
fi
# --- End Environment Setup ---
Expand All @@ -28,12 +28,13 @@ export PROJECT_ID=""
export CLUSTER_NAME=""
export ZONE=""
export BASE_OUTPUT_DIR=""
export ARTIFACT_DIR="${BASE_OUTPUT_DIR}"
export WORKLOAD_IMAGE=""
export WORKLOAD_NAME="$(printf "%.26s" "${USER//_/-}-llama3-1-405b-8192-4x4x4")-$(date +%Y%m%d-%H%M)"
export WORKLOAD_NAME="$(printf "%.26s" "${USER//_/-}-llama3-1-405b-8192-4x8x8")-$(date +%Y%m%d-%H%M)"

# XLA Flags
XLA_FLAGS=" \
--xla_tpu_impure_enable_packed_bf16_math_ops=true \
--xla_tpu_bf16_emission_mode=NATIVE_EMISSION \
--xla_tpu_enable_sparse_core_reduce_scatter_v2=true \
--xla_tpu_use_single_sparse_core_for_all_gather_offload=true \
--xla_tpu_enable_sparse_core_collective_offload_all_gather=true \
Expand All @@ -47,7 +48,9 @@ XLA_FLAGS=" \
--xla_tpu_prefer_async_allgather_to_allreduce=true \
--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \
--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \
--xla_tpu_scoped_vmem_limit_kib=65536 "
--xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_enable_sparse_core_offload_queuing_in_lhs=true \
--xla_tpu_enable_ici_ar_pipelining=true "

# MaxText Workload Overrides
MAXTEXT_ARGS="\
Expand All @@ -59,30 +62,39 @@ profile_periodically_period=10000 \
async_checkpointing=False \
enable_checkpointing=False \
use_iota_embed=True \
ici_fsdp_parallelism=-1 \
remat_policy=custom \
decoder_layer_input=offload \
ici_fsdp_parallelism=-1 \
dataset_type=synthetic \
opt_type=adamw \
mu_dtype=bfloat16 \
use_tokamax_splash=True \
use_max_logit_estimate=30 \
mlpwo=offload \
key_proj=device \
value_proj=device \
attention=flash \
sa_block_q=2048 \
sa_block_kv=2048 \
sa_block_kv_compute=256 \
sa_block_q=1024 \
sa_block_kv_dkv=2048 \
sa_block_kv_dkv_compute=1024 \
sa_block_kv_compute=1024 \
sa_block_q_dkv=2048 \
sa_k_layout=SEQ_MINOR \
sa_q_layout=HEAD_DIM_MINOR \
sa_v_layout=SEQ_MINOR \
attention=flash \
sa_block_kv_dkv=2048 \
sa_block_kv_dkv_compute=512 \
sa_use_fused_bwd_kernel=True \
sa_q_layout=SEQ_MINOR \
sa_k_layout=SEQ_MINOR \
sa_v_layout=HEAD_DIM_MINOR \
use_splash_scheduler=True \
use_tokamax_splash=True \
dataset_type=synthetic \
opt_type=adamw \
mu_dtype=bfloat16 \
num_vocab_tiling=4 \
max_target_length=8192 \
profiler=xplane \
skip_first_n_steps_for_profiler=8 \
profiler_steps=1 \
steps=30 \
base_output_directory=${BASE_OUTPUT_DIR} \
run_name=${WORKLOAD_NAME}"
run_name=${WORKLOAD_NAME} \
output_dir=${BASE_OUTPUT_DIR}"



xpk workload create \
--cluster=$CLUSTER_NAME \
Expand All @@ -94,8 +106,13 @@ xpk workload create \
--num-slices=1 \
--docker-image="${WORKLOAD_IMAGE}" \
--enable-debug-logs \
\
--workload="${WORKLOAD_NAME}" \
--command="set -e && export ENABLE_PATHWAYS_PERSISTENCE='1' && \
\
--command="set -e && set -o pipefail && export ENABLE_PATHWAYS_PERSISTENCE='1' && \
export LIBTPU_INIT_ARGS='${XLA_FLAGS}' && \
export ARTIFACT_DIR='${ARTIFACT_DIR}' && \
export JAX_PLATFORMS='tpu,cpu' && export ENABLE_PJRT_COMPATIBILITY='true' && \
python3 -m MaxText.train MaxText/configs/base.yml ${MAXTEXT_ARGS}"
\
python3 -m MaxText.train maxtext/configs/base.yml ${MAXTEXT_ARGS} | tee train.log && \
gsutil cp train.log ${BASE_OUTPUT_DIR}/${WORKLOAD_NAME}/logs/train-\${TPU_WORKER_ID}.log"