diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index dc0cc7cbc9..c81945d40e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,3 +1,8 @@ .github/CODEOWNERS @fzyzcjy @Ying1123 .github/workflows/ @yushengsu-thu /miles/ @fzyzcjy @yueming-yuan +/miles/backends/ @fzyzcjy @yueming-yuan @maocheng23 +/miles/ray/ @fzyzcjy @yueming-yuan @maocheng23 +/miles/rollout/ @fzyzcjy @yueming-yuan @guapisolo +/miles/router/ @fzyzcjy @yueming-yuan @guapisolo +/miles/utils/ @fzyzcjy @yueming-yuan @guapisolo @maocheng23 diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index d34c823aa3..eb2e20b9c8 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -25,8 +25,8 @@ concurrency: jobs: - e2e-test-short: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) + fast: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request) runs-on: self-hosted container: image: radixark/miles:latest @@ -47,7 +47,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}] + info: [{"num_gpus": 0, "test_file": "fast"}] defaults: run: working-directory: ${{ github.workspace }} @@ -81,19 +81,68 @@ jobs: - name: Execute shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() + unit-test: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-unit-test')) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 2, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes shell: bash run: | pkill -9 -f 'ray::' 2>/dev/null || true pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 - e2e-test-fsdp: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-fsdp')) + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + e2e-test-sglang: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-sglang')) runs-on: self-hosted container: image: radixark/miles:latest @@ -114,7 +163,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}] + info: [{"num_gpus": 1, "test_file": "e2e/sglang_patch/test_chat_input_ids_equivalence.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -148,19 +197,68 @@ jobs: - name: Execute shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() + e2e-test-short: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 4, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes shell: bash run: | pkill -9 -f 'ray::' 2>/dev/null || true pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 - e2e-test-megatron: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-megatron')) + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + e2e-test-fsdp: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-fsdp')) runs-on: self-hosted container: image: radixark/miles:latest @@ -181,7 +279,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}] + info: [{"num_gpus": 2, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 4, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -217,14 +315,63 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() + e2e-test-megatron: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-megatron')) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes shell: bash run: | pkill -9 -f 'ray::' 2>/dev/null || true pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} e2e-test-precision: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-precision')) @@ -248,7 +395,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}] + info: [{"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -284,15 +431,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - e2e-test-ckpt: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-ckpt')) runs-on: self-hosted @@ -315,7 +453,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}] + info: [{"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}] defaults: run: working-directory: ${{ github.workspace }} @@ -351,15 +489,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - e2e-test-long: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-long')) runs-on: self-hosted @@ -382,7 +511,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}] + info: [{"num_gpus": 2, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -418,20 +547,11 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - e2e-test-image: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-image')) runs-on: self-hosted container: - image: radixark/miles-test:latest + image: radixark/miles:latest options: > --gpus all --ipc=host @@ -449,7 +569,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}] + info: [{"num_gpus": 4, "test_file": "e2e/image/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "e2e/image/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "e2e/image/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "e2e/image/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "e2e/image/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "e2e/image/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "e2e/image/test_qwen2.5_0.5B_gsm8k_async.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -484,12 +604,3 @@ jobs: - name: Execute shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 37b6fa4463..5fdcc201f2 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -1,76 +1,95 @@ <% set jobs = { + 'fast': { + 'test_executor': 'pytest', + 'tests': [ + {'test_file': 'fast', 'num_gpus': 0}, + ], + }, + 'unit-test': { + 'label': 'run-unit-test', + 'tests': [ + {'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2} + ], + }, + 'e2e-test-sglang': { + 'label': 'run-ci-sglang', + 'test_executor': 'pytest', + 'tests': [ + {'test_file': 'e2e/sglang_patch/test_chat_input_ids_equivalence.py', 'num_gpus': 1}, + ], + }, 'e2e-test-short': { 'label': 'run-ci-short', 'tests': [ - {'test_file': 'test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, + {'test_file': 'e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4}, + {'test_file': 'e2e/short/test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4}, + {'test_file': 'e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, ], }, 'e2e-test-fsdp': { 'label': 'run-ci-fsdp', 'tests': [ - {'test_file': 'test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2}, - {'test_file': 'test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, - {'test_file': 'test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, + {'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2}, + {'test_file': 'e2e/fsdp/test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, + {'test_file': 'e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, + {'test_file': 'e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, ], }, 'e2e-test-megatron': { 'label': 'run-ci-megatron', 'tests': [ - {'test_file': 'test_quick_start_glm4_9B.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1'}, - {'test_file': 'test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1', 'enable_eval': '0'}, - {'test_file': 'test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8}, - {'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8}, - {'test_file': 'test_moonlight_16B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, + {'test_file': 'e2e/megatron/test_quick_start_glm4_9B.py', 'num_gpus': 8}, + {'test_file': 'e2e/megatron/test_qwen3_30B_A3B.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1'}, + {'test_file': 'e2e/megatron/test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1', 'enable_eval': '0'}, + {'test_file': 'e2e/megatron/test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'e2e/megatron/test_qwen3_4B_ppo.py', 'num_gpus': 8}, + {'test_file': 'e2e/megatron/test_moonlight_16B_A3B.py', 'num_gpus': 8}, + {'test_file': 'e2e/megatron/test_moonlight_16B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'e2e/megatron/test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, ], }, 'e2e-test-precision': { 'label': 'run-ci-precision', 'tests': [ - {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, + {'test_file': 'e2e/precision/test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, + {'test_file': 'e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, ], }, 'e2e-test-ckpt': { 'label': 'run-ci-ckpt', 'tests': [ - {'test_file': 'test_qwen3_4B_ckpt.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, + {'test_file': 'e2e/ckpt/test_qwen3_4B_ckpt.py', 'num_gpus': 8}, + {'test_file': 'e2e/ckpt/test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, ], }, 'e2e-test-long': { 'label': 'run-ci-long', 'tests': [ - {'test_file': 'test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, - {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, + {'test_file': 'e2e/long/test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, + {'test_file': 'e2e/long/test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, ], }, 'e2e-test-image': { 'label': 'run-ci-image', - 'image': 'radixark/miles-test:latest', + 'image': 'radixark/miles:latest', 'tests': [ - {'test_file': 'test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, - {'test_file': 'test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2}, - {'test_file': 'test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, - {'test_file': 'test_quick_start_glm4_9B.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8}, - {'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8}, - {'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, - {'test_file': 'test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, - {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, + {'test_file': 'e2e/image/test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4}, + {'test_file': 'e2e/image/test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4}, + {'test_file': 'e2e/image/test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, + {'test_file': 'e2e/image/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2}, + {'test_file': 'e2e/image/test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, + {'test_file': 'e2e/image/test_quick_start_glm4_9B.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_30B_A3B.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_4B_ppo.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_moonlight_16B_A3B.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, + {'test_file': 'e2e/image/test_qwen3_4B_ckpt.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, + {'test_file': 'e2e/image/test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, ], }, } %> @@ -98,7 +117,7 @@ concurrency: jobs: <% for job_name, config in jobs.items() %> << job_name >>: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, '<< config.label >>')) + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request<% if config.label %> && contains(github.event.pull_request.labels.*.name, '<< config.label >>')<% endif %>) runs-on: self-hosted container: image: << config.image if config.image else 'radixark/miles:latest' >> @@ -153,14 +172,5 @@ jobs: - name: Execute shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true -<% endfor %> \ No newline at end of file + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- << config.test_executor | default('python') >> tests/${{ matrix.info.test_file }} +<% endfor %> diff --git a/README.md b/README.md index 809f78471b..1d5bd5d4a8 100644 --- a/README.md +++ b/README.md @@ -1,212 +1,160 @@ -
-logo +
-[![GitHub Repo](https://img.shields.io/badge/github-radixark%2Fmiles-black?logo=github)](https://github.com/radixark/miles) - - -
- - -> A journey of a thousand miles is made one small step at a time. +Miles Logo -**Miles** is an enterprise-facing reinforcement learning framework for **large-scale MoE post-training and production workloads**, forked from and co-evolving with **[slime](https://github.com/THUDM/slime)**. +### **Enterprise-Grade Reinforcement Learning for Large-Scale Model Training** +### **High-Performance Rollout • Low Precision Training • Production Stability** -Miles keeps slime’s lightweight, modular design, but focuses on: - -- New hardware support (e.g., GB300 and beyond) -- Stable, controllable RL for large MoE models -- Production-grade features - - -## News +[![GitHub Repo](https://img.shields.io/badge/github-radixark%2Fmiles-black?logo=github)](https://github.com/radixark/miles) +[![License](https://img.shields.io/github/license/radixark/miles)](LICENSE) +[![Slack](https://img.shields.io/badge/slack-join-brightgreen.svg)](https://slack.sglang.ai) -- [2025/12] Support FSDP2 as A Training Backend for Miles ([blog](https://lmsys.org/blog/2025-12-03-miles-fsdp/)). -- [2025/11] Unified FP8: Moving Beyond Mixed Precision for Stable and Accelerated MoE RL ([blog](https://lmsys.org/blog/2025-11-25-fp8-rl/)). -- [2025/11] Power Up Speculative Decoding In Reinforcement Learning ([blog](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/spec/readme-en.md)). -- [2025/11] Introduce Miles - born after slime towards enterprise RL training ([blog](https://lmsys.org/blog/2025-11-19-miles/)). +[**Latest Updates**](#latest-updates) | [**Quick Start**](#quick-start) | [**Key Features**](#key-features) | [**Documentation**](docs/en/get_started/quick_start.md) +
--- -## Table of Contents -- [Quick Start](#quick-start) -- [Arguments Walkthrough](#arguments-walkthrough) -- [Developer Guide](#developer-guide) -- [Recent Updates](#recent-updates) -- [Roadmap](#roadmap) -- [Architecture Overview](#architecture-overview) -- [FAQ & Acknowledgements](#faq--acknowledgements) ---- +## Latest Updates -## Quick Start +* **[2026/01]** 💎 **INT4 Quantization-Aware Training (QAT)**: Inspired by the Kimi K2-Thinking report, Miles now features a full-stack INT4 W4A16 QAT pipeline. This allows 1TB-scale models to fit into single-machine VRAM (e.g., NVIDIA H200), doubling rollout efficiency by eliminating cross-node bottlenecks while maintaining BF16-equivalent accuracy. [Blog](https://lmsys.org/blog/2026-01-26-int4-qat/) +* **[2026/01]** 💎 **Unified VLM/LLM Multi-Turn Training**: We provided an implementation for the VLM multi-turn sampling paradigm. Developers only need to write a customized `rollout` function to easily start multi-turn RL for VLM, just like training LLM. [Blog](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/vlm-multi-turn/readme-en.md) +* **[2026/01]** 🤖 **Multi-Agent Co-Evolution**: Miles now supports **MrlX**, a novel asynchronous co-evolutionary framework for Multi-Agent RL. Achieve superior performance in complex tasks like Doctor-Patient simulations and DeepResearch pipelines by enabling specialized agents to evolve together symbiotically. [[Link]](https://github.com/AQ-MedAI/MrlX) +* **[2025/12]** 🔄 **Rollout Routing Replay (R3)**: In collaboration with SGLang, we've launched R3 to solve MoE RL instability. R3 records inference routing decisions and replays them during training, effectively eliminating the "training-inference mismatch" and preventing training collapse in large MoE models like Qwen3 and DeepSeek-V3. [[Paper]](https://arxiv.org/pdf/2510.11370) [[Docs]](docs/en/advanced/miles-router.md#22-rollout-routing-replay-r3-for-moe) +* **[2025/11]** 🔥 **Unified FP8 Release**: Solves the stability issues in MoE RL by ensuring training and inference use the exact same FP8 quantization logic. [[Blog]](https://lmsys.org/blog/2025-11-25-fp8-rl/) +* **[2025/11]** ⚡ **Speculative Decoding in RL**: Integrated speculative rollout with online SFT for draft models, achieving massive throughput gains. [[Blog]](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/spec/readme-en.md) +* **[2025/11]** 🎉 **Miles Project Launch**: A joint effort by InfiXAI, Ant Group, SGLang RL Team, and the Miles community. [[Announcement]](https://lmsys.org/blog/2025-11-19-miles/) -> **Note:** Miles is under active development. Commands and examples may evolve; please check the repo for the latest instructions. +## What is Miles? -For a comprehensive quick start guide covering environment setup, data preparation, training startup, and key code analysis, please refer to: -- [Quick Start Guide](./docs/en/get_started/quick_start.md) +**Miles** is a high-performance, enterprise-ready reinforcement learning (RL) framework specifically optimized for **Large-Scale model Post-Training**. Built as a powerful fork of **[slime](https://github.com/THUDM/slime)**, Miles bridges the gap between research-grade RL and production-grade reliability by integrating **SGLang** for high-throughput rollout and **Megatron-LM** for scalable training. -We also provide examples for some use cases not covered in the quick start guide; please check [examples](examples/). +> *"A journey of a thousand miles begins with a single step."* — Miles focuses on the low-level system optimizations that make large-scale RL stable, efficient, and reproducible. --- -## Arguments Walkthrough - -Arguments in Miles follow the same three-layer pattern as slime: - -1. **Megatron arguments**: Megatron arguments are exposed unchanged, e.g. `--tensor-model-parallel-size 2`. -2. **SGLang arguments**: All SGLang arguments are exposed with a prefix `--sglang-`, e.g. `--mem-fraction-static` → `--sglang-mem-fraction-static`. +## Key Features -3. **Miles-specific arguments*: Please refer to [`miles/utils/arguments.py`](miles/utils/arguments.py) for a full list +### 🌪️ Advanced MoE & Low-Precision Training -For more detailed usage, please refer to the documentation and example configs in the repo as they become available. - +* **Unified FP8 Pipeline**: The first framework to implement end-to-end FP8 sampling and training. By unifying precision across rollout and training, Miles eliminates the quantization-induced discrepancy that causes RL collapse in large MoE models. +* **Rollout Routing Replay (R3)**: Records expert routing decisions during SGLang inference and replays them during training to ensure bit-wise expert alignment. +* **INT4 QAT Support**: Recommendation for 1TB+ models to enable single-machine (e.g., H200) deployment by significantly reducing memory footprint. +### 🛡️ Eliminating Train-Inference Mismatch -## Recent Updates +* **Bit-wise Identical Training and Inference Log Probs**: System-level solution achieving deterministic forward/backward passes through kernel-level optimization (FlashAttention-3, DeepGEMM). +* **Algorithmic Correction (TIS/MIS)**: When mismatch is unavoidable, Miles provides **Truncated Importance Sampling (TIS)** and **Masked Importance Sampling (MIS)** to mitigate off-policy bias and prevent training divergence. -Miles starts from slime’s proven backbone and adds a series of upgrades for production environments. The recent PRs and changes have also been synced to slime side. +### ⚡ Extreme Performance & Efficiency -### ✅ True On-Policy +* **Speculative RL Training**: Achieve **25%+ rollout speedup** by using an **Online SFT Draft Model**. Unlike frozen draft models, Miles updates the draft policy during RL to prevent policy drift. +* **Zero-Copy Weight Sync**: Optimized weight refit via **CUDA IPC zero-copy mapping**, async tensor gathering, and bucketed flattening. Sync time reduced by 50% compared to standard HTTP/RPC transfers. +* **Partial Rollout & Over-Sampling**: Handles the "Long-Tail Effect" in multi-turn RL by over-sampling requests and recycling half-finished trajectories to maximize GPU utilization. -Miles extends slime’s deterministic training and supports **infrastructure-level true on-policy support** for SGLang + FSDP: +## Model Support & Training Diversity -- Keeps the mismatch between **training** and **inference** effectively at **zero** -- Aligns numerical behavior end-to-end between training and deployment -- Uses: - - FlashAttention-3 - - DeepGEMM - - Batch-invariant kernels from Thinking Machines Lab - - `torch.compile` and careful alignment of numeric operations +### 🏗️ Supported Models +Miles supports a wide range of state-of-the-art architectures, with a special emphasis on **DeepSeek, Qwen, Llama** and mainstream models. -This makes Miles suitable for **high-stakes experiments** where repeatability, auditability, and production debugging matter. +| Family | Supported Models | +| :--- | :--- | +| **DeepSeek** | **R1, V3, V3.2** | +| **Qwen** | **Qwen 2, 2.5, 3** | +| **Llama** | **Llama 3, 3.1, 3.3, 4** | +| **Gemma** | **Gemma 2, 3, 3N** | +| **GLM** | **GLM-4.5, GLM-4.6, GLM-4.7** | +| **MiniMax** | **M2, M2.1** | +| **Others** | **Mistral, Mixtral, Phi, gpt-oss and any model supported by SGLang and Megatron** | -### 🧮 Memory Robustness & Efficiency +### 🧩 Diverse Training Scenarios +Miles is designed to handle the complexity of modern RL workloads across various dimensions: +* **Multi-Turn Interaction**: Optimized for complex, multi-round conversations and tool-use scenarios. +* **VLM & LLM Support**: Unified framework for both Vision-Language and pure Text models. +* **Reasoning & Coding**: Specific recipes and optimizations for **Reasoning (Math/Logic)** and **Coding Agent** tasks. +* **Multi-Agent Training**: Support for advanced co-training and collaborative multi-agent reinforcement learning. -To fully utilize precious GPU memory **without** constant OOM failures, Miles includes: - -- Graceful handling of benign OOMs via error propagation -- Memory margins to avoid NCCL-related OOM issues -- Fixes for FSDP excessive memory usage -- Support for move-based and partial offloading -- Host peak memory savings for smoother multi-node training - -The goal is to let large MoE jobs run **closer to the hardware limit** while staying stable. +--- -### ⚡ Speculative Training +## Quick Start -Miles adds **speculative training** support tailored for RL: +### Installation -- Performs **online SFT on the draft model during RL**, instead of freezing it -- Avoids draft policy drift away from the target model -- Achieves **25%+ rollout speedup** vs. frozen MTP, especially in later training stages -- Includes: - - MTP with sequence packing + CP - - Proper loss masking and edge-case handling - - LM head / embedding gradient isolation - - Weight sync flows between Megatron and SGLang +We recommend using our official Docker image for the best performance and compatibility: -### 🧱 Hardware & Examples +```bash +# Pull the latest image +docker pull radixark/miles:latest -Miles actively tracks new hardware and provides usable examples: +# Or install from source +pip install -r requirements.txt +pip install -e . +``` -- GB300 training support, with more recipes coming -- A **formal mathematics (Lean)** example with SFT / RL scripts, showcasing Miles in a verifiable environment setting +### Launch Training -### 🛠 Miscellaneous Improvements +Miles provides a unified entry point for complex RL tasks. Here is an example of FP8 GRPO training for Qwen3: -Additional engineering improvements include: +```bash +python train.py \ + --advantage-estimator grpo \ + --model-name qwen3-30b-a3b \ + --hf-checkpoint /path/to/qwen3-30b-a3b-hf \ + --rollout-batch-size 512 \ + --n-samples-per-prompt 8 +``` -- Enhanced FSDP training backend -- Option to deploy the **rollout subsystem independently** outside the main framework -- Better debugging & profiling: more metrics, post-hoc analyzers, and profiler integration -- Gradual refactoring for clarity and maintainability +For comprehensive guides on environment setup and custom reward functions, see the [Quick Start Guide](docs/en/get_started/quick_start.md). --- ## Roadmap -We are actively evolving Miles toward a **production-ready RL engine** for large-scale MoE and multimodal workloads. Current roadmap items include: +### ✅ Completed -- **Large-scale MoE RL recipes** on new hardware (e.g., GB300 and successors) -- **Multimodal training** support -- **Rollout accelerations** - - Compatibility with SGLang spec v2 for improved performance - - More advanced speculative training schemes (e.g., EAGLE3-style, multi-spec layers) -- **Elasticity & fault tolerance** - - More robust handling of GPU / node failures in long-running jobs -- **Resource scheduling for async training** - - Balancing training and serving in large-scale asynchronous RL systems +- [x] **Unified FP8** E2E Training & Rollout +- [x] **INT4 Quantization-Aware Training (QAT)**: Single-machine 1TB models +- [x] **Speculative RL** with Online SFT +- [x] **Multi-Agent RL** (Co-evolutionary frameworks like [MrlX](https://github.com/AQ-MedAI/MrlX)) +- [x] **Support DeepSeek V3.2 Models** +- [x] **VLM Multi-Turn Training** +- [x] **Aligning SGLang with Megatron in Dense Models** +- [x] **Rollout Routing Replay (R3)** -We’ll continue to iterate based on feedback from users across research labs, startups, and enterprise teams. - ---- - -## Architecture Overview - -Miles inherits slime’s core architecture as below. - - -![arch](./imgs/arch.png) +### 🏗️ In Progress & Planned +- [ ] **Zero mismatch for MoE RL** +- [ ] **Aligning SGLang with Megatron in MoE Models** +- [ ] **Diffusion RL** Support +- [ ] **Omni RL** Support +- [ ] **Diffusion LLM RL** Support +- [ ] **Elastic Resource Scheduling**: Dynamic scaling of rollout vs. training workers -**Module overview:** -- **training (Megatron)** - Main training loop. Reads data from the Data Buffer and synchronizes parameters to the rollout subsystem after updates. - -- **rollout (SGLang + router)** - Generates new samples, including rewards / verifier outputs, and writes them back to the Data Buffer. - -- **data buffer** - Manages prompt initialization, custom data sources, and rollout generation strategies. Serves as the bridge between training and rollout. - -This decoupled design lets you: - -- Swap in different algorithms / reward functions without touching rollout code -- Customize rollout engines independently from training -- Scale rollouts and training differently depending on hardware and deployment constraints --- +## Acknowledgements -## Developer Guide - -* **Contributions welcome!** - We’re especially interested in: - - * New hardware backends & tuning - * MoE RL recipes - * Stability / determinism improvements - * Multimodal & speculative training use cases +Miles is built upon the shoulders of giants in the LLM infrastructure ecosystem: +* **[slime](https://github.com/THUDM/slime)**: The core modular architecture and inspiration. +* **[SGLang](https://github.com/sgl-project/sglang)**: The high-performance inference engine. +* **[Megatron-LM](https://github.com/NVIDIA/Megatron-LM)**: Robust large-scale training components. -* We recommend using [pre-commit](https://pre-commit.com/) to keep style consistent: - -```bash -apt install pre-commit -y -pre-commit install - -# run pre-commit to ensure code style consistency -pre-commit run --all-files --show-diff-on-failure --color=always -``` - -* For debugging tips, performance tuning, and internal architecture notes, see the `docs/` and `developer_guide/` folders (coming soon). - ---- - -## FAQ & Acknowledgements - -* For FAQs, please see `docs/en/get_started/qa.md` (to be added as the project matures). -* **Huge thanks** to the **slime** authors and community — Miles would not exist without slime’s design and ecosystem. -* We also acknowledge and rely on the broader LLM infra ecosystem, including SGLang, Megatron-LM, and related tools. +Special thanks to **InfiXAI Team**, **Ant Group AQ Team**, **SGLang RL Team**, and the **Miles Team**. We also thank **DataCrunch** for compute sponsorship and **NVIDIA** for technical support on Transformer Engine (TE). --- ## Links -* **Miles GitHub**: [https://github.com/radixark/miles](https://github.com/radixark/miles) -* **slime GitHub**: [https://github.com/THUDM/slime](https://github.com/THUDM/slime) +* **GitHub**: [https://github.com/radixark/miles](https://github.com/radixark/miles) +* **Slime Project**: [https://github.com/THUDM/slime](https://github.com/THUDM/slime) +* **Developer Guide**: Check the `docs/` and `examples/` directories for in-depth technical notes. -We’re excited to see what you build — whether you choose **slime**, **Miles**, or both in different parts of your stack. 🚀 +
+**Give Miles a ⭐️ Star if it helps your RL journey!** + +
diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 8dafd4cd55..41c5e93563 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -263,7 +263,7 @@ COPY amd_patch/sglv0.5.0rc0 /workspace/patch RUN pip uninstall -y megatron-core && \ git clone https://github.com/NVIDIA/Megatron-LM && \ cd Megatron-LM && \ - git checkout 48406695c4efcf1026a7ed70bb390793918dd97b && \ + git checkout 3714d81d418c9f1bca4594fc35f9e8289f652862 && \ git apply /workspace/patch/amd_megatron_fused_kernels_init.patch && \ pip install -vvv -e . && \ cd /workspace/ diff --git a/docker/Dockerfile.rocm_MI350-5 b/docker/Dockerfile.rocm_MI350-5 index dd32f32c5e..7db29a7517 100644 --- a/docker/Dockerfile.rocm_MI350-5 +++ b/docker/Dockerfile.rocm_MI350-5 @@ -1,8 +1,23 @@ #### Use the base image for ROCm 7 / gfx950 (MI355) -# The Docker image built with this Dockerfile: +# ===================================================================== +# Docker Image Version Information (Updated: Feb 5, 2026) +# ===================================================================== # Base image: ROCm 7 with vllm pre-built for gfx950 # Target GPU: MI355 (gfx950) +# +# Key Dependencies: +# - sglang: v0.5.7 +# - sgl_kernel: 0.3.20 (built from sglang v0.5.7) +# - Megatron-LM: commit 3714d81d418c9f1bca4594fc35f9e8289f652862 +# - TransformerEngine: commit 90c04bcdc3c109505b318f40a39680263af55edf +# - aiter: v0.1.7.post2 +# - Ray: 2.47.1 +# +# Patches: amd_patch/sglv0.5.7/ +# - sglang.patch +# - megatron.patch, amd_megatron_fused_kernels_init.patch +# ===================================================================== FROM rocm/sgl-dev:rocm7-vllm-20250904 @@ -70,7 +85,7 @@ RUN pip uninstall -y megatron-core || true RUN rm -rf Megatron-LM RUN git clone https://github.com/NVIDIA/Megatron-LM \ && cd Megatron-LM \ - && git checkout 48406695c4efcf1026a7ed70bb390793918dd97b \ + && git checkout 3714d81d418c9f1bca4594fc35f9e8289f652862 \ && pip install -e . ######################################### ######################################### @@ -99,7 +114,7 @@ RUN pip install "ray[data,train,tune,serve]==2.47.1" ######################################### ###6. Install torch_memory_saver######### ######################################### -RUN pip install torch_memory_saver +RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@64a92e1d7fb822ea4af5579c8cebb162692c531c --no-cache-dir --force-reinstall ######################################### ######################################### @@ -167,7 +182,7 @@ RUN pip uninstall -y sgl_kernel sglang || true RUN rm -rf sglang RUN git clone https://github.com/sgl-project/sglang.git \ && cd sglang \ - && git checkout v0.5.6 + && git checkout v0.5.7 # Build sgl-kernel for gfx950 RUN cd sglang/sgl-kernel \ @@ -194,8 +209,8 @@ RUN python -m pip cache purge #### APPLY PATCHES (gfx950/MI355) ######### ########################################### -# Copy patches from miles repo -COPY amd_patch/latest /app/patch +# Copy patches from miles repo (sglang v0.5.7 specific) +COPY amd_patch/sglv0.5.7 /app/patch # Apply Megatron patches RUN cd /app/Megatron-LM \ @@ -209,7 +224,7 @@ RUN cd /app/Megatron-LM \ # Apply SGLang patch RUN cd /app/sglang \ - && git apply /app/patch/sglang.patch || echo "Check patch compatibility with v0.5.6" \ + && git apply /app/patch/sglang.patch \ && if grep -R -n '^<<<<<<< ' .; then \ echo "Patch failed to apply cleanly. Please resolve conflicts." && \ exit 1; \ diff --git a/docker/amd_patch/latest/megatron.patch b/docker/amd_patch/latest/megatron.patch index c840133cef..b9e6a61d7c 100644 --- a/docker/amd_patch/latest/megatron.patch +++ b/docker/amd_patch/latest/megatron.patch @@ -1,5 +1,5 @@ diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py -index fe26e8b4..4451f277 100644 +index fe26e8b43..4451f2776 100644 --- a/megatron/core/distributed/__init__.py +++ b/megatron/core/distributed/__init__.py @@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads @@ -19,7 +19,7 @@ index fe26e8b4..4451f277 100644 + if hasattr(custom_fsdp, 'MegatronFSDP'): + custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py -index 99c3edc0..26ea5cb4 100644 +index 99c3edc05..26ea5cb4b 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -404,6 +404,7 @@ class TELinear(te.pytorch.Linear): @@ -31,7 +31,7 @@ index 99c3edc0..26ea5cb4 100644 # Reduce the gradient on the expert_data_parallel group for expert linear layers setattr(param, "allreduce", not self.expert_parallel) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py -index 002edb92..f7273488 100755 +index 002edb925..f72734885 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -80,6 +80,8 @@ def get_gpt_layer_with_transformer_engine_spec( @@ -56,7 +56,7 @@ index 002edb92..f7273488 100755 "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py -index df9adc3e..2f4f544a 100644 +index df9adc3ef..2f4f544a7 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -443,7 +443,7 @@ class GPTModel(LanguageModule): @@ -69,7 +69,7 @@ index df9adc3e..2f4f544a 100644 input_ids=input_ids, position_ids=position_ids, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py -index 57332ac3..f3abd642 100644 +index 57332ac39..f2d0fa9c8 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -9,6 +9,7 @@ from typing import Callable, List, Optional @@ -80,222 +80,8 @@ index 57332ac3..f3abd642 100644 from .utils import GlobalMemoryBuffer, is_torch_min_version -@@ -163,6 +164,213 @@ def get_nccl_options(pg_name, nccl_comm_cfgs): - return None - - -+old_new_group = None -+ -+ -+def monkey_patch_torch_dist(): -+ print("Applying monkey patch to torch.distributed", flush=True) -+ global old_new_group -+ if old_new_group is not None: -+ return -+ -+ old_new_group = dist.new_group -+ -+ def new_group(*args, **kwargs): -+ group = old_new_group(*args, **kwargs) -+ # skip none nccl group. -+ if ( -+ len(args) >= 3 and args[2] == "gloo" or -+ "backend" in kwargs and kwargs["backend"] == "gloo" -+ ): -+ return group -+ -+ # Get ranks from arguments -+ if len(args) >= 1 and args[0] is not None: -+ ranks = args[0] -+ elif "ranks" in kwargs and kwargs["ranks"] is not None: -+ ranks = kwargs["ranks"] -+ else: -+ # If no ranks specified, use all ranks in world -+ ranks = list(range(dist.get_world_size())) -+ -+ if len(ranks) == 1: -+ return group -+ -+ group = ReloadableProcessGroup(group, ranks) -+ return group -+ -+ dist.new_group = new_group -+ -+ def get_new_function(func): -+ def new_function(*args, **kwargs): -+ args = ( -+ arg.group if isinstance(arg, ReloadableProcessGroup) else arg -+ for arg in args -+ ) -+ kwargs = { -+ k: (v.group if isinstance(v, ReloadableProcessGroup) else v) -+ for k, v in kwargs.items() -+ } -+ return func(*args, **kwargs) -+ return new_function -+ -+ dist.get_rank = get_new_function(dist.get_rank) -+ dist.get_world_size = get_new_function(dist.get_world_size) -+ dist.get_backend = get_new_function(dist.get_backend) -+ dist.get_global_rank = get_new_function(dist.get_global_rank) -+ dist.get_group_rank = get_new_function(dist.get_group_rank) -+ dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks) -+ -+ dist.all_reduce = get_new_function(dist.all_reduce) -+ dist.all_gather = get_new_function(dist.all_gather) -+ dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor) -+ dist.all_gather_object = get_new_function(dist.all_gather_object) -+ dist.all_to_all = get_new_function(dist.all_to_all) -+ dist.all_to_all_single = get_new_function(dist.all_to_all_single) -+ dist.broadcast = get_new_function(dist.broadcast) -+ dist.reduce = get_new_function(dist.reduce) -+ dist.reduce_scatter = get_new_function(dist.reduce_scatter) -+ dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor) -+ dist.scatter = get_new_function(dist.scatter) -+ dist.gather = get_new_function(dist.gather) -+ dist.barrier = get_new_function(dist.barrier) -+ dist.send = get_new_function(dist.send) -+ dist.recv = get_new_function(dist.recv) -+ dist._coalescing_manager = get_new_function(dist._coalescing_manager) -+ -+ # p2p -+ old_isend = dist.isend -+ old_irecv = dist.irecv -+ -+ dist.isend = get_new_function(dist.isend) -+ dist.irecv = get_new_function(dist.irecv) -+ -+ def get_new_p2pop_function(func): -+ def new_function(*args, **kwargs): -+ def convert(arg): -+ if isinstance(arg, ReloadableProcessGroup): -+ return arg.group -+ elif arg == dist.isend: -+ arg = old_isend -+ elif arg == dist.irecv: -+ arg = old_irecv -+ return arg -+ -+ args = (convert(arg) for arg in args) -+ kwargs = { -+ k: convert(v) -+ for k, v in kwargs.items() -+ } -+ return func(*args, **kwargs) -+ return new_function -+ -+ dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__) -+ dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__) -+ -+ -+ -+class ReloadableProcessGroup(torch.distributed.ProcessGroup): -+ GROUPS = [] -+ -+ def __init__(self, group, ranks): -+ super().__init__( -+ rank=dist.get_rank(group), -+ size=dist.get_world_size(group), -+ ) -+ #print(f"Creating ReloadableProcessGroup with ranks: {ranks}", flush=True) -+ self.group = group -+ self.group_info = { -+ "ranks": ranks, -+ } -+ ReloadableProcessGroup.GROUPS.append(self) -+ -+ def __getattr__(self, name): -+ return getattr(self.group, name) -+ -+ @staticmethod -+ def destroy_process_groups(): -+ for reloadable_group in ReloadableProcessGroup.GROUPS: -+ if reloadable_group.group is None: -+ continue -+ #print(f"Destroying process group: {reloadable_group.group_info['ranks']}") -+ dist.destroy_process_group(reloadable_group.group) -+ del reloadable_group.group -+ reloadable_group.group = None -+ -+ @staticmethod -+ def reload_process_groups(): -+ for reloadable_group in ReloadableProcessGroup.GROUPS: -+ if reloadable_group.group is not None: -+ continue -+ #print(f"Reloading process group: {reloadable_group.group_info['ranks']}") -+ group = old_new_group( -+ ranks=reloadable_group.group_info["ranks"], -+ backend="nccl" -+ ) -+ reloadable_group.group = group -+ -+ def rank(self) -> int: return self.group.rank() -+ def size(self) -> int: return self.group.size() -+ def name(self) -> str: return self.group.name() -+ -+ def shutdown(self) -> None: -+ if self.group is not None: -+ self.group.shutdown() -+ -+ def abort(self) -> None: -+ if self.group is not None: -+ self.group.abort() -+ -+ def _fwd(self, method, *args, **kwargs): -+ inner = self.group -+ if inner is None: -+ raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.") -+ return getattr(inner, method)(*args, **kwargs) -+ -+ def barrier(self, *a, **kw): return self._fwd("barrier", *a, **kw) -+ def broadcast(self, *a, **kw): return self._fwd("broadcast", *a, **kw) -+ def allreduce(self, *a, **kw): return self._fwd("allreduce", *a, **kw) -+ def allreduce_coalesced(self, *a, **kw): return self._fwd("allreduce_coalesced", *a, **kw) -+ def reduce(self, *a, **kw): return self._fwd("reduce", *a, **kw) -+ def allgather(self, *a, **kw): return self._fwd("allgather", *a, **kw) -+ def _allgather_base(self, *a, **kw): return self._fwd("_allgather_base", *a, **kw) -+ def allgather_coalesced(self, *a, **kw): return self._fwd("allgather_coalesced", *a, **kw) -+ def allgather_into_tensor_coalesced(self, *a, **kw): return self._fwd("allgather_into_tensor_coalesced", *a, **kw) -+ def gather(self, *a, **kw): return self._fwd("gather", *a, **kw) -+ def scatter(self, *a, **kw): return self._fwd("scatter", *a, **kw) -+ def reduce_scatter(self, *a, **kw): return self._fwd("reduce_scatter", *a, **kw) -+ def _reduce_scatter_base(self, *a, **kw): return self._fwd("_reduce_scatter_base", *a, **kw) -+ def reduce_scatter_tensor_coalesced(self, *a, **kw): return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw) -+ def alltoall_base(self, *a, **kw): return self._fwd("alltoall_base", *a, **kw) -+ def alltoall(self, *a, **kw): return self._fwd("alltoall", *a, **kw) -+ def send(self, *a, **kw): return self._fwd("send", *a, **kw) -+ def recv(self, *a, **kw): return self._fwd("recv", *a, **kw) -+ def recv_anysource(self, *a, **kw): return self._fwd("recv_anysource", *a, **kw) -+ -+ def _start_coalescing(self, *a, **kw): return self._fwd("_start_coalescing", *a, **kw) -+ def _end_coalescing(self, *a, **kw): return self._fwd("_end_coalescing", *a, **kw) -+ def _get_backend_name(self): return self._fwd("_get_backend_name") -+ def _get_backend(self, *a, **kw): return self._fwd("_get_backend", *a, **kw) -+ def _set_default_backend(self, *a, **kw): return self._fwd("_set_default_backend", *a, **kw) -+ @property -+ def bound_device_id(self): return self.group.bound_device_id -+ @bound_device_id.setter -+ def bound_device_id(self, dev): self.group.bound_device_id = dev -+ -+ -+def destroy_process_groups(): -+ """Destroy all reloadable process groups.""" -+ ReloadableProcessGroup.destroy_process_groups() -+ -+ -+def reload_process_groups(): -+ """Reload all reloadable process groups.""" -+ ReloadableProcessGroup.reload_process_groups() -+ -+ -+monkey_patch_torch_dist() -+ -+ - def create_group( - ranks=None, - timeout=None, diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py -index 63ee9d1f..b90b744c 100644 +index 63ee9d1f5..b90b744c1 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( @@ -326,7 +112,7 @@ index 63ee9d1f..b90b744c 100644 ops.append(recv_next_op) if len(ops) > 0: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index 6f557e1f..b295fd35 100644 +index 6f557e1f5..b295fd351 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig): @@ -340,7 +126,7 @@ index 6f557e1f..b295fd35 100644 """Whether to run real-time tests.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py -index 84f22bde..b4807d26 100644 +index 84f22bdea..b4807d261 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -224,6 +224,7 @@ class TransformerLayerSubmodules: @@ -412,7 +198,7 @@ index 84f22bde..b4807d26 100644 # discard the output of the pre-mlp layernorm and register the recompute # as a gradient hook of mlp_output_with_bias[0] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index 24ba8926..4f039fd4 100644 +index 24ba89263..4f039fd43 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1191,6 +1191,9 @@ def core_transformer_config_from_args(args, config_class=None): diff --git a/docker/amd_patch/sglv0.5.0rc0/megatron.patch b/docker/amd_patch/sglv0.5.0rc0/megatron.patch index c840133cef..b129959aff 100644 --- a/docker/amd_patch/sglv0.5.0rc0/megatron.patch +++ b/docker/amd_patch/sglv0.5.0rc0/megatron.patch @@ -1,5 +1,56 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index 5a1ea308d..aa701237f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py -index fe26e8b4..4451f277 100644 +index fe26e8b43..4451f2776 100644 --- a/megatron/core/distributed/__init__.py +++ b/megatron/core/distributed/__init__.py @@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads @@ -19,10 +70,10 @@ index fe26e8b4..4451f277 100644 + if hasattr(custom_fsdp, 'MegatronFSDP'): + custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py -index 99c3edc0..26ea5cb4 100644 +index acb93ef78..d239db4ab 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py -@@ -404,6 +404,7 @@ class TELinear(te.pytorch.Linear): +@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): ) for param in self.parameters(): @@ -30,49 +81,418 @@ index 99c3edc0..26ea5cb4 100644 if is_expert: # Reduce the gradient on the expert_data_parallel group for expert linear layers setattr(param, "allreduce", not self.expert_parallel) +@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): + + + if HAVE_TE and is_te_min_version("1.9.0.dev0"): ++ def ceil_div(x: int, y: int) -> int: ++ return (x + y - 1) // y ++ ++ class _FakeInt4QuantizationSTE(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, x, group_size): ++ m, n = x.shape ++ block_size_m, block_size_n = 1, group_size ++ ++ ++ m_padded = ceil_div(m, block_size_m) * block_size_m ++ n_padded = ceil_div(n, block_size_n) * block_size_n ++ ++ x_padded = torch.zeros( ++ (m_padded, n_padded), ++ dtype=x.dtype, device=x.device ++ ) ++ x_padded[:m, :n] = x ++ ++ x_view = x_padded.view( ++ m_padded // block_size_m, ++ block_size_m, ++ n_padded // block_size_n, ++ block_size_n ++ ) ++ ++ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) ++ q_max = 7 ++ x_scale = x_max / q_max ++ ++ x_scale = x_scale.clamp(min=1e-5) ++ ++ x_div = x_view / x_scale ++ x_round = torch.round(x_div) ++ ++ x_q_clamped = x_round.clamp(-q_max, q_max) ++ ++ x_dequant_view = x_q_clamped * x_scale ++ ++ x_dequant_full = x_dequant_view.view_as(x_padded) ++ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) ++ ++ return x_out ++ ++ @staticmethod ++ def backward(ctx, grad_output): ++ return grad_output, None ++ ++ def fake_int4_quantization_ste(x, group_size): ++ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) ++ ++ if hasattr(x, 'main_grad'): ++ x_out.main_grad = x.main_grad ++ ++ return x_out + + class TEGroupedLinear(te.pytorch.GroupedLinear): + """ +@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) ++ + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + +@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + return out + return out, None + ++ def _get_weight_tensors(self): ++ """Get the weight tensors of the module.""" ++ weight_tensors = super()._get_weight_tensors() ++ ++ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": ++ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) ++ ++ weight_tensors = [ ++ fake_int4_quantization_ste(w, group_size) ++ for w in weight_tensors ++ ] ++ ++ return weight_tensors ++ + def _encode_extra_state(self, state): + # TE 2.0 changed the format of extra_state to be a byte tensor + if is_te_min_version("2.0.0"): +diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +index 1fd5dcfae..c9aeef1f0 100644 +--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py ++++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( + cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + +- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads +- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads +- mask = kv_off < head_num * stride_kv_nheads +- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] +- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] +- k = tl.load(KV_ptr + k_in_off, mask=mask) +- v = tl.load(KV_ptr + v_in_off, mask=mask) ++ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ k_off = ki_range * stride_kv_nheads + kj_range ++ if v_dim > 0: ++ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ v = tl.load(KV_ptr + v_off, mask=mask_v) ++ else: ++ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) ++ k = tl.load(KV_ptr + k_off, mask=mask_k) + +- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads +- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads ++ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads ++ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads + +- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] +- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] +- tl.store(K_ptr + k_out_off, k, mask=mask) +- tl.store(V_ptr + v_out_off, v, mask=mask) ++ k_out_off = ki_range * stride_k_nheads + kj_range ++ tl.store(K_ptr + k_out_off, k, mask=mask_k) ++ if v_dim > 0: ++ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] ++ tl.store(V_ptr + v_out_off, v, mask=mask_v) + + EMB = K_POS_EMB + pid_m * stride_emb_seq + # x1 = t[..., 0::2], x2 = t[..., 1::2] +@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( + x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + ++ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ mask_x = x_range < head_num + x_left_off = ( +- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads ++ x_range * stride_k_nheads + + k_dim + + tl.arange(0, emb_dim // 2)[None, :] + ) + x_right_off = x_left_off + emb_dim // 2 +- tl.store(K_ptr + x_left_off, x_left, mask=mask) +- tl.store(K_ptr + x_right_off, x_right, mask=mask) ++ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) ++ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) + + + @triton.autotune( +@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( + else: + token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) + +- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads +- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads +- mask = dkv_off < head_num * stride_dkv_nheads +- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] +- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] +- +- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads +- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads +- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] +- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] +- dk = tl.load(dK_ptr + dk_in_off, mask=mask) +- dv = tl.load(dV_ptr + dv_in_off, mask=mask) +- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) +- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) ++ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ dk_out_off = ki_range * stride_dkv_nheads + kj_range ++ ++ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads ++ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads ++ dk_in_off = ki_range * stride_dk_nheads + kj_range ++ ++ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) ++ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) ++ ++ if v_dim > 0: ++ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] ++ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) ++ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) + + if pid_head == 0: + x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): +- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads +- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim ++ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads ++ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads + mask = x_off < head_num * stride_dk_nheads + x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] + x_right_off = x_left_off + emb_dim // 2 +@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) + o_value = kv.new_empty(total_seqlen, nheads, v_dim) ++ k_dim_ceil = triton.next_power_of_2(k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_fwd_kv_kernel[grid]( +@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + emb_dim, + k_dim, ++ k_dim_ceil, + v_dim, + nheads, + batch_size, +@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) + d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) ++ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_bwd_kv_kernel[grid]( +@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + ctx.emb_dim, + ctx.k_dim, ++ k_dim_ceil, + ctx.v_dim, + nheads, + batch_size, +diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py +index 13d74aa52..060898a7a 100644 +--- a/megatron/core/models/common/language_module/language_module.py ++++ b/megatron/core/models/common/language_module/language_module.py +@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): + assert ( + column_parallel_linear is not None + ), "column_parallel_linear cannot be None when not using fused linear cross entropy." +- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) ++ # output ++ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} ++ output_layer_buffers = dict(column_parallel_linear.named_buffers()) ++ logits, _ = torch.func.functional_call( ++ column_parallel_linear, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden,), ++ col_linear_kwargs, ++ ) + + return self.compute_language_model_loss(labels, logits) + diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py -index 002edb92..f7273488 100755 +index e21127b87..712793853 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py -@@ -80,6 +80,8 @@ def get_gpt_layer_with_transformer_engine_spec( - use_te_op_fuser: Optional[bool] = False, +@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( use_kitchen: bool = False, use_te_activation_func: bool = False, + fallback_to_eager_attn: bool = False, + post_self_attn_layernorm: bool = False, + post_mlp_layernorm: bool = False, ) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). -@@ -182,9 +184,11 @@ def get_gpt_layer_with_transformer_engine_spec( - ), - ), - self_attn_bda=get_bias_dropout_add, -+ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, - pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, - mlp=mlp, - mlp_bda=get_bias_dropout_add, -+ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, - sharded_state_dict_keys_map={ - "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", - "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", +@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( + mlp=mlp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + normalization=normalization, ++ post_self_attn_layernorm=post_self_attn_layernorm, ++ post_mlp_layernorm=post_mlp_layernorm, + ) + + +@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( + mlp: ModuleSpec, + sharded_state_dict_keys_map: Optional[dict] = None, + normalization: Optional[str] = None, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Helper function to get module spec for TransformerLayer""" + +@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py -index df9adc3e..2f4f544a 100644 +index a1230568c..1fd52f65a 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py -@@ -443,7 +443,7 @@ class GPTModel(LanguageModule): +@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post +@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): + output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - - if mtp_in_postprocess: -+ if mtp_in_postprocess and labels is not None: ++ ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, +@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): + return hidden_states + + # Skip when mtp_num_layers is None or 0 +- if self.config.mtp_num_layers: +- mtp_labels = labels.clone() ++ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( +@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ +- 'weight': output_weight, ++ 'weight': output_weight.detach() if output_weight else None, + 'runtime_gather_output': runtime_gather_output, + }, + ) +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index 6e093f96f..eac21a3ea 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + if key == 'padding': + tensors[key] = LocalNonpersistentObject(tensors[key]) + continue ++ if key == 'step': ++ continue + assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( + tensors[key].shape, + gbuf_local_start, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py -index 57332ac3..f3abd642 100644 +index a273002b9..4f821cfd5 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py -@@ -9,6 +9,7 @@ from typing import Callable, List, Optional +@@ -11,6 +11,7 @@ from typing import Callable, List, Optional import numpy as np import torch @@ -80,222 +500,8 @@ index 57332ac3..f3abd642 100644 from .utils import GlobalMemoryBuffer, is_torch_min_version -@@ -163,6 +164,213 @@ def get_nccl_options(pg_name, nccl_comm_cfgs): - return None - - -+old_new_group = None -+ -+ -+def monkey_patch_torch_dist(): -+ print("Applying monkey patch to torch.distributed", flush=True) -+ global old_new_group -+ if old_new_group is not None: -+ return -+ -+ old_new_group = dist.new_group -+ -+ def new_group(*args, **kwargs): -+ group = old_new_group(*args, **kwargs) -+ # skip none nccl group. -+ if ( -+ len(args) >= 3 and args[2] == "gloo" or -+ "backend" in kwargs and kwargs["backend"] == "gloo" -+ ): -+ return group -+ -+ # Get ranks from arguments -+ if len(args) >= 1 and args[0] is not None: -+ ranks = args[0] -+ elif "ranks" in kwargs and kwargs["ranks"] is not None: -+ ranks = kwargs["ranks"] -+ else: -+ # If no ranks specified, use all ranks in world -+ ranks = list(range(dist.get_world_size())) -+ -+ if len(ranks) == 1: -+ return group -+ -+ group = ReloadableProcessGroup(group, ranks) -+ return group -+ -+ dist.new_group = new_group -+ -+ def get_new_function(func): -+ def new_function(*args, **kwargs): -+ args = ( -+ arg.group if isinstance(arg, ReloadableProcessGroup) else arg -+ for arg in args -+ ) -+ kwargs = { -+ k: (v.group if isinstance(v, ReloadableProcessGroup) else v) -+ for k, v in kwargs.items() -+ } -+ return func(*args, **kwargs) -+ return new_function -+ -+ dist.get_rank = get_new_function(dist.get_rank) -+ dist.get_world_size = get_new_function(dist.get_world_size) -+ dist.get_backend = get_new_function(dist.get_backend) -+ dist.get_global_rank = get_new_function(dist.get_global_rank) -+ dist.get_group_rank = get_new_function(dist.get_group_rank) -+ dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks) -+ -+ dist.all_reduce = get_new_function(dist.all_reduce) -+ dist.all_gather = get_new_function(dist.all_gather) -+ dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor) -+ dist.all_gather_object = get_new_function(dist.all_gather_object) -+ dist.all_to_all = get_new_function(dist.all_to_all) -+ dist.all_to_all_single = get_new_function(dist.all_to_all_single) -+ dist.broadcast = get_new_function(dist.broadcast) -+ dist.reduce = get_new_function(dist.reduce) -+ dist.reduce_scatter = get_new_function(dist.reduce_scatter) -+ dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor) -+ dist.scatter = get_new_function(dist.scatter) -+ dist.gather = get_new_function(dist.gather) -+ dist.barrier = get_new_function(dist.barrier) -+ dist.send = get_new_function(dist.send) -+ dist.recv = get_new_function(dist.recv) -+ dist._coalescing_manager = get_new_function(dist._coalescing_manager) -+ -+ # p2p -+ old_isend = dist.isend -+ old_irecv = dist.irecv -+ -+ dist.isend = get_new_function(dist.isend) -+ dist.irecv = get_new_function(dist.irecv) -+ -+ def get_new_p2pop_function(func): -+ def new_function(*args, **kwargs): -+ def convert(arg): -+ if isinstance(arg, ReloadableProcessGroup): -+ return arg.group -+ elif arg == dist.isend: -+ arg = old_isend -+ elif arg == dist.irecv: -+ arg = old_irecv -+ return arg -+ -+ args = (convert(arg) for arg in args) -+ kwargs = { -+ k: convert(v) -+ for k, v in kwargs.items() -+ } -+ return func(*args, **kwargs) -+ return new_function -+ -+ dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__) -+ dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__) -+ -+ -+ -+class ReloadableProcessGroup(torch.distributed.ProcessGroup): -+ GROUPS = [] -+ -+ def __init__(self, group, ranks): -+ super().__init__( -+ rank=dist.get_rank(group), -+ size=dist.get_world_size(group), -+ ) -+ #print(f"Creating ReloadableProcessGroup with ranks: {ranks}", flush=True) -+ self.group = group -+ self.group_info = { -+ "ranks": ranks, -+ } -+ ReloadableProcessGroup.GROUPS.append(self) -+ -+ def __getattr__(self, name): -+ return getattr(self.group, name) -+ -+ @staticmethod -+ def destroy_process_groups(): -+ for reloadable_group in ReloadableProcessGroup.GROUPS: -+ if reloadable_group.group is None: -+ continue -+ #print(f"Destroying process group: {reloadable_group.group_info['ranks']}") -+ dist.destroy_process_group(reloadable_group.group) -+ del reloadable_group.group -+ reloadable_group.group = None -+ -+ @staticmethod -+ def reload_process_groups(): -+ for reloadable_group in ReloadableProcessGroup.GROUPS: -+ if reloadable_group.group is not None: -+ continue -+ #print(f"Reloading process group: {reloadable_group.group_info['ranks']}") -+ group = old_new_group( -+ ranks=reloadable_group.group_info["ranks"], -+ backend="nccl" -+ ) -+ reloadable_group.group = group -+ -+ def rank(self) -> int: return self.group.rank() -+ def size(self) -> int: return self.group.size() -+ def name(self) -> str: return self.group.name() -+ -+ def shutdown(self) -> None: -+ if self.group is not None: -+ self.group.shutdown() -+ -+ def abort(self) -> None: -+ if self.group is not None: -+ self.group.abort() -+ -+ def _fwd(self, method, *args, **kwargs): -+ inner = self.group -+ if inner is None: -+ raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.") -+ return getattr(inner, method)(*args, **kwargs) -+ -+ def barrier(self, *a, **kw): return self._fwd("barrier", *a, **kw) -+ def broadcast(self, *a, **kw): return self._fwd("broadcast", *a, **kw) -+ def allreduce(self, *a, **kw): return self._fwd("allreduce", *a, **kw) -+ def allreduce_coalesced(self, *a, **kw): return self._fwd("allreduce_coalesced", *a, **kw) -+ def reduce(self, *a, **kw): return self._fwd("reduce", *a, **kw) -+ def allgather(self, *a, **kw): return self._fwd("allgather", *a, **kw) -+ def _allgather_base(self, *a, **kw): return self._fwd("_allgather_base", *a, **kw) -+ def allgather_coalesced(self, *a, **kw): return self._fwd("allgather_coalesced", *a, **kw) -+ def allgather_into_tensor_coalesced(self, *a, **kw): return self._fwd("allgather_into_tensor_coalesced", *a, **kw) -+ def gather(self, *a, **kw): return self._fwd("gather", *a, **kw) -+ def scatter(self, *a, **kw): return self._fwd("scatter", *a, **kw) -+ def reduce_scatter(self, *a, **kw): return self._fwd("reduce_scatter", *a, **kw) -+ def _reduce_scatter_base(self, *a, **kw): return self._fwd("_reduce_scatter_base", *a, **kw) -+ def reduce_scatter_tensor_coalesced(self, *a, **kw): return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw) -+ def alltoall_base(self, *a, **kw): return self._fwd("alltoall_base", *a, **kw) -+ def alltoall(self, *a, **kw): return self._fwd("alltoall", *a, **kw) -+ def send(self, *a, **kw): return self._fwd("send", *a, **kw) -+ def recv(self, *a, **kw): return self._fwd("recv", *a, **kw) -+ def recv_anysource(self, *a, **kw): return self._fwd("recv_anysource", *a, **kw) -+ -+ def _start_coalescing(self, *a, **kw): return self._fwd("_start_coalescing", *a, **kw) -+ def _end_coalescing(self, *a, **kw): return self._fwd("_end_coalescing", *a, **kw) -+ def _get_backend_name(self): return self._fwd("_get_backend_name") -+ def _get_backend(self, *a, **kw): return self._fwd("_get_backend", *a, **kw) -+ def _set_default_backend(self, *a, **kw): return self._fwd("_set_default_backend", *a, **kw) -+ @property -+ def bound_device_id(self): return self.group.bound_device_id -+ @bound_device_id.setter -+ def bound_device_id(self, dev): self.group.bound_device_id = dev -+ -+ -+def destroy_process_groups(): -+ """Destroy all reloadable process groups.""" -+ ReloadableProcessGroup.destroy_process_groups() -+ -+ -+def reload_process_groups(): -+ """Reload all reloadable process groups.""" -+ ReloadableProcessGroup.reload_process_groups() -+ -+ -+monkey_patch_torch_dist() -+ -+ - def create_group( - ranks=None, - timeout=None, diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py -index 63ee9d1f..b90b744c 100644 +index ac839c21f..f18309217 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( @@ -325,13 +531,148 @@ index 63ee9d1f..b90b744c 100644 ) ops.append(recv_next_op) if len(ops) > 0: +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 28cff06f5..48c9c1a25 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -587,6 +587,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from miles.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 16fc9d9af..3e95858a6 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -201,6 +201,9 @@ class TopKRouter(Router): + self.global_tokens_per_expert = None + self.ga_steps = None + ++ from miles.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index a8f4abfcd..f33f6f05e 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) +- position_ids, _ = roll_tensor( +- position_ids, +- shifts=-1, +- dims=-1, +- cp_group=self.cp_group, +- packed_seq_params=packed_seq_params, +- ) ++ if position_ids is not None: ++ position_ids, _ = roll_tensor( ++ position_ids, ++ shifts=-1, ++ dims=-1, ++ cp_group=self.cp_group, ++ packed_seq_params=packed_seq_params, ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: +@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index 6f557e1f..b295fd35 100644 +index e2705bd9f..a0aa109b5 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py -@@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig): - qk_layernorm: bool = False - """Whether to apply `normalization` type of normalization to the query and key embeddings.""" +@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): + attention_output_gate: bool = False + """Whether to apply output gate to the attention layers.""" + post_self_attn_layernorm: bool = False + post_mlp_layernorm: bool = False @@ -340,10 +681,10 @@ index 6f557e1f..b295fd35 100644 """Whether to run real-time tests.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py -index 84f22bde..b4807d26 100644 +index 3ea405770..5a42001b9 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py -@@ -224,6 +224,7 @@ class TransformerLayerSubmodules: +@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: input_layernorm: Union[ModuleSpec, type] = IdentityOp self_attention: Union[ModuleSpec, type] = IdentityOp self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp @@ -351,7 +692,7 @@ index 84f22bde..b4807d26 100644 pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp -@@ -232,6 +233,7 @@ class TransformerLayerSubmodules: +@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp @@ -359,7 +700,7 @@ index 84f22bde..b4807d26 100644 # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) -@@ -336,6 +338,14 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): +@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): # [Module 3: BiasDropoutFusion] self.self_attn_bda = build_module(submodules.self_attn_bda) @@ -369,14 +710,13 @@ index 84f22bde..b4807d26 100644 + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) -+ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn self.pre_cross_attn_layernorm = build_module( submodules.pre_cross_attn_layernorm, -@@ -399,6 +409,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): - # [Module 9: BiasDropoutFusion] - self.mlp_bda = build_module(submodules.mlp_bda) +@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + self.post_mlp_layernorm = build_module( + submodules.post_mlp_layernorm, @@ -388,19 +728,18 @@ index 84f22bde..b4807d26 100644 self.recompute_input_layernorm = False self.recompute_pre_mlp_layernorm = False self.recompute_mlp = False -@@ -535,6 +552,11 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): +@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): attention_output_with_bias[0] ) + attention_output, attention_output_bias = attention_output_with_bias + attention_output = self.post_self_attn_layernorm(attention_output) + attention_output_with_bias = (attention_output, attention_output_bias) -+ + # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? nvtx_range_push(suffix="self_attn_bda") -@@ -635,6 +657,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): +@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): else: mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) @@ -412,12 +751,12 @@ index 84f22bde..b4807d26 100644 # discard the output of the pre-mlp layernorm and register the recompute # as a gradient hook of mlp_output_with_bias[0] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index 24ba8926..4f039fd4 100644 +index b267c8a81..83736acdc 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py -@@ -1191,6 +1191,9 @@ def core_transformer_config_from_args(args, config_class=None): - if args.is_hybrid_model: - kw_args['is_hybrid_model'] = args.is_hybrid_model +@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm @@ -425,7 +764,7 @@ index 24ba8926..4f039fd4 100644 # handle quantization config # NOTE: Kitchen arguments are only added to the namespace when # Kitchen library is available. -@@ -1481,6 +1484,10 @@ def _add_network_size_args(parser): +@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): action='store_true', help='If set, use original BERT residula connection ' 'ordering.') @@ -433,6 +772,21 @@ index 24ba8926..4f039fd4 100644 + help='If set, use post self attention layernorm.') + group.add_argument('--post-mlp-layernorm', action='store_true', + help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') group.add_argument('--openai-gelu', action='store_true', help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 13b7526ca..6c590f653 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, +- trust_remote_code=trust_remote_code, ++ trust_remote_code=True, + **kwargs, + ) + self._vocab = self._tokenizer.get_vocab() diff --git a/docker/amd_patch/sglv0.5.7/amd_megatron_fused_kernels_init.patch b/docker/amd_patch/sglv0.5.7/amd_megatron_fused_kernels_init.patch new file mode 100644 index 0000000000..f6efca346d --- /dev/null +++ b/docker/amd_patch/sglv0.5.7/amd_megatron_fused_kernels_init.patch @@ -0,0 +1,51 @@ +diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py +index 87cceac3..ac686d74 100644 +--- a/megatron/legacy/fused_kernels/__init__.py ++++ b/megatron/legacy/fused_kernels/__init__.py +@@ -3,6 +3,7 @@ + import os + import pathlib + import subprocess ++import torch + + from torch.utils import cpp_extension + +@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + + def load(args): +- +- # Check if cuda 11 is installed for compute capability 8.0 +- cc_flag = [] +- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( +- cpp_extension.CUDA_HOME +- ) +- if int(bare_metal_major) >= 11: +- cc_flag.append('-gencode') +- cc_flag.append('arch=compute_80,code=sm_80') +- if int(bare_metal_minor) >= 8: ++ if torch.cuda.is_available() and torch.version.cuda: ++ # Check if cuda 11 is installed for compute capability 8.0 ++ cc_flag = [] ++ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( ++ cpp_extension.CUDA_HOME ++ ) ++ if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') +- cc_flag.append('arch=compute_90,code=sm_90') ++ cc_flag.append('arch=compute_80,code=sm_80') ++ if int(bare_metal_minor) >= 8: ++ cc_flag.append('-gencode') ++ cc_flag.append('arch=compute_90,code=sm_90') + +- # Build path +- srcpath = pathlib.Path(__file__).parent.absolute() +- buildpath = srcpath / "build" +- _create_build_dir(buildpath) ++ # Build path ++ srcpath = pathlib.Path(__file__).parent.absolute() ++ buildpath = srcpath / "build" ++ _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): diff --git a/docker/amd_patch/sglv0.5.7/megatron.patch b/docker/amd_patch/sglv0.5.7/megatron.patch new file mode 100644 index 0000000000..b129959aff --- /dev/null +++ b/docker/amd_patch/sglv0.5.7/megatron.patch @@ -0,0 +1,792 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index 5a1ea308d..aa701237f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + +diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py +index fe26e8b43..4451f2776 100644 +--- a/megatron/core/distributed/__init__.py ++++ b/megatron/core/distributed/__init__.py +@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads + from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel + from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel + from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig ++ ++# Backward compatibility patch for FSDP module reorganization ++import sys ++import importlib.util ++ ++spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') ++if spec: ++ custom_fsdp = importlib.util.module_from_spec(spec) ++ spec.loader.exec_module(custom_fsdp) ++ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp ++ if hasattr(custom_fsdp, 'MegatronFSDP'): ++ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index acb93ef78..d239db4ab 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): + ) + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): + + + if HAVE_TE and is_te_min_version("1.9.0.dev0"): ++ def ceil_div(x: int, y: int) -> int: ++ return (x + y - 1) // y ++ ++ class _FakeInt4QuantizationSTE(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, x, group_size): ++ m, n = x.shape ++ block_size_m, block_size_n = 1, group_size ++ ++ ++ m_padded = ceil_div(m, block_size_m) * block_size_m ++ n_padded = ceil_div(n, block_size_n) * block_size_n ++ ++ x_padded = torch.zeros( ++ (m_padded, n_padded), ++ dtype=x.dtype, device=x.device ++ ) ++ x_padded[:m, :n] = x ++ ++ x_view = x_padded.view( ++ m_padded // block_size_m, ++ block_size_m, ++ n_padded // block_size_n, ++ block_size_n ++ ) ++ ++ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) ++ q_max = 7 ++ x_scale = x_max / q_max ++ ++ x_scale = x_scale.clamp(min=1e-5) ++ ++ x_div = x_view / x_scale ++ x_round = torch.round(x_div) ++ ++ x_q_clamped = x_round.clamp(-q_max, q_max) ++ ++ x_dequant_view = x_q_clamped * x_scale ++ ++ x_dequant_full = x_dequant_view.view_as(x_padded) ++ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) ++ ++ return x_out ++ ++ @staticmethod ++ def backward(ctx, grad_output): ++ return grad_output, None ++ ++ def fake_int4_quantization_ste(x, group_size): ++ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) ++ ++ if hasattr(x, 'main_grad'): ++ x_out.main_grad = x.main_grad ++ ++ return x_out + + class TEGroupedLinear(te.pytorch.GroupedLinear): + """ +@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) ++ + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + +@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + return out + return out, None + ++ def _get_weight_tensors(self): ++ """Get the weight tensors of the module.""" ++ weight_tensors = super()._get_weight_tensors() ++ ++ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": ++ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) ++ ++ weight_tensors = [ ++ fake_int4_quantization_ste(w, group_size) ++ for w in weight_tensors ++ ] ++ ++ return weight_tensors ++ + def _encode_extra_state(self, state): + # TE 2.0 changed the format of extra_state to be a byte tensor + if is_te_min_version("2.0.0"): +diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +index 1fd5dcfae..c9aeef1f0 100644 +--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py ++++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( + cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + +- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads +- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads +- mask = kv_off < head_num * stride_kv_nheads +- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] +- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] +- k = tl.load(KV_ptr + k_in_off, mask=mask) +- v = tl.load(KV_ptr + v_in_off, mask=mask) ++ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ k_off = ki_range * stride_kv_nheads + kj_range ++ if v_dim > 0: ++ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ v = tl.load(KV_ptr + v_off, mask=mask_v) ++ else: ++ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) ++ k = tl.load(KV_ptr + k_off, mask=mask_k) + +- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads +- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads ++ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads ++ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads + +- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] +- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] +- tl.store(K_ptr + k_out_off, k, mask=mask) +- tl.store(V_ptr + v_out_off, v, mask=mask) ++ k_out_off = ki_range * stride_k_nheads + kj_range ++ tl.store(K_ptr + k_out_off, k, mask=mask_k) ++ if v_dim > 0: ++ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] ++ tl.store(V_ptr + v_out_off, v, mask=mask_v) + + EMB = K_POS_EMB + pid_m * stride_emb_seq + # x1 = t[..., 0::2], x2 = t[..., 1::2] +@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( + x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + ++ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ mask_x = x_range < head_num + x_left_off = ( +- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads ++ x_range * stride_k_nheads + + k_dim + + tl.arange(0, emb_dim // 2)[None, :] + ) + x_right_off = x_left_off + emb_dim // 2 +- tl.store(K_ptr + x_left_off, x_left, mask=mask) +- tl.store(K_ptr + x_right_off, x_right, mask=mask) ++ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) ++ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) + + + @triton.autotune( +@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( + else: + token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) + +- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads +- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads +- mask = dkv_off < head_num * stride_dkv_nheads +- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] +- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] +- +- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads +- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads +- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] +- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] +- dk = tl.load(dK_ptr + dk_in_off, mask=mask) +- dv = tl.load(dV_ptr + dv_in_off, mask=mask) +- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) +- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) ++ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ dk_out_off = ki_range * stride_dkv_nheads + kj_range ++ ++ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads ++ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads ++ dk_in_off = ki_range * stride_dk_nheads + kj_range ++ ++ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) ++ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) ++ ++ if v_dim > 0: ++ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] ++ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) ++ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) + + if pid_head == 0: + x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): +- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads +- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim ++ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads ++ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads + mask = x_off < head_num * stride_dk_nheads + x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] + x_right_off = x_left_off + emb_dim // 2 +@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) + o_value = kv.new_empty(total_seqlen, nheads, v_dim) ++ k_dim_ceil = triton.next_power_of_2(k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_fwd_kv_kernel[grid]( +@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + emb_dim, + k_dim, ++ k_dim_ceil, + v_dim, + nheads, + batch_size, +@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) + d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) ++ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_bwd_kv_kernel[grid]( +@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + ctx.emb_dim, + ctx.k_dim, ++ k_dim_ceil, + ctx.v_dim, + nheads, + batch_size, +diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py +index 13d74aa52..060898a7a 100644 +--- a/megatron/core/models/common/language_module/language_module.py ++++ b/megatron/core/models/common/language_module/language_module.py +@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): + assert ( + column_parallel_linear is not None + ), "column_parallel_linear cannot be None when not using fused linear cross entropy." +- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) ++ # output ++ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} ++ output_layer_buffers = dict(column_parallel_linear.named_buffers()) ++ logits, _ = torch.func.functional_call( ++ column_parallel_linear, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden,), ++ col_linear_kwargs, ++ ) + + return self.compute_language_model_loss(labels, logits) + +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index e21127b87..712793853 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( + use_kitchen: bool = False, + use_te_activation_func: bool = False, + fallback_to_eager_attn: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + +@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( + mlp=mlp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + normalization=normalization, ++ post_self_attn_layernorm=post_self_attn_layernorm, ++ post_mlp_layernorm=post_mlp_layernorm, + ) + + +@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( + mlp: ModuleSpec, + sharded_state_dict_keys_map: Optional[dict] = None, + normalization: Optional[str] = None, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Helper function to get module spec for TransformerLayer""" + +@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index a1230568c..1fd52f65a 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post +@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() +- if mtp_in_postprocess: ++ ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, +@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): + return hidden_states + + # Skip when mtp_num_layers is None or 0 +- if self.config.mtp_num_layers: +- mtp_labels = labels.clone() ++ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( +@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ +- 'weight': output_weight, ++ 'weight': output_weight.detach() if output_weight else None, + 'runtime_gather_output': runtime_gather_output, + }, + ) +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index 6e093f96f..eac21a3ea 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + if key == 'padding': + tensors[key] = LocalNonpersistentObject(tensors[key]) + continue ++ if key == 'step': ++ continue + assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( + tensors[key].shape, + gbuf_local_start, +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index a273002b9..4f821cfd5 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -11,6 +11,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from .utils import GlobalMemoryBuffer, is_torch_min_version + +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index ac839c21f..f18309217 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -26,22 +26,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 28cff06f5..48c9c1a25 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -587,6 +587,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from miles.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 16fc9d9af..3e95858a6 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -201,6 +201,9 @@ class TopKRouter(Router): + self.global_tokens_per_expert = None + self.ga_steps = None + ++ from miles.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index a8f4abfcd..f33f6f05e 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) +- position_ids, _ = roll_tensor( +- position_ids, +- shifts=-1, +- dims=-1, +- cp_group=self.cp_group, +- packed_seq_params=packed_seq_params, +- ) ++ if position_ids is not None: ++ position_ids, _ = roll_tensor( ++ position_ids, ++ shifts=-1, ++ dims=-1, ++ cp_group=self.cp_group, ++ packed_seq_params=packed_seq_params, ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: +@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index e2705bd9f..a0aa109b5 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): + attention_output_gate: bool = False + """Whether to apply output gate to the attention layers.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index 3ea405770..5a42001b9 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + cross_attention: Union[ModuleSpec, type] = IdentityOp +@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = build_module( + submodules.pre_cross_attn_layernorm, +@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index b267c8a81..83736acdc 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') ++ group.add_argument('--post-self-attn-layernorm', action='store_true', ++ help='If set, use post self attention layernorm.') ++ group.add_argument('--post-mlp-layernorm', action='store_true', ++ help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 13b7526ca..6c590f653 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, +- trust_remote_code=trust_remote_code, ++ trust_remote_code=True, + **kwargs, + ) + self._vocab = self._tokenizer.get_vocab() diff --git a/docker/amd_patch/sglv0.5.7/sglang.patch b/docker/amd_patch/sglv0.5.7/sglang.patch new file mode 100644 index 0000000000..e1d6562e16 --- /dev/null +++ b/docker/amd_patch/sglv0.5.7/sglang.patch @@ -0,0 +1,36 @@ +diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py +index 8e3429dec..494a754b3 100644 +--- a/python/sglang/srt/distributed/parallel_state.py ++++ b/python/sglang/srt/distributed/parallel_state.py +@@ -1849,7 +1849,10 @@ def get_tensor_model_parallel_world_size(): + + def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" +- return get_tp_group().rank_in_group ++ try: ++ return get_tp_group().rank_in_group ++ except Exception: ++ return 0 + + + def get_pipeline_model_parallel_world_size(): +diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py +index b38a83d57..a492e3ef8 100644 +--- a/python/sglang/srt/models/qwen3_next.py ++++ b/python/sglang/srt/models/qwen3_next.py +@@ -45,13 +45,14 @@ from sglang.srt.utils import ( + LazyValue, + add_prefix, + is_cuda, ++ is_cuda_alike, + is_npu, + make_layers, + set_weight_attrs, + ) + + logger = logging.getLogger(__name__) +-_is_cuda = is_cuda() ++_is_cuda = is_cuda_alike() + _is_npu = is_npu() + + diff --git a/docker/patch/v0.5.7/sglang.patch b/docker/patch/v0.5.7/sglang.patch index 42d23ed659..ea44e24be1 100644 --- a/docker/patch/v0.5.7/sglang.patch +++ b/docker/patch/v0.5.7/sglang.patch @@ -74,6 +74,213 @@ index 0478526ef..cfb1aa669 100644 def get_pipeline_model_parallel_world_size(): +diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py +index 34aa364cf..da5d0d6b6 100644 +--- a/python/sglang/srt/entrypoints/openai/protocol.py ++++ b/python/sglang/srt/entrypoints/openai/protocol.py +@@ -81,6 +81,7 @@ class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) ++ token_ids: List[int] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + +@@ -92,6 +93,7 @@ class TopLogprob(BaseModel): + + class ChatCompletionTokenLogprob(BaseModel): + token: str ++ token_id: int + bytes: List[int] + logprob: float + top_logprobs: List[TopLogprob] +@@ -501,6 +503,7 @@ class ChatCompletionRequest(BaseModel): + top_k: Optional[int] = None + min_p: Optional[float] = None + min_tokens: int = 0 ++ logprob_start_len: Optional[int] = None + regex: Optional[str] = None + ebnf: Optional[str] = None + repetition_penalty: Optional[float] = None +@@ -536,6 +539,9 @@ class ChatCompletionRequest(BaseModel): + + # For data parallel rank routing + data_parallel_rank: Optional[int] = None ++ ++ # Input ids, if provided, it will override the message input. ++ input_ids: Optional[Union[List[List[int]], List[int]]] = None + + # OpenAI/SGLang default sampling parameters + _DEFAULT_SAMPLING_PARAMS = { +@@ -618,8 +624,8 @@ class ChatCompletionRequest(BaseModel): + + def to_sampling_params( + self, +- stop: List[str], + model_generation_config: Dict[str, Any], ++ stop: Optional[List[str]] = None, + tool_call_constraint: Optional[ToolCallConstraint] = None, + ) -> Dict[str, Any]: + """ +@@ -706,6 +712,7 @@ class ChatMessage(BaseModel): + class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage ++ input_token_ids: Optional[List[int]] = None + logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None + finish_reason: Optional[ + Literal[ +diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py +index cb0c084a3..daa0db7bb 100644 +--- a/python/sglang/srt/entrypoints/openai/serving_chat.py ++++ b/python/sglang/srt/entrypoints/openai/serving_chat.py +@@ -146,6 +146,62 @@ class OpenAIServingChat(OpenAIServingBase): + + return None + ++ def _convert_chat_completion_with_input_ids_to_internal_request( ++ self, ++ request: ChatCompletionRequest, ++ raw_request: Request = None, ++ ) -> tuple[GenerateReqInput, ChatCompletionRequest]: ++ ++ # Notice: currently, if input_ids is provided, the stop token is not used. ++ sampling_params = request.to_sampling_params( ++ model_generation_config=self.default_sampling_params ++ ) ++ ++ prompt_kwargs = {"input_ids": request.input_ids} ++ ++ # Extract custom labels from raw request headers ++ custom_labels = self.extract_custom_labels(raw_request) ++ ++ # Resolve LoRA adapter from model parameter or explicit lora_path ++ lora_path = self._resolve_lora_path(request.model, request.lora_path) ++ if lora_path: ++ first_adapter = ( ++ lora_path ++ if isinstance(lora_path, str) ++ else next((a for a in lora_path if a), None) ++ ) ++ if first_adapter: ++ self._validate_lora_enabled(first_adapter) ++ ++ logprob_start_len = ( ++ request.logprob_start_len if request.logprob_start_len is not None else -1 ++ ) ++ ++ adapted_request = GenerateReqInput( ++ **prompt_kwargs, ++ sampling_params=sampling_params, ++ return_logprob=request.logprobs, ++ logprob_start_len=logprob_start_len, ++ top_logprobs_num=request.top_logprobs or 0, ++ stream=request.stream, ++ return_text_in_logprobs=True, ++ lora_path=lora_path, ++ bootstrap_host=request.bootstrap_host, ++ bootstrap_port=request.bootstrap_port, ++ bootstrap_room=request.bootstrap_room, ++ data_parallel_rank=request.data_parallel_rank, ++ return_hidden_states=request.return_hidden_states, ++ rid=request.rid, ++ extra_key=self._compute_extra_key(request), ++ require_reasoning=self._get_reasoning_from_request(request), ++ priority=request.priority, ++ custom_labels=custom_labels, ++ custom_logit_processor=request.custom_logit_processor, ++ ) ++ ++ return adapted_request, request ++ ++ + def _convert_to_internal_request( + self, + request: ChatCompletionRequest, +@@ -162,6 +218,9 @@ class OpenAIServingChat(OpenAIServingBase): + """Convert OpenAI chat completion request to internal format""" + is_multimodal = self.tokenizer_manager.model_config.is_multimodal + ++ if request.input_ids: ++ return self._convert_chat_completion_with_input_ids_to_internal_request(request, raw_request) ++ + # Process messages and apply chat template + processed_messages = self._process_messages(request, is_multimodal) + +@@ -195,6 +254,10 @@ class OpenAIServingChat(OpenAIServingBase): + if first_adapter: + self._validate_lora_enabled(first_adapter) + ++ logprob_start_len = ( ++ request.logprob_start_len if request.logprob_start_len is not None else -1 ++ ) ++ + adapted_request = GenerateReqInput( + **prompt_kwargs, + image_data=processed_messages.image_data, +@@ -202,7 +265,7 @@ class OpenAIServingChat(OpenAIServingBase): + audio_data=processed_messages.audio_data, + sampling_params=sampling_params, + return_logprob=request.logprobs, +- logprob_start_len=-1, ++ logprob_start_len=logprob_start_len, + top_logprobs_num=request.top_logprobs or 0, + stream=request.stream, + return_text_in_logprobs=True, +@@ -768,8 +831,13 @@ class OpenAIServingChat(OpenAIServingBase): + for idx, ret_item in enumerate(ret): + # Process logprobs + choice_logprobs = None ++ input_token_ids = None + if request.logprobs: + choice_logprobs = self._process_response_logprobs(ret_item) ++ input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] ++ input_token_ids = [ ++ token_id for _, token_id, _ in input_token_logprobs ++ ] + + # Handle hidden states + hidden_states = process_hidden_states_from_ret(ret_item, request) +@@ -824,6 +892,7 @@ class OpenAIServingChat(OpenAIServingBase): + tool_calls=tool_calls, + reasoning_content=reasoning_text if reasoning_text else None, + ), ++ input_token_ids=input_token_ids, + logprobs=choice_logprobs, + finish_reason=finish_reason["type"] if finish_reason else None, + matched_stop=( +@@ -865,6 +934,7 @@ class OpenAIServingChat(OpenAIServingBase): + for token_idx, (token, logprob) in enumerate( + zip(logprobs.tokens, logprobs.token_logprobs) + ): ++ token_id = logprobs.token_ids[token_idx] + token_bytes = list(token.encode("utf-8")) + top_logprobs = [] + if logprobs.top_logprobs: +@@ -885,6 +955,7 @@ class OpenAIServingChat(OpenAIServingBase): + token_logprobs.append( + ChatCompletionTokenLogprob( + token=token, ++ token_id=token_id, + bytes=token_bytes, + logprob=logprob, + top_logprobs=top_logprobs, +diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py +index 94ac5458d..e718ddb2a 100644 +--- a/python/sglang/srt/entrypoints/openai/utils.py ++++ b/python/sglang/srt/entrypoints/openai/utils.py +@@ -19,9 +19,10 @@ def to_openai_style_logprobs( + ret_logprobs = LogProbs() + + def append_token_logprobs(token_logprobs): +- for logprob, _, token_text in token_logprobs: ++ for logprob, token_id, token_text in token_logprobs: + ret_logprobs.tokens.append(token_text) + ret_logprobs.token_logprobs.append(logprob) ++ ret_logprobs.token_ids.append(token_id) + + # Not supported yet + ret_logprobs.text_offset.append(-1) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index b07164c53..8e6722ce0 100644 --- a/python/sglang/srt/layers/layernorm.py diff --git a/docs/en/advanced/miles-router.md b/docs/en/advanced/miles-router.md new file mode 100644 index 0000000000..1eeb42b659 --- /dev/null +++ b/docs/en/advanced/miles-router.md @@ -0,0 +1,93 @@ +# Miles Router + +miles includes an optional Miles Router used during rollout / data generation. It is a lightweight HTTP router/proxy that sits in front of one or more SGLang worker servers and adds training-oriented capabilities that are not the main goal of serving-focused routers. + +--- + +## 1. What is Miles Router? + +Miles Router is a small FastAPI service that: + +- Registers workers (SGLang HTTP servers) into a local pool +- Routes requests to a selected worker (simple least-inflight load balancing) +- Proxies arbitrary paths to the selected worker (e.g. `/generate`) +- Runs periodic health checks and quarantines unhealthy workers +- Supports middleware plugins (via `--miles-router-middleware-paths`) to implement rollout-specific processing (e.g. caching, request/response transforms) + +In miles's architecture, the router is part of the rollout system ("SGLang + router") that generates samples and pushes them into the data buffer. + +### How it is launched + +In distributed training, miles will start a router automatically when `--sglang-router-ip` is not provided: + +- If `--use-miles-router` is set, miles starts Miles Router +- Otherwise, miles starts SGLang Model Gateway + +--- + +## 2. Why we need Miles Router + +Unlike production inference, RL rollout needs to capture additional metadata for training: token-level logprobs, loss masks, and (for MoE models) expert routing decisions. Miles Router provides these capabilities through its middleware system and passthrough proxy design. + +### 2.1 Radix-tree cache (transparent token management) + +> Use this when your rollout pipeline is text-in/text-out and you cannot reliably persist token IDs; if you already control token-in/token-out (e.g. search r1, multiturn VLM examples), you likely don't need the radix-tree cache. + +Text-in text-out interfaces can cause token retokenization mismatches - re-tokenizing text at training time may produce different token sequences than rollout, breaking per-token alignment needed for PPO/GRPO losses. + +The radix-tree cache solves this transparently: it intercepts text-based requests, tokenizes them, and stores trajectories (text, token IDs, logprobs, loss masks) keyed by the text prefix. After rollout finishes, calling `/retrieve_from_text` returns the exact token sequence with aligned metadata, without requiring any changes to your rollout code. + +Implementation-wise, the radix-tree cache: + +- Accepts text plus tokens/metadata and stores them in a radix tree +- Uses longest-prefix matching to reuse cached token sequences (enabling token-in/token-out downstream) +- Allows insertion of new text continuations as rollout proceeds (multiple trajectories per prompt, e.g. GRPO) +- Periodically cleans up stale nodes to control memory usage + +Use the radix cache when you have text-based rollout code and want token-level precision without rewriting, or when running GRPO with multiple trajectories sharing the same prompt prefix. + +### 2.2 Rollout routing replay (R3) for MoE + +For MoE models, miles supports rollout routing replay (R3): record expert routing decisions during rollout and replay them during training to improve stability. + +#### SGLang side + +SGLang provides expert routing capture via: + +- `--enable-return-routed-experts`: server argument to enable routing capture +- `RoutedExpertsCapturer`: captures `topk_ids` (selected expert IDs) at each MoE layer during forward pass +- `return_routed_experts`: request parameter to retrieve routing data +- Returns `routed_experts` in response `meta_info` - a `[seq_len - 1, num_layers, top_k]` tensor of expert IDs + +#### miles side + +miles consumes the routing data and replays it during training: + +- `--use-miles-router --use-rollout-routing-replay`: both flags required to enable R3 +- Rollout sends `return_routed_experts=True` and stores results in `sample.rollout_routed_experts` +- Training calls `fill_routing_replay()` to load routing data into `RoutingReplay` objects +- During forward pass, recorded routing decisions are replayed instead of recomputed + +#### Why Miles Router is needed + +We need Miles Router because the SGLang worker returns routed experts in the response (`meta_info.routed_experts`) when the request sets `return_routed_experts=true`, and Miles Router preserves this field end-to-end. SGLang Model Gateway may drop this extra metadata when it reconstructs responses with a fixed schema (see section 3). + +--- + +## 3. Differences vs SGLang Model Gateway + +Miles Router and SGLang Model Gateway can both route requests to workers, but they are optimized for different goals. + +### Key differences + +Miles Router is a lightweight Python/FastAPI proxy that acts as a passthrough to SGLang workers. This passthrough design enables RL-specific features like radix-tree trajectory caching and R3 (which require preserving raw response metadata like `routed_experts`). + +SGLang Model Gateway is a high-performance Rust-based router optimized for large-scale inference: async non-blocking routing, advanced fault tolerance (retries, circuit breakers), multiple load balancing policies (including cache-aware routing), and PD disaggregation support. However, it reconstructs responses with a fixed schema, so it does not preserve the metadata needed for miles's R3 flow. + +For more details on SGLang Model Gateway, see the [official documentation](https://docs.sglang.io/advanced_features/sgl_model_gateway.html). + +### When to use which + +- Use Miles Router when you need R3 or radix-tree caching +- Use SGLang Model Gateway for everything else (recommended default) + diff --git a/docs/en/advanced/miles_server_args.md b/docs/en/advanced/miles_server_args.md new file mode 100644 index 0000000000..c02de1545c --- /dev/null +++ b/docs/en/advanced/miles_server_args.md @@ -0,0 +1,492 @@ +# Miles Server Arguments + +This document provides a detailed list of command-line arguments used to configure Miles for RL training and inference. These arguments enable precise control over cluster resources, training backends (Megatron/FSDP), inference optimization via SGLang, and RL algorithmic hyperparameters. + +You can find all arguments by running: +```bash +python3 train.py --help +``` + +Note that this document is based on commit `a93d484` and was last updated on 02/09/2026. We try our best to ensure the quality and accuracy of these documents. Even so, it's hard to accurately describe all the hundreds of parameters' effect on such complex RL scenarios. This doc is for reference and may contain some tiny errors. + +## Argument Sources + +Miles acts as an orchestrator that integrates multiple frameworks. To help identify where an argument is directed, we follow these prefix conventions: + +* **`--sglang-*`**: Arguments passed directly to the **SGLang** rollout. +* **`--router-*`**: Arguments directed to the **SGLang Model Gateway/Router**. +* **No Prefix**: Default arguments corresponding to **Megatron-LM** (when using the Megatron backend) or **Miles native** configuration. +* **`--fsdp-*`**: Specific arguments for the experimental **FSDP** backend. + +**Note** that Arguments labeled as **Megatron-LM (Reset by Miles)** are native Megatron-LM parameters where Miles has modified the default value or behavior to better suit RL training workflows. + +## Table of Contents + +1. [Cluster and Resource Management](#cluster-and-resource-management) +2. [Training Backend](#training-backend) +3. [Rollout Management](#rollout-management) +4. [Sampling and Filtering](#sampling-and-filtering) +5. [Data Arguments](#data-arguments) +6. [Evaluation Arguments](#evaluation-arguments) +7. [Checkpointing and Resuming](#checkpointing-and-resuming) +8. [Algorithm and RL Arguments](#algorithm-and-rl-arguments) +9. [Logging and Monitoring](#logging-and-monitoring) +10. [Fault Tolerance](#fault-tolerance) +11. [Miles Router](#miles-router) +12. [Reward Model Arguments](#reward-model-arguments) +13. [Rollout Buffer Management](#rollout-buffer-management) +14. [Multi-Token Prediction (MTP) Arguments](#multi-token-prediction-mtp-arguments) +15. [SGLang Backend Arguments](#sglang-backend-arguments) +16. [Megatron Specific Arguments](#megatron-specific-arguments) +17. [FSDP Specific Arguments](#fsdp-specific-arguments) +18. [Debug and Profiling](#debug-and-profiling) +19. [Environment Variables](#environment-variables) +20. [Multi-Turn and Agentic Arguments](#multi-turn-and-agentic-arguments) +21. [Advanced Developer Hooks and CI](#advanced-developer-hooks-and-ci) +22. [Miscellaneous and System](#miscellaneous-and-system) + +## Cluster and Resource Management + +Arguments for configuring Ray cluster resources and GPU allocation. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--actor-num-nodes` | Number of nodes for training the Actor. | `1` | Type: int | Miles Native | +| `--actor-num-gpus-per-node` | Number of GPUs per node for training the Actor. | `8` | Type: int | Miles Native | +| `--critic-num-nodes` | Number of nodes for the Critic. Defaults to `--actor-num-nodes`. | `None` | Type: int | Miles Native | +| `--critic-num-gpus-per-node` | Number of GPUs per node for the Critic. Defaults to `--actor-num-gpus-per-node`. | `None` | Type: int | Miles Native | +| `--rollout-num-gpus` | Total number of GPUs required for rollout (inference). In `--colocate` mode, this is ignored and set to `actor-num-gpus-per-node * actor-num-nodes` (plus critic GPUs if enabled). | `None` | Type: int | Miles Native | +| `--rollout-num-gpus-per-engine` | Number of GPUs per inference engine, same as `tp_size` in SGLang. For multi-node serving, this should be the total GPU count / `tp_size` for each SGLang instance. | `1` | Type: int | Miles Native | +| `--num-gpus-per-node` | Total GPUs per node on the physical machine. This informs the Ray scheduler of the hardware capacity. In **Colocate mode**, it is required that the machine has fewer than 8 GPUs to calculate correct VRAM offsets. In **Disaggregated mode**, it ensures SGLang engines are distributed correctly across nodes without exceeding per-node GPU limits. | `8` | Type: int | Miles Native | +| `--colocate` | Deploy training and rollout on the same GPUs. This mode automatically enables `--offload-train` and `--offload-rollout` to facilitate weight-swapping between the training actor and inference engine. **Note:** The offload parameters are currently only used for AMD GPUs and will be removed soon. **Memory Tip:** When colocating, it is highly recommended to set `--sglang-mem-fraction-static` to **0.8** (especially on **NVIDIA Blackwell B200/B300** GPUs). This leaves sufficient VRAM (~20%) for Megatron to initialize its structures before the first weight offload to CPU occurs. On GB200/GB300, values up to 0.75 are safer for long-running jobs to prevent potential OOMs. #TODO: Verify optimal fraction for Blackwell in production | `False` | bool flag (set to enable) | Miles Native | +| `--prefill-num-servers` | Number of dedicated prefill servers for PD disaggregation. | `None` | Type: int | Miles Native | +| `--distributed-backend` | Backend for distributed communication. | `nccl` | `nccl`, `gloo` | Megatron-LM (Reset by Miles) | +| `--distributed-timeout-minutes` | Timeout for distributed operations in minutes. | `10` | Type: int | Megatron-LM (Reset by Miles) | + +Note that most use cases do not need to consider offload parameters, including `--offload-rollout, --no-offload-rollout, --offload-train, --no-offload-train`. They are used only on AMD GPUs and will eventually be removed. + +## Training Backend + +Arguments for configuring the training engine (Megatron or FSDP). + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--train-backend` | The backend for training. Highly suggest Megatron for numerical stability and efficiency. | `"megatron"` | `megatron`, `fsdp` | Miles Native | +| `--qkv-format` | Whether to pack variable-length sequences into the token dimension for training. `thd` (T-H-D, a.k.a. varlen / packed sequence) concatenates sequences and uses `cu_seqlens` to avoid padding; it is the default and is usually faster by reducing padding overhead. `bshd` (B-S-H-D) uses fixed-shape padded batches; use it for newer models with novel attention architectures (e.g., sparse attention, attention sink) where the training backend does not support `thd`. | `"thd"` | `thd`, `bshd` | Miles Native | +| `--optimizer` | Optimizer type. | `adam` | `adam`, `sgd` | Megatron-LM & FSDP | +| `--lr` | Learning rate for the Actor. | `1e-6` | Type: float | Megatron-LM (Reset by Miles) & FSDP | +| `--lr-warmup-init` | Initial learning rate for warmup. | `0.0` | Type: float | Megatron-LM & FSDP | +| `--min-lr` | Minimum learning rate after decay. | `0.0` | Type: float | Megatron-LM & FSDP | +| `--lr-decay-style` | Learning rate decay style. | `constant`(FSDP), `linear`(Megatron) | Type: str | Megatron-LM & FSDP | +| `--lr-warmup-iters` | Number of iterations for warmup. | `0` | Type: int | Megatron-LM & FSDP | +| `--lr-decay-iters` | Number of iterations for learning rate decay. | `None` | Type: int | Megatron-LM & FSDP | +| `--lr-warmup-fraction` | Fraction of total steps to warmup. | `None` | Type: float | Megatron-LM & FSDP | +| `--adam-beta1` | Beta1 for Adam optimizer. | `0.9` | Type: float | Megatron-LM & FSDP | +| `--adam-beta2` | Beta2 for Adam optimizer. | `0.95` | Type: float | Megatron-LM & FSDP | +| `--adam-eps` | Epsilon for Adam optimizer. | `1e-8` | Type: float | Megatron-LM & FSDP | +| `--true-on-policy-mode` | Strictly align SGLang's log probs and training engine's log probs to bit-wise equal. This parameter is only used for FSDP right now. [Ref](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/mismatch/blog-en.md#truly-on-policy-training) | `False` | bool flag (set to enable) | Miles Native | +| `--train-env-vars` | Extra environment variables for training process, e.g., PyTorch memory management ones. | `{}` | Type: JSON / Dict | Miles Native | +| `--train-memory-margin-bytes` | Reserved memory margin for training in bytes. Defaults to 1GB. | `1073741824` | Type: int | Miles Native | +| `--disable-weights-backuper` | Applies to `megatron` training backend only. Disables the system that backs up model weights (Actor, Ref, Old Actor) to CPU RAM. Disabling saves significant host memory but prevents features that rely on weight-swapping, such as computing the KL-divergence against a reference model. **Note**: do not set `--ref-load` and `--keep-old-actor` if disable weights backuper. | `False` | bool flag (set to disable) | Miles Native | +| `--custom-model-provider-path` | Path to a custom function that replaces the default model provider. [Ref](../get_started/customization.md#20-model-provider---custom-model-provider-path) | `None` | Type: str | Miles Native | +| `--recompute-loss-function` | Enable recomputing the loss function to save memory during training. | `False` | bool flag (set to enable) | Miles Native | +| `--log-probs-chunk-size` | Specifies the chunk size for logprobs computation to reduce peak memory usage. Processing logits in smaller batches, it prevents CUDA OOM errors during long-context prefilling or re-computation. Set to `-1` to disable chunking. [Ref](https://github.com/sgl-project/sglang/pull/6318) | `-1` | Type: int | Miles Native | +| `--keep-old-actor` | Maintains a "Model Queue" (Actor, Rollout Actor, Old Actor) to ensure importance sampling ratios are calculated against the exact policy version that generated the data. Essential for asynchronous RL where training and inference are decoupled, preventing mathematical incorrectness due to model staleness. It consumes additional Host Memory (extra ~1x model size for `update_weights_interval > 1` or 2x for `update_weights_interval == 1`) depending on update interval. | `False` | bool flag (set to enable) | Miles Native | +| `--update-weight-buffer-size` | Buffer size for updating weights, in bytes. [Ref](https://hebiao064.github.io/rl-weight-sync#42-optimizing-sglang-server-calls-with-tensor-bucketing-from-50s-to-30s) | `536870912` | Type: int | Miles Native | +| `--update-weights-interval` | Interval (in rollout rounds) for syncing weights to inference engines. Set to `>1` for async RL. | `1` | Type: int | Miles Native | +| `--fp16` | Enable FP16 mixed precision. | `False` | bool flag (set to enable) | Megatron-LM & FSDP | +| `--context-parallel-size` | Size of context parallelism. | `1` | Type: int | Megatron-LM & FSDP | +| `--deterministic-mode` | Enable deterministic mode for reproducibility. [Ref](https://lmsys.org/blog/2025-09-22-sglang-deterministic/) | `False` | bool flag (set to enable) | Megatron-LM & FSDP | + +## Rollout Management + +Arguments for configuring the rollout (inference) process and custom rollout logic. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--hf-checkpoint` | Path to the HuggingFace checkpoint used to initialize SGLang and provide the tokenizer. | `None` | Type: str | Miles Native | +| `--model-name` | The name of the model that is used to convert the Megatron weights into HuggingFace format. If not set, we will use `type(AutoConfig.from_pretrained(args.hf_checkpoint)).__name__.lower()` as `model_name`. Providing this argument can also help in cases where transformers cannot find certain models. | `None` | Type: str | Miles Native | +| `--rollout-function-path` | Path to the rollout generation function. Use this to inject custom logic (e.g., for multi-turn or tool use). [Ref](../get_started/customization.md#1-rollout-function---rollout-function-path) | `miles.rollout.sglang_rollout.generate_rollout` (or `miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn` when `MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1`) | Type: str | Miles Native | +| `--rollout-temperature` | Sampling temperature for the inference engine during rollout. | `1.0` | Type: float | Miles Native | +| `--rollout-top-p` | Top-p (nucleus) sampling threshold during rollout. | `1.0` | Type: float | Miles Native | +| `--rollout-top-k` | Top-k sampling threshold during rollout. `-1` means disabled. | `-1` | Type: int | Miles Native | +| `--rollout-max-context-len` | The maximum context size for the inference engine during rollout. It should not exceed the `max_position_embeddings` in the HuggingFace model's `config.json`. **Note:** This acts as a hard cap for the total tokens (Prompt + Response). | `None` | Type: int | Miles Native | +| `--rollout-max-prompt-len` | Maximum length of the prompt. Longer prompts are filtered during dataset initialization. This is not recommended if the dataset is large. **Note:** Defaults to `rollout-max-context-len - 1` if not set, ensuring at least one token can be generated. | `None` | Type: int | Miles Native | +| `--rollout-max-response-len` | Maximum length of the response (`max_tokens` in SGLang). **Note:** Generation will stop when either this limit is reached or the total session length hits `rollout-max-context-len`. | `None` | Type: int | Miles Native | +| `--rollout-skip-special-tokens` | Skip special tokens (e.g., `<\|im_end\|>`, `<\|endoftext\|>`) in the decoded response string. **Critical for Multi-Turn RL:** Ensures that when a response is appended to the conversation history for the next turn, it doesn't include terminal special tokens that would interfere with chat template formatting or cause early termination in subsequent turns. | `False` | bool flag (set to enable) | Miles Native | +| `--rollout-stop` | A list of strings that trigger termination of generation if they appear in the output (e.g., `"\nUser:"`). | `None` | Type: List[str] | Miles Native | +| `--rollout-stop-token-ids` | A list of numerical token IDs that trigger termination. This is the token-level equivalent of `--rollout-stop` and is preferred for special control tokens that are difficult to input as strings. | `None` | Type: List[int] | Miles Native | +| `--rollout-shuffle` | Shuffle the prompts during rollout. | `False` | bool flag (set to enable) | Miles Native | +| `--rollout-seed` | Seed for the random number generator during rollout (used for shuffling and sampling). | `42` | Type: int | Miles Native | +| `--rollout-external` | Use external SGLang instances instead of launching them inside the framework. | `False` | bool flag (set to enable) | Miles Native | +| `--rollout-external-engine-addrs` | Addresses and ports of the external engines. | `None` | Type: List[str] | Miles Native | +| `--custom-generate-function-path` | Path to override only the `generate` step within the default rollout function. If your custom `generate` returns `list[Sample]` (multi-sample), make sure your rollout pipeline can handle it; the default rollout expects a flat `list[Sample]` of length `--n-samples-per-prompt` for each prompt group. [Ref](../get_started/customization.md#2-custom-generate-function---custom-generate-function-path) | `None` | Type: str | Miles Native | +| `--custom-rollout-log-function-path` | Path to a custom function for logging training rollout data. [Ref](../get_started/customization.md#14-logging-functions) | `None` | Type: str | Miles Native | +| `--custom-eval-rollout-log-function-path` | Path to a custom function for logging evaluation rollout data. [Ref](../get_started/customization.md#14-logging-functions) | `None` | Type: str | Miles Native | +| `--rollout-data-postprocess-path` | Path to a function called after all rollout data (including log probs) is ready. [Ref](../get_started/customization.md#8-rollout-data-postprocess---rollout-data-postprocess-path) | `None` | Type: str | Miles Native | + +## Sampling and Filtering + +Arguments for sampling strategies and data filtering during rollout and buffer management. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--over-sampling-batch-size` | Number of prompts requested in each **oversampling** round when **dynamic sampling** is enabled. Miles samples `over_sampling_batch_size` prompts, generates `--n-samples-per-prompt` responses per prompt asynchronously, and then keeps/discards each prompt group via `--dynamic-sampling-filter-path`. If filtering is strict and the remaining accepted batch size drops below the target `--rollout-batch-size`, Miles automatically triggers another oversampling round of the same size. If unset, defaults to `--rollout-batch-size`. See [Dynamic Sampling](../get_started/quick_start.md#dynamic-sampling). | `None` | Type: int | Miles Native | +| `--dynamic-sampling-filter-path` | Path to the filter function for dynamic sampling. [Ref](../get_started/customization.md#4-dynamic-sampling-filter---dynamic-sampling-filter-path) | `None` | Type: str | Miles Native | +| `--partial-rollout` | Enable partial rollout for **dynamic sampling**: cache partially generated (aborted/unfinished) samples and resume generation in later rollout steps, reducing wasted compute for long responses. Cached samples are stored in the rollout buffer and can be prioritized/selected via `--buffer-filter-path` (default FIFO behavior). See [Partial Rollout](../get_started/quick_start.md#partial-rollout). | `False` | bool flag (set to enable) | Miles Native | +| `--mask-offpolicy-in-partial-rollout` | When using partial rollout, mask the previously generated (cached) response tokens so they do not contribute to the loss; only tokens generated after resuming are used for training. This helps avoid training on a cached prefix produced by an older policy version. See [Partial Rollout](../get_started/quick_start.md#partial-rollout). | `False` | bool flag (set to enable) | Miles Native | +| `--buffer-filter-path` | Path to the function to filter or sort samples in the rollout buffer before training. [Ref](../get_started/customization.md#5-buffer-filter---buffer-filter-path) | `None` | Type: str | Miles Native | +| `--rollout-sample-filter-path` | Path to the function that marks individual samples to be excluded from loss calculation. [Ref](../get_started/customization.md#6-rollout-sample-filter---rollout-sample-filter-path) | `None` | Type: str | Miles Native | +| `--rollout-all-samples-process-path` | Path to the function to process all samples (including filtered ones) after rollout. [Ref](../get_started/customization.md#7-rollout-all-samples-process---rollout-all-samples-process-path) | `None` | Type: str | Miles Native | + +## Data Arguments + +Arguments for dataset configuration, prompt mapping, and training batch sizes. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--prompt-data` | Path to the prompt dataset (JSONL format), and each line should contain `--input-key` and `--label-key`, which will be used as the prompt and the label, respectively. | `None` | Type: str | Miles Native | +| `--disable-rollout-global-dataset` | Disable the global dataset for rollout. By default, Miles loads `--prompt-data` into a global dataset and samples from it for rollout. Setting this flag turns off this behavior. Use this flag only when providing a custom `--rollout-function-path` (and usually a custom `--data-source-path`) that handles data loading independently. | `False` | bool flag (set to disable) | Miles Native | +| `--data-source-path` | Path to a custom Python class for the rollout data source. [Ref](../get_started/customization.md#15-data-source---data-source-path) | `miles.rollout.data_source.RolloutDataSourceWithBuffer` | Type: str | Miles Native | +| `--input-key` | Key in the JSONL data representing the user input/prompt. | `"input"` | Type: str | Miles Native | +| `--label-key` | Key in the JSONL data representing the label/ground truth. | `None` | Type: str | Miles Native | +| `--metadata-key` | When adding tools during `apply_chat_template`, provide the key for the tools to the prompt dataset. | `"metadata"` | Type: str | Miles Native | +| `--multimodal-keys` | JSON string for multimodal data mapping media types to data keys. Example: `'{"image": "image_file"}'` | `None` | Type: str | Miles Native | +| `--tool-key` | JSON key for tool definitions in the prompt dataset (used when applying chat templates). | `"tools"` | Type: str | Miles Native | +| `--apply-chat-template` | Whether to apply the chat template to the input prompt. The input should be the same structure as an OpenAI message, e.g., `[{'role': 'user', 'content': 'blabla'}]`. | `False` | bool flag (set to enable) | Miles Native | +| `--apply-chat-template-kwargs` | Extra arguments for the chat template processing (JSON string). | `"{}"` | Type: str | Miles Native | +| `--num-rollout` | Number of rollout steps. If not set, Miles will calculate the number of rollout steps from the dataset size. **Note:** This value will be overwritten if `--num-epoch` is also set. | `None` | Type: int | Miles Native | +| `--num-epoch` | Number of epochs for the training. If set, `num_rollout` is calculated as `(num_epoch * dataset_size) // rollout_batch_size`. **Note:** This argument takes precedence and will overwrite `--num-rollout` if both are specified. | `None` | Type: int | Miles Native | +| `--rollout-batch-size` | Number of prompts per rollout batch. The total data returned should be `rollout_batch_size` * `n_samples_per_prompt`. | Required | Type: int | Miles Native | +| `--n-samples-per-prompt` | Number of responses to generate for each prompt, e.g., the group size of GRPO. The default rollout pipeline expects each prompt group to contain exactly `n_samples_per_prompt` samples. | `1` | Type: int | Miles Native | +| `--global-batch-size` | Total samples per optimizer step. Automatically calculated or **overridden** if `num_steps_per_rollout` is set. | `None` | Type: int | Megatron-LM (Reset by Miles) | +| `--num-steps-per-rollout` | The number of training steps to perform using the data collected in a single rollout round. Setting this to `n` means the policy model will be updated `n` times using the same batch of rollout data. Miles ensures that `(rollout-batch-size * n-samples-per-prompt) = (global-batch-size * num-steps-per-rollout)`. If this value is not provided, you have to set `--global-batch-size` explicitly. If both are provided, `--num-steps-per-rollout` will **override** the global batch size with `num_steps_per_rollout = (rollout_batch_size * n_samples_per_prompt) // num_steps_per_rollout`. | `None` | Type: int | Miles Native | +| `--use-dynamic-batch-size` | Dynamically packs variable-length samples into micro-batches to maximize GPU utilization, ensuring the total token count per batch does not exceed `--max-tokens-per-gpu`. For example, with a 300-token limit, samples of lengths 100, 200, and 300 would be packed into two batches: `[100, 200]` and `[300]`. **Note:** Miles ensures that enabling this optimization does not affect the mathematical correctness of per-sample or per-token loss calculation. It is **strongly recommended** to enable this for maximum efficiency. **Compatibility:** only supported when `--qkv-format` is `thd` (does not work for `bshd`). | `False` | bool flag (set to enable) | Miles Native | +| `--max-tokens-per-gpu` | The maximum number of tokens (Prompt + Response combined) per GPU for dynamic batch size. This parameter defines the total sequence length budget for packing samples into micro-batches during training. Note that when enabling context parallel (CP), the effective capacity is shared, so the value should be approximately `(Total_Sequence_Length) // cp_size`. | `None` | Type: int | Miles Native | +| `--log-probs-max-tokens-per-gpu` | The maximum number of tokens per GPU for calculating log probs. This is used to calculate the log probs of the responses during rollout, and should be set to a larger value than `max_tokens_per_gpu` if you want better performance. | `None` | Type: int | Miles Native | +| `--balance-data` | Repartition each rollout batch so each data-parallel rank gets a similar total token count via the Karmarkar-Karp method. It may be beneficial for training speed, but changes per-rank sample grouping and adds a small CPU scheduling overhead. | `False` | bool flag (set to enable) | Miles Native | +| `--data-pad-size-multiplier` | Multiplier used to calculate the sequence padding boundary. Miles rounds sequence lengths up to a multiple of `tensor_parallel_size * data_pad_size_multiplier`. This optimization ensures that matrix dimensions are aligned with NVIDIA Tensor Core requirements, maximizing throughput and reducing VRAM fragmentation. **Note:** better not change this; values `<128` may trigger accuracy loss under `--qkv-format thd` when `TP>=4`. | `128` | Type: int | Miles Native | +| `--micro-batch-size` | Micro batch size per GPU. Ignored when `--use-dynamic-batch-size` is enabled. Works for both `--qkv-format thd` and `--qkv-format bshd` (and is required for `bshd` because dynamic batch size is unsupported). | `1` | Type: int | Megatron-LM (Reset by Miles) | + +## Evaluation Arguments + +Arguments for configuring the evaluation process during training. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--eval-interval` | Interval (in rollout steps) between evaluations. | `None` | Type: int | Megatron-LM (Reset by Miles) | +| `--eval-prompt-data` | List of name and path pairs for evaluation datasets (e.g., `aime /path/to/aime.jsonl`). | `None` | Type: List[str] | Miles Native | +| `--eval-config` | Path to an OmegaConf YAML/JSON file describing evaluation datasets (overrides `--eval-prompt-data`). | `None` | Type: str | Miles Native | +| `--skip-eval-before-train` | Skip the evaluation step before training starts. | `False` | bool flag (set to enable) | Miles Native | +| `--n-samples-per-eval-prompt` | Number of responses for each prompt in generation. | `1` | Type: int | Miles Native | +| `--eval-temperature` | Temperature for evaluation (defaults to rollout temperature if not set). | `None` | Type: float | Miles Native | +| `--eval-top-p` | Top-p sampling threshold for evaluation (defaults to rollout top-p if not set). | `None` | Type: float | Miles Native | +| `--eval-top-k` | Top-k sampling threshold for evaluation (defaults to rollout top-k if not set). | `None` | Type: int | Miles Native | +| `--eval-max-response-len` | Maximum response length for evaluation (defaults to rollout max response length if not set). | `None` | Type: int | Miles Native | +| `--eval-max-prompt-len` | Maximum prompt length for evaluation. | `None` | Type: int | Miles Native | +| `--eval-min-new-tokens` | Minimum tokens to generate for evaluation responses (Not used). | `None` | Type: int | Miles Native | +| `--eval-max-context-len` | Maximum context length for evaluation (defaults to rollout max context length if not set). | `None` | Type: int | Miles Native | +| `--eval-function-path` | Path to a custom evaluation function. [Ref](../get_started/customization.md#16-evaluation-function---eval-function-path) | `None` | Type: str | Miles Native | +| `--eval-input-key` | JSON key for input text in evaluation datasets. | `None` | Type: str | Miles Native | +| `--eval-label-key` | JSON key for ground truth labels in evaluation datasets. | `None` | Type: str | Miles Native | +| `--eval-tool-key` | JSON key for tool definitions in evaluation datasets. | `None` | Type: str | Miles Native | + +## Checkpointing and Resuming + +Arguments for saving and loading model states. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--load` | Path to the training model checkpoint to load. | `None` | Type: str | Megatron-LM (Reset by Miles) | +| `--save` | Path to save checkpoints. | `None` | Type: str | Megatron-LM (Reset by Miles) | +| `--save-interval` | Interval (in rollout steps) to save checkpoints. Requires `--save` to be set. | `None` | Type: int | Megatron-LM (Reset by Miles) | +| `--async-save` | Enable asynchronous checkpoint saving (Megatron backend only). | `False` | bool flag (set to enable) | Megatron-LM (Reset by Miles) | +| `--save-hf` | Path to save the model in HuggingFace format when using Megatron backend. The model will be saved to `save_hf.format(rollout_id)`. | `None` | Type: str | Miles Native | +| `--no-save-optim` | If set, optimizer state is not saved with checkpoints to reduce size, but prevents resumption of training. | `False` | bool flag (set to enable) | Megatron-LM (Reset by Miles) | +| `--ref-load` | Path to the reference model checkpoint. Used as an initial checkpoint if `--load` is not set. | `None` | Type: str | Miles Native | +| `--ref-ckpt-step` | The checkpoint step for the reference model. | `None` | Type: int | Miles Native | +| `--critic-load` | Checkpoint to load for the critic model. | value of `--load` | Type: str | Miles Native | +| `--critic-save` | Path to save the critic model. | `None` | Type: str | Miles Native | +| `--start-rollout-id` | The starting rollout step. If not set, it is inferred from the --load checkpoint when resuming training. Otherwise, if training is not continuous, Miles will start training from scratch | `None` | Type: int | Miles Native | + +--- + +## Algorithm and RL Arguments + +Arguments for reinforcement learning algorithms and loss calculation. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--advantage-estimator` | Advantage estimator to use. | `"grpo"` | `grpo`, `gspo`, `ppo`, `reinforce_plus_plus`, `reinforce_plus_plus_baseline`, `on_policy_distillation` | Miles Native | +| `--loss-type` | Type of loss function to use. | `"policy_loss"` | `policy_loss`, `sft_loss`, `custom_loss` | Miles Native | +| `--custom-loss-function-path` | Path to a custom loss calculation function (requires `--loss-type custom_loss`). [Ref](../get_started/customization.md#9-custom-loss-function---custom-loss-function-path) | `None` | Type: str | Miles Native | +| `--critic-lr` | Learning rate for the Critic. Defaults to `--lr`. | `None` | Type: float | Miles Native | +| `--critic-lr-warmup-iters` | Number of iterations for Critic learning rate linear warmup. | `0` | Type: int | Miles Native | +| `--num-critic-only-steps` | Number of initial steps dedicated to training only the Critic. | `0` | Type: int | Miles Native | +| `--eps-clip` | PPO clip range. | `0.2` | Type: float | Miles Native | +| `--eps-clip-high` | PPO clip upper range (defaults to `--eps-clip` if not set). | `None` | Type: float | Miles Native | +| `--eps-clip-c` | Lower bound for [Dual-clip PPO](https://arxiv.org/pdf/1912.09729). | `None` | Type: float | Miles Native | +| `--value-clip` | Clip range for value loss. | `0.2` | Type: float | Miles Native | +| `--kl-coef` | KL penalty coefficient for reward shaping. This is applied to the reward signal before advantage calculation for PPO and REINFORCE-style estimator. | `0.00` | Type: float | Miles Native | +| `--use-kl-loss` | Enable KL loss term in the final objective (as in GRPO). | `False` | bool flag (set to enable) | Miles Native | +| `--kl-loss-coef` | Weight of the KL loss term in the final objective. | `0.0` | Type: float | Miles Native | +| `--kl-loss-type` | Selection of the KL loss implementation. See [Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for more details. | `k1` | `k1`, `k2`, `k3`, `low_var_kl` | Miles Native | +| `--use-unbiased-kl` | Apply Importance Sampling (IS) correction to the KL estimator. Reduces bias from distribution shift. | `False` | bool flag (set to enable) | Miles Native | +| `--entropy-coef` | Coefficient for entropy regularization term. Penalizes low entropy to encourage exploration and prevent premature convergence. | `0.0` | Type: float | Miles Native | +| `--gamma` | Discount factor for future rewards. Used in PPO (GAE) and REINFORCE++. | `1.0` | Type: float | Miles Native | +| `--lambd` | PPO GAE lambda. | `1.0` | Type: float | Miles Native | +| `--normalize-advantages` | Performs distributed masked whitening of advantages. Normalization statistics are computed globally across the Data-Parallel group, ignoring padding tokens. | `False` | bool flag (set to enable) | Miles Native | +| `--disable-compute-advantages-and-returns` | Disables the calculation of advantages and returns. This is typically used for SFT or custom loss functions where value estimation is not required. | `False` | bool flag (set to enable) | Miles Native | +| `--use-tis` | Enable Token-level Importance Sampling (TIS) from this [blog](https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33). | `False` | bool (set to enable) | Miles Native | +| `--tis-clip` | Clipping threshold C for importance sampling ratios to control variance. | `2.0` | Type: float | Miles Native | +| `--tis-clip-low` | Lower bound clipping threshold C for importance sampling ratios to control variance. | `0.0` | Type: float | Miles Native | +| `--custom-tis-function-path` | Path to a custom TIS or MIS function. [Ref](../get_started/customization.md#10-custom-tisrs-function---custom-tis-function-path) | `None` | Type: str | Miles Native | +| `--custom-pg-loss-reducer-function-path` | Custom reducer function for policy gradient loss. [Ref](../get_started/customization.md#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | `None` | Type: str | Miles Native | +| `--use-routing-replay` | Enable R2 (Routing Replay) for MoE: record expert routing decisions during forward and replay them during backward. [Paper](https://arxiv.org/abs/2507.18071) **Note:** automatically set to `True` when `--use-rollout-routing-replay` is enabled. | `False` | bool flag (set to enable) | Miles Native | +| `--use-rollout-routing-replay` | Enable R3 (Rollout Routing Replay) for MoE: record expert routing decisions during rollout and replay them during training. **Requires `--use-miles-router`**. [Paper](https://arxiv.org/abs/2510.11370) [Ref](miles-router.md#22-rollout-routing-replay-r3-for-moe) | `False` | bool flag (set to enable) | Miles Native | +| `--use-opsm` | Enable Off-Policy Sequence Masking (OPSM). Filters sequences that have **BOTH** negative advantages (bad results) AND high KL divergence (stale data). This stabilizes training by preventing updates from unreliable, highly off-policy samples. | `False` | bool flag (set to enable) | Miles Native | +| `--opsm-delta` | The threshold for Off-Policy Sequence Masking (OPSM). | `1e-4` | Type: float | Miles Native | +| `--get-mismatch-metrics` | Calculate mismatch metrics. If it is set, you need to provide a custom TIS function via `--custom-tis-function-path`. | `False` | bool flag (set to enable) | Miles Native | +| `--ref-update-interval` | Interval (in rollout steps) to update ref model from actor. If `None`, ref model is not updated. | `None` | Type: int | Miles Native | +| `--reset-optimizer-states` | Resets the optimizer state after each rollout round. This clears the optimization history, which can improve stability or satisfy specific experimental requirements. | `False` | bool flag (set to enable) | Miles Native | +| `--disable-grpo-std-normalization` | Disable standard deviation normalization for GRPO. From [Dr.GRPO](https://arxiv.org/pdf/2503.20783) | `False` | bool flag (set to enable) | Miles Native | +| `--disable-rewards-normalization` | Disable the default group-wise reward normalization for GRPO, GSPO, and REINFORCE++. This effectively skips the baseline subtraction step. | `False` | bool flag (set to enable) | Miles Native | +| `--use-rollout-entropy` | Enable entropy calculation when calculating the logprobs from actor and reference model. This is useful for implementing custom entropy-based loss masking. | `False` | bool flag (set to enable) | Miles Native | +| `--use-rollout-logprobs` | Use rollout logprobs as the old-policy logprobs when computing importance sampling ratios / PPO-style KL in GRPO/GSPO/PPO. If not set, Miles recomputes old-policy logprobs with the training actor (e.g., `old_actor` or `actor`, depending on configuration). If `--get-mismatch-metrics` is set, the log probs will still be recomputed by the training engine (one more forward pass will be applied). | `False` | bool flag (set to enable) | Miles Native | +| `--calculate-per-token-loss` | Calculate loss on a per-token basis. | `False` | bool flag (set to enable) | Megatron-LM (Reset by Miles) | +| `--seed` | Random seed for the training process. **Also passed to SGLang servers as `random_seed`** (Miles uses `seed + engine_rank` so each engine has a distinct but reproducible seed). | `1234` | Type: int | Megatron-LM (Reset by Miles) | +| `--clip-grad` | Maximum gradient norm for gradient clipping. | `1.0` | Type: float | Megatron-LM (Reset by Miles) | + +--- + +## Logging and Monitoring + +Arguments for WandB, Tensorboard, and general logging. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--use-wandb` | Enable WandB logging. | `False` | bool flag (set to enable) | Miles Native | +| `--wandb-mode` | WandB operating mode. Overrides `WANDB_MODE`. | `None` | `online`, `offline`, `disabled` | Miles Native | +| `--wandb-project` | WandB project name. | `None` | Type: str | Megatron-LM (Reset by Miles) | +| `--wandb-group` | WandB group name. | `None` | Type: str | Miles Native | +| `--wandb-team` | WandB team name. | `None` | Type: str | Miles Native | +| `--wandb-host` | WandB host address. | `None` | Type: str | Miles Native | +| `--wandb-key` | WandB API key. | `None` | Type: str | Miles Native | +| `--wandb-run-id` | Specific WandB run ID to resume. | `None` | Type: str | Miles Native | +| `--wandb-dir` | Directory to store WandB logs. Default is `./wandb` in current directory. | `None` | Type: str | Miles Native | +| `--disable-wandb-random-suffix` | Disable adding a random suffix to the WandB run name. By default, we will add a random 6 length string with characters to the run name. | `False` | bool flag (set to enable) | Miles Native | +| `--wandb-always-use-train-step` | Use training steps instead of rollout steps for the x-axis. | `False` | bool flag (set to enable) | Miles Native | +| `--use-tensorboard` | Enable Tensorboard logging. | `False` | bool flag (set to enable) | Miles Native | +| `--tb-project-name` | Tensorboard project directory. | `None` | Type: str | Miles Native | +| `--tb-experiment-name` | Tensorboard experiment name. | `None` | Type: str | Miles Native | +| `--tensorboard-dir` | Directory to store Tensorboard logs. | `None` | Type: str | Miles Native | +| `--log-multi-turn` | Log detailed information for multi-turn conversations. | `False` | bool flag (set to enable) | Miles Native | +| `--log-passrate` | Enable logging of `pass@n` metrics. | `False` | bool flag (set to enable) | Miles Native | +| `--log-correct-samples` | Explicitly log metrics for correct samples. | `False` | bool flag (set to enable) | Miles Native | +| `--log-reward-category` | Log reward-category statistics (e.g., why the reward function marked a failure). Use this argument to specify the key in the reward dict. | `None` | Type: str | Miles Native | + +--- + +## Fault Tolerance + +Arguments for handling server failures during rollout. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--use-fault-tolerance` | Enable fault tolerance for rollout engines. Periodically sends `/health_generate` heartbeats. | `False` | bool flag (set to enable) | Miles Native | +| `--rollout-health-check-interval` | Interval in seconds between rollout engine `/health_generate` checks during generate/eval. | `30.0` | Type: float | Miles Native | +| `--rollout-health-check-timeout` | Timeout in seconds to wait for a rollout engine `/health_generate` response before killing it. | `30.0` | Type: float | Miles Native | +| `--rollout-health-check-first-wait` | Initial grace period (in seconds) before starting health checks. This allows time for model compilation and initialization. Increase this value significantly when using deepgemm. | `0.0` | Type: float | Miles Native | + +--- + +## Miles Router + +Arguments for the specialized Miles text-based router. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--use-miles-router` | Use Miles Router (FastAPI passthrough proxy) instead of SGLang Model Gateway for rollout routing. Required for features that depend on preserving extra rollout metadata (e.g., R3). [Ref](miles-router.md) | `False` | bool flag (set to enable) | Miles Native | +| `--miles-router-middleware-paths` | Paths to custom MilesRouter middleware functions. [Ref](../get_started/customization.md#18-miles-router-middleware---miles-router-middleware-paths) | `""` | Type: List[str] | Miles Native | +| `--miles-router-timeout` | Timeout for router HTTP requests in seconds. | `None` | Type: float | Miles Native | +| `--miles-router-max-connections` | Max connections for MilesRouter HTTP client. | `None` | Type: int | Miles Native | +| `--miles-router-health-check-failure-threshold` | Number of consecutive failures before marking a worker as unhealthy. | `3` | Type: int | Miles Native | + +--- + +## Reward Model Arguments + +Arguments for configuring reward signals and post-processing. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--rm-type` | Built-in reward model selection. | `None` | `remote_rm`, `deepscaler`, `dapo`, `math`, `f1`, `gpqa`, `ifbench`, `random` | Miles Native | +| `--rm-url` | URL for the reward model service (used with `--rm-type remote_rm`). | `None` | Type: str | Miles Native | +| `--reward-key` | JSON key to extract the numerical reward from a returned dictionary if reward model returns a dict instead of a value. | `None` | Type: str | Miles Native | +| `--eval-reward-key` | Evaluation variant for `--reward-key`. | `None` | Type: str | Miles Native | +| `--custom-rm-path` | Path to a custom Python reward function. [Ref](../get_started/customization.md#3-reward-model---custom-rm-path) | `None` | Type: str | Miles Native | +| `--group-rm` | Defer reward computation to process the entire group of samples (per-prompt) at once. Essential for comparative/ranking reward models and improves throughput. **Not supported in eval**. | `False` | bool flag (set to enable) | Miles Native | +| `--custom-reward-post-process-path` | Path to a custom reward post-processor. [Ref](../get_started/customization.md#12-reward-post-processing---custom-reward-post-process-path) | `None` | Type: str | Miles Native | +| `--custom-convert-samples-to-train-data-path` | Path to a custom data format converter. [Ref](../get_started/customization.md#13-samples-to-train-data-conversion---custom-convert-samples-to-train-data-path) | `None` | Type: str | Miles Native | + +--- + +## Rollout Buffer Management + +Arguments for managing the rollout data buffer. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--rollout-buffer-url` | URL for the rollout buffer service. | `None` | Type: str | Miles Native | +| `--fetch-trajectory-retry-times` | Number of times to retry fetching trajectory, -1 means unlimited retry. | `-1` | Type: int | Miles Native | +| `--min-batch-collection-ratio` | Minimum batch collection ratio before proceeding. | `1.0` | Type: float | Miles Native | +| `--disable-rollout-trim-samples` | Disable trim samples in rollout buffer when converting samples to train data. | `False` | bool flag (set to enable) | Miles Native | +| `--use-dynamic-global-batch-size` | Enable dynamic global batch size, disable trim samples in rollout buffer when converting samples to train data. | `False` | bool flag (set to enable) | Miles Native | +| `--rollout-task-type` | Type of task being performed. | `math` | Type: str | Miles Native | +| `--loss-mask-type` | Selection of the token masking logic. | `qwen` | `qwen`, `qwen3`, `distill_qwen` | Miles Native | + +--- + +## Multi-Token Prediction (MTP) Arguments + +Arguments for MTP-based training. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--enable-mtp-training` | Enable MTP layer parameter updates during training. | `False` | bool flag (set to enable) | Miles Native | +| `--mtp-num-layers` | Number of MTP layers to include. | `None` | Type: int | Megatron-LM (Reset by Miles) | +| `--mtp-loss-scaling-factor` | Scaling factor applied to the MTP loss. | `0.2` | Type: float | Megatron-LM (Reset by Miles) | + +--- + +## SGLang Backend Arguments + +Most SGLang server arguments can be passed through by adding the `--sglang-` prefix (some are intentionally skipped, e.g. `model_path`, `tp_size`, `port`, `nnodes`, `node_rank`). For a full list, refer to the [SGLang Server Arguments documentation](https://docs.sglang.io/advanced_features/server_arguments.html). + +Commonly used arguments: + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--sglang-mem-fraction-static` | Fraction of GPU memory to reserve for SGLang KV cache. | `0.9` | Type: float | SGLang | +| `--sglang-server-concurrency` | Maximum number of concurrent requests. | `512` | Type: int | SGLang | +| `--sglang-router-ip` | IP address of the SGLang router and Miles Router. | `None` | Type: str | SGLang Gateway & Miles Router | +| `--sglang-router-port` | Port of the SGLang router and Miles Router. | `None` | Type: int | SGLang Gateway & Miles Router | +| `--sglang-router-request-timeout-secs` | Timeout for requests to the SGLang router. | `14400` | Type: int | SGLang Gateway | + +--- + +## Megatron Specific Arguments + +Arguments applicable when using `--train-backend megatron`. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--megatron-to-hf-mode` | Method to convert Megatron weights to HuggingFace format for SGLang integration. | `raw` | `raw`, `bridge` | Miles Native | +| `--seq-length` | Megatron’s “maximum sequence length” parameter. **In miles training, this parameter has no effect in most setups**: miles uses varlen/packed samples (no truncation based on `seq_length`), forces variable sequence lengths for PP communication buffers, and uses all-to-all token dispatch for MoE. This parameter mainly matters in Megatron’s dataset pipeline. | `None` | Type: int | Megatron-LM | + +--- + +## FSDP Specific Arguments + +Arguments applicable when using `--train-backend fsdp`. **Note: The FSDP backend is still under development and experimental.** + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--warmup-ratio` | Ratio of total steps for warmup. | `0.03` | Type: float | Miles Native | +| `--weight-decay` | Weight decay for the optimizer. | `0.0` | Type: float | Miles Native | +| `--gradient-checkpointing` | Enable gradient checkpointing. | `False` | bool flag (set to enable) | Miles Native | +| `--fsdp-cpu-offload` | Offload parameters and gradients to CPU. | `False` | bool flag (set to enable) | Miles Native | +| `--fsdp-state-dict-cpu-offload` | Offload full state dict to CPU during collection. | `False` | bool flag (set to enable) | Miles Native | +| `--fsdp-cpu-backend` | CPU backend for FSDP CPU offload. | `gloo` | `gloo`, `None` | Miles Native | +| `--attn-implementation` | Selection of the attention implementation. | `flash_attention_2` | `flash_attention_2`, `sdpa`, `eager` | Miles Native | +| `--use-pytorch-profiler` | Enable PyTorch-native profiling. | `False` | bool flag (set to enable) | Miles Native | +| `--profile-step-start` | Starting step for profiling. | `10` | Type: int | Miles Native | +| `--profile-step-end` | Ending step for profiling. | `12` | Type: int | Miles Native | +| `--lr-wsd-decay-iters` | Number of iterations for WSD decay. | `None` | Type: int | Miles Native | +| `--lr-wsd-decay-style` | Decay style for WSD. | `None` | Type: str | Miles Native | +| `--use-checkpoint-lr-scheduler` | Use the checkpoint's LR scheduler state. | `False` | bool flag (set to enable) | Miles Native | +| `--override-lr-scheduler` | Override the loaded LR scheduler state. | `False` | bool flag (set to enable) | Miles Native | +| `--wandb-run-name` | Specific run name for WandB (FSDP backend). | `None` | Type: str | Miles Native | + +--- + +## Debug and Profiling + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--check-weight-update-equal` | Use SGLang's weight checker to check and ensure that the loaded weight from HF checkpoint and received from Megatron are bit-wise equal. | `False` | bool flag (set to enable) | Miles Native | +| `--save-debug-rollout-data` | Path to save rollout data for offline analysis. [Ref](../developer_guide/debug.md) | `None` | Type: str | Miles Native | +| `--load-debug-rollout-data` | Path to load debug rollout data (bypasses SGLang). [Ref](../developer_guide/debug.md) | `None` | Type: str | Miles Native | +| `--load-debug-rollout-data-subsample` | Percentage of debug data to load (0.0 to 1.0). [Ref](../developer_guide/debug.md) | `None` | Type: float | Miles Native | +| `--debug-rollout-only` | Run the rollout phase only without training. [Ref](../developer_guide/debug.md) | `False` | bool flag (set to enable) | Miles Native | +| `--debug-train-only` | Run the training phase only without launching SGLang servers. [Ref](../developer_guide/debug.md) | `False` | bool flag (set to enable) | Miles Native | +| `--save-debug-train-data` | Path to save training batches for offline math debugging. | `None` | Type: str | Miles Native | +| `--dump-details` | Dump exhaustive training details for post-hoc visualization. | `None` | Type: str | Miles Native | +| `--memory-snapshot-path` | Path to save memory snapshots. | `snapshot.pickle` | Type: str | Miles Native | +| `--record-memory-history` | Record memory history for snapshots. | `False` | bool flag (set to enable) | Miles Native | +| `--memory-snapshot-dir` | Directory for PyTorch memory snapshots. | `.` | Type: str | Miles Native | +| `--memory-snapshot-num-steps` | Number of steps to record before saving snapshot. | `None` | Type: int | Miles Native | +| `--memory-recorder` | Selection of the memory recording backend. | `torch` | `torch`, `memray` | Miles Native | +| `--profile-target` | Training components to profile (accepts multiple). | `train_overall` | `train_overall`, `train_actor`, `train_log_probs` | Miles Native | + +--- + +## Environment Variables + +Miles recognizes several environment variables for advanced configuration. + +| Variable | Description | Source | +| :--- | :--- | :--- | +| `MILES_EXPERIMENTAL_ROLLOUT_REFACTOR` | Set to `1` to enable the experimental rollout implementation refactor. | Miles Native | +| `ENABLE_ROUTING_REPLAY` | Internal variable used to enable MoE routing consistency checks during training. | Miles Native | +| `TENSORBOARD_DIR` | Base directory for Tensorboard logs. | Miles Native | +| `MILES_HOST_IP` | Overrides the host IP used for distributed communication. | Miles Native | +| `PYTHONPATH` | Must include the path to your `Megatron-LM` installation when using the Megatron backend. | System | +| `NCCL_SOCKET_IFNAME` | Specifies the network interface for NCCL communication (e.g., `eth0`, `bond0`). | System | +| `GLOO_SOCKET_IFNAME` | Specifies the network interface for GLOO communication. | System | +| `NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME` | Network interface for NVSHMEM bootstrap. | System | + +--- + +## Multi-Turn and Agentic Arguments + +Arguments for managing interactions and tools. Only available when `MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1` and the rollout/generate function exposes `add_arguments`. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--generate-max-turns` | Maximum number of turns in a conversation. | `16` | Type: int | Miles Native | +| `--generate-tool-specs-path` | Path to the tool specifications (JSON). | `None` | Type: str | Miles Native | +| `--generate-tool-call-parser` | The parser used to extract tool calls from text. | `None` | Type: str | Miles Native | +| `--generate-execute-tool-function-path` | Path to the function that executes the tool. | `None` | Type: str | Miles Native | +| `--generate-multi-samples` | Whether to generate multiple samples within one turn. | `False` | bool flag (set to enable) | Miles Native | + +--- + +## Advanced Developer Hooks and CI + +Hooks for custom logic and Continuous Integration testing flags. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--custom-megatron-init-path` | Path to custom Megatron initialization logic. [Ref](../get_started/customization.md#17-megatron-hooks) | `None` | Type: str | Miles Native | +| `--custom-megatron-before-log-prob-hook-path` | Hook called before calculating log probabilities. [Ref](../get_started/customization.md#17-megatron-hooks) | `None` | Type: str | Miles Native | +| `--custom-megatron-before-train-step-hook-path` | Hook called before each training step. [Ref](../get_started/customization.md#17-megatron-hooks) | `None` | Type: str | Miles Native | +| `--ci-test` | Enable Continuous Integration testing mode. | `False` | bool flag (set to enable) | Miles Native | +| `--ci-disable-kl-checker` | Disable KL divergence sanity checks in CI. | `False` | bool flag (set to enable) | Miles Native | +| `--ci-metric-checker-key` | Metric key to monitor for pass/fail in CI. | `None` | Type: str | Miles Native | +| `--ci-metric-checker-threshold` | Pass/fail threshold (minimum value) for the monitored metric. | `None` | Type: float | Miles Native | +| `--ci-save-grad-norm` | Path to save gradient norms for CI comparison. | `None` | Type: str | Miles Native | +| `--ci-load-grad-norm` | Path to load gradient norms for CI verification. | `None` | Type: str | Miles Native | + +--- + +## Miscellaneous and System + +General arguments for infrastructure and configuration overrides. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--http-proxy` | HTTP proxy server for remote reward model calls. | `None` | Type: str | Miles Native | +| `--use-distributed-post` | Use distributed POST requests for remote reward models. | `False` | bool flag (set to enable) | Miles Native | +| `--custom-config-path` | Path to the YAML config for custom function arguments. | `None` | Type: str | Miles Native | +| `--padded-vocab-size` | Manually specify the vocab size for padding. | `None` | Type: int | Miles Native | diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index b1088ce643..8aa63c23fb 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -29,12 +29,19 @@ Below is a summary of all available customization interfaces and their purposes. | [`--custom-megatron-before-log-prob-hook-path`](#17-megatron-hooks) | Custom logic before log probability computation. | | [`--custom-megatron-before-train-step-hook-path`](#17-megatron-hooks) | Custom logic before each training step. | | [`--miles-router-middleware-paths`](#18-miles-router-middleware---miles-router-middleware-paths) | Add custom middleware to miles router. | +| [`--custom-model-provider-path`](#20-model-provider---custom-model-provider-path) | Path to a custom function that replaces the default model provider. | ## Detailed Interface Reference ### 1. Rollout Function (`--rollout-function-path`) -**Default**: `miles.rollout.sglang_rollout.generate_rollout` +**Default**: +```python +if enable_experimental_rollout_refactor(): + miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn +else: + miles.rollout.sglang_rollout.generate_rollout +``` **Purpose**: Override the entire rollout generation logic. @@ -417,3 +424,19 @@ Stabilize MoE RL training by recording and replaying expert routing decisions to | `--use-routing-replay` | Forward-backward routing consistency in training. ([arXiv:2507.18071](https://arxiv.org/abs/2507.18071)) | | `--use-rollout-routing-replay` | R3: Replay routing from rollout during training. **Requires `--use-miles-router`**. ([arXiv:2510.11370](https://arxiv.org/abs/2510.11370)) | +For detailed explanation of R3 and MilesRouter, see [Miles Router](../advanced/miles-router.md). + +--- + +### 20. Model Provider (`--custom-model-provider-path`) + +**Default**: `None` + +**Purpose**: Path to a custom function that replaces the default model provider (e.g., `'my_module.my_provider'`). The function must return a GPTModel. + +**Signature**: +```python +def custom_model_provider(pre_process: bool, post_process: bool, vp_stage: int | None = None) -> GPTModel +``` + + diff --git a/docs/en/get_started/gen_endpoint.md b/docs/en/get_started/gen_endpoint.md new file mode 100644 index 0000000000..a8e9d2ae1e --- /dev/null +++ b/docs/en/get_started/gen_endpoint.md @@ -0,0 +1,104 @@ +# Gen Endpoint Usage + +This document covers generate_hub usage for the `/generate` endpoint. For OpenAI +format usage, see `docs/en/get_started/oai_endpoint.md`. + +## 1. What generate_hub is + +`miles/rollout/generate_hub/` contains reusable generate functions that plug into +rollout through `--custom-generate-function-path`. They use the refactor +interface (`GenerateFnInput` / `GenerateFnOutput`) and are meant to be composed +with custom agents, tool use, or multi-turn logic. + +Key types and entry points: + +- `miles/rollout/base_types.py` defines `GenerateFnInput` and `GenerateFnOutput`. +- `miles/rollout/inference_rollout/inference_rollout_common.py` builds a + `GenerateState` and calls the generate function. +- `MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1` enables the new path (see + `examples/openai_format/*.sh`). + +## 2. Generate function basics + +The intended abstraction is: + +1. The rollout engine provides a `GenerateFnInput` with: + - `state` (tokenizer, processor, args, sampling defaults) + - `sample` (prompt, current tokens, response, status) + - `sampling_params` (max_new_tokens, temperature, top_p, etc.) +2. The generate function focuses only on: + - turning the sample into a model request + - executing the request (SGLang `/generate` or OpenAI format) + - updating the `Sample` with tokens, logprobs, loss mask, and status + +Minimal skeleton: + +```python +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.utils.types import Sample + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + args = input.args + sample = input.sample + sampling_params = input.sampling_params + + # 1) build request from prompt and sampling params + # 2) call backend + # 3) update sample.tokens, sample.response, sample.rollout_log_probs, sample.loss_mask, sample.status + + return GenerateFnOutput(samples=sample) + +def _add_arguments(parser): + parser.add_argument("--your-arg", type=str) + +generate.add_arguments = _add_arguments +``` + +Notes: + +- `generate.add_arguments = _add_arguments` is the hook for custom CLI flags. + Add any arguments you want; they are parsed into `input.args` and can be used + freely by your generator without touching rollout core code. +- Use `compute_prompt_ids_from_sample` and `compute_request_payload` from + `miles/rollout/generate_utils/generate_endpoint_utils.py` to build requests + for the `/generate` endpoint. +- If you want to return multiple samples, set `--generate-multi-samples` and + return a list. + +## 3. /generate endpoint examples + +Examples (library side): + +- `miles/rollout/generate_hub/single_turn.py` + - Single-turn generation using `/generate`. + - Works with text or multimodal prompts. +- `miles/rollout/generate_hub/multi_turn.py` + - Multi-turn tool calling using `/generate`. + - CLI flags: `--generate-max-turns`, `--generate-tool-specs-path`, + `--generate-tool-call-parser`, `--generate-execute-tool-function-path`, + `--generate-multi-samples`. +- `miles/rollout/generate_hub/benchmarkers.py` + - Benchmark helper that forces random output sequence length (OSL). + +## 4. Radix tree middleware helper (full TITO for `/generate`) + +Full TITO caching for the `/generate` endpoint is provided by the radix tree +middleware. This is unrelated to session middleware and works only on the +`/generate` and `/retrieve_from_text` routes. + +What it does: + +- Caches token ids and logprobs by prompt text in a radix tree. +- Lets `/generate` requests include `input_tokens` and avoids re-tokenization. +- Enables `update_sample_from_response` to fetch tokens via + `/retrieve_from_text` for training. + +How to enable: + +``` +--use-miles-router \ +--miles-router-middleware-paths miles.router.middleware_hub.radix_tree_middleware.RadixTreeMiddleware +``` + +Make sure `--sglang-router-ip` and `--sglang-router-port` point to the Miles +Router so `/retrieve_from_text` can be reached during rollout. diff --git a/docs/en/get_started/oai_endpoint.md b/docs/en/get_started/oai_endpoint.md new file mode 100644 index 0000000000..9b882ec846 --- /dev/null +++ b/docs/en/get_started/oai_endpoint.md @@ -0,0 +1,136 @@ +# OAI Endpoint Usage + +This document explains how to use the OpenAI-format chat endpoint through Miles +Router sessions. For the `/generate` endpoint, see +`docs/en/get_started/gen_endpoint.md`. + +## 1. Minimal `run_agent` loop + +Your `run_agent` receives a session-scoped `base_url`. Send OpenAI-format chat +requests to `base_url/v1/chat/completions` and pass the `messages` list as the +prompt. + +Minimal custom agent example: + +```python +from miles.utils.http_utils import post + +async def run_agent(base_url: str, prompt, request_kwargs: dict | None = None) -> None: + payload = {"model": "default", "messages": prompt, **(request_kwargs or {})} + await post(f"{base_url}/v1/chat/completions", payload) +``` + +Notes for `run_agent`: + +- `base_url` already includes the session path (e.g. `/sessions/`), so you + should not manually add the session id. Just append the OpenAI route. +- `request_kwargs` already contains the default sampling settings from + `agentic_tool_call.build_chat_request_kwargs`, so you can directly expand it + into the chat request payload. +- If you pass rollout sampling params, `max_new_tokens` will be mapped to the + OpenAI `max_tokens` field before the request is sent. +- If you need structured parsing payloads, use SGLang's + `ChatCompletionRequest`-compatible format. It is compatible with native OpenAI + fields, plus extra SGLang parameters. + +## 2. OpenAI chat messages and the basic request + +The OpenAI-format chat API uses a list of `messages`, each with a `role` and +`content`. + +Minimal request shape: + +```json +{ + "model": "default", + "messages": [ + {"role": "system", "content": "You are a concise assistant."}, + {"role": "user", "content": "Answer with one word: 2+2?"} + ], + "logprobs": true, + "logprob_start_len": 0 +} +``` + +You can pass any OpenAI-compatible parameters in the payload, or any +SGLang-compatible `ChatCompletionRequest` parameters. Note: +`logprobs=True` and `logprob_start_len=0` are required to extract token ids and +logprobs for TITO (see below), and are already set in `request_kwargs`. + +## 3. Quickstart index + +If you just want something runnable, start here: + +Generator entry point: + +- `miles/rollout/generate_hub/agentic_tool_call.py` + - OpenAI-format agent loop via router sessions. + +OpenAI-format examples that use `agentic_tool_call.generate`: + +- `examples/openai_format/dapo_math.py` + - Single-turn OpenAI format agent (DAPO math). +- Launcher scripts: + - `examples/openai_format/run-qwen3-4B-dapo-math.sh` + + +You can customize generate function like: +``` +CUSTOM_ARGS=( + --custom-generate-function-path miles.rollout.generate_hub.agentic_tool_call.generate + --custom-agent-function-path examples.openai_format.dapo_math.run_agent +) +``` + +For OpenAI format, do not add `--apply-chat-template`; the +prompt must remain a `messages` list. + +More agentic multi-turn examples will come in the future. + +## 4. Further customization (OpenAI wrapper generate function) + +For OpenAI-format rollout, the key generate function is +`miles/rollout/generate_hub/agentic_tool_call.generate`. It is a thin wrapper +around your custom agent: + +1. Create a session on Miles Router and build a session-scoped `base_url`. +2. Call the custom agent (from `--custom-agent-function-path`) to send one or + more chat requests to `base_url/v1/chat/completions`, typically using + `prompt` and `request_kwargs`. +3. Collect session records via `OpenAIEndpointTracer`. +4. Convert records into `Sample` objects with + `compute_samples_from_openai_records`. + +If you want general generate-function customization beyond the OpenAI wrapper, +see `docs/en/get_started/gen_endpoint.md`. + +## 5. TITO (token-in token-out) + +TITO needs two things: + +1. Prompt token ids returned by the backend (e.g. `input_logprobs` or + `input_token_ids`). These can come from tokenizing `messages`, or from a + provided `input_ids` payload. +2. Output token ids returned by the backend (`logprobs.content[*].token_id`). + +By default, the session middleware forwards raw `messages` to SGLang. With +`logprobs=True` and `logprob_start_len=0`, SGLang tokenizes the prompt and +returns prompt token ids along with output token ids, which is sufficient for +TITO. You do not need to provide `input_ids`. + +If you prefer to send `input_ids` to SGLang, you can enable token input for chat +completions in the router via +`--miles-router-enable-token-input-for-chat-completions`. The session route +will tokenize `messages` and inject `input_ids` before proxying to SGLang. The +backend still returns prompt token ids, and they should match any `input_ids` +you supplied. + +We can save multi-turn samples within a single session, but we still do not +inherit or reuse prompt tokens across turns. Each request is tokenized +independently, regardless of which option you choose. + +### Common pitfalls + +- Ensure `logprobs=True` in OpenAI chat requests, and ensure + `logprob_start_len=0` if you rely on SGLang to return prompt token ids. +- Ensure the tokenizer matches `--hf-checkpoint`. diff --git a/docs/en/index.rst b/docs/en/index.rst index afafc67966..3f08d98d02 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -41,6 +41,7 @@ miles is the RL-framework behind GLM-4.7, GLM-4.6 and GLM-4.5. Apart from models :caption: Advanced Features _examples_synced/reproducibility/README.md + advanced/miles-router.md advanced/speculative-decoding.md advanced/fault-tolerance.md advanced/arch-support-beyond-megatron.md diff --git a/examples/openai_format/__init__.py b/examples/openai_format/__init__.py new file mode 100644 index 0000000000..30436bcc42 --- /dev/null +++ b/examples/openai_format/__init__.py @@ -0,0 +1 @@ +"""OpenAI format examples.""" diff --git a/examples/openai_format/dapo_math.py b/examples/openai_format/dapo_math.py new file mode 100644 index 0000000000..dae0c4ed9a --- /dev/null +++ b/examples/openai_format/dapo_math.py @@ -0,0 +1,19 @@ +""" +Custom agent example: single-turn DAPO math via OpenAI endpoints. +""" + +from __future__ import annotations + +from typing import Any + + +# Notice: only function based agent can use post API in miles +from miles.utils.http_utils import post + + +async def run_agent( + base_url: str, prompt: list[dict[str, Any]] | str, request_kwargs: dict[str, Any] | None = None +) -> None: + request_kwargs = request_kwargs or {} + payload = {"model": "default", "messages": prompt, "logprobs": True, **request_kwargs} + await post(base_url + "/v1/chat/completions", payload) diff --git a/examples/openai_format/run-qwen3-4B.sh b/examples/openai_format/run-qwen3-4B.sh new file mode 100644 index 0000000000..d6bbfddec6 --- /dev/null +++ b/examples/openai_format/run-qwen3-4B.sh @@ -0,0 +1,158 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 +export MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../scripts/models/qwen3-4B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/shared/Qwen3-4B + #--hf-checkpoint /root/shared/Qwen3-4B-FP8 + --ref-load /root/shared/Qwen3-4B_torch_dist +# --load /root/shared/Qwen3-4B_miles/ + --save /root/shared/Qwen3-4B_miles/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --rollout-shuffle + --rm-type deepscaler + --num-rollout 200 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 8192 + --rollout-temperature 1 + + --global-batch-size 256 + --balance-data +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 16 + --eval-max-response-len 16384 + --eval-top-p 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 1 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project miles-oai + --wandb-group qwen3-4B-test + --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.8 +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash + --use-miles-router +) + +CUSTOM_ARGS=( + --custom-generate-function-path miles.rollout.generate_hub.agentic_tool_call.generate + --custom-agent-function-path examples.openai_format.dapo_math.run_agent +) + +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export CUDA_VISIBLE_DEVICES=4,5,6,7 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} diff --git a/examples/true_on_policy/README.md b/examples/true_on_policy/README.md index 620564d410..553d2de64d 100644 --- a/examples/true_on_policy/README.md +++ b/examples/true_on_policy/README.md @@ -1,6 +1,6 @@ # True On-Policy between Training and Inference -True on-policy ensures that the log probs generated by inference engine (SGLang) is strictly equal to the one generated by the training Engine. Here's our [blog](https://lmsys.org/blog/2025-12-03-miles-fsdp/) for more details. +True on-policy ensures that the log probs generated by inference engine (SGLang) is strictly equal to the one generated by the training Engine. ## Examples diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index a92198a674..f196164876 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -228,7 +228,7 @@ def pad_func(experts, pad): # TODO: maybe extract a common process function for here and get_batch? rollout_routed_experts = [slice_with_cp(r, pad_func, self.parallel_state) for r in rollout_routed_experts] rollout_routed_experts = torch.cat(rollout_routed_experts, dim=0) - pad_size = self.parallel_state.dp_size * self.args.data_pad_size_multiplier + pad_size = self.parallel_state.tp_size * self.args.data_pad_size_multiplier pad = (pad_size - rollout_routed_experts.size(0) % pad_size) % pad_size if pad != 0: rollout_routed_experts = pad_func(rollout_routed_experts, pad) diff --git a/miles/backends/megatron_utils/arguments.py b/miles/backends/megatron_utils/arguments.py index 0eb2bcd444..24496011b1 100644 --- a/miles/backends/megatron_utils/arguments.py +++ b/miles/backends/megatron_utils/arguments.py @@ -14,7 +14,8 @@ def set_default_megatron_args(args): # TODO: maybe change this after megatron has good fp8 support args.bf16 = not args.fp16 # placeholders - args.seq_length = 4096 + if args.seq_length is None: + args.seq_length = 4096 args.max_position_embeddings = args.seq_length # TODO: revisit this when megatron(dev) have solved the optimizer-cpu-offload ckpt saving bug args.dist_ckpt_save_pre_mcore_014 = True diff --git a/miles/backends/megatron_utils/kernels/int4_qat/setup.py b/miles/backends/megatron_utils/kernels/int4_qat/setup.py index b27967bc98..8715dd7b8a 100644 --- a/miles/backends/megatron_utils/kernels/int4_qat/setup.py +++ b/miles/backends/megatron_utils/kernels/int4_qat/setup.py @@ -1,3 +1,4 @@ +import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension import torch @@ -10,6 +11,16 @@ arch_list.append(f"{major}.{minor}") arch_list = sorted(set(arch_list)) +# Fallback to TORCH_CUDA_ARCH_LIST env var or default architectures when GPU is not available +if not arch_list: + env_arch = os.environ.get("TORCH_CUDA_ARCH_LIST", "") + if env_arch: + # Parse TORCH_CUDA_ARCH_LIST format: "7.0 7.5 8.0 8.6 9.0+PTX" + arch_list = [a.strip().replace("+PTX", "") for a in env_arch.replace(";", " ").split() if a.strip()] + else: + # Default to common architectures (Volta, Turing, Ampere, Ada, Hopper) + arch_list = ["8.0", "8.6", "8.9", "9.0"] + setup( name="fake_int4_quant_cuda", ext_modules=[ @@ -31,7 +42,8 @@ + [ f'-gencode=arch=compute_{arch.replace(".", "")},code=sm_{arch.replace(".", "")}' for arch in arch_list - ], + ] + + ["-gencode=arch=compute_90a,code=sm_90a"], }, ) ], diff --git a/miles/backends/training_utils/data.py b/miles/backends/training_utils/data.py index 67bb30108d..c7b0707425 100644 --- a/miles/backends/training_utils/data.py +++ b/miles/backends/training_utils/data.py @@ -121,7 +121,7 @@ def get_batch( tokens = batch["tokens"] # use 0 as the pad token id should be fine? pad_token_id = 0 - pad_size = parallel_state.dp_size * pad_multiplier + pad_size = parallel_state.tp_size * pad_multiplier # for cp, we need all tokens to calculate logprob batch["unconcat_tokens"] = tokens diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 1a2f176028..2b9891f64e 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -136,7 +136,15 @@ def log_rollout_data( # NOTE: Here we have to do the clone().detach(), otherwise the tensor will be # modified in place and will cause problem for the next rollout. val = torch.cat(val).clone().detach() - if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]: + if key in [ + "log_probs", + "ref_log_probs", + "rollout_log_probs", + "returns", + "advantages", + "values", + "entropy", + ]: sum_of_sample_mean = get_sum_of_sample_mean( total_lengths, response_lengths, diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 79c6649be0..27211845d8 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -13,8 +13,15 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from miles.backends.sglang_utils.sglang_engine import SGLangEngine -from miles.rollout.base_types import call_rollout_fn +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnTrainInput, + call_rollout_fn, +) +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils import tracking_utils +from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.health_monitor import RolloutHealthMonitor from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client from miles.utils.iter_utils import group_by @@ -53,8 +60,14 @@ def __init__(self, args, pg): data_source_cls = load_function(self.args.data_source_path) self.data_source = data_source_cls(args) - self.generate_rollout = load_function(self.args.rollout_function_path) - self.eval_generate_rollout = load_function(self.args.eval_function_path) + self.use_experimental_refactor = enable_experimental_rollout_refactor() + if self.use_experimental_refactor: + input = RolloutFnConstructorInput(args=args, data_source=self.data_source) + self.generate_rollout = load_rollout_function(input, self.args.rollout_function_path) + self.eval_generate_rollout = load_rollout_function(input, self.args.eval_function_path) + else: + self.generate_rollout = load_function(self.args.rollout_function_path) + self.eval_generate_rollout = load_function(self.args.eval_function_path) self.custom_reward_post_process_func = None if self.args.custom_reward_post_process_path is not None: self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path) @@ -142,7 +155,12 @@ def eval(self, rollout_id): return self.health_monitoring_resume() - result = call_rollout_fn(self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True) + if self.use_experimental_refactor: + result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) + else: + result = call_rollout_fn( + self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True + ) data = result.data self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True) metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics) @@ -224,7 +242,12 @@ def _get_rollout_data(self, rollout_id): ) metrics = None else: - data = call_rollout_fn(self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False) + if self.use_experimental_refactor: + data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id)) + else: + data = call_rollout_fn( + self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False + ) metrics = data.metrics data = data.samples # flatten the data if it is a list of lists diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index faa85c7269..c2644e87f9 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -1,22 +1,86 @@ +from __future__ import annotations + +from argparse import Namespace from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any +from miles.rollout.data_source import DataSource from miles.utils.types import Sample +if TYPE_CHECKING: + from miles.rollout.inference_rollout.inference_rollout_common import GenerateState + + +@dataclass(frozen=True) +class RolloutFnConstructorInput: + args: Namespace + # TODO may refactor DataSource API + data_source: DataSource + + +@dataclass(frozen=True) +class RolloutFnBaseInput: + rollout_id: int + + @property + def evaluation(self): + raise NotImplementedError + + +# subclassing for different data in the future +@dataclass(frozen=True) +class RolloutFnTrainInput(RolloutFnBaseInput): + @property + def evaluation(self): + return False + +@dataclass(frozen=True) +class RolloutFnEvalInput(RolloutFnBaseInput): + @property + def evaluation(self): + return True + + +# TODO make it frozen @dataclass class RolloutFnTrainOutput: samples: list[list[Sample]] metrics: dict[str, Any] = None +# TODO make it frozen @dataclass class RolloutFnEvalOutput: data: dict[str, dict[str, Any]] metrics: dict[str, Any] = None +RolloutFnInput = RolloutFnTrainInput | RolloutFnEvalInput +RolloutFnOutput = RolloutFnTrainOutput | RolloutFnEvalOutput + + +@dataclass(frozen=True) +class GenerateFnInput: + state: GenerateState + sample: Sample + sampling_params: dict[str, Any] + evaluation: bool + + @property + def args(self) -> Namespace: + return self.state.args + + +@dataclass(frozen=True) +class GenerateFnOutput: + # One generate may lead to multiple samples, such as multi-agent, tree-like exploration, or + # multi-turn with removing thinking tokens. + samples: Sample | list[Sample] + + def call_rollout_fn(fn, *args, evaluation: bool, **kwargs): + """Legacy rollout function call interface. Used when MILES_EXPERIMENTAL_ROLLOUT_REFACTOR is disabled.""" output = fn(*args, **kwargs, evaluation=evaluation) # compatibility for legacy version diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py new file mode 100644 index 0000000000..3e1ae9ef54 --- /dev/null +++ b/miles/rollout/generate_hub/agentic_tool_call.py @@ -0,0 +1,69 @@ +""" +Simple agentic demo with tool calling. +""" + +import argparse +from collections.abc import Callable +from typing import Any + +from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.openai_endpoint_utils import ( + OpenAIEndpointTracer, + compute_samples_from_openai_records, +) +from miles.rollout.generate_utils.sample_utils import merge_samples +from miles.utils.misc import load_function + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + tracer = await OpenAIEndpointTracer.create(input.args) + + custom_agent_function: Callable = load_function(input.args.custom_agent_function_path) + assert ( + custom_agent_function is not None + ), f"Custom agent function {input.args.custom_agent_function_path} not found" + await custom_agent_function( + base_url=tracer.base_url, + prompt=input.sample.prompt, + request_kwargs=build_chat_request_kwargs(input.sampling_params), + ) + + records = await tracer.collect_records() + samples = compute_samples_from_openai_records(input.sample, records, input.state.tokenizer) + if not input.args.generate_multi_samples: + samples = merge_samples(samples, input.state.tokenizer) + return GenerateFnOutput(samples=samples) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--custom-agent-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true", default=False) + + +generate.add_arguments = _add_arguments + + +# Process keys to match ChatCompletionRequest input +def build_chat_request_kwargs(sampling_params: dict[str, Any]) -> dict[str, Any]: + request_kwargs = dict(sampling_params) + key_map = { + "max_new_tokens": "max_tokens", + "min_new_tokens": "min_tokens", + "sampling_seed": "seed", + } + for src, dst in key_map.items(): + if src in request_kwargs: + if dst not in request_kwargs: + request_kwargs[dst] = request_kwargs[src] + request_kwargs.pop(src, None) + + # Notice: Here we force the inference backend to return token information and start from 0 + # The start len should be 0 to make sure prompt token ids and be correctly returned from SGLang. + request_kwargs["logprobs"] = True + request_kwargs["logprob_start_len"] = 0 + + reserved_keys = {"model", "messages"} + allowed_keys = set(ChatCompletionRequest.model_fields) - reserved_keys + return {key: value for key, value in request_kwargs.items() if key in allowed_keys and value is not None} diff --git a/miles/rollout/generate_hub/multi_turn.py b/miles/rollout/generate_hub/multi_turn.py new file mode 100644 index 0000000000..97814ecb3d --- /dev/null +++ b/miles/rollout/generate_hub/multi_turn.py @@ -0,0 +1,88 @@ +""" +Simple multi-turn generation with tool calling. +""" + +import argparse +from copy import deepcopy + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.generate_endpoint_utils import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.rollout.generate_utils.tool_call_utils import ( + create_tool_call_parser, + execute_tool_calls, + update_sample_with_tool_responses, +) +from miles.utils.http_utils import post +from miles.utils.misc import load_function + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + # ----------------------- Setup ------------------------- + + args = input.args + sample = deepcopy(input.sample) + tokenizer = input.state.tokenizer + assert not args.partial_rollout, "Partial rollout is not supported" + + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + execute_tool_function = load_function(args.generate_execute_tool_function_path) + + tool_specs = load_function(args.generate_tool_specs_path) + tool_call_parser = create_tool_call_parser(tool_specs, args.generate_tool_call_parser) + + multi_samples = [] + + # ----------------------- Initial prompts ------------------------- + + prompt_tokens_ids = compute_prompt_ids_from_sample(input.state, sample, tools=tool_specs) + + sample.tokens = prompt_tokens_ids.copy() + + for _turn in range(args.generate_max_turns): + # ----------------------- Call inference endpoint ------------------------- + + payload, halt_status = compute_request_payload(args, sample.tokens, input.sampling_params) + if payload is None: + sample.status = halt_status + if args.generate_multi_samples and multi_samples: + multi_samples[-1].status = halt_status + break + + if args.generate_multi_samples: + sample = deepcopy(input.sample) + + output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output, update_loss_mask=True) + + if args.generate_multi_samples: + multi_samples.append(deepcopy(sample)) + + if output["meta_info"]["finish_reason"]["type"] in ("abort", "length"): + break + + # ----------------------- Execute tools ------------------------- + + _, tool_calls = tool_call_parser.parse_non_stream(output["text"]) + if len(tool_calls) == 0: + break + + tool_messages = await execute_tool_calls(tool_calls, execute_tool_function) + update_sample_with_tool_responses(sample, tool_messages, tokenizer=tokenizer) + + return GenerateFnOutput(samples=multi_samples if args.generate_multi_samples else sample) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-max-turns", type=int, default=16) + parser.add_argument("--generate-tool-specs-path", type=str) + parser.add_argument("--generate-tool-call-parser", type=str) + parser.add_argument("--generate-execute-tool-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true") + + +generate.add_arguments = _add_arguments diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py new file mode 100644 index 0000000000..5c0a15b5b4 --- /dev/null +++ b/miles/rollout/generate_hub/single_turn.py @@ -0,0 +1,46 @@ +""" +Simple single-turn generation. +""" + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.generate_endpoint_utils import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.utils.http_utils import post +from miles.utils.types import Sample + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + args = input.args + sample = input.sample + sampling_params = input.sampling_params + assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}" + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + prompt_ids = compute_prompt_ids_from_sample(input.state, sample) + + # Handle Partial Rollout resuming + if len(sample.response) > 0: + input_ids = sample.tokens + sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) + + assert sampling_params["max_new_tokens"] >= 0 + if sampling_params["max_new_tokens"] == 0: + sample.status = Sample.Status.TRUNCATED + return GenerateFnOutput(samples=sample) + else: + input_ids = prompt_ids + + payload, halt_status = compute_request_payload( + args, input_ids=input_ids, sampling_params=sampling_params, multimodal_inputs=sample.multimodal_inputs + ) + if payload is None: + sample.status = halt_status + return GenerateFnOutput(samples=sample) + + output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output) + + return GenerateFnOutput(samples=sample) diff --git a/miles/rollout/generate_utils/__init__.py b/miles/rollout/generate_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/miles/rollout/generate_utils/generate_endpoint_utils.py b/miles/rollout/generate_utils/generate_endpoint_utils.py new file mode 100644 index 0000000000..a91d71f1de --- /dev/null +++ b/miles/rollout/generate_utils/generate_endpoint_utils.py @@ -0,0 +1,112 @@ +""" +Utils to integrate SGLang's `/generate` endpoint with RL things like Sample. +""" + +from copy import deepcopy +from typing import Any + +import numpy as np +import pybase64 + +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.types import Sample + + +# Make this an isolated function because users may want to compute their own +def compute_prompt_ids_from_sample(state, sample, tools=None): + prompt = sample.prompt + + if state.processor: + processor_output = state.processor(text=prompt, **sample.multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + + # TODO shall we move it to other places? then can make this function immutable + sample.multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + + return prompt_ids + else: + if not isinstance(prompt, str): + prompt = state.tokenizer.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True, tools=tools + ) + + return state.tokenizer.encode(prompt, add_special_tokens=False) + + +def compute_request_payload( + args, + input_ids: list[int], + sampling_params: dict, + multimodal_inputs: dict | None = None, +) -> tuple[dict[str, Any] | None, Sample.Status | None]: + sampling_params = deepcopy(sampling_params) + max_new_tokens = sampling_params.pop("max_new_tokens", args.rollout_max_response_len) + if x := args.rollout_max_context_len: + max_new_tokens = min(max_new_tokens, x - len(input_ids)) + if max_new_tokens <= 0: + return None, Sample.Status.TRUNCATED + + payload = { + "input_ids": input_ids, + "sampling_params": {**sampling_params, "max_new_tokens": max_new_tokens}, + "return_logprob": True, + "return_routed_experts": args.use_rollout_routing_replay, + } + if image_data := (multimodal_inputs or {}).get("images"): + payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] + + return payload, None + + +async def update_sample_from_response( + args, sample: Sample, payload: dict, output: dict, update_loss_mask: bool = False +): + # Initialize sample.tokens for the first turn + if (len(sample.response) == 0) and not sample.tokens: + sample.tokens = payload["input_ids"] + + if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: + from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + + # TODO may rename to match + await postprocess_sample_with_radix_tree(args, sample, output) + + assert not update_loss_mask, "This code branch has not implemented update_loss_mask" + else: + if x := output["meta_info"].get("output_token_logprobs"): + new_response_tokens = [item[1] for item in x] + new_response_log_probs = [item[0] for item in x] + else: + new_response_tokens, new_response_log_probs = [], [] + + # Update sample with tokens directly - avoiding re-tokenization + sample.tokens = sample.tokens + new_response_tokens + sample.response_length += len(new_response_tokens) + sample.response += output["text"] + + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += new_response_log_probs + + if update_loss_mask: + if sample.loss_mask is None: + sample.loss_mask = [] + sample.loss_mask += [1] * len(new_response_tokens) + + # TODO handle multi-turn cases (may need concat instead of assignment) + sample.rollout_routed_experts = _get_rollout_routed_experts_from_response(args, sample, output) + + # TODO may unify (currently there are both methods inside Sample and separate functions) + sample.update_from_meta_info(args, output["meta_info"]) + + +def _get_rollout_routed_experts_from_response(args, sample, output): + info = output["meta_info"].get("routed_experts") + if info is None: + return None + + x = np.frombuffer(pybase64.b64decode(info.encode("ascii")), dtype=np.int32) + x = x.reshape(len(sample.tokens) - 1, args.num_layers, args.moe_router_topk) + return x diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py new file mode 100644 index 0000000000..f5bf52d6ca --- /dev/null +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -0,0 +1,80 @@ +""" +Utilities for the OpenAI endpoint +""" + +import logging +from argparse import Namespace +from copy import deepcopy + +from miles.router.session.sessions import GetSessionResponse, SessionRecord +from miles.utils.http_utils import post +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class OpenAIEndpointTracer: + def __init__(self, router_url: str, session_id: str): + self.router_url = router_url + self.session_id = session_id + self.base_url = f"{router_url}/sessions/{session_id}" + + @staticmethod + async def create(args: Namespace): + router_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" + response = await post(f"{router_url}/sessions", {}, action="post") + session_id = response["session_id"] + return OpenAIEndpointTracer(router_url=router_url, session_id=session_id) + + async def collect_records(self) -> list[SessionRecord]: + try: + response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="get") + except Exception as e: + logger.warning(f"Failed to get session {self.session_id} records: {e}") + raise + response = GetSessionResponse.model_validate(response) + records = response.records + + try: + await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="delete") + except Exception as e: + logger.warning(f"Failed to delete session {self.session_id} after collecting records: {e}") + + return records or [] + + +def compute_samples_from_openai_records(input_sample: Sample, records: list[SessionRecord], tokenizer) -> list[Sample]: + return [_compute_sample_from_openai_record(input_sample, record, tokenizer) for record in records] + + +def _compute_sample_from_openai_record(input_sample: Sample, record: SessionRecord, tokenizer) -> Sample: + # TODO may refine after @guapisolo's implementation + choice = record.response["choices"][0] + + input_token_ids = choice["input_token_ids"] + output_token_ids = [item["token_id"] for item in choice["logprobs"]["content"]] + output_log_probs = [item["logprob"] for item in choice["logprobs"]["content"]] + + sample = deepcopy(input_sample) + # sample.tokens = record.request["input_ids"] + output_token_ids + request_input_ids = record.request.get("input_ids") + if request_input_ids is not None: + assert ( + request_input_ids == input_token_ids + ), "for prompt part, input_ids return by sglang should match with the request input_ids" + sample.tokens = input_token_ids + output_token_ids + sample.rollout_log_probs = output_log_probs + sample.response = tokenizer.decode(output_token_ids) + sample.response_length = len(output_token_ids) + sample.loss_mask = [1] * len(output_token_ids) + + # TODO unify with Sample.update_from_meta_info + match choice["finish_reason"]: + case "stop" | "tool_calls": + sample.status = Sample.Status.COMPLETED + case "length": + sample.status = Sample.Status.TRUNCATED + case "abort": + sample.status = Sample.Status.ABORTED + + return sample diff --git a/miles/rollout/generate_utils/sample_utils.py b/miles/rollout/generate_utils/sample_utils.py new file mode 100644 index 0000000000..6a4e645be5 --- /dev/null +++ b/miles/rollout/generate_utils/sample_utils.py @@ -0,0 +1,115 @@ +from copy import deepcopy +from dataclasses import fields + +from miles.utils.types import Sample + + +def merge_samples(samples: list[Sample], tokenizer) -> Sample: + acc = samples[0] + for sample in samples[1:]: + acc = _merge_sample_pair(acc, sample, tokenizer=tokenizer) + return acc + + +def _merge_sample_pair(a: Sample, b: Sample, tokenizer) -> Sample: + """Merge two samples generated from sibling inference engine calls.""" + a, b = deepcopy(a), deepcopy(b) + + def _merge_equal_value(field): + x = getattr(a, field) + y = getattr(b, field) + assert x == y, f"{field} mismatch: a.{field}={x}, b.{field}={y}" + return x + + def _fill_defaults(sample: Sample): + if sample.loss_mask is None: + sample.loss_mask = [1] * sample.response_length + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [0.0] * sample.response_length + + _fill_defaults(a) + _fill_defaults(b) + + obs_len = len(b.tokens) - len(a.tokens) - b.response_length + obs_tokens = b.tokens[len(a.tokens) : len(a.tokens) + obs_len] + # TODO: is this acceptable? + obs_text = tokenizer.decode(obs_tokens) + + try: + a.validate() + b.validate() + assert _startswith(short=a.prompt, long=b.prompt), "b.prompt must start with a.prompt" + assert _startswith(short=a.tokens, long=b.tokens), "b.tokens must start with a.tokens" + assert obs_len > 0, f"obs_len must be > 0, got {obs_len}" + if a.rollout_routed_experts is not None: + assert a.rollout_routed_experts.shape[0] <= b.rollout_routed_experts.shape[0] + assert a.status == Sample.Status.COMPLETED, f"a.status must be COMPLETED, got {a.status}" + + return _create_with_all_fields( + Sample, + group_index=_merge_equal_value("group_index"), + index=_merge_equal_value("index"), + prompt=b.prompt, + tokens=b.tokens, + multimodal_inputs=_merge_equal_value("multimodal_inputs"), + multimodal_train_inputs=_merge_equal_value("multimodal_train_inputs"), + response=a.response + obs_text + b.response, + response_length=a.response_length + obs_len + b.response_length, + label=_merge_equal_value("label"), + reward=_merge_equal_value("reward"), + loss_mask=a.loss_mask + [0] * obs_len + b.loss_mask, + weight_versions=a.weight_versions + b.weight_versions, + rollout_log_probs=a.rollout_log_probs + [0.0] * obs_len + b.rollout_log_probs, + rollout_routed_experts=b.rollout_routed_experts, + remove_sample=_merge_equal_value("remove_sample"), + status=b.status, + metadata=_merge_equal_value("metadata"), + train_metadata=_merge_equal_value("train_metadata"), + non_generation_time=_merge_equal_value("non_generation_time"), + spec_info=_merge_spec_info(a.spec_info, b.spec_info), + prefix_cache_info=_merge_prefix_cache_info(a.prefix_cache_info, b.prefix_cache_info), + ) + except AssertionError as e: + e.add_note(f"{a=} {b=}") + raise + + +def _merge_spec_info(a: Sample.SpecInfo, b: Sample.SpecInfo) -> Sample.SpecInfo: + def _merge_plus_value(field): + return getattr(a, field) + getattr(b, field) + + return _create_with_all_fields( + Sample.SpecInfo, + spec_accept_token_num=_merge_plus_value("spec_accept_token_num"), + spec_draft_token_num=_merge_plus_value("spec_draft_token_num"), + spec_verify_ct=_merge_plus_value("spec_verify_ct"), + completion_token_num=_merge_plus_value("completion_token_num"), + ) + + +def _merge_prefix_cache_info(a: Sample.PrefixCacheInfo, b: Sample.PrefixCacheInfo) -> Sample.PrefixCacheInfo: + def _merge_plus_value(field): + return getattr(a, field) + getattr(b, field) + + return _create_with_all_fields( + Sample.PrefixCacheInfo, + cached_tokens=_merge_plus_value("cached_tokens"), + total_prompt_tokens=_merge_plus_value("total_prompt_tokens"), + ) + + +def _create_with_all_fields(cls, **kwargs): + expected = {f.name for f in fields(cls)} + actual = set(kwargs.keys()) + assert ( + expected == actual + ), f"{cls.__name__} field mismatch. Missing: {expected - actual}, Extra: {actual - expected}" + return cls(**kwargs) + + +def _startswith(*, short, long) -> bool: + if isinstance(short, str) and isinstance(long, str): + return long.startswith(short) + if isinstance(short, list) and isinstance(long, list): + return (len(long) >= len(short)) and (long[: len(short)] == short) + raise NotImplementedError diff --git a/miles/rollout/generate_utils/tool_call_utils.py b/miles/rollout/generate_utils/tool_call_utils.py new file mode 100644 index 0000000000..85ea87aeab --- /dev/null +++ b/miles/rollout/generate_utils/tool_call_utils.py @@ -0,0 +1,115 @@ +""" +Utils to handle tool calls. +""" + +import json +import uuid +from collections.abc import Callable +from typing import Any + +from openai.types.chat import ChatCompletionMessageToolCall +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.function_call_parser import FunctionCallParser + +from miles.utils.types import Sample + +_DUMMY_USER = {"role": "user", "content": "dummy"} + + +def create_tool_call_parser(tool_specs, tool_call_parser): + return FunctionCallParser( + tools=TypeAdapter(list[Tool]).validate_python(tool_specs), + tool_call_parser=tool_call_parser, + ) + + +async def execute_tool_calls( + tool_calls: list[ToolCallItem | ChatCompletionMessageToolCall], + execute_one: Callable, +) -> list[dict[str, Any]]: + tool_messages = [] + for call in tool_calls: + tool_messages.append(await _execute_tool_call(call, execute_one)) + return tool_messages + + +async def _execute_tool_call( + call: ToolCallItem | ChatCompletionMessageToolCall, execute_one: Callable +) -> dict[str, Any]: + if isinstance(call, ChatCompletionMessageToolCall): + name = call.function.name + params = json.loads(call.function.arguments) if call.function.arguments else {} + tool_call_id = call.id + elif isinstance(call, ToolCallItem): + name = call.name + params = json.loads(call.parameters) if call.parameters else {} + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + else: + raise TypeError(f"Unsupported tool call type: {type(call)}") + + result = await execute_one(name, params) + assert isinstance(result, str) + + return {"role": "tool", "tool_call_id": tool_call_id, "content": result, "name": name} + + +def update_sample_with_tool_responses(sample: Sample, tool_messages: list[dict[str, Any]], tokenizer): + next_obs_tokens_ids: list[int] = tokenize_tool_responses(tool_messages, tokenizer=tokenizer) + sample.response += tokenizer.decode(next_obs_tokens_ids) + sample.response_length += len(next_obs_tokens_ids) + sample.tokens += next_obs_tokens_ids + sample.loss_mask += [0] * len(next_obs_tokens_ids) + sample.rollout_log_probs += [0.0] * len(next_obs_tokens_ids) + + +# TODO: very naive implementation, need the to-be-implemented e2e test to validate. +def tokenize_tool_responses( + tool_messages: list[dict[str, Any]], + tokenizer, +) -> list[int]: + return _tokenize_postfix_messages(tool_messages, tokenizer) + + +def _tokenize_postfix_messages( + postfix_messages: list[dict[str, Any]], + tokenizer, +) -> list[int]: + dummy_assistant = _build_dummy_assistant(postfix_messages) + base_messages = [_DUMMY_USER, dummy_assistant] + + messages_without = base_messages + messages_with = base_messages + postfix_messages + + tokens_with = tokenizer.apply_chat_template(messages_with, tokenize=True, add_generation_prompt=True) + tokens_without = tokenizer.apply_chat_template(messages_without, tokenize=True, add_generation_prompt=False) + + assert tokens_with[: len(tokens_without)] == tokens_without, ( + f"Fail to tokenize_tool_responses caused by token prefix mismatch. " + f"This can happen for thinking model or models with special chat template, " + f"and this simple example does not support it yet, " + f"since this means we cannot have a append-only token id list. " + f"{tokens_with=} {tokens_without=} " + f"{tokenizer.decode(tokens_with)=} {tokenizer.decode(tokens_without)=} " + ) + return tokens_with[len(tokens_without) :] + + +def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, Any]: + return { + "role": "assistant", + "content": "", + "reasoning_content": " ", + "tool_calls": [ + { + "id": resp.get("tool_call_id", f"call0000{i}"), + "type": "function", + "function": { + "name": resp.get("name", "dummy_func"), + "arguments": {}, + }, + } + for i, resp in enumerate(tool_responses) + ], + } diff --git a/miles/rollout/inference_rollout/__init__.py b/miles/rollout/inference_rollout/__init__.py new file mode 100644 index 0000000000..33ccf17bfb --- /dev/null +++ b/miles/rollout/inference_rollout/__init__.py @@ -0,0 +1,2 @@ +# This is a refactor of the portions above generate-function in sglang_rollout.py, +# and is give a different name to ensure both code exist at the same time. diff --git a/miles/rollout/inference_rollout/compatibility.py b/miles/rollout/inference_rollout/compatibility.py new file mode 100644 index 0000000000..7711e0dd31 --- /dev/null +++ b/miles/rollout/inference_rollout/compatibility.py @@ -0,0 +1,84 @@ +import inspect +from collections.abc import Callable + +from miles.rollout.base_types import ( + GenerateFnInput, + GenerateFnOutput, + RolloutFnConstructorInput, + RolloutFnEvalOutput, + RolloutFnInput, + RolloutFnOutput, + RolloutFnTrainOutput, +) +from miles.utils.async_utils import run +from miles.utils.misc import load_function + + +class LegacyRolloutFnAdapter: + def __init__(self, input: RolloutFnConstructorInput, fn: Callable): + self.args = input.args + self.data_source = input.data_source + self.fn = fn + + def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: + output = self.fn(self.args, input.rollout_id, self.data_source, evaluation=input.evaluation) + + # compatibility for legacy version + if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)): + output = RolloutFnEvalOutput(data=output) if input.evaluation else RolloutFnTrainOutput(samples=output) + + return output + + +def load_rollout_function(input: RolloutFnConstructorInput, path: str): + fn = load_function(path) + + if inspect.isclass(fn): + return fn(input) + else: + return LegacyRolloutFnAdapter(input, fn) + + +def call_rollout_function(fn, input: RolloutFnInput) -> RolloutFnOutput: + output = fn(input) + + if inspect.iscoroutine(output): + output = run(output) + + return output + + +class LegacyGenerateFnAdapter: + def __init__(self, fn: Callable): + self.fn = fn + self._has_evaluation_param = "evaluation" in inspect.signature(fn).parameters + + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: + if self._has_evaluation_param: + output = await self.fn(input.args, input.sample, input.sampling_params, evaluation=input.evaluation) + else: + output = await self.fn(input.args, input.sample, input.sampling_params) + + if not isinstance(output, GenerateFnOutput): + output = GenerateFnOutput(samples=output) + + return output + + +def load_generate_function(path: str): + fn = load_function(path) + if fn is None: + return None + + if inspect.isclass(fn): + return fn() + elif _is_legacy_generate_fn(fn): + return LegacyGenerateFnAdapter(fn) + else: + return fn + + +def _is_legacy_generate_fn(fn: Callable) -> bool: + sig = inspect.signature(fn) + params = list(sig.parameters.keys()) + return len(params) >= 3 and params[0] != "input" diff --git a/miles/rollout/inference_rollout/inference_rollout_common.py b/miles/rollout/inference_rollout/inference_rollout_common.py new file mode 100644 index 0000000000..8518c6e020 --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_common.py @@ -0,0 +1,192 @@ +import asyncio +import logging +from argparse import Namespace +from copy import deepcopy +from typing import Any + +from miles.rollout.base_types import ( + GenerateFnInput, + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnEvalOutput, + RolloutFnInput, + RolloutFnOutput, + RolloutFnTrainInput, + RolloutFnTrainOutput, +) +from miles.rollout.generate_hub.single_turn import generate +from miles.rollout.inference_rollout.compatibility import load_generate_function +from miles.rollout.rm_hub import async_rm, batched_async_rm +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class GenerateState: + def __init__(self, args: Namespace) -> None: + # persistent state for the generation process + self.args = args + self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + + self.generate_fn_semaphore = asyncio.Semaphore( + args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + ) + self.sampling_params: dict[str, Any] = compute_sampling_params( + args, + temperature=args.rollout_temperature, + top_p=args.rollout_top_p, + top_k=args.rollout_top_k, + max_new_tokens=args.rollout_max_response_len, + ) + + self.generate_function = load_generate_function(args.custom_generate_function_path) or generate + + self.reset() + + def reset(self) -> None: + self.aborted = False + + +async def generate_and_rm( + state: GenerateState, + sample: Sample | list[Sample], + sampling_params: dict[str, Any], + evaluation: bool = False, +) -> Sample | list[Sample]: + args = state.args + + # mask previous off-policy generation for partial rollout + if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: + sample.loss_mask = [0] * sample.response_length + + # For samples with existing response, check if they're complete + if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED: + assert sample.response is not None + if not args.group_rm: + assert sample.reward is not None + return sample + + # generate + async with state.generate_fn_semaphore: + if state.aborted: + sample.status = Sample.Status.ABORTED + return sample + + output = await state.generate_function( + GenerateFnInput( + state=state, + sample=sample, + sampling_params=deepcopy(sampling_params), + evaluation=evaluation, + ) + ) + sample = output.samples + + # TODO change to `if not args.group_rm: do reward model` for more clarity after the refactor below + # for the rm that need the whole group, we will not do the rm here + if args.group_rm: + return sample + + # TODO: unify the two branches into one if we decide to use list as output type + # multi samples + if isinstance(sample, list): + samples = sample + if any([sample.status == Sample.Status.ABORTED for sample in samples]): + return samples + + # for multi agent system, the reward of some sample is calculated during generation. + samples_need_reward = [sample for sample in samples if sample.reward is None] + await batched_async_rm(args, samples_need_reward, inplace_set_reward_field=True) + return samples + else: + if sample.status == Sample.Status.ABORTED: + return sample + # for multi-turn environment, a reward could be assigned to the agent. + if sample.reward is None: + sample.reward = await async_rm(args, sample) + + return sample + + +async def generate_and_rm_group( + state: GenerateState, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False +) -> list[Sample]: + args = state.args + + if state.aborted: + return group + + tasks = [] + for idx, sample in enumerate(group): + current_sampling_params = sampling_params.copy() + if getattr(args, "sglang_enable_deterministic_inference", False): + current_sampling_params["sampling_seed"] = args.rollout_seed + idx + tasks.append( + asyncio.create_task(generate_and_rm(state, sample, current_sampling_params, evaluation=evaluation)) + ) + + group = await asyncio.gather(*tasks) + if state.aborted: + return group + + if args.group_rm: + await batched_async_rm(args, group, inplace_set_reward_field=True) + + return group + + +def compute_sampling_params( + args, + *, + # after unifying configuration, this can be further refactored + temperature, + top_p, + top_k, + max_new_tokens, +): + return dict( + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_new_tokens=max_new_tokens, + stop=args.rollout_stop, + stop_token_ids=args.rollout_stop_token_ids, + skip_special_tokens=args.rollout_skip_special_tokens, + no_stop_trim=True, + spaces_between_special_tokens=False, + ) + + +class InferenceRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + self.data_source = input.data_source + self.state = GenerateState(input.args) + self.eval_prompt_dataset_cache = {} + + async def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: + if input.evaluation: + return await self._call_eval(input) + return await self._call_train(input) + + async def _call_train(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput: + from miles.rollout.inference_rollout.inference_rollout_train import generate_rollout_async + + output, aborted_samples = await generate_rollout_async( + self.state, input.rollout_id, self.data_source.get_samples + ) + self.data_source.add_samples(aborted_samples) + return output + + async def _call_eval(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput: + from miles.rollout.inference_rollout.inference_rollout_eval import eval_rollout_single_dataset + + assert not self.state.args.group_rm, "Group RM is not supported for eval rollout" + + coros = [] + for dataset_cfg in getattr(self.state.args, "eval_datasets", []) or []: + coros.append(eval_rollout_single_dataset(self.state, dataset_cfg, self.eval_prompt_dataset_cache)) + results_list = await asyncio.gather(*coros) + results = {k: v for r in results_list for k, v in r.items()} + return RolloutFnEvalOutput(data=results) diff --git a/miles/rollout/inference_rollout/inference_rollout_eval.py b/miles/rollout/inference_rollout/inference_rollout_eval.py new file mode 100644 index 0000000000..2d052be0ae --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_eval.py @@ -0,0 +1,112 @@ +import asyncio +import copy +import logging +from typing import Any + +from tqdm import tqdm + +from miles.rollout.inference_rollout.inference_rollout_common import ( + GenerateState, + compute_sampling_params, + generate_and_rm, +) +from miles.utils.data import Dataset +from miles.utils.eval_config import EvalDatasetConfig +from miles.utils.misc import as_completed_async +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +async def eval_rollout_single_dataset( + state: GenerateState, + dataset_cfg: EvalDatasetConfig, + prompt_dataset_cache: dict[Any, Dataset], +) -> dict[str, dict[str, list[Any]]]: + args = state.args + assert not args.group_rm, "Group RM is not supported for eval rollout" + + cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template) + if cache_key not in prompt_dataset_cache: + tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + prompt_dataset_cache[cache_key] = Dataset( + path=dataset_cfg.path, + tokenizer=tokenizer, + processor=processor, + max_length=args.eval_max_prompt_len, + prompt_key=dataset_cfg.input_key, + label_key=dataset_cfg.label_key, + multimodal_keys=args.multimodal_keys, + metadata_key=dataset_cfg.metadata_key, + tool_key=dataset_cfg.tool_key, + apply_chat_template=args.apply_chat_template, + apply_chat_template_kwargs=args.apply_chat_template_kwargs, + ) + dataset = prompt_dataset_cache[cache_key] + + base_sampling_params = compute_sampling_params( + args, + temperature=dataset_cfg.temperature, + top_p=dataset_cfg.top_p, + top_k=dataset_cfg.top_k, + max_new_tokens=dataset_cfg.max_response_len, + ) + + tasks = [] + # do multiple samples for eval prompts + sample_index = 0 + for _i, prompt_sample in enumerate(dataset.samples): + for j in range(dataset_cfg.n_samples_per_eval_prompt): + # use the same prompt for multiple samples + sample = copy.deepcopy(prompt_sample) + sample.index = sample_index + sample_index += 1 + sample.metadata = dataset_cfg.inject_metadata(getattr(sample, "metadata", None)) + sampling_params = base_sampling_params + if getattr(args, "sglang_enable_deterministic_inference", False): + sampling_params = base_sampling_params.copy() + sampling_params["sampling_seed"] = args.rollout_seed + j + tasks.append( + asyncio.create_task( + generate_and_rm( + state, + sample, + sampling_params=sampling_params, + evaluation=True, + ) + ) + ) + + data = [] + do_print = True + pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) + async for sample in as_completed_async(tasks): + if do_print: + # TODO improve this after enhancing samples' type + s = (sample[0] if len(sample) > 0 else None) if isinstance(sample, list) else sample + if s is not None: + logger.info( + "eval_rollout_single_dataset example data: " + f"{[str(s.prompt) + s.response]} " + f"reward={s.reward}" + ) + do_print = False + if isinstance(sample, list): + data.extend(sample) + else: + data.append(sample) + pbar.update(1) + pbar.close() + + data.sort(key=lambda sample: sample.index) + + reward_key = args.eval_reward_key or args.reward_key + return { + dataset_cfg.name: { + "rewards": [sample.reward if not reward_key else sample.reward[reward_key] for sample in data], + "truncated": [sample.status == Sample.Status.TRUNCATED for sample in data], + "samples": data, + } + } diff --git a/miles/rollout/inference_rollout/inference_rollout_train.py b/miles/rollout/inference_rollout/inference_rollout_train.py new file mode 100644 index 0000000000..bae94ec67b --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_train.py @@ -0,0 +1,146 @@ +import asyncio +import logging +from argparse import Namespace +from collections.abc import Callable + +import sglang_router +from packaging.version import parse +from tqdm import tqdm + +from miles.rollout.base_types import RolloutFnTrainOutput +from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter +from miles.rollout.inference_rollout.inference_rollout_common import GenerateState, generate_and_rm_group +from miles.utils.http_utils import get, post +from miles.utils.misc import as_completed_async, load_function +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +async def abort(state: GenerateState, pendings: set, rollout_id: int) -> list[list[Sample]]: + args = state.args + + assert not state.aborted + state.aborted = True + + urls = await get_worker_urls(args) + logger.info(f"Abort request for {urls}") + await asyncio.gather(*[post(f"{url}/abort_request", {"abort_all": True}) for url in urls]) + + # make sure all the pending tasks are finished + aborted_samples = [] + async for group in as_completed_async(pendings): + if not args.partial_rollout: + continue + + # for partial rollout, collect the partial samples into the data buffer + for sample in group: + if sample.response and "start_rollout_id" not in sample.metadata: + sample.metadata["start_rollout_id"] = rollout_id + aborted_samples.append(group) + + if args.partial_rollout: + logger.info(f"Collected {sum(len(x) for x in aborted_samples)} partial samples into the data buffer") + + return aborted_samples + + +async def get_worker_urls(args: Namespace): + if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") + return response["urls"] + else: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") + return [worker["url"] for worker in response["workers"]] + + +def submit_generate_tasks(state: GenerateState, samples: list[list[Sample]]): + return [ + asyncio.create_task( + # submit a group of samples as a single task. + generate_and_rm_group( + state, + group, + sampling_params=state.sampling_params.copy(), + evaluation=False, + ) + ) + for group in samples + ] + + +async def generate_rollout_async( + state: GenerateState, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] +) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: + args = state.args + assert args.rollout_global_dataset + + # instantiate data filters + dynamic_filter = load_function(args.dynamic_sampling_filter_path) + + metric_gatherer = MetricGatherer() + + # target_data_size is the total number of valid samples to get + target_data_size = args.rollout_batch_size + + pendings = set() + data = [] + all_data = [] + do_print = True + pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation") + while len(data) < target_data_size: + while len(data) + len(pendings) < target_data_size: + # get samples from the buffer and submit the generation requests. + samples = data_source(args.over_sampling_batch_size) + pendings.update(submit_generate_tasks(state, samples)) + + # wait for the generation to finish + done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED) + for task in done: + group: list[Sample] = task.result() + + if do_print: + sample = group[0][0] if isinstance(group[0], list) else group[0] + logger.info( + f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", + ) + do_print = False + + assert len(group) == args.n_samples_per_prompt + all_data.append(group) + dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) + if not dynamic_filter_output.keep: + metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) + continue + + # add the samples to the data + # NOTE: here we have not stored all the unused samples back to the data buffer. + if len(data) < target_data_size: + data.append(group) + pbar.update(args.n_samples_per_prompt) + + pbar.close() + sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] + logger.info( + f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", + ) + + # there are still some unfinished requests, abort them + aborted_samples = await abort(state, pendings, rollout_id) + + assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" + data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) + all_samples = sorted( + all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index + ) + + # reset the global state to prevent effects on the next rollout or eval. + state.reset() + + if f := load_function(args.rollout_sample_filter_path): + f(args, data) + # There can be circumstances where users want to process all samples including filtered ones. + if f := load_function(args.rollout_all_samples_process_path): + f(args, all_samples, data_source) + + return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples diff --git a/miles/rollout/rm_hub/__init__.py b/miles/rollout/rm_hub/__init__.py index 62b253ddee..e9ee29db41 100644 --- a/miles/rollout/rm_hub/__init__.py +++ b/miles/rollout/rm_hub/__init__.py @@ -69,8 +69,18 @@ async def async_rm(args, sample: Sample, **kwargs): async def batched_async_rm( args, samples: list[Sample], + inplace_set_reward_field: bool = False, **kwargs, -) -> list[int | float]: +) -> list[int | float] | None: + if inplace_set_reward_field: + rewards = await batched_async_rm(args, samples, **kwargs) + for sample, reward in zip(samples, rewards, strict=True): + assert ( + sample.reward is None + ), f"Overriding sample.reward from {sample.reward} to {reward}, is this intended?" + sample.reward = reward + return None + if args.custom_rm_path is not None: # Ensure the custom reward function is implemented in batch mode rm_function = load_function(args.custom_rm_path) diff --git a/miles/router/middleware_hub/radix_tree.py b/miles/router/middleware_hub/radix_tree.py index 6e722f1e25..67b9d6fe4e 100644 --- a/miles/router/middleware_hub/radix_tree.py +++ b/miles/router/middleware_hub/radix_tree.py @@ -584,8 +584,8 @@ def retrieve_from_text(self, text: str, return_logprob: bool = True): text: Input text to get tokens for return_logprob: If True, also return log probabilities Returns: - List of token IDs corresponding to the input text if return_logprob is False. - Tuple of (token_ids, logp) if return_logprob is True. + List of token (IDs, logp, loss_mask) corresponding to the input text + if return_logprob is False, all logp will be 0.0 """ # Call find_longest_prefix to get the match result result = self.find_longest_prefix(text) diff --git a/miles/router/middleware_hub/radix_tree_middleware.py b/miles/router/middleware_hub/radix_tree_middleware.py index db57f64564..b9d62d8415 100644 --- a/miles/router/middleware_hub/radix_tree_middleware.py +++ b/miles/router/middleware_hub/radix_tree_middleware.py @@ -66,12 +66,14 @@ def __init__(self, app, *, router): self.router.radix_tree = self.radix_tree async def dispatch(self, request: Request, call_next): - path = request.url.path + if path == "/generate": + return await self._generate(request, call_next) + if path == "/retrieve_from_text": + return await self._retrieve_from_text(request) + return await call_next(request) - if path != "/generate": - return await call_next(request) - + async def _generate(self, request: Request, call_next): request_json = await request.json() if "text" in request_json: input_text = request_json.pop("text", "") @@ -154,6 +156,23 @@ async def dispatch(self, request: Request, call_next): print(f"[miles-router] Warning: Failed to cache trajectory: {e}") return response + async def _retrieve_from_text(self, request: Request): + payload = await request.json() + text = payload.get("text", "") + token_ids, logp, loss_mask = self.radix_tree.retrieve_from_text(text, return_logprob=True) + result = { + "response": text, + "tokens": token_ids, + "loss_mask": loss_mask, + "rollout_logp": logp, + "token_length": len(token_ids), + "loss_mask_length": len(loss_mask), + } + assert ( + len(token_ids) == len(logp) == len(loss_mask) + ), "Token IDs, logp, and loss mask must have the same length" + return JSONResponse(result) + async def postprocess_sample_with_radix_tree(args, sample: Sample, output: dict): assert not args.partial_rollout, "Currently partial rollout is not supported when using miles router" diff --git a/miles/router/router.py b/miles/router/router.py index 2e8ecfc41f..f092f359a7 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse from starlette.responses import Response +from miles.router.session.sessions import setup_session_routes from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -64,11 +65,12 @@ def __init__(self, args, verbose=False): self.app.add_middleware(middleware, router=self) def _setup_routes(self): - """Setup all the HTTP routes""" + """Setup all the HTTP routes except catch-all proxy""" # sglang-router api self.app.post("/add_worker")(self.add_worker) self.app.get("/list_workers")(self.list_workers) - self.app.post("/retrieve_from_text")(self.retrieve_from_text) + # Session routes - must be registered before catch-all + setup_session_routes(self.app, self) # Catch-all route for proxying to SGLang - must be registered LAST self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) @@ -130,39 +132,51 @@ async def _health_check_loop(self): async def proxy(self, request: Request, path: str): """Proxy all other requests to the SGLang router""" - # Forward all other paths to SGLang router + result = await self._do_proxy(request, path) + return self._build_proxy_response(result) + + async def _do_proxy( + self, + request: Request, + path: str, + body: bytes | None = None, + headers: dict | None = None, + ) -> dict: + """Core proxy logic. Returns dict with request_body, response_body, status_code, headers.""" worker_url = self._use_url() url = f"{worker_url}/{path}" - # Get request body and headers - body = await request.body() - headers = dict(request.headers) + if body is None: + body = await request.body() + if headers is None: + headers = dict(request.headers) + if body is not None: + headers = {k: v for k, v in headers.items() if k.lower() not in ("content-length", "transfer-encoding")} try: response = await self.client.request(request.method, url, content=body, headers=headers) - # Eagerly read content so we can return JSON (not streaming) content = await response.aread() - content_type = response.headers.get("content-type", "") - try: - # Prefer parsing JSON if possible - data = json.loads(content) - return JSONResponse( - content=data, - status_code=response.status_code, - headers=dict(response.headers), - ) - except Exception: - # Fall back to raw body with original content type - return Response( - content=content, - status_code=response.status_code, - headers=dict(response.headers), - media_type=content_type or None, - ) - + return { + "request_body": body, + "response_body": content, + "status_code": response.status_code, + "headers": dict(response.headers), + } finally: self._finish_url(worker_url) + def _build_proxy_response(self, result: dict) -> Response: + """Build HTTP response from proxy result.""" + content = result["response_body"] + status_code = result["status_code"] + headers = result["headers"] + content_type = headers.get("content-type", "") + try: + data = json.loads(content) + return JSONResponse(content=data, status_code=status_code, headers=headers) + except Exception: + return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) + async def add_worker(self, request: Request): """Add a new worker to the router. Supports providing the URL via query string or JSON body. @@ -197,28 +211,6 @@ async def list_workers(self, request: Request): """List all registered workers""" return {"urls": list(self.worker_request_counts.keys())} - async def retrieve_from_text(self, request: Request): - """Get token information from text input""" - body = await request.body() - payload = json.loads(body) if body else {} - - text = payload.get("text", "") - - # Use radix tree's retrieve_from_text method (no need to fetch weight version here) - token_ids, logp, loss_mask = self.radix_tree.retrieve_from_text(text, return_logprob=True) - - # Handle the result based on whether logp was requested - result = { - "tokens": token_ids, # token IDs - "response": text, # The input text - "loss_mask": loss_mask, # Loss mask for the tokens - "token_length": len(token_ids), - "loss_mask_length": len(loss_mask), - "rollout_logp": logp, - } - - return result - def _use_url(self): """Select worker URL with minimal active requests.""" diff --git a/miles/router/session/naive_trajectory.py b/miles/router/session/naive_trajectory.py new file mode 100644 index 0000000000..3cd4ff1b75 --- /dev/null +++ b/miles/router/session/naive_trajectory.py @@ -0,0 +1,70 @@ +import threading +import uuid +from typing import Any + +from pydantic import BaseModel, Field + +from miles.router.session.session_types import SessionRecord + + +class NaiveTrajectory(BaseModel): + messages: list[dict[str, Any]] = Field(default_factory=list) + records: list[SessionRecord] = Field(default_factory=list) + + def append_session_record(self, record: SessionRecord): + self.records.append(record) + + +# This is only a naive trajectory manager to store history session record. +# Cross turn token input not implemented. +class NaiveTrajectoryManager: + def __init__(self, args, tokenizer: Any): + self.sessions: dict[str, NaiveTrajectory] = {} + self.args = args + self.tokenizer = tokenizer + self._lock = threading.RLock() + + def create_session(self) -> str: + with self._lock: + session_id = uuid.uuid4().hex + self.sessions[session_id] = NaiveTrajectory() + return session_id + + def get_session_records_by_id(self, session_id: str) -> list[SessionRecord] | None: + with self._lock: + session = self.sessions.get(session_id) + if session is None: + return None + return session.records + + def calc_prompt_tokens( + self, + session_id: str, + messages: list[dict[str, Any]], + ) -> list[int] | None: + with self._lock: + session = self.sessions.get(session_id) + if session is None: + return None + token_ids = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_special_tokens=False, + add_generation_prompt=True, + ) + return token_ids + + def delete_session_by_id(self, session_id: str) -> bool | None: + with self._lock: + session = self.sessions.pop(session_id, None) + if session is None: + return None + return True + + def append_session_record(self, session_id: str, record: SessionRecord) -> bool | None: + with self._lock: + session = self.sessions.get(session_id) + if session is None: + return None + session.append_session_record(record) + return True diff --git a/miles/router/session/session_types.py b/miles/router/session/session_types.py new file mode 100644 index 0000000000..c895b5e38d --- /dev/null +++ b/miles/router/session/session_types.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + + +class SessionRecord(BaseModel): + timestamp: float + method: str + path: str + request: dict + response: dict + status_code: int + + +class GetSessionResponse(BaseModel): + session_id: str + records: list[SessionRecord] diff --git a/miles/router/session/sessions.py b/miles/router/session/sessions.py new file mode 100644 index 0000000000..349272f818 --- /dev/null +++ b/miles/router/session/sessions.py @@ -0,0 +1,94 @@ +import json +import logging +import time +from typing import TYPE_CHECKING + +from fastapi import Request +from fastapi.responses import JSONResponse, Response +from transformers import AutoTokenizer + +from miles.router.session.naive_trajectory import NaiveTrajectoryManager +from miles.router.session.session_types import GetSessionResponse, SessionRecord + +if TYPE_CHECKING: + from miles.router.router import MilesRouter + +logger = logging.getLogger(__name__) + + +def setup_session_routes(app, router: "MilesRouter"): + hf_checkpoint = getattr(router.args, "hf_checkpoint", None) + if not hf_checkpoint: + if getattr(router, "verbose", False): + logger.info("[miles-router] Skipping session routes (hf_checkpoint not set).") + return + + tokenizer = AutoTokenizer.from_pretrained(hf_checkpoint, trust_remote_code=True) + manager = NaiveTrajectoryManager(router.args, tokenizer) + + @app.post("/sessions") + async def create_session(): + session_id = manager.create_session() + return {"session_id": session_id} + + @app.get("/sessions/{session_id}") + async def get_session(session_id: str): + records = manager.get_session_records_by_id(session_id) + if records is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + return GetSessionResponse( + session_id=session_id, + records=records, + ) + + @app.delete("/sessions/{session_id}") + async def delete_session(session_id: str): + deleted = manager.delete_session_by_id(session_id) + if deleted is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + return Response(status_code=204) + + @app.post("/sessions/{session_id}/v1/chat/completions") + async def chat_completions(request: Request, session_id: str): + body = await request.body() + request_body = json.loads(body) if body else {} + + if router.args.miles_router_enable_token_input_for_chat_completions: + if "messages" in request_body and "input_ids" not in request_body: + prompt_token_ids = manager.calc_prompt_tokens(session_id, request_body["messages"]) + if prompt_token_ids is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + request_body["input_ids"] = prompt_token_ids + body = json.dumps(request_body).encode("utf-8") + + result = await router._do_proxy(request, "v1/chat/completions", body=body) + + response = json.loads(result["response_body"]) + + choice = response.get("choices", [{}])[0] + # messages = request_body["messages"] + [choice["message"]] + + if "logprobs" not in choice or "content" not in choice["logprobs"]: + raise RuntimeError("logprobs must be in choice") + logprobs_content = choice["logprobs"]["content"] + + for item in logprobs_content: + if "token_id" not in item: + raise RuntimeError("token_id must be in item") + record = SessionRecord( + timestamp=time.time(), + method=request.method, + path="/v1/chat/completions", + status_code=result["status_code"], + request=request_body, + response=response, + ) + appended = manager.append_session_record(session_id, record) + if appended is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + return router._build_proxy_response(result) + + @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) + async def session_proxy(request: Request, session_id: str, path: str): + result = await router._do_proxy(request, path) + return router._build_proxy_response(result) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 79b2c419ca..ac43859521 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -10,8 +10,10 @@ from miles.backends.sglang_utils.arguments import add_sglang_arguments from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args +from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from miles.utils.logging_utils import configure_logger +from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -122,7 +124,7 @@ def add_train_arguments(parser): type=str, choices=["thd", "bshd"], default="thd", - help="The qkv layout for Megatron backend.", + help="The qkv layout.", ) parser.add_argument( "--true-on-policy-mode", @@ -146,7 +148,12 @@ def add_train_arguments(parser): "--disable-weights-backuper", action="store_false", dest="enable_weights_backuper", - help="Whether to disable weights backuper to save host memory.", + help=( + "Applies to `megatron` training backend only. " + "Disables the system that backups model weights (Actor, Ref, Old Actor) to CPU RAM. " + "Disabling saves significant host memory but prevents features that rely on weight-swapping, such as computing KL-divergence against a reference model. " + "Note: do not set `--ref-load` and `--keep-old-actor` if disable weights backuper." + ), ) parser.add_argument( "--megatron-to-hf-mode", @@ -169,7 +176,7 @@ def add_train_arguments(parser): parser.add_argument( "--recompute-loss-function", action="store_true", - help="Whether to disable recompute loss function to save memory during training.", + help="Whether to enable recompute loss function to save memory during training.", ) parser.add_argument( "--log-probs-chunk-size", type=int, default=-1, help="Chunk size to compute log probs to save memory" @@ -204,7 +211,11 @@ def add_rollout_arguments(parser): parser.add_argument( "--rollout-function-path", type=str, - default="miles.rollout.sglang_rollout.generate_rollout", + default=( + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn" + if enable_experimental_rollout_refactor() + else "miles.rollout.sglang_rollout.generate_rollout" + ), help=( "Path to the rollout generation function." "You should use this model to create your own custom rollout function, " @@ -487,10 +498,8 @@ def add_data_arguments(parser): action="store_false", dest="rollout_global_dataset", help=( - "Whether to use a global dataset for rollout. " - "If set, the rollout will use the `--prompt-data` as the prompt dataset, " - "and the prompts for rollout will be sampled from the dataset. " - "If not set, you need to manage the data by your self." + "Disable the global dataset for rollout. By default, Miles loads `--prompt-data` into a global dataset and samples from it for rollout. " + "Setting this flag turns off this behavior, Use this flag only when providing a custom `--rollout-function-path` (and usually a custom `--data-source-path`) that handles data loading independently." ), ) @@ -507,7 +516,7 @@ def add_data_arguments(parser): help=( "The path to the prompt data. " "Currently we only support jsonl format, and each line should contains --input-key and --label-key, " - "which will be used as the prompt and the label respectively. " + "which will be used as the prompt and the label respectively." "If you want to use a custom template, you can set --apply-chat-template to true, in that case, " "the input should be the same structure as an openai message, e.g. [{'role': 'user', 'content': 'blabla'}]. " ), @@ -579,8 +588,8 @@ def add_data_arguments(parser): action="store_true", default=False, help=( - "Balance the number of tokens between data parallel ranks with `karmarkar_karp` for verl. " - "Note that this may allocate the different response of the same prompt into different training steps." + "Repartition each rollout batch so each data-parallel rank gets a similar total token count via Karmarkar-Karp method. " + "It may be beneficial for training speed but changes per-rank sample grouping and adds a small CPU scheduling overhead." ), ) @@ -869,7 +878,7 @@ def add_algo_arguments(parser): "--use-tis", action="store_true", default=False, - help="Enable TIS from https://fengyao.notion.site/off-policy-rl for off-policy importance sampling.", + help="Enable TIS from https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33.", ) parser.add_argument( "--tis-clip", @@ -953,6 +962,17 @@ def add_router_arguments(parser): default=3, help="Number of consecutive failures before marking a worker as unhealthy.", ) + parser.add_argument( + "--miles-router-enable-token-input-for-chat-completions", + action="store_true", + default=False, + help=( + "This is an experimental feature, and only supports for text model." + "Whether to enable token input for chat completions. If set, we will calculate " + "the input_ids for the prompt part inside miles and add it to the request body." + "This is reserved for cross turn token in under OAI format." + ), + ) RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) return parser @@ -1016,14 +1036,14 @@ def add_wandb_arguments(parser): default=None, help=( "Log statistics of the category of reward, such as why the reward function considers it as failed. " - "Specify the key in the reward dict using this argument.", + "Specify the key in the reward dict using this argument." ), ) parser.add_argument( "--log-correct-samples", action="store_true", default=False, - help="Whether to turn on passrate logging, which will log the pass@n of the responses in the rollout.", + help="Explicitly log metrics for correct samples.", ) parser.add_argument("--wandb-run-id", type=str, default=None) return parser @@ -1344,6 +1364,20 @@ def add_ci_arguments(parser): ) return parser + def add_user_provided_function_arguments(parser): + args_partial, _ = parser.parse_known_args() + for path in [ + args_partial.rollout_function_path, + args_partial.custom_generate_function_path, + ]: + try: + fn = load_function(path) + except (ModuleNotFoundError, ValueError): + continue + if fn is not None and callable(getattr(fn, "add_arguments", None)): + fn.add_arguments(parser) + return parser + def add_sglang_tp_size(): temp_parser = argparse.ArgumentParser(add_help=False) temp_parser.add_argument("--rollout-num-gpus-per-engine", type=int, default=1) @@ -1374,6 +1408,8 @@ def add_sglang_tp_size(): parser = add_prefill_decode_disaggregation_arguments(parser) parser = add_ci_arguments(parser) parser = add_custom_megatron_plugins_arguments(parser) + if enable_experimental_rollout_refactor(): + parser = add_user_provided_function_arguments(parser) reset_arg( parser, "--custom-config-path", diff --git a/miles/utils/environ.py b/miles/utils/environ.py new file mode 100644 index 0000000000..35d1f350ee --- /dev/null +++ b/miles/utils/environ.py @@ -0,0 +1,14 @@ +import os + +_printed_experimental_rollout_refactor = False + + +def enable_experimental_rollout_refactor() -> bool: + result = bool(int(os.environ.get("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", "0"))) + + global _printed_experimental_rollout_refactor + if result and not _printed_experimental_rollout_refactor: + print("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1 is enabled (experimental feature)") + _printed_experimental_rollout_refactor = True + + return result diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 2b3e6e192f..0abdbbf59d 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -162,11 +162,15 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60): +async def _post(client, url, payload, max_retries=60, action="post"): retry_count = 0 while retry_count < max_retries: try: - response = await client.post(url, json=payload or {}) + if action in ("delete", "get"): + assert not payload + response = await getattr(client, action)(url) + else: + response = await getattr(client, action)(url, json=payload or {}) response.raise_for_status() try: output = response.json() @@ -240,8 +244,8 @@ def __init__(self, concurrency: int): timeout=httpx.Timeout(None), ) - async def do_post(self, url, payload, max_retries=60): - return await _post(self._client, url, payload, max_retries) + async def do_post(self, url, payload, max_retries=60, action="post"): + return await _post(self._client, url, payload, max_retries, action=action) # Create actors per node created = [] @@ -265,7 +269,8 @@ async def do_post(self, url, payload, max_retries=60): _post_actors = created -async def post(url, payload, max_retries=60): +# TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) +async def post(url, payload, max_retries=60, action="post"): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: @@ -274,15 +279,16 @@ async def post(url, payload, max_retries=60): actor = _next_actor() if actor is not None: # Use a thread to avoid blocking the event loop on ray.get - obj_ref = actor.do_post.remote(url, payload, max_retries) + obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) return await asyncio.to_thread(ray.get, obj_ref) except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries) + return await _post(_http_client, url, payload, max_retries, action=action) +# TODO unify w/ `post` to add retries and remote-execution async def get(url): response = await _http_client.get(url) response.raise_for_status() diff --git a/miles/utils/misc.py b/miles/utils/misc.py index c0a96d6366..bae72ec0d7 100644 --- a/miles/utils/misc.py +++ b/miles/utils/misc.py @@ -1,17 +1,55 @@ +import asyncio import importlib import subprocess +from contextlib import contextmanager import ray from miles.utils.http_utils import is_port_available +# Mainly used for test purpose where `load_function` needs to load many in-flight generated functions +class FunctionRegistry: + def __init__(self): + self._registry: dict[str, object] = {} + + @contextmanager + def temporary(self, name: str, fn: object): + self._register(name, fn) + try: + yield + finally: + self._unregister(name) + + def get(self, name: str) -> object | None: + return self._registry.get(name) + + def _register(self, name: str, fn: object) -> None: + assert name not in self._registry + self._registry[name] = fn + + def _unregister(self, name: str) -> None: + assert name in self._registry + self._registry.pop(name) + + +function_registry = FunctionRegistry() + + +# TODO may rename to `load_object` since it can be used to load things like tool_specs def load_function(path): """ - Load a function from a module. + Load a function from registry or module. :param path: The path to the function, e.g. "module.submodule.function". :return: The function object. """ + if path is None: + return None + + registered = function_registry.get(path) + if registered is not None: + return registered + module_path, _, attr = path.rpartition(".") module = importlib.import_module(module_path) return getattr(module, attr) @@ -30,8 +68,9 @@ def __call__(cls, *args, **kwargs): cls._instances[cls] = instance return cls._instances[cls] - def clear_instances(cls): - cls._instances = {} + @staticmethod + def clear_all_instances(): + SingletonMeta._instances.clear() def exec_command(cmd: str, capture_output: bool = False) -> str | None: @@ -92,3 +131,8 @@ def should_run_periodic_action( step = rollout_id + 1 return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0) + + +async def as_completed_async(tasks): + for coro in asyncio.as_completed(tasks): + yield await coro diff --git a/miles/utils/test_utils/__init__.py b/miles/utils/test_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py new file mode 100644 index 0000000000..3647c86265 --- /dev/null +++ b/miles/utils/test_utils/mock_sglang_server.py @@ -0,0 +1,254 @@ +import asyncio +import re +import time +import uuid +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import asdict, dataclass + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.function_call_parser import FunctionCallParser +from transformers import AutoTokenizer + +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +@dataclass(frozen=True) +class ProcessResultMetaInfo: + weight_version: str | None = None + routed_experts: str | None = None + spec_accept_token_num: int | None = None + spec_draft_token_num: int | None = None + spec_verify_ct: int | None = None + + def to_dict(self) -> dict: + return {k: v for k, v in asdict(self).items() if v is not None} + + +@dataclass(frozen=True) +class ProcessResult: + text: str + finish_reason: str = "stop" + cached_tokens: int = 0 + meta_info: ProcessResultMetaInfo = ProcessResultMetaInfo() + + +ProcessFn = Callable[[str], ProcessResult] + + +class MockSGLangServer: + def __init__( + self, + model_name: str, + process_fn: ProcessFn, + host: str, + port: int, + latency: float = 0.0, + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.process_fn = process_fn + self.host = host + self.port = port or find_available_port(30000) + self.latency = latency + + self.app = FastAPI() + self._server: UvicornThreadServer | None = None + + self.request_log: list[dict] = [] + self._concurrency = Counter() + + self._setup_routes() + + @property + def max_concurrent(self) -> int: + return self._concurrency.max_value + + def reset_stats(self): + self.request_log.clear() + self._concurrency.reset() + + def start(self): + self._server = UvicornThreadServer(self.app, host=self.host, port=self.port) + self._server.start() + + def stop(self): + if self._server is not None: + self._server.stop() + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + def _setup_routes(self): + @self.app.post("/generate") + async def generate(request: Request): + return await self._handle_generate_like_request(request, self._compute_generate_response) + + @self.app.post("/v1/chat/completions") + async def chat_completions(request: Request): + return await self._handle_generate_like_request(request, self._compute_chat_completions_response) + + @self.app.get("/health") + async def health(): + return JSONResponse(content={"status": "ok"}) + + @self.app.post("/abort_request") + async def abort_request(_request: Request): + return JSONResponse(content={"status": "ok"}) + + async def _handle_generate_like_request(self, request: Request, compute_fn: Callable[[dict], dict]): + payload = await request.json() + self.request_log.append(payload) + with self._concurrency.track(): + if self.latency > 0: + await asyncio.sleep(self.latency) + response = compute_fn(payload) + return JSONResponse(content=response) + + def _compute_generate_response(self, payload: dict) -> dict: + assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True" + input_ids = payload.get("input_ids", []) + + prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False) + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + prompt_tokens = len(input_ids) + completion_tokens = len(output_ids) + + finish_reason_dict = {"type": process_result.finish_reason} + if process_result.finish_reason == "length": + finish_reason_dict["length"] = completion_tokens + + output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] + + meta_info = { + "finish_reason": finish_reason_dict, + "prompt_tokens": prompt_tokens, + "cached_tokens": process_result.cached_tokens, + "completion_tokens": completion_tokens, + "output_token_logprobs": output_token_logprobs, + **process_result.meta_info.to_dict(), + } + + return {"text": process_result.text, "meta_info": meta_info} + + def _compute_chat_completions_response(self, payload: dict) -> dict: + messages = payload.get("messages", []) + tools = payload.get("tools") + + prompt_str = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, tools=tools + ) + prompt_ids = self.tokenizer.encode(prompt_str, add_special_tokens=False) + + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + logprobs_content = [ + { + "token": self.tokenizer.convert_ids_to_tokens(tid), + "token_id": tid, + "logprob": -1 / 128 * i, + } + for i, tid in enumerate(output_ids) + ] + + finish_reason = process_result.finish_reason + tool_calls = None + if tools and finish_reason == "stop": + parser = FunctionCallParser( + tools=TypeAdapter(list[Tool]).validate_python(tools), + tool_call_parser="qwen25", + ) + message_content, parsed_calls = parser.parse_non_stream(process_result.text) + if parsed_calls: + finish_reason = "tool_calls" + tool_calls = [ + { + "id": f"call{i:05d}", + "type": "function", + "function": {"name": call.name, "arguments": call.parameters or "{}"}, + } + for i, call in enumerate(parsed_calls) + ] + else: + message_content = process_result.text + + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message_content, + "tool_calls": tool_calls, + }, + "logprobs": {"content": logprobs_content}, + "input_token_ids": prompt_ids, + "finish_reason": finish_reason, + } + ], + } + + +class Counter: + def __init__(self): + self._current = 0 + self._max = 0 + + @property + def max_value(self) -> int: + return self._max + + def reset(self): + self._current = 0 + self._max = 0 + + @contextmanager + def track(self): + self._current += 1 + self._max = max(self._max, self._current) + try: + yield + finally: + self._current -= 1 + + +def default_process_fn(prompt: str) -> ProcessResult: + match = re.search(r"What is 1\+(\d+)\?", prompt) + if match: + num = int(match.group(1)) + ans = 1 + num + return ProcessResult(text=f"\\boxed{{{ans}}}", finish_reason="stop") + return ProcessResult(text="I don't understand.", finish_reason="stop") + + +@contextmanager +def with_mock_server( + model_name: str = "Qwen/Qwen3-0.6B", + process_fn: ProcessFn = default_process_fn, + host: str = "127.0.0.1", + port: int | None = None, + latency: float = 0.0, +): + server = MockSGLangServer( + model_name=model_name, + process_fn=process_fn, + host=host, + port=port, + latency=latency, + ) + try: + server.start() + yield server + finally: + server.stop() diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py new file mode 100644 index 0000000000..f38344c8c6 --- /dev/null +++ b/miles/utils/test_utils/mock_tools.py @@ -0,0 +1,323 @@ +import json +from copy import deepcopy +from typing import Any + +from transformers import AutoTokenizer + +from miles.utils.test_utils.mock_sglang_server import ProcessResult + +AGENTIC_MAX_TURNS: int | None = None +from miles.utils.http_utils import post + +SAMPLE_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_year", + "description": "Get current year", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_temperature", + "description": "Get temperature for a location", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + }, +] + + +def _get_year(params: dict) -> str: + assert len(params) == 0 + return json.dumps({"year": 2026}) + + +def _get_temperature(params: dict) -> str: + temps = {"Mars": -60, "Earth": 15} + location = params.get("location") + assert location in temps, f"Unknown location: {location}" + return json.dumps({"temperature": temps[location]}) + + +TOOL_EXECUTORS = { + "get_year": _get_year, + "get_temperature": _get_temperature, +} + + +async def execute_tool_call(name: str, params: dict) -> str: + return TOOL_EXECUTORS[name](params) + + +async def run_agentic_tool_call( + base_url: str, + prompt: list[dict[str, Any]] | str, + request_kwargs: dict[str, Any] | None = None, + max_turns: int = 8, +) -> None: + if AGENTIC_MAX_TURNS is not None: + max_turns = AGENTIC_MAX_TURNS + messages = deepcopy(prompt) if isinstance(prompt, list) else [{"role": "user", "content": prompt}] + request_kwargs = request_kwargs or {} + model = request_kwargs.get("model", "default") + tools = request_kwargs.get("tools", SAMPLE_TOOLS) + + for _ in range(max_turns): + payload = {"model": model, "messages": messages, "tools": tools} + response = await post(base_url + "/v1/chat/completions", payload) + choice = response["choices"][0]["message"] + tool_calls = choice.get("tool_calls") or [] + if not tool_calls: + break + + assistant_msg = { + "content": choice.get("content"), + "refusal": choice.get("refusal"), + "role": choice.get("role", "assistant"), + "annotations": choice.get("annotations"), + "audio": choice.get("audio"), + "function_call": choice.get("function_call"), + "tool_calls": tool_calls, + } + messages.append(assistant_msg) + + for tool_call in tool_calls: + name = tool_call["function"]["name"] + raw_args = tool_call["function"].get("arguments") or "{}" + try: + params = json.loads(raw_args) if isinstance(raw_args, str) else raw_args + except json.JSONDecodeError: + params = {} + result = await execute_tool_call(name, params) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.get("id"), + "content": result, + "name": name, + } + ) + + +_SYSTEM_PROMPT = ( + "<|im_start|>system\n" + "# Tools\n" + "\n" + "You may call one or more functions to assist with the user query.\n" + "\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n" + "\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" +) + + +_TOKENIZER = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) + + +class TwoTurnStub: + """Stub for 2-turn: get_year + get_temperature(Mars) -> final answer""" + + USER_QUESTION = "What is 42 + year + temperature?" + + FIRST_RESPONSE = ( + "Let me get the year and temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "<|im_end|>\n" + ) + + FIRST_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + SECOND_RESPONSE = "The answer is: 42 + 2026 + -60 = 2008." + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = "Let me get the year and temperature first." + FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + TwoTurnStub.FIRST_PROMPT: TwoTurnStub.FIRST_RESPONSE, + TwoTurnStub.SECOND_PROMPT: TwoTurnStub.SECOND_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") + + +class ThreeTurnStub: + """Stub for 3-turn: get_year + get_temperature(Mars) -> get_temperature(Earth) -> final answer""" + + USER_QUESTION = "What is 42 + year + Mars temperature + Earth temperature?" + + FIRST_RESPONSE = ( + "Let me get the year and Mars temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "<|im_end|>\n" + ) + + SECOND_RESPONSE = ( + "Now let me get Earth temperature.\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Earth"}}\n' + "<|im_end|>\n" + ) + + FIRST_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + SECOND_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"temperature": 15}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + THIRD_RESPONSE = "The answer is: 42 + 2026 + -60 + 15 = 2023." + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + THIRD_PROMPT = SECOND_PROMPT + SECOND_RESPONSE + SECOND_TOOL_RESPONSE + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + THIRD_PROMPT_TOKEN_IDS = _TOKENIZER(THIRD_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = "Let me get the year and Mars temperature first." + FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + SECOND_RESPONSE_CONTENT = "Now let me get Earth temperature." + SECOND_TOOL_CALLS_OPENAI_FORMAT = [ + { + "id": "call00000", + "function": {"arguments": '{"location": "Earth"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT = OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT + [ + { + "content": SECOND_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": SECOND_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"temperature": 15}', "name": "get_temperature"}, + ] + + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + ThreeTurnStub.FIRST_PROMPT: ThreeTurnStub.FIRST_RESPONSE, + ThreeTurnStub.SECOND_PROMPT: ThreeTurnStub.SECOND_RESPONSE, + ThreeTurnStub.THIRD_PROMPT: ThreeTurnStub.THIRD_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") diff --git a/miles/utils/test_utils/uvicorn_thread_server.py b/miles/utils/test_utils/uvicorn_thread_server.py new file mode 100644 index 0000000000..904343c984 --- /dev/null +++ b/miles/utils/test_utils/uvicorn_thread_server.py @@ -0,0 +1,49 @@ +import asyncio +import socket +import threading +import time + +import uvicorn + + +class UvicornThreadServer: + def __init__(self, app, host: str, port: int): + self._app = app + self.host = host + self.port = port + self._server: uvicorn.Server | None = None + self._thread: threading.Thread | None = None + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + def start(self) -> None: + config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info") + self._server = uvicorn.Server(config) + + def run() -> None: + asyncio.run(self._server.serve()) + + self._thread = threading.Thread(target=run, daemon=True) + self._thread.start() + self._wait_for_port_open() + + def stop(self) -> None: + if self._server is not None: + self._server.should_exit = True + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=2.0) + + def _wait_for_port_open(self) -> None: + for _ in range(50): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex((self.host, self.port)) + sock.close() + if result == 0: + return + except Exception: + pass + time.sleep(0.1) + raise RuntimeError(f"Failed to start server on {self.url}") diff --git a/miles/utils/types.py b/miles/utils/types.py index 0a2531a7af..5200d625e6 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -145,6 +145,24 @@ def get_reward_value(self, args) -> float: def effective_response_length(self): return sum(self.loss_mask) if self.loss_mask is not None else self.response_length + def validate(self): + assert self.response_length >= 0, f"response_length must be >= 0, got {self.response_length}" + assert ( + len(self.tokens) >= self.response_length + ), f"tokens length ({len(self.tokens)}) must be >= response_length ({self.response_length})" + if self.loss_mask is not None: + assert ( + len(self.loss_mask) == self.response_length + ), f"loss_mask length ({len(self.loss_mask)}) != response_length ({self.response_length})" + if self.rollout_log_probs is not None: + assert ( + len(self.rollout_log_probs) == self.response_length + ), f"rollout_log_probs length ({len(self.rollout_log_probs)}) != response_length ({self.response_length})" + if self.rollout_routed_experts is not None: + actual = len(self.rollout_routed_experts) + expect = len(self.tokens) - 1 + assert actual == expect, f"rollout_routed_experts length ({actual}) != len(tokens) - 1 ({expect})" + def update_from_meta_info(self, args, meta_info: dict): """ Update the sample with new information from meta_info returned by the rollout engine. diff --git a/requirements.txt b/requirements.txt index 2c20195fc4..dacd51132c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf pillow +pybase64 pylatexenc pyyaml qwen_vl_utils # for VLM diff --git a/scripts/run-qwen3-235B-A22B.sh b/scripts/run-qwen3-235B-A22B.sh index e42e17ab29..ffd5972ac0 100644 --- a/scripts/run-qwen3-235B-A22B.sh +++ b/scripts/run-qwen3-235B-A22B.sh @@ -161,7 +161,7 @@ RUNTIME_ENV_JSON="{ \"env_vars\": { \"PYTHONPATH\": \"/root/Megatron-LM/\", \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", \"no_proxy\": \"${no_proxy}\", \"MASTER_ADDR\": \"${MASTER_ADDR}\" } diff --git a/scripts/run-qwen3-4B-amd.sh b/scripts/run-qwen3-4B-amd.sh index 83af901563..998f06b7f4 100755 --- a/scripts/run-qwen3-4B-amd.sh +++ b/scripts/run-qwen3-4B-amd.sh @@ -139,16 +139,16 @@ NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 -# "PYTHONPATH": "/workspace/Megatron-LM/", -MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}') +# Dynamically detect Megatron-LM installation path +MEGATRON_LM_PATH=$(python3 -c "import megatron; import os; print(os.path.dirname(os.path.dirname(megatron.__file__)))" 2>/dev/null || echo "/app/Megatron-LM") ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json='{ - "env_vars": { - "PYTHONPATH": "/workspace/Megatron-LM/", - "CUDA_DEVICE_MAX_CONNECTIONS": "1" + --runtime-env-json="{ + \"env_vars\": { + \"PYTHONPATH\": \"${MEGATRON_LM_PATH}/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" } - }' \ + }" \ -- python3 train.py \ --actor-num-nodes 1 \ --actor-num-gpus-per-node 8 \ diff --git a/scripts/run-qwen3-4B.sh b/scripts/run-qwen3-4B.sh index c7f01abd93..cecb41704e 100644 --- a/scripts/run-qwen3-4B.sh +++ b/scripts/run-qwen3-4B.sh @@ -10,33 +10,58 @@ sleep 3 pkill -9 ray pkill -9 python -set -ex +set -euxo pipefail -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 +# ==================== Platform Detection ==================== +if [ -e /dev/kfd ] || python3 -c "import torch; assert torch.version.hip" 2>/dev/null; then + GPU_VENDOR="amd" +elif command -v nvidia-smi &>/dev/null; then + GPU_VENDOR="nvidia" else + echo "ERROR: No supported GPU detected (need NVIDIA or AMD)" + exit 1 +fi +echo "Detected GPU vendor: ${GPU_VENDOR}" + +# ==================== Configurable Paths ==================== +MODEL_DIR="${MODEL_DIR:-/root}" +DATA_DIR="${DATA_DIR:-/root}" +export MODEL_DIR DATA_DIR + +# ==================== Platform-Specific Setup ==================== +if [ "$GPU_VENDOR" = "amd" ]; then + export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} + export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} + NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) HAS_NVLINK=0 +else + NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) + if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 + else + HAS_NVLINK=0 + fi + echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/qwen3-4B.sh" CKPT_ARGS=( - --hf-checkpoint /root/Qwen3-4B - #--hf-checkpoint /root/Qwen3-4B-FP8 - --ref-load /root/Qwen3-4B_torch_dist - --load /root/Qwen3-4B_miles/ - --save /root/Qwen3-4B_miles/ + --hf-checkpoint ${MODEL_DIR}/Qwen3-4B + #--hf-checkpoint ${MODEL_DIR}/Qwen3-4B-FP8 + --ref-load ${MODEL_DIR}/Qwen3-4B_torch_dist + --load ${MODEL_DIR}/Qwen3-4B_miles/ + --save ${MODEL_DIR}/Qwen3-4B_miles/ --save-interval 20 ) ROLLOUT_ARGS=( - --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl --input-key prompt --label-key label --apply-chat-template @@ -54,7 +79,7 @@ ROLLOUT_ARGS=( EVAL_ARGS=( --eval-interval 20 - --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 --eval-top-p 1 @@ -108,6 +133,11 @@ SGLANG_ARGS=( --sglang-mem-fraction-static 0.7 ) +# AMD: disable custom all-reduce to prevent driver-level deadlocks with offload enabled +if [ "$GPU_VENDOR" = "amd" ]; then + SGLANG_ARGS+=(--sglang-disable-custom-all-reduce) +fi + MISC_ARGS=( # default dropout in megatron is 0.1 --attention-dropout 0.0 @@ -119,14 +149,17 @@ MISC_ARGS=( --attention-backend flash ) -# launch the master node of ray in container +# ==================== Launch Ray ==================== export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Dynamically detect Megatron-LM installation path +MEGATRON_LM_PATH=$(python3 -c "import megatron; import os; print(os.path.dirname(os.path.dirname(megatron.__file__)))" 2>/dev/null || echo "/app/Megatron-LM") -# Build the runtime environment JSON with proper variable substitution +# Build the runtime environment JSON RUNTIME_ENV_JSON="{ \"env_vars\": { - \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"PYTHONPATH\": \"${MEGATRON_LM_PATH}/\", \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" } @@ -136,7 +169,7 @@ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ + --actor-num-gpus-per-node ${NUM_GPUS} \ --colocate \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ @@ -147,4 +180,4 @@ ray job submit --address="http://127.0.0.1:8265" \ ${PERF_ARGS[@]} \ ${EVAL_ARGS[@]} \ ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ No newline at end of file + ${MISC_ARGS[@]} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/ci/gpu_lock_exec.py b/tests/ci/gpu_lock_exec.py index 9507e2e858..20379f76a2 100644 --- a/tests/ci/gpu_lock_exec.py +++ b/tests/ci/gpu_lock_exec.py @@ -19,11 +19,14 @@ def main(): _execute_print_only(args) return - fd_locks = _try_acquire(args) + if args.count == 0 and not args.devices: + print("[gpu_lock_exec] Do not acquire GPU since count=0", flush=True) + else: + fd_locks = _try_acquire(args) - dev_list = ",".join(str(x.gpu_id) for x in fd_locks) - os.environ[args.target_env_name] = dev_list - print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) + dev_list = ",".join(str(x.gpu_id) for x in fd_locks) + os.environ[args.target_env_name] = dev_list + print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) _os_execvp(args) diff --git a/tests/e2e/.gitkeep b/tests/e2e/.gitkeep new file mode 100644 index 0000000000..615f2b076c --- /dev/null +++ b/tests/e2e/.gitkeep @@ -0,0 +1 @@ +# TODO: may move e2e tests to this folder \ No newline at end of file diff --git a/tests/test_qwen3_4B_ckpt.py b/tests/e2e/ckpt/test_qwen3_4B_ckpt.py similarity index 98% rename from tests/test_qwen3_4B_ckpt.py rename to tests/e2e/ckpt/test_qwen3_4B_ckpt.py index 22fb2b5fc3..0df4492e10 100644 --- a/tests/test_qwen3_4B_ckpt.py +++ b/tests/e2e/ckpt/test_qwen3_4B_ckpt.py @@ -124,6 +124,7 @@ def execute(mode: str = ""): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py similarity index 97% rename from tests/test_qwen3_0.6B_fsdp_distributed.py rename to tests/e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py index 3d70f3e4ce..fcd7772882 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py @@ -95,6 +95,7 @@ def execute(): num_gpus_per_node=2 if FEW_GPU else 4, megatron_model_type=None, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py similarity index 95% rename from tests/test_qwen3_0.6B_megatron_fsdp_align.py rename to tests/e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py index 1431d8c3d4..b89a2f283b 100644 --- a/tests/test_qwen3_0.6B_megatron_fsdp_align.py +++ b/tests/e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py @@ -97,6 +97,7 @@ def execute(): train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "), num_gpus_per_node=NUM_GPUS, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) U.execute_train( @@ -109,6 +110,7 @@ def execute(): ), num_gpus_per_node=NUM_GPUS, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) U.execute_train( @@ -135,6 +137,7 @@ def execute(): "--debug-train-only " ), num_gpus_per_node=NUM_GPUS, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, megatron_model_type=MODEL_TYPE, ) diff --git a/tests/test_qwen3_4B_fsdp_true_on_policy.py b/tests/e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py similarity index 98% rename from tests/test_qwen3_4B_fsdp_true_on_policy.py rename to tests/e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py index 7c975c7cc2..03ba4094e9 100644 --- a/tests/test_qwen3_4B_fsdp_true_on_policy.py +++ b/tests/e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py @@ -95,6 +95,7 @@ def execute(): "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", "CUBLAS_WORKSPACE_CONFIG": ":4096:8", "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", } U.execute_train( diff --git a/tests/test_qwen3_vl_4B_fsdp.py b/tests/e2e/fsdp/test_qwen3_vl_4B_fsdp.py similarity index 98% rename from tests/test_qwen3_vl_4B_fsdp.py rename to tests/e2e/fsdp/test_qwen3_vl_4B_fsdp.py index fbdffd237e..bc4ef3293c 100644 --- a/tests/test_qwen3_vl_4B_fsdp.py +++ b/tests/e2e/fsdp/test_qwen3_vl_4B_fsdp.py @@ -92,6 +92,7 @@ def execute(): extra_env_vars = { "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", } U.execute_train( diff --git a/tests/test_mimo_7B_mtp_only_grad.py b/tests/e2e/image/test_mimo_7B_mtp_only_grad.py similarity index 98% rename from tests/test_mimo_7B_mtp_only_grad.py rename to tests/e2e/image/test_mimo_7B_mtp_only_grad.py index 97c76ace5a..d90a2d7a71 100644 --- a/tests/test_mimo_7B_mtp_only_grad.py +++ b/tests/e2e/image/test_mimo_7B_mtp_only_grad.py @@ -135,6 +135,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_moonlight_16B_A3B.py b/tests/e2e/image/test_moonlight_16B_A3B.py similarity index 98% rename from tests/test_moonlight_16B_A3B.py rename to tests/e2e/image/test_moonlight_16B_A3B.py index b1255982ed..c35943ec15 100644 --- a/tests/test_moonlight_16B_A3B.py +++ b/tests/e2e/image/test_moonlight_16B_A3B.py @@ -113,6 +113,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_quick_start_glm4_9B.py b/tests/e2e/image/test_quick_start_glm4_9B.py similarity index 98% rename from tests/test_quick_start_glm4_9B.py rename to tests/e2e/image/test_quick_start_glm4_9B.py index 15ca8ce5fe..ae3c383ae8 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/e2e/image/test_quick_start_glm4_9B.py @@ -115,6 +115,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k.py similarity index 98% rename from tests/test_qwen2.5_0.5B_gsm8k.py rename to tests/e2e/image/test_qwen2.5_0.5B_gsm8k.py index dcdbd58347..4d7f034f6c 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k.py @@ -120,6 +120,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k_async.py similarity index 98% rename from tests/test_qwen2.5_0.5B_gsm8k_async.py rename to tests/e2e/image/test_qwen2.5_0.5B_gsm8k_async.py index dcaaf5e1f7..32b60f5937 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k_async.py @@ -120,6 +120,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k_async_short.py similarity index 98% rename from tests/test_qwen2.5_0.5B_gsm8k_async_short.py rename to tests/e2e/image/test_qwen2.5_0.5B_gsm8k_async_short.py index 90cd15cb68..b1954a4e83 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py +++ b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k_async_short.py @@ -118,6 +118,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_short.py b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k_short.py similarity index 98% rename from tests/test_qwen2.5_0.5B_gsm8k_short.py rename to tests/e2e/image/test_qwen2.5_0.5B_gsm8k_short.py index 867fdcad60..86e21eac8d 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_short.py +++ b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k_short.py @@ -117,6 +117,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/e2e/image/test_qwen3_0.6B_fsdp_colocated_2xGPU.py similarity index 97% rename from tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py rename to tests/e2e/image/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 3d19b48ced..3d4768e420 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/e2e/image/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -93,6 +93,7 @@ def execute(): train_args=train_args, num_gpus_per_node=2, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/e2e/image/test_qwen3_0.6B_fsdp_distributed.py b/tests/e2e/image/test_qwen3_0.6B_fsdp_distributed.py new file mode 100644 index 0000000000..fcd7772882 --- /dev/null +++ b/tests/e2e/image/test_qwen3_0.6B_fsdp_distributed.py @@ -0,0 +1,106 @@ +import os +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" + + +FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + # NOTE cannot be exactly multiple of eval-interval, since async causes some offsets + f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 65} " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 1 " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 256 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + # "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 1 " "--sglang-enable-metrics " + + misc_args = ( + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {1 if FEW_GPU else 2} " + f"--rollout-num-gpus {1 if FEW_GPU else 2} " + "--train-backend fsdp " + ) + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + "--ci-metric-checker-key eval/gsm8k " + "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=2 if FEW_GPU else 4, + megatron_model_type=None, + train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py new file mode 100644 index 0000000000..b89a2f283b --- /dev/null +++ b/tests/e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py @@ -0,0 +1,155 @@ +import os + +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" +MODEL_TYPE = "qwen3-0.6B" +NUM_GPUS = 4 +CP_SIZE = 1 +MEGATRON_TP_SIZE = 1 +MEGATRON_PP_SIZE = 1 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + dir_dst="/root/models", + ) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/" + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 64 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " "--sglang-chunked-prefill-size 4096 " "--sglang-mem-fraction-static 0.75 " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " "--colocate " f"--actor-num-gpus-per-node {NUM_GPUS} " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + debug_data_path = "test_rollout_data_megatron_fsdp_align.pt" + grad_norm_path = "grad_norm_fsdp.pt" + + fsdp_args = ( + "--train-backend fsdp " + "--attn-implementation flash_attention_2 " + "--gradient-checkpointing " + f"--context-parallel-size {CP_SIZE} " + f"--update-weight-buffer-size {512 * 1024 * 1024} " + """--train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' """ + ) + + try: + U.execute_train( + train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + U.execute_train( + train_args=train_args + + ( + f"{fsdp_args}" + f"--load-debug-rollout-data {debug_data_path} " + f"--ci-save-grad-norm {grad_norm_path} " + "--debug-train-only " + ), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + U.execute_train( + train_args=train_args + + ( + f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + f"--tensor-model-parallel-size {MEGATRON_TP_SIZE} " + "--sequence-parallel " + f"--pipeline-model-parallel-size {MEGATRON_PP_SIZE} " + f"--context-parallel-size {CP_SIZE} " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--train-memory-margin-bytes 3221225472 " + f"--load-debug-rollout-data {debug_data_path} " + f"--ci-load-grad-norm {grad_norm_path} " + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--debug-train-only " + ), + num_gpus_per_node=NUM_GPUS, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + megatron_model_type=MODEL_TYPE, + ) + + finally: + if os.path.exists(grad_norm_path): + os.remove(grad_norm_path) + if os.path.exists(debug_data_path): + os.remove(debug_data_path) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_0.6B_parallel_check.py b/tests/e2e/image/test_qwen3_0.6B_parallel_check.py similarity index 96% rename from tests/test_qwen3_0.6B_parallel_check.py rename to tests/e2e/image/test_qwen3_0.6B_parallel_check.py index 44f5c42fa5..d0ad283d15 100644 --- a/tests/test_qwen3_0.6B_parallel_check.py +++ b/tests/e2e/image/test_qwen3_0.6B_parallel_check.py @@ -95,6 +95,7 @@ def execute(): ), num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) # 8 GPU CPU 1 for num_gpus in [8, 4, 2]: @@ -124,6 +125,7 @@ def execute(): train_args=args, num_gpus_per_node=num_gpus, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) train_args += "--calculate-per-token-loss " diff --git a/tests/test_qwen3_30B_A3B.py b/tests/e2e/image/test_qwen3_30B_A3B.py similarity index 96% rename from tests/test_qwen3_30B_A3B.py rename to tests/e2e/image/test_qwen3_30B_A3B.py index adff108043..95649e2a33 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/e2e/image/test_qwen3_30B_A3B.py @@ -93,7 +93,7 @@ def execute(): sglang_args = ( "--rollout-num-gpus-per-engine 8 " - "--sglang-mem-fraction-static 0.8 " + f"--sglang-mem-fraction-static {0.7 if TIGHT_HOST_MEMORY else 0.8} " "--sglang-max-running-requests 512 " "--sglang-enable-metrics " ) @@ -139,6 +139,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/e2e/image/test_qwen3_4B_ckpt.py b/tests/e2e/image/test_qwen3_4B_ckpt.py new file mode 100644 index 0000000000..0df4492e10 --- /dev/null +++ b/tests/e2e/image/test_qwen3_4B_ckpt.py @@ -0,0 +1,138 @@ +import os +from argparse import ArgumentParser + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Qwen3-4B" +MODEL_TYPE = "qwen3-4B" +NUM_GPUS = 8 + + +parser = ArgumentParser() +parser.add_argument("--async-save", action="store_true", help="Whether to test async save/load.") + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"rm -rf /root/models/{MODEL_NAME}_miles") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint( + model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS, dir_dst="/root/models" + ) + + +def execute(mode: str = ""): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + if mode == "save": + ckpt_args += f"--save /root/models/{MODEL_NAME}_miles " + ckpt_args += "--save-interval 2 " + elif mode == "async_save": + ckpt_args += f"--save /root/models/{MODEL_NAME}_miles " + ckpt_args += "--save-interval 2 " + ckpt_args += "--async-save " + elif mode == "load": + ckpt_args += f"--load /root/models/{MODEL_NAME}_miles " + ckpt_args += "--ckpt-step 1 " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--global-batch-size 32 " + "--balance-data " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 16384} " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 --sglang-mem-fraction-static 0.8 --sglang-cuda-graph-bs 1 2 4 8 16 " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + args = parser.parse_args() + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute("save" if not args.async_save else "async_save") + execute("load") diff --git a/tests/e2e/image/test_qwen3_4B_fsdp_true_on_policy.py b/tests/e2e/image/test_qwen3_4B_fsdp_true_on_policy.py new file mode 100644 index 0000000000..03ba4094e9 --- /dev/null +++ b/tests/e2e/image/test_qwen3_4B_fsdp_true_on_policy.py @@ -0,0 +1,113 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +NUM_GPUS = 2 + +MODEL_NAME = "Qwen3-4B" + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 4096 " + "--eval-top-p 0.7 " + ) + + fsdp_args = "--train-backend fsdp " "--update-weight-buffer-size 536870912 " + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + "--sglang-decode-log-interval 1000 " + "--sglang-enable-metrics " + "--sglang-enable-deterministic-inference " + "--sglang-rl-on-policy-target fsdp " + "--sglang-attention-backend fa3 " + "--attn-implementation flash_attention_3 " + "--deterministic-mode " + "--true-on-policy-mode " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{fsdp_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + extra_env_vars = { + "NCCL_ALGO": "allreduce:tree", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + } + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars=extra_env_vars, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_4B_ppo.py b/tests/e2e/image/test_qwen3_4B_ppo.py similarity index 98% rename from tests/test_qwen3_4B_ppo.py rename to tests/e2e/image/test_qwen3_4B_ppo.py index 962f610fac..d4c1ac273a 100644 --- a/tests/test_qwen3_4B_ppo.py +++ b/tests/e2e/image/test_qwen3_4B_ppo.py @@ -122,6 +122,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/e2e/image/test_qwen3_vl_4B_fsdp.py b/tests/e2e/image/test_qwen3_vl_4B_fsdp.py new file mode 100644 index 0000000000..bc4ef3293c --- /dev/null +++ b/tests/e2e/image/test_qwen3_vl_4B_fsdp.py @@ -0,0 +1,112 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +NUM_GPUS = 8 + +MODEL_NAME = "Qwen3-VL-4B-Instruct" +DATASET_NAME = "chenhegu/geo3k_imgurl" + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset(DATASET_NAME) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/geo3k_imgurl/train.parquet " + "--input-key problem " + "--label-key answer " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + ) + + # multimodal keys required for vlm datasets + multimodal_args = '--multimodal-keys \'{"image": "images"}\' ' + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data geo3k /root/datasets/geo3k_imgurl/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 4096 " + ) + + fsdp_args = "--train-backend fsdp " "--gradient-checkpointing " "--update-weight-buffer-size 536870912 " + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + "--sglang-mem-fraction-static 0.6 " + "--sglang-decode-log-interval 1000 " + "--sglang-enable-metrics " + "--sglang-attention-backend fa3 " + "--attn-implementation flash_attention_3 " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{multimodal_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{fsdp_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + extra_env_vars = { + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + } + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars=extra_env_vars, + ) + + +if __name__ == "__main__": + prepare() + os.environ.pop("http_proxy", None) + os.environ.pop("https_proxy", None) + os.environ.pop("HTTP_PROXY", None) + os.environ.pop("HTTPS_PROXY", None) + execute() diff --git a/tests/e2e/long/test_qwen2.5_0.5B_gsm8k.py b/tests/e2e/long/test_qwen2.5_0.5B_gsm8k.py new file mode 100644 index 0000000000..4d7f034f6c --- /dev/null +++ b/tests/e2e/long/test_qwen2.5_0.5B_gsm8k.py @@ -0,0 +1,131 @@ +import os +import miles.utils.external_utils.command_utils as U + + +FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 2 if FEW_GPU else 4 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 250} " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 1 " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 256 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + # "--micro-batch-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + f"--sglang-mem-fraction-static {0.6 if TIGHT_DEVICE_MEMORY else 0.7} " + "--sglang-enable-metrics " + ) + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + "--ci-metric-checker-key eval/gsm8k " + "--ci-metric-checker-threshold 0.55 " # loose threshold at 250 step + ) + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {2 if FEW_GPU else 4} " + "--colocate " + "--megatron-to-hf-mode bridge " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/long/test_qwen2.5_0.5B_gsm8k_async.py b/tests/e2e/long/test_qwen2.5_0.5B_gsm8k_async.py new file mode 100644 index 0000000000..32b60f5937 --- /dev/null +++ b/tests/e2e/long/test_qwen2.5_0.5B_gsm8k_async.py @@ -0,0 +1,131 @@ +import os +import miles.utils.external_utils.command_utils as U + +FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 2 if FEW_GPU else 4 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 250} " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 1 " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 256 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + # "--micro-batch-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + f"--sglang-mem-fraction-static {0.6 if TIGHT_DEVICE_MEMORY else 0.7} " + "--sglang-enable-metrics " + ) + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + "--ci-metric-checker-key eval/gsm8k " + "--ci-metric-checker-threshold 0.55 " # loose threshold at 250 step + ) + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {1 if FEW_GPU else 2} " + f"--rollout-num-gpus {1 if FEW_GPU else 2} " + "--megatron-to-hf-mode bridge " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/megatron/test_mimo_7B_mtp_only_grad.py b/tests/e2e/megatron/test_mimo_7B_mtp_only_grad.py new file mode 100644 index 0000000000..d90a2d7a71 --- /dev/null +++ b/tests/e2e/megatron/test_mimo_7B_mtp_only_grad.py @@ -0,0 +1,147 @@ +"""End-to-end test for MTP-only gradient verification. + +This test verifies that when MTP training is enabled and all outputs are truncated +(due to very short max response length), only MTP parameters receive non-zero +gradients while all other model parameters have zero gradients. + +This validates that the MTP loss computation correctly isolates gradient flow +to only the MTP layers when the main model loss is zero (due to truncation). +""" + +import os + +import miles.utils.external_utils.command_utils as U + + +MODEL_NAME = "MiMo-7B-RL" +MODEL_TYPE = "mimo-7B-rl" +NUM_GPUS = 8 + + +def prepare(): + """Download model and convert checkpoint with MTP layers.""" + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download XiaomiMiMo/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + # Convert checkpoint with MTP layers enabled + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + extra_args=" --mtp-num-layers 1", + dir_dst="/root/models", + ) + + +def execute(): + """Run training with MTP enabled and very short output length to cause truncation.""" + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + + # Use very short rollout-max-response-len to ensure all outputs are truncated + # This should result in zero loss for the main model, leaving only MTP loss + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 2 " + # Very short max response length to cause all outputs to be truncated + "--rollout-max-response-len 128 " + "--rollout-temperature 0.8 " + "--global-batch-size 8 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 4096 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " + "--rollout-num-gpus 8 " + "--sglang-mem-fraction-static 0.8 " + "--sglang-enable-metrics " + "--sglang-speculative-algorithm EAGLE " + "--sglang-speculative-num-steps 2 " + "--sglang-speculative-eagle-topk 1 " + "--sglang-speculative-num-draft-tokens 3 " + ) + + # Enable MTP training with loss scaling + mtp_args = "--mtp-num-layers 1 " "--enable-mtp-training " "--mtp-loss-scaling-factor 0.2 " + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + # MTP grad check is automatically triggered when ci_test and enable_mtp_training are both set + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{sglang_args} " + f"{mtp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + # Remove proxy settings that might interfere with local operations + for key in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"]: + os.environ.pop(key, None) + execute() diff --git a/tests/e2e/megatron/test_moonlight_16B_A3B.py b/tests/e2e/megatron/test_moonlight_16B_A3B.py new file mode 100644 index 0000000000..c35943ec15 --- /dev/null +++ b/tests/e2e/megatron/test_moonlight_16B_A3B.py @@ -0,0 +1,124 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Moonlight-16B-A3B-Instruct" +MODEL_TYPE = "moonlight" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command( + "hf download moonshotai/Moonlight-16B-A3B-Instruct --local-dir /root/models/Moonlight-16B-A3B-Instruct" + ) + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 4096 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 2048} " + ) + + grpo_args = ( + "--advantage-estimator gspo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " "--sglang-mem-fraction-static 0.8 " "--sglang-max-running-requests 512 " + ) + + ci_args = "--ci-test " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_moonlight_16B_A3B_r3.py b/tests/e2e/megatron/test_moonlight_16B_A3B_r3.py similarity index 100% rename from tests/test_moonlight_16B_A3B_r3.py rename to tests/e2e/megatron/test_moonlight_16B_A3B_r3.py diff --git a/tests/e2e/megatron/test_quick_start_glm4_9B.py b/tests/e2e/megatron/test_quick_start_glm4_9B.py new file mode 100644 index 0000000000..ae3c383ae8 --- /dev/null +++ b/tests/e2e/megatron/test_quick_start_glm4_9B.py @@ -0,0 +1,127 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = U.get_bool_env_var("MILES_TEST_ENABLE_EVAL", "1") +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "GLM-Z1-9B-0414" +MODEL_TYPE = "glm4-9B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command("hf download zai-org/GLM-Z1-9B-0414 --local-dir /root/models/GLM-Z1-9B-0414") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 16384 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_DEVICE_MEMORY else 4608} " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + "--use-tis " + "--calculate-per-token-loss " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 " "--use-miles-router " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 4 " + "--rollout-num-gpus 4 " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/megatron/test_qwen3_30B_A3B.py b/tests/e2e/megatron/test_qwen3_30B_A3B.py new file mode 100644 index 0000000000..95649e2a33 --- /dev/null +++ b/tests/e2e/megatron/test_qwen3_30B_A3B.py @@ -0,0 +1,151 @@ +import os + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) +USE_DEEPEP = bool(int(os.environ.get("MILES_TEST_USE_DEEPEP", "1"))) +USE_FP8_ROLLOUT = bool(int(os.environ.get("MILES_TEST_USE_FP8_ROLLOUT", "1"))) + +MODEL_NAME = "Qwen3-30B-A3B" +MODEL_TYPE = "qwen3-30B-A3B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command("hf download Qwen/Qwen3-30B-A3B --local-dir /root/models/Qwen3-30B-A3B") + U.exec_command("hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/models/Qwen3-30B-A3B-FP8") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + if USE_FP8_ROLLOUT: + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}-FP8 " f"--ref-load /root/{MODEL_NAME}_torch_dist " + else: + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 16384 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 4 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 16384} " + ) + + grpo_args = ( + "--advantage-estimator gspo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + "--use-tis " + "--use-routing-replay " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 8 " + f"--sglang-mem-fraction-static {0.7 if TIGHT_HOST_MEMORY else 0.8} " + "--sglang-max-running-requests 512 " + "--sglang-enable-metrics " + ) + + if USE_DEEPEP: + sglang_args += "--sglang-moe-a2a-backend deepep --sglang-deepep-mode auto " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + if USE_DEEPEP: + misc_args += "--moe-token-dispatcher-type flex --moe-enable-deepep " + else: + misc_args += "--moe-token-dispatcher-type alltoall " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_30B_A3B_r3.py b/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py similarity index 98% rename from tests/test_qwen3_30B_A3B_r3.py rename to tests/e2e/megatron/test_qwen3_30B_A3B_r3.py index 5a5b968aa6..8b54176d12 100644 --- a/tests/test_qwen3_30B_A3B_r3.py +++ b/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py @@ -94,7 +94,7 @@ def execute(): sglang_args = ( "--rollout-num-gpus-per-engine 8 " - "--sglang-mem-fraction-static 0.8 " + f"--sglang-mem-fraction-static {0.7 if TIGHT_HOST_MEMORY else 0.8} " "--sglang-max-running-requests 512 " "--sglang-enable-metrics " ) diff --git a/tests/e2e/megatron/test_qwen3_4B_ppo.py b/tests/e2e/megatron/test_qwen3_4B_ppo.py new file mode 100644 index 0000000000..d4c1ac273a --- /dev/null +++ b/tests/e2e/megatron/test_qwen3_4B_ppo.py @@ -0,0 +1,134 @@ +import os + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Qwen3-4B" +MODEL_TYPE = "qwen3-4B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command("hf download Qwen/Qwen3-4B --local-dir /root/models/Qwen3-4B") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 0.8 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 16384 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 16384} " + ) + + ppo_args = ( + "--advantage-estimator ppo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + "--num-critic-only-steps 1 " + "--normalize-advantages " + "--critic-lr 1e-5 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " + "--rollout-num-gpus 8 " + "--sglang-mem-fraction-static 0.8 " + "--sglang-max-running-requests 512 " + "--sglang-enable-metrics " + ) + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 4 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py new file mode 100644 index 0000000000..b89a2f283b --- /dev/null +++ b/tests/e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py @@ -0,0 +1,155 @@ +import os + +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" +MODEL_TYPE = "qwen3-0.6B" +NUM_GPUS = 4 +CP_SIZE = 1 +MEGATRON_TP_SIZE = 1 +MEGATRON_PP_SIZE = 1 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + dir_dst="/root/models", + ) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/" + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 64 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " "--sglang-chunked-prefill-size 4096 " "--sglang-mem-fraction-static 0.75 " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " "--colocate " f"--actor-num-gpus-per-node {NUM_GPUS} " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + debug_data_path = "test_rollout_data_megatron_fsdp_align.pt" + grad_norm_path = "grad_norm_fsdp.pt" + + fsdp_args = ( + "--train-backend fsdp " + "--attn-implementation flash_attention_2 " + "--gradient-checkpointing " + f"--context-parallel-size {CP_SIZE} " + f"--update-weight-buffer-size {512 * 1024 * 1024} " + """--train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' """ + ) + + try: + U.execute_train( + train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + U.execute_train( + train_args=train_args + + ( + f"{fsdp_args}" + f"--load-debug-rollout-data {debug_data_path} " + f"--ci-save-grad-norm {grad_norm_path} " + "--debug-train-only " + ), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + U.execute_train( + train_args=train_args + + ( + f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + f"--tensor-model-parallel-size {MEGATRON_TP_SIZE} " + "--sequence-parallel " + f"--pipeline-model-parallel-size {MEGATRON_PP_SIZE} " + f"--context-parallel-size {CP_SIZE} " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--train-memory-margin-bytes 3221225472 " + f"--load-debug-rollout-data {debug_data_path} " + f"--ci-load-grad-norm {grad_norm_path} " + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--debug-train-only " + ), + num_gpus_per_node=NUM_GPUS, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + megatron_model_type=MODEL_TYPE, + ) + + finally: + if os.path.exists(grad_norm_path): + os.remove(grad_norm_path) + if os.path.exists(debug_data_path): + os.remove(debug_data_path) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/precision/test_qwen3_0.6B_parallel_check.py b/tests/e2e/precision/test_qwen3_0.6B_parallel_check.py new file mode 100644 index 0000000000..d0ad283d15 --- /dev/null +++ b/tests/e2e/precision/test_qwen3_0.6B_parallel_check.py @@ -0,0 +1,138 @@ +import os + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Qwen3-0.6B" +MODEL_TYPE = "qwen3-0.6B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + U.convert_checkpoint( + model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS, dir_dst="/root/models" + ) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 0.8 " + "--global-batch-size 32 " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 " "--rollout-num-gpus 8 " "--sglang-mem-fraction-static 0.8 " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + for i in range(2): + U.execute_train( + train_args=train_args + + ( + f"--save-debug-rollout-data data-{i}.pt " + f"--ci-save-grad-norm grad_norms-{i}.pt " + f"--actor-num-gpus-per-node {NUM_GPUS} " + ), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + # 8 GPU CPU 1 + for num_gpus in [8, 4, 2]: + remaining_gpus = num_gpus + for tp_size in [1, 2, 4, 8]: + remaining_gpus /= tp_size + for pp_size in [1, 2, 4]: + if remaining_gpus < pp_size: + continue + remaining_gpus /= pp_size + for cp_size in [1, 2, 4, 8]: + if remaining_gpus < cp_size: + continue + args = train_args + ( + f"--load-debug-rollout-data data-{i}.pt " + f"--ci-load-grad-norm grad_norms-{i}.pt " + f"--context-parallel-size {cp_size} " + f"--tensor-model-parallel-size {tp_size} " + f"--pipeline-model-parallel-size {pp_size} " + "--sequence-parallel " + f"--actor-num-gpus-per-node {num_gpus} " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + U.execute_train( + train_args=args, + num_gpus_per_node=num_gpus, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + train_args += "--calculate-per-token-loss " + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/sglang_patch/__init__.py b/tests/e2e/sglang_patch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/e2e/sglang_patch/sglang_server.py b/tests/e2e/sglang_patch/sglang_server.py new file mode 100644 index 0000000000..44214de056 --- /dev/null +++ b/tests/e2e/sglang_patch/sglang_server.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import os +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import IO + +import requests + +from miles.utils.http_utils import find_available_port + +DEFAULT_HOST = "127.0.0.1" +DEFAULT_BASE_PORT = 34000 +DEFAULT_STARTUP_TIMEOUT_SECS = 900.0 +DEFAULT_SHUTDOWN_TIMEOUT_SECS = 30.0 + + +@dataclass +class SGLangServer: + process: subprocess.Popen + host: str + port: int + log_path: Path + _log_file: IO[str] + + @property + def base_url(self) -> str: + return f"http://{self.host}:{self.port}" + + def stop(self, timeout_secs: float = DEFAULT_SHUTDOWN_TIMEOUT_SECS) -> None: + if self.process.poll() is None: + self.process.terminate() + try: + self.process.wait(timeout=timeout_secs) + except subprocess.TimeoutExpired: + self.process.kill() + self.process.wait(timeout=timeout_secs) + self._log_file.close() + + +def start_sglang_server( + *, + model_path: str, + host: str = DEFAULT_HOST, + port: int | None = None, + startup_timeout_secs: float = DEFAULT_STARTUP_TIMEOUT_SECS, + enable_deterministic_inference: bool = True, + extra_args: list[str] | None = None, +) -> SGLangServer: + if port is None: + port = find_available_port(DEFAULT_BASE_PORT) + + log_path = Path(f"/tmp/sglang_e2e_{port}.log") + log_file = log_path.open("w", encoding="utf-8") + + cmd = [ + sys.executable, + "-m", + "sglang.launch_server", + "--model-path", + model_path, + "--host", + host, + "--port", + str(port), + "--trust-remote-code", + ] + if enable_deterministic_inference: + cmd.append("--enable-deterministic-inference") + if extra_args: + cmd.extend(extra_args) + + env = os.environ.copy() + env.setdefault("PYTHONUNBUFFERED", "1") + + process = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT, env=env) + server = SGLangServer(process=process, host=host, port=port, log_path=log_path, _log_file=log_file) + + _wait_for_ready(server, timeout_secs=startup_timeout_secs) + return server + + +def _wait_for_ready(server: SGLangServer, *, timeout_secs: float) -> None: + deadline = time.monotonic() + timeout_secs + last_error = "" + + while time.monotonic() < deadline: + if server.process.poll() is not None: + log_tail = _read_log_tail(server.log_path) + raise RuntimeError( + "SGLang server exited early. " f"Exit code: {server.process.returncode}. " f"Log tail:\n{log_tail}" + ) + + try: + response = requests.get(f"{server.base_url}/health", timeout=5) + if response.status_code == 200: + return + last_error = f"status_code={response.status_code}" + except requests.RequestException as exc: + last_error = str(exc) + + time.sleep(1.0) + + log_tail = _read_log_tail(server.log_path) + raise TimeoutError( + "Timed out waiting for SGLang server to become healthy. " + f"Last error: {last_error}. " + f"Log tail:\n{log_tail}" + ) + + +def _read_log_tail(path: Path, max_lines: int = 80) -> str: + if not path.exists(): + return "" + + content = path.read_text(encoding="utf-8", errors="ignore") + lines = content.splitlines() + if len(lines) <= max_lines: + return content + return "\n".join(lines[-max_lines:]) diff --git a/tests/e2e/sglang_patch/test_chat_input_ids_equivalence.py b/tests/e2e/sglang_patch/test_chat_input_ids_equivalence.py new file mode 100644 index 0000000000..adcf08fb2b --- /dev/null +++ b/tests/e2e/sglang_patch/test_chat_input_ids_equivalence.py @@ -0,0 +1,122 @@ +import math +import os + +import pytest +import requests +from tests.e2e.sglang_patch.sglang_server import start_sglang_server +from transformers import AutoTokenizer + +MODEL_PATH = os.environ.get("SGLANG_E2E_MODEL_PATH", "Qwen/Qwen3-0.6B") +SEED = 1234 +TEMPERATURE = 1.0 +TOP_P = 1.0 +MAX_COMPLETION_TOKENS = 64 +LOGPROB_TOL = 1e-6 + + +@pytest.fixture(scope="module") +def sglang_server(): + server = start_sglang_server(model_path=MODEL_PATH) + try: + yield server + finally: + server.stop() + + +@pytest.mark.system +def test_chat_completions_input_ids_equivalence(sglang_server): + """Validate that providing input_ids yields the same completion as raw messages.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + messages = _build_messages() + # Build the same prompt two ways: message list vs. explicit input_ids. + input_ids = _build_input_ids(tokenizer, messages) + + # Request completions for both payload variants. + response_a = _post_chat(sglang_server.base_url, _build_payload(messages)) + response_b = _post_chat(sglang_server.base_url, _build_payload(messages, input_ids)) + + choice_a = response_a["choices"][0] + choice_b = response_b["choices"][0] + + # The generated content and finish reason should match across variants. + assert choice_a["message"]["content"] == choice_b["message"]["content"] + assert choice_a["finish_reason"] == choice_b["finish_reason"] + + # Compare token ids and per-token logprobs for exact equivalence. + token_ids_a, logprobs_a = _extract_tokens_and_logprobs(choice_a) + token_ids_b, logprobs_b = _extract_tokens_and_logprobs(choice_b) + + assert token_ids_a == token_ids_b + assert len(logprobs_a) == len(logprobs_b) + + for index, (a_val, b_val) in enumerate(zip(logprobs_a, logprobs_b, strict=True)): + assert math.isclose(a_val, b_val, abs_tol=LOGPROB_TOL), f"logprob mismatch at {index}: {a_val} vs {b_val}" + + +@pytest.mark.system +def test_chat_completions_input_logprobs_prompt_ids_match(sglang_server): + """Ensure input_ids are echoed exactly in input_token_ids and logprobs are present.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + messages = _build_messages() + input_ids = _build_input_ids(tokenizer, messages) + + response = _post_chat(sglang_server.base_url, _build_payload(messages, input_ids)) + choice = response["choices"][0] + + input_token_ids = _extract_input_token_ids(choice) + + assert input_token_ids == input_ids + assert choice.get("logprobs", {}).get("content"), "logprobs content is missing" + + +def _post_chat(base_url: str, payload: dict) -> dict: + response = requests.post(f"{base_url}/v1/chat/completions", json=payload, timeout=120) + print(f"response: {response.json()}", flush=True) + assert response.status_code == 200, response.text + return response.json() + + +def _build_messages() -> list[dict]: + return [ + {"role": "system", "content": "You are a concise assistant."}, + {"role": "user", "content": "Answer with one word: 2+2?"}, + ] + + +def _build_input_ids(tokenizer, messages: list[dict]) -> list[int]: + return tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + + +def _build_payload(messages: list[dict], input_ids: list[int] | None = None) -> dict: + payload = { + "model": MODEL_PATH, + "temperature": TEMPERATURE, + "top_p": TOP_P, + "max_completion_tokens": MAX_COMPLETION_TOKENS, + "seed": SEED, + "logprobs": True, + "messages": messages, + "logprob_start_len": 0, + } + if input_ids is not None: + payload["input_ids"] = input_ids + return payload + + +def _extract_tokens_and_logprobs(choice: dict) -> tuple[list[int], list[float]]: + logprobs = choice.get("logprobs", {}).get("content") + assert logprobs, "logprobs content is missing" + + token_ids = [] + for item in logprobs: + token_ids.append(item["token_id"]) + values = [item["logprob"] for item in logprobs] + return token_ids, values + + +def _extract_input_token_ids(choice: dict) -> list[int]: + token_ids = choice.get("input_token_ids") + assert token_ids is not None, "input_token_ids is missing in response" + + print(f"input_token_ids: {token_ids}", flush=True) + return token_ids diff --git a/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py b/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py new file mode 100644 index 0000000000..b1954a4e83 --- /dev/null +++ b/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py @@ -0,0 +1,129 @@ +import os +import miles.utils.external_utils.command_utils as U + +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 4 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 4 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--over-sampling-batch-size 16 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 32 " + ) + + eval_args = ( + "--eval-interval 8 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + f"--sglang-mem-fraction-static {0.55 if TIGHT_DEVICE_MEMORY else 0.65} " + "--sglang-enable-metrics " + ) + + ci_args = "--ci-test " + + fault_tolerance_args = ( + "--use-fault-tolerance " + "--rollout-health-check-interval 5 " + "--rollout-health-check-timeout 10 " + "--rollout-health-check-first-wait 0 " + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 1 " + "--rollout-num-gpus 3 " + "--megatron-to-hf-mode bridge " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{fault_tolerance_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_short.py b/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_short.py new file mode 100644 index 0000000000..86e21eac8d --- /dev/null +++ b/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_short.py @@ -0,0 +1,128 @@ +import os +import miles.utils.external_utils.command_utils as U + +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 4 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 4 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--over-sampling-batch-size 16 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 32 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + f"--sglang-mem-fraction-static {0.6 if TIGHT_DEVICE_MEMORY else 0.7} " + "--sglang-enable-metrics " + ) + + ci_args = "--ci-test " + + fault_tolerance_args = ( + "--use-fault-tolerance " + "--rollout-health-check-interval 5 " + "--rollout-health-check-timeout 10 " + "--rollout-health-check-first-wait 0 " + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 4 " + "--colocate " + "--megatron-to-hf-mode bridge " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{fault_tolerance_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py new file mode 100644 index 0000000000..3d4768e420 --- /dev/null +++ b/tests/e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -0,0 +1,104 @@ +import os +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 60} " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 1 " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 256 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + # "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 " "--sglang-decode-log-interval 1000 " "--sglang-enable-metrics " + + fsdp_args = ( + # Set to true for FULL_STATE_DICT mode, false for SHARDED_STATE_DICT mode (default) + # "--fsdp-full-params " # Uncomment this line to enable full params mode + # Set the bucket size for weight update + "--update-weight-buffer-size 536870912 " # 512MB + ) + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + "--ci-metric-checker-key eval/gsm8k " + "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step + ) + + misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{sglang_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{eval_args} " + f"{fsdp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=2, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/fast/__init__.py b/tests/fast/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/conftest.py b/tests/fast/conftest.py new file mode 100644 index 0000000000..4cb30e91fa --- /dev/null +++ b/tests/fast/conftest.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from tests.fast.fixtures.generation_fixtures import generation_env +from tests.fast.fixtures.rollout_fixtures import rollout_env + +_ = rollout_env, generation_env + + +@pytest.fixture(autouse=True) +def enable_experimental_rollout_refactor(): + os.environ["MILES_EXPERIMENTAL_ROLLOUT_REFACTOR"] = "1" + yield + os.environ.pop("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", None) diff --git a/tests/fast/fixtures/__init__.py b/tests/fast/fixtures/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/fast/fixtures/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/fast/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py new file mode 100644 index 0000000000..2dfabfa3ee --- /dev/null +++ b/tests/fast/fixtures/generation_fixtures.py @@ -0,0 +1,278 @@ +""" +Fixtures to test custom-generate-function +""" + +from argparse import Namespace +from contextlib import contextmanager +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any +from unittest.mock import patch + +import pytest +import requests + +from miles.rollout.base_types import GenerateFnInput +from miles.rollout.inference_rollout.compatibility import load_generate_function +from miles.rollout.inference_rollout.inference_rollout_common import GenerateState +from miles.router.router import MilesRouter +from miles.utils.async_utils import run +from miles.utils.http_utils import find_available_port, init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.test_utils import mock_tools +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer +from miles.utils.types import Sample + +MODEL_NAME = "Qwen/Qwen3-0.6B" +RESPONSE_TEXT = "\\boxed{8}" +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} + +VARIANT_TO_GENERATE_FN_PATH = { + "old_sglang_rollout": "miles.rollout.sglang_rollout.generate", + "single_turn": "miles.rollout.generate_hub.single_turn.generate", + "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn.generate", + "multi_turn_multi_samples": "miles.rollout.generate_hub.multi_turn.generate", + "agentic_tool_call_single_sample": "miles.rollout.generate_hub.agentic_tool_call.generate", + "agentic_tool_call_multi_samples": "miles.rollout.generate_hub.agentic_tool_call.generate", +} + + +def extra_argv_for_variant( + variant: str, + *, + custom_generate_function_path: str | None = None, + generate_max_turns: int = 16, + generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + generate_tool_call_parser: str = "qwen25", + generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", + custom_agent_function_path: str = "miles.utils.test_utils.mock_tools.run_agentic_tool_call", +) -> list[str]: + argv = [ + "--custom-generate-function-path", + custom_generate_function_path or VARIANT_TO_GENERATE_FN_PATH[variant], + ] + + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + argv += [ + "--generate-max-turns", + str(generate_max_turns), + "--generate-tool-specs-path", + generate_tool_specs_path, + "--generate-execute-tool-function-path", + generate_execute_tool_function_path, + ] + argv += ["--generate-tool-call-parser", generate_tool_call_parser] + if variant == "multi_turn_multi_samples": + argv.append("--generate-multi-samples") + elif variant in ("agentic_tool_call_single_sample", "agentic_tool_call_multi_samples"): + argv += ["--custom-agent-function-path", custom_agent_function_path] + if variant == "agentic_tool_call_multi_samples": + argv.append("--generate-multi-samples") + + return argv + + +def listify(x): + return x if isinstance(x, list) else [x] + + +def make_sample( + *, + prompt: str | list[dict] = "What is 1+7?", + tokens: list[int] | None = None, + response: str = "", + response_length: int = 0, + status: Sample.Status = Sample.Status.PENDING, + multimodal_inputs: dict | None = None, +) -> Sample: + return Sample( + prompt=prompt, + tokens=tokens or [], + response=response, + response_length=response_length, + status=status, + multimodal_inputs=multimodal_inputs, + ) + + +@dataclass +class GenerateEnv: + args: Namespace + mock_server: Any + + +@dataclass +class GenerateResult: + sample: Sample | list[Sample] + requests: list[dict] + + +def run_generate( + env: GenerateEnv, + sample: Sample, + sampling_params: dict[str, Any] | None = None, + *, + variant: str = "single_turn", +) -> GenerateResult: + env.mock_server.request_log.clear() + result_sample = run( + _call_generate( + env.args, + sample, + sampling_params or DEFAULT_SAMPLING_PARAMS, + variant=variant, + ) + ) + return GenerateResult(sample=result_sample, requests=list(env.mock_server.request_log)) + + +async def _call_generate( + args: Namespace, + sample: Sample, + sampling_params: dict[str, Any], + *, + variant: str = "single_turn", +) -> Sample: + generate_fn = load_generate_function(VARIANT_TO_GENERATE_FN_PATH[variant]) + state = GenerateState(args) + input = GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params.copy(), evaluation=False) + output = await generate_fn(input) + return output.samples + + +def make_args( + *, + variant: str, + router_port: int, + use_rollout_routing_replay: bool = False, + sglang_speculative_algorithm: str | None = None, + model_name: str = MODEL_NAME, + extra_argv: list[str] | None = None, + custom_generate_function_path: str | None = None, + generate_max_turns: int = 16, + generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + generate_tool_call_parser: str = "qwen25", + generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", + rollout_max_context_len: int | None = None, +) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + model_name, + "--prompt-data", + "/dev/null", + "--rm-type", + "math", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + if use_rollout_routing_replay: + argv.append("--use-rollout-routing-replay") + if sglang_speculative_algorithm: + argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) + if rollout_max_context_len is not None: + argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) + + argv.extend( + extra_argv_for_variant( + variant, + custom_generate_function_path=custom_generate_function_path, + generate_max_turns=generate_max_turns, + generate_tool_specs_path=generate_tool_specs_path, + generate_tool_call_parser=generate_tool_call_parser, + generate_execute_tool_function_path=generate_execute_tool_function_path, + ) + ) + + if extra_argv: + argv.extend(extra_argv) + + from miles.utils.arguments import parse_args + + with patch("sys.argv", argv): + args = parse_args() + + init_http_client(args) + return args + + +@contextmanager +def with_miles_router(backend_url: str, model_name: str): + router_args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + miles_router_enable_token_input_for_chat_completions=False, + hf_checkpoint=model_name, + ) + router = MilesRouter(router_args) + + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() + + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend_url}) + + try: + yield port + finally: + server.stop() + + +@pytest.fixture +def generation_env(request, variant): + SingletonMeta.clear_all_instances() + params = getattr(request, "param", {}) + args_kwargs = params.get("args_kwargs", {}) + model_name = args_kwargs.get("model_name", MODEL_NAME) + custom_generate_function_path = VARIANT_TO_GENERATE_FN_PATH[variant] + + def process_fn(_): + x = params.get("process_fn_kwargs", {}) + return ProcessResult( + text=x.get("response_text", RESPONSE_TEXT), + finish_reason=x.get("finish_reason", "stop"), + cached_tokens=x.get("cached_tokens", 0), + meta_info=ProcessResultMetaInfo( + weight_version=x.get("weight_version"), + routed_experts=x.get("routed_experts"), + spec_accept_token_num=x.get("spec_accept_token_num"), + spec_draft_token_num=x.get("spec_draft_token_num"), + spec_verify_ct=x.get("spec_verify_ct"), + ), + ) + + with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server: + with with_miles_router(mock_server.url, model_name) as router_port: + other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} + args = make_args( + variant=variant, + router_port=router_port, + model_name=model_name, + custom_generate_function_path=custom_generate_function_path, + **other_args_kwargs, + ) + if variant.startswith("agentic_tool_call"): + mock_tools.AGENTIC_MAX_TURNS = args_kwargs.get("generate_max_turns") + yield GenerateEnv(args=args, mock_server=mock_server) + + mock_tools.AGENTIC_MAX_TURNS = None + SingletonMeta.clear_all_instances() diff --git a/tests/fast/fixtures/rollout_fixtures.py b/tests/fast/fixtures/rollout_fixtures.py new file mode 100644 index 0000000000..44d8a50d79 --- /dev/null +++ b/tests/fast/fixtures/rollout_fixtures.py @@ -0,0 +1,127 @@ +""" +Fixtures to test rollout-function +""" + +import json +from argparse import Namespace +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from unittest.mock import patch + +import pytest +import requests + +from miles.rollout.data_source import DataSource, RolloutDataSourceWithBuffer +from miles.router.router import MilesRouter +from miles.utils.arguments import parse_args +from miles.utils.http_utils import find_available_port, init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +@dataclass(frozen=True) +class RolloutEnvConfig: + extra_argv: list[str] | None = None + data_rows: list[dict] | None = None + latency: float = 0.0 + + +@dataclass(frozen=True) +class RolloutEnv: + args: Namespace + data_source: DataSource + mock_server: MockSGLangServer + + +def _build_args(*, data_path: str, router_port: int, extra_argv: list[str] | None = None) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--n-samples-per-prompt", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + "Qwen/Qwen3-0.6B", + "--prompt-data", + data_path, + "--input-key", + "input", + "--label-key", + "label", + "--rm-type", + "math", + "--eval-prompt-data", + "toy", + data_path, + "--use-miles-router", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + (extra_argv or []) + with patch("sys.argv", argv): + args = parse_args() + args.miles_router_middleware_paths = [] + init_http_client(args) + return args + + +@contextmanager +def _with_miles_router(args: Namespace) -> Iterator[UvicornThreadServer]: + router = MilesRouter(args, verbose=False) + server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) + try: + server.start() + yield server + finally: + server.stop() + + +def _write_jsonl(path: str, rows: list[dict]) -> None: + Path(path).write_text("".join(json.dumps(row, ensure_ascii=False) + "\n" for row in rows), encoding="utf-8") + + +DEFAULT_DATA_ROWS = [{"input": "What is 1+7?", "label": "8"}] + + +@pytest.fixture +def rollout_env(tmp_path, request) -> RolloutEnv: + config = request.param + assert isinstance(config, RolloutEnvConfig) + + data_rows = config.data_rows or DEFAULT_DATA_ROWS + + data_path = str(tmp_path / "data.jsonl") + _write_jsonl(data_path, data_rows) + + router_port = find_available_port(20000) + args = _build_args(data_path=data_path, router_port=router_port, extra_argv=config.extra_argv) + + SingletonMeta.clear_all_instances() + + with with_mock_server(model_name=args.hf_checkpoint, latency=config.latency) as mock_server: + with _with_miles_router(args) as router_server: + r = requests.post( + f"{router_server.url}/add_worker", + params={"url": mock_server.url}, + timeout=5.0, + ) + r.raise_for_status() + + data_source = RolloutDataSourceWithBuffer(args) + yield RolloutEnv(args=args, data_source=data_source, mock_server=mock_server) + + SingletonMeta.clear_all_instances() diff --git a/tests/fast/rollout/__init__.py b/tests/fast/rollout/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/generate_hub/__init__.py b/tests/fast/rollout/generate_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/generate_hub/test_multi_turn.py b/tests/fast/rollout/generate_hub/test_multi_turn.py new file mode 100644 index 0000000000..5d974aaadd --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_multi_turn.py @@ -0,0 +1,572 @@ +from copy import deepcopy +from dataclasses import dataclass, replace +from itertools import groupby + +import numpy as np +import pybase64 +import pytest +from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate +from transformers import AutoTokenizer + +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, ThreeTurnStub, TwoTurnStub +from miles.utils.types import Sample + +_ = generation_env, SAMPLE_TOOLS, TwoTurnStub, ThreeTurnStub + + +def is_agentic_variant(variant: str) -> bool: + return variant in ("agentic_tool_call_single_sample", "agentic_tool_call_multi_samples") + + +# ------------------------------------ fixtures and consts ---------------------------------------- + + +MODEL_NAME = "Qwen/Qwen3-0.6B" +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} +TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + + +@pytest.fixture( + params=[ + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", + ] +) +def variant(request): + return request.param + + +@dataclass(frozen=True) +class SampleParsedChunk: + tokens_decoded_str: str + loss_mask_value: int + rollout_log_probs: list[float] + + +@dataclass +class ExpectedSampleInfo: + chunks: list[SampleParsedChunk] + partial_sample: Sample + + +def token_len(text: str) -> int: + return len(TOKENIZER(text, add_special_tokens=False)["input_ids"]) + + +def expected_chunk(text: str, loss_mask: int) -> SampleParsedChunk: + n = token_len(text) + log_probs = [-1 / 128 * i for i in range(n)] if loss_mask else [0.0] * n + return SampleParsedChunk(text, loss_mask, log_probs) + + +def parse_sample_into_chunks(sample: Sample, tokenizer) -> list[SampleParsedChunk]: + prompt_len = len(sample.tokens) - sample.response_length + response_tokens = sample.tokens[prompt_len:] + loss_mask = sample.loss_mask or [] + log_probs = sample.rollout_log_probs or [] + + chunks = [] + idx = 0 + for mask_val, group in groupby(loss_mask): + group_len = len(list(group)) + sli = slice(idx, idx + group_len) + chunks.append( + SampleParsedChunk( + tokens_decoded_str=tokenizer.decode(response_tokens[sli]), + loss_mask_value=mask_val, + rollout_log_probs=log_probs[sli], + ) + ) + idx += group_len + return chunks + + +def expected_partial_sample( + *, + prompt: list[dict], + response: str, + response_length: int, + status: Sample.Status = Sample.Status.COMPLETED, +) -> Sample: + return Sample( + prompt=prompt, + response=response, + response_length=response_length, + status=status, + tokens=[], + loss_mask=[], + rollout_log_probs=[], + weight_versions=[], + spec_info=Sample.SpecInfo(), + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + + +def verify_samples(actual: Sample | list[Sample], expected: list[ExpectedSampleInfo]): + actual = listify(actual) + assert len(actual) == len(expected) + + for actual_item, expected_item in zip(actual, expected, strict=True): + actual_chunks = parse_sample_into_chunks(actual_item, TOKENIZER) + assert actual_chunks == expected_item.chunks + + actual_partial = replace( + deepcopy(actual_item), + tokens=[], + loss_mask=[], + rollout_log_probs=[], + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + assert actual_partial == expected_item.partial_sample + + +def _run_generate(variant: str, env: GenerateEnv, sample: Sample, sampling_params: dict | None = None): + return run_generate(env, sample, sampling_params, variant=variant) + + +def expected_request(input_ids: list[int], sampling_params: dict | None = None) -> dict: + return { + "input_ids": input_ids, + "sampling_params": sampling_params or DEFAULT_SAMPLING_PARAMS, + "return_logprob": True, + "return_routed_experts": False, + } + + +def expected_openai_request(messages: list[dict]) -> dict: + return {"messages": messages, "model": "default", "tools": SAMPLE_TOOLS} + + +SINGLE_TURN_PROMPT = [{"role": "user", "content": "What is 1+1?"}] +SINGLE_TURN_RESPONSE = "The answer is 2." +_SINGLE_TURN_PROMPT_TEXT = TOKENIZER.apply_chat_template( + SINGLE_TURN_PROMPT, tokenize=False, add_generation_prompt=True, tools=SAMPLE_TOOLS +) +SINGLE_TURN_PROMPT_TOKEN_IDS = TOKENIZER(_SINGLE_TURN_PROMPT_TEXT, add_special_tokens=False)["input_ids"] +SINGLE_TURN_PROMPT_TOKEN_LEN = len(SINGLE_TURN_PROMPT_TOKEN_IDS) + + +# ------------------------------------ tests ---------------------------------------- + + +class TestBasicMultiTurn: + def test_single_turn_no_tool_call(self, variant, generation_env): + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=SINGLE_TURN_RESPONSE, finish_reason="stop" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(SINGLE_TURN_PROMPT)] + else: + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, response=SINGLE_TURN_RESPONSE, response_length=6 + ), + ), + ], + ) + + def test_two_turns_with_tool_call(self, variant, generation_env): + generation_env.mock_server.process_fn = TwoTurnStub.process_fn + + S = TwoTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [ + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), + ] + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + S.SECOND_RESPONSE + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestExitConditions: + def test_partial_rollout_not_supported(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("agentic_tool_call does not check partial_rollout flag") + generation_env.args.partial_rollout = True + + with pytest.raises(AssertionError, match="Partial rollout is not supported"): + _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + def test_abort_preserves_content(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("agentic_tool_call does not handle abort finish_reason") + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=SINGLE_TURN_RESPONSE, finish_reason="abort" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, + response=SINGLE_TURN_RESPONSE, + response_length=6, + status=Sample.Status.ABORTED, + ), + ), + ], + ) + + def test_finish_reason_length_exits_and_preserves_content(self, variant, generation_env): + S = TwoTurnStub + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="length") + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] + else: + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + status=Sample.Status.TRUNCATED, + ), + ), + ], + ) + + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"generate_max_turns": 1}}], indirect=True) + def test_max_turns_reached(self, variant, generation_env): + S = TwoTurnStub + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="stop") + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] + else: + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestRespectMaxContextLen: + @pytest.mark.parametrize( + "generation_env", [{"args_kwargs": {"rollout_max_context_len": SINGLE_TURN_PROMPT_TOKEN_LEN}}], indirect=True + ) + def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + assert result.requests == [] + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, response="", response_length=0, status=Sample.Status.TRUNCATED + ), + ) + ] + else: + expected = [] + verify_samples(result.sample, expected) + + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": { + "rollout_max_context_len": len(TwoTurnStub.FIRST_PROMPT_TOKEN_IDS) + + token_len(TwoTurnStub.FIRST_RESPONSE) + + token_len(TwoTurnStub.FIRST_TOOL_RESPONSE) + } + } + ], + indirect=True, + ) + def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + S = TwoTurnStub + generation_env.mock_server.process_fn = S.process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + if variant == "multi_turn_single_sample": + partial_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=partial_response, + response_length=token_len(partial_response), + status=Sample.Status.TRUNCATED, + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + status=Sample.Status.TRUNCATED, + ), + ), + ] + verify_samples(result.sample, expected) + + @pytest.mark.parametrize( + "generation_env,expected_max_new_tokens", + [ + ( + {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 10}}, + 10, + ), + ( + {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 100}}, + 64, + ), + ], + indirect=["generation_env"], + ) + def test_second_turn_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + S = TwoTurnStub + generation_env.mock_server.process_fn = S.process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + assert len(result.requests) >= 2 + assert result.requests[1]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens + assert result.requests[1]["sampling_params"]["temperature"] == DEFAULT_SAMPLING_PARAMS["temperature"] + + +class TestThreeTurn: + """Need to test 3-turn case besides 2-turn, because e.g. merge_samples may behave differently.""" + + def test_three_turns_with_sequential_tool_calls(self, variant, generation_env): + generation_env.mock_server.process_fn = ThreeTurnStub.process_fn + + S = ThreeTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [ + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + expected_openai_request(S.OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), + expected_request(S.THIRD_PROMPT_TOKEN_IDS), + ] + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = ( + S.FIRST_RESPONSE + + S.FIRST_TOOL_RESPONSE + + S.SECOND_RESPONSE + + S.SECOND_TOOL_RESPONSE + + S.THIRD_RESPONSE + ) + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), + expected_chunk(S.SECOND_TOOL_RESPONSE, 0), + expected_chunk(S.THIRD_RESPONSE, 1), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.THIRD_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.THIRD_RESPONSE, + response_length=token_len(S.THIRD_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestRoutedExpertsMultiTurn: + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": { + "use_rollout_routing_replay": True, + } + } + ], + indirect=True, + ) + def test_two_turns_routed_experts(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + + S = TwoTurnStub + num_layers, moe_router_topk = 2, 4 + generation_env.args.num_layers = num_layers + generation_env.args.moe_router_topk = moe_router_topk + + def make_routed_experts(prompt_token_ids, response_text): + total_tokens = len(prompt_token_ids) + token_len(response_text) + routed_experts_len = total_tokens - 1 + return np.arange(routed_experts_len * num_layers * moe_router_topk, dtype=np.int32).reshape( + routed_experts_len, num_layers, moe_router_topk + ) + + first_routed_experts = make_routed_experts(S.FIRST_PROMPT_TOKEN_IDS, S.FIRST_RESPONSE) + second_routed_experts = make_routed_experts(S.SECOND_PROMPT_TOKEN_IDS, S.SECOND_RESPONSE) + + def process_fn(prompt: str) -> ProcessResult: + if prompt == S.FIRST_PROMPT: + text, routed_experts = S.FIRST_RESPONSE, first_routed_experts + elif prompt == S.SECOND_PROMPT: + text, routed_experts = S.SECOND_RESPONSE, second_routed_experts + else: + raise ValueError(f"Unexpected prompt: {prompt}") + return ProcessResult( + text=text, + finish_reason="stop", + meta_info=ProcessResultMetaInfo( + routed_experts=pybase64.b64encode(routed_experts.tobytes()).decode("ascii") + ), + ) + + generation_env.mock_server.process_fn = process_fn + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT), DEFAULT_SAMPLING_PARAMS) + + sample = result.sample[-1] if isinstance(result.sample, list) else result.sample + assert sample.rollout_routed_experts is not None + assert sample.rollout_routed_experts.shape == second_routed_experts.shape + np.testing.assert_array_equal(sample.rollout_routed_experts, second_routed_experts) + assert len(sample.tokens) - 1 == second_routed_experts.shape[0] diff --git a/tests/fast/rollout/generate_hub/test_single_turn.py b/tests/fast/rollout/generate_hub/test_single_turn.py new file mode 100644 index 0000000000..a58e6fb3c6 --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_single_turn.py @@ -0,0 +1,424 @@ +import numpy as np +import pybase64 +import pytest +import torch +from PIL import Image +from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate +from transformers import AutoProcessor + +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo +from miles.utils.types import Sample + +_ = generation_env + +# ------------------------------------ fixtures and consts ---------------------------------------- + + +MODEL_NAME = "Qwen/Qwen3-0.6B" +PROMPT = "What is 1+7?" +PROMPT_TOKENS = [3838, 374, 220, 16, 10, 22, 30] +PROMPT_TOKEN_LEN = len(PROMPT_TOKENS) +RESPONSE_TOKENS = [59, 79075, 90, 23, 92] +RESPONSE_TEXT = "\\boxed{8}" +RESPONSE_LOG_PROBS = [-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125] +SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} +DEFAULT_MAX_NEW_TOKENS = SAMPLING_PARAMS["max_new_tokens"] + + +@pytest.fixture(params=["old_sglang_rollout", "single_turn", "multi_turn_single_sample", "multi_turn_multi_samples"]) +def variant(request): + return request.param + + +def expected_request( + variant: str, + *, + input_ids: list[int] | None = None, + sampling_params: dict | None = None, + return_routed_experts: bool = False, + image_data: list[str] | None = None, +) -> dict: + result = { + "input_ids": input_ids or PROMPT_TOKENS, + "sampling_params": sampling_params or SAMPLING_PARAMS, + "return_logprob": True, + } + if variant in ("single_turn", "multi_turn_single_sample", "multi_turn_multi_samples") or return_routed_experts: + result["return_routed_experts"] = return_routed_experts + if image_data is not None: + result["image_data"] = image_data + return result + + +class _Unset: + pass + + +_UNSET = _Unset() + + +def expected_sample( + variant: str, + *, + prompt: str = PROMPT, + response: str = RESPONSE_TEXT, + response_length: int = 5, + tokens: list[int] | None | _Unset = _UNSET, + rollout_log_probs: list[float] | None | _Unset = _UNSET, + status: Sample.Status = Sample.Status.COMPLETED, + cached_tokens: int = 0, + prompt_tokens: int = 7, + weight_versions: list[str] | None = None, + rollout_routed_experts: np.ndarray | None = None, + spec_info: Sample.SpecInfo | None = None, + multimodal_inputs: dict | None = None, + multimodal_train_inputs: dict | None = None, + loss_mask: list[int] | None | _Unset = _UNSET, +) -> Sample: + actual_response_length = response_length if response_length is not None else len(RESPONSE_TOKENS) + if isinstance(loss_mask, _Unset): + loss_mask = ( + [1] * actual_response_length + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") + else None + ) + + return Sample( + group_index=None, + index=None, + prompt=prompt, + tokens=PROMPT_TOKENS + RESPONSE_TOKENS if isinstance(tokens, _Unset) else tokens, + multimodal_inputs=multimodal_inputs, + multimodal_train_inputs=multimodal_train_inputs, + response=response, + response_length=response_length, + label=None, + reward=None, + loss_mask=loss_mask, + weight_versions=weight_versions or [], + rollout_log_probs=RESPONSE_LOG_PROBS if isinstance(rollout_log_probs, _Unset) else rollout_log_probs, + rollout_routed_experts=rollout_routed_experts, + remove_sample=False, + status=status, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=spec_info or Sample.SpecInfo(), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=cached_tokens, total_prompt_tokens=prompt_tokens), + ) + + +def _make_sample(tokens=None, response="", response_length=0, status=Sample.Status.PENDING, multimodal_inputs=None): + return make_sample( + prompt=PROMPT, + tokens=tokens, + response=response, + response_length=response_length, + status=status, + multimodal_inputs=multimodal_inputs, + ) + + +def _run_generate(variant: str, env: GenerateEnv, sample: Sample | None = None, sampling_params: dict | None = None): + return run_generate(env, sample or _make_sample(), sampling_params or SAMPLING_PARAMS, variant=variant) + + +# ------------------------------------ tests ---------------------------------------- + + +class TestBasicGeneration: + def test_basic_generation(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant)] + + +class TestResumedSingleTurn: + def test_two_consecutive_calls_on_same_sample(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + partial_text = "\\boxed" + partial_tokens = [59, 79075] + partial_log_probs = [-0.0, -0.0078125] + + remaining_text = "{8}" + remaining_tokens = [90, 23, 92] + remaining_log_probs = [-0.0, -0.0078125, -0.015625] + + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=partial_text, finish_reason="abort") + sample = _make_sample() + result1 = _run_generate(variant, generation_env, sample) + assert result1.requests == [expected_request(variant)] + assert result1.sample == expected_sample( + variant, + response=partial_text, + response_length=2, + tokens=PROMPT_TOKENS + partial_tokens, + rollout_log_probs=partial_log_probs, + status=Sample.Status.ABORTED, + ) + + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=remaining_text, finish_reason="stop") + result2 = _run_generate(variant, generation_env, result1.sample) + tokens_after_turn1 = PROMPT_TOKENS + partial_tokens + assert result2.requests == [ + expected_request( + variant, + input_ids=tokens_after_turn1, + sampling_params={"max_new_tokens": 14, "temperature": 0.7}, + ) + ] + assert result2.sample == expected_sample( + variant, + response=partial_text + remaining_text, + response_length=2 + 3, + tokens=tokens_after_turn1 + remaining_tokens, + rollout_log_probs=partial_log_probs + remaining_log_probs, + prompt_tokens=len(PROMPT_TOKENS) + len(tokens_after_turn1), + status=Sample.Status.COMPLETED, + ) + + +class TestFinishReason: + @pytest.mark.parametrize( + "generation_env,expected_status", + [ + ({"process_fn_kwargs": {"finish_reason": "stop"}}, Sample.Status.COMPLETED), + ({"process_fn_kwargs": {"finish_reason": "length"}}, Sample.Status.TRUNCATED), + ({"process_fn_kwargs": {"finish_reason": "abort"}}, Sample.Status.ABORTED), + ], + indirect=["generation_env"], + ) + def test_finish_reason_sets_status(self, variant, generation_env, expected_status): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant, status=expected_status)] + + +class TestRoutedExperts: + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": {"use_rollout_routing_replay": True}, + "process_fn_kwargs": {"routed_experts": "placeholder"}, + } + ], + indirect=True, + ) + def test_routed_experts_enabled_and_parsed(self, variant, generation_env): + num_layers, moe_router_topk = 2, 4 + num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS) + routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape( + num_tokens - 1, num_layers, moe_router_topk + ) + + generation_env.args.num_layers = num_layers + generation_env.args.moe_router_topk = moe_router_topk + routed_experts_str = pybase64.b64encode(routed_experts_array.tobytes()).decode("ascii") + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=RESPONSE_TEXT, + finish_reason="stop", + meta_info=ProcessResultMetaInfo(routed_experts=routed_experts_str), + ) + + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant, return_routed_experts=True)] + sample = result.sample[0] if isinstance(result.sample, list) else result.sample + assert sample.rollout_routed_experts is not None + assert sample.rollout_routed_experts.shape == (num_tokens - 1, num_layers, moe_router_topk) + np.testing.assert_array_equal(sample.rollout_routed_experts, routed_experts_array) + + +class TestMetaInfo: + @pytest.mark.parametrize( + "generation_env", [{"process_fn_kwargs": {"cached_tokens": 3, "weight_version": "v1.0"}}], indirect=True + ) + def test_meta_info_fields_updated(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant, cached_tokens=3, weight_versions=["v1.0"])] + + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": {"sglang_speculative_algorithm": "EAGLE"}, + "process_fn_kwargs": {"spec_accept_token_num": 10, "spec_draft_token_num": 15, "spec_verify_ct": 3}, + } + ], + indirect=True, + ) + def test_spec_info_updated(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [ + expected_sample( + variant, + spec_info=Sample.SpecInfo( + spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3, completion_token_num=5 + ), + ) + ] + + +class TestInputStatusValidation: + @pytest.mark.parametrize("status", [Sample.Status.PENDING, Sample.Status.ABORTED]) + def test_allowed_statuses(self, variant, generation_env, status): + result = _run_generate(variant, generation_env, _make_sample(status=status)) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant)] + + @pytest.mark.parametrize("status", [Sample.Status.COMPLETED, Sample.Status.TRUNCATED]) + def test_rejected_statuses(self, variant, generation_env, status): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + with pytest.raises(AssertionError): + _run_generate(variant, generation_env, _make_sample(status=status)) + + +class TestPayloadStructure: + def test_sampling_params_passed_through(self, variant, generation_env): + result = _run_generate( + variant, generation_env, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9} + ) + assert result.requests == [ + expected_request(variant, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) + ] + assert listify(result.sample) == [expected_sample(variant)] + + +class TestBoundaryConditions: + def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + existing_tokens = [1, 2, 3, 4, 5, 6, 7] + list(range(100, 110)) + sample = _make_sample(tokens=existing_tokens, response="x" * 10, response_length=10) + + result = _run_generate(variant, generation_env, sample, {"max_new_tokens": 10, "temperature": 0.7}) + assert result.requests == [] + assert result.sample == expected_sample( + variant, + response="x" * 10, + response_length=10, + tokens=existing_tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + ) + + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"rollout_max_context_len": 5}}], indirect=True) + def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + if variant == "multi_turn_multi_samples": + pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") + result = _run_generate(variant, generation_env) + assert result.requests == [] + tokens = PROMPT_TOKENS if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") else [] + assert listify(result.sample) == [ + expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, + ) + ] + + @pytest.mark.parametrize( + "generation_env,expected_max_new_tokens", + [ + ({"args_kwargs": {"rollout_max_context_len": 10}}, 10 - PROMPT_TOKEN_LEN), + ({"args_kwargs": {"rollout_max_context_len": 8}}, 8 - PROMPT_TOKEN_LEN), + ({"args_kwargs": {"rollout_max_context_len": 100}}, DEFAULT_MAX_NEW_TOKENS), + ], + indirect=["generation_env"], + ) + def test_moderate_length_input_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + result = _run_generate(variant, generation_env) + assert len(result.requests) == 1 + assert result.requests[0]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens + assert result.requests[0]["sampling_params"]["temperature"] == SAMPLING_PARAMS["temperature"] + assert listify(result.sample) == [expected_sample(variant)] + + @pytest.mark.parametrize( + "generation_env", + [{"args_kwargs": {"rollout_max_context_len": PROMPT_TOKEN_LEN}}], + indirect=True, + ) + def test_adjusted_max_new_tokens_zero_returns_truncated(self, variant, generation_env): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + if variant == "multi_turn_multi_samples": + pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") + result = _run_generate(variant, generation_env) + assert result.requests == [] + tokens = PROMPT_TOKENS if variant == "multi_turn_single_sample" else [] + assert listify(result.sample) == [ + expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, + ) + ] + + +class TestEmptyResponse: + @pytest.mark.parametrize("generation_env", [{"process_fn_kwargs": {"response_text": ""}}], indirect=True) + def test_empty_response(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [ + expected_sample(variant, response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[]) + ] + + +VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" + + +class TestMultimodal: + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True) + def test_multimodal_inputs_processed(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + test_image = Image.new("RGB", (64, 64), color="red") + multimodal_inputs = {"images": [test_image]} + processor = AutoProcessor.from_pretrained(VLM_MODEL_NAME, trust_remote_code=True) + expected_mti = { + k: v + for k, v in processor(text=PROMPT, **multimodal_inputs).items() + if k not in ["input_ids", "attention_mask"] + } + + result = _run_generate(variant, generation_env, _make_sample(multimodal_inputs=multimodal_inputs)) + + assert result.requests == [ + expected_request( + variant, + input_ids=PROMPT_TOKENS, + image_data=[encode_image_for_rollout_engine(test_image)], + ) + ] + actual_mti = result.sample.multimodal_train_inputs + assert actual_mti is not None + assert set(actual_mti.keys()) == set(expected_mti.keys()) + assert torch.all(actual_mti["pixel_values"] == expected_mti["pixel_values"]) + assert torch.all(actual_mti["image_grid_thw"] == expected_mti["image_grid_thw"]) + assert result.sample == expected_sample( + variant, + tokens=PROMPT_TOKENS + RESPONSE_TOKENS, + multimodal_inputs=multimodal_inputs, + multimodal_train_inputs=actual_mti, + ) diff --git a/tests/fast/rollout/generate_hub/test_tool_call_utils.py b/tests/fast/rollout/generate_hub/test_tool_call_utils.py new file mode 100644 index 0000000000..0f2305e753 --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_tool_call_utils.py @@ -0,0 +1,99 @@ +import pytest + +from miles.rollout.generate_utils.tool_call_utils import _DUMMY_USER, _build_dummy_assistant, tokenize_tool_responses + +TOOL_CALL_TEST_MODELS = [ + "Qwen/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen3-0.6B", + "Qwen/Qwen3-4B-Instruct-2507", + "Qwen/Qwen3-Coder-30B-A3B-Instruct", + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI + "mistralai/Mistral-7B-Instruct-v0.3", + "deepseek-ai/DeepSeek-V3", + "stepfun-ai/step3", + "MiniMaxAI/MiniMax-M2", + "internlm/internlm3-8b-instruct", + "THUDM/glm-4-9b-chat", + "moonshotai/Kimi-K2-Instruct", + "XiaomiMiMo/MiMo-7B-RL", +] + +SINGLE_TOOL_CALL_ONLY_MODELS = [ + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo +] + +# Models where tokenize->decode produces extra whitespace vs direct string diff +TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS = [ + "THUDM/glm-4-9b-chat", +] + +SAMPLE_TOOL_RESPONSES = [ + { + "role": "tool", + "tool_call_id": "call00000", + "content": '{"year": 2026}', + "name": "get_year", + }, + { + "role": "tool", + "tool_call_id": "call00001", + "content": '{"temperature": 25}', + "name": "get_temperature", + }, +] + + +class TestTokenizeToolResponses: + @pytest.mark.parametrize("model_name", ["Qwen/Qwen3-0.6B"]) + def test_snapshot(self, model_name): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + token_ids = tokenize_tool_responses(SAMPLE_TOOL_RESPONSES, tokenizer) + decoded = tokenizer.decode(token_ids) + + assert decoded == ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": 25}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + @pytest.mark.parametrize("num_tools", [1, 2]) + @pytest.mark.parametrize("model_name", TOOL_CALL_TEST_MODELS) + def test_tokenize_tool_responses(self, model_name, num_tools): + if num_tools > 1 and model_name in SINGLE_TOOL_CALL_ONLY_MODELS: + pytest.skip(f"{model_name} only supports single tool call") + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + tool_responses = SAMPLE_TOOL_RESPONSES[:num_tools] + assert len(tool_responses) == num_tools + + actual_token_ids = tokenize_tool_responses(tool_responses, tokenizer) + actual_str = tokenizer.decode(actual_token_ids) + + dummy_assistant = _build_dummy_assistant(tool_responses) + base_messages = [_DUMMY_USER, dummy_assistant] + expected_str = self._compute_chat_template_diff(base_messages, tool_responses, tokenizer) + + if model_name in TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS: + # Some models produce whitespace differences between tokenize->decode and direct string diff + actual_str = actual_str.replace(" ", "") + expected_str = expected_str.replace(" ", "") + + assert actual_str == expected_str, f"{model_name=}" + + @staticmethod + def _compute_chat_template_diff(base_messages, extra_messages, tokenizer) -> str: + text_with = tokenizer.apply_chat_template( + base_messages + extra_messages, tokenize=False, add_generation_prompt=True + ) + text_without = tokenizer.apply_chat_template(base_messages, tokenize=False, add_generation_prompt=False) + return text_with[len(text_without) :] diff --git a/tests/fast/rollout/generate_utils/__init__.py b/tests/fast/rollout/generate_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/generate_utils/test_sample_utils.py b/tests/fast/rollout/generate_utils/test_sample_utils.py new file mode 100644 index 0000000000..c53fbbb56a --- /dev/null +++ b/tests/fast/rollout/generate_utils/test_sample_utils.py @@ -0,0 +1,156 @@ +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.generate_utils.sample_utils import _merge_sample_pair +from miles.utils.types import Sample + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + tokenizer.decode = lambda tokens: f"" + return tokenizer + + +def make_sample( + prompt="test_prompt", + tokens=None, + response="", + response_length=0, + loss_mask=None, + rollout_log_probs=None, + status=Sample.Status.COMPLETED, + label="test_label", + reward=1.0, + index=0, + group_index=0, +): + return Sample( + prompt=prompt, + tokens=tokens or [], + response=response, + response_length=response_length, + loss_mask=loss_mask, + rollout_log_probs=rollout_log_probs, + status=status, + label=label, + reward=reward, + index=index, + group_index=group_index, + ) + + +class TestMergeSamples: + def test_basic_merge(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 3, 10, 11, 12], + response="response1", + response_length=3, + loss_mask=[1, 1, 1], + rollout_log_probs=[-0.1, -0.2, -0.3], + ) + b = make_sample( + tokens=[1, 2, 3, 10, 11, 12, 20, 21, 30, 31, 32], + response="response2", + response_length=3, + loss_mask=[1, 1, 1], + rollout_log_probs=[-0.4, -0.5, -0.6], + status=Sample.Status.TRUNCATED, + ) + + merged = _merge_sample_pair(a, b, mock_tokenizer) + + assert merged.tokens == b.tokens + assert merged.response_length == 3 + 2 + 3 + assert merged.loss_mask == [1, 1, 1, 0, 0, 1, 1, 1] + assert merged.rollout_log_probs == [-0.1, -0.2, -0.3, 0.0, 0.0, -0.4, -0.5, -0.6] + assert merged.prompt == a.prompt + assert merged.status == b.status + assert merged.label == a.label + assert merged.index == a.index + assert merged.group_index == a.group_index + assert "response1" in merged.response + assert "response2" in merged.response + assert "" in merged.response + + def test_loss_mask_none_defaults_to_all_ones(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=None, + rollout_log_probs=None, + ) + b = make_sample( + tokens=[1, 2, 10, 20, 30], + response_length=1, + loss_mask=None, + rollout_log_probs=None, + ) + + merged = _merge_sample_pair(a, b, mock_tokenizer) + + assert merged.loss_mask == [1, 0, 1] + assert merged.rollout_log_probs == [0.0, 0.0, 0.0] + + def test_tokens_prefix_mismatch_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 3], + response_length=1, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 99, 20, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="b.tokens must start with a.tokens"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_field_mismatch_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=[1], + index=0, + ) + b = make_sample( + tokens=[1, 2, 10, 20, 30], + response_length=1, + loss_mask=[1], + index=1, + ) + + with pytest.raises(AssertionError, match="index mismatch"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_obs_len_invalid_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 10, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="obs_len must be > 0"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_sample_validate_fails_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10, 11], + response_length=2, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 10, 11, 20, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="loss_mask length"): + _merge_sample_pair(a, b, mock_tokenizer) diff --git a/tests/fast/rollout/inference_rollout/__init__.py b/tests/fast/rollout/inference_rollout/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/inference_rollout/conftest.py b/tests/fast/rollout/inference_rollout/conftest.py new file mode 100644 index 0000000000..ca47edeeb6 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/conftest.py @@ -0,0 +1,45 @@ +from unittest.mock import patch + +import pytest + +from miles.utils.arguments import parse_args + + +def _build_mock_args(extra_argv: list[str] | None = None): + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "2", + "--n-samples-per-prompt", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "4", + "--rollout-num-gpus-per-engine", + "2", + "--hf-checkpoint", + "Qwen/Qwen3-0.6B", + "--prompt-data", + "/dev/null", + "--input-key", + "input", + "--label-key", + "label", + "--rm-type", + "math", + "--use-miles-router", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + "30000", + ] + (extra_argv or []) + with patch("sys.argv", argv): + return parse_args() + + +@pytest.fixture +def mock_args(): + return _build_mock_args() diff --git a/tests/fast/rollout/inference_rollout/integration/__init__.py b/tests/fast/rollout/inference_rollout/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/inference_rollout/integration/test_basic.py b/tests/fast/rollout/inference_rollout/integration/test_basic.py new file mode 100644 index 0000000000..5b791829d5 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_basic.py @@ -0,0 +1,69 @@ +import pytest +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import ( + MODULAR_ROLLOUT_BASE_ARGV, + expected_sample, + load_and_call_train, +) + +from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function + +_VARIANTS = [ + pytest.param( + RolloutEnvConfig( + extra_argv=[ + "--rollout-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--eval-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ] + ), + id="old_rollout_old_generate", + ), + pytest.param( + RolloutEnvConfig( + extra_argv=[ + "--rollout-function-path", + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ] + ), + id="new_rollout_old_generate", + ), + pytest.param( + RolloutEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant("single_turn")), + id="new_rollout_new_generate", + ), +] + + +@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) +def test_train(rollout_env): + env = rollout_env + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + assert len(group) == env.args.n_samples_per_prompt + assert group[0] == expected_sample(group_index=0) + + +@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) +def test_eval(rollout_env): + env = rollout_env + fn = load_rollout_function( + RolloutFnConstructorInput(args=env.args, data_source=env.data_source), env.args.eval_function_path + ) + out = call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + assert "toy" in out.data + rewards = out.data["toy"]["rewards"] + samples = out.data["toy"]["samples"] + assert len(rewards) == len(samples) == env.args.n_samples_per_eval_prompt + assert rewards[0] == 1 + assert samples[0] == expected_sample(group_index=None) diff --git a/tests/fast/rollout/inference_rollout/integration/test_deterministic.py b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py new file mode 100644 index 0000000000..69a2359117 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py @@ -0,0 +1,37 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + + +@pytest.mark.parametrize( + "rollout_env,expected_seeds", + [ + pytest.param( + integration_env_config( + [ + "--sglang-enable-deterministic-inference", + "--rollout-seed", + "42", + "--n-samples-per-prompt", + "3", + "--rollout-batch-size", + "1", + ] + ), + {42, 43, 44}, + id="enabled", + ), + pytest.param( + integration_env_config(["--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + {None}, + id="disabled", + ), + ], + indirect=["rollout_env"], +) +def test_sampling_seeds(rollout_env, expected_seeds): + env = rollout_env + load_and_call_train(env.args, env.data_source) + + seeds = {req.get("sampling_params", {}).get("sampling_seed") for req in env.mock_server.request_log} + assert seeds == expected_seeds diff --git a/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py new file mode 100644 index 0000000000..0ca5743ac5 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py @@ -0,0 +1,46 @@ +from contextlib import nullcontext + +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + MIXED_DATA_ROWS, + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + + +@pytest.mark.parametrize( + "rollout_env,use_filter,expect_all_correct", + [ + pytest.param( + integration_env_config(["--rollout-batch-size", "4"], data_rows=MIXED_DATA_ROWS), + False, + False, + id="no_filter", + ), + pytest.param( + integration_env_config( + ["--rollout-batch-size", "3", "--dynamic-sampling-filter-path", "test:filter_by_reward"], + data_rows=MIXED_DATA_ROWS, + ), + True, + True, + id="with_filter", + ), + ], + indirect=["rollout_env"], +) +def test_filter_effect(rollout_env, use_filter, expect_all_correct): + env = rollout_env + ctx = function_registry.temporary("test:filter_by_reward", filter_by_reward) if use_filter else nullcontext() + + with ctx: + out = load_and_call_train(env.args, env.data_source) + + rewards = {group[0].reward for group in out.samples} + if expect_all_correct: + assert rewards == {1}, "Filter should keep only correct samples" + else: + assert 0 in rewards, "Without filter, incorrect samples should be present" diff --git a/tests/fast/rollout/inference_rollout/integration/test_group_rm.py b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py new file mode 100644 index 0000000000..afd870c302 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py @@ -0,0 +1,22 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + integration_env_config(["--group-rm", "--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + id="group_rm_enabled", + ), + ], + indirect=True, +) +def test_group_rm_rewards_set(rollout_env): + env = rollout_env + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + rewards = [sample.reward for group in out.samples for sample in group] + assert all(r in (0, 1) for r in rewards) diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py new file mode 100644 index 0000000000..2b12d3d88f --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py @@ -0,0 +1,65 @@ +import pytest +from tests.fast.fixtures.rollout_fixtures import DEFAULT_DATA_ROWS, RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_train + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.utils.misc import function_registry +from miles.utils.types import Sample + + +async def _multi_sample_generate(input: GenerateFnInput) -> GenerateFnOutput: + sample = input.sample + s1 = Sample( + prompt=sample.prompt, + response="\\boxed{8}", + response_length=5, + tokens=sample.tokens + [59, 79075, 90, 23, 92], + label=sample.label, + reward=None, + status=Sample.Status.COMPLETED, + ) + s2 = Sample( + prompt=sample.prompt, + response="\\boxed{8}", + response_length=5, + tokens=sample.tokens + [59, 79075, 90, 23, 92], + label=sample.label, + reward=0.5, + status=Sample.Status.COMPLETED, + ) + return GenerateFnOutput(samples=[s1, s2]) + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + + [ + "--custom-generate-function-path", + "test:multi_sample_generate", + "--rollout-batch-size", + "1", + "--n-samples-per-prompt", + "1", + ], + data_rows=DEFAULT_DATA_ROWS, + ), + id="multi_sample_output", + ), + ], + indirect=True, +) +def test_multi_sample_output_preserves_existing_reward(rollout_env): + env = rollout_env + with function_registry.temporary("test:multi_sample_generate", _multi_sample_generate): + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + assert isinstance(group[0], list) + samples = group[0] + assert len(samples) == 2 + assert samples[0].reward == 1 + assert samples[1].reward == 0.5 diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py new file mode 100644 index 0000000000..c41d713991 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py @@ -0,0 +1,114 @@ +from typing import Any + +import pytest +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_rollout + +from miles.utils.test_utils.mock_tools import TwoTurnStub +from miles.utils.types import Sample + + +TWO_TURN_DATA_ROWS = [{"input": [{"role": "user", "content": TwoTurnStub.USER_QUESTION}], "label": "2008"}] + +_VARIANT_NAMES = [ + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", +] + +_BASE_EXTRA_ARGV = [ + "--rollout-batch-size", + "2", + "--n-samples-per-prompt", + "2", + "--n-samples-per-eval-prompt", + "2", + "--custom-rm-path", + "tests.fast.rollout.inference_rollout.integration.test_multi_turn._simple_reward_function", +] + + +def _config_for_variant(variant: str) -> RolloutEnvConfig: + return RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + _BASE_EXTRA_ARGV, + data_rows=TWO_TURN_DATA_ROWS, + ) + + +@pytest.mark.parametrize( + "variant,rollout_env", + [pytest.param(variant, _config_for_variant(variant), id=variant) for variant in _VARIANT_NAMES], + indirect=["rollout_env"], +) +@pytest.mark.parametrize("test_type", ["train", "eval"]) +def test_rollout(rollout_env, variant, test_type): + env = rollout_env + env.mock_server.process_fn = TwoTurnStub.process_fn + + out = load_and_call_rollout(env.args, env.data_source, mode=test_type) + + if test_type == "train": + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + _verify_samples(variant, group) + else: + assert "toy" in out.data + samples = out.data["toy"]["samples"] + _verify_samples(variant, samples) + + +def _verify_samples(variant: str, samples: list[Any]): + is_multi_samples = variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples") + + if is_multi_samples: + if len(samples) > 0 and isinstance(samples[0], list): + # Train mode: list[list[Sample]], grouped by prompt + assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" + for group_sample in samples: + assert isinstance(group_sample, list), "multi_samples variant should return list[Sample] per generate" + _verify_group_samples(group_sample) + else: + # Eval mode: list[Sample], flattened + # n_samples_per_eval_prompt=2, and each generate returns 2 turns, so 2*2=4 samples + assert ( + len(samples) == 4 + ), f"n_samples_per_eval_prompt=2, each generate returns 2 turns, so should have 4 samples, got {len(samples)}" + # Group samples by prompt (every 2 samples form a group) + group_samples_list = [samples[i : i + 2] for i in range(0, len(samples), 2)] + for group_samples in group_samples_list: + _verify_group_samples(group_samples) + else: + assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" + for sample in samples: + assert isinstance(sample, Sample), "single_sample variant should return Sample, not list" + _verify_sample(sample) + + +def _verify_group_samples(group_samples: list[Sample], expected_count: int = 2): + assert len(group_samples) == expected_count, f"Group should have {expected_count} samples (one per turn)" + for i, sample in enumerate(group_samples): + _verify_sample(sample, expect_answer=(i == len(group_samples) - 1)) + + +def _verify_sample(sample: Sample, expected_reward: float = 1.0, expect_answer: bool = True): + assert sample.status == Sample.Status.COMPLETED + assert sample.reward == expected_reward, f"Sample should have reward={expected_reward}" + if expect_answer: + assert "2008" in sample.response, "Response should contain final answer '2008'" + + +async def _simple_reward_function(args, samples: Sample | list[Sample]) -> float | list[float]: + if isinstance(samples, list): + # For multi_samples variants, use the last sample's reward + if getattr(args, "generate_multi_samples", False): + return [_check_reward(samples[-1])] * len(samples) + else: + return [_check_reward(sample) for sample in samples] + else: + return _check_reward(samples) + + +def _check_reward(sample: Sample) -> float: + return float(sample.response and (str(sample.label) in sample.response)) diff --git a/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py new file mode 100644 index 0000000000..0812962cc7 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py @@ -0,0 +1,48 @@ +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + +_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, + {"input": "What is 1+8?", "label": "wrong"}, + {"input": "What is 1+9?", "label": "wrong"}, + {"input": "What is 1+6?", "label": "wrong"}, +] + +_BASE_ARGV = [ + "--over-sampling-batch-size", + "4", + "--dynamic-sampling-filter-path", + "test:filter_by_reward", +] + + +def _over_sampling_config(rollout_batch_size: int): + return integration_env_config(["--rollout-batch-size", str(rollout_batch_size)] + _BASE_ARGV, data_rows=_DATA_ROWS) + + +@pytest.mark.parametrize( + "rollout_env,expected_rounds", + [ + pytest.param(_over_sampling_config(1), 1, id="one_round"), + pytest.param(_over_sampling_config(2), 2, id="two_rounds"), + ], + indirect=["rollout_env"], +) +def test_over_sampling_rounds(rollout_env, expected_rounds): + env = rollout_env + + with function_registry.temporary("test:filter_by_reward", filter_by_reward): + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + assert all(group[0].reward == 1 for group in out.samples) + + requests_count = len(env.mock_server.request_log) + expected_requests = expected_rounds * env.args.over_sampling_batch_size + assert requests_count == expected_requests, f"Expected {expected_rounds} round(s) = {expected_requests} requests" diff --git a/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py new file mode 100644 index 0000000000..36e78c16c1 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py @@ -0,0 +1,67 @@ +from unittest.mock import Mock + +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + +# Data with only 2 reward=1 samples out of 4. +# This ensures all 4 samples must be generated to collect 2 valid ones. +_FILTER_TEST_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, # reward=1 + {"input": "What is 1+8?", "label": "wrong"}, # reward=0 + {"input": "What is 1+9?", "label": "wrong"}, # reward=0 + {"input": "What is 1+6?", "label": "7"}, # reward=1 +] + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + integration_env_config( + [ + "--rollout-batch-size", + "2", + "--over-sampling-batch-size", + "4", + "--dynamic-sampling-filter-path", + "test:filter_by_reward", + "--rollout-sample-filter-path", + "test:sample_filter", + "--rollout-all-samples-process-path", + "test:all_samples_process", + ], + data_rows=_FILTER_TEST_DATA_ROWS, + ), + id="sample_filter_vs_all_samples", + ), + ], + indirect=True, +) +def test_sample_filter_and_all_samples_process(rollout_env): + env = rollout_env + sample_filter_mock = Mock() + all_samples_process_mock = Mock() + + with ( + function_registry.temporary("test:filter_by_reward", filter_by_reward), + function_registry.temporary("test:sample_filter", sample_filter_mock), + function_registry.temporary("test:all_samples_process", all_samples_process_mock), + ): + load_and_call_train(env.args, env.data_source) + + sample_filter_mock.assert_called_once() + _, filtered_data = sample_filter_mock.call_args[0] + rewards = [g[0][0].reward if isinstance(g[0], list) else g[0].reward for g in filtered_data] + assert all(r == 1 for r in rewards) + + all_samples_process_mock.assert_called_once() + _, all_samples, data_source = all_samples_process_mock.call_args[0] + assert data_source is not None + + assert len(all_samples) > len(filtered_data), "all_samples_process should see more samples than sample_filter" diff --git a/tests/fast/rollout/inference_rollout/integration/test_semaphore.py b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py new file mode 100644 index 0000000000..889a9ff8ac --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py @@ -0,0 +1,33 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + +_DATA_ROWS = [{"input": f"What is 1+{i}?", "label": str(1 + i)} for i in range(10)] +_BASE_ARGV = ["--rollout-batch-size", "4", "--n-samples-per-prompt", "2"] + + +@pytest.mark.parametrize( + "rollout_env,expected_range", + [ + pytest.param( + integration_env_config( + ["--sglang-server-concurrency", "1"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 + ), + (1, 1), + id="limit_1", + ), + pytest.param( + integration_env_config( + ["--sglang-server-concurrency", "999"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 + ), + (2, 999), + id="no_limit", + ), + ], + indirect=["rollout_env"], +) +def test_max_concurrent(rollout_env, expected_range): + env = rollout_env + load_and_call_train(env.args, env.data_source) + min_expected, max_expected = expected_range + assert min_expected <= env.mock_server.max_concurrent <= max_expected diff --git a/tests/fast/rollout/inference_rollout/integration/utils.py b/tests/fast/rollout/inference_rollout/integration/utils.py new file mode 100644 index 0000000000..ad413cf949 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/utils.py @@ -0,0 +1,89 @@ +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig + +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnOutput, + RolloutFnTrainInput, +) +from miles.rollout.filter_hub.base_types import DynamicFilterOutput +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.utils.types import Sample + + +def expected_sample(*, group_index: int | None) -> Sample: + return Sample( + group_index=group_index, + index=0, + prompt="What is 1+7?", + tokens=[3838, 374, 220, 16, 10, 22, 30, 59, 79075, 90, 23, 92], + multimodal_inputs=None, + multimodal_train_inputs=None, + response="\\boxed{8}", + response_length=5, + label="8", + reward=1, + loss_mask=None, + weight_versions=[], + rollout_log_probs=[-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125], + rollout_routed_experts=None, + remove_sample=False, + status=Sample.Status.COMPLETED, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=Sample.SpecInfo( + spec_accept_token_num=0, spec_draft_token_num=0, spec_verify_ct=0, completion_token_num=0 + ), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=0, total_prompt_tokens=7), + ) + + +MODULAR_ROLLOUT_BASE_ARGV = [ + "--rollout-function-path", + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", +] + +MIXED_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, + {"input": "What is 1+8?", "label": "9"}, + {"input": "What is 1+9?", "label": "wrong"}, + {"input": "What is 1+6?", "label": "7"}, +] + + +def integration_env_config( + extra_argv: list[str], + data_rows: list[dict] | None = None, + latency: float = 0.0, + variant: str = "single_turn", +): + return RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + extra_argv, + data_rows=data_rows, + latency=latency, + ) + + +def load_and_call_rollout(args, data_source, mode: str = "train") -> RolloutFnOutput: + function_path = args.rollout_function_path if mode == "train" else args.eval_function_path + fn = load_rollout_function( + RolloutFnConstructorInput(args=args, data_source=data_source), + function_path, + ) + if mode == "train": + return call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0)) + else: + return call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + +def load_and_call_train(args, data_source): + return load_and_call_rollout(args, data_source, mode="train") + + +def filter_by_reward(args, samples, **kwargs): + reward = samples[0].reward if not isinstance(samples[0], list) else samples[0][0].reward + if reward == 1: + return DynamicFilterOutput(keep=True) + return DynamicFilterOutput(keep=False, reason="reward_zero") diff --git a/tests/fast/rollout/inference_rollout/test_compatibility.py b/tests/fast/rollout/inference_rollout/test_compatibility.py new file mode 100644 index 0000000000..ddfecd067b --- /dev/null +++ b/tests/fast/rollout/inference_rollout/test_compatibility.py @@ -0,0 +1,196 @@ +import asyncio +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.base_types import ( + GenerateFnInput, + GenerateFnOutput, + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnEvalOutput, + RolloutFnTrainInput, + RolloutFnTrainOutput, +) +from miles.rollout.inference_rollout.compatibility import ( + LegacyGenerateFnAdapter, + LegacyRolloutFnAdapter, + call_rollout_function, + load_generate_function, + load_rollout_function, +) +from miles.utils.async_utils import run +from miles.utils.misc import function_registry + + +@pytest.fixture +def constructor_input(): + return RolloutFnConstructorInput(args="dummy_args", data_source="dummy_data_source") + + +@pytest.fixture +def make_generate_fn_input(): + def _make(evaluation: bool = False): + state = MagicMock() + state.args = MagicMock() + + return GenerateFnInput( + state=state, + sample={"text": "test prompt"}, + sampling_params={"temperature": 0.7}, + evaluation=evaluation, + ) + + return _make + + +class TestSupportedRolloutFormats: + """ + Documentation test to show various supported rollout function formats + """ + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_1_legacy_function_raw_output(self, constructor_input, evaluation): + def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): + if evaluation: + return {"metric": {"accuracy": 0.9}} + return [[{"text": "sample"}]] + + with function_registry.temporary("test:legacy_rollout", legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "test:legacy_rollout") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, LegacyRolloutFnAdapter) + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"metric": {"accuracy": 0.9}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "sample"}]] + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_2_legacy_function_typed_output(self, constructor_input, evaluation): + def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): + if evaluation: + return RolloutFnEvalOutput(data={"ds": {"acc": 0.95}}) + return RolloutFnTrainOutput(samples=[[{"text": "typed"}]]) + + with function_registry.temporary("test:legacy_typed", legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "test:legacy_typed") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"ds": {"acc": 0.95}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "typed"}]] + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_3_sync_class(self, constructor_input, evaluation): + class SyncRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + pass + + def __call__(self, input): + if input.evaluation: + return RolloutFnEvalOutput(data={"test": {"score": 1}}) + return RolloutFnTrainOutput(samples=[[{"text": "sync"}]]) + + with function_registry.temporary("test:sync_class", SyncRolloutFn): + fn = load_rollout_function(constructor_input, "test:sync_class") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, SyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_4_async_class(self, constructor_input, evaluation): + class AsyncRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + pass + + async def __call__(self, input): + await asyncio.sleep(0.001) + if input.evaluation: + return RolloutFnEvalOutput(data={"benchmark": {"accuracy": 0.98}}) + return RolloutFnTrainOutput(samples=[[{"text": "async"}]]) + + with function_registry.temporary("test:async_class", AsyncRolloutFn): + fn = load_rollout_function(constructor_input, "test:async_class") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, AsyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) + + +class TestSupportedGenerateFormats: + """ + Documentation test similar to TestSupportedRolloutFormats + """ + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_1_legacy_function_with_evaluation_param(self, make_generate_fn_input, evaluation): + async def legacy_generate_fn(args, sample, sampling_params, evaluation=False): + return "my_sample" + + with function_registry.temporary("test:legacy_gen_eval", legacy_generate_fn): + fn = load_generate_function("test:legacy_gen_eval") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_2_legacy_function_without_evaluation_param(self, make_generate_fn_input, evaluation): + async def legacy_generate_fn(args, sample, sampling_params): + return "my_sample" + + with function_registry.temporary("test:legacy_gen", legacy_generate_fn): + fn = load_generate_function("test:legacy_gen") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_3_new_async_function_api(self, make_generate_fn_input, evaluation): + async def generate(input: GenerateFnInput) -> GenerateFnOutput: + return GenerateFnOutput(samples="my_sample") + + with function_registry.temporary("test:new_async", generate): + fn = load_generate_function("test:new_async") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_4_new_class_api(self, make_generate_fn_input, evaluation): + class MyGenerateFn: + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: + return GenerateFnOutput(samples="my_sample") + + with function_registry.temporary("test:new_class", MyGenerateFn): + fn = load_generate_function("test:new_class") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, MyGenerateFn) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" diff --git a/tests/fast/rollout/rm_hub/__init__.py b/tests/fast/rollout/rm_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/rm_hub/test_deepscaler.py b/tests/fast/rollout/rm_hub/test_deepscaler.py new file mode 100644 index 0000000000..bd4c606a68 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_deepscaler.py @@ -0,0 +1,26 @@ +import pytest + +from miles.rollout.rm_hub.deepscaler import get_deepscaler_rule_based_reward + + +class TestGetDeepscalerRuleBasedReward: + @pytest.mark.parametrize( + "response,label,expected", + [ + (r"Let me analyze...The answer is \boxed{42}", "42", 1), + (r"Thinking...The answer is \boxed{wrong}", "42", 0), + (r"###Response\boxed{42}", "42", 1), + (r"###Response\boxed{wrong}", "42", 0), + (r"The answer is \boxed{42}", "42", 0), + (r"The answer is 42", "42", 0), + (r"\boxed{42}", "", 0), + (r"\boxed{42}", r"\boxed{42}", 1), + (r"\boxed{123}", 123, 1), + (r"\boxed{3.14}", 3.14, 1), + (r"\boxed{1/2}", "0.5", 1), + (r"\boxed{\frac{1}{2}}", "0.5", 1), + (r"First thoughtSecond thought\boxed{42}", "42", 1), + ], + ) + def test_get_deepscaler_rule_based_reward(self, response, label, expected): + assert get_deepscaler_rule_based_reward(response, label) == expected diff --git a/tests/fast/rollout/rm_hub/test_f1.py b/tests/fast/rollout/rm_hub/test_f1.py new file mode 100644 index 0000000000..c9ecf9614d --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_f1.py @@ -0,0 +1,44 @@ +import pytest + +from miles.rollout.rm_hub.f1 import f1_score, normalize_answer + + +class TestNormalizeAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("Hello World", "hello world"), + ("The quick brown fox", "quick brown fox"), + ("A cat and a dog", "cat and dog"), + ("Hello, world!", "hello world"), + (" multiple spaces ", "multiple spaces"), + ("An apple", "apple"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_normalize_answer(self, input_str, expected): + assert normalize_answer(input_str) == expected + + +class TestF1Score: + @pytest.mark.parametrize( + "prediction,ground_truth,expected_f1,expected_prec,expected_recall", + [ + ("hello world", "hello world", 1.0, 1.0, 1.0), + ("hello world foo", "hello world bar", 2 / 3, 2 / 3, 2 / 3), + ("abc", "xyz", 0, 0, 0), + (None, "anything", 0, 0, 0), + ("yes", "no", 0, 0, 0), + ("no", "yes", 0, 0, 0), + ("yes", "yes", 1.0, 1.0, 1.0), + ("noanswer", "yes", 0, 0, 0), + ("the answer is correct", "answer is correct", 1.0, 1.0, 1.0), + ("hello, world!", "hello world", 1.0, 1.0, 1.0), + ("hello", "hello world", pytest.approx(2 / 3), 1.0, 0.5), + ], + ) + def test_f1_score(self, prediction, ground_truth, expected_f1, expected_prec, expected_recall): + f1, prec, recall = f1_score(prediction, ground_truth) + assert f1 == expected_f1 + assert prec == expected_prec + assert recall == expected_recall diff --git a/tests/fast/rollout/rm_hub/test_gpqa.py b/tests/fast/rollout/rm_hub/test_gpqa.py new file mode 100644 index 0000000000..45cefd2015 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_gpqa.py @@ -0,0 +1,86 @@ +import pytest + +from miles.rollout.rm_hub.gpqa import ( + _extract_letter_from_response, + _normalize_text, + _strip_chain_of_thought, + compute_gpqa_reward, +) + + +class TestStripChainOfThought: + @pytest.mark.parametrize( + "text,expected", + [ + ("Let me think...The answer is A", "The answer is A"), + ("The answer is A", "The answer is A"), + ("", ""), + (None, ""), + ], + ) + def test_strip_chain_of_thought(self, text, expected): + assert _strip_chain_of_thought(text) == expected + + +class TestNormalizeText: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("Hello World", "hello world"), + ("Test-123", "test 123"), + ("A, B, C", "a b c"), + ("", ""), + ], + ) + def test_normalize_text(self, input_str, expected): + assert _normalize_text(input_str) == expected + + +class TestExtractLetterFromResponse: + @pytest.mark.parametrize( + "response,expected", + [ + ("The answer is A", "A"), + ("answer: B", "B"), + ("I think C is correct", "C"), + ("final answer: D", "D"), + ("Option A is the best choice", "A"), + ("The answer is B", "B"), + ("After analysis, my choice is C", "C"), + ("A B C D", "D"), + ("No valid letter here", None), + ("", None), + (None, None), + ("The answer is Z", None), + ], + ) + def test_extract_letter(self, response, expected): + assert _extract_letter_from_response(response, "ABCD") == expected + + +class TestComputeGpqaReward: + @pytest.mark.parametrize( + "response,label,metadata,expected", + [ + ("Answer: A", "A", None, 1.0), + ("Answer: A", "B", None, 0.0), + (None, "A", None, 0.0), + ("Answer: B", "ignored", {"correct_letter": "B"}, 1.0), + ("Answer: A", "ignored", {"correct_letter": "B"}, 0.0), + ("Answer: A", 0, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), + ("Answer: B", 1, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), + ("Answer: X", "X", {"valid_letters": ["X", "Y", "Z"]}, 1.0), + ("Answer: A", "X", {"valid_letters": ["X", "Y", "Z"]}, 0.0), + ( + "I believe the answer is Paris", + "", + {"choices": ["Paris", "London", "Berlin", "Rome"], "correct_letter": "A"}, + 1.0, + ), + ("Answer: A", "", {"choices": {"A": "Paris", "B": "London"}, "correct_letter": "A"}, 1.0), + ("The answer is Paris", "Paris", {"choices": ["Paris", "London", "Berlin", "Rome"]}, 1.0), + ("Let me think step by step...The answer is A", "A", None, 1.0), + ], + ) + def test_compute_gpqa_reward(self, response, label, metadata, expected): + assert compute_gpqa_reward(response, label, metadata=metadata) == expected diff --git a/tests/fast/rollout/rm_hub/test_math_dapo_utils.py b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py new file mode 100644 index 0000000000..56a7f6d1f9 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py @@ -0,0 +1,108 @@ +import pytest + +from miles.rollout.rm_hub.math_dapo_utils import ( + compute_score, + is_correct_minerva, + is_correct_strict_box, + last_boxed_only_string, + normalize_final_answer, + remove_boxed, +) + + +class TestLastBoxedOnlyString: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", r"\boxed{42}"), + (r"\boxed{x^2}", r"\boxed{x^2}"), + (r"No boxed", None), + (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), + ], + ) + def test_last_boxed_only_string(self, input_str, expected): + assert last_boxed_only_string(input_str) == expected + + +class TestRemoveBoxed: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"\boxed{42}", "42"), + (r"\boxed{x + 1}", "x + 1"), + ], + ) + def test_remove_boxed_valid(self, input_str, expected): + assert remove_boxed(input_str) == expected + + def test_remove_boxed_invalid(self): + with pytest.raises(AssertionError): + remove_boxed("not boxed") + + +class TestNormalizeFinalAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("42", "42"), + (" 42 ", "42"), + (r"\text{hello}", "hello"), + (r"\textbf{bold}", "bold"), + (r"x = 42", "42"), + (r"100 square", "100"), + (r"$50$ dollars", "50"), + (r"\boxed{42}", "42"), + (r"\frac12", r"\frac{1}{2}"), + (r"\sqrt3", r"\sqrt{3}"), + ("1,000", "1000"), + ("<|im_end|>", ""), + ], + ) + def test_normalize_final_answer(self, input_str, expected): + assert normalize_final_answer(input_str) == expected + + +class TestIsCorrectMinerva: + @pytest.mark.parametrize( + "solution,gt,gt_need_extract,expected_correct", + [ + ("Answer: 42", "42", False, True), + ("Answer: 100", "42", False, False), + ("Answer: wrong", "42", False, False), + ("Answer: 42", r"\boxed{42}", True, True), + ], + ) + def test_is_correct_minerva(self, solution, gt, gt_need_extract, expected_correct): + correct, pred = is_correct_minerva(solution, gt, gt_need_extract=gt_need_extract) + assert correct == expected_correct + + +class TestIsCorrectStrictBox: + @pytest.mark.parametrize( + "pred,gt,expected_score,expected_pred", + [ + (r"blah blah \boxed{42}", "42", 1, "42"), + (r"\boxed{wrong}", "42", -1, "wrong"), + ("no box here", "42", -1, None), + ], + ) + def test_is_correct_strict_box(self, pred, gt, expected_score, expected_pred): + score, extracted = is_correct_strict_box(pred, gt) + assert score == expected_score + assert extracted == expected_pred + + +class TestComputeScore: + @pytest.mark.parametrize( + "solution,gt,strict_box,expected_score,expected_acc", + [ + ("Answer: 42", "42", False, 1.0, True), + ("Answer: wrong", "42", False, -1.0, False), + (r"\boxed{42}", "42", True, 1.0, True), + ("x" * 500 + " Answer: 42", "42", False, 1.0, True), + ], + ) + def test_compute_score(self, solution, gt, strict_box, expected_score, expected_acc): + result = compute_score(solution, gt, strict_box_verify=strict_box) + assert result["score"] == expected_score + assert result["acc"] == expected_acc diff --git a/tests/fast/rollout/rm_hub/test_math_utils.py b/tests/fast/rollout/rm_hub/test_math_utils.py new file mode 100644 index 0000000000..2423ed4acc --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_math_utils.py @@ -0,0 +1,129 @@ +import pytest + +from miles.rollout.rm_hub.math_utils import ( + _normalize, + extract_answer, + grade_answer_mathd, + grade_answer_sympy, + grade_answer_verl, + last_boxed_only_string, + remove_boxed, +) + + +class TestLastBoxedOnlyString: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", r"\boxed{42}"), + (r"\boxed{x^2 + 1}", r"\boxed{x^2 + 1}"), + (r"So \boxed{\frac{1}{2}}", r"\boxed{\frac{1}{2}}"), + (r"No boxed here", None), + (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), + (r"\boxed{nested {braces}}", r"\boxed{nested {braces}}"), + (r"\fbox{fbox content}", r"\fbox{fbox content}"), + ("", None), + ], + ) + def test_last_boxed_only_string(self, input_str, expected): + assert last_boxed_only_string(input_str) == expected + + +class TestRemoveBoxed: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"\boxed{42}", "42"), + (r"\boxed{x^2 + 1}", "x^2 + 1"), + (r"\boxed{\frac{1}{2}}", r"\frac{1}{2}"), + ("not boxed", None), + ], + ) + def test_remove_boxed(self, input_str, expected): + assert remove_boxed(input_str) == expected + + +class TestExtractAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", "42"), + (r"So \boxed{\frac{1}{2}}", r"\frac{1}{2}"), + (r"Multiple \boxed{1} then \boxed{final}", "final"), + (r"No boxed here", None), + ("", None), + ], + ) + def test_extract_answer(self, input_str, expected): + assert extract_answer(input_str) == expected + + +class TestNormalize: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("1,000", "1000"), + (r"\text{hello}", "hello"), + (" 42 ", "42"), + (r"100%", "100"), + (r"\$50", "50"), + ("HELLO", "hello"), + ("1,234,567", "1234567"), + (None, None), + ], + ) + def test_normalize(self, input_str, expected): + assert _normalize(input_str) == expected + + +class TestGradeAnswerMathd: + @pytest.mark.parametrize( + "given,ground_truth,expected", + [ + ("42", "42", True), + (" 42 ", "42", True), + (r"\frac{1}{2}", r"\frac{1}{2}", True), + ("wrong", "42", False), + ("", "42", False), + ], + ) + def test_grade_answer_mathd(self, given, ground_truth, expected): + assert grade_answer_mathd(given, ground_truth) == expected + + +class TestGradeAnswerSympy: + @pytest.mark.parametrize( + "given,ground_truth,expected", + [ + ("42", "42", True), + ("x^2", "x^2", True), + ("1/2", "0.5", True), + (r"\frac{1}{2}", "0.5", True), + ("wrong", "42", False), + ("", "42", False), + ("(1,2)", "(1,2)", True), + ("(1,2,3)", "(1,2)", False), + ("42", None, False), + ], + ) + def test_grade_answer_sympy(self, given, ground_truth, expected): + assert grade_answer_sympy(given, ground_truth) == expected + + +class TestGradeAnswerVerl: + @pytest.mark.parametrize( + "solution,ground_truth,expected", + [ + (r"\boxed{42}", "42", True), + (r"The answer is \boxed{42}", "42", True), + (r"\boxed{1/2}", r"\frac{1}{2}", True), + (r"\boxed{wrong}", "42", False), + ("no boxed", "42", False), + (r"\boxed{42}", r"\boxed{42}", True), + ("", "42", False), + (r"\boxed{42}", "", False), + (r"\boxed{42}", None, False), + ], + ) + def test_grade_answer_verl(self, solution, ground_truth, expected): + assert grade_answer_verl(solution, ground_truth) == expected diff --git a/tests/fast/rollout/rm_hub/test_rm_hub.py b/tests/fast/rollout/rm_hub/test_rm_hub.py new file mode 100644 index 0000000000..a3dadbdaf0 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_rm_hub.py @@ -0,0 +1,126 @@ +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.rm_hub import async_rm, batched_async_rm +from miles.utils.async_utils import run +from miles.utils.types import Sample + + +@pytest.fixture +def mock_args(): + args = MagicMock() + args.custom_rm_path = None + args.rm_type = None + args.rm_url = None + return args + + +class TestAsyncRm: + @pytest.mark.parametrize( + "rm_type,response,label,expected", + [ + ("math", r"\boxed{42}", "42", 1), + ("math", r"\boxed{wrong}", "42", 0), + ("f1", "hello world", "hello world", 1.0), + ("dapo", "Answer: 42", "42", {"score": 1.0}), + ("deepscaler", r"\boxed{42}", "42", 1), + ("gpqa", "Answer: A", "A", 1.0), + ("boxed_f1", r"Final answer is \boxed{hello world}", "hello world", 1.0), + ], + ) + def test_rm_types(self, mock_args, rm_type, response, label, expected): + mock_args.rm_type = rm_type + sample = Sample(prompt="", response=response, label=label) + reward = run(async_rm(mock_args, sample)) + if isinstance(expected, dict): + for k, v in expected.items(): + assert reward[k] == v + else: + assert reward == expected + + def test_f1_rm_partial(self, mock_args): + mock_args.rm_type = "f1" + sample = Sample(prompt="", response="hello", label="hello world") + reward = run(async_rm(mock_args, sample)) + assert 0 < reward < 1 + + def test_random_rm(self, mock_args): + mock_args.rm_type = "random" + sample = Sample(prompt="", response="anything", label="anything") + reward = run(async_rm(mock_args, sample)) + assert reward in [0, 1] + + def test_rm_type_from_metadata(self, mock_args): + mock_args.rm_type = None + sample = Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}) + reward = run(async_rm(mock_args, sample)) + assert reward == 1 + + @pytest.mark.parametrize( + "rm_type,match", + [ + ("unknown_type", "not implemented"), + ("", "not specified"), + ], + ) + def test_invalid_rm_type_raises(self, mock_args, rm_type, match): + mock_args.rm_type = rm_type + sample = Sample(prompt="", response="test", label="test") + with pytest.raises(NotImplementedError, match=match): + run(async_rm(mock_args, sample)) + + +class TestBatchedAsyncRm: + @pytest.mark.parametrize( + "rm_type,samples_data,expected", + [ + ( + "math", + [(r"\boxed{42}", "42"), (r"\boxed{100}", "100"), (r"\boxed{wrong}", "42")], + [1, 1, 0], + ), + ( + "f1", + [("hello world", "hello world"), ("different", "something else")], + [1.0, 0], + ), + ], + ) + def test_batched_rm(self, mock_args, rm_type, samples_data, expected): + mock_args.rm_type = rm_type + samples = [Sample(prompt="", response=r, label=label) for r, label in samples_data] + rewards = run(batched_async_rm(mock_args, samples)) + assert rewards == expected + + def test_inplace_set_reward_field(self, mock_args): + mock_args.rm_type = "math" + samples = [ + Sample(prompt="", response=r"\boxed{42}", label="42"), + Sample(prompt="", response=r"\boxed{100}", label="100"), + ] + result = run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) + assert result is None + assert samples[0].reward == 1 + assert samples[1].reward == 1 + + def test_inplace_raises_on_existing_reward(self, mock_args): + mock_args.rm_type = "math" + samples = [Sample(prompt="", response=r"\boxed{42}", label="42", reward=0.5)] + with pytest.raises(AssertionError, match="Overriding"): + run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) + + def test_empty_samples(self, mock_args): + mock_args.rm_type = "math" + rewards = run(batched_async_rm(mock_args, [])) + assert rewards == [] + + def test_mixed_rm_types_via_metadata(self, mock_args): + mock_args.rm_type = None + samples = [ + Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}), + Sample(prompt="", response="hello", label="hello", metadata={"rm_type": "f1"}), + ] + rewards = run(batched_async_rm(mock_args, samples)) + assert rewards[0] == 1 + assert rewards[1] == 1.0 diff --git a/tests/fast/router/__init__.py b/tests/fast/router/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/router/test_router.py b/tests/fast/router/test_router.py new file mode 100644 index 0000000000..7c645fe304 --- /dev/null +++ b/tests/fast/router/test_router.py @@ -0,0 +1,204 @@ +import asyncio +from argparse import Namespace + +import pytest +import requests + +from miles.router.router import MilesRouter +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, default_process_fn +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +def make_router_args(router_port: int, **overrides) -> Namespace: + defaults = dict( + sglang_router_ip="127.0.0.1", + sglang_router_port=router_port, + rollout_health_check_interval=1.0, + miles_router_health_check_failure_threshold=3, + miles_router_max_connections=100, + miles_router_timeout=None, + miles_router_middleware_paths=[], + ) + defaults.update(overrides) + return Namespace(**defaults) + + +def create_mock_worker(start_port: int = 30000) -> MockSGLangServer: + port = find_available_port(start_port) + return MockSGLangServer( + model_name="Qwen/Qwen3-0.6B", + process_fn=default_process_fn, + host="127.0.0.1", + port=port, + latency=0.0, + ) + + +class RouterEnv: + def __init__(self, router: MilesRouter, server: UvicornThreadServer): + self.router = router + self.server = server + + @property + def url(self) -> str: + return self.server.url + + +@pytest.fixture +def router_env(): + args = make_router_args(find_available_port(20000)) + router = MilesRouter(args, verbose=False) + server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) + server.start() + yield RouterEnv(router, server) + server.stop() + + +@pytest.fixture +def mock_worker(): + server = create_mock_worker() + server.start() + yield server + server.stop() + + +@pytest.fixture +def mock_worker_factory(): + servers = [] + + def _create(): + start_port = 30000 + len(servers) * 100 + server = create_mock_worker(start_port) + server.start() + servers.append(server) + return server + + yield _create + for s in servers: + s.stop() + + +@pytest.fixture +def router_factory(): + def _create(**overrides) -> MilesRouter: + args = make_router_args(find_available_port(20000), **overrides) + return MilesRouter(args, verbose=False) + + return _create + + +class TestWorkerManagement: + def test_add_worker_via_query_param(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30001" + r = requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0) + r.raise_for_status() + + assert r.json()["status"] == "success" + assert worker_url in router_env.router.worker_request_counts + assert router_env.router.worker_request_counts[worker_url] == 0 + + def test_add_worker_via_body(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30002" + r = requests.post(f"{router_env.url}/add_worker", json={"url": worker_url}, timeout=5.0) + r.raise_for_status() + + assert r.json()["status"] == "success" + assert worker_url in router_env.router.worker_request_counts + + def test_add_worker_duplicate(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30003" + requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() + requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() + + assert len(router_env.router.worker_request_counts) == 1 + assert worker_url in router_env.router.worker_request_counts + + def test_add_worker_missing_url(self, router_env: RouterEnv): + r = requests.post(f"{router_env.url}/add_worker", json={}, timeout=5.0) + assert r.status_code == 400 + assert "error" in r.json() + + def test_list_workers(self, router_env: RouterEnv): + worker_urls = ["http://127.0.0.1:30001", "http://127.0.0.1:30002"] + for url in worker_urls: + requests.post(f"{router_env.url}/add_worker", params={"url": url}, timeout=5.0) + + r = requests.get(f"{router_env.url}/list_workers", timeout=5.0) + r.raise_for_status() + assert set(r.json()["urls"]) == set(worker_urls) + + +class TestLoadBalancing: + def test_use_url_selects_min_load(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8} + + selected = router._use_url() + assert selected == "http://w2:8000" + assert router.worker_request_counts["http://w2:8000"] == 3 + + def test_use_url_excludes_dead_workers(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 1, "http://w3:8000": 3} + router.dead_workers = {"http://w2:8000"} + + selected = router._use_url() + assert selected == "http://w3:8000" + assert router.worker_request_counts["http://w3:8000"] == 4 + + def test_use_url_raises_when_all_dead(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 0} + router.dead_workers = {"http://w1:8000"} + + with pytest.raises(RuntimeError, match="No healthy workers"): + router._use_url() + + +# TODO: extract main body inside `_health_check_loop`, then can test that function +class TestHealthCheck: + def test_check_worker_health_success(self, router_factory, mock_worker: MockSGLangServer): + router = router_factory() + url, healthy = asyncio.run(router._check_worker_health(mock_worker.url)) + assert url == mock_worker.url + assert healthy is True + + def test_check_worker_health_failure(self, router_factory): + router = router_factory() + url, healthy = asyncio.run(router._check_worker_health("http://127.0.0.1:59999")) + assert url == "http://127.0.0.1:59999" + assert healthy is False + + +class TestProxyIntegration: + def test_proxy_forwards_request(self, router_env: RouterEnv, mock_worker: MockSGLangServer): + requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0).raise_for_status() + + payload = {"input_ids": [1, 2, 3], "return_logprob": True} + r = requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0) + r.raise_for_status() + + assert "text" in r.json() + assert len(mock_worker.request_log) == 1 + assert mock_worker.request_log[0] == payload + + def test_proxy_multi_worker(self, router_env: RouterEnv, mock_worker_factory): + worker1, worker2 = mock_worker_factory(), mock_worker_factory() + requests.post(f"{router_env.url}/add_worker", params={"url": worker1.url}, timeout=5.0) + requests.post(f"{router_env.url}/add_worker", params={"url": worker2.url}, timeout=5.0) + + payload = {"input_ids": [1, 2, 3], "return_logprob": True} + for _ in range(4): + requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0).raise_for_status() + + all_requests = worker1.request_log + worker2.request_log + assert len(all_requests) == 4 + assert all(req == payload for req in all_requests) + + def test_proxy_health_endpoint(self, router_env: RouterEnv, mock_worker: MockSGLangServer): + requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0) + + r = requests.get(f"{router_env.url}/health", timeout=5.0) + r.raise_for_status() + assert r.json()["status"] == "ok" diff --git a/tests/fast/router/test_sessions.py b/tests/fast/router/test_sessions.py new file mode 100644 index 0000000000..566bb938f7 --- /dev/null +++ b/tests/fast/router/test_sessions.py @@ -0,0 +1,210 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pytest +import requests + +from miles.router.router import MilesRouter +from miles.router.session.naive_trajectory import NaiveTrajectoryManager +from miles.router.session.session_types import SessionRecord +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, ProcessResult, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +class DummyTokenizer: + """Minimal tokenizer stub for testing NaiveTrajectoryManager.""" + + def apply_chat_template( + self, + messages, + tokenize: bool = True, + add_special_tokens: bool = False, + add_generation_prompt: bool = True, + ): + """Return deterministic token ids based on message count.""" + base = len(messages) or 1 + return [base, base + 1, base + 2] + + +@pytest.fixture +def naive_manager(): + """Create a NaiveTrajectoryManager with a dummy tokenizer.""" + args = SimpleNamespace() + tokenizer = DummyTokenizer() + return NaiveTrajectoryManager(args, tokenizer) + + +class TestNaiveTrajectoryManager: + def test_create_session(self, naive_manager: NaiveTrajectoryManager): + session_id = naive_manager.create_session() + assert session_id is not None + assert len(session_id) == 32 + assert session_id in naive_manager.sessions + + def test_get_session_records_by_id(self, naive_manager: NaiveTrajectoryManager): + session_id = naive_manager.create_session() + records = naive_manager.get_session_records_by_id(session_id) + assert records == [] + + def test_get_session_records_by_id_not_found(self, naive_manager: NaiveTrajectoryManager): + records = naive_manager.get_session_records_by_id("nonexistent") + assert records is None + + def test_calc_prompt_tokens_for_existing_session(self, naive_manager: NaiveTrajectoryManager): + session_id = naive_manager.create_session() + messages = [{"role": "user", "content": "hello"}] + + token_ids = naive_manager.calc_prompt_tokens(session_id, messages) + + assert token_ids == [1, 2, 3] + + def test_calc_prompt_tokens_for_missing_session(self, naive_manager: NaiveTrajectoryManager): + messages = [{"role": "user", "content": "hello"}] + token_ids = naive_manager.calc_prompt_tokens("missing", messages) + assert token_ids is None + + def test_delete_session_by_id(self, naive_manager: NaiveTrajectoryManager): + session_id = naive_manager.create_session() + assert naive_manager.delete_session_by_id(session_id) is True + assert session_id not in naive_manager.sessions + assert naive_manager.delete_session_by_id(session_id) is None + + def test_append_session_record(self, naive_manager: NaiveTrajectoryManager): + session_id = naive_manager.create_session() + record = SessionRecord( + timestamp=0.0, + method="POST", + path="/v1/chat/completions", + status_code=200, + request={"messages": [{"role": "user", "content": "hello"}]}, + response={"choices": []}, + ) + + appended = naive_manager.append_session_record(session_id, record) + + assert appended is True + records = naive_manager.get_session_records_by_id(session_id) + assert records is not None + assert len(records) == 1 + assert records[0].path == record.path + + def test_append_session_record_missing_session(self, naive_manager: NaiveTrajectoryManager): + record = SessionRecord( + timestamp=0.0, + method="POST", + path="/v1/chat/completions", + status_code=200, + request={}, + response={}, + ) + appended = naive_manager.append_session_record("missing", record) + assert appended is None + + +@pytest.fixture(scope="class") +def router_env(): + """Create a MilesRouter with session routes and a mock backend.""" + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text=f"echo: {prompt}", finish_reason="stop") + + original_chat_response = MockSGLangServer._compute_chat_completions_response + + def patched_chat_response(self, payload: dict) -> dict: + response = original_chat_response(self, payload) + logprobs_content = response["choices"][0]["logprobs"]["content"] + for item in logprobs_content: + item["token_id"] = self.tokenizer.convert_tokens_to_ids(item["token"]) + return response + + with patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response): + with with_mock_server(process_fn=process_fn) as backend: + args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + miles_router_enable_token_input_for_chat_completions=False, + hf_checkpoint="Qwen/Qwen3-0.6B", + trajectory_manager="naive_trajectory", + ) + router = MilesRouter(args) + + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() + + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend.url}, timeout=5.0) + + try: + yield SimpleNamespace(url=url) + finally: + server.stop() + + +class TestSessionRoutes: + def test_create_session(self, router_env): + response = requests.post(f"{router_env.url}/sessions", timeout=5.0) + assert response.status_code == 200 + data = response.json() + assert "session_id" in data + assert len(data["session_id"]) == 32 + + def test_get_session_initial_state(self, router_env): + session_id = requests.post(f"{router_env.url}/sessions", timeout=5.0).json()["session_id"] + + get_resp = requests.get(f"{router_env.url}/sessions/{session_id}", timeout=5.0) + assert get_resp.status_code == 200 + data = get_resp.json() + assert data["session_id"] == session_id + assert data["records"] == [] + + def test_get_session_not_found(self, router_env): + response = requests.get(f"{router_env.url}/sessions/nonexistent", timeout=5.0) + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + def test_delete_session(self, router_env): + session_id = requests.post(f"{router_env.url}/sessions", timeout=5.0).json()["session_id"] + + delete_resp = requests.delete(f"{router_env.url}/sessions/{session_id}", timeout=5.0) + assert delete_resp.status_code == 204 + assert delete_resp.text == "" + + assert requests.delete(f"{router_env.url}/sessions/{session_id}", timeout=5.0).status_code == 404 + + def test_delete_session_not_found(self, router_env): + response = requests.delete(f"{router_env.url}/sessions/nonexistent", timeout=5.0) + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + +class TestSessionProxy: + def test_proxy_chat_appends_record(self, router_env): + session_id = requests.post(f"{router_env.url}/sessions", timeout=5.0).json()["session_id"] + + payload = { + "messages": [{"role": "user", "content": "What is 1+2?"}], + "return_logprob": True, + } + resp = requests.post( + f"{router_env.url}/sessions/{session_id}/v1/chat/completions", + json=payload, + timeout=10.0, + ) + assert resp.status_code == 200 + body = resp.json() + assert "choices" in body + assert body["choices"] + + get_resp = requests.get(f"{router_env.url}/sessions/{session_id}", timeout=5.0) + records = get_resp.json()["records"] + + assert isinstance(records, list) + assert len(records) == 1 + record = records[0] + assert record["path"] == "/v1/chat/completions" + assert record["status_code"] == 200 diff --git a/tests/fast/utils/__init__.py b/tests/fast/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/utils/test_arguments.py b/tests/fast/utils/test_arguments.py new file mode 100644 index 0000000000..9bd1a620d6 --- /dev/null +++ b/tests/fast/utils/test_arguments.py @@ -0,0 +1,58 @@ +import argparse +import sys +from unittest.mock import patch + +import pytest + +from miles.utils.arguments import get_miles_extra_args_provider +from miles.utils.misc import function_registry + +PATH_ARGS = ["--rollout-function-path", "--custom-generate-function-path"] +REQUIRED_ARGS = ["--rollout-batch-size", "64"] + + +def make_class_with_add_arguments(): + class MyFn: + @classmethod + def add_arguments(cls, parser): + parser.add_argument("--my-custom-arg", type=int, default=42) + + return MyFn + + +def make_function_with_add_arguments(): + def my_fn(): + pass + + my_fn.add_arguments = lambda parser: parser.add_argument("--my-custom-arg", type=int, default=42) + return my_fn + + +def make_function_without_add_arguments(): + def my_fn(): + pass + + return my_fn + + +@pytest.mark.parametrize("path_arg", PATH_ARGS) +class TestAddArgumentsSupport: + + @pytest.mark.parametrize("fn_factory", [make_class_with_add_arguments, make_function_with_add_arguments]) + def test_add_arguments_is_called_and_arg_is_parsed(self, path_arg, fn_factory): + fn = fn_factory() + with function_registry.temporary("test:fn", fn), patch.object( + sys, "argv", ["test", path_arg, "test:fn", "--my-custom-arg", "100"] + REQUIRED_ARGS + ): + parser = argparse.ArgumentParser() + get_miles_extra_args_provider()(parser) + args, _ = parser.parse_known_args() + assert args.my_custom_arg == 100 + + def test_skips_function_without_add_arguments(self, path_arg): + fn = make_function_without_add_arguments() + with function_registry.temporary("test:fn", fn), patch.object( + sys, "argv", ["test", path_arg, "test:fn"] + REQUIRED_ARGS + ): + parser = argparse.ArgumentParser() + get_miles_extra_args_provider()(parser) diff --git a/tests/utils/test_mask_utils.py b/tests/fast/utils/test_mask_utils.py similarity index 100% rename from tests/utils/test_mask_utils.py rename to tests/fast/utils/test_mask_utils.py diff --git a/tests/fast/utils/test_misc.py b/tests/fast/utils/test_misc.py new file mode 100644 index 0000000000..810c2b67c7 --- /dev/null +++ b/tests/fast/utils/test_misc.py @@ -0,0 +1,59 @@ +import os + +import pytest + +from miles.utils.misc import FunctionRegistry, function_registry, load_function + + +def _fn_a(): + return "a" + + +def _fn_b(): + return "b" + + +class TestFunctionRegistry: + def test_register_and_get(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + assert registry.get("my_fn") is _fn_a + + def test_register_duplicate_raises(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + with pytest.raises(AssertionError): + with registry.temporary("my_fn", _fn_b): + pass + + def test_unregister(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + assert registry.get("my_fn") is _fn_a + assert registry.get("my_fn") is None + + def test_temporary_cleanup_on_exception(self): + registry = FunctionRegistry() + with pytest.raises(RuntimeError): + with registry.temporary("temp_fn", _fn_a): + raise RuntimeError("test") + assert registry.get("temp_fn") is None + + +class TestLoadFunction: + def test_load_from_module(self): + import os.path + + assert load_function("os.path.join") is os.path.join + + def test_load_none_returns_none(self): + assert load_function(None) is None + + def test_load_from_registry(self): + with function_registry.temporary("test:my_fn", _fn_a): + assert load_function("test:my_fn") is _fn_a + + def test_registry_takes_precedence(self): + with function_registry.temporary("os.path.join", _fn_b): + assert load_function("os.path.join") is _fn_b + assert load_function("os.path.join") is os.path.join diff --git a/tests/fast/utils/test_utils/__init__.py b/tests/fast/utils/test_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/utils/test_utils/test_mock_sglang_server.py b/tests/fast/utils/test_utils/test_mock_sglang_server.py new file mode 100644 index 0000000000..e387fd78bd --- /dev/null +++ b/tests/fast/utils/test_utils/test_mock_sglang_server.py @@ -0,0 +1,434 @@ +import asyncio +import concurrent.futures +import time + +import pytest +import requests + +from miles.utils.test_utils.mock_sglang_server import ( + Counter, + ProcessResult, + ProcessResultMetaInfo, + default_process_fn, + with_mock_server, +) +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub + + +def expected_logprobs(tokenizer, text: str) -> list[dict]: + output_ids = tokenizer.encode(text, add_special_tokens=False) + return [ + {"token": tokenizer.convert_ids_to_tokens(tid), "token_id": tid, "logprob": -i / 128} + for i, tid in enumerate(output_ids) + ] + + +def expected_input_token_ids(tokenizer, messages: list[dict], tools: list[dict] | None) -> list[int]: + prompt_str = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools) + return tokenizer.encode(prompt_str, add_special_tokens=False) + + +@pytest.fixture(scope="module") +def mock_server(): + with with_mock_server() as server: + yield server + + +class TestProcessResultMetaInfo: + def test_to_dict_empty(self): + assert ProcessResultMetaInfo().to_dict() == {} + + def test_to_dict_single_field(self): + assert ProcessResultMetaInfo(weight_version="v1").to_dict() == {"weight_version": "v1"} + + def test_to_dict_partial_fields(self): + assert ProcessResultMetaInfo(weight_version="v1", spec_accept_token_num=10).to_dict() == { + "weight_version": "v1", + "spec_accept_token_num": 10, + } + + def test_to_dict_all_fields(self): + assert ProcessResultMetaInfo( + weight_version="v1", + routed_experts="abc", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ).to_dict() == { + "weight_version": "v1", + "routed_experts": "abc", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + } + + +class TestDefaultProcessFn: + def test_math_question(self): + assert default_process_fn("What is 1+5?") == ProcessResult(text="\\boxed{6}", finish_reason="stop") + assert default_process_fn("What is 1+10?") == ProcessResult(text="\\boxed{11}", finish_reason="stop") + + def test_unknown_question(self): + assert default_process_fn("Hello") == ProcessResult(text="I don't understand.", finish_reason="stop") + + +class TestCounter: + def test_tracks_max(self): + counter = Counter() + assert counter.max_value == 0 + + with counter.track(): + assert counter.max_value == 1 + with counter.track(): + assert counter.max_value == 2 + + counter.reset() + assert counter.max_value == 0 + + def test_concurrent_tasks(self): + counter = Counter() + + async def task(): + with counter.track(): + await asyncio.sleep(0.1) + + async def run_all(): + await asyncio.gather(task(), task(), task()) + + asyncio.run(run_all()) + assert counter.max_value == 3 + + +class TestMockServerBasic: + def test_start_stop(self, mock_server): + assert mock_server.port > 0 + assert f"http://{mock_server.host}:{mock_server.port}" == mock_server.url + + def test_request_log_and_reset_stats(self, mock_server): + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + + payload = {"input_ids": [1, 2, 3], "sampling_params": {"temperature": 0.5}, "return_logprob": True} + requests.post(f"{mock_server.url}/generate", json=payload, timeout=5.0) + assert len(mock_server.request_log) == 1 + assert mock_server.request_log[0] == payload + + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + assert mock_server.max_concurrent == 0 + + @pytest.mark.parametrize("latency,min_time,max_time", [(0.0, 0.0, 0.3), (0.5, 0.5, 1.0)]) + def test_latency(self, latency, min_time, max_time): + with with_mock_server(latency=latency) as server: + start = time.time() + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + elapsed = time.time() - start + assert min_time <= elapsed < max_time + + def test_max_concurrent_with_latency(self): + with with_mock_server(latency=0.1) as server: + + def send_request(): + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(send_request) for _ in range(3)] + concurrent.futures.wait(futures) + + assert server.max_concurrent == 3 + + def test_health_endpoint(self, mock_server): + response = requests.get(f"{mock_server.url}/health", timeout=5.0) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + def test_abort_request_endpoint(self, mock_server): + response = requests.post(f"{mock_server.url}/abort_request", json={}, timeout=5.0) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +class TestGenerateEndpoint: + def test_basic(self, mock_server): + prompt = "What is 1+7?" + input_ids = mock_server.tokenizer.encode(prompt, add_special_tokens=False) + assert input_ids == [3838, 374, 220, 16, 10, 22, 30] + + response = requests.post( + f"{mock_server.url}/generate", + json={ + "input_ids": input_ids, + "sampling_params": {"temperature": 0.7, "max_new_tokens": 10}, + "return_logprob": True, + }, + timeout=5.0, + ) + assert response.status_code == 200 + assert response.json() == { + "text": "\\boxed{8}", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": len(input_ids), + "cached_tokens": 0, + "completion_tokens": 5, + "output_token_logprobs": [ + [-0.0, 59], + [-0.0078125, 79075], + [-0.015625, 90], + [-0.0234375, 23], + [-0.03125, 92], + ], + }, + } + + def test_with_meta_info(self): + def process_fn(_: str) -> ProcessResult: + return ProcessResult( + text="ok", + finish_reason="stop", + cached_tokens=5, + meta_info=ProcessResultMetaInfo( + weight_version="v2.0", + routed_experts="encoded_data", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ), + ) + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + + assert response.json() == { + "text": "ok", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": 3, + "cached_tokens": 5, + "completion_tokens": 1, + "output_token_logprobs": [[-0.0, 562]], + "weight_version": "v2.0", + "routed_experts": "encoded_data", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + }, + } + + def test_finish_reason_length(self): + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text="truncated output", finish_reason="length") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + data = response.json() + + finish_reason = data["meta_info"]["finish_reason"] + assert finish_reason["type"] == "length" + assert finish_reason["length"] == data["meta_info"]["completion_tokens"] + + +class TestChatCompletionsEndpoint: + def test_basic(self, mock_server): + messages = [{"role": "user", "content": "What is 1+5?"}] + response = requests.post( + f"{mock_server.url}/v1/chat/completions", + json={ + "model": "test-model", + "messages": messages, + }, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + + assert data["id"].startswith("chatcmpl-") + assert isinstance(data["created"], int) + assert data == { + "id": data["id"], + "object": "chat.completion", + "created": data["created"], + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "\\boxed{6}", "tool_calls": None}, + "logprobs": {"content": expected_logprobs(mock_server.tokenizer, "\\boxed{6}")}, + "input_token_ids": expected_input_token_ids(mock_server.tokenizer, messages, None), + "finish_reason": "stop", + } + ], + } + + def test_with_tool_calls(self): + tool_call_response = 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n' + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=tool_call_response, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What year is it?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": { + "role": "assistant", + "content": "Let me check for you.", + "tool_calls": [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}} + ], + }, + "logprobs": {"content": expected_logprobs(server.tokenizer, tool_call_response)}, + "input_token_ids": expected_input_token_ids( + server.tokenizer, + [{"role": "user", "content": "What year is it?"}], + SAMPLE_TOOLS, + ), + "finish_reason": "tool_calls", + } + + def test_with_tools_but_no_tool_call(self): + response_text = "The weather is sunny today." + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=response_text, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What's the weather?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": {"role": "assistant", "content": response_text, "tool_calls": None}, + "logprobs": {"content": expected_logprobs(server.tokenizer, response_text)}, + "input_token_ids": expected_input_token_ids( + server.tokenizer, + [{"role": "user", "content": "What's the weather?"}], + SAMPLE_TOOLS, + ), + "finish_reason": "stop", + } + + def test_with_multiple_tool_calls(self): + multi_tool_response = ( + "I will get year and temperature.\n" + '\n{"name": "get_year", "arguments": {}}\n\n' + '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n' + ) + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=multi_tool_response, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What year and temperature?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": { + "role": "assistant", + "content": "I will get year and temperature.", + "tool_calls": [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}}, + { + "id": "call00001", + "type": "function", + "function": {"name": "get_temperature", "arguments": '{"location": "Shanghai"}'}, + }, + ], + }, + "logprobs": {"content": expected_logprobs(server.tokenizer, multi_tool_response)}, + "input_token_ids": expected_input_token_ids( + server.tokenizer, + [{"role": "user", "content": "What year and temperature?"}], + SAMPLE_TOOLS, + ), + "finish_reason": "tool_calls", + } + + +class TestMultiTurnToolCallProcessFn: + @pytest.mark.parametrize( + "prompt,expected_response", + [ + pytest.param(TwoTurnStub.FIRST_PROMPT, TwoTurnStub.FIRST_RESPONSE, id="first_turn"), + pytest.param(TwoTurnStub.SECOND_PROMPT, TwoTurnStub.SECOND_RESPONSE, id="second_turn"), + ], + ) + def test_generate_endpoint(self, prompt, expected_response): + with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: + input_ids = server.tokenizer.encode(prompt, add_special_tokens=False) + response = requests.post( + f"{server.url}/generate", + json={"input_ids": input_ids, "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + assert data["text"] == expected_response + assert data["meta_info"]["finish_reason"] == {"type": "stop"} + + @pytest.mark.parametrize( + "messages,expected_content,expected_tool_calls,expected_finish_reason", + [ + pytest.param( + TwoTurnStub.OPENAI_MESSAGES_FIRST_TURN, + TwoTurnStub.FIRST_RESPONSE_CONTENT, + TwoTurnStub.FIRST_TOOL_CALLS_OPENAI_FORMAT, + "tool_calls", + id="first_turn", + ), + pytest.param( + TwoTurnStub.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT, + TwoTurnStub.SECOND_RESPONSE, + None, + "stop", + id="second_turn", + ), + ], + ) + def test_chat_completions_endpoint(self, messages, expected_content, expected_tool_calls, expected_finish_reason): + with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={"model": "test", "messages": messages, "tools": SAMPLE_TOOLS}, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + assert data["choices"][0]["message"]["content"] == expected_content + assert data["choices"][0]["message"]["tool_calls"] == expected_tool_calls + assert data["choices"][0]["finish_reason"] == expected_finish_reason diff --git a/tests/fast/utils/test_utils/test_mock_tools.py b/tests/fast/utils/test_utils/test_mock_tools.py new file mode 100644 index 0000000000..3f2116ec01 --- /dev/null +++ b/tests/fast/utils/test_utils/test_mock_tools.py @@ -0,0 +1,111 @@ +import asyncio + +import pytest +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.function_call_parser import FunctionCallParser + +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub, execute_tool_call + + +class TestExecuteToolCall: + def test_execute_get_year(self): + result = asyncio.run(execute_tool_call("get_year", {})) + assert result == '{"year": 2026}' + + def test_execute_get_temperature(self): + result = asyncio.run(execute_tool_call("get_temperature", {"location": "Mars"})) + assert result == '{"temperature": -60}' + + +class TestApplyChatTemplateWithTools: + EXPECTED_PROMPT_WITHOUT_TOOLS = ( + "<|im_start|>user\n" "What's the weather in Paris?<|im_end|>\n" "<|im_start|>assistant\n" + ) + + EXPECTED_PROMPT_WITH_TOOLS = ( + "<|im_start|>system\n" + "# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" + "<|im_start|>user\n" + "What's the weather in Paris?<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + @pytest.mark.parametrize( + "tools,expected", + [ + pytest.param(None, EXPECTED_PROMPT_WITHOUT_TOOLS, id="without_tools"), + pytest.param(SAMPLE_TOOLS, EXPECTED_PROMPT_WITH_TOOLS, id="with_tools"), + ], + ) + def test_apply_chat_template(self, tools, expected): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) + messages = [{"role": "user", "content": "What's the weather in Paris?"}] + + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools) + + assert prompt == expected + + +class TestSGLangFunctionCallParser: + """Test to demonstrate and ensure SGLang function call parser have features we need without breaking changes.""" + + @pytest.mark.parametrize( + "model_output,expected", + [ + pytest.param( + 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n', + ( + "Let me check for you.", + [ToolCallItem(tool_index=-1, name="get_year", parameters="{}")], + ), + id="single_tool_call", + ), + pytest.param( + "I will get year and temperature.\n" + '\n{"name": "get_year", "arguments": {}}\n\n' + '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n', + ( + "I will get year and temperature.", + [ + ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), + ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Shanghai"}'), + ], + ), + id="multi_tool_calls", + ), + pytest.param( + "The weather is sunny today.", + ("The weather is sunny today.", []), + id="no_tool_call", + ), + pytest.param( + TwoTurnStub.FIRST_RESPONSE, + ( + "Let me get the year and temperature first.", + [ + ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), + ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Mars"}'), + ], + ), + id="multi_turn_first_response", + ), + ], + ) + def test_parse_non_stream(self, model_output, expected): + tools = TypeAdapter(list[Tool]).validate_python(SAMPLE_TOOLS) + parser = FunctionCallParser(tools=tools, tool_call_parser="qwen25") + assert parser.parse_non_stream(model_output) == expected diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index c5c0838c53..9b6e69c295 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -126,6 +126,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, before_ray_job_submit=_launch_background, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tools/convert_hf_to_torch_dist.py b/tools/convert_hf_to_torch_dist.py index d6fddf386c..994700293b 100644 --- a/tools/convert_hf_to_torch_dist.py +++ b/tools/convert_hf_to_torch_dist.py @@ -1,6 +1,7 @@ import gc import os import shutil +from functools import wraps import torch import torch.distributed as dist @@ -11,6 +12,7 @@ import miles_plugins.mbridge # noqa: F401 from mbridge import AutoBridge +from mbridge.core.bridge import Bridge from miles.backends.megatron_utils.arguments import set_default_megatron_args from miles.backends.megatron_utils.initialize import init from miles.backends.megatron_utils.model_provider import get_model_provider_func @@ -18,6 +20,24 @@ from miles.utils.memory_utils import print_memory +def patch_weight_to_mcore_format_preserve_fp32(): + + original_method = Bridge._weight_to_mcore_format + + @wraps(original_method) + def patched_method(self, mcore_weights_name, hf_weights): + original_dtype = getattr(self, "dtype", None) + self.dtype = None + try: + result = original_method(self, mcore_weights_name, hf_weights) + finally: + self.dtype = original_dtype + return result + + Bridge._weight_to_mcore_format = patched_method + print("[Patch] Applied patch to preserve FP32 precision in _weight_to_mcore_format") + + def add_convertion_args(parser): """Add conversion arguments to the parser""" parser.add_argument("--hf-checkpoint", type=str, required=True, help="HuggingFace model path") @@ -111,6 +131,10 @@ def main(): # Load model hf_model_path = args.hf_checkpoint bridge = AutoBridge.from_pretrained(hf_model_path, trust_remote_code=True) + + # Patch to preserve FP32 precision for _keep_fp32 params + patch_weight_to_mcore_format_preserve_fp32() + bridge.load_weights(model, hf_model_path, memory_efficient=True) print(f"Model loaded: {hf_model_path}")