diff --git a/examples/deepscaler/run_deepscaler_disagg_v5p16.sh b/examples/deepscaler/run_deepscaler_disagg_v5p16.sh index 7feb5b3d9..e0aedf5ba 100755 --- a/examples/deepscaler/run_deepscaler_disagg_v5p16.sh +++ b/examples/deepscaler/run_deepscaler_disagg_v5p16.sh @@ -32,9 +32,15 @@ warmup_ratio="${warmup_ratio:-0.1}" batch_size="${batch_size:-128}" mini_batch_size="${mini_batch_size:-128}" max_response_length="${max_response_length:-8192}" +#TODO(b/510820709) - find the optimal mesh configuration +total_tpus="${total_tpus:-8}" + trainer_mesh="${trainer_mesh:-(4,1)}" rollout_mesh="${rollout_mesh:-(4,1)}" +source "$(dirname "$0")/../tpu_utils.sh" +validate_mesh_allocation "$total_tpus" "$trainer_mesh" "$rollout_mesh" "null" || exit 1 + checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/01}" checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}" if [[ -n "$checkpoint_dir" && "$checkpoint_dir" != "null" ]]; then diff --git a/examples/deepswe/run_deepswe_disagg_v5p_32.sh b/examples/deepswe/run_deepswe_disagg_v5p_32.sh index 5b73f063e..66fdb4a46 100755 --- a/examples/deepswe/run_deepswe_disagg_v5p_32.sh +++ b/examples/deepswe/run_deepswe_disagg_v5p_32.sh @@ -42,10 +42,15 @@ rollout_micro_batch_size="${rollout_micro_batch_size:-1}" num_generations="${num_generations:-2}" max_response_length="${max_response_length:-8192}" - +#TODO(b/510820709) - find the optimal mesh configuration. +total_tpus="${total_tpus:-32}" trainer_mesh="${trainer_mesh:-(8,2)}" rollout_mesh="${rollout_mesh:-(2,8)}" + +source "$(dirname "$0")/../tpu_utils.sh" +validate_mesh_allocation "$total_tpus" "$trainer_mesh" "$rollout_mesh" "null" || exit 1 + checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/01}" checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}" if [[ -n "$checkpoint_dir" && "$checkpoint_dir" != "null" ]]; then diff --git a/examples/rl/grpo/gsm8k/run_gemma3_12b.sh b/examples/rl/grpo/gsm8k/run_gemma3_12b.sh index f900f6c61..8eda350f9 100755 --- a/examples/rl/grpo/gsm8k/run_gemma3_12b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma3_12b.sh @@ -54,8 +54,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_gemma3_1b.sh b/examples/rl/grpo/gsm8k/run_gemma3_1b.sh index 0cad80455..a5ec7fd8a 100755 --- a/examples/rl/grpo/gsm8k/run_gemma3_1b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma3_1b.sh @@ -54,8 +54,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_gemma3_4b.sh b/examples/rl/grpo/gsm8k/run_gemma3_4b.sh index ef89e2b8a..b9e16495c 100755 --- a/examples/rl/grpo/gsm8k/run_gemma3_4b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma3_4b.sh @@ -54,8 +54,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_gemma_7b.sh b/examples/rl/grpo/gsm8k/run_gemma_7b.sh index 75f9f3f7a..8f6072ca0 100755 --- a/examples/rl/grpo/gsm8k/run_gemma_7b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma_7b.sh @@ -55,8 +55,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="/tmp/models/gemma-7b/models/google/gemma/flax/7b-it/2/tokenizer.model" \ tokenizer_config.tokenizer_type="sentencepiece" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh b/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh index b91b04025..24ad05f33 100755 --- a/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh +++ b/examples/rl/grpo/gsm8k/run_llama3.1_8b.sh @@ -53,8 +53,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="meta-llama/Llama-3.1-8B" \ tokenizer_config.tokenizer_type="huggingface" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh b/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh index 63546bf0d..8035d458f 100755 --- a/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh +++ b/examples/rl/grpo/gsm8k/run_llama3.2_1b.sh @@ -53,8 +53,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="meta-llama/Llama-3.2-1B" \ tokenizer_config.tokenizer_type="huggingface" \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_qwen3.sh b/examples/rl/grpo/gsm8k/run_qwen3.sh index 163767ed9..6490d97af 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3.sh @@ -45,8 +45,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path=Qwen/${model_name} \ tokenizer_config.tokenizer_type=huggingface \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh similarity index 97% rename from examples/rl/grpo/gsm8k/run_qwen3_8b.sh rename to examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh index dd5c2c173..4c2d5b2d2 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh @@ -41,10 +41,13 @@ rollout_micro_batch_size="${rollout_micro_batch_size:-8}" compute_logps_micro_batch_size="${compute_logps_micro_batch_size:-1}" num_generations="${num_generations:-4}" - +total_tpus="${total_tpus:-16}" train_mesh="${train_mesh:-(8,1)}" rollout_mesh="${rollout_mesh:-(1,8)}" +source "$(dirname "$0")/../../../tpu_utils.sh" +validate_mesh_allocation "$total_tpus" "$train_mesh" "$rollout_mesh" "null" || exit 1 + checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}" checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}" if [[ -n "$checkpoint_dir" && "$checkpoint_dir" != "null" ]]; then diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b_maxtext.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh similarity index 97% rename from examples/rl/grpo/gsm8k/run_qwen3_8b_maxtext.sh rename to examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh index b094bd8bf..1cc85b569 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3_8b_maxtext.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh @@ -39,10 +39,13 @@ rollout_micro_batch_size="${rollout_micro_batch_size:-8}" compute_logps_micro_batch_size="${compute_logps_micro_batch_size:-1}" num_generations="${num_generations:-4}" - +total_tpus="${total_tpus:-16}" train_mesh="${train_mesh:-(8,1)}" rollout_mesh="${rollout_mesh:-(1,8)}" +source "$(dirname "$0")/../../../tpu_utils.sh" +validate_mesh_allocation "$total_tpus" "$train_mesh" "$rollout_mesh" "null" || exit 1 + checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}" checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}" if [[ -n "$checkpoint_dir" && "$checkpoint_dir" != "null" ]]; then diff --git a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh index d655dfefd..b28f4aad3 100644 --- a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh @@ -54,8 +54,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ actor_model_config.mesh.shape="(2,4)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path=Qwen/${model_name} \ tokenizer_config.tokenizer_type=huggingface \ tokenizer_config.add_bos=false \ diff --git a/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh b/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh index 59b9b2493..18a350149 100755 --- a/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh +++ b/examples/rl/grpo/gsm8k/verl_compatible/run_llama3.2_1b.sh @@ -51,8 +51,10 @@ python3 -m tunix.cli.grpo_main \ actor_model_config.mesh.shape="(4,1)" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ actor_model_config.lora_config={} \ - rollout_model_config.mesh.shape="(4,1)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ + reference_model_config.mesh=null \ + reference_model_config.same_mesh_as="actor" \ + rollout_model_config.mesh=null \ + rollout_model_config.same_mesh_as="actor" \ tokenizer_config.tokenizer_path="meta-llama/Llama-3.2-1B-Instruct" \ tokenizer_config.tokenizer_type="huggingface" \ tokenizer_config.add_bos=false \ diff --git a/examples/tpu_utils.sh b/examples/tpu_utils.sh new file mode 100644 index 000000000..8654c57e8 --- /dev/null +++ b/examples/tpu_utils.sh @@ -0,0 +1,71 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash +# tpu_utils.sh + +# Helper function: extracts dimensions from a mesh tuple and calculates the product +calc_mesh_tpus() { + local input="$1" + if [[ -z "$input" || "$input" == "null" ]]; then + echo 0 + return 0 + fi + # Remove parens and spaces, then replace commas with spaces + local dims="${input//[() ]/}" + dims="${dims//,/ }" + local product=1 + for d in $dims; do + if [[ -n "$d" ]]; then + product=$(( product * d )) + fi + done + echo $product +} + +# Main validation function +# Usage: validate_mesh_allocation +validate_mesh_allocation() { + if [[ $# -ne 4 ]]; then + echo "Error: validate_mesh_allocation requires exactly 4 arguments: " >&2 + echo "Got: $@" >&2 + return 1 + fi + + local total_tpus="$1" + local trainer_mesh="$2" + local rollout_mesh="$3" + local reference_mesh="$4" + + local trainer_tpus + local rollout_tpus + local reference_tpus + local required_tpus + + trainer_tpus=$(calc_mesh_tpus "$trainer_mesh") + rollout_tpus=$(calc_mesh_tpus "$rollout_mesh") + reference_tpus=$(calc_mesh_tpus "$reference_mesh") + required_tpus=$(( trainer_tpus + rollout_tpus + reference_tpus )) + + if (( required_tpus > total_tpus )); then + # Print errors to standard error (stderr) using >&2 + echo "Error: Required TPUs ($required_tpus) exceeds total_tpus ($total_tpus)." >&2 + echo " Trainer needs: $trainer_tpus (mesh: $trainer_mesh)" >&2 + echo " Rollout needs: $rollout_tpus (mesh: $rollout_mesh)" >&2 + echo " Reference needs: $reference_tpus (mesh: $reference_mesh)" >&2 + return 1 # Return failure so the caller can handle it + fi + + return 0 # Return success +}