From 474d542e33513f27f1103a032a65e38ea76453bd Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 03:52:46 +0800 Subject: [PATCH 01/77] Enable experimental rollout flag for CI tests (#492) Co-authored-by: Ethan (Yusheng) Su --- .github/workflows/pr-test.yml | 40 ++ .github/workflows/pr-test.yml.j2 | 19 +- miles/ray/rollout.py | 33 +- miles/rollout/base_types.py | 66 +- .../rollout/generate_hub/agentic_tool_call.py | 85 +++ miles/rollout/generate_hub/multi_turn.py | 88 +++ miles/rollout/generate_hub/single_turn.py | 46 ++ miles/rollout/generate_utils/__init__.py | 0 .../generate_utils/generate_endpoint_utils.py | 112 ++++ .../generate_utils/openai_endpoint_utils.py | 67 ++ miles/rollout/generate_utils/sample_utils.py | 115 ++++ .../rollout/generate_utils/tool_call_utils.py | 115 ++++ miles/rollout/inference_rollout/__init__.py | 2 + .../inference_rollout/compatibility.py | 84 +++ .../inference_rollout_common.py | 192 ++++++ .../inference_rollout_eval.py | 112 ++++ .../inference_rollout_train.py | 146 +++++ miles/rollout/rm_hub/__init__.py | 12 +- miles/router/router.py | 47 +- miles/router/sessions.py | 124 ++++ miles/utils/arguments.py | 24 +- miles/utils/environ.py | 14 + miles/utils/http_utils.py | 20 +- miles/utils/misc.py | 50 +- miles/utils/test_utils/__init__.py | 0 miles/utils/test_utils/mock_sglang_server.py | 248 ++++++++ miles/utils/test_utils/mock_tools.py | 268 ++++++++ .../utils/test_utils/uvicorn_thread_server.py | 49 ++ miles/utils/types.py | 18 + requirements.txt | 1 + tests/__init__.py | 1 + tests/ci/gpu_lock_exec.py | 11 +- tests/e2e/.gitkeep | 1 + tests/fast/__init__.py | 0 tests/fast/conftest.py | 15 + tests/fast/fixtures/__init__.py | 1 + tests/fast/fixtures/generation_fixtures.py | 274 +++++++++ tests/fast/fixtures/rollout_fixtures.py | 127 ++++ tests/fast/rollout/__init__.py | 0 tests/fast/rollout/generate_hub/__init__.py | 0 .../rollout/generate_hub/test_multi_turn.py | 572 ++++++++++++++++++ .../rollout/generate_hub/test_single_turn.py | 424 +++++++++++++ .../generate_hub/test_tool_call_utils.py | 99 +++ tests/fast/rollout/generate_utils/__init__.py | 0 .../generate_utils/test_sample_utils.py | 156 +++++ .../rollout/inference_rollout/__init__.py | 0 .../rollout/inference_rollout/conftest.py | 45 ++ .../inference_rollout/integration/__init__.py | 0 .../integration/test_basic.py | 69 +++ .../integration/test_deterministic.py | 37 ++ .../integration/test_dynamic_filter.py | 46 ++ .../integration/test_group_rm.py | 22 + .../integration/test_multi_sample.py | 65 ++ .../integration/test_multi_turn.py | 114 ++++ .../integration/test_over_sampling.py | 48 ++ .../integration/test_sample_filter.py | 67 ++ .../integration/test_semaphore.py | 33 + .../inference_rollout/integration/utils.py | 89 +++ .../inference_rollout/test_compatibility.py | 196 ++++++ tests/fast/rollout/rm_hub/__init__.py | 0 tests/fast/rollout/rm_hub/test_deepscaler.py | 26 + tests/fast/rollout/rm_hub/test_f1.py | 44 ++ tests/fast/rollout/rm_hub/test_gpqa.py | 86 +++ .../rollout/rm_hub/test_math_dapo_utils.py | 108 ++++ tests/fast/rollout/rm_hub/test_math_utils.py | 129 ++++ tests/fast/rollout/rm_hub/test_rm_hub.py | 126 ++++ tests/fast/router/__init__.py | 0 tests/fast/router/test_router.py | 204 +++++++ tests/fast/router/test_sessions.py | 195 ++++++ tests/fast/utils/__init__.py | 0 tests/fast/utils/test_arguments.py | 58 ++ tests/{ => fast}/utils/test_mask_utils.py | 0 tests/fast/utils/test_misc.py | 59 ++ tests/fast/utils/test_utils/__init__.py | 0 .../test_utils/test_mock_sglang_server.py | 409 +++++++++++++ .../fast/utils/test_utils/test_mock_tools.py | 111 ++++ tests/test_external_rollout.py | 1 + tests/test_mimo_7B_mtp_only_grad.py | 1 + tests/test_moonlight_16B_A3B.py | 1 + tests/test_quick_start_glm4_9B.py | 1 + tests/test_qwen2.5_0.5B_gsm8k.py | 1 + tests/test_qwen2.5_0.5B_gsm8k_async.py | 1 + tests/test_qwen2.5_0.5B_gsm8k_async_short.py | 1 + tests/test_qwen2.5_0.5B_gsm8k_short.py | 1 + tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 1 + tests/test_qwen3_0.6B_fsdp_distributed.py | 1 + tests/test_qwen3_0.6B_megatron_fsdp_align.py | 3 + tests/test_qwen3_0.6B_parallel_check.py | 2 + tests/test_qwen3_30B_A3B.py | 1 + tests/test_qwen3_4B_ckpt.py | 1 + tests/test_qwen3_4B_fsdp_true_on_policy.py | 1 + tests/test_qwen3_4B_ppo.py | 1 + tests/test_qwen3_vl_4B_fsdp.py | 1 + 93 files changed, 6230 insertions(+), 54 deletions(-) create mode 100644 miles/rollout/generate_hub/agentic_tool_call.py create mode 100644 miles/rollout/generate_hub/multi_turn.py create mode 100644 miles/rollout/generate_hub/single_turn.py create mode 100644 miles/rollout/generate_utils/__init__.py create mode 100644 miles/rollout/generate_utils/generate_endpoint_utils.py create mode 100644 miles/rollout/generate_utils/openai_endpoint_utils.py create mode 100644 miles/rollout/generate_utils/sample_utils.py create mode 100644 miles/rollout/generate_utils/tool_call_utils.py create mode 100644 miles/rollout/inference_rollout/__init__.py create mode 100644 miles/rollout/inference_rollout/compatibility.py create mode 100644 miles/rollout/inference_rollout/inference_rollout_common.py create mode 100644 miles/rollout/inference_rollout/inference_rollout_eval.py create mode 100644 miles/rollout/inference_rollout/inference_rollout_train.py create mode 100644 miles/router/sessions.py create mode 100644 miles/utils/environ.py create mode 100644 miles/utils/test_utils/__init__.py create mode 100644 miles/utils/test_utils/mock_sglang_server.py create mode 100644 miles/utils/test_utils/mock_tools.py create mode 100644 miles/utils/test_utils/uvicorn_thread_server.py create mode 100644 tests/__init__.py create mode 100644 tests/e2e/.gitkeep create mode 100644 tests/fast/__init__.py create mode 100644 tests/fast/conftest.py create mode 100644 tests/fast/fixtures/__init__.py create mode 100644 tests/fast/fixtures/generation_fixtures.py create mode 100644 tests/fast/fixtures/rollout_fixtures.py create mode 100644 tests/fast/rollout/__init__.py create mode 100644 tests/fast/rollout/generate_hub/__init__.py create mode 100644 tests/fast/rollout/generate_hub/test_multi_turn.py create mode 100644 tests/fast/rollout/generate_hub/test_single_turn.py create mode 100644 tests/fast/rollout/generate_hub/test_tool_call_utils.py create mode 100644 tests/fast/rollout/generate_utils/__init__.py create mode 100644 tests/fast/rollout/generate_utils/test_sample_utils.py create mode 100644 tests/fast/rollout/inference_rollout/__init__.py create mode 100644 tests/fast/rollout/inference_rollout/conftest.py create mode 100644 tests/fast/rollout/inference_rollout/integration/__init__.py create mode 100644 tests/fast/rollout/inference_rollout/integration/test_basic.py create mode 100644 tests/fast/rollout/inference_rollout/integration/test_deterministic.py create mode 100644 tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py create mode 100644 tests/fast/rollout/inference_rollout/integration/test_group_rm.py create mode 100644 tests/fast/rollout/inference_rollout/integration/test_multi_sample.py create mode 100644 tests/fast/rollout/inference_rollout/integration/test_multi_turn.py create mode 100644 tests/fast/rollout/inference_rollout/integration/test_over_sampling.py create mode 100644 tests/fast/rollout/inference_rollout/integration/test_sample_filter.py create mode 100644 tests/fast/rollout/inference_rollout/integration/test_semaphore.py create mode 100644 tests/fast/rollout/inference_rollout/integration/utils.py create mode 100644 tests/fast/rollout/inference_rollout/test_compatibility.py create mode 100644 tests/fast/rollout/rm_hub/__init__.py create mode 100644 tests/fast/rollout/rm_hub/test_deepscaler.py create mode 100644 tests/fast/rollout/rm_hub/test_f1.py create mode 100644 tests/fast/rollout/rm_hub/test_gpqa.py create mode 100644 tests/fast/rollout/rm_hub/test_math_dapo_utils.py create mode 100644 tests/fast/rollout/rm_hub/test_math_utils.py create mode 100644 tests/fast/rollout/rm_hub/test_rm_hub.py create mode 100644 tests/fast/router/__init__.py create mode 100644 tests/fast/router/test_router.py create mode 100644 tests/fast/router/test_sessions.py create mode 100644 tests/fast/utils/__init__.py create mode 100644 tests/fast/utils/test_arguments.py rename tests/{ => fast}/utils/test_mask_utils.py (100%) create mode 100644 tests/fast/utils/test_misc.py create mode 100644 tests/fast/utils/test_utils/__init__.py create mode 100644 tests/fast/utils/test_utils/test_mock_sglang_server.py create mode 100644 tests/fast/utils/test_utils/test_mock_tools.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index d34c823aa3..4b8b5dc82c 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -25,6 +25,46 @@ concurrency: jobs: + fast: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=16g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 0, "test_file": "fast"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} + e2e-test-short: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) runs-on: self-hosted diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 37b6fa4463..c052b8494f 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -1,4 +1,10 @@ <% set jobs = { + 'fast': { + 'test_executor': 'pytest', + 'tests': [ + {'test_file': 'fast', 'num_gpus': 0}, + ], + }, 'e2e-test-short': { 'label': 'run-ci-short', 'tests': [ @@ -98,7 +104,7 @@ concurrency: jobs: <% for job_name, config in jobs.items() %> << job_name >>: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, '<< config.label >>')) + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request<% if config.label %> && contains(github.event.pull_request.labels.*.name, '<< config.label >>')<% endif %>) runs-on: self-hosted container: image: << config.image if config.image else 'radixark/miles:latest' >> @@ -153,14 +159,5 @@ jobs: - name: Execute shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- << config.test_executor | default('python') >> tests/${{ matrix.info.test_file }} <% endfor %> \ No newline at end of file diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 79c6649be0..27211845d8 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -13,8 +13,15 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from miles.backends.sglang_utils.sglang_engine import SGLangEngine -from miles.rollout.base_types import call_rollout_fn +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnTrainInput, + call_rollout_fn, +) +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils import tracking_utils +from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.health_monitor import RolloutHealthMonitor from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client from miles.utils.iter_utils import group_by @@ -53,8 +60,14 @@ def __init__(self, args, pg): data_source_cls = load_function(self.args.data_source_path) self.data_source = data_source_cls(args) - self.generate_rollout = load_function(self.args.rollout_function_path) - self.eval_generate_rollout = load_function(self.args.eval_function_path) + self.use_experimental_refactor = enable_experimental_rollout_refactor() + if self.use_experimental_refactor: + input = RolloutFnConstructorInput(args=args, data_source=self.data_source) + self.generate_rollout = load_rollout_function(input, self.args.rollout_function_path) + self.eval_generate_rollout = load_rollout_function(input, self.args.eval_function_path) + else: + self.generate_rollout = load_function(self.args.rollout_function_path) + self.eval_generate_rollout = load_function(self.args.eval_function_path) self.custom_reward_post_process_func = None if self.args.custom_reward_post_process_path is not None: self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path) @@ -142,7 +155,12 @@ def eval(self, rollout_id): return self.health_monitoring_resume() - result = call_rollout_fn(self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True) + if self.use_experimental_refactor: + result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) + else: + result = call_rollout_fn( + self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True + ) data = result.data self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True) metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics) @@ -224,7 +242,12 @@ def _get_rollout_data(self, rollout_id): ) metrics = None else: - data = call_rollout_fn(self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False) + if self.use_experimental_refactor: + data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id)) + else: + data = call_rollout_fn( + self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False + ) metrics = data.metrics data = data.samples # flatten the data if it is a list of lists diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index faa85c7269..c2644e87f9 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -1,22 +1,86 @@ +from __future__ import annotations + +from argparse import Namespace from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any +from miles.rollout.data_source import DataSource from miles.utils.types import Sample +if TYPE_CHECKING: + from miles.rollout.inference_rollout.inference_rollout_common import GenerateState + + +@dataclass(frozen=True) +class RolloutFnConstructorInput: + args: Namespace + # TODO may refactor DataSource API + data_source: DataSource + + +@dataclass(frozen=True) +class RolloutFnBaseInput: + rollout_id: int + + @property + def evaluation(self): + raise NotImplementedError + + +# subclassing for different data in the future +@dataclass(frozen=True) +class RolloutFnTrainInput(RolloutFnBaseInput): + @property + def evaluation(self): + return False + +@dataclass(frozen=True) +class RolloutFnEvalInput(RolloutFnBaseInput): + @property + def evaluation(self): + return True + + +# TODO make it frozen @dataclass class RolloutFnTrainOutput: samples: list[list[Sample]] metrics: dict[str, Any] = None +# TODO make it frozen @dataclass class RolloutFnEvalOutput: data: dict[str, dict[str, Any]] metrics: dict[str, Any] = None +RolloutFnInput = RolloutFnTrainInput | RolloutFnEvalInput +RolloutFnOutput = RolloutFnTrainOutput | RolloutFnEvalOutput + + +@dataclass(frozen=True) +class GenerateFnInput: + state: GenerateState + sample: Sample + sampling_params: dict[str, Any] + evaluation: bool + + @property + def args(self) -> Namespace: + return self.state.args + + +@dataclass(frozen=True) +class GenerateFnOutput: + # One generate may lead to multiple samples, such as multi-agent, tree-like exploration, or + # multi-turn with removing thinking tokens. + samples: Sample | list[Sample] + + def call_rollout_fn(fn, *args, evaluation: bool, **kwargs): + """Legacy rollout function call interface. Used when MILES_EXPERIMENTAL_ROLLOUT_REFACTOR is disabled.""" output = fn(*args, **kwargs, evaluation=evaluation) # compatibility for legacy version diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py new file mode 100644 index 0000000000..05223a6544 --- /dev/null +++ b/miles/rollout/generate_hub/agentic_tool_call.py @@ -0,0 +1,85 @@ +""" +Simple agentic demo with tool calling. +""" + +import argparse +from copy import deepcopy +from typing import Any + +from openai import AsyncOpenAI + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.openai_endpoint_utils import ( + OpenAIEndpointTracer, + compute_samples_from_openai_records, +) +from miles.rollout.generate_utils.sample_utils import merge_samples +from miles.rollout.generate_utils.tool_call_utils import execute_tool_calls +from miles.utils.misc import load_function + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + tracer = await OpenAIEndpointTracer.create(input.args) + + await _run_blackbox_tool_call_agent( + base_url=tracer.base_url, + prompt=input.sample.prompt, + max_turns=input.args.generate_max_turns, + tool_specs_path=input.args.generate_tool_specs_path, + execute_tool_function_path=input.args.generate_execute_tool_function_path, + ) + + records = await tracer.collect_records() + samples = compute_samples_from_openai_records(input.sample, records, input.state.tokenizer) + if not input.args.generate_multi_samples: + samples = merge_samples(samples, input.state.tokenizer) + return GenerateFnOutput(samples=samples) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-max-turns", type=int, default=16) + parser.add_argument("--generate-tool-specs-path", type=str) + parser.add_argument("--generate-execute-tool-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true") + + +generate.add_arguments = _add_arguments + + +async def _run_blackbox_tool_call_agent( + base_url: str, + prompt: list[dict[str, Any]], + max_turns: int, + tool_specs_path: str, + execute_tool_function_path: str, +): + """ + Imagine this is a black-box agent, e.g. SWE-agent, which does arbitrarily complex work, + only understands OpenAI compatible API, and never understands Miles or the Sample data structure. + """ + + # ----------------------- Setup ------------------------- + + client = AsyncOpenAI(base_url=base_url, api_key="empty") + execute_tool_function = load_function(execute_tool_function_path) + tool_specs = load_function(tool_specs_path) + + # ----------------------- Initial prompts ------------------------- + + messages = deepcopy(prompt) + + for _turn in range(max_turns): + # ----------------------- Call inference endpoint ------------------------- + + response = await client.chat.completions.create(model="default", messages=messages, tools=tool_specs) + + choice = response.choices[0] + messages.append(choice.message.model_dump()) + + if choice.finish_reason in ("stop", "length"): + break + + # ----------------------- Execute tools ------------------------- + + if x := choice.message.tool_calls: + messages += await execute_tool_calls(x, execute_tool_function) diff --git a/miles/rollout/generate_hub/multi_turn.py b/miles/rollout/generate_hub/multi_turn.py new file mode 100644 index 0000000000..97814ecb3d --- /dev/null +++ b/miles/rollout/generate_hub/multi_turn.py @@ -0,0 +1,88 @@ +""" +Simple multi-turn generation with tool calling. +""" + +import argparse +from copy import deepcopy + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.generate_endpoint_utils import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.rollout.generate_utils.tool_call_utils import ( + create_tool_call_parser, + execute_tool_calls, + update_sample_with_tool_responses, +) +from miles.utils.http_utils import post +from miles.utils.misc import load_function + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + # ----------------------- Setup ------------------------- + + args = input.args + sample = deepcopy(input.sample) + tokenizer = input.state.tokenizer + assert not args.partial_rollout, "Partial rollout is not supported" + + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + execute_tool_function = load_function(args.generate_execute_tool_function_path) + + tool_specs = load_function(args.generate_tool_specs_path) + tool_call_parser = create_tool_call_parser(tool_specs, args.generate_tool_call_parser) + + multi_samples = [] + + # ----------------------- Initial prompts ------------------------- + + prompt_tokens_ids = compute_prompt_ids_from_sample(input.state, sample, tools=tool_specs) + + sample.tokens = prompt_tokens_ids.copy() + + for _turn in range(args.generate_max_turns): + # ----------------------- Call inference endpoint ------------------------- + + payload, halt_status = compute_request_payload(args, sample.tokens, input.sampling_params) + if payload is None: + sample.status = halt_status + if args.generate_multi_samples and multi_samples: + multi_samples[-1].status = halt_status + break + + if args.generate_multi_samples: + sample = deepcopy(input.sample) + + output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output, update_loss_mask=True) + + if args.generate_multi_samples: + multi_samples.append(deepcopy(sample)) + + if output["meta_info"]["finish_reason"]["type"] in ("abort", "length"): + break + + # ----------------------- Execute tools ------------------------- + + _, tool_calls = tool_call_parser.parse_non_stream(output["text"]) + if len(tool_calls) == 0: + break + + tool_messages = await execute_tool_calls(tool_calls, execute_tool_function) + update_sample_with_tool_responses(sample, tool_messages, tokenizer=tokenizer) + + return GenerateFnOutput(samples=multi_samples if args.generate_multi_samples else sample) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-max-turns", type=int, default=16) + parser.add_argument("--generate-tool-specs-path", type=str) + parser.add_argument("--generate-tool-call-parser", type=str) + parser.add_argument("--generate-execute-tool-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true") + + +generate.add_arguments = _add_arguments diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py new file mode 100644 index 0000000000..5c0a15b5b4 --- /dev/null +++ b/miles/rollout/generate_hub/single_turn.py @@ -0,0 +1,46 @@ +""" +Simple single-turn generation. +""" + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.generate_endpoint_utils import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.utils.http_utils import post +from miles.utils.types import Sample + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + args = input.args + sample = input.sample + sampling_params = input.sampling_params + assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}" + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + prompt_ids = compute_prompt_ids_from_sample(input.state, sample) + + # Handle Partial Rollout resuming + if len(sample.response) > 0: + input_ids = sample.tokens + sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) + + assert sampling_params["max_new_tokens"] >= 0 + if sampling_params["max_new_tokens"] == 0: + sample.status = Sample.Status.TRUNCATED + return GenerateFnOutput(samples=sample) + else: + input_ids = prompt_ids + + payload, halt_status = compute_request_payload( + args, input_ids=input_ids, sampling_params=sampling_params, multimodal_inputs=sample.multimodal_inputs + ) + if payload is None: + sample.status = halt_status + return GenerateFnOutput(samples=sample) + + output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output) + + return GenerateFnOutput(samples=sample) diff --git a/miles/rollout/generate_utils/__init__.py b/miles/rollout/generate_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/miles/rollout/generate_utils/generate_endpoint_utils.py b/miles/rollout/generate_utils/generate_endpoint_utils.py new file mode 100644 index 0000000000..a91d71f1de --- /dev/null +++ b/miles/rollout/generate_utils/generate_endpoint_utils.py @@ -0,0 +1,112 @@ +""" +Utils to integrate SGLang's `/generate` endpoint with RL things like Sample. +""" + +from copy import deepcopy +from typing import Any + +import numpy as np +import pybase64 + +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.types import Sample + + +# Make this an isolated function because users may want to compute their own +def compute_prompt_ids_from_sample(state, sample, tools=None): + prompt = sample.prompt + + if state.processor: + processor_output = state.processor(text=prompt, **sample.multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + + # TODO shall we move it to other places? then can make this function immutable + sample.multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + + return prompt_ids + else: + if not isinstance(prompt, str): + prompt = state.tokenizer.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True, tools=tools + ) + + return state.tokenizer.encode(prompt, add_special_tokens=False) + + +def compute_request_payload( + args, + input_ids: list[int], + sampling_params: dict, + multimodal_inputs: dict | None = None, +) -> tuple[dict[str, Any] | None, Sample.Status | None]: + sampling_params = deepcopy(sampling_params) + max_new_tokens = sampling_params.pop("max_new_tokens", args.rollout_max_response_len) + if x := args.rollout_max_context_len: + max_new_tokens = min(max_new_tokens, x - len(input_ids)) + if max_new_tokens <= 0: + return None, Sample.Status.TRUNCATED + + payload = { + "input_ids": input_ids, + "sampling_params": {**sampling_params, "max_new_tokens": max_new_tokens}, + "return_logprob": True, + "return_routed_experts": args.use_rollout_routing_replay, + } + if image_data := (multimodal_inputs or {}).get("images"): + payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] + + return payload, None + + +async def update_sample_from_response( + args, sample: Sample, payload: dict, output: dict, update_loss_mask: bool = False +): + # Initialize sample.tokens for the first turn + if (len(sample.response) == 0) and not sample.tokens: + sample.tokens = payload["input_ids"] + + if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: + from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + + # TODO may rename to match + await postprocess_sample_with_radix_tree(args, sample, output) + + assert not update_loss_mask, "This code branch has not implemented update_loss_mask" + else: + if x := output["meta_info"].get("output_token_logprobs"): + new_response_tokens = [item[1] for item in x] + new_response_log_probs = [item[0] for item in x] + else: + new_response_tokens, new_response_log_probs = [], [] + + # Update sample with tokens directly - avoiding re-tokenization + sample.tokens = sample.tokens + new_response_tokens + sample.response_length += len(new_response_tokens) + sample.response += output["text"] + + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += new_response_log_probs + + if update_loss_mask: + if sample.loss_mask is None: + sample.loss_mask = [] + sample.loss_mask += [1] * len(new_response_tokens) + + # TODO handle multi-turn cases (may need concat instead of assignment) + sample.rollout_routed_experts = _get_rollout_routed_experts_from_response(args, sample, output) + + # TODO may unify (currently there are both methods inside Sample and separate functions) + sample.update_from_meta_info(args, output["meta_info"]) + + +def _get_rollout_routed_experts_from_response(args, sample, output): + info = output["meta_info"].get("routed_experts") + if info is None: + return None + + x = np.frombuffer(pybase64.b64decode(info.encode("ascii")), dtype=np.int32) + x = x.reshape(len(sample.tokens) - 1, args.num_layers, args.moe_router_topk) + return x diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py new file mode 100644 index 0000000000..73ba8198bf --- /dev/null +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -0,0 +1,67 @@ +""" +Utilities for the OpenAI endpoint +""" + +import logging +from argparse import Namespace +from copy import deepcopy + +from miles.router.sessions import GetSessionResponse, SessionRecord +from miles.utils.http_utils import post +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class OpenAIEndpointTracer: + def __init__(self, router_url: str, session_id: str): + self.router_url = router_url + self.session_id = session_id + self.base_url = f"{router_url}/sessions/{session_id}/v1" + + @staticmethod + async def create(args: Namespace): + router_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" + session_id = (await post(f"{router_url}/sessions", {}))["session_id"] + return OpenAIEndpointTracer(router_url=router_url, session_id=session_id) + + async def collect_records(self) -> list[SessionRecord]: + response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="get") + response = GetSessionResponse.model_validate(response) + records = response.records + + try: + await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="delete") + except Exception as e: + logger.warning(f"Failed to delete session {self.session_id} after collecting records: {e}") + + return records + + +def compute_samples_from_openai_records(input_sample: Sample, records: list[SessionRecord], tokenizer) -> list[Sample]: + return [_compute_sample_from_openai_record(input_sample, record, tokenizer) for record in records] + + +def _compute_sample_from_openai_record(input_sample: Sample, record: SessionRecord, tokenizer) -> Sample: + # TODO may refine after @guapisolo's implementation + choice = record.response["choices"][0] + output_token_ids = [item["token_id"] for item in choice["logprobs"]["content"]] + output_log_probs = [item["logprob"] for item in choice["logprobs"]["content"]] + + sample = deepcopy(input_sample) + sample.tokens = record.request["input_ids"] + output_token_ids + sample.rollout_log_probs = output_log_probs + sample.response = tokenizer.decode(output_token_ids) + sample.response_length = len(output_token_ids) + sample.loss_mask = [1] * len(output_token_ids) + + # TODO unify with Sample.update_from_meta_info + match choice["finish_reason"]: + case "stop" | "tool_calls": + sample.status = Sample.Status.COMPLETED + case "length": + sample.status = Sample.Status.TRUNCATED + case "abort": + sample.status = Sample.Status.ABORTED + + return sample diff --git a/miles/rollout/generate_utils/sample_utils.py b/miles/rollout/generate_utils/sample_utils.py new file mode 100644 index 0000000000..6a4e645be5 --- /dev/null +++ b/miles/rollout/generate_utils/sample_utils.py @@ -0,0 +1,115 @@ +from copy import deepcopy +from dataclasses import fields + +from miles.utils.types import Sample + + +def merge_samples(samples: list[Sample], tokenizer) -> Sample: + acc = samples[0] + for sample in samples[1:]: + acc = _merge_sample_pair(acc, sample, tokenizer=tokenizer) + return acc + + +def _merge_sample_pair(a: Sample, b: Sample, tokenizer) -> Sample: + """Merge two samples generated from sibling inference engine calls.""" + a, b = deepcopy(a), deepcopy(b) + + def _merge_equal_value(field): + x = getattr(a, field) + y = getattr(b, field) + assert x == y, f"{field} mismatch: a.{field}={x}, b.{field}={y}" + return x + + def _fill_defaults(sample: Sample): + if sample.loss_mask is None: + sample.loss_mask = [1] * sample.response_length + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [0.0] * sample.response_length + + _fill_defaults(a) + _fill_defaults(b) + + obs_len = len(b.tokens) - len(a.tokens) - b.response_length + obs_tokens = b.tokens[len(a.tokens) : len(a.tokens) + obs_len] + # TODO: is this acceptable? + obs_text = tokenizer.decode(obs_tokens) + + try: + a.validate() + b.validate() + assert _startswith(short=a.prompt, long=b.prompt), "b.prompt must start with a.prompt" + assert _startswith(short=a.tokens, long=b.tokens), "b.tokens must start with a.tokens" + assert obs_len > 0, f"obs_len must be > 0, got {obs_len}" + if a.rollout_routed_experts is not None: + assert a.rollout_routed_experts.shape[0] <= b.rollout_routed_experts.shape[0] + assert a.status == Sample.Status.COMPLETED, f"a.status must be COMPLETED, got {a.status}" + + return _create_with_all_fields( + Sample, + group_index=_merge_equal_value("group_index"), + index=_merge_equal_value("index"), + prompt=b.prompt, + tokens=b.tokens, + multimodal_inputs=_merge_equal_value("multimodal_inputs"), + multimodal_train_inputs=_merge_equal_value("multimodal_train_inputs"), + response=a.response + obs_text + b.response, + response_length=a.response_length + obs_len + b.response_length, + label=_merge_equal_value("label"), + reward=_merge_equal_value("reward"), + loss_mask=a.loss_mask + [0] * obs_len + b.loss_mask, + weight_versions=a.weight_versions + b.weight_versions, + rollout_log_probs=a.rollout_log_probs + [0.0] * obs_len + b.rollout_log_probs, + rollout_routed_experts=b.rollout_routed_experts, + remove_sample=_merge_equal_value("remove_sample"), + status=b.status, + metadata=_merge_equal_value("metadata"), + train_metadata=_merge_equal_value("train_metadata"), + non_generation_time=_merge_equal_value("non_generation_time"), + spec_info=_merge_spec_info(a.spec_info, b.spec_info), + prefix_cache_info=_merge_prefix_cache_info(a.prefix_cache_info, b.prefix_cache_info), + ) + except AssertionError as e: + e.add_note(f"{a=} {b=}") + raise + + +def _merge_spec_info(a: Sample.SpecInfo, b: Sample.SpecInfo) -> Sample.SpecInfo: + def _merge_plus_value(field): + return getattr(a, field) + getattr(b, field) + + return _create_with_all_fields( + Sample.SpecInfo, + spec_accept_token_num=_merge_plus_value("spec_accept_token_num"), + spec_draft_token_num=_merge_plus_value("spec_draft_token_num"), + spec_verify_ct=_merge_plus_value("spec_verify_ct"), + completion_token_num=_merge_plus_value("completion_token_num"), + ) + + +def _merge_prefix_cache_info(a: Sample.PrefixCacheInfo, b: Sample.PrefixCacheInfo) -> Sample.PrefixCacheInfo: + def _merge_plus_value(field): + return getattr(a, field) + getattr(b, field) + + return _create_with_all_fields( + Sample.PrefixCacheInfo, + cached_tokens=_merge_plus_value("cached_tokens"), + total_prompt_tokens=_merge_plus_value("total_prompt_tokens"), + ) + + +def _create_with_all_fields(cls, **kwargs): + expected = {f.name for f in fields(cls)} + actual = set(kwargs.keys()) + assert ( + expected == actual + ), f"{cls.__name__} field mismatch. Missing: {expected - actual}, Extra: {actual - expected}" + return cls(**kwargs) + + +def _startswith(*, short, long) -> bool: + if isinstance(short, str) and isinstance(long, str): + return long.startswith(short) + if isinstance(short, list) and isinstance(long, list): + return (len(long) >= len(short)) and (long[: len(short)] == short) + raise NotImplementedError diff --git a/miles/rollout/generate_utils/tool_call_utils.py b/miles/rollout/generate_utils/tool_call_utils.py new file mode 100644 index 0000000000..85ea87aeab --- /dev/null +++ b/miles/rollout/generate_utils/tool_call_utils.py @@ -0,0 +1,115 @@ +""" +Utils to handle tool calls. +""" + +import json +import uuid +from collections.abc import Callable +from typing import Any + +from openai.types.chat import ChatCompletionMessageToolCall +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.function_call_parser import FunctionCallParser + +from miles.utils.types import Sample + +_DUMMY_USER = {"role": "user", "content": "dummy"} + + +def create_tool_call_parser(tool_specs, tool_call_parser): + return FunctionCallParser( + tools=TypeAdapter(list[Tool]).validate_python(tool_specs), + tool_call_parser=tool_call_parser, + ) + + +async def execute_tool_calls( + tool_calls: list[ToolCallItem | ChatCompletionMessageToolCall], + execute_one: Callable, +) -> list[dict[str, Any]]: + tool_messages = [] + for call in tool_calls: + tool_messages.append(await _execute_tool_call(call, execute_one)) + return tool_messages + + +async def _execute_tool_call( + call: ToolCallItem | ChatCompletionMessageToolCall, execute_one: Callable +) -> dict[str, Any]: + if isinstance(call, ChatCompletionMessageToolCall): + name = call.function.name + params = json.loads(call.function.arguments) if call.function.arguments else {} + tool_call_id = call.id + elif isinstance(call, ToolCallItem): + name = call.name + params = json.loads(call.parameters) if call.parameters else {} + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + else: + raise TypeError(f"Unsupported tool call type: {type(call)}") + + result = await execute_one(name, params) + assert isinstance(result, str) + + return {"role": "tool", "tool_call_id": tool_call_id, "content": result, "name": name} + + +def update_sample_with_tool_responses(sample: Sample, tool_messages: list[dict[str, Any]], tokenizer): + next_obs_tokens_ids: list[int] = tokenize_tool_responses(tool_messages, tokenizer=tokenizer) + sample.response += tokenizer.decode(next_obs_tokens_ids) + sample.response_length += len(next_obs_tokens_ids) + sample.tokens += next_obs_tokens_ids + sample.loss_mask += [0] * len(next_obs_tokens_ids) + sample.rollout_log_probs += [0.0] * len(next_obs_tokens_ids) + + +# TODO: very naive implementation, need the to-be-implemented e2e test to validate. +def tokenize_tool_responses( + tool_messages: list[dict[str, Any]], + tokenizer, +) -> list[int]: + return _tokenize_postfix_messages(tool_messages, tokenizer) + + +def _tokenize_postfix_messages( + postfix_messages: list[dict[str, Any]], + tokenizer, +) -> list[int]: + dummy_assistant = _build_dummy_assistant(postfix_messages) + base_messages = [_DUMMY_USER, dummy_assistant] + + messages_without = base_messages + messages_with = base_messages + postfix_messages + + tokens_with = tokenizer.apply_chat_template(messages_with, tokenize=True, add_generation_prompt=True) + tokens_without = tokenizer.apply_chat_template(messages_without, tokenize=True, add_generation_prompt=False) + + assert tokens_with[: len(tokens_without)] == tokens_without, ( + f"Fail to tokenize_tool_responses caused by token prefix mismatch. " + f"This can happen for thinking model or models with special chat template, " + f"and this simple example does not support it yet, " + f"since this means we cannot have a append-only token id list. " + f"{tokens_with=} {tokens_without=} " + f"{tokenizer.decode(tokens_with)=} {tokenizer.decode(tokens_without)=} " + ) + return tokens_with[len(tokens_without) :] + + +def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, Any]: + return { + "role": "assistant", + "content": "", + "reasoning_content": " ", + "tool_calls": [ + { + "id": resp.get("tool_call_id", f"call0000{i}"), + "type": "function", + "function": { + "name": resp.get("name", "dummy_func"), + "arguments": {}, + }, + } + for i, resp in enumerate(tool_responses) + ], + } diff --git a/miles/rollout/inference_rollout/__init__.py b/miles/rollout/inference_rollout/__init__.py new file mode 100644 index 0000000000..33ccf17bfb --- /dev/null +++ b/miles/rollout/inference_rollout/__init__.py @@ -0,0 +1,2 @@ +# This is a refactor of the portions above generate-function in sglang_rollout.py, +# and is give a different name to ensure both code exist at the same time. diff --git a/miles/rollout/inference_rollout/compatibility.py b/miles/rollout/inference_rollout/compatibility.py new file mode 100644 index 0000000000..7711e0dd31 --- /dev/null +++ b/miles/rollout/inference_rollout/compatibility.py @@ -0,0 +1,84 @@ +import inspect +from collections.abc import Callable + +from miles.rollout.base_types import ( + GenerateFnInput, + GenerateFnOutput, + RolloutFnConstructorInput, + RolloutFnEvalOutput, + RolloutFnInput, + RolloutFnOutput, + RolloutFnTrainOutput, +) +from miles.utils.async_utils import run +from miles.utils.misc import load_function + + +class LegacyRolloutFnAdapter: + def __init__(self, input: RolloutFnConstructorInput, fn: Callable): + self.args = input.args + self.data_source = input.data_source + self.fn = fn + + def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: + output = self.fn(self.args, input.rollout_id, self.data_source, evaluation=input.evaluation) + + # compatibility for legacy version + if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)): + output = RolloutFnEvalOutput(data=output) if input.evaluation else RolloutFnTrainOutput(samples=output) + + return output + + +def load_rollout_function(input: RolloutFnConstructorInput, path: str): + fn = load_function(path) + + if inspect.isclass(fn): + return fn(input) + else: + return LegacyRolloutFnAdapter(input, fn) + + +def call_rollout_function(fn, input: RolloutFnInput) -> RolloutFnOutput: + output = fn(input) + + if inspect.iscoroutine(output): + output = run(output) + + return output + + +class LegacyGenerateFnAdapter: + def __init__(self, fn: Callable): + self.fn = fn + self._has_evaluation_param = "evaluation" in inspect.signature(fn).parameters + + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: + if self._has_evaluation_param: + output = await self.fn(input.args, input.sample, input.sampling_params, evaluation=input.evaluation) + else: + output = await self.fn(input.args, input.sample, input.sampling_params) + + if not isinstance(output, GenerateFnOutput): + output = GenerateFnOutput(samples=output) + + return output + + +def load_generate_function(path: str): + fn = load_function(path) + if fn is None: + return None + + if inspect.isclass(fn): + return fn() + elif _is_legacy_generate_fn(fn): + return LegacyGenerateFnAdapter(fn) + else: + return fn + + +def _is_legacy_generate_fn(fn: Callable) -> bool: + sig = inspect.signature(fn) + params = list(sig.parameters.keys()) + return len(params) >= 3 and params[0] != "input" diff --git a/miles/rollout/inference_rollout/inference_rollout_common.py b/miles/rollout/inference_rollout/inference_rollout_common.py new file mode 100644 index 0000000000..8518c6e020 --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_common.py @@ -0,0 +1,192 @@ +import asyncio +import logging +from argparse import Namespace +from copy import deepcopy +from typing import Any + +from miles.rollout.base_types import ( + GenerateFnInput, + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnEvalOutput, + RolloutFnInput, + RolloutFnOutput, + RolloutFnTrainInput, + RolloutFnTrainOutput, +) +from miles.rollout.generate_hub.single_turn import generate +from miles.rollout.inference_rollout.compatibility import load_generate_function +from miles.rollout.rm_hub import async_rm, batched_async_rm +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class GenerateState: + def __init__(self, args: Namespace) -> None: + # persistent state for the generation process + self.args = args + self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + + self.generate_fn_semaphore = asyncio.Semaphore( + args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + ) + self.sampling_params: dict[str, Any] = compute_sampling_params( + args, + temperature=args.rollout_temperature, + top_p=args.rollout_top_p, + top_k=args.rollout_top_k, + max_new_tokens=args.rollout_max_response_len, + ) + + self.generate_function = load_generate_function(args.custom_generate_function_path) or generate + + self.reset() + + def reset(self) -> None: + self.aborted = False + + +async def generate_and_rm( + state: GenerateState, + sample: Sample | list[Sample], + sampling_params: dict[str, Any], + evaluation: bool = False, +) -> Sample | list[Sample]: + args = state.args + + # mask previous off-policy generation for partial rollout + if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: + sample.loss_mask = [0] * sample.response_length + + # For samples with existing response, check if they're complete + if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED: + assert sample.response is not None + if not args.group_rm: + assert sample.reward is not None + return sample + + # generate + async with state.generate_fn_semaphore: + if state.aborted: + sample.status = Sample.Status.ABORTED + return sample + + output = await state.generate_function( + GenerateFnInput( + state=state, + sample=sample, + sampling_params=deepcopy(sampling_params), + evaluation=evaluation, + ) + ) + sample = output.samples + + # TODO change to `if not args.group_rm: do reward model` for more clarity after the refactor below + # for the rm that need the whole group, we will not do the rm here + if args.group_rm: + return sample + + # TODO: unify the two branches into one if we decide to use list as output type + # multi samples + if isinstance(sample, list): + samples = sample + if any([sample.status == Sample.Status.ABORTED for sample in samples]): + return samples + + # for multi agent system, the reward of some sample is calculated during generation. + samples_need_reward = [sample for sample in samples if sample.reward is None] + await batched_async_rm(args, samples_need_reward, inplace_set_reward_field=True) + return samples + else: + if sample.status == Sample.Status.ABORTED: + return sample + # for multi-turn environment, a reward could be assigned to the agent. + if sample.reward is None: + sample.reward = await async_rm(args, sample) + + return sample + + +async def generate_and_rm_group( + state: GenerateState, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False +) -> list[Sample]: + args = state.args + + if state.aborted: + return group + + tasks = [] + for idx, sample in enumerate(group): + current_sampling_params = sampling_params.copy() + if getattr(args, "sglang_enable_deterministic_inference", False): + current_sampling_params["sampling_seed"] = args.rollout_seed + idx + tasks.append( + asyncio.create_task(generate_and_rm(state, sample, current_sampling_params, evaluation=evaluation)) + ) + + group = await asyncio.gather(*tasks) + if state.aborted: + return group + + if args.group_rm: + await batched_async_rm(args, group, inplace_set_reward_field=True) + + return group + + +def compute_sampling_params( + args, + *, + # after unifying configuration, this can be further refactored + temperature, + top_p, + top_k, + max_new_tokens, +): + return dict( + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_new_tokens=max_new_tokens, + stop=args.rollout_stop, + stop_token_ids=args.rollout_stop_token_ids, + skip_special_tokens=args.rollout_skip_special_tokens, + no_stop_trim=True, + spaces_between_special_tokens=False, + ) + + +class InferenceRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + self.data_source = input.data_source + self.state = GenerateState(input.args) + self.eval_prompt_dataset_cache = {} + + async def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: + if input.evaluation: + return await self._call_eval(input) + return await self._call_train(input) + + async def _call_train(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput: + from miles.rollout.inference_rollout.inference_rollout_train import generate_rollout_async + + output, aborted_samples = await generate_rollout_async( + self.state, input.rollout_id, self.data_source.get_samples + ) + self.data_source.add_samples(aborted_samples) + return output + + async def _call_eval(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput: + from miles.rollout.inference_rollout.inference_rollout_eval import eval_rollout_single_dataset + + assert not self.state.args.group_rm, "Group RM is not supported for eval rollout" + + coros = [] + for dataset_cfg in getattr(self.state.args, "eval_datasets", []) or []: + coros.append(eval_rollout_single_dataset(self.state, dataset_cfg, self.eval_prompt_dataset_cache)) + results_list = await asyncio.gather(*coros) + results = {k: v for r in results_list for k, v in r.items()} + return RolloutFnEvalOutput(data=results) diff --git a/miles/rollout/inference_rollout/inference_rollout_eval.py b/miles/rollout/inference_rollout/inference_rollout_eval.py new file mode 100644 index 0000000000..2d052be0ae --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_eval.py @@ -0,0 +1,112 @@ +import asyncio +import copy +import logging +from typing import Any + +from tqdm import tqdm + +from miles.rollout.inference_rollout.inference_rollout_common import ( + GenerateState, + compute_sampling_params, + generate_and_rm, +) +from miles.utils.data import Dataset +from miles.utils.eval_config import EvalDatasetConfig +from miles.utils.misc import as_completed_async +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +async def eval_rollout_single_dataset( + state: GenerateState, + dataset_cfg: EvalDatasetConfig, + prompt_dataset_cache: dict[Any, Dataset], +) -> dict[str, dict[str, list[Any]]]: + args = state.args + assert not args.group_rm, "Group RM is not supported for eval rollout" + + cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template) + if cache_key not in prompt_dataset_cache: + tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + prompt_dataset_cache[cache_key] = Dataset( + path=dataset_cfg.path, + tokenizer=tokenizer, + processor=processor, + max_length=args.eval_max_prompt_len, + prompt_key=dataset_cfg.input_key, + label_key=dataset_cfg.label_key, + multimodal_keys=args.multimodal_keys, + metadata_key=dataset_cfg.metadata_key, + tool_key=dataset_cfg.tool_key, + apply_chat_template=args.apply_chat_template, + apply_chat_template_kwargs=args.apply_chat_template_kwargs, + ) + dataset = prompt_dataset_cache[cache_key] + + base_sampling_params = compute_sampling_params( + args, + temperature=dataset_cfg.temperature, + top_p=dataset_cfg.top_p, + top_k=dataset_cfg.top_k, + max_new_tokens=dataset_cfg.max_response_len, + ) + + tasks = [] + # do multiple samples for eval prompts + sample_index = 0 + for _i, prompt_sample in enumerate(dataset.samples): + for j in range(dataset_cfg.n_samples_per_eval_prompt): + # use the same prompt for multiple samples + sample = copy.deepcopy(prompt_sample) + sample.index = sample_index + sample_index += 1 + sample.metadata = dataset_cfg.inject_metadata(getattr(sample, "metadata", None)) + sampling_params = base_sampling_params + if getattr(args, "sglang_enable_deterministic_inference", False): + sampling_params = base_sampling_params.copy() + sampling_params["sampling_seed"] = args.rollout_seed + j + tasks.append( + asyncio.create_task( + generate_and_rm( + state, + sample, + sampling_params=sampling_params, + evaluation=True, + ) + ) + ) + + data = [] + do_print = True + pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) + async for sample in as_completed_async(tasks): + if do_print: + # TODO improve this after enhancing samples' type + s = (sample[0] if len(sample) > 0 else None) if isinstance(sample, list) else sample + if s is not None: + logger.info( + "eval_rollout_single_dataset example data: " + f"{[str(s.prompt) + s.response]} " + f"reward={s.reward}" + ) + do_print = False + if isinstance(sample, list): + data.extend(sample) + else: + data.append(sample) + pbar.update(1) + pbar.close() + + data.sort(key=lambda sample: sample.index) + + reward_key = args.eval_reward_key or args.reward_key + return { + dataset_cfg.name: { + "rewards": [sample.reward if not reward_key else sample.reward[reward_key] for sample in data], + "truncated": [sample.status == Sample.Status.TRUNCATED for sample in data], + "samples": data, + } + } diff --git a/miles/rollout/inference_rollout/inference_rollout_train.py b/miles/rollout/inference_rollout/inference_rollout_train.py new file mode 100644 index 0000000000..bae94ec67b --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_train.py @@ -0,0 +1,146 @@ +import asyncio +import logging +from argparse import Namespace +from collections.abc import Callable + +import sglang_router +from packaging.version import parse +from tqdm import tqdm + +from miles.rollout.base_types import RolloutFnTrainOutput +from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter +from miles.rollout.inference_rollout.inference_rollout_common import GenerateState, generate_and_rm_group +from miles.utils.http_utils import get, post +from miles.utils.misc import as_completed_async, load_function +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +async def abort(state: GenerateState, pendings: set, rollout_id: int) -> list[list[Sample]]: + args = state.args + + assert not state.aborted + state.aborted = True + + urls = await get_worker_urls(args) + logger.info(f"Abort request for {urls}") + await asyncio.gather(*[post(f"{url}/abort_request", {"abort_all": True}) for url in urls]) + + # make sure all the pending tasks are finished + aborted_samples = [] + async for group in as_completed_async(pendings): + if not args.partial_rollout: + continue + + # for partial rollout, collect the partial samples into the data buffer + for sample in group: + if sample.response and "start_rollout_id" not in sample.metadata: + sample.metadata["start_rollout_id"] = rollout_id + aborted_samples.append(group) + + if args.partial_rollout: + logger.info(f"Collected {sum(len(x) for x in aborted_samples)} partial samples into the data buffer") + + return aborted_samples + + +async def get_worker_urls(args: Namespace): + if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") + return response["urls"] + else: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") + return [worker["url"] for worker in response["workers"]] + + +def submit_generate_tasks(state: GenerateState, samples: list[list[Sample]]): + return [ + asyncio.create_task( + # submit a group of samples as a single task. + generate_and_rm_group( + state, + group, + sampling_params=state.sampling_params.copy(), + evaluation=False, + ) + ) + for group in samples + ] + + +async def generate_rollout_async( + state: GenerateState, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] +) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: + args = state.args + assert args.rollout_global_dataset + + # instantiate data filters + dynamic_filter = load_function(args.dynamic_sampling_filter_path) + + metric_gatherer = MetricGatherer() + + # target_data_size is the total number of valid samples to get + target_data_size = args.rollout_batch_size + + pendings = set() + data = [] + all_data = [] + do_print = True + pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation") + while len(data) < target_data_size: + while len(data) + len(pendings) < target_data_size: + # get samples from the buffer and submit the generation requests. + samples = data_source(args.over_sampling_batch_size) + pendings.update(submit_generate_tasks(state, samples)) + + # wait for the generation to finish + done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED) + for task in done: + group: list[Sample] = task.result() + + if do_print: + sample = group[0][0] if isinstance(group[0], list) else group[0] + logger.info( + f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", + ) + do_print = False + + assert len(group) == args.n_samples_per_prompt + all_data.append(group) + dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) + if not dynamic_filter_output.keep: + metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) + continue + + # add the samples to the data + # NOTE: here we have not stored all the unused samples back to the data buffer. + if len(data) < target_data_size: + data.append(group) + pbar.update(args.n_samples_per_prompt) + + pbar.close() + sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] + logger.info( + f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", + ) + + # there are still some unfinished requests, abort them + aborted_samples = await abort(state, pendings, rollout_id) + + assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" + data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) + all_samples = sorted( + all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index + ) + + # reset the global state to prevent effects on the next rollout or eval. + state.reset() + + if f := load_function(args.rollout_sample_filter_path): + f(args, data) + # There can be circumstances where users want to process all samples including filtered ones. + if f := load_function(args.rollout_all_samples_process_path): + f(args, all_samples, data_source) + + return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples diff --git a/miles/rollout/rm_hub/__init__.py b/miles/rollout/rm_hub/__init__.py index 62b253ddee..e9ee29db41 100644 --- a/miles/rollout/rm_hub/__init__.py +++ b/miles/rollout/rm_hub/__init__.py @@ -69,8 +69,18 @@ async def async_rm(args, sample: Sample, **kwargs): async def batched_async_rm( args, samples: list[Sample], + inplace_set_reward_field: bool = False, **kwargs, -) -> list[int | float]: +) -> list[int | float] | None: + if inplace_set_reward_field: + rewards = await batched_async_rm(args, samples, **kwargs) + for sample, reward in zip(samples, rewards, strict=True): + assert ( + sample.reward is None + ), f"Overriding sample.reward from {sample.reward} to {reward}, is this intended?" + sample.reward = reward + return None + if args.custom_rm_path is not None: # Ensure the custom reward function is implemented in batch mode rm_function = load_function(args.custom_rm_path) diff --git a/miles/router/router.py b/miles/router/router.py index 2e8ecfc41f..7d3ecd9806 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse from starlette.responses import Response +from miles.router.sessions import setup_session_routes from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -69,6 +70,8 @@ def _setup_routes(self): self.app.post("/add_worker")(self.add_worker) self.app.get("/list_workers")(self.list_workers) self.app.post("/retrieve_from_text")(self.retrieve_from_text) + # Session routes - must be registered before catch-all + setup_session_routes(self.app, self) # Catch-all route for proxying to SGLang - must be registered LAST self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) @@ -130,39 +133,41 @@ async def _health_check_loop(self): async def proxy(self, request: Request, path: str): """Proxy all other requests to the SGLang router""" - # Forward all other paths to SGLang router + result = await self._do_proxy(request, path) + return self._build_proxy_response(result) + + async def _do_proxy(self, request: Request, path: str) -> dict: + """Core proxy logic. Returns dict with request_body, response_body, status_code, headers.""" worker_url = self._use_url() url = f"{worker_url}/{path}" - # Get request body and headers body = await request.body() headers = dict(request.headers) try: response = await self.client.request(request.method, url, content=body, headers=headers) - # Eagerly read content so we can return JSON (not streaming) content = await response.aread() - content_type = response.headers.get("content-type", "") - try: - # Prefer parsing JSON if possible - data = json.loads(content) - return JSONResponse( - content=data, - status_code=response.status_code, - headers=dict(response.headers), - ) - except Exception: - # Fall back to raw body with original content type - return Response( - content=content, - status_code=response.status_code, - headers=dict(response.headers), - media_type=content_type or None, - ) - + return { + "request_body": body, + "response_body": content, + "status_code": response.status_code, + "headers": dict(response.headers), + } finally: self._finish_url(worker_url) + def _build_proxy_response(self, result: dict) -> Response: + """Build HTTP response from proxy result.""" + content = result["response_body"] + status_code = result["status_code"] + headers = result["headers"] + content_type = headers.get("content-type", "") + try: + data = json.loads(content) + return JSONResponse(content=data, status_code=status_code, headers=headers) + except Exception: + return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) + async def add_worker(self, request: Request): """Add a new worker to the router. Supports providing the URL via query string or JSON body. diff --git a/miles/router/sessions.py b/miles/router/sessions.py new file mode 100644 index 0000000000..9d753e5975 --- /dev/null +++ b/miles/router/sessions.py @@ -0,0 +1,124 @@ +import json +import time +import uuid +from typing import TYPE_CHECKING + +from fastapi import Request +from fastapi.responses import JSONResponse, Response +from pydantic import BaseModel +from transformers import AutoTokenizer + +if TYPE_CHECKING: + from miles.router.router import MilesRouter + + +class SessionRecord(BaseModel): + timestamp: float + method: str + path: str + request: dict + response: dict + status_code: int + + +class GetSessionResponse(BaseModel): + session_id: str + records: list[SessionRecord] + + +class SessionManager: + def __init__(self): + self.sessions: dict[str, list[SessionRecord]] = {} + + def create_session(self) -> str: + session_id = uuid.uuid4().hex + self.sessions[session_id] = [] + return session_id + + def get_session(self, session_id: str) -> list[SessionRecord] | None: + return self.sessions.get(session_id) + + def delete_session(self, session_id: str) -> list[SessionRecord]: + assert session_id in self.sessions + return self.sessions.pop(session_id) + + def add_record(self, session_id: str, record: SessionRecord): + assert session_id in self.sessions + self.sessions[session_id].append(record) + + +def setup_session_routes(app, router: "MilesRouter"): + manager = SessionManager() + + # TODO temporary hack before @guapisolo implements TITO + # ============================= HACK START =============================== + # Lazy load tokenizer only when needed (for tests that don't have hf_checkpoint) + tokenizer = None + + def get_tokenizer(): + nonlocal tokenizer + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(router.args.hf_checkpoint, trust_remote_code=True) + return tokenizer + + # ============================= HACK END =============================== + + @app.post("/sessions") + async def create_session(): + session_id = manager.create_session() + return {"session_id": session_id} + + @app.get("/sessions/{session_id}") + async def get_session(session_id: str): + records = manager.get_session(session_id) + if records is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + return GetSessionResponse(session_id=session_id, records=records) + + @app.delete("/sessions/{session_id}") + async def delete_session(session_id: str): + if session_id not in manager.sessions: + return JSONResponse(status_code=404, content={"error": "session not found"}) + manager.delete_session(session_id) + return Response(status_code=204) + + @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) + async def session_proxy(request: Request, session_id: str, path: str): + if session_id not in manager.sessions: + return JSONResponse(status_code=404, content={"error": "session not found"}) + + result = await router._do_proxy(request, path) + + request_body = json.loads(result["request_body"]) + response_body = json.loads(result["response_body"]) + + # TODO: remove this hack when @guapisolo implements the real TITO + # ============================= HACK START =============================== + if "messages" in request_body and "input_ids" not in request_body: + request_body["input_ids"] = get_tokenizer().apply_chat_template( + request_body["messages"], + add_generation_prompt=True, + add_special_tokens=False, + tools=request_body.get("tools"), + ) + if ( + "logprobs" in response_body.get("choices", [{}])[0] + and "content" in response_body["choices"][0]["logprobs"] + ): + logprobs_content = response_body["choices"][0]["logprobs"]["content"] + for item in logprobs_content: + if "token" in item and "token_id" not in item: + item["token_id"] = get_tokenizer().convert_tokens_to_ids(item["token"]) + # ============================= HACK END =============================== + + record = SessionRecord( + timestamp=time.time(), + method=request.method, + path=path, + request=request_body, + response=response_body, + status_code=result["status_code"], + ) + manager.add_record(session_id, record) + + return router._build_proxy_response(result) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 79b2c419ca..0710202924 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -10,8 +10,10 @@ from miles.backends.sglang_utils.arguments import add_sglang_arguments from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args +from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from miles.utils.logging_utils import configure_logger +from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -204,7 +206,11 @@ def add_rollout_arguments(parser): parser.add_argument( "--rollout-function-path", type=str, - default="miles.rollout.sglang_rollout.generate_rollout", + default=( + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn" + if enable_experimental_rollout_refactor() + else "miles.rollout.sglang_rollout.generate_rollout" + ), help=( "Path to the rollout generation function." "You should use this model to create your own custom rollout function, " @@ -1344,6 +1350,20 @@ def add_ci_arguments(parser): ) return parser + def add_user_provided_function_arguments(parser): + args_partial, _ = parser.parse_known_args() + for path in [ + args_partial.rollout_function_path, + args_partial.custom_generate_function_path, + ]: + try: + fn = load_function(path) + except (ModuleNotFoundError, ValueError): + continue + if fn is not None and callable(getattr(fn, "add_arguments", None)): + fn.add_arguments(parser) + return parser + def add_sglang_tp_size(): temp_parser = argparse.ArgumentParser(add_help=False) temp_parser.add_argument("--rollout-num-gpus-per-engine", type=int, default=1) @@ -1374,6 +1394,8 @@ def add_sglang_tp_size(): parser = add_prefill_decode_disaggregation_arguments(parser) parser = add_ci_arguments(parser) parser = add_custom_megatron_plugins_arguments(parser) + if enable_experimental_rollout_refactor(): + parser = add_user_provided_function_arguments(parser) reset_arg( parser, "--custom-config-path", diff --git a/miles/utils/environ.py b/miles/utils/environ.py new file mode 100644 index 0000000000..35d1f350ee --- /dev/null +++ b/miles/utils/environ.py @@ -0,0 +1,14 @@ +import os + +_printed_experimental_rollout_refactor = False + + +def enable_experimental_rollout_refactor() -> bool: + result = bool(int(os.environ.get("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", "0"))) + + global _printed_experimental_rollout_refactor + if result and not _printed_experimental_rollout_refactor: + print("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1 is enabled (experimental feature)") + _printed_experimental_rollout_refactor = True + + return result diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 2b3e6e192f..0abdbbf59d 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -162,11 +162,15 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60): +async def _post(client, url, payload, max_retries=60, action="post"): retry_count = 0 while retry_count < max_retries: try: - response = await client.post(url, json=payload or {}) + if action in ("delete", "get"): + assert not payload + response = await getattr(client, action)(url) + else: + response = await getattr(client, action)(url, json=payload or {}) response.raise_for_status() try: output = response.json() @@ -240,8 +244,8 @@ def __init__(self, concurrency: int): timeout=httpx.Timeout(None), ) - async def do_post(self, url, payload, max_retries=60): - return await _post(self._client, url, payload, max_retries) + async def do_post(self, url, payload, max_retries=60, action="post"): + return await _post(self._client, url, payload, max_retries, action=action) # Create actors per node created = [] @@ -265,7 +269,8 @@ async def do_post(self, url, payload, max_retries=60): _post_actors = created -async def post(url, payload, max_retries=60): +# TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) +async def post(url, payload, max_retries=60, action="post"): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: @@ -274,15 +279,16 @@ async def post(url, payload, max_retries=60): actor = _next_actor() if actor is not None: # Use a thread to avoid blocking the event loop on ray.get - obj_ref = actor.do_post.remote(url, payload, max_retries) + obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) return await asyncio.to_thread(ray.get, obj_ref) except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries) + return await _post(_http_client, url, payload, max_retries, action=action) +# TODO unify w/ `post` to add retries and remote-execution async def get(url): response = await _http_client.get(url) response.raise_for_status() diff --git a/miles/utils/misc.py b/miles/utils/misc.py index c0a96d6366..bae72ec0d7 100644 --- a/miles/utils/misc.py +++ b/miles/utils/misc.py @@ -1,17 +1,55 @@ +import asyncio import importlib import subprocess +from contextlib import contextmanager import ray from miles.utils.http_utils import is_port_available +# Mainly used for test purpose where `load_function` needs to load many in-flight generated functions +class FunctionRegistry: + def __init__(self): + self._registry: dict[str, object] = {} + + @contextmanager + def temporary(self, name: str, fn: object): + self._register(name, fn) + try: + yield + finally: + self._unregister(name) + + def get(self, name: str) -> object | None: + return self._registry.get(name) + + def _register(self, name: str, fn: object) -> None: + assert name not in self._registry + self._registry[name] = fn + + def _unregister(self, name: str) -> None: + assert name in self._registry + self._registry.pop(name) + + +function_registry = FunctionRegistry() + + +# TODO may rename to `load_object` since it can be used to load things like tool_specs def load_function(path): """ - Load a function from a module. + Load a function from registry or module. :param path: The path to the function, e.g. "module.submodule.function". :return: The function object. """ + if path is None: + return None + + registered = function_registry.get(path) + if registered is not None: + return registered + module_path, _, attr = path.rpartition(".") module = importlib.import_module(module_path) return getattr(module, attr) @@ -30,8 +68,9 @@ def __call__(cls, *args, **kwargs): cls._instances[cls] = instance return cls._instances[cls] - def clear_instances(cls): - cls._instances = {} + @staticmethod + def clear_all_instances(): + SingletonMeta._instances.clear() def exec_command(cmd: str, capture_output: bool = False) -> str | None: @@ -92,3 +131,8 @@ def should_run_periodic_action( step = rollout_id + 1 return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0) + + +async def as_completed_async(tasks): + for coro in asyncio.as_completed(tasks): + yield await coro diff --git a/miles/utils/test_utils/__init__.py b/miles/utils/test_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py new file mode 100644 index 0000000000..2c0dddfe54 --- /dev/null +++ b/miles/utils/test_utils/mock_sglang_server.py @@ -0,0 +1,248 @@ +import asyncio +import re +import time +import uuid +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import asdict, dataclass + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.function_call_parser import FunctionCallParser +from transformers import AutoTokenizer + +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +@dataclass(frozen=True) +class ProcessResultMetaInfo: + weight_version: str | None = None + routed_experts: str | None = None + spec_accept_token_num: int | None = None + spec_draft_token_num: int | None = None + spec_verify_ct: int | None = None + + def to_dict(self) -> dict: + return {k: v for k, v in asdict(self).items() if v is not None} + + +@dataclass(frozen=True) +class ProcessResult: + text: str + finish_reason: str = "stop" + cached_tokens: int = 0 + meta_info: ProcessResultMetaInfo = ProcessResultMetaInfo() + + +ProcessFn = Callable[[str], ProcessResult] + + +class MockSGLangServer: + def __init__( + self, + model_name: str, + process_fn: ProcessFn, + host: str, + port: int, + latency: float = 0.0, + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.process_fn = process_fn + self.host = host + self.port = port or find_available_port(30000) + self.latency = latency + + self.app = FastAPI() + self._server: UvicornThreadServer | None = None + + self.request_log: list[dict] = [] + self._concurrency = Counter() + + self._setup_routes() + + @property + def max_concurrent(self) -> int: + return self._concurrency.max_value + + def reset_stats(self): + self.request_log.clear() + self._concurrency.reset() + + def start(self): + self._server = UvicornThreadServer(self.app, host=self.host, port=self.port) + self._server.start() + + def stop(self): + if self._server is not None: + self._server.stop() + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + def _setup_routes(self): + @self.app.post("/generate") + async def generate(request: Request): + return await self._handle_generate_like_request(request, self._compute_generate_response) + + @self.app.post("/v1/chat/completions") + async def chat_completions(request: Request): + return await self._handle_generate_like_request(request, self._compute_chat_completions_response) + + @self.app.get("/health") + async def health(): + return JSONResponse(content={"status": "ok"}) + + @self.app.post("/abort_request") + async def abort_request(_request: Request): + return JSONResponse(content={"status": "ok"}) + + async def _handle_generate_like_request(self, request: Request, compute_fn: Callable[[dict], dict]): + payload = await request.json() + self.request_log.append(payload) + with self._concurrency.track(): + if self.latency > 0: + await asyncio.sleep(self.latency) + response = compute_fn(payload) + return JSONResponse(content=response) + + def _compute_generate_response(self, payload: dict) -> dict: + assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True" + input_ids = payload.get("input_ids", []) + + prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False) + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + prompt_tokens = len(input_ids) + completion_tokens = len(output_ids) + + finish_reason_dict = {"type": process_result.finish_reason} + if process_result.finish_reason == "length": + finish_reason_dict["length"] = completion_tokens + + output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] + + meta_info = { + "finish_reason": finish_reason_dict, + "prompt_tokens": prompt_tokens, + "cached_tokens": process_result.cached_tokens, + "completion_tokens": completion_tokens, + "output_token_logprobs": output_token_logprobs, + **process_result.meta_info.to_dict(), + } + + return {"text": process_result.text, "meta_info": meta_info} + + def _compute_chat_completions_response(self, payload: dict) -> dict: + messages = payload.get("messages", []) + tools = payload.get("tools") + + prompt_str = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, tools=tools + ) + + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + logprobs_content = [ + {"token": self.tokenizer.convert_ids_to_tokens(tid), "logprob": -1 / 128 * i} + for i, tid in enumerate(output_ids) + ] + + finish_reason = process_result.finish_reason + tool_calls = None + if tools and finish_reason == "stop": + parser = FunctionCallParser( + tools=TypeAdapter(list[Tool]).validate_python(tools), + tool_call_parser="qwen25", + ) + message_content, parsed_calls = parser.parse_non_stream(process_result.text) + if parsed_calls: + finish_reason = "tool_calls" + tool_calls = [ + { + "id": f"call{i:05d}", + "type": "function", + "function": {"name": call.name, "arguments": call.parameters or "{}"}, + } + for i, call in enumerate(parsed_calls) + ] + else: + message_content = process_result.text + + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message_content, + "tool_calls": tool_calls, + }, + "logprobs": {"content": logprobs_content}, + "finish_reason": finish_reason, + } + ], + } + + +class Counter: + def __init__(self): + self._current = 0 + self._max = 0 + + @property + def max_value(self) -> int: + return self._max + + def reset(self): + self._current = 0 + self._max = 0 + + @contextmanager + def track(self): + self._current += 1 + self._max = max(self._max, self._current) + try: + yield + finally: + self._current -= 1 + + +def default_process_fn(prompt: str) -> ProcessResult: + match = re.search(r"What is 1\+(\d+)\?", prompt) + if match: + num = int(match.group(1)) + ans = 1 + num + return ProcessResult(text=f"\\boxed{{{ans}}}", finish_reason="stop") + return ProcessResult(text="I don't understand.", finish_reason="stop") + + +@contextmanager +def with_mock_server( + model_name: str = "Qwen/Qwen3-0.6B", + process_fn: ProcessFn = default_process_fn, + host: str = "127.0.0.1", + port: int | None = None, + latency: float = 0.0, +): + server = MockSGLangServer( + model_name=model_name, + process_fn=process_fn, + host=host, + port=port, + latency=latency, + ) + try: + server.start() + yield server + finally: + server.stop() diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py new file mode 100644 index 0000000000..6b99e36739 --- /dev/null +++ b/miles/utils/test_utils/mock_tools.py @@ -0,0 +1,268 @@ +import json + +from transformers import AutoTokenizer + +from miles.utils.test_utils.mock_sglang_server import ProcessResult + +SAMPLE_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_year", + "description": "Get current year", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_temperature", + "description": "Get temperature for a location", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + }, +] + + +def _get_year(params: dict) -> str: + assert len(params) == 0 + return json.dumps({"year": 2026}) + + +def _get_temperature(params: dict) -> str: + temps = {"Mars": -60, "Earth": 15} + location = params.get("location") + assert location in temps, f"Unknown location: {location}" + return json.dumps({"temperature": temps[location]}) + + +TOOL_EXECUTORS = { + "get_year": _get_year, + "get_temperature": _get_temperature, +} + + +async def execute_tool_call(name: str, params: dict) -> str: + return TOOL_EXECUTORS[name](params) + + +_SYSTEM_PROMPT = ( + "<|im_start|>system\n" + "# Tools\n" + "\n" + "You may call one or more functions to assist with the user query.\n" + "\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n" + "\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" +) + + +_TOKENIZER = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) + + +class TwoTurnStub: + """Stub for 2-turn: get_year + get_temperature(Mars) -> final answer""" + + USER_QUESTION = "What is 42 + year + temperature?" + + FIRST_RESPONSE = ( + "Let me get the year and temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "<|im_end|>\n" + ) + + FIRST_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + SECOND_RESPONSE = "The answer is: 42 + 2026 + -60 = 2008." + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = "Let me get the year and temperature first." + FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + TwoTurnStub.FIRST_PROMPT: TwoTurnStub.FIRST_RESPONSE, + TwoTurnStub.SECOND_PROMPT: TwoTurnStub.SECOND_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") + + +class ThreeTurnStub: + """Stub for 3-turn: get_year + get_temperature(Mars) -> get_temperature(Earth) -> final answer""" + + USER_QUESTION = "What is 42 + year + Mars temperature + Earth temperature?" + + FIRST_RESPONSE = ( + "Let me get the year and Mars temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "<|im_end|>\n" + ) + + SECOND_RESPONSE = ( + "Now let me get Earth temperature.\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Earth"}}\n' + "<|im_end|>\n" + ) + + FIRST_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + SECOND_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"temperature": 15}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + THIRD_RESPONSE = "The answer is: 42 + 2026 + -60 + 15 = 2023." + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + THIRD_PROMPT = SECOND_PROMPT + SECOND_RESPONSE + SECOND_TOOL_RESPONSE + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + THIRD_PROMPT_TOKEN_IDS = _TOKENIZER(THIRD_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = "Let me get the year and Mars temperature first." + FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + SECOND_RESPONSE_CONTENT = "Now let me get Earth temperature." + SECOND_TOOL_CALLS_OPENAI_FORMAT = [ + { + "id": "call00000", + "function": {"arguments": '{"location": "Earth"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT = OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT + [ + { + "content": SECOND_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": SECOND_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"temperature": 15}', "name": "get_temperature"}, + ] + + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + ThreeTurnStub.FIRST_PROMPT: ThreeTurnStub.FIRST_RESPONSE, + ThreeTurnStub.SECOND_PROMPT: ThreeTurnStub.SECOND_RESPONSE, + ThreeTurnStub.THIRD_PROMPT: ThreeTurnStub.THIRD_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") diff --git a/miles/utils/test_utils/uvicorn_thread_server.py b/miles/utils/test_utils/uvicorn_thread_server.py new file mode 100644 index 0000000000..904343c984 --- /dev/null +++ b/miles/utils/test_utils/uvicorn_thread_server.py @@ -0,0 +1,49 @@ +import asyncio +import socket +import threading +import time + +import uvicorn + + +class UvicornThreadServer: + def __init__(self, app, host: str, port: int): + self._app = app + self.host = host + self.port = port + self._server: uvicorn.Server | None = None + self._thread: threading.Thread | None = None + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + def start(self) -> None: + config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info") + self._server = uvicorn.Server(config) + + def run() -> None: + asyncio.run(self._server.serve()) + + self._thread = threading.Thread(target=run, daemon=True) + self._thread.start() + self._wait_for_port_open() + + def stop(self) -> None: + if self._server is not None: + self._server.should_exit = True + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=2.0) + + def _wait_for_port_open(self) -> None: + for _ in range(50): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex((self.host, self.port)) + sock.close() + if result == 0: + return + except Exception: + pass + time.sleep(0.1) + raise RuntimeError(f"Failed to start server on {self.url}") diff --git a/miles/utils/types.py b/miles/utils/types.py index 0a2531a7af..5200d625e6 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -145,6 +145,24 @@ def get_reward_value(self, args) -> float: def effective_response_length(self): return sum(self.loss_mask) if self.loss_mask is not None else self.response_length + def validate(self): + assert self.response_length >= 0, f"response_length must be >= 0, got {self.response_length}" + assert ( + len(self.tokens) >= self.response_length + ), f"tokens length ({len(self.tokens)}) must be >= response_length ({self.response_length})" + if self.loss_mask is not None: + assert ( + len(self.loss_mask) == self.response_length + ), f"loss_mask length ({len(self.loss_mask)}) != response_length ({self.response_length})" + if self.rollout_log_probs is not None: + assert ( + len(self.rollout_log_probs) == self.response_length + ), f"rollout_log_probs length ({len(self.rollout_log_probs)}) != response_length ({self.response_length})" + if self.rollout_routed_experts is not None: + actual = len(self.rollout_routed_experts) + expect = len(self.tokens) - 1 + assert actual == expect, f"rollout_routed_experts length ({actual}) != len(tokens) - 1 ({expect})" + def update_from_meta_info(self, args, meta_info: dict): """ Update the sample with new information from meta_info returned by the rollout engine. diff --git a/requirements.txt b/requirements.txt index 2c20195fc4..dacd51132c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf pillow +pybase64 pylatexenc pyyaml qwen_vl_utils # for VLM diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/ci/gpu_lock_exec.py b/tests/ci/gpu_lock_exec.py index 9507e2e858..20379f76a2 100644 --- a/tests/ci/gpu_lock_exec.py +++ b/tests/ci/gpu_lock_exec.py @@ -19,11 +19,14 @@ def main(): _execute_print_only(args) return - fd_locks = _try_acquire(args) + if args.count == 0 and not args.devices: + print("[gpu_lock_exec] Do not acquire GPU since count=0", flush=True) + else: + fd_locks = _try_acquire(args) - dev_list = ",".join(str(x.gpu_id) for x in fd_locks) - os.environ[args.target_env_name] = dev_list - print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) + dev_list = ",".join(str(x.gpu_id) for x in fd_locks) + os.environ[args.target_env_name] = dev_list + print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) _os_execvp(args) diff --git a/tests/e2e/.gitkeep b/tests/e2e/.gitkeep new file mode 100644 index 0000000000..615f2b076c --- /dev/null +++ b/tests/e2e/.gitkeep @@ -0,0 +1 @@ +# TODO: may move e2e tests to this folder \ No newline at end of file diff --git a/tests/fast/__init__.py b/tests/fast/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/conftest.py b/tests/fast/conftest.py new file mode 100644 index 0000000000..4cb30e91fa --- /dev/null +++ b/tests/fast/conftest.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from tests.fast.fixtures.generation_fixtures import generation_env +from tests.fast.fixtures.rollout_fixtures import rollout_env + +_ = rollout_env, generation_env + + +@pytest.fixture(autouse=True) +def enable_experimental_rollout_refactor(): + os.environ["MILES_EXPERIMENTAL_ROLLOUT_REFACTOR"] = "1" + yield + os.environ.pop("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", None) diff --git a/tests/fast/fixtures/__init__.py b/tests/fast/fixtures/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/fast/fixtures/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/fast/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py new file mode 100644 index 0000000000..816371ee3a --- /dev/null +++ b/tests/fast/fixtures/generation_fixtures.py @@ -0,0 +1,274 @@ +""" +Fixtures to test custom-generate-function +""" + +from argparse import Namespace +from contextlib import contextmanager +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any +from unittest.mock import patch + +import pytest +import requests + +from miles.rollout.base_types import GenerateFnInput +from miles.rollout.inference_rollout.compatibility import load_generate_function +from miles.rollout.inference_rollout.inference_rollout_common import GenerateState +from miles.router.router import MilesRouter +from miles.utils.async_utils import run +from miles.utils.http_utils import find_available_port, init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer +from miles.utils.types import Sample + +MODEL_NAME = "Qwen/Qwen3-0.6B" +RESPONSE_TEXT = "\\boxed{8}" +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} + +VARIANT_TO_GENERATE_FN_PATH = { + "old_sglang_rollout": "miles.rollout.sglang_rollout.generate", + "single_turn": "miles.rollout.generate_hub.single_turn.generate", + "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn.generate", + "multi_turn_multi_samples": "miles.rollout.generate_hub.multi_turn.generate", + "agentic_tool_call_single_sample": "miles.rollout.generate_hub.agentic_tool_call.generate", + "agentic_tool_call_multi_samples": "miles.rollout.generate_hub.agentic_tool_call.generate", +} + + +def extra_argv_for_variant( + variant: str, + *, + custom_generate_function_path: str | None = None, + generate_max_turns: int = 16, + generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + generate_tool_call_parser: str = "qwen25", + generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", +) -> list[str]: + argv = [ + "--custom-generate-function-path", + custom_generate_function_path or VARIANT_TO_GENERATE_FN_PATH[variant], + ] + + if variant in ( + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", + ): + argv += [ + "--generate-max-turns", + str(generate_max_turns), + "--generate-tool-specs-path", + generate_tool_specs_path, + "--generate-execute-tool-function-path", + generate_execute_tool_function_path, + ] + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + argv += ["--generate-tool-call-parser", generate_tool_call_parser] + if variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples"): + argv.append("--generate-multi-samples") + + return argv + + +def listify(x): + return x if isinstance(x, list) else [x] + + +def make_sample( + *, + prompt: str | list[dict] = "What is 1+7?", + tokens: list[int] | None = None, + response: str = "", + response_length: int = 0, + status: Sample.Status = Sample.Status.PENDING, + multimodal_inputs: dict | None = None, +) -> Sample: + return Sample( + prompt=prompt, + tokens=tokens or [], + response=response, + response_length=response_length, + status=status, + multimodal_inputs=multimodal_inputs, + ) + + +@dataclass +class GenerateEnv: + args: Namespace + mock_server: Any + + +@dataclass +class GenerateResult: + sample: Sample | list[Sample] + requests: list[dict] + + +def run_generate( + env: GenerateEnv, + sample: Sample, + sampling_params: dict[str, Any] | None = None, + *, + variant: str = "single_turn", +) -> GenerateResult: + env.mock_server.request_log.clear() + result_sample = run( + _call_generate( + env.args, + sample, + sampling_params or DEFAULT_SAMPLING_PARAMS, + variant=variant, + ) + ) + return GenerateResult(sample=result_sample, requests=list(env.mock_server.request_log)) + + +async def _call_generate( + args: Namespace, + sample: Sample, + sampling_params: dict[str, Any], + *, + variant: str = "single_turn", +) -> Sample: + generate_fn = load_generate_function(VARIANT_TO_GENERATE_FN_PATH[variant]) + state = GenerateState(args) + input = GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params.copy(), evaluation=False) + output = await generate_fn(input) + return output.samples + + +def make_args( + *, + variant: str, + router_port: int, + use_rollout_routing_replay: bool = False, + sglang_speculative_algorithm: str | None = None, + model_name: str = MODEL_NAME, + extra_argv: list[str] | None = None, + custom_generate_function_path: str | None = None, + generate_max_turns: int = 16, + generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + generate_tool_call_parser: str = "qwen25", + generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", + rollout_max_context_len: int | None = None, +) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + model_name, + "--prompt-data", + "/dev/null", + "--rm-type", + "math", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + if use_rollout_routing_replay: + argv.append("--use-rollout-routing-replay") + if sglang_speculative_algorithm: + argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) + if rollout_max_context_len is not None: + argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) + + argv.extend( + extra_argv_for_variant( + variant, + custom_generate_function_path=custom_generate_function_path, + generate_max_turns=generate_max_turns, + generate_tool_specs_path=generate_tool_specs_path, + generate_tool_call_parser=generate_tool_call_parser, + generate_execute_tool_function_path=generate_execute_tool_function_path, + ) + ) + + if extra_argv: + argv.extend(extra_argv) + + from miles.utils.arguments import parse_args + + with patch("sys.argv", argv): + args = parse_args() + + init_http_client(args) + return args + + +@contextmanager +def with_miles_router(backend_url: str, model_name: str): + router_args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + hf_checkpoint=model_name, + ) + router = MilesRouter(router_args) + + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() + + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend_url}) + + try: + yield port + finally: + server.stop() + + +@pytest.fixture +def generation_env(request, variant): + SingletonMeta.clear_all_instances() + params = getattr(request, "param", {}) + args_kwargs = params.get("args_kwargs", {}) + model_name = args_kwargs.get("model_name", MODEL_NAME) + custom_generate_function_path = VARIANT_TO_GENERATE_FN_PATH[variant] + + def process_fn(_): + x = params.get("process_fn_kwargs", {}) + return ProcessResult( + text=x.get("response_text", RESPONSE_TEXT), + finish_reason=x.get("finish_reason", "stop"), + cached_tokens=x.get("cached_tokens", 0), + meta_info=ProcessResultMetaInfo( + weight_version=x.get("weight_version"), + routed_experts=x.get("routed_experts"), + spec_accept_token_num=x.get("spec_accept_token_num"), + spec_draft_token_num=x.get("spec_draft_token_num"), + spec_verify_ct=x.get("spec_verify_ct"), + ), + ) + + with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server: + with with_miles_router(mock_server.url, model_name) as router_port: + other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} + args = make_args( + variant=variant, + router_port=router_port, + model_name=model_name, + custom_generate_function_path=custom_generate_function_path, + **other_args_kwargs, + ) + yield GenerateEnv(args=args, mock_server=mock_server) + + SingletonMeta.clear_all_instances() diff --git a/tests/fast/fixtures/rollout_fixtures.py b/tests/fast/fixtures/rollout_fixtures.py new file mode 100644 index 0000000000..44d8a50d79 --- /dev/null +++ b/tests/fast/fixtures/rollout_fixtures.py @@ -0,0 +1,127 @@ +""" +Fixtures to test rollout-function +""" + +import json +from argparse import Namespace +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from unittest.mock import patch + +import pytest +import requests + +from miles.rollout.data_source import DataSource, RolloutDataSourceWithBuffer +from miles.router.router import MilesRouter +from miles.utils.arguments import parse_args +from miles.utils.http_utils import find_available_port, init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +@dataclass(frozen=True) +class RolloutEnvConfig: + extra_argv: list[str] | None = None + data_rows: list[dict] | None = None + latency: float = 0.0 + + +@dataclass(frozen=True) +class RolloutEnv: + args: Namespace + data_source: DataSource + mock_server: MockSGLangServer + + +def _build_args(*, data_path: str, router_port: int, extra_argv: list[str] | None = None) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--n-samples-per-prompt", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + "Qwen/Qwen3-0.6B", + "--prompt-data", + data_path, + "--input-key", + "input", + "--label-key", + "label", + "--rm-type", + "math", + "--eval-prompt-data", + "toy", + data_path, + "--use-miles-router", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + (extra_argv or []) + with patch("sys.argv", argv): + args = parse_args() + args.miles_router_middleware_paths = [] + init_http_client(args) + return args + + +@contextmanager +def _with_miles_router(args: Namespace) -> Iterator[UvicornThreadServer]: + router = MilesRouter(args, verbose=False) + server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) + try: + server.start() + yield server + finally: + server.stop() + + +def _write_jsonl(path: str, rows: list[dict]) -> None: + Path(path).write_text("".join(json.dumps(row, ensure_ascii=False) + "\n" for row in rows), encoding="utf-8") + + +DEFAULT_DATA_ROWS = [{"input": "What is 1+7?", "label": "8"}] + + +@pytest.fixture +def rollout_env(tmp_path, request) -> RolloutEnv: + config = request.param + assert isinstance(config, RolloutEnvConfig) + + data_rows = config.data_rows or DEFAULT_DATA_ROWS + + data_path = str(tmp_path / "data.jsonl") + _write_jsonl(data_path, data_rows) + + router_port = find_available_port(20000) + args = _build_args(data_path=data_path, router_port=router_port, extra_argv=config.extra_argv) + + SingletonMeta.clear_all_instances() + + with with_mock_server(model_name=args.hf_checkpoint, latency=config.latency) as mock_server: + with _with_miles_router(args) as router_server: + r = requests.post( + f"{router_server.url}/add_worker", + params={"url": mock_server.url}, + timeout=5.0, + ) + r.raise_for_status() + + data_source = RolloutDataSourceWithBuffer(args) + yield RolloutEnv(args=args, data_source=data_source, mock_server=mock_server) + + SingletonMeta.clear_all_instances() diff --git a/tests/fast/rollout/__init__.py b/tests/fast/rollout/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/generate_hub/__init__.py b/tests/fast/rollout/generate_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/generate_hub/test_multi_turn.py b/tests/fast/rollout/generate_hub/test_multi_turn.py new file mode 100644 index 0000000000..5d974aaadd --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_multi_turn.py @@ -0,0 +1,572 @@ +from copy import deepcopy +from dataclasses import dataclass, replace +from itertools import groupby + +import numpy as np +import pybase64 +import pytest +from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate +from transformers import AutoTokenizer + +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, ThreeTurnStub, TwoTurnStub +from miles.utils.types import Sample + +_ = generation_env, SAMPLE_TOOLS, TwoTurnStub, ThreeTurnStub + + +def is_agentic_variant(variant: str) -> bool: + return variant in ("agentic_tool_call_single_sample", "agentic_tool_call_multi_samples") + + +# ------------------------------------ fixtures and consts ---------------------------------------- + + +MODEL_NAME = "Qwen/Qwen3-0.6B" +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} +TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + + +@pytest.fixture( + params=[ + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", + ] +) +def variant(request): + return request.param + + +@dataclass(frozen=True) +class SampleParsedChunk: + tokens_decoded_str: str + loss_mask_value: int + rollout_log_probs: list[float] + + +@dataclass +class ExpectedSampleInfo: + chunks: list[SampleParsedChunk] + partial_sample: Sample + + +def token_len(text: str) -> int: + return len(TOKENIZER(text, add_special_tokens=False)["input_ids"]) + + +def expected_chunk(text: str, loss_mask: int) -> SampleParsedChunk: + n = token_len(text) + log_probs = [-1 / 128 * i for i in range(n)] if loss_mask else [0.0] * n + return SampleParsedChunk(text, loss_mask, log_probs) + + +def parse_sample_into_chunks(sample: Sample, tokenizer) -> list[SampleParsedChunk]: + prompt_len = len(sample.tokens) - sample.response_length + response_tokens = sample.tokens[prompt_len:] + loss_mask = sample.loss_mask or [] + log_probs = sample.rollout_log_probs or [] + + chunks = [] + idx = 0 + for mask_val, group in groupby(loss_mask): + group_len = len(list(group)) + sli = slice(idx, idx + group_len) + chunks.append( + SampleParsedChunk( + tokens_decoded_str=tokenizer.decode(response_tokens[sli]), + loss_mask_value=mask_val, + rollout_log_probs=log_probs[sli], + ) + ) + idx += group_len + return chunks + + +def expected_partial_sample( + *, + prompt: list[dict], + response: str, + response_length: int, + status: Sample.Status = Sample.Status.COMPLETED, +) -> Sample: + return Sample( + prompt=prompt, + response=response, + response_length=response_length, + status=status, + tokens=[], + loss_mask=[], + rollout_log_probs=[], + weight_versions=[], + spec_info=Sample.SpecInfo(), + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + + +def verify_samples(actual: Sample | list[Sample], expected: list[ExpectedSampleInfo]): + actual = listify(actual) + assert len(actual) == len(expected) + + for actual_item, expected_item in zip(actual, expected, strict=True): + actual_chunks = parse_sample_into_chunks(actual_item, TOKENIZER) + assert actual_chunks == expected_item.chunks + + actual_partial = replace( + deepcopy(actual_item), + tokens=[], + loss_mask=[], + rollout_log_probs=[], + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + assert actual_partial == expected_item.partial_sample + + +def _run_generate(variant: str, env: GenerateEnv, sample: Sample, sampling_params: dict | None = None): + return run_generate(env, sample, sampling_params, variant=variant) + + +def expected_request(input_ids: list[int], sampling_params: dict | None = None) -> dict: + return { + "input_ids": input_ids, + "sampling_params": sampling_params or DEFAULT_SAMPLING_PARAMS, + "return_logprob": True, + "return_routed_experts": False, + } + + +def expected_openai_request(messages: list[dict]) -> dict: + return {"messages": messages, "model": "default", "tools": SAMPLE_TOOLS} + + +SINGLE_TURN_PROMPT = [{"role": "user", "content": "What is 1+1?"}] +SINGLE_TURN_RESPONSE = "The answer is 2." +_SINGLE_TURN_PROMPT_TEXT = TOKENIZER.apply_chat_template( + SINGLE_TURN_PROMPT, tokenize=False, add_generation_prompt=True, tools=SAMPLE_TOOLS +) +SINGLE_TURN_PROMPT_TOKEN_IDS = TOKENIZER(_SINGLE_TURN_PROMPT_TEXT, add_special_tokens=False)["input_ids"] +SINGLE_TURN_PROMPT_TOKEN_LEN = len(SINGLE_TURN_PROMPT_TOKEN_IDS) + + +# ------------------------------------ tests ---------------------------------------- + + +class TestBasicMultiTurn: + def test_single_turn_no_tool_call(self, variant, generation_env): + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=SINGLE_TURN_RESPONSE, finish_reason="stop" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(SINGLE_TURN_PROMPT)] + else: + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, response=SINGLE_TURN_RESPONSE, response_length=6 + ), + ), + ], + ) + + def test_two_turns_with_tool_call(self, variant, generation_env): + generation_env.mock_server.process_fn = TwoTurnStub.process_fn + + S = TwoTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [ + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), + ] + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + S.SECOND_RESPONSE + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestExitConditions: + def test_partial_rollout_not_supported(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("agentic_tool_call does not check partial_rollout flag") + generation_env.args.partial_rollout = True + + with pytest.raises(AssertionError, match="Partial rollout is not supported"): + _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + def test_abort_preserves_content(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("agentic_tool_call does not handle abort finish_reason") + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=SINGLE_TURN_RESPONSE, finish_reason="abort" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, + response=SINGLE_TURN_RESPONSE, + response_length=6, + status=Sample.Status.ABORTED, + ), + ), + ], + ) + + def test_finish_reason_length_exits_and_preserves_content(self, variant, generation_env): + S = TwoTurnStub + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="length") + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] + else: + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + status=Sample.Status.TRUNCATED, + ), + ), + ], + ) + + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"generate_max_turns": 1}}], indirect=True) + def test_max_turns_reached(self, variant, generation_env): + S = TwoTurnStub + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="stop") + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] + else: + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestRespectMaxContextLen: + @pytest.mark.parametrize( + "generation_env", [{"args_kwargs": {"rollout_max_context_len": SINGLE_TURN_PROMPT_TOKEN_LEN}}], indirect=True + ) + def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + assert result.requests == [] + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, response="", response_length=0, status=Sample.Status.TRUNCATED + ), + ) + ] + else: + expected = [] + verify_samples(result.sample, expected) + + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": { + "rollout_max_context_len": len(TwoTurnStub.FIRST_PROMPT_TOKEN_IDS) + + token_len(TwoTurnStub.FIRST_RESPONSE) + + token_len(TwoTurnStub.FIRST_TOOL_RESPONSE) + } + } + ], + indirect=True, + ) + def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + S = TwoTurnStub + generation_env.mock_server.process_fn = S.process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + if variant == "multi_turn_single_sample": + partial_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=partial_response, + response_length=token_len(partial_response), + status=Sample.Status.TRUNCATED, + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + status=Sample.Status.TRUNCATED, + ), + ), + ] + verify_samples(result.sample, expected) + + @pytest.mark.parametrize( + "generation_env,expected_max_new_tokens", + [ + ( + {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 10}}, + 10, + ), + ( + {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 100}}, + 64, + ), + ], + indirect=["generation_env"], + ) + def test_second_turn_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + S = TwoTurnStub + generation_env.mock_server.process_fn = S.process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + assert len(result.requests) >= 2 + assert result.requests[1]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens + assert result.requests[1]["sampling_params"]["temperature"] == DEFAULT_SAMPLING_PARAMS["temperature"] + + +class TestThreeTurn: + """Need to test 3-turn case besides 2-turn, because e.g. merge_samples may behave differently.""" + + def test_three_turns_with_sequential_tool_calls(self, variant, generation_env): + generation_env.mock_server.process_fn = ThreeTurnStub.process_fn + + S = ThreeTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [ + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + expected_openai_request(S.OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), + expected_request(S.THIRD_PROMPT_TOKEN_IDS), + ] + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = ( + S.FIRST_RESPONSE + + S.FIRST_TOOL_RESPONSE + + S.SECOND_RESPONSE + + S.SECOND_TOOL_RESPONSE + + S.THIRD_RESPONSE + ) + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), + expected_chunk(S.SECOND_TOOL_RESPONSE, 0), + expected_chunk(S.THIRD_RESPONSE, 1), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.THIRD_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.THIRD_RESPONSE, + response_length=token_len(S.THIRD_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestRoutedExpertsMultiTurn: + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": { + "use_rollout_routing_replay": True, + } + } + ], + indirect=True, + ) + def test_two_turns_routed_experts(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + + S = TwoTurnStub + num_layers, moe_router_topk = 2, 4 + generation_env.args.num_layers = num_layers + generation_env.args.moe_router_topk = moe_router_topk + + def make_routed_experts(prompt_token_ids, response_text): + total_tokens = len(prompt_token_ids) + token_len(response_text) + routed_experts_len = total_tokens - 1 + return np.arange(routed_experts_len * num_layers * moe_router_topk, dtype=np.int32).reshape( + routed_experts_len, num_layers, moe_router_topk + ) + + first_routed_experts = make_routed_experts(S.FIRST_PROMPT_TOKEN_IDS, S.FIRST_RESPONSE) + second_routed_experts = make_routed_experts(S.SECOND_PROMPT_TOKEN_IDS, S.SECOND_RESPONSE) + + def process_fn(prompt: str) -> ProcessResult: + if prompt == S.FIRST_PROMPT: + text, routed_experts = S.FIRST_RESPONSE, first_routed_experts + elif prompt == S.SECOND_PROMPT: + text, routed_experts = S.SECOND_RESPONSE, second_routed_experts + else: + raise ValueError(f"Unexpected prompt: {prompt}") + return ProcessResult( + text=text, + finish_reason="stop", + meta_info=ProcessResultMetaInfo( + routed_experts=pybase64.b64encode(routed_experts.tobytes()).decode("ascii") + ), + ) + + generation_env.mock_server.process_fn = process_fn + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT), DEFAULT_SAMPLING_PARAMS) + + sample = result.sample[-1] if isinstance(result.sample, list) else result.sample + assert sample.rollout_routed_experts is not None + assert sample.rollout_routed_experts.shape == second_routed_experts.shape + np.testing.assert_array_equal(sample.rollout_routed_experts, second_routed_experts) + assert len(sample.tokens) - 1 == second_routed_experts.shape[0] diff --git a/tests/fast/rollout/generate_hub/test_single_turn.py b/tests/fast/rollout/generate_hub/test_single_turn.py new file mode 100644 index 0000000000..a58e6fb3c6 --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_single_turn.py @@ -0,0 +1,424 @@ +import numpy as np +import pybase64 +import pytest +import torch +from PIL import Image +from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate +from transformers import AutoProcessor + +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo +from miles.utils.types import Sample + +_ = generation_env + +# ------------------------------------ fixtures and consts ---------------------------------------- + + +MODEL_NAME = "Qwen/Qwen3-0.6B" +PROMPT = "What is 1+7?" +PROMPT_TOKENS = [3838, 374, 220, 16, 10, 22, 30] +PROMPT_TOKEN_LEN = len(PROMPT_TOKENS) +RESPONSE_TOKENS = [59, 79075, 90, 23, 92] +RESPONSE_TEXT = "\\boxed{8}" +RESPONSE_LOG_PROBS = [-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125] +SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} +DEFAULT_MAX_NEW_TOKENS = SAMPLING_PARAMS["max_new_tokens"] + + +@pytest.fixture(params=["old_sglang_rollout", "single_turn", "multi_turn_single_sample", "multi_turn_multi_samples"]) +def variant(request): + return request.param + + +def expected_request( + variant: str, + *, + input_ids: list[int] | None = None, + sampling_params: dict | None = None, + return_routed_experts: bool = False, + image_data: list[str] | None = None, +) -> dict: + result = { + "input_ids": input_ids or PROMPT_TOKENS, + "sampling_params": sampling_params or SAMPLING_PARAMS, + "return_logprob": True, + } + if variant in ("single_turn", "multi_turn_single_sample", "multi_turn_multi_samples") or return_routed_experts: + result["return_routed_experts"] = return_routed_experts + if image_data is not None: + result["image_data"] = image_data + return result + + +class _Unset: + pass + + +_UNSET = _Unset() + + +def expected_sample( + variant: str, + *, + prompt: str = PROMPT, + response: str = RESPONSE_TEXT, + response_length: int = 5, + tokens: list[int] | None | _Unset = _UNSET, + rollout_log_probs: list[float] | None | _Unset = _UNSET, + status: Sample.Status = Sample.Status.COMPLETED, + cached_tokens: int = 0, + prompt_tokens: int = 7, + weight_versions: list[str] | None = None, + rollout_routed_experts: np.ndarray | None = None, + spec_info: Sample.SpecInfo | None = None, + multimodal_inputs: dict | None = None, + multimodal_train_inputs: dict | None = None, + loss_mask: list[int] | None | _Unset = _UNSET, +) -> Sample: + actual_response_length = response_length if response_length is not None else len(RESPONSE_TOKENS) + if isinstance(loss_mask, _Unset): + loss_mask = ( + [1] * actual_response_length + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") + else None + ) + + return Sample( + group_index=None, + index=None, + prompt=prompt, + tokens=PROMPT_TOKENS + RESPONSE_TOKENS if isinstance(tokens, _Unset) else tokens, + multimodal_inputs=multimodal_inputs, + multimodal_train_inputs=multimodal_train_inputs, + response=response, + response_length=response_length, + label=None, + reward=None, + loss_mask=loss_mask, + weight_versions=weight_versions or [], + rollout_log_probs=RESPONSE_LOG_PROBS if isinstance(rollout_log_probs, _Unset) else rollout_log_probs, + rollout_routed_experts=rollout_routed_experts, + remove_sample=False, + status=status, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=spec_info or Sample.SpecInfo(), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=cached_tokens, total_prompt_tokens=prompt_tokens), + ) + + +def _make_sample(tokens=None, response="", response_length=0, status=Sample.Status.PENDING, multimodal_inputs=None): + return make_sample( + prompt=PROMPT, + tokens=tokens, + response=response, + response_length=response_length, + status=status, + multimodal_inputs=multimodal_inputs, + ) + + +def _run_generate(variant: str, env: GenerateEnv, sample: Sample | None = None, sampling_params: dict | None = None): + return run_generate(env, sample or _make_sample(), sampling_params or SAMPLING_PARAMS, variant=variant) + + +# ------------------------------------ tests ---------------------------------------- + + +class TestBasicGeneration: + def test_basic_generation(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant)] + + +class TestResumedSingleTurn: + def test_two_consecutive_calls_on_same_sample(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + partial_text = "\\boxed" + partial_tokens = [59, 79075] + partial_log_probs = [-0.0, -0.0078125] + + remaining_text = "{8}" + remaining_tokens = [90, 23, 92] + remaining_log_probs = [-0.0, -0.0078125, -0.015625] + + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=partial_text, finish_reason="abort") + sample = _make_sample() + result1 = _run_generate(variant, generation_env, sample) + assert result1.requests == [expected_request(variant)] + assert result1.sample == expected_sample( + variant, + response=partial_text, + response_length=2, + tokens=PROMPT_TOKENS + partial_tokens, + rollout_log_probs=partial_log_probs, + status=Sample.Status.ABORTED, + ) + + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=remaining_text, finish_reason="stop") + result2 = _run_generate(variant, generation_env, result1.sample) + tokens_after_turn1 = PROMPT_TOKENS + partial_tokens + assert result2.requests == [ + expected_request( + variant, + input_ids=tokens_after_turn1, + sampling_params={"max_new_tokens": 14, "temperature": 0.7}, + ) + ] + assert result2.sample == expected_sample( + variant, + response=partial_text + remaining_text, + response_length=2 + 3, + tokens=tokens_after_turn1 + remaining_tokens, + rollout_log_probs=partial_log_probs + remaining_log_probs, + prompt_tokens=len(PROMPT_TOKENS) + len(tokens_after_turn1), + status=Sample.Status.COMPLETED, + ) + + +class TestFinishReason: + @pytest.mark.parametrize( + "generation_env,expected_status", + [ + ({"process_fn_kwargs": {"finish_reason": "stop"}}, Sample.Status.COMPLETED), + ({"process_fn_kwargs": {"finish_reason": "length"}}, Sample.Status.TRUNCATED), + ({"process_fn_kwargs": {"finish_reason": "abort"}}, Sample.Status.ABORTED), + ], + indirect=["generation_env"], + ) + def test_finish_reason_sets_status(self, variant, generation_env, expected_status): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant, status=expected_status)] + + +class TestRoutedExperts: + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": {"use_rollout_routing_replay": True}, + "process_fn_kwargs": {"routed_experts": "placeholder"}, + } + ], + indirect=True, + ) + def test_routed_experts_enabled_and_parsed(self, variant, generation_env): + num_layers, moe_router_topk = 2, 4 + num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS) + routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape( + num_tokens - 1, num_layers, moe_router_topk + ) + + generation_env.args.num_layers = num_layers + generation_env.args.moe_router_topk = moe_router_topk + routed_experts_str = pybase64.b64encode(routed_experts_array.tobytes()).decode("ascii") + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=RESPONSE_TEXT, + finish_reason="stop", + meta_info=ProcessResultMetaInfo(routed_experts=routed_experts_str), + ) + + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant, return_routed_experts=True)] + sample = result.sample[0] if isinstance(result.sample, list) else result.sample + assert sample.rollout_routed_experts is not None + assert sample.rollout_routed_experts.shape == (num_tokens - 1, num_layers, moe_router_topk) + np.testing.assert_array_equal(sample.rollout_routed_experts, routed_experts_array) + + +class TestMetaInfo: + @pytest.mark.parametrize( + "generation_env", [{"process_fn_kwargs": {"cached_tokens": 3, "weight_version": "v1.0"}}], indirect=True + ) + def test_meta_info_fields_updated(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant, cached_tokens=3, weight_versions=["v1.0"])] + + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": {"sglang_speculative_algorithm": "EAGLE"}, + "process_fn_kwargs": {"spec_accept_token_num": 10, "spec_draft_token_num": 15, "spec_verify_ct": 3}, + } + ], + indirect=True, + ) + def test_spec_info_updated(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [ + expected_sample( + variant, + spec_info=Sample.SpecInfo( + spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3, completion_token_num=5 + ), + ) + ] + + +class TestInputStatusValidation: + @pytest.mark.parametrize("status", [Sample.Status.PENDING, Sample.Status.ABORTED]) + def test_allowed_statuses(self, variant, generation_env, status): + result = _run_generate(variant, generation_env, _make_sample(status=status)) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant)] + + @pytest.mark.parametrize("status", [Sample.Status.COMPLETED, Sample.Status.TRUNCATED]) + def test_rejected_statuses(self, variant, generation_env, status): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + with pytest.raises(AssertionError): + _run_generate(variant, generation_env, _make_sample(status=status)) + + +class TestPayloadStructure: + def test_sampling_params_passed_through(self, variant, generation_env): + result = _run_generate( + variant, generation_env, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9} + ) + assert result.requests == [ + expected_request(variant, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) + ] + assert listify(result.sample) == [expected_sample(variant)] + + +class TestBoundaryConditions: + def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + existing_tokens = [1, 2, 3, 4, 5, 6, 7] + list(range(100, 110)) + sample = _make_sample(tokens=existing_tokens, response="x" * 10, response_length=10) + + result = _run_generate(variant, generation_env, sample, {"max_new_tokens": 10, "temperature": 0.7}) + assert result.requests == [] + assert result.sample == expected_sample( + variant, + response="x" * 10, + response_length=10, + tokens=existing_tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + ) + + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"rollout_max_context_len": 5}}], indirect=True) + def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + if variant == "multi_turn_multi_samples": + pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") + result = _run_generate(variant, generation_env) + assert result.requests == [] + tokens = PROMPT_TOKENS if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") else [] + assert listify(result.sample) == [ + expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, + ) + ] + + @pytest.mark.parametrize( + "generation_env,expected_max_new_tokens", + [ + ({"args_kwargs": {"rollout_max_context_len": 10}}, 10 - PROMPT_TOKEN_LEN), + ({"args_kwargs": {"rollout_max_context_len": 8}}, 8 - PROMPT_TOKEN_LEN), + ({"args_kwargs": {"rollout_max_context_len": 100}}, DEFAULT_MAX_NEW_TOKENS), + ], + indirect=["generation_env"], + ) + def test_moderate_length_input_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + result = _run_generate(variant, generation_env) + assert len(result.requests) == 1 + assert result.requests[0]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens + assert result.requests[0]["sampling_params"]["temperature"] == SAMPLING_PARAMS["temperature"] + assert listify(result.sample) == [expected_sample(variant)] + + @pytest.mark.parametrize( + "generation_env", + [{"args_kwargs": {"rollout_max_context_len": PROMPT_TOKEN_LEN}}], + indirect=True, + ) + def test_adjusted_max_new_tokens_zero_returns_truncated(self, variant, generation_env): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + if variant == "multi_turn_multi_samples": + pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") + result = _run_generate(variant, generation_env) + assert result.requests == [] + tokens = PROMPT_TOKENS if variant == "multi_turn_single_sample" else [] + assert listify(result.sample) == [ + expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, + ) + ] + + +class TestEmptyResponse: + @pytest.mark.parametrize("generation_env", [{"process_fn_kwargs": {"response_text": ""}}], indirect=True) + def test_empty_response(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [ + expected_sample(variant, response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[]) + ] + + +VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" + + +class TestMultimodal: + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True) + def test_multimodal_inputs_processed(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + test_image = Image.new("RGB", (64, 64), color="red") + multimodal_inputs = {"images": [test_image]} + processor = AutoProcessor.from_pretrained(VLM_MODEL_NAME, trust_remote_code=True) + expected_mti = { + k: v + for k, v in processor(text=PROMPT, **multimodal_inputs).items() + if k not in ["input_ids", "attention_mask"] + } + + result = _run_generate(variant, generation_env, _make_sample(multimodal_inputs=multimodal_inputs)) + + assert result.requests == [ + expected_request( + variant, + input_ids=PROMPT_TOKENS, + image_data=[encode_image_for_rollout_engine(test_image)], + ) + ] + actual_mti = result.sample.multimodal_train_inputs + assert actual_mti is not None + assert set(actual_mti.keys()) == set(expected_mti.keys()) + assert torch.all(actual_mti["pixel_values"] == expected_mti["pixel_values"]) + assert torch.all(actual_mti["image_grid_thw"] == expected_mti["image_grid_thw"]) + assert result.sample == expected_sample( + variant, + tokens=PROMPT_TOKENS + RESPONSE_TOKENS, + multimodal_inputs=multimodal_inputs, + multimodal_train_inputs=actual_mti, + ) diff --git a/tests/fast/rollout/generate_hub/test_tool_call_utils.py b/tests/fast/rollout/generate_hub/test_tool_call_utils.py new file mode 100644 index 0000000000..0f2305e753 --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_tool_call_utils.py @@ -0,0 +1,99 @@ +import pytest + +from miles.rollout.generate_utils.tool_call_utils import _DUMMY_USER, _build_dummy_assistant, tokenize_tool_responses + +TOOL_CALL_TEST_MODELS = [ + "Qwen/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen3-0.6B", + "Qwen/Qwen3-4B-Instruct-2507", + "Qwen/Qwen3-Coder-30B-A3B-Instruct", + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI + "mistralai/Mistral-7B-Instruct-v0.3", + "deepseek-ai/DeepSeek-V3", + "stepfun-ai/step3", + "MiniMaxAI/MiniMax-M2", + "internlm/internlm3-8b-instruct", + "THUDM/glm-4-9b-chat", + "moonshotai/Kimi-K2-Instruct", + "XiaomiMiMo/MiMo-7B-RL", +] + +SINGLE_TOOL_CALL_ONLY_MODELS = [ + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo +] + +# Models where tokenize->decode produces extra whitespace vs direct string diff +TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS = [ + "THUDM/glm-4-9b-chat", +] + +SAMPLE_TOOL_RESPONSES = [ + { + "role": "tool", + "tool_call_id": "call00000", + "content": '{"year": 2026}', + "name": "get_year", + }, + { + "role": "tool", + "tool_call_id": "call00001", + "content": '{"temperature": 25}', + "name": "get_temperature", + }, +] + + +class TestTokenizeToolResponses: + @pytest.mark.parametrize("model_name", ["Qwen/Qwen3-0.6B"]) + def test_snapshot(self, model_name): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + token_ids = tokenize_tool_responses(SAMPLE_TOOL_RESPONSES, tokenizer) + decoded = tokenizer.decode(token_ids) + + assert decoded == ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": 25}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + @pytest.mark.parametrize("num_tools", [1, 2]) + @pytest.mark.parametrize("model_name", TOOL_CALL_TEST_MODELS) + def test_tokenize_tool_responses(self, model_name, num_tools): + if num_tools > 1 and model_name in SINGLE_TOOL_CALL_ONLY_MODELS: + pytest.skip(f"{model_name} only supports single tool call") + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + tool_responses = SAMPLE_TOOL_RESPONSES[:num_tools] + assert len(tool_responses) == num_tools + + actual_token_ids = tokenize_tool_responses(tool_responses, tokenizer) + actual_str = tokenizer.decode(actual_token_ids) + + dummy_assistant = _build_dummy_assistant(tool_responses) + base_messages = [_DUMMY_USER, dummy_assistant] + expected_str = self._compute_chat_template_diff(base_messages, tool_responses, tokenizer) + + if model_name in TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS: + # Some models produce whitespace differences between tokenize->decode and direct string diff + actual_str = actual_str.replace(" ", "") + expected_str = expected_str.replace(" ", "") + + assert actual_str == expected_str, f"{model_name=}" + + @staticmethod + def _compute_chat_template_diff(base_messages, extra_messages, tokenizer) -> str: + text_with = tokenizer.apply_chat_template( + base_messages + extra_messages, tokenize=False, add_generation_prompt=True + ) + text_without = tokenizer.apply_chat_template(base_messages, tokenize=False, add_generation_prompt=False) + return text_with[len(text_without) :] diff --git a/tests/fast/rollout/generate_utils/__init__.py b/tests/fast/rollout/generate_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/generate_utils/test_sample_utils.py b/tests/fast/rollout/generate_utils/test_sample_utils.py new file mode 100644 index 0000000000..c53fbbb56a --- /dev/null +++ b/tests/fast/rollout/generate_utils/test_sample_utils.py @@ -0,0 +1,156 @@ +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.generate_utils.sample_utils import _merge_sample_pair +from miles.utils.types import Sample + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + tokenizer.decode = lambda tokens: f"" + return tokenizer + + +def make_sample( + prompt="test_prompt", + tokens=None, + response="", + response_length=0, + loss_mask=None, + rollout_log_probs=None, + status=Sample.Status.COMPLETED, + label="test_label", + reward=1.0, + index=0, + group_index=0, +): + return Sample( + prompt=prompt, + tokens=tokens or [], + response=response, + response_length=response_length, + loss_mask=loss_mask, + rollout_log_probs=rollout_log_probs, + status=status, + label=label, + reward=reward, + index=index, + group_index=group_index, + ) + + +class TestMergeSamples: + def test_basic_merge(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 3, 10, 11, 12], + response="response1", + response_length=3, + loss_mask=[1, 1, 1], + rollout_log_probs=[-0.1, -0.2, -0.3], + ) + b = make_sample( + tokens=[1, 2, 3, 10, 11, 12, 20, 21, 30, 31, 32], + response="response2", + response_length=3, + loss_mask=[1, 1, 1], + rollout_log_probs=[-0.4, -0.5, -0.6], + status=Sample.Status.TRUNCATED, + ) + + merged = _merge_sample_pair(a, b, mock_tokenizer) + + assert merged.tokens == b.tokens + assert merged.response_length == 3 + 2 + 3 + assert merged.loss_mask == [1, 1, 1, 0, 0, 1, 1, 1] + assert merged.rollout_log_probs == [-0.1, -0.2, -0.3, 0.0, 0.0, -0.4, -0.5, -0.6] + assert merged.prompt == a.prompt + assert merged.status == b.status + assert merged.label == a.label + assert merged.index == a.index + assert merged.group_index == a.group_index + assert "response1" in merged.response + assert "response2" in merged.response + assert "" in merged.response + + def test_loss_mask_none_defaults_to_all_ones(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=None, + rollout_log_probs=None, + ) + b = make_sample( + tokens=[1, 2, 10, 20, 30], + response_length=1, + loss_mask=None, + rollout_log_probs=None, + ) + + merged = _merge_sample_pair(a, b, mock_tokenizer) + + assert merged.loss_mask == [1, 0, 1] + assert merged.rollout_log_probs == [0.0, 0.0, 0.0] + + def test_tokens_prefix_mismatch_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 3], + response_length=1, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 99, 20, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="b.tokens must start with a.tokens"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_field_mismatch_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=[1], + index=0, + ) + b = make_sample( + tokens=[1, 2, 10, 20, 30], + response_length=1, + loss_mask=[1], + index=1, + ) + + with pytest.raises(AssertionError, match="index mismatch"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_obs_len_invalid_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 10, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="obs_len must be > 0"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_sample_validate_fails_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10, 11], + response_length=2, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 10, 11, 20, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="loss_mask length"): + _merge_sample_pair(a, b, mock_tokenizer) diff --git a/tests/fast/rollout/inference_rollout/__init__.py b/tests/fast/rollout/inference_rollout/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/inference_rollout/conftest.py b/tests/fast/rollout/inference_rollout/conftest.py new file mode 100644 index 0000000000..ca47edeeb6 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/conftest.py @@ -0,0 +1,45 @@ +from unittest.mock import patch + +import pytest + +from miles.utils.arguments import parse_args + + +def _build_mock_args(extra_argv: list[str] | None = None): + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "2", + "--n-samples-per-prompt", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "4", + "--rollout-num-gpus-per-engine", + "2", + "--hf-checkpoint", + "Qwen/Qwen3-0.6B", + "--prompt-data", + "/dev/null", + "--input-key", + "input", + "--label-key", + "label", + "--rm-type", + "math", + "--use-miles-router", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + "30000", + ] + (extra_argv or []) + with patch("sys.argv", argv): + return parse_args() + + +@pytest.fixture +def mock_args(): + return _build_mock_args() diff --git a/tests/fast/rollout/inference_rollout/integration/__init__.py b/tests/fast/rollout/inference_rollout/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/inference_rollout/integration/test_basic.py b/tests/fast/rollout/inference_rollout/integration/test_basic.py new file mode 100644 index 0000000000..5b791829d5 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_basic.py @@ -0,0 +1,69 @@ +import pytest +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import ( + MODULAR_ROLLOUT_BASE_ARGV, + expected_sample, + load_and_call_train, +) + +from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function + +_VARIANTS = [ + pytest.param( + RolloutEnvConfig( + extra_argv=[ + "--rollout-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--eval-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ] + ), + id="old_rollout_old_generate", + ), + pytest.param( + RolloutEnvConfig( + extra_argv=[ + "--rollout-function-path", + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ] + ), + id="new_rollout_old_generate", + ), + pytest.param( + RolloutEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant("single_turn")), + id="new_rollout_new_generate", + ), +] + + +@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) +def test_train(rollout_env): + env = rollout_env + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + assert len(group) == env.args.n_samples_per_prompt + assert group[0] == expected_sample(group_index=0) + + +@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) +def test_eval(rollout_env): + env = rollout_env + fn = load_rollout_function( + RolloutFnConstructorInput(args=env.args, data_source=env.data_source), env.args.eval_function_path + ) + out = call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + assert "toy" in out.data + rewards = out.data["toy"]["rewards"] + samples = out.data["toy"]["samples"] + assert len(rewards) == len(samples) == env.args.n_samples_per_eval_prompt + assert rewards[0] == 1 + assert samples[0] == expected_sample(group_index=None) diff --git a/tests/fast/rollout/inference_rollout/integration/test_deterministic.py b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py new file mode 100644 index 0000000000..69a2359117 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py @@ -0,0 +1,37 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + + +@pytest.mark.parametrize( + "rollout_env,expected_seeds", + [ + pytest.param( + integration_env_config( + [ + "--sglang-enable-deterministic-inference", + "--rollout-seed", + "42", + "--n-samples-per-prompt", + "3", + "--rollout-batch-size", + "1", + ] + ), + {42, 43, 44}, + id="enabled", + ), + pytest.param( + integration_env_config(["--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + {None}, + id="disabled", + ), + ], + indirect=["rollout_env"], +) +def test_sampling_seeds(rollout_env, expected_seeds): + env = rollout_env + load_and_call_train(env.args, env.data_source) + + seeds = {req.get("sampling_params", {}).get("sampling_seed") for req in env.mock_server.request_log} + assert seeds == expected_seeds diff --git a/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py new file mode 100644 index 0000000000..0ca5743ac5 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py @@ -0,0 +1,46 @@ +from contextlib import nullcontext + +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + MIXED_DATA_ROWS, + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + + +@pytest.mark.parametrize( + "rollout_env,use_filter,expect_all_correct", + [ + pytest.param( + integration_env_config(["--rollout-batch-size", "4"], data_rows=MIXED_DATA_ROWS), + False, + False, + id="no_filter", + ), + pytest.param( + integration_env_config( + ["--rollout-batch-size", "3", "--dynamic-sampling-filter-path", "test:filter_by_reward"], + data_rows=MIXED_DATA_ROWS, + ), + True, + True, + id="with_filter", + ), + ], + indirect=["rollout_env"], +) +def test_filter_effect(rollout_env, use_filter, expect_all_correct): + env = rollout_env + ctx = function_registry.temporary("test:filter_by_reward", filter_by_reward) if use_filter else nullcontext() + + with ctx: + out = load_and_call_train(env.args, env.data_source) + + rewards = {group[0].reward for group in out.samples} + if expect_all_correct: + assert rewards == {1}, "Filter should keep only correct samples" + else: + assert 0 in rewards, "Without filter, incorrect samples should be present" diff --git a/tests/fast/rollout/inference_rollout/integration/test_group_rm.py b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py new file mode 100644 index 0000000000..afd870c302 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py @@ -0,0 +1,22 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + integration_env_config(["--group-rm", "--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + id="group_rm_enabled", + ), + ], + indirect=True, +) +def test_group_rm_rewards_set(rollout_env): + env = rollout_env + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + rewards = [sample.reward for group in out.samples for sample in group] + assert all(r in (0, 1) for r in rewards) diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py new file mode 100644 index 0000000000..2b12d3d88f --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py @@ -0,0 +1,65 @@ +import pytest +from tests.fast.fixtures.rollout_fixtures import DEFAULT_DATA_ROWS, RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_train + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.utils.misc import function_registry +from miles.utils.types import Sample + + +async def _multi_sample_generate(input: GenerateFnInput) -> GenerateFnOutput: + sample = input.sample + s1 = Sample( + prompt=sample.prompt, + response="\\boxed{8}", + response_length=5, + tokens=sample.tokens + [59, 79075, 90, 23, 92], + label=sample.label, + reward=None, + status=Sample.Status.COMPLETED, + ) + s2 = Sample( + prompt=sample.prompt, + response="\\boxed{8}", + response_length=5, + tokens=sample.tokens + [59, 79075, 90, 23, 92], + label=sample.label, + reward=0.5, + status=Sample.Status.COMPLETED, + ) + return GenerateFnOutput(samples=[s1, s2]) + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + + [ + "--custom-generate-function-path", + "test:multi_sample_generate", + "--rollout-batch-size", + "1", + "--n-samples-per-prompt", + "1", + ], + data_rows=DEFAULT_DATA_ROWS, + ), + id="multi_sample_output", + ), + ], + indirect=True, +) +def test_multi_sample_output_preserves_existing_reward(rollout_env): + env = rollout_env + with function_registry.temporary("test:multi_sample_generate", _multi_sample_generate): + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + assert isinstance(group[0], list) + samples = group[0] + assert len(samples) == 2 + assert samples[0].reward == 1 + assert samples[1].reward == 0.5 diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py new file mode 100644 index 0000000000..c41d713991 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py @@ -0,0 +1,114 @@ +from typing import Any + +import pytest +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_rollout + +from miles.utils.test_utils.mock_tools import TwoTurnStub +from miles.utils.types import Sample + + +TWO_TURN_DATA_ROWS = [{"input": [{"role": "user", "content": TwoTurnStub.USER_QUESTION}], "label": "2008"}] + +_VARIANT_NAMES = [ + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", +] + +_BASE_EXTRA_ARGV = [ + "--rollout-batch-size", + "2", + "--n-samples-per-prompt", + "2", + "--n-samples-per-eval-prompt", + "2", + "--custom-rm-path", + "tests.fast.rollout.inference_rollout.integration.test_multi_turn._simple_reward_function", +] + + +def _config_for_variant(variant: str) -> RolloutEnvConfig: + return RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + _BASE_EXTRA_ARGV, + data_rows=TWO_TURN_DATA_ROWS, + ) + + +@pytest.mark.parametrize( + "variant,rollout_env", + [pytest.param(variant, _config_for_variant(variant), id=variant) for variant in _VARIANT_NAMES], + indirect=["rollout_env"], +) +@pytest.mark.parametrize("test_type", ["train", "eval"]) +def test_rollout(rollout_env, variant, test_type): + env = rollout_env + env.mock_server.process_fn = TwoTurnStub.process_fn + + out = load_and_call_rollout(env.args, env.data_source, mode=test_type) + + if test_type == "train": + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + _verify_samples(variant, group) + else: + assert "toy" in out.data + samples = out.data["toy"]["samples"] + _verify_samples(variant, samples) + + +def _verify_samples(variant: str, samples: list[Any]): + is_multi_samples = variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples") + + if is_multi_samples: + if len(samples) > 0 and isinstance(samples[0], list): + # Train mode: list[list[Sample]], grouped by prompt + assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" + for group_sample in samples: + assert isinstance(group_sample, list), "multi_samples variant should return list[Sample] per generate" + _verify_group_samples(group_sample) + else: + # Eval mode: list[Sample], flattened + # n_samples_per_eval_prompt=2, and each generate returns 2 turns, so 2*2=4 samples + assert ( + len(samples) == 4 + ), f"n_samples_per_eval_prompt=2, each generate returns 2 turns, so should have 4 samples, got {len(samples)}" + # Group samples by prompt (every 2 samples form a group) + group_samples_list = [samples[i : i + 2] for i in range(0, len(samples), 2)] + for group_samples in group_samples_list: + _verify_group_samples(group_samples) + else: + assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" + for sample in samples: + assert isinstance(sample, Sample), "single_sample variant should return Sample, not list" + _verify_sample(sample) + + +def _verify_group_samples(group_samples: list[Sample], expected_count: int = 2): + assert len(group_samples) == expected_count, f"Group should have {expected_count} samples (one per turn)" + for i, sample in enumerate(group_samples): + _verify_sample(sample, expect_answer=(i == len(group_samples) - 1)) + + +def _verify_sample(sample: Sample, expected_reward: float = 1.0, expect_answer: bool = True): + assert sample.status == Sample.Status.COMPLETED + assert sample.reward == expected_reward, f"Sample should have reward={expected_reward}" + if expect_answer: + assert "2008" in sample.response, "Response should contain final answer '2008'" + + +async def _simple_reward_function(args, samples: Sample | list[Sample]) -> float | list[float]: + if isinstance(samples, list): + # For multi_samples variants, use the last sample's reward + if getattr(args, "generate_multi_samples", False): + return [_check_reward(samples[-1])] * len(samples) + else: + return [_check_reward(sample) for sample in samples] + else: + return _check_reward(samples) + + +def _check_reward(sample: Sample) -> float: + return float(sample.response and (str(sample.label) in sample.response)) diff --git a/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py new file mode 100644 index 0000000000..0812962cc7 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py @@ -0,0 +1,48 @@ +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + +_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, + {"input": "What is 1+8?", "label": "wrong"}, + {"input": "What is 1+9?", "label": "wrong"}, + {"input": "What is 1+6?", "label": "wrong"}, +] + +_BASE_ARGV = [ + "--over-sampling-batch-size", + "4", + "--dynamic-sampling-filter-path", + "test:filter_by_reward", +] + + +def _over_sampling_config(rollout_batch_size: int): + return integration_env_config(["--rollout-batch-size", str(rollout_batch_size)] + _BASE_ARGV, data_rows=_DATA_ROWS) + + +@pytest.mark.parametrize( + "rollout_env,expected_rounds", + [ + pytest.param(_over_sampling_config(1), 1, id="one_round"), + pytest.param(_over_sampling_config(2), 2, id="two_rounds"), + ], + indirect=["rollout_env"], +) +def test_over_sampling_rounds(rollout_env, expected_rounds): + env = rollout_env + + with function_registry.temporary("test:filter_by_reward", filter_by_reward): + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + assert all(group[0].reward == 1 for group in out.samples) + + requests_count = len(env.mock_server.request_log) + expected_requests = expected_rounds * env.args.over_sampling_batch_size + assert requests_count == expected_requests, f"Expected {expected_rounds} round(s) = {expected_requests} requests" diff --git a/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py new file mode 100644 index 0000000000..36e78c16c1 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py @@ -0,0 +1,67 @@ +from unittest.mock import Mock + +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + +# Data with only 2 reward=1 samples out of 4. +# This ensures all 4 samples must be generated to collect 2 valid ones. +_FILTER_TEST_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, # reward=1 + {"input": "What is 1+8?", "label": "wrong"}, # reward=0 + {"input": "What is 1+9?", "label": "wrong"}, # reward=0 + {"input": "What is 1+6?", "label": "7"}, # reward=1 +] + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + integration_env_config( + [ + "--rollout-batch-size", + "2", + "--over-sampling-batch-size", + "4", + "--dynamic-sampling-filter-path", + "test:filter_by_reward", + "--rollout-sample-filter-path", + "test:sample_filter", + "--rollout-all-samples-process-path", + "test:all_samples_process", + ], + data_rows=_FILTER_TEST_DATA_ROWS, + ), + id="sample_filter_vs_all_samples", + ), + ], + indirect=True, +) +def test_sample_filter_and_all_samples_process(rollout_env): + env = rollout_env + sample_filter_mock = Mock() + all_samples_process_mock = Mock() + + with ( + function_registry.temporary("test:filter_by_reward", filter_by_reward), + function_registry.temporary("test:sample_filter", sample_filter_mock), + function_registry.temporary("test:all_samples_process", all_samples_process_mock), + ): + load_and_call_train(env.args, env.data_source) + + sample_filter_mock.assert_called_once() + _, filtered_data = sample_filter_mock.call_args[0] + rewards = [g[0][0].reward if isinstance(g[0], list) else g[0].reward for g in filtered_data] + assert all(r == 1 for r in rewards) + + all_samples_process_mock.assert_called_once() + _, all_samples, data_source = all_samples_process_mock.call_args[0] + assert data_source is not None + + assert len(all_samples) > len(filtered_data), "all_samples_process should see more samples than sample_filter" diff --git a/tests/fast/rollout/inference_rollout/integration/test_semaphore.py b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py new file mode 100644 index 0000000000..889a9ff8ac --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py @@ -0,0 +1,33 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + +_DATA_ROWS = [{"input": f"What is 1+{i}?", "label": str(1 + i)} for i in range(10)] +_BASE_ARGV = ["--rollout-batch-size", "4", "--n-samples-per-prompt", "2"] + + +@pytest.mark.parametrize( + "rollout_env,expected_range", + [ + pytest.param( + integration_env_config( + ["--sglang-server-concurrency", "1"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 + ), + (1, 1), + id="limit_1", + ), + pytest.param( + integration_env_config( + ["--sglang-server-concurrency", "999"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 + ), + (2, 999), + id="no_limit", + ), + ], + indirect=["rollout_env"], +) +def test_max_concurrent(rollout_env, expected_range): + env = rollout_env + load_and_call_train(env.args, env.data_source) + min_expected, max_expected = expected_range + assert min_expected <= env.mock_server.max_concurrent <= max_expected diff --git a/tests/fast/rollout/inference_rollout/integration/utils.py b/tests/fast/rollout/inference_rollout/integration/utils.py new file mode 100644 index 0000000000..ad413cf949 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/utils.py @@ -0,0 +1,89 @@ +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig + +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnOutput, + RolloutFnTrainInput, +) +from miles.rollout.filter_hub.base_types import DynamicFilterOutput +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.utils.types import Sample + + +def expected_sample(*, group_index: int | None) -> Sample: + return Sample( + group_index=group_index, + index=0, + prompt="What is 1+7?", + tokens=[3838, 374, 220, 16, 10, 22, 30, 59, 79075, 90, 23, 92], + multimodal_inputs=None, + multimodal_train_inputs=None, + response="\\boxed{8}", + response_length=5, + label="8", + reward=1, + loss_mask=None, + weight_versions=[], + rollout_log_probs=[-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125], + rollout_routed_experts=None, + remove_sample=False, + status=Sample.Status.COMPLETED, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=Sample.SpecInfo( + spec_accept_token_num=0, spec_draft_token_num=0, spec_verify_ct=0, completion_token_num=0 + ), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=0, total_prompt_tokens=7), + ) + + +MODULAR_ROLLOUT_BASE_ARGV = [ + "--rollout-function-path", + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", +] + +MIXED_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, + {"input": "What is 1+8?", "label": "9"}, + {"input": "What is 1+9?", "label": "wrong"}, + {"input": "What is 1+6?", "label": "7"}, +] + + +def integration_env_config( + extra_argv: list[str], + data_rows: list[dict] | None = None, + latency: float = 0.0, + variant: str = "single_turn", +): + return RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + extra_argv, + data_rows=data_rows, + latency=latency, + ) + + +def load_and_call_rollout(args, data_source, mode: str = "train") -> RolloutFnOutput: + function_path = args.rollout_function_path if mode == "train" else args.eval_function_path + fn = load_rollout_function( + RolloutFnConstructorInput(args=args, data_source=data_source), + function_path, + ) + if mode == "train": + return call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0)) + else: + return call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + +def load_and_call_train(args, data_source): + return load_and_call_rollout(args, data_source, mode="train") + + +def filter_by_reward(args, samples, **kwargs): + reward = samples[0].reward if not isinstance(samples[0], list) else samples[0][0].reward + if reward == 1: + return DynamicFilterOutput(keep=True) + return DynamicFilterOutput(keep=False, reason="reward_zero") diff --git a/tests/fast/rollout/inference_rollout/test_compatibility.py b/tests/fast/rollout/inference_rollout/test_compatibility.py new file mode 100644 index 0000000000..ddfecd067b --- /dev/null +++ b/tests/fast/rollout/inference_rollout/test_compatibility.py @@ -0,0 +1,196 @@ +import asyncio +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.base_types import ( + GenerateFnInput, + GenerateFnOutput, + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnEvalOutput, + RolloutFnTrainInput, + RolloutFnTrainOutput, +) +from miles.rollout.inference_rollout.compatibility import ( + LegacyGenerateFnAdapter, + LegacyRolloutFnAdapter, + call_rollout_function, + load_generate_function, + load_rollout_function, +) +from miles.utils.async_utils import run +from miles.utils.misc import function_registry + + +@pytest.fixture +def constructor_input(): + return RolloutFnConstructorInput(args="dummy_args", data_source="dummy_data_source") + + +@pytest.fixture +def make_generate_fn_input(): + def _make(evaluation: bool = False): + state = MagicMock() + state.args = MagicMock() + + return GenerateFnInput( + state=state, + sample={"text": "test prompt"}, + sampling_params={"temperature": 0.7}, + evaluation=evaluation, + ) + + return _make + + +class TestSupportedRolloutFormats: + """ + Documentation test to show various supported rollout function formats + """ + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_1_legacy_function_raw_output(self, constructor_input, evaluation): + def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): + if evaluation: + return {"metric": {"accuracy": 0.9}} + return [[{"text": "sample"}]] + + with function_registry.temporary("test:legacy_rollout", legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "test:legacy_rollout") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, LegacyRolloutFnAdapter) + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"metric": {"accuracy": 0.9}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "sample"}]] + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_2_legacy_function_typed_output(self, constructor_input, evaluation): + def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): + if evaluation: + return RolloutFnEvalOutput(data={"ds": {"acc": 0.95}}) + return RolloutFnTrainOutput(samples=[[{"text": "typed"}]]) + + with function_registry.temporary("test:legacy_typed", legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "test:legacy_typed") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"ds": {"acc": 0.95}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "typed"}]] + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_3_sync_class(self, constructor_input, evaluation): + class SyncRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + pass + + def __call__(self, input): + if input.evaluation: + return RolloutFnEvalOutput(data={"test": {"score": 1}}) + return RolloutFnTrainOutput(samples=[[{"text": "sync"}]]) + + with function_registry.temporary("test:sync_class", SyncRolloutFn): + fn = load_rollout_function(constructor_input, "test:sync_class") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, SyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_4_async_class(self, constructor_input, evaluation): + class AsyncRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + pass + + async def __call__(self, input): + await asyncio.sleep(0.001) + if input.evaluation: + return RolloutFnEvalOutput(data={"benchmark": {"accuracy": 0.98}}) + return RolloutFnTrainOutput(samples=[[{"text": "async"}]]) + + with function_registry.temporary("test:async_class", AsyncRolloutFn): + fn = load_rollout_function(constructor_input, "test:async_class") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, AsyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) + + +class TestSupportedGenerateFormats: + """ + Documentation test similar to TestSupportedRolloutFormats + """ + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_1_legacy_function_with_evaluation_param(self, make_generate_fn_input, evaluation): + async def legacy_generate_fn(args, sample, sampling_params, evaluation=False): + return "my_sample" + + with function_registry.temporary("test:legacy_gen_eval", legacy_generate_fn): + fn = load_generate_function("test:legacy_gen_eval") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_2_legacy_function_without_evaluation_param(self, make_generate_fn_input, evaluation): + async def legacy_generate_fn(args, sample, sampling_params): + return "my_sample" + + with function_registry.temporary("test:legacy_gen", legacy_generate_fn): + fn = load_generate_function("test:legacy_gen") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_3_new_async_function_api(self, make_generate_fn_input, evaluation): + async def generate(input: GenerateFnInput) -> GenerateFnOutput: + return GenerateFnOutput(samples="my_sample") + + with function_registry.temporary("test:new_async", generate): + fn = load_generate_function("test:new_async") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_4_new_class_api(self, make_generate_fn_input, evaluation): + class MyGenerateFn: + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: + return GenerateFnOutput(samples="my_sample") + + with function_registry.temporary("test:new_class", MyGenerateFn): + fn = load_generate_function("test:new_class") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, MyGenerateFn) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" diff --git a/tests/fast/rollout/rm_hub/__init__.py b/tests/fast/rollout/rm_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/rollout/rm_hub/test_deepscaler.py b/tests/fast/rollout/rm_hub/test_deepscaler.py new file mode 100644 index 0000000000..bd4c606a68 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_deepscaler.py @@ -0,0 +1,26 @@ +import pytest + +from miles.rollout.rm_hub.deepscaler import get_deepscaler_rule_based_reward + + +class TestGetDeepscalerRuleBasedReward: + @pytest.mark.parametrize( + "response,label,expected", + [ + (r"Let me analyze...The answer is \boxed{42}", "42", 1), + (r"Thinking...The answer is \boxed{wrong}", "42", 0), + (r"###Response\boxed{42}", "42", 1), + (r"###Response\boxed{wrong}", "42", 0), + (r"The answer is \boxed{42}", "42", 0), + (r"The answer is 42", "42", 0), + (r"\boxed{42}", "", 0), + (r"\boxed{42}", r"\boxed{42}", 1), + (r"\boxed{123}", 123, 1), + (r"\boxed{3.14}", 3.14, 1), + (r"\boxed{1/2}", "0.5", 1), + (r"\boxed{\frac{1}{2}}", "0.5", 1), + (r"First thoughtSecond thought\boxed{42}", "42", 1), + ], + ) + def test_get_deepscaler_rule_based_reward(self, response, label, expected): + assert get_deepscaler_rule_based_reward(response, label) == expected diff --git a/tests/fast/rollout/rm_hub/test_f1.py b/tests/fast/rollout/rm_hub/test_f1.py new file mode 100644 index 0000000000..c9ecf9614d --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_f1.py @@ -0,0 +1,44 @@ +import pytest + +from miles.rollout.rm_hub.f1 import f1_score, normalize_answer + + +class TestNormalizeAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("Hello World", "hello world"), + ("The quick brown fox", "quick brown fox"), + ("A cat and a dog", "cat and dog"), + ("Hello, world!", "hello world"), + (" multiple spaces ", "multiple spaces"), + ("An apple", "apple"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_normalize_answer(self, input_str, expected): + assert normalize_answer(input_str) == expected + + +class TestF1Score: + @pytest.mark.parametrize( + "prediction,ground_truth,expected_f1,expected_prec,expected_recall", + [ + ("hello world", "hello world", 1.0, 1.0, 1.0), + ("hello world foo", "hello world bar", 2 / 3, 2 / 3, 2 / 3), + ("abc", "xyz", 0, 0, 0), + (None, "anything", 0, 0, 0), + ("yes", "no", 0, 0, 0), + ("no", "yes", 0, 0, 0), + ("yes", "yes", 1.0, 1.0, 1.0), + ("noanswer", "yes", 0, 0, 0), + ("the answer is correct", "answer is correct", 1.0, 1.0, 1.0), + ("hello, world!", "hello world", 1.0, 1.0, 1.0), + ("hello", "hello world", pytest.approx(2 / 3), 1.0, 0.5), + ], + ) + def test_f1_score(self, prediction, ground_truth, expected_f1, expected_prec, expected_recall): + f1, prec, recall = f1_score(prediction, ground_truth) + assert f1 == expected_f1 + assert prec == expected_prec + assert recall == expected_recall diff --git a/tests/fast/rollout/rm_hub/test_gpqa.py b/tests/fast/rollout/rm_hub/test_gpqa.py new file mode 100644 index 0000000000..45cefd2015 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_gpqa.py @@ -0,0 +1,86 @@ +import pytest + +from miles.rollout.rm_hub.gpqa import ( + _extract_letter_from_response, + _normalize_text, + _strip_chain_of_thought, + compute_gpqa_reward, +) + + +class TestStripChainOfThought: + @pytest.mark.parametrize( + "text,expected", + [ + ("Let me think...The answer is A", "The answer is A"), + ("The answer is A", "The answer is A"), + ("", ""), + (None, ""), + ], + ) + def test_strip_chain_of_thought(self, text, expected): + assert _strip_chain_of_thought(text) == expected + + +class TestNormalizeText: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("Hello World", "hello world"), + ("Test-123", "test 123"), + ("A, B, C", "a b c"), + ("", ""), + ], + ) + def test_normalize_text(self, input_str, expected): + assert _normalize_text(input_str) == expected + + +class TestExtractLetterFromResponse: + @pytest.mark.parametrize( + "response,expected", + [ + ("The answer is A", "A"), + ("answer: B", "B"), + ("I think C is correct", "C"), + ("final answer: D", "D"), + ("Option A is the best choice", "A"), + ("The answer is B", "B"), + ("After analysis, my choice is C", "C"), + ("A B C D", "D"), + ("No valid letter here", None), + ("", None), + (None, None), + ("The answer is Z", None), + ], + ) + def test_extract_letter(self, response, expected): + assert _extract_letter_from_response(response, "ABCD") == expected + + +class TestComputeGpqaReward: + @pytest.mark.parametrize( + "response,label,metadata,expected", + [ + ("Answer: A", "A", None, 1.0), + ("Answer: A", "B", None, 0.0), + (None, "A", None, 0.0), + ("Answer: B", "ignored", {"correct_letter": "B"}, 1.0), + ("Answer: A", "ignored", {"correct_letter": "B"}, 0.0), + ("Answer: A", 0, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), + ("Answer: B", 1, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), + ("Answer: X", "X", {"valid_letters": ["X", "Y", "Z"]}, 1.0), + ("Answer: A", "X", {"valid_letters": ["X", "Y", "Z"]}, 0.0), + ( + "I believe the answer is Paris", + "", + {"choices": ["Paris", "London", "Berlin", "Rome"], "correct_letter": "A"}, + 1.0, + ), + ("Answer: A", "", {"choices": {"A": "Paris", "B": "London"}, "correct_letter": "A"}, 1.0), + ("The answer is Paris", "Paris", {"choices": ["Paris", "London", "Berlin", "Rome"]}, 1.0), + ("Let me think step by step...The answer is A", "A", None, 1.0), + ], + ) + def test_compute_gpqa_reward(self, response, label, metadata, expected): + assert compute_gpqa_reward(response, label, metadata=metadata) == expected diff --git a/tests/fast/rollout/rm_hub/test_math_dapo_utils.py b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py new file mode 100644 index 0000000000..56a7f6d1f9 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py @@ -0,0 +1,108 @@ +import pytest + +from miles.rollout.rm_hub.math_dapo_utils import ( + compute_score, + is_correct_minerva, + is_correct_strict_box, + last_boxed_only_string, + normalize_final_answer, + remove_boxed, +) + + +class TestLastBoxedOnlyString: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", r"\boxed{42}"), + (r"\boxed{x^2}", r"\boxed{x^2}"), + (r"No boxed", None), + (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), + ], + ) + def test_last_boxed_only_string(self, input_str, expected): + assert last_boxed_only_string(input_str) == expected + + +class TestRemoveBoxed: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"\boxed{42}", "42"), + (r"\boxed{x + 1}", "x + 1"), + ], + ) + def test_remove_boxed_valid(self, input_str, expected): + assert remove_boxed(input_str) == expected + + def test_remove_boxed_invalid(self): + with pytest.raises(AssertionError): + remove_boxed("not boxed") + + +class TestNormalizeFinalAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("42", "42"), + (" 42 ", "42"), + (r"\text{hello}", "hello"), + (r"\textbf{bold}", "bold"), + (r"x = 42", "42"), + (r"100 square", "100"), + (r"$50$ dollars", "50"), + (r"\boxed{42}", "42"), + (r"\frac12", r"\frac{1}{2}"), + (r"\sqrt3", r"\sqrt{3}"), + ("1,000", "1000"), + ("<|im_end|>", ""), + ], + ) + def test_normalize_final_answer(self, input_str, expected): + assert normalize_final_answer(input_str) == expected + + +class TestIsCorrectMinerva: + @pytest.mark.parametrize( + "solution,gt,gt_need_extract,expected_correct", + [ + ("Answer: 42", "42", False, True), + ("Answer: 100", "42", False, False), + ("Answer: wrong", "42", False, False), + ("Answer: 42", r"\boxed{42}", True, True), + ], + ) + def test_is_correct_minerva(self, solution, gt, gt_need_extract, expected_correct): + correct, pred = is_correct_minerva(solution, gt, gt_need_extract=gt_need_extract) + assert correct == expected_correct + + +class TestIsCorrectStrictBox: + @pytest.mark.parametrize( + "pred,gt,expected_score,expected_pred", + [ + (r"blah blah \boxed{42}", "42", 1, "42"), + (r"\boxed{wrong}", "42", -1, "wrong"), + ("no box here", "42", -1, None), + ], + ) + def test_is_correct_strict_box(self, pred, gt, expected_score, expected_pred): + score, extracted = is_correct_strict_box(pred, gt) + assert score == expected_score + assert extracted == expected_pred + + +class TestComputeScore: + @pytest.mark.parametrize( + "solution,gt,strict_box,expected_score,expected_acc", + [ + ("Answer: 42", "42", False, 1.0, True), + ("Answer: wrong", "42", False, -1.0, False), + (r"\boxed{42}", "42", True, 1.0, True), + ("x" * 500 + " Answer: 42", "42", False, 1.0, True), + ], + ) + def test_compute_score(self, solution, gt, strict_box, expected_score, expected_acc): + result = compute_score(solution, gt, strict_box_verify=strict_box) + assert result["score"] == expected_score + assert result["acc"] == expected_acc diff --git a/tests/fast/rollout/rm_hub/test_math_utils.py b/tests/fast/rollout/rm_hub/test_math_utils.py new file mode 100644 index 0000000000..2423ed4acc --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_math_utils.py @@ -0,0 +1,129 @@ +import pytest + +from miles.rollout.rm_hub.math_utils import ( + _normalize, + extract_answer, + grade_answer_mathd, + grade_answer_sympy, + grade_answer_verl, + last_boxed_only_string, + remove_boxed, +) + + +class TestLastBoxedOnlyString: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", r"\boxed{42}"), + (r"\boxed{x^2 + 1}", r"\boxed{x^2 + 1}"), + (r"So \boxed{\frac{1}{2}}", r"\boxed{\frac{1}{2}}"), + (r"No boxed here", None), + (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), + (r"\boxed{nested {braces}}", r"\boxed{nested {braces}}"), + (r"\fbox{fbox content}", r"\fbox{fbox content}"), + ("", None), + ], + ) + def test_last_boxed_only_string(self, input_str, expected): + assert last_boxed_only_string(input_str) == expected + + +class TestRemoveBoxed: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"\boxed{42}", "42"), + (r"\boxed{x^2 + 1}", "x^2 + 1"), + (r"\boxed{\frac{1}{2}}", r"\frac{1}{2}"), + ("not boxed", None), + ], + ) + def test_remove_boxed(self, input_str, expected): + assert remove_boxed(input_str) == expected + + +class TestExtractAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", "42"), + (r"So \boxed{\frac{1}{2}}", r"\frac{1}{2}"), + (r"Multiple \boxed{1} then \boxed{final}", "final"), + (r"No boxed here", None), + ("", None), + ], + ) + def test_extract_answer(self, input_str, expected): + assert extract_answer(input_str) == expected + + +class TestNormalize: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("1,000", "1000"), + (r"\text{hello}", "hello"), + (" 42 ", "42"), + (r"100%", "100"), + (r"\$50", "50"), + ("HELLO", "hello"), + ("1,234,567", "1234567"), + (None, None), + ], + ) + def test_normalize(self, input_str, expected): + assert _normalize(input_str) == expected + + +class TestGradeAnswerMathd: + @pytest.mark.parametrize( + "given,ground_truth,expected", + [ + ("42", "42", True), + (" 42 ", "42", True), + (r"\frac{1}{2}", r"\frac{1}{2}", True), + ("wrong", "42", False), + ("", "42", False), + ], + ) + def test_grade_answer_mathd(self, given, ground_truth, expected): + assert grade_answer_mathd(given, ground_truth) == expected + + +class TestGradeAnswerSympy: + @pytest.mark.parametrize( + "given,ground_truth,expected", + [ + ("42", "42", True), + ("x^2", "x^2", True), + ("1/2", "0.5", True), + (r"\frac{1}{2}", "0.5", True), + ("wrong", "42", False), + ("", "42", False), + ("(1,2)", "(1,2)", True), + ("(1,2,3)", "(1,2)", False), + ("42", None, False), + ], + ) + def test_grade_answer_sympy(self, given, ground_truth, expected): + assert grade_answer_sympy(given, ground_truth) == expected + + +class TestGradeAnswerVerl: + @pytest.mark.parametrize( + "solution,ground_truth,expected", + [ + (r"\boxed{42}", "42", True), + (r"The answer is \boxed{42}", "42", True), + (r"\boxed{1/2}", r"\frac{1}{2}", True), + (r"\boxed{wrong}", "42", False), + ("no boxed", "42", False), + (r"\boxed{42}", r"\boxed{42}", True), + ("", "42", False), + (r"\boxed{42}", "", False), + (r"\boxed{42}", None, False), + ], + ) + def test_grade_answer_verl(self, solution, ground_truth, expected): + assert grade_answer_verl(solution, ground_truth) == expected diff --git a/tests/fast/rollout/rm_hub/test_rm_hub.py b/tests/fast/rollout/rm_hub/test_rm_hub.py new file mode 100644 index 0000000000..a3dadbdaf0 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_rm_hub.py @@ -0,0 +1,126 @@ +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.rm_hub import async_rm, batched_async_rm +from miles.utils.async_utils import run +from miles.utils.types import Sample + + +@pytest.fixture +def mock_args(): + args = MagicMock() + args.custom_rm_path = None + args.rm_type = None + args.rm_url = None + return args + + +class TestAsyncRm: + @pytest.mark.parametrize( + "rm_type,response,label,expected", + [ + ("math", r"\boxed{42}", "42", 1), + ("math", r"\boxed{wrong}", "42", 0), + ("f1", "hello world", "hello world", 1.0), + ("dapo", "Answer: 42", "42", {"score": 1.0}), + ("deepscaler", r"\boxed{42}", "42", 1), + ("gpqa", "Answer: A", "A", 1.0), + ("boxed_f1", r"Final answer is \boxed{hello world}", "hello world", 1.0), + ], + ) + def test_rm_types(self, mock_args, rm_type, response, label, expected): + mock_args.rm_type = rm_type + sample = Sample(prompt="", response=response, label=label) + reward = run(async_rm(mock_args, sample)) + if isinstance(expected, dict): + for k, v in expected.items(): + assert reward[k] == v + else: + assert reward == expected + + def test_f1_rm_partial(self, mock_args): + mock_args.rm_type = "f1" + sample = Sample(prompt="", response="hello", label="hello world") + reward = run(async_rm(mock_args, sample)) + assert 0 < reward < 1 + + def test_random_rm(self, mock_args): + mock_args.rm_type = "random" + sample = Sample(prompt="", response="anything", label="anything") + reward = run(async_rm(mock_args, sample)) + assert reward in [0, 1] + + def test_rm_type_from_metadata(self, mock_args): + mock_args.rm_type = None + sample = Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}) + reward = run(async_rm(mock_args, sample)) + assert reward == 1 + + @pytest.mark.parametrize( + "rm_type,match", + [ + ("unknown_type", "not implemented"), + ("", "not specified"), + ], + ) + def test_invalid_rm_type_raises(self, mock_args, rm_type, match): + mock_args.rm_type = rm_type + sample = Sample(prompt="", response="test", label="test") + with pytest.raises(NotImplementedError, match=match): + run(async_rm(mock_args, sample)) + + +class TestBatchedAsyncRm: + @pytest.mark.parametrize( + "rm_type,samples_data,expected", + [ + ( + "math", + [(r"\boxed{42}", "42"), (r"\boxed{100}", "100"), (r"\boxed{wrong}", "42")], + [1, 1, 0], + ), + ( + "f1", + [("hello world", "hello world"), ("different", "something else")], + [1.0, 0], + ), + ], + ) + def test_batched_rm(self, mock_args, rm_type, samples_data, expected): + mock_args.rm_type = rm_type + samples = [Sample(prompt="", response=r, label=label) for r, label in samples_data] + rewards = run(batched_async_rm(mock_args, samples)) + assert rewards == expected + + def test_inplace_set_reward_field(self, mock_args): + mock_args.rm_type = "math" + samples = [ + Sample(prompt="", response=r"\boxed{42}", label="42"), + Sample(prompt="", response=r"\boxed{100}", label="100"), + ] + result = run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) + assert result is None + assert samples[0].reward == 1 + assert samples[1].reward == 1 + + def test_inplace_raises_on_existing_reward(self, mock_args): + mock_args.rm_type = "math" + samples = [Sample(prompt="", response=r"\boxed{42}", label="42", reward=0.5)] + with pytest.raises(AssertionError, match="Overriding"): + run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) + + def test_empty_samples(self, mock_args): + mock_args.rm_type = "math" + rewards = run(batched_async_rm(mock_args, [])) + assert rewards == [] + + def test_mixed_rm_types_via_metadata(self, mock_args): + mock_args.rm_type = None + samples = [ + Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}), + Sample(prompt="", response="hello", label="hello", metadata={"rm_type": "f1"}), + ] + rewards = run(batched_async_rm(mock_args, samples)) + assert rewards[0] == 1 + assert rewards[1] == 1.0 diff --git a/tests/fast/router/__init__.py b/tests/fast/router/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/router/test_router.py b/tests/fast/router/test_router.py new file mode 100644 index 0000000000..7c645fe304 --- /dev/null +++ b/tests/fast/router/test_router.py @@ -0,0 +1,204 @@ +import asyncio +from argparse import Namespace + +import pytest +import requests + +from miles.router.router import MilesRouter +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, default_process_fn +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +def make_router_args(router_port: int, **overrides) -> Namespace: + defaults = dict( + sglang_router_ip="127.0.0.1", + sglang_router_port=router_port, + rollout_health_check_interval=1.0, + miles_router_health_check_failure_threshold=3, + miles_router_max_connections=100, + miles_router_timeout=None, + miles_router_middleware_paths=[], + ) + defaults.update(overrides) + return Namespace(**defaults) + + +def create_mock_worker(start_port: int = 30000) -> MockSGLangServer: + port = find_available_port(start_port) + return MockSGLangServer( + model_name="Qwen/Qwen3-0.6B", + process_fn=default_process_fn, + host="127.0.0.1", + port=port, + latency=0.0, + ) + + +class RouterEnv: + def __init__(self, router: MilesRouter, server: UvicornThreadServer): + self.router = router + self.server = server + + @property + def url(self) -> str: + return self.server.url + + +@pytest.fixture +def router_env(): + args = make_router_args(find_available_port(20000)) + router = MilesRouter(args, verbose=False) + server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) + server.start() + yield RouterEnv(router, server) + server.stop() + + +@pytest.fixture +def mock_worker(): + server = create_mock_worker() + server.start() + yield server + server.stop() + + +@pytest.fixture +def mock_worker_factory(): + servers = [] + + def _create(): + start_port = 30000 + len(servers) * 100 + server = create_mock_worker(start_port) + server.start() + servers.append(server) + return server + + yield _create + for s in servers: + s.stop() + + +@pytest.fixture +def router_factory(): + def _create(**overrides) -> MilesRouter: + args = make_router_args(find_available_port(20000), **overrides) + return MilesRouter(args, verbose=False) + + return _create + + +class TestWorkerManagement: + def test_add_worker_via_query_param(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30001" + r = requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0) + r.raise_for_status() + + assert r.json()["status"] == "success" + assert worker_url in router_env.router.worker_request_counts + assert router_env.router.worker_request_counts[worker_url] == 0 + + def test_add_worker_via_body(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30002" + r = requests.post(f"{router_env.url}/add_worker", json={"url": worker_url}, timeout=5.0) + r.raise_for_status() + + assert r.json()["status"] == "success" + assert worker_url in router_env.router.worker_request_counts + + def test_add_worker_duplicate(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30003" + requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() + requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() + + assert len(router_env.router.worker_request_counts) == 1 + assert worker_url in router_env.router.worker_request_counts + + def test_add_worker_missing_url(self, router_env: RouterEnv): + r = requests.post(f"{router_env.url}/add_worker", json={}, timeout=5.0) + assert r.status_code == 400 + assert "error" in r.json() + + def test_list_workers(self, router_env: RouterEnv): + worker_urls = ["http://127.0.0.1:30001", "http://127.0.0.1:30002"] + for url in worker_urls: + requests.post(f"{router_env.url}/add_worker", params={"url": url}, timeout=5.0) + + r = requests.get(f"{router_env.url}/list_workers", timeout=5.0) + r.raise_for_status() + assert set(r.json()["urls"]) == set(worker_urls) + + +class TestLoadBalancing: + def test_use_url_selects_min_load(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8} + + selected = router._use_url() + assert selected == "http://w2:8000" + assert router.worker_request_counts["http://w2:8000"] == 3 + + def test_use_url_excludes_dead_workers(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 1, "http://w3:8000": 3} + router.dead_workers = {"http://w2:8000"} + + selected = router._use_url() + assert selected == "http://w3:8000" + assert router.worker_request_counts["http://w3:8000"] == 4 + + def test_use_url_raises_when_all_dead(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 0} + router.dead_workers = {"http://w1:8000"} + + with pytest.raises(RuntimeError, match="No healthy workers"): + router._use_url() + + +# TODO: extract main body inside `_health_check_loop`, then can test that function +class TestHealthCheck: + def test_check_worker_health_success(self, router_factory, mock_worker: MockSGLangServer): + router = router_factory() + url, healthy = asyncio.run(router._check_worker_health(mock_worker.url)) + assert url == mock_worker.url + assert healthy is True + + def test_check_worker_health_failure(self, router_factory): + router = router_factory() + url, healthy = asyncio.run(router._check_worker_health("http://127.0.0.1:59999")) + assert url == "http://127.0.0.1:59999" + assert healthy is False + + +class TestProxyIntegration: + def test_proxy_forwards_request(self, router_env: RouterEnv, mock_worker: MockSGLangServer): + requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0).raise_for_status() + + payload = {"input_ids": [1, 2, 3], "return_logprob": True} + r = requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0) + r.raise_for_status() + + assert "text" in r.json() + assert len(mock_worker.request_log) == 1 + assert mock_worker.request_log[0] == payload + + def test_proxy_multi_worker(self, router_env: RouterEnv, mock_worker_factory): + worker1, worker2 = mock_worker_factory(), mock_worker_factory() + requests.post(f"{router_env.url}/add_worker", params={"url": worker1.url}, timeout=5.0) + requests.post(f"{router_env.url}/add_worker", params={"url": worker2.url}, timeout=5.0) + + payload = {"input_ids": [1, 2, 3], "return_logprob": True} + for _ in range(4): + requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0).raise_for_status() + + all_requests = worker1.request_log + worker2.request_log + assert len(all_requests) == 4 + assert all(req == payload for req in all_requests) + + def test_proxy_health_endpoint(self, router_env: RouterEnv, mock_worker: MockSGLangServer): + requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0) + + r = requests.get(f"{router_env.url}/health", timeout=5.0) + r.raise_for_status() + assert r.json()["status"] == "ok" diff --git a/tests/fast/router/test_sessions.py b/tests/fast/router/test_sessions.py new file mode 100644 index 0000000000..5c6edafe20 --- /dev/null +++ b/tests/fast/router/test_sessions.py @@ -0,0 +1,195 @@ +from types import SimpleNamespace + +import pytest +import requests + +from miles.router.router import MilesRouter +from miles.router.sessions import SessionManager, SessionRecord +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import ProcessResult, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +class TestSessionManager: + def test_create_session(self): + manager = SessionManager() + session_id = manager.create_session() + assert session_id is not None + assert len(session_id) == 32 + assert session_id in manager.sessions + assert manager.sessions[session_id] == [] + + def test_get_session_exists(self): + manager = SessionManager() + session_id = manager.create_session() + records = manager.get_session(session_id) + assert records == [] + + def test_get_session_not_exists(self): + manager = SessionManager() + records = manager.get_session("nonexistent") + assert records is None + + def test_delete_session_exists(self): + manager = SessionManager() + session_id = manager.create_session() + records = manager.delete_session(session_id) + assert records == [] + assert session_id not in manager.sessions + + def test_delete_session_not_exists(self): + manager = SessionManager() + with pytest.raises(AssertionError): + manager.delete_session("nonexistent") + + def test_add_record(self): + manager = SessionManager() + session_id = manager.create_session() + record = SessionRecord( + timestamp=1234567890.0, + method="POST", + path="generate", + request={"prompt": "hello"}, + response={"text": "world"}, + status_code=200, + ) + manager.add_record(session_id, record) + assert len(manager.sessions[session_id]) == 1 + assert manager.sessions[session_id][0] == record + + def test_add_record_nonexistent_session(self): + manager = SessionManager() + record = SessionRecord( + timestamp=1234567890.0, + method="POST", + path="generate", + request={}, + response={}, + status_code=200, + ) + with pytest.raises(AssertionError): + manager.add_record("nonexistent", record) + + +@pytest.fixture(scope="class") +def router_url(): + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text=f"echo: {prompt}", finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as backend: + args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + hf_checkpoint="Qwen/Qwen3-0.6B", + ) + router = MilesRouter(args) + + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() + + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend.url}) + + try: + yield url + finally: + server.stop() + + +class TestSessionRoutes: + def test_create_session(self, router_url): + response = requests.post(f"{router_url}/sessions") + assert response.status_code == 200 + data = response.json() + assert "session_id" in data + assert len(data["session_id"]) == 32 + + def test_get_session(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + assert get_resp.status_code == 200 + data = get_resp.json() + assert data["session_id"] == session_id + assert data["records"] == [] + + def test_get_session_not_found(self, router_url): + response = requests.get(f"{router_url}/sessions/nonexistent") + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + def test_get_with_records(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + requests.post( + f"{router_url}/sessions/{session_id}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + ) + + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + assert get_resp.status_code == 200 + data = get_resp.json() + assert data["session_id"] == session_id + assert len(data["records"]) == 1 + + def test_delete_session(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") + assert delete_resp.status_code == 204 + assert delete_resp.text == "" + + assert requests.delete(f"{router_url}/sessions/{session_id}").status_code == 404 + + def test_delete_session_not_found(self, router_url): + response = requests.delete(f"{router_url}/sessions/nonexistent") + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + +class TestSessionProxy: + def test_proxy_session_not_found(self, router_url): + response = requests.post(f"{router_url}/sessions/nonexistent/generate", json={}) + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + def test_proxy_records_request_response(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + resp = requests.post( + f"{router_url}/sessions/{session_id}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + ) + assert resp.status_code == 200 + assert "text" in resp.json() + + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + records = get_resp.json()["records"] + assert len(records) == 1 + assert records[0]["method"] == "POST" + assert records[0]["path"] == "generate" + assert records[0]["request"]["input_ids"] == [1, 2, 3] + assert "text" in records[0]["response"] + + delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") + assert delete_resp.status_code == 204 + + def test_proxy_accumulates_records(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + for _ in range(3): + requests.post( + f"{router_url}/sessions/{session_id}/generate", + json={"input_ids": [1], "sampling_params": {}, "return_logprob": True}, + ) + + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + records = get_resp.json()["records"] + assert len(records) == 3 + + delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") + assert delete_resp.status_code == 204 diff --git a/tests/fast/utils/__init__.py b/tests/fast/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/utils/test_arguments.py b/tests/fast/utils/test_arguments.py new file mode 100644 index 0000000000..9bd1a620d6 --- /dev/null +++ b/tests/fast/utils/test_arguments.py @@ -0,0 +1,58 @@ +import argparse +import sys +from unittest.mock import patch + +import pytest + +from miles.utils.arguments import get_miles_extra_args_provider +from miles.utils.misc import function_registry + +PATH_ARGS = ["--rollout-function-path", "--custom-generate-function-path"] +REQUIRED_ARGS = ["--rollout-batch-size", "64"] + + +def make_class_with_add_arguments(): + class MyFn: + @classmethod + def add_arguments(cls, parser): + parser.add_argument("--my-custom-arg", type=int, default=42) + + return MyFn + + +def make_function_with_add_arguments(): + def my_fn(): + pass + + my_fn.add_arguments = lambda parser: parser.add_argument("--my-custom-arg", type=int, default=42) + return my_fn + + +def make_function_without_add_arguments(): + def my_fn(): + pass + + return my_fn + + +@pytest.mark.parametrize("path_arg", PATH_ARGS) +class TestAddArgumentsSupport: + + @pytest.mark.parametrize("fn_factory", [make_class_with_add_arguments, make_function_with_add_arguments]) + def test_add_arguments_is_called_and_arg_is_parsed(self, path_arg, fn_factory): + fn = fn_factory() + with function_registry.temporary("test:fn", fn), patch.object( + sys, "argv", ["test", path_arg, "test:fn", "--my-custom-arg", "100"] + REQUIRED_ARGS + ): + parser = argparse.ArgumentParser() + get_miles_extra_args_provider()(parser) + args, _ = parser.parse_known_args() + assert args.my_custom_arg == 100 + + def test_skips_function_without_add_arguments(self, path_arg): + fn = make_function_without_add_arguments() + with function_registry.temporary("test:fn", fn), patch.object( + sys, "argv", ["test", path_arg, "test:fn"] + REQUIRED_ARGS + ): + parser = argparse.ArgumentParser() + get_miles_extra_args_provider()(parser) diff --git a/tests/utils/test_mask_utils.py b/tests/fast/utils/test_mask_utils.py similarity index 100% rename from tests/utils/test_mask_utils.py rename to tests/fast/utils/test_mask_utils.py diff --git a/tests/fast/utils/test_misc.py b/tests/fast/utils/test_misc.py new file mode 100644 index 0000000000..810c2b67c7 --- /dev/null +++ b/tests/fast/utils/test_misc.py @@ -0,0 +1,59 @@ +import os + +import pytest + +from miles.utils.misc import FunctionRegistry, function_registry, load_function + + +def _fn_a(): + return "a" + + +def _fn_b(): + return "b" + + +class TestFunctionRegistry: + def test_register_and_get(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + assert registry.get("my_fn") is _fn_a + + def test_register_duplicate_raises(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + with pytest.raises(AssertionError): + with registry.temporary("my_fn", _fn_b): + pass + + def test_unregister(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + assert registry.get("my_fn") is _fn_a + assert registry.get("my_fn") is None + + def test_temporary_cleanup_on_exception(self): + registry = FunctionRegistry() + with pytest.raises(RuntimeError): + with registry.temporary("temp_fn", _fn_a): + raise RuntimeError("test") + assert registry.get("temp_fn") is None + + +class TestLoadFunction: + def test_load_from_module(self): + import os.path + + assert load_function("os.path.join") is os.path.join + + def test_load_none_returns_none(self): + assert load_function(None) is None + + def test_load_from_registry(self): + with function_registry.temporary("test:my_fn", _fn_a): + assert load_function("test:my_fn") is _fn_a + + def test_registry_takes_precedence(self): + with function_registry.temporary("os.path.join", _fn_b): + assert load_function("os.path.join") is _fn_b + assert load_function("os.path.join") is os.path.join diff --git a/tests/fast/utils/test_utils/__init__.py b/tests/fast/utils/test_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fast/utils/test_utils/test_mock_sglang_server.py b/tests/fast/utils/test_utils/test_mock_sglang_server.py new file mode 100644 index 0000000000..6633678da1 --- /dev/null +++ b/tests/fast/utils/test_utils/test_mock_sglang_server.py @@ -0,0 +1,409 @@ +import asyncio +import concurrent.futures +import time + +import pytest +import requests + +from miles.utils.test_utils.mock_sglang_server import ( + Counter, + ProcessResult, + ProcessResultMetaInfo, + default_process_fn, + with_mock_server, +) +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub + + +def expected_logprobs(tokenizer, text: str) -> list[dict]: + output_ids = tokenizer.encode(text, add_special_tokens=False) + return [{"token": tokenizer.convert_ids_to_tokens(tid), "logprob": -i / 128} for i, tid in enumerate(output_ids)] + + +@pytest.fixture(scope="module") +def mock_server(): + with with_mock_server() as server: + yield server + + +class TestProcessResultMetaInfo: + def test_to_dict_empty(self): + assert ProcessResultMetaInfo().to_dict() == {} + + def test_to_dict_single_field(self): + assert ProcessResultMetaInfo(weight_version="v1").to_dict() == {"weight_version": "v1"} + + def test_to_dict_partial_fields(self): + assert ProcessResultMetaInfo(weight_version="v1", spec_accept_token_num=10).to_dict() == { + "weight_version": "v1", + "spec_accept_token_num": 10, + } + + def test_to_dict_all_fields(self): + assert ProcessResultMetaInfo( + weight_version="v1", + routed_experts="abc", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ).to_dict() == { + "weight_version": "v1", + "routed_experts": "abc", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + } + + +class TestDefaultProcessFn: + def test_math_question(self): + assert default_process_fn("What is 1+5?") == ProcessResult(text="\\boxed{6}", finish_reason="stop") + assert default_process_fn("What is 1+10?") == ProcessResult(text="\\boxed{11}", finish_reason="stop") + + def test_unknown_question(self): + assert default_process_fn("Hello") == ProcessResult(text="I don't understand.", finish_reason="stop") + + +class TestCounter: + def test_tracks_max(self): + counter = Counter() + assert counter.max_value == 0 + + with counter.track(): + assert counter.max_value == 1 + with counter.track(): + assert counter.max_value == 2 + + counter.reset() + assert counter.max_value == 0 + + def test_concurrent_tasks(self): + counter = Counter() + + async def task(): + with counter.track(): + await asyncio.sleep(0.1) + + async def run_all(): + await asyncio.gather(task(), task(), task()) + + asyncio.run(run_all()) + assert counter.max_value == 3 + + +class TestMockServerBasic: + def test_start_stop(self, mock_server): + assert mock_server.port > 0 + assert f"http://{mock_server.host}:{mock_server.port}" == mock_server.url + + def test_request_log_and_reset_stats(self, mock_server): + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + + payload = {"input_ids": [1, 2, 3], "sampling_params": {"temperature": 0.5}, "return_logprob": True} + requests.post(f"{mock_server.url}/generate", json=payload, timeout=5.0) + assert len(mock_server.request_log) == 1 + assert mock_server.request_log[0] == payload + + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + assert mock_server.max_concurrent == 0 + + @pytest.mark.parametrize("latency,min_time,max_time", [(0.0, 0.0, 0.3), (0.5, 0.5, 1.0)]) + def test_latency(self, latency, min_time, max_time): + with with_mock_server(latency=latency) as server: + start = time.time() + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + elapsed = time.time() - start + assert min_time <= elapsed < max_time + + def test_max_concurrent_with_latency(self): + with with_mock_server(latency=0.1) as server: + + def send_request(): + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(send_request) for _ in range(3)] + concurrent.futures.wait(futures) + + assert server.max_concurrent == 3 + + def test_health_endpoint(self, mock_server): + response = requests.get(f"{mock_server.url}/health", timeout=5.0) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + def test_abort_request_endpoint(self, mock_server): + response = requests.post(f"{mock_server.url}/abort_request", json={}, timeout=5.0) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +class TestGenerateEndpoint: + def test_basic(self, mock_server): + prompt = "What is 1+7?" + input_ids = mock_server.tokenizer.encode(prompt, add_special_tokens=False) + assert input_ids == [3838, 374, 220, 16, 10, 22, 30] + + response = requests.post( + f"{mock_server.url}/generate", + json={ + "input_ids": input_ids, + "sampling_params": {"temperature": 0.7, "max_new_tokens": 10}, + "return_logprob": True, + }, + timeout=5.0, + ) + assert response.status_code == 200 + assert response.json() == { + "text": "\\boxed{8}", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": len(input_ids), + "cached_tokens": 0, + "completion_tokens": 5, + "output_token_logprobs": [ + [-0.0, 59], + [-0.0078125, 79075], + [-0.015625, 90], + [-0.0234375, 23], + [-0.03125, 92], + ], + }, + } + + def test_with_meta_info(self): + def process_fn(_: str) -> ProcessResult: + return ProcessResult( + text="ok", + finish_reason="stop", + cached_tokens=5, + meta_info=ProcessResultMetaInfo( + weight_version="v2.0", + routed_experts="encoded_data", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ), + ) + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + + assert response.json() == { + "text": "ok", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": 3, + "cached_tokens": 5, + "completion_tokens": 1, + "output_token_logprobs": [[-0.0, 562]], + "weight_version": "v2.0", + "routed_experts": "encoded_data", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + }, + } + + def test_finish_reason_length(self): + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text="truncated output", finish_reason="length") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + data = response.json() + + finish_reason = data["meta_info"]["finish_reason"] + assert finish_reason["type"] == "length" + assert finish_reason["length"] == data["meta_info"]["completion_tokens"] + + +class TestChatCompletionsEndpoint: + def test_basic(self, mock_server): + response = requests.post( + f"{mock_server.url}/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "What is 1+5?"}], + }, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + + assert data["id"].startswith("chatcmpl-") + assert isinstance(data["created"], int) + assert data == { + "id": data["id"], + "object": "chat.completion", + "created": data["created"], + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "\\boxed{6}", "tool_calls": None}, + "logprobs": {"content": expected_logprobs(mock_server.tokenizer, "\\boxed{6}")}, + "finish_reason": "stop", + } + ], + } + + def test_with_tool_calls(self): + tool_call_response = 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n' + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=tool_call_response, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What year is it?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": { + "role": "assistant", + "content": "Let me check for you.", + "tool_calls": [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}} + ], + }, + "logprobs": {"content": expected_logprobs(server.tokenizer, tool_call_response)}, + "finish_reason": "tool_calls", + } + + def test_with_tools_but_no_tool_call(self): + response_text = "The weather is sunny today." + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=response_text, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What's the weather?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": {"role": "assistant", "content": response_text, "tool_calls": None}, + "logprobs": {"content": expected_logprobs(server.tokenizer, response_text)}, + "finish_reason": "stop", + } + + def test_with_multiple_tool_calls(self): + multi_tool_response = ( + "I will get year and temperature.\n" + '\n{"name": "get_year", "arguments": {}}\n\n' + '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n' + ) + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=multi_tool_response, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What year and temperature?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": { + "role": "assistant", + "content": "I will get year and temperature.", + "tool_calls": [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}}, + { + "id": "call00001", + "type": "function", + "function": {"name": "get_temperature", "arguments": '{"location": "Shanghai"}'}, + }, + ], + }, + "logprobs": {"content": expected_logprobs(server.tokenizer, multi_tool_response)}, + "finish_reason": "tool_calls", + } + + +class TestMultiTurnToolCallProcessFn: + @pytest.mark.parametrize( + "prompt,expected_response", + [ + pytest.param(TwoTurnStub.FIRST_PROMPT, TwoTurnStub.FIRST_RESPONSE, id="first_turn"), + pytest.param(TwoTurnStub.SECOND_PROMPT, TwoTurnStub.SECOND_RESPONSE, id="second_turn"), + ], + ) + def test_generate_endpoint(self, prompt, expected_response): + with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: + input_ids = server.tokenizer.encode(prompt, add_special_tokens=False) + response = requests.post( + f"{server.url}/generate", + json={"input_ids": input_ids, "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + assert data["text"] == expected_response + assert data["meta_info"]["finish_reason"] == {"type": "stop"} + + @pytest.mark.parametrize( + "messages,expected_content,expected_tool_calls,expected_finish_reason", + [ + pytest.param( + TwoTurnStub.OPENAI_MESSAGES_FIRST_TURN, + TwoTurnStub.FIRST_RESPONSE_CONTENT, + TwoTurnStub.FIRST_TOOL_CALLS_OPENAI_FORMAT, + "tool_calls", + id="first_turn", + ), + pytest.param( + TwoTurnStub.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT, + TwoTurnStub.SECOND_RESPONSE, + None, + "stop", + id="second_turn", + ), + ], + ) + def test_chat_completions_endpoint(self, messages, expected_content, expected_tool_calls, expected_finish_reason): + with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={"model": "test", "messages": messages, "tools": SAMPLE_TOOLS}, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + assert data["choices"][0]["message"]["content"] == expected_content + assert data["choices"][0]["message"]["tool_calls"] == expected_tool_calls + assert data["choices"][0]["finish_reason"] == expected_finish_reason diff --git a/tests/fast/utils/test_utils/test_mock_tools.py b/tests/fast/utils/test_utils/test_mock_tools.py new file mode 100644 index 0000000000..3f2116ec01 --- /dev/null +++ b/tests/fast/utils/test_utils/test_mock_tools.py @@ -0,0 +1,111 @@ +import asyncio + +import pytest +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.function_call_parser import FunctionCallParser + +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub, execute_tool_call + + +class TestExecuteToolCall: + def test_execute_get_year(self): + result = asyncio.run(execute_tool_call("get_year", {})) + assert result == '{"year": 2026}' + + def test_execute_get_temperature(self): + result = asyncio.run(execute_tool_call("get_temperature", {"location": "Mars"})) + assert result == '{"temperature": -60}' + + +class TestApplyChatTemplateWithTools: + EXPECTED_PROMPT_WITHOUT_TOOLS = ( + "<|im_start|>user\n" "What's the weather in Paris?<|im_end|>\n" "<|im_start|>assistant\n" + ) + + EXPECTED_PROMPT_WITH_TOOLS = ( + "<|im_start|>system\n" + "# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" + "<|im_start|>user\n" + "What's the weather in Paris?<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + @pytest.mark.parametrize( + "tools,expected", + [ + pytest.param(None, EXPECTED_PROMPT_WITHOUT_TOOLS, id="without_tools"), + pytest.param(SAMPLE_TOOLS, EXPECTED_PROMPT_WITH_TOOLS, id="with_tools"), + ], + ) + def test_apply_chat_template(self, tools, expected): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) + messages = [{"role": "user", "content": "What's the weather in Paris?"}] + + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools) + + assert prompt == expected + + +class TestSGLangFunctionCallParser: + """Test to demonstrate and ensure SGLang function call parser have features we need without breaking changes.""" + + @pytest.mark.parametrize( + "model_output,expected", + [ + pytest.param( + 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n', + ( + "Let me check for you.", + [ToolCallItem(tool_index=-1, name="get_year", parameters="{}")], + ), + id="single_tool_call", + ), + pytest.param( + "I will get year and temperature.\n" + '\n{"name": "get_year", "arguments": {}}\n\n' + '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n', + ( + "I will get year and temperature.", + [ + ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), + ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Shanghai"}'), + ], + ), + id="multi_tool_calls", + ), + pytest.param( + "The weather is sunny today.", + ("The weather is sunny today.", []), + id="no_tool_call", + ), + pytest.param( + TwoTurnStub.FIRST_RESPONSE, + ( + "Let me get the year and temperature first.", + [ + ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), + ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Mars"}'), + ], + ), + id="multi_turn_first_response", + ), + ], + ) + def test_parse_non_stream(self, model_output, expected): + tools = TypeAdapter(list[Tool]).validate_python(SAMPLE_TOOLS) + parser = FunctionCallParser(tools=tools, tool_call_parser="qwen25") + assert parser.parse_non_stream(model_output) == expected diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index c5c0838c53..9b6e69c295 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -126,6 +126,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, before_ray_job_submit=_launch_background, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_mimo_7B_mtp_only_grad.py b/tests/test_mimo_7B_mtp_only_grad.py index 97c76ace5a..d90a2d7a71 100644 --- a/tests/test_mimo_7B_mtp_only_grad.py +++ b/tests/test_mimo_7B_mtp_only_grad.py @@ -135,6 +135,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_moonlight_16B_A3B.py b/tests/test_moonlight_16B_A3B.py index b1255982ed..c35943ec15 100644 --- a/tests/test_moonlight_16B_A3B.py +++ b/tests/test_moonlight_16B_A3B.py @@ -113,6 +113,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index 15ca8ce5fe..ae3c383ae8 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -115,6 +115,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index dcdbd58347..4d7f034f6c 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -120,6 +120,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index dcaaf5e1f7..32b60f5937 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -120,6 +120,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py index 90cd15cb68..b1954a4e83 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py @@ -118,6 +118,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_short.py b/tests/test_qwen2.5_0.5B_gsm8k_short.py index 867fdcad60..86e21eac8d 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_short.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_short.py @@ -117,6 +117,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 3d19b48ced..3d4768e420 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -93,6 +93,7 @@ def execute(): train_args=train_args, num_gpus_per_node=2, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index 3d70f3e4ce..fcd7772882 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -95,6 +95,7 @@ def execute(): num_gpus_per_node=2 if FEW_GPU else 4, megatron_model_type=None, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/test_qwen3_0.6B_megatron_fsdp_align.py index 1431d8c3d4..b89a2f283b 100644 --- a/tests/test_qwen3_0.6B_megatron_fsdp_align.py +++ b/tests/test_qwen3_0.6B_megatron_fsdp_align.py @@ -97,6 +97,7 @@ def execute(): train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "), num_gpus_per_node=NUM_GPUS, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) U.execute_train( @@ -109,6 +110,7 @@ def execute(): ), num_gpus_per_node=NUM_GPUS, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) U.execute_train( @@ -135,6 +137,7 @@ def execute(): "--debug-train-only " ), num_gpus_per_node=NUM_GPUS, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, megatron_model_type=MODEL_TYPE, ) diff --git a/tests/test_qwen3_0.6B_parallel_check.py b/tests/test_qwen3_0.6B_parallel_check.py index 44f5c42fa5..d0ad283d15 100644 --- a/tests/test_qwen3_0.6B_parallel_check.py +++ b/tests/test_qwen3_0.6B_parallel_check.py @@ -95,6 +95,7 @@ def execute(): ), num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) # 8 GPU CPU 1 for num_gpus in [8, 4, 2]: @@ -124,6 +125,7 @@ def execute(): train_args=args, num_gpus_per_node=num_gpus, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) train_args += "--calculate-per-token-loss " diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index adff108043..b30eeed8e5 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -139,6 +139,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_4B_ckpt.py b/tests/test_qwen3_4B_ckpt.py index 22fb2b5fc3..0df4492e10 100644 --- a/tests/test_qwen3_4B_ckpt.py +++ b/tests/test_qwen3_4B_ckpt.py @@ -124,6 +124,7 @@ def execute(mode: str = ""): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_4B_fsdp_true_on_policy.py b/tests/test_qwen3_4B_fsdp_true_on_policy.py index 7c975c7cc2..03ba4094e9 100644 --- a/tests/test_qwen3_4B_fsdp_true_on_policy.py +++ b/tests/test_qwen3_4B_fsdp_true_on_policy.py @@ -95,6 +95,7 @@ def execute(): "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", "CUBLAS_WORKSPACE_CONFIG": ":4096:8", "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", } U.execute_train( diff --git a/tests/test_qwen3_4B_ppo.py b/tests/test_qwen3_4B_ppo.py index 962f610fac..d4c1ac273a 100644 --- a/tests/test_qwen3_4B_ppo.py +++ b/tests/test_qwen3_4B_ppo.py @@ -122,6 +122,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_vl_4B_fsdp.py b/tests/test_qwen3_vl_4B_fsdp.py index fbdffd237e..bc4ef3293c 100644 --- a/tests/test_qwen3_vl_4B_fsdp.py +++ b/tests/test_qwen3_vl_4B_fsdp.py @@ -92,6 +92,7 @@ def execute(): extra_env_vars = { "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", } U.execute_train( From 72bafb1437e0fa676dd4136867f32d323e196323 Mon Sep 17 00:00:00 2001 From: lizamd <161388580+lizamd@users.noreply.github.com> Date: Thu, 22 Jan 2026 14:50:49 -0800 Subject: [PATCH 02/77] Fix PYTHONPATH for AMD container Megatron-LM location (#506) --- scripts/run-qwen3-4B-amd.sh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/run-qwen3-4B-amd.sh b/scripts/run-qwen3-4B-amd.sh index 83af901563..998f06b7f4 100755 --- a/scripts/run-qwen3-4B-amd.sh +++ b/scripts/run-qwen3-4B-amd.sh @@ -139,16 +139,16 @@ NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 -# "PYTHONPATH": "/workspace/Megatron-LM/", -MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}') +# Dynamically detect Megatron-LM installation path +MEGATRON_LM_PATH=$(python3 -c "import megatron; import os; print(os.path.dirname(os.path.dirname(megatron.__file__)))" 2>/dev/null || echo "/app/Megatron-LM") ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json='{ - "env_vars": { - "PYTHONPATH": "/workspace/Megatron-LM/", - "CUDA_DEVICE_MAX_CONNECTIONS": "1" + --runtime-env-json="{ + \"env_vars\": { + \"PYTHONPATH\": \"${MEGATRON_LM_PATH}/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" } - }' \ + }" \ -- python3 train.py \ --actor-num-nodes 1 \ --actor-num-gpus-per-node 8 \ From 37c96a58a5b7957fb0487087b6a86c4ece8aa094 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:56:37 +0800 Subject: [PATCH 03/77] Revert "Enable experimental rollout flag for CI tests" (#507) --- .github/workflows/pr-test.yml | 40 -- .github/workflows/pr-test.yml.j2 | 19 +- miles/ray/rollout.py | 33 +- miles/rollout/base_types.py | 66 +- .../rollout/generate_hub/agentic_tool_call.py | 85 --- miles/rollout/generate_hub/multi_turn.py | 88 --- miles/rollout/generate_hub/single_turn.py | 46 -- miles/rollout/generate_utils/__init__.py | 0 .../generate_utils/generate_endpoint_utils.py | 112 ---- .../generate_utils/openai_endpoint_utils.py | 67 -- miles/rollout/generate_utils/sample_utils.py | 115 ---- .../rollout/generate_utils/tool_call_utils.py | 115 ---- miles/rollout/inference_rollout/__init__.py | 2 - .../inference_rollout/compatibility.py | 84 --- .../inference_rollout_common.py | 192 ------ .../inference_rollout_eval.py | 112 ---- .../inference_rollout_train.py | 146 ----- miles/rollout/rm_hub/__init__.py | 12 +- miles/router/router.py | 47 +- miles/router/sessions.py | 124 ---- miles/utils/arguments.py | 24 +- miles/utils/environ.py | 14 - miles/utils/http_utils.py | 20 +- miles/utils/misc.py | 50 +- miles/utils/test_utils/__init__.py | 0 miles/utils/test_utils/mock_sglang_server.py | 248 -------- miles/utils/test_utils/mock_tools.py | 268 -------- .../utils/test_utils/uvicorn_thread_server.py | 49 -- miles/utils/types.py | 18 - requirements.txt | 1 - tests/__init__.py | 1 - tests/ci/gpu_lock_exec.py | 11 +- tests/e2e/.gitkeep | 1 - tests/fast/__init__.py | 0 tests/fast/conftest.py | 15 - tests/fast/fixtures/__init__.py | 1 - tests/fast/fixtures/generation_fixtures.py | 274 --------- tests/fast/fixtures/rollout_fixtures.py | 127 ---- tests/fast/rollout/__init__.py | 0 tests/fast/rollout/generate_hub/__init__.py | 0 .../rollout/generate_hub/test_multi_turn.py | 572 ------------------ .../rollout/generate_hub/test_single_turn.py | 424 ------------- .../generate_hub/test_tool_call_utils.py | 99 --- tests/fast/rollout/generate_utils/__init__.py | 0 .../generate_utils/test_sample_utils.py | 156 ----- .../rollout/inference_rollout/__init__.py | 0 .../rollout/inference_rollout/conftest.py | 45 -- .../inference_rollout/integration/__init__.py | 0 .../integration/test_basic.py | 69 --- .../integration/test_deterministic.py | 37 -- .../integration/test_dynamic_filter.py | 46 -- .../integration/test_group_rm.py | 22 - .../integration/test_multi_sample.py | 65 -- .../integration/test_multi_turn.py | 114 ---- .../integration/test_over_sampling.py | 48 -- .../integration/test_sample_filter.py | 67 -- .../integration/test_semaphore.py | 33 - .../inference_rollout/integration/utils.py | 89 --- .../inference_rollout/test_compatibility.py | 196 ------ tests/fast/rollout/rm_hub/__init__.py | 0 tests/fast/rollout/rm_hub/test_deepscaler.py | 26 - tests/fast/rollout/rm_hub/test_f1.py | 44 -- tests/fast/rollout/rm_hub/test_gpqa.py | 86 --- .../rollout/rm_hub/test_math_dapo_utils.py | 108 ---- tests/fast/rollout/rm_hub/test_math_utils.py | 129 ---- tests/fast/rollout/rm_hub/test_rm_hub.py | 126 ---- tests/fast/router/__init__.py | 0 tests/fast/router/test_router.py | 204 ------- tests/fast/router/test_sessions.py | 195 ------ tests/fast/utils/__init__.py | 0 tests/fast/utils/test_arguments.py | 58 -- tests/fast/utils/test_misc.py | 59 -- tests/fast/utils/test_utils/__init__.py | 0 .../test_utils/test_mock_sglang_server.py | 409 ------------- .../fast/utils/test_utils/test_mock_tools.py | 111 ---- tests/test_external_rollout.py | 1 - tests/test_mimo_7B_mtp_only_grad.py | 1 - tests/test_moonlight_16B_A3B.py | 1 - tests/test_quick_start_glm4_9B.py | 1 - tests/test_qwen2.5_0.5B_gsm8k.py | 1 - tests/test_qwen2.5_0.5B_gsm8k_async.py | 1 - tests/test_qwen2.5_0.5B_gsm8k_async_short.py | 1 - tests/test_qwen2.5_0.5B_gsm8k_short.py | 1 - tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 1 - tests/test_qwen3_0.6B_fsdp_distributed.py | 1 - tests/test_qwen3_0.6B_megatron_fsdp_align.py | 3 - tests/test_qwen3_0.6B_parallel_check.py | 2 - tests/test_qwen3_30B_A3B.py | 1 - tests/test_qwen3_4B_ckpt.py | 1 - tests/test_qwen3_4B_fsdp_true_on_policy.py | 1 - tests/test_qwen3_4B_ppo.py | 1 - tests/test_qwen3_vl_4B_fsdp.py | 1 - tests/{fast => }/utils/test_mask_utils.py | 0 93 files changed, 54 insertions(+), 6230 deletions(-) delete mode 100644 miles/rollout/generate_hub/agentic_tool_call.py delete mode 100644 miles/rollout/generate_hub/multi_turn.py delete mode 100644 miles/rollout/generate_hub/single_turn.py delete mode 100644 miles/rollout/generate_utils/__init__.py delete mode 100644 miles/rollout/generate_utils/generate_endpoint_utils.py delete mode 100644 miles/rollout/generate_utils/openai_endpoint_utils.py delete mode 100644 miles/rollout/generate_utils/sample_utils.py delete mode 100644 miles/rollout/generate_utils/tool_call_utils.py delete mode 100644 miles/rollout/inference_rollout/__init__.py delete mode 100644 miles/rollout/inference_rollout/compatibility.py delete mode 100644 miles/rollout/inference_rollout/inference_rollout_common.py delete mode 100644 miles/rollout/inference_rollout/inference_rollout_eval.py delete mode 100644 miles/rollout/inference_rollout/inference_rollout_train.py delete mode 100644 miles/router/sessions.py delete mode 100644 miles/utils/environ.py delete mode 100644 miles/utils/test_utils/__init__.py delete mode 100644 miles/utils/test_utils/mock_sglang_server.py delete mode 100644 miles/utils/test_utils/mock_tools.py delete mode 100644 miles/utils/test_utils/uvicorn_thread_server.py delete mode 100644 tests/__init__.py delete mode 100644 tests/e2e/.gitkeep delete mode 100644 tests/fast/__init__.py delete mode 100644 tests/fast/conftest.py delete mode 100644 tests/fast/fixtures/__init__.py delete mode 100644 tests/fast/fixtures/generation_fixtures.py delete mode 100644 tests/fast/fixtures/rollout_fixtures.py delete mode 100644 tests/fast/rollout/__init__.py delete mode 100644 tests/fast/rollout/generate_hub/__init__.py delete mode 100644 tests/fast/rollout/generate_hub/test_multi_turn.py delete mode 100644 tests/fast/rollout/generate_hub/test_single_turn.py delete mode 100644 tests/fast/rollout/generate_hub/test_tool_call_utils.py delete mode 100644 tests/fast/rollout/generate_utils/__init__.py delete mode 100644 tests/fast/rollout/generate_utils/test_sample_utils.py delete mode 100644 tests/fast/rollout/inference_rollout/__init__.py delete mode 100644 tests/fast/rollout/inference_rollout/conftest.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/__init__.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/test_basic.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/test_deterministic.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/test_group_rm.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/test_multi_sample.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/test_multi_turn.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/test_over_sampling.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/test_sample_filter.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/test_semaphore.py delete mode 100644 tests/fast/rollout/inference_rollout/integration/utils.py delete mode 100644 tests/fast/rollout/inference_rollout/test_compatibility.py delete mode 100644 tests/fast/rollout/rm_hub/__init__.py delete mode 100644 tests/fast/rollout/rm_hub/test_deepscaler.py delete mode 100644 tests/fast/rollout/rm_hub/test_f1.py delete mode 100644 tests/fast/rollout/rm_hub/test_gpqa.py delete mode 100644 tests/fast/rollout/rm_hub/test_math_dapo_utils.py delete mode 100644 tests/fast/rollout/rm_hub/test_math_utils.py delete mode 100644 tests/fast/rollout/rm_hub/test_rm_hub.py delete mode 100644 tests/fast/router/__init__.py delete mode 100644 tests/fast/router/test_router.py delete mode 100644 tests/fast/router/test_sessions.py delete mode 100644 tests/fast/utils/__init__.py delete mode 100644 tests/fast/utils/test_arguments.py delete mode 100644 tests/fast/utils/test_misc.py delete mode 100644 tests/fast/utils/test_utils/__init__.py delete mode 100644 tests/fast/utils/test_utils/test_mock_sglang_server.py delete mode 100644 tests/fast/utils/test_utils/test_mock_tools.py rename tests/{fast => }/utils/test_mask_utils.py (100%) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 4b8b5dc82c..d34c823aa3 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -25,46 +25,6 @@ concurrency: jobs: - fast: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request) - runs-on: self-hosted - container: - image: radixark/miles:latest - options: > - --gpus all - --ipc=host - --shm-size=16g - --ulimit memlock=-1 - --ulimit stack=67108864 - --memory=0 - --memory-swap=0 - -v /mnt/nvme0n1/miles_ci:/data/miles_ci - -v /mnt/nvme0n1/miles_ci/models:/root/models - -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets - strategy: - fail-fast: false - matrix: - info: [{"num_gpus": 0, "test_file": "fast"}] - defaults: - run: - working-directory: ${{ github.workspace }} - env: - GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} - WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} - MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Install - shell: bash - run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - - - name: Execute - shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} - e2e-test-short: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) runs-on: self-hosted diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index c052b8494f..37b6fa4463 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -1,10 +1,4 @@ <% set jobs = { - 'fast': { - 'test_executor': 'pytest', - 'tests': [ - {'test_file': 'fast', 'num_gpus': 0}, - ], - }, 'e2e-test-short': { 'label': 'run-ci-short', 'tests': [ @@ -104,7 +98,7 @@ concurrency: jobs: <% for job_name, config in jobs.items() %> << job_name >>: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request<% if config.label %> && contains(github.event.pull_request.labels.*.name, '<< config.label >>')<% endif %>) + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, '<< config.label >>')) runs-on: self-hosted container: image: << config.image if config.image else 'radixark/miles:latest' >> @@ -159,5 +153,14 @@ jobs: - name: Execute shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- << config.test_executor | default('python') >> tests/${{ matrix.info.test_file }} + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true <% endfor %> \ No newline at end of file diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 27211845d8..79c6649be0 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -13,15 +13,8 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from miles.backends.sglang_utils.sglang_engine import SGLangEngine -from miles.rollout.base_types import ( - RolloutFnConstructorInput, - RolloutFnEvalInput, - RolloutFnTrainInput, - call_rollout_fn, -) -from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.rollout.base_types import call_rollout_fn from miles.utils import tracking_utils -from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.health_monitor import RolloutHealthMonitor from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client from miles.utils.iter_utils import group_by @@ -60,14 +53,8 @@ def __init__(self, args, pg): data_source_cls = load_function(self.args.data_source_path) self.data_source = data_source_cls(args) - self.use_experimental_refactor = enable_experimental_rollout_refactor() - if self.use_experimental_refactor: - input = RolloutFnConstructorInput(args=args, data_source=self.data_source) - self.generate_rollout = load_rollout_function(input, self.args.rollout_function_path) - self.eval_generate_rollout = load_rollout_function(input, self.args.eval_function_path) - else: - self.generate_rollout = load_function(self.args.rollout_function_path) - self.eval_generate_rollout = load_function(self.args.eval_function_path) + self.generate_rollout = load_function(self.args.rollout_function_path) + self.eval_generate_rollout = load_function(self.args.eval_function_path) self.custom_reward_post_process_func = None if self.args.custom_reward_post_process_path is not None: self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path) @@ -155,12 +142,7 @@ def eval(self, rollout_id): return self.health_monitoring_resume() - if self.use_experimental_refactor: - result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) - else: - result = call_rollout_fn( - self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True - ) + result = call_rollout_fn(self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True) data = result.data self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True) metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics) @@ -242,12 +224,7 @@ def _get_rollout_data(self, rollout_id): ) metrics = None else: - if self.use_experimental_refactor: - data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id)) - else: - data = call_rollout_fn( - self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False - ) + data = call_rollout_fn(self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False) metrics = data.metrics data = data.samples # flatten the data if it is a list of lists diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index c2644e87f9..faa85c7269 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -1,86 +1,22 @@ -from __future__ import annotations - -from argparse import Namespace from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import Any -from miles.rollout.data_source import DataSource from miles.utils.types import Sample -if TYPE_CHECKING: - from miles.rollout.inference_rollout.inference_rollout_common import GenerateState - - -@dataclass(frozen=True) -class RolloutFnConstructorInput: - args: Namespace - # TODO may refactor DataSource API - data_source: DataSource - - -@dataclass(frozen=True) -class RolloutFnBaseInput: - rollout_id: int - - @property - def evaluation(self): - raise NotImplementedError - - -# subclassing for different data in the future -@dataclass(frozen=True) -class RolloutFnTrainInput(RolloutFnBaseInput): - @property - def evaluation(self): - return False - -@dataclass(frozen=True) -class RolloutFnEvalInput(RolloutFnBaseInput): - @property - def evaluation(self): - return True - - -# TODO make it frozen @dataclass class RolloutFnTrainOutput: samples: list[list[Sample]] metrics: dict[str, Any] = None -# TODO make it frozen @dataclass class RolloutFnEvalOutput: data: dict[str, dict[str, Any]] metrics: dict[str, Any] = None -RolloutFnInput = RolloutFnTrainInput | RolloutFnEvalInput -RolloutFnOutput = RolloutFnTrainOutput | RolloutFnEvalOutput - - -@dataclass(frozen=True) -class GenerateFnInput: - state: GenerateState - sample: Sample - sampling_params: dict[str, Any] - evaluation: bool - - @property - def args(self) -> Namespace: - return self.state.args - - -@dataclass(frozen=True) -class GenerateFnOutput: - # One generate may lead to multiple samples, such as multi-agent, tree-like exploration, or - # multi-turn with removing thinking tokens. - samples: Sample | list[Sample] - - def call_rollout_fn(fn, *args, evaluation: bool, **kwargs): - """Legacy rollout function call interface. Used when MILES_EXPERIMENTAL_ROLLOUT_REFACTOR is disabled.""" output = fn(*args, **kwargs, evaluation=evaluation) # compatibility for legacy version diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py deleted file mode 100644 index 05223a6544..0000000000 --- a/miles/rollout/generate_hub/agentic_tool_call.py +++ /dev/null @@ -1,85 +0,0 @@ -""" -Simple agentic demo with tool calling. -""" - -import argparse -from copy import deepcopy -from typing import Any - -from openai import AsyncOpenAI - -from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput -from miles.rollout.generate_utils.openai_endpoint_utils import ( - OpenAIEndpointTracer, - compute_samples_from_openai_records, -) -from miles.rollout.generate_utils.sample_utils import merge_samples -from miles.rollout.generate_utils.tool_call_utils import execute_tool_calls -from miles.utils.misc import load_function - - -async def generate(input: GenerateFnInput) -> GenerateFnOutput: - tracer = await OpenAIEndpointTracer.create(input.args) - - await _run_blackbox_tool_call_agent( - base_url=tracer.base_url, - prompt=input.sample.prompt, - max_turns=input.args.generate_max_turns, - tool_specs_path=input.args.generate_tool_specs_path, - execute_tool_function_path=input.args.generate_execute_tool_function_path, - ) - - records = await tracer.collect_records() - samples = compute_samples_from_openai_records(input.sample, records, input.state.tokenizer) - if not input.args.generate_multi_samples: - samples = merge_samples(samples, input.state.tokenizer) - return GenerateFnOutput(samples=samples) - - -def _add_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--generate-max-turns", type=int, default=16) - parser.add_argument("--generate-tool-specs-path", type=str) - parser.add_argument("--generate-execute-tool-function-path", type=str) - parser.add_argument("--generate-multi-samples", action="store_true") - - -generate.add_arguments = _add_arguments - - -async def _run_blackbox_tool_call_agent( - base_url: str, - prompt: list[dict[str, Any]], - max_turns: int, - tool_specs_path: str, - execute_tool_function_path: str, -): - """ - Imagine this is a black-box agent, e.g. SWE-agent, which does arbitrarily complex work, - only understands OpenAI compatible API, and never understands Miles or the Sample data structure. - """ - - # ----------------------- Setup ------------------------- - - client = AsyncOpenAI(base_url=base_url, api_key="empty") - execute_tool_function = load_function(execute_tool_function_path) - tool_specs = load_function(tool_specs_path) - - # ----------------------- Initial prompts ------------------------- - - messages = deepcopy(prompt) - - for _turn in range(max_turns): - # ----------------------- Call inference endpoint ------------------------- - - response = await client.chat.completions.create(model="default", messages=messages, tools=tool_specs) - - choice = response.choices[0] - messages.append(choice.message.model_dump()) - - if choice.finish_reason in ("stop", "length"): - break - - # ----------------------- Execute tools ------------------------- - - if x := choice.message.tool_calls: - messages += await execute_tool_calls(x, execute_tool_function) diff --git a/miles/rollout/generate_hub/multi_turn.py b/miles/rollout/generate_hub/multi_turn.py deleted file mode 100644 index 97814ecb3d..0000000000 --- a/miles/rollout/generate_hub/multi_turn.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Simple multi-turn generation with tool calling. -""" - -import argparse -from copy import deepcopy - -from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput -from miles.rollout.generate_utils.generate_endpoint_utils import ( - compute_prompt_ids_from_sample, - compute_request_payload, - update_sample_from_response, -) -from miles.rollout.generate_utils.tool_call_utils import ( - create_tool_call_parser, - execute_tool_calls, - update_sample_with_tool_responses, -) -from miles.utils.http_utils import post -from miles.utils.misc import load_function - - -async def generate(input: GenerateFnInput) -> GenerateFnOutput: - # ----------------------- Setup ------------------------- - - args = input.args - sample = deepcopy(input.sample) - tokenizer = input.state.tokenizer - assert not args.partial_rollout, "Partial rollout is not supported" - - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - - execute_tool_function = load_function(args.generate_execute_tool_function_path) - - tool_specs = load_function(args.generate_tool_specs_path) - tool_call_parser = create_tool_call_parser(tool_specs, args.generate_tool_call_parser) - - multi_samples = [] - - # ----------------------- Initial prompts ------------------------- - - prompt_tokens_ids = compute_prompt_ids_from_sample(input.state, sample, tools=tool_specs) - - sample.tokens = prompt_tokens_ids.copy() - - for _turn in range(args.generate_max_turns): - # ----------------------- Call inference endpoint ------------------------- - - payload, halt_status = compute_request_payload(args, sample.tokens, input.sampling_params) - if payload is None: - sample.status = halt_status - if args.generate_multi_samples and multi_samples: - multi_samples[-1].status = halt_status - break - - if args.generate_multi_samples: - sample = deepcopy(input.sample) - - output = await post(url, payload) - await update_sample_from_response(args, sample, payload=payload, output=output, update_loss_mask=True) - - if args.generate_multi_samples: - multi_samples.append(deepcopy(sample)) - - if output["meta_info"]["finish_reason"]["type"] in ("abort", "length"): - break - - # ----------------------- Execute tools ------------------------- - - _, tool_calls = tool_call_parser.parse_non_stream(output["text"]) - if len(tool_calls) == 0: - break - - tool_messages = await execute_tool_calls(tool_calls, execute_tool_function) - update_sample_with_tool_responses(sample, tool_messages, tokenizer=tokenizer) - - return GenerateFnOutput(samples=multi_samples if args.generate_multi_samples else sample) - - -def _add_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--generate-max-turns", type=int, default=16) - parser.add_argument("--generate-tool-specs-path", type=str) - parser.add_argument("--generate-tool-call-parser", type=str) - parser.add_argument("--generate-execute-tool-function-path", type=str) - parser.add_argument("--generate-multi-samples", action="store_true") - - -generate.add_arguments = _add_arguments diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py deleted file mode 100644 index 5c0a15b5b4..0000000000 --- a/miles/rollout/generate_hub/single_turn.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Simple single-turn generation. -""" - -from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput -from miles.rollout.generate_utils.generate_endpoint_utils import ( - compute_prompt_ids_from_sample, - compute_request_payload, - update_sample_from_response, -) -from miles.utils.http_utils import post -from miles.utils.types import Sample - - -async def generate(input: GenerateFnInput) -> GenerateFnOutput: - args = input.args - sample = input.sample - sampling_params = input.sampling_params - assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}" - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - - prompt_ids = compute_prompt_ids_from_sample(input.state, sample) - - # Handle Partial Rollout resuming - if len(sample.response) > 0: - input_ids = sample.tokens - sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) - - assert sampling_params["max_new_tokens"] >= 0 - if sampling_params["max_new_tokens"] == 0: - sample.status = Sample.Status.TRUNCATED - return GenerateFnOutput(samples=sample) - else: - input_ids = prompt_ids - - payload, halt_status = compute_request_payload( - args, input_ids=input_ids, sampling_params=sampling_params, multimodal_inputs=sample.multimodal_inputs - ) - if payload is None: - sample.status = halt_status - return GenerateFnOutput(samples=sample) - - output = await post(url, payload) - await update_sample_from_response(args, sample, payload=payload, output=output) - - return GenerateFnOutput(samples=sample) diff --git a/miles/rollout/generate_utils/__init__.py b/miles/rollout/generate_utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/miles/rollout/generate_utils/generate_endpoint_utils.py b/miles/rollout/generate_utils/generate_endpoint_utils.py deleted file mode 100644 index a91d71f1de..0000000000 --- a/miles/rollout/generate_utils/generate_endpoint_utils.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -Utils to integrate SGLang's `/generate` endpoint with RL things like Sample. -""" - -from copy import deepcopy -from typing import Any - -import numpy as np -import pybase64 - -from miles.utils.processing_utils import encode_image_for_rollout_engine -from miles.utils.types import Sample - - -# Make this an isolated function because users may want to compute their own -def compute_prompt_ids_from_sample(state, sample, tools=None): - prompt = sample.prompt - - if state.processor: - processor_output = state.processor(text=prompt, **sample.multimodal_inputs) - prompt_ids = processor_output["input_ids"][0] - - # TODO shall we move it to other places? then can make this function immutable - sample.multimodal_train_inputs = { - k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] - } or None - - return prompt_ids - else: - if not isinstance(prompt, str): - prompt = state.tokenizer.apply_chat_template( - prompt, tokenize=False, add_generation_prompt=True, tools=tools - ) - - return state.tokenizer.encode(prompt, add_special_tokens=False) - - -def compute_request_payload( - args, - input_ids: list[int], - sampling_params: dict, - multimodal_inputs: dict | None = None, -) -> tuple[dict[str, Any] | None, Sample.Status | None]: - sampling_params = deepcopy(sampling_params) - max_new_tokens = sampling_params.pop("max_new_tokens", args.rollout_max_response_len) - if x := args.rollout_max_context_len: - max_new_tokens = min(max_new_tokens, x - len(input_ids)) - if max_new_tokens <= 0: - return None, Sample.Status.TRUNCATED - - payload = { - "input_ids": input_ids, - "sampling_params": {**sampling_params, "max_new_tokens": max_new_tokens}, - "return_logprob": True, - "return_routed_experts": args.use_rollout_routing_replay, - } - if image_data := (multimodal_inputs or {}).get("images"): - payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] - - return payload, None - - -async def update_sample_from_response( - args, sample: Sample, payload: dict, output: dict, update_loss_mask: bool = False -): - # Initialize sample.tokens for the first turn - if (len(sample.response) == 0) and not sample.tokens: - sample.tokens = payload["input_ids"] - - if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: - from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree - - # TODO may rename to match - await postprocess_sample_with_radix_tree(args, sample, output) - - assert not update_loss_mask, "This code branch has not implemented update_loss_mask" - else: - if x := output["meta_info"].get("output_token_logprobs"): - new_response_tokens = [item[1] for item in x] - new_response_log_probs = [item[0] for item in x] - else: - new_response_tokens, new_response_log_probs = [], [] - - # Update sample with tokens directly - avoiding re-tokenization - sample.tokens = sample.tokens + new_response_tokens - sample.response_length += len(new_response_tokens) - sample.response += output["text"] - - if sample.rollout_log_probs is None: - sample.rollout_log_probs = [] - sample.rollout_log_probs += new_response_log_probs - - if update_loss_mask: - if sample.loss_mask is None: - sample.loss_mask = [] - sample.loss_mask += [1] * len(new_response_tokens) - - # TODO handle multi-turn cases (may need concat instead of assignment) - sample.rollout_routed_experts = _get_rollout_routed_experts_from_response(args, sample, output) - - # TODO may unify (currently there are both methods inside Sample and separate functions) - sample.update_from_meta_info(args, output["meta_info"]) - - -def _get_rollout_routed_experts_from_response(args, sample, output): - info = output["meta_info"].get("routed_experts") - if info is None: - return None - - x = np.frombuffer(pybase64.b64decode(info.encode("ascii")), dtype=np.int32) - x = x.reshape(len(sample.tokens) - 1, args.num_layers, args.moe_router_topk) - return x diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py deleted file mode 100644 index 73ba8198bf..0000000000 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -Utilities for the OpenAI endpoint -""" - -import logging -from argparse import Namespace -from copy import deepcopy - -from miles.router.sessions import GetSessionResponse, SessionRecord -from miles.utils.http_utils import post -from miles.utils.types import Sample - -logger = logging.getLogger(__name__) - - -class OpenAIEndpointTracer: - def __init__(self, router_url: str, session_id: str): - self.router_url = router_url - self.session_id = session_id - self.base_url = f"{router_url}/sessions/{session_id}/v1" - - @staticmethod - async def create(args: Namespace): - router_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" - session_id = (await post(f"{router_url}/sessions", {}))["session_id"] - return OpenAIEndpointTracer(router_url=router_url, session_id=session_id) - - async def collect_records(self) -> list[SessionRecord]: - response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="get") - response = GetSessionResponse.model_validate(response) - records = response.records - - try: - await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="delete") - except Exception as e: - logger.warning(f"Failed to delete session {self.session_id} after collecting records: {e}") - - return records - - -def compute_samples_from_openai_records(input_sample: Sample, records: list[SessionRecord], tokenizer) -> list[Sample]: - return [_compute_sample_from_openai_record(input_sample, record, tokenizer) for record in records] - - -def _compute_sample_from_openai_record(input_sample: Sample, record: SessionRecord, tokenizer) -> Sample: - # TODO may refine after @guapisolo's implementation - choice = record.response["choices"][0] - output_token_ids = [item["token_id"] for item in choice["logprobs"]["content"]] - output_log_probs = [item["logprob"] for item in choice["logprobs"]["content"]] - - sample = deepcopy(input_sample) - sample.tokens = record.request["input_ids"] + output_token_ids - sample.rollout_log_probs = output_log_probs - sample.response = tokenizer.decode(output_token_ids) - sample.response_length = len(output_token_ids) - sample.loss_mask = [1] * len(output_token_ids) - - # TODO unify with Sample.update_from_meta_info - match choice["finish_reason"]: - case "stop" | "tool_calls": - sample.status = Sample.Status.COMPLETED - case "length": - sample.status = Sample.Status.TRUNCATED - case "abort": - sample.status = Sample.Status.ABORTED - - return sample diff --git a/miles/rollout/generate_utils/sample_utils.py b/miles/rollout/generate_utils/sample_utils.py deleted file mode 100644 index 6a4e645be5..0000000000 --- a/miles/rollout/generate_utils/sample_utils.py +++ /dev/null @@ -1,115 +0,0 @@ -from copy import deepcopy -from dataclasses import fields - -from miles.utils.types import Sample - - -def merge_samples(samples: list[Sample], tokenizer) -> Sample: - acc = samples[0] - for sample in samples[1:]: - acc = _merge_sample_pair(acc, sample, tokenizer=tokenizer) - return acc - - -def _merge_sample_pair(a: Sample, b: Sample, tokenizer) -> Sample: - """Merge two samples generated from sibling inference engine calls.""" - a, b = deepcopy(a), deepcopy(b) - - def _merge_equal_value(field): - x = getattr(a, field) - y = getattr(b, field) - assert x == y, f"{field} mismatch: a.{field}={x}, b.{field}={y}" - return x - - def _fill_defaults(sample: Sample): - if sample.loss_mask is None: - sample.loss_mask = [1] * sample.response_length - if sample.rollout_log_probs is None: - sample.rollout_log_probs = [0.0] * sample.response_length - - _fill_defaults(a) - _fill_defaults(b) - - obs_len = len(b.tokens) - len(a.tokens) - b.response_length - obs_tokens = b.tokens[len(a.tokens) : len(a.tokens) + obs_len] - # TODO: is this acceptable? - obs_text = tokenizer.decode(obs_tokens) - - try: - a.validate() - b.validate() - assert _startswith(short=a.prompt, long=b.prompt), "b.prompt must start with a.prompt" - assert _startswith(short=a.tokens, long=b.tokens), "b.tokens must start with a.tokens" - assert obs_len > 0, f"obs_len must be > 0, got {obs_len}" - if a.rollout_routed_experts is not None: - assert a.rollout_routed_experts.shape[0] <= b.rollout_routed_experts.shape[0] - assert a.status == Sample.Status.COMPLETED, f"a.status must be COMPLETED, got {a.status}" - - return _create_with_all_fields( - Sample, - group_index=_merge_equal_value("group_index"), - index=_merge_equal_value("index"), - prompt=b.prompt, - tokens=b.tokens, - multimodal_inputs=_merge_equal_value("multimodal_inputs"), - multimodal_train_inputs=_merge_equal_value("multimodal_train_inputs"), - response=a.response + obs_text + b.response, - response_length=a.response_length + obs_len + b.response_length, - label=_merge_equal_value("label"), - reward=_merge_equal_value("reward"), - loss_mask=a.loss_mask + [0] * obs_len + b.loss_mask, - weight_versions=a.weight_versions + b.weight_versions, - rollout_log_probs=a.rollout_log_probs + [0.0] * obs_len + b.rollout_log_probs, - rollout_routed_experts=b.rollout_routed_experts, - remove_sample=_merge_equal_value("remove_sample"), - status=b.status, - metadata=_merge_equal_value("metadata"), - train_metadata=_merge_equal_value("train_metadata"), - non_generation_time=_merge_equal_value("non_generation_time"), - spec_info=_merge_spec_info(a.spec_info, b.spec_info), - prefix_cache_info=_merge_prefix_cache_info(a.prefix_cache_info, b.prefix_cache_info), - ) - except AssertionError as e: - e.add_note(f"{a=} {b=}") - raise - - -def _merge_spec_info(a: Sample.SpecInfo, b: Sample.SpecInfo) -> Sample.SpecInfo: - def _merge_plus_value(field): - return getattr(a, field) + getattr(b, field) - - return _create_with_all_fields( - Sample.SpecInfo, - spec_accept_token_num=_merge_plus_value("spec_accept_token_num"), - spec_draft_token_num=_merge_plus_value("spec_draft_token_num"), - spec_verify_ct=_merge_plus_value("spec_verify_ct"), - completion_token_num=_merge_plus_value("completion_token_num"), - ) - - -def _merge_prefix_cache_info(a: Sample.PrefixCacheInfo, b: Sample.PrefixCacheInfo) -> Sample.PrefixCacheInfo: - def _merge_plus_value(field): - return getattr(a, field) + getattr(b, field) - - return _create_with_all_fields( - Sample.PrefixCacheInfo, - cached_tokens=_merge_plus_value("cached_tokens"), - total_prompt_tokens=_merge_plus_value("total_prompt_tokens"), - ) - - -def _create_with_all_fields(cls, **kwargs): - expected = {f.name for f in fields(cls)} - actual = set(kwargs.keys()) - assert ( - expected == actual - ), f"{cls.__name__} field mismatch. Missing: {expected - actual}, Extra: {actual - expected}" - return cls(**kwargs) - - -def _startswith(*, short, long) -> bool: - if isinstance(short, str) and isinstance(long, str): - return long.startswith(short) - if isinstance(short, list) and isinstance(long, list): - return (len(long) >= len(short)) and (long[: len(short)] == short) - raise NotImplementedError diff --git a/miles/rollout/generate_utils/tool_call_utils.py b/miles/rollout/generate_utils/tool_call_utils.py deleted file mode 100644 index 85ea87aeab..0000000000 --- a/miles/rollout/generate_utils/tool_call_utils.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -Utils to handle tool calls. -""" - -import json -import uuid -from collections.abc import Callable -from typing import Any - -from openai.types.chat import ChatCompletionMessageToolCall -from pydantic import TypeAdapter -from sglang.srt.entrypoints.openai.protocol import Tool -from sglang.srt.function_call.core_types import ToolCallItem -from sglang.srt.function_call.function_call_parser import FunctionCallParser - -from miles.utils.types import Sample - -_DUMMY_USER = {"role": "user", "content": "dummy"} - - -def create_tool_call_parser(tool_specs, tool_call_parser): - return FunctionCallParser( - tools=TypeAdapter(list[Tool]).validate_python(tool_specs), - tool_call_parser=tool_call_parser, - ) - - -async def execute_tool_calls( - tool_calls: list[ToolCallItem | ChatCompletionMessageToolCall], - execute_one: Callable, -) -> list[dict[str, Any]]: - tool_messages = [] - for call in tool_calls: - tool_messages.append(await _execute_tool_call(call, execute_one)) - return tool_messages - - -async def _execute_tool_call( - call: ToolCallItem | ChatCompletionMessageToolCall, execute_one: Callable -) -> dict[str, Any]: - if isinstance(call, ChatCompletionMessageToolCall): - name = call.function.name - params = json.loads(call.function.arguments) if call.function.arguments else {} - tool_call_id = call.id - elif isinstance(call, ToolCallItem): - name = call.name - params = json.loads(call.parameters) if call.parameters else {} - tool_call_id = f"call_{uuid.uuid4().hex[:24]}" - else: - raise TypeError(f"Unsupported tool call type: {type(call)}") - - result = await execute_one(name, params) - assert isinstance(result, str) - - return {"role": "tool", "tool_call_id": tool_call_id, "content": result, "name": name} - - -def update_sample_with_tool_responses(sample: Sample, tool_messages: list[dict[str, Any]], tokenizer): - next_obs_tokens_ids: list[int] = tokenize_tool_responses(tool_messages, tokenizer=tokenizer) - sample.response += tokenizer.decode(next_obs_tokens_ids) - sample.response_length += len(next_obs_tokens_ids) - sample.tokens += next_obs_tokens_ids - sample.loss_mask += [0] * len(next_obs_tokens_ids) - sample.rollout_log_probs += [0.0] * len(next_obs_tokens_ids) - - -# TODO: very naive implementation, need the to-be-implemented e2e test to validate. -def tokenize_tool_responses( - tool_messages: list[dict[str, Any]], - tokenizer, -) -> list[int]: - return _tokenize_postfix_messages(tool_messages, tokenizer) - - -def _tokenize_postfix_messages( - postfix_messages: list[dict[str, Any]], - tokenizer, -) -> list[int]: - dummy_assistant = _build_dummy_assistant(postfix_messages) - base_messages = [_DUMMY_USER, dummy_assistant] - - messages_without = base_messages - messages_with = base_messages + postfix_messages - - tokens_with = tokenizer.apply_chat_template(messages_with, tokenize=True, add_generation_prompt=True) - tokens_without = tokenizer.apply_chat_template(messages_without, tokenize=True, add_generation_prompt=False) - - assert tokens_with[: len(tokens_without)] == tokens_without, ( - f"Fail to tokenize_tool_responses caused by token prefix mismatch. " - f"This can happen for thinking model or models with special chat template, " - f"and this simple example does not support it yet, " - f"since this means we cannot have a append-only token id list. " - f"{tokens_with=} {tokens_without=} " - f"{tokenizer.decode(tokens_with)=} {tokenizer.decode(tokens_without)=} " - ) - return tokens_with[len(tokens_without) :] - - -def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, Any]: - return { - "role": "assistant", - "content": "", - "reasoning_content": " ", - "tool_calls": [ - { - "id": resp.get("tool_call_id", f"call0000{i}"), - "type": "function", - "function": { - "name": resp.get("name", "dummy_func"), - "arguments": {}, - }, - } - for i, resp in enumerate(tool_responses) - ], - } diff --git a/miles/rollout/inference_rollout/__init__.py b/miles/rollout/inference_rollout/__init__.py deleted file mode 100644 index 33ccf17bfb..0000000000 --- a/miles/rollout/inference_rollout/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# This is a refactor of the portions above generate-function in sglang_rollout.py, -# and is give a different name to ensure both code exist at the same time. diff --git a/miles/rollout/inference_rollout/compatibility.py b/miles/rollout/inference_rollout/compatibility.py deleted file mode 100644 index 7711e0dd31..0000000000 --- a/miles/rollout/inference_rollout/compatibility.py +++ /dev/null @@ -1,84 +0,0 @@ -import inspect -from collections.abc import Callable - -from miles.rollout.base_types import ( - GenerateFnInput, - GenerateFnOutput, - RolloutFnConstructorInput, - RolloutFnEvalOutput, - RolloutFnInput, - RolloutFnOutput, - RolloutFnTrainOutput, -) -from miles.utils.async_utils import run -from miles.utils.misc import load_function - - -class LegacyRolloutFnAdapter: - def __init__(self, input: RolloutFnConstructorInput, fn: Callable): - self.args = input.args - self.data_source = input.data_source - self.fn = fn - - def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: - output = self.fn(self.args, input.rollout_id, self.data_source, evaluation=input.evaluation) - - # compatibility for legacy version - if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)): - output = RolloutFnEvalOutput(data=output) if input.evaluation else RolloutFnTrainOutput(samples=output) - - return output - - -def load_rollout_function(input: RolloutFnConstructorInput, path: str): - fn = load_function(path) - - if inspect.isclass(fn): - return fn(input) - else: - return LegacyRolloutFnAdapter(input, fn) - - -def call_rollout_function(fn, input: RolloutFnInput) -> RolloutFnOutput: - output = fn(input) - - if inspect.iscoroutine(output): - output = run(output) - - return output - - -class LegacyGenerateFnAdapter: - def __init__(self, fn: Callable): - self.fn = fn - self._has_evaluation_param = "evaluation" in inspect.signature(fn).parameters - - async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: - if self._has_evaluation_param: - output = await self.fn(input.args, input.sample, input.sampling_params, evaluation=input.evaluation) - else: - output = await self.fn(input.args, input.sample, input.sampling_params) - - if not isinstance(output, GenerateFnOutput): - output = GenerateFnOutput(samples=output) - - return output - - -def load_generate_function(path: str): - fn = load_function(path) - if fn is None: - return None - - if inspect.isclass(fn): - return fn() - elif _is_legacy_generate_fn(fn): - return LegacyGenerateFnAdapter(fn) - else: - return fn - - -def _is_legacy_generate_fn(fn: Callable) -> bool: - sig = inspect.signature(fn) - params = list(sig.parameters.keys()) - return len(params) >= 3 and params[0] != "input" diff --git a/miles/rollout/inference_rollout/inference_rollout_common.py b/miles/rollout/inference_rollout/inference_rollout_common.py deleted file mode 100644 index 8518c6e020..0000000000 --- a/miles/rollout/inference_rollout/inference_rollout_common.py +++ /dev/null @@ -1,192 +0,0 @@ -import asyncio -import logging -from argparse import Namespace -from copy import deepcopy -from typing import Any - -from miles.rollout.base_types import ( - GenerateFnInput, - RolloutFnConstructorInput, - RolloutFnEvalInput, - RolloutFnEvalOutput, - RolloutFnInput, - RolloutFnOutput, - RolloutFnTrainInput, - RolloutFnTrainOutput, -) -from miles.rollout.generate_hub.single_turn import generate -from miles.rollout.inference_rollout.compatibility import load_generate_function -from miles.rollout.rm_hub import async_rm, batched_async_rm -from miles.utils.processing_utils import load_processor, load_tokenizer -from miles.utils.types import Sample - -logger = logging.getLogger(__name__) - - -class GenerateState: - def __init__(self, args: Namespace) -> None: - # persistent state for the generation process - self.args = args - self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) - self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) - - self.generate_fn_semaphore = asyncio.Semaphore( - args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine - ) - self.sampling_params: dict[str, Any] = compute_sampling_params( - args, - temperature=args.rollout_temperature, - top_p=args.rollout_top_p, - top_k=args.rollout_top_k, - max_new_tokens=args.rollout_max_response_len, - ) - - self.generate_function = load_generate_function(args.custom_generate_function_path) or generate - - self.reset() - - def reset(self) -> None: - self.aborted = False - - -async def generate_and_rm( - state: GenerateState, - sample: Sample | list[Sample], - sampling_params: dict[str, Any], - evaluation: bool = False, -) -> Sample | list[Sample]: - args = state.args - - # mask previous off-policy generation for partial rollout - if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: - sample.loss_mask = [0] * sample.response_length - - # For samples with existing response, check if they're complete - if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED: - assert sample.response is not None - if not args.group_rm: - assert sample.reward is not None - return sample - - # generate - async with state.generate_fn_semaphore: - if state.aborted: - sample.status = Sample.Status.ABORTED - return sample - - output = await state.generate_function( - GenerateFnInput( - state=state, - sample=sample, - sampling_params=deepcopy(sampling_params), - evaluation=evaluation, - ) - ) - sample = output.samples - - # TODO change to `if not args.group_rm: do reward model` for more clarity after the refactor below - # for the rm that need the whole group, we will not do the rm here - if args.group_rm: - return sample - - # TODO: unify the two branches into one if we decide to use list as output type - # multi samples - if isinstance(sample, list): - samples = sample - if any([sample.status == Sample.Status.ABORTED for sample in samples]): - return samples - - # for multi agent system, the reward of some sample is calculated during generation. - samples_need_reward = [sample for sample in samples if sample.reward is None] - await batched_async_rm(args, samples_need_reward, inplace_set_reward_field=True) - return samples - else: - if sample.status == Sample.Status.ABORTED: - return sample - # for multi-turn environment, a reward could be assigned to the agent. - if sample.reward is None: - sample.reward = await async_rm(args, sample) - - return sample - - -async def generate_and_rm_group( - state: GenerateState, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False -) -> list[Sample]: - args = state.args - - if state.aborted: - return group - - tasks = [] - for idx, sample in enumerate(group): - current_sampling_params = sampling_params.copy() - if getattr(args, "sglang_enable_deterministic_inference", False): - current_sampling_params["sampling_seed"] = args.rollout_seed + idx - tasks.append( - asyncio.create_task(generate_and_rm(state, sample, current_sampling_params, evaluation=evaluation)) - ) - - group = await asyncio.gather(*tasks) - if state.aborted: - return group - - if args.group_rm: - await batched_async_rm(args, group, inplace_set_reward_field=True) - - return group - - -def compute_sampling_params( - args, - *, - # after unifying configuration, this can be further refactored - temperature, - top_p, - top_k, - max_new_tokens, -): - return dict( - temperature=temperature, - top_p=top_p, - top_k=top_k, - max_new_tokens=max_new_tokens, - stop=args.rollout_stop, - stop_token_ids=args.rollout_stop_token_ids, - skip_special_tokens=args.rollout_skip_special_tokens, - no_stop_trim=True, - spaces_between_special_tokens=False, - ) - - -class InferenceRolloutFn: - def __init__(self, input: RolloutFnConstructorInput): - self.data_source = input.data_source - self.state = GenerateState(input.args) - self.eval_prompt_dataset_cache = {} - - async def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: - if input.evaluation: - return await self._call_eval(input) - return await self._call_train(input) - - async def _call_train(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput: - from miles.rollout.inference_rollout.inference_rollout_train import generate_rollout_async - - output, aborted_samples = await generate_rollout_async( - self.state, input.rollout_id, self.data_source.get_samples - ) - self.data_source.add_samples(aborted_samples) - return output - - async def _call_eval(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput: - from miles.rollout.inference_rollout.inference_rollout_eval import eval_rollout_single_dataset - - assert not self.state.args.group_rm, "Group RM is not supported for eval rollout" - - coros = [] - for dataset_cfg in getattr(self.state.args, "eval_datasets", []) or []: - coros.append(eval_rollout_single_dataset(self.state, dataset_cfg, self.eval_prompt_dataset_cache)) - results_list = await asyncio.gather(*coros) - results = {k: v for r in results_list for k, v in r.items()} - return RolloutFnEvalOutput(data=results) diff --git a/miles/rollout/inference_rollout/inference_rollout_eval.py b/miles/rollout/inference_rollout/inference_rollout_eval.py deleted file mode 100644 index 2d052be0ae..0000000000 --- a/miles/rollout/inference_rollout/inference_rollout_eval.py +++ /dev/null @@ -1,112 +0,0 @@ -import asyncio -import copy -import logging -from typing import Any - -from tqdm import tqdm - -from miles.rollout.inference_rollout.inference_rollout_common import ( - GenerateState, - compute_sampling_params, - generate_and_rm, -) -from miles.utils.data import Dataset -from miles.utils.eval_config import EvalDatasetConfig -from miles.utils.misc import as_completed_async -from miles.utils.processing_utils import load_processor, load_tokenizer -from miles.utils.types import Sample - -logger = logging.getLogger(__name__) - - -async def eval_rollout_single_dataset( - state: GenerateState, - dataset_cfg: EvalDatasetConfig, - prompt_dataset_cache: dict[Any, Dataset], -) -> dict[str, dict[str, list[Any]]]: - args = state.args - assert not args.group_rm, "Group RM is not supported for eval rollout" - - cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template) - if cache_key not in prompt_dataset_cache: - tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) - processor = load_processor(args.hf_checkpoint, trust_remote_code=True) - prompt_dataset_cache[cache_key] = Dataset( - path=dataset_cfg.path, - tokenizer=tokenizer, - processor=processor, - max_length=args.eval_max_prompt_len, - prompt_key=dataset_cfg.input_key, - label_key=dataset_cfg.label_key, - multimodal_keys=args.multimodal_keys, - metadata_key=dataset_cfg.metadata_key, - tool_key=dataset_cfg.tool_key, - apply_chat_template=args.apply_chat_template, - apply_chat_template_kwargs=args.apply_chat_template_kwargs, - ) - dataset = prompt_dataset_cache[cache_key] - - base_sampling_params = compute_sampling_params( - args, - temperature=dataset_cfg.temperature, - top_p=dataset_cfg.top_p, - top_k=dataset_cfg.top_k, - max_new_tokens=dataset_cfg.max_response_len, - ) - - tasks = [] - # do multiple samples for eval prompts - sample_index = 0 - for _i, prompt_sample in enumerate(dataset.samples): - for j in range(dataset_cfg.n_samples_per_eval_prompt): - # use the same prompt for multiple samples - sample = copy.deepcopy(prompt_sample) - sample.index = sample_index - sample_index += 1 - sample.metadata = dataset_cfg.inject_metadata(getattr(sample, "metadata", None)) - sampling_params = base_sampling_params - if getattr(args, "sglang_enable_deterministic_inference", False): - sampling_params = base_sampling_params.copy() - sampling_params["sampling_seed"] = args.rollout_seed + j - tasks.append( - asyncio.create_task( - generate_and_rm( - state, - sample, - sampling_params=sampling_params, - evaluation=True, - ) - ) - ) - - data = [] - do_print = True - pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) - async for sample in as_completed_async(tasks): - if do_print: - # TODO improve this after enhancing samples' type - s = (sample[0] if len(sample) > 0 else None) if isinstance(sample, list) else sample - if s is not None: - logger.info( - "eval_rollout_single_dataset example data: " - f"{[str(s.prompt) + s.response]} " - f"reward={s.reward}" - ) - do_print = False - if isinstance(sample, list): - data.extend(sample) - else: - data.append(sample) - pbar.update(1) - pbar.close() - - data.sort(key=lambda sample: sample.index) - - reward_key = args.eval_reward_key or args.reward_key - return { - dataset_cfg.name: { - "rewards": [sample.reward if not reward_key else sample.reward[reward_key] for sample in data], - "truncated": [sample.status == Sample.Status.TRUNCATED for sample in data], - "samples": data, - } - } diff --git a/miles/rollout/inference_rollout/inference_rollout_train.py b/miles/rollout/inference_rollout/inference_rollout_train.py deleted file mode 100644 index bae94ec67b..0000000000 --- a/miles/rollout/inference_rollout/inference_rollout_train.py +++ /dev/null @@ -1,146 +0,0 @@ -import asyncio -import logging -from argparse import Namespace -from collections.abc import Callable - -import sglang_router -from packaging.version import parse -from tqdm import tqdm - -from miles.rollout.base_types import RolloutFnTrainOutput -from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter -from miles.rollout.inference_rollout.inference_rollout_common import GenerateState, generate_and_rm_group -from miles.utils.http_utils import get, post -from miles.utils.misc import as_completed_async, load_function -from miles.utils.types import Sample - -logger = logging.getLogger(__name__) - - -async def abort(state: GenerateState, pendings: set, rollout_id: int) -> list[list[Sample]]: - args = state.args - - assert not state.aborted - state.aborted = True - - urls = await get_worker_urls(args) - logger.info(f"Abort request for {urls}") - await asyncio.gather(*[post(f"{url}/abort_request", {"abort_all": True}) for url in urls]) - - # make sure all the pending tasks are finished - aborted_samples = [] - async for group in as_completed_async(pendings): - if not args.partial_rollout: - continue - - # for partial rollout, collect the partial samples into the data buffer - for sample in group: - if sample.response and "start_rollout_id" not in sample.metadata: - sample.metadata["start_rollout_id"] = rollout_id - aborted_samples.append(group) - - if args.partial_rollout: - logger.info(f"Collected {sum(len(x) for x in aborted_samples)} partial samples into the data buffer") - - return aborted_samples - - -async def get_worker_urls(args: Namespace): - if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router: - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") - return response["urls"] - else: - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") - return [worker["url"] for worker in response["workers"]] - - -def submit_generate_tasks(state: GenerateState, samples: list[list[Sample]]): - return [ - asyncio.create_task( - # submit a group of samples as a single task. - generate_and_rm_group( - state, - group, - sampling_params=state.sampling_params.copy(), - evaluation=False, - ) - ) - for group in samples - ] - - -async def generate_rollout_async( - state: GenerateState, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] -) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: - args = state.args - assert args.rollout_global_dataset - - # instantiate data filters - dynamic_filter = load_function(args.dynamic_sampling_filter_path) - - metric_gatherer = MetricGatherer() - - # target_data_size is the total number of valid samples to get - target_data_size = args.rollout_batch_size - - pendings = set() - data = [] - all_data = [] - do_print = True - pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation") - while len(data) < target_data_size: - while len(data) + len(pendings) < target_data_size: - # get samples from the buffer and submit the generation requests. - samples = data_source(args.over_sampling_batch_size) - pendings.update(submit_generate_tasks(state, samples)) - - # wait for the generation to finish - done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED) - for task in done: - group: list[Sample] = task.result() - - if do_print: - sample = group[0][0] if isinstance(group[0], list) else group[0] - logger.info( - f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", - ) - do_print = False - - assert len(group) == args.n_samples_per_prompt - all_data.append(group) - dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) - if not dynamic_filter_output.keep: - metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) - continue - - # add the samples to the data - # NOTE: here we have not stored all the unused samples back to the data buffer. - if len(data) < target_data_size: - data.append(group) - pbar.update(args.n_samples_per_prompt) - - pbar.close() - sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] - logger.info( - f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", - ) - - # there are still some unfinished requests, abort them - aborted_samples = await abort(state, pendings, rollout_id) - - assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" - data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) - all_samples = sorted( - all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index - ) - - # reset the global state to prevent effects on the next rollout or eval. - state.reset() - - if f := load_function(args.rollout_sample_filter_path): - f(args, data) - # There can be circumstances where users want to process all samples including filtered ones. - if f := load_function(args.rollout_all_samples_process_path): - f(args, all_samples, data_source) - - return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples diff --git a/miles/rollout/rm_hub/__init__.py b/miles/rollout/rm_hub/__init__.py index e9ee29db41..62b253ddee 100644 --- a/miles/rollout/rm_hub/__init__.py +++ b/miles/rollout/rm_hub/__init__.py @@ -69,18 +69,8 @@ async def async_rm(args, sample: Sample, **kwargs): async def batched_async_rm( args, samples: list[Sample], - inplace_set_reward_field: bool = False, **kwargs, -) -> list[int | float] | None: - if inplace_set_reward_field: - rewards = await batched_async_rm(args, samples, **kwargs) - for sample, reward in zip(samples, rewards, strict=True): - assert ( - sample.reward is None - ), f"Overriding sample.reward from {sample.reward} to {reward}, is this intended?" - sample.reward = reward - return None - +) -> list[int | float]: if args.custom_rm_path is not None: # Ensure the custom reward function is implemented in batch mode rm_function = load_function(args.custom_rm_path) diff --git a/miles/router/router.py b/miles/router/router.py index 7d3ecd9806..2e8ecfc41f 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -9,7 +9,6 @@ from fastapi.responses import JSONResponse from starlette.responses import Response -from miles.router.sessions import setup_session_routes from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -70,8 +69,6 @@ def _setup_routes(self): self.app.post("/add_worker")(self.add_worker) self.app.get("/list_workers")(self.list_workers) self.app.post("/retrieve_from_text")(self.retrieve_from_text) - # Session routes - must be registered before catch-all - setup_session_routes(self.app, self) # Catch-all route for proxying to SGLang - must be registered LAST self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) @@ -133,41 +130,39 @@ async def _health_check_loop(self): async def proxy(self, request: Request, path: str): """Proxy all other requests to the SGLang router""" - result = await self._do_proxy(request, path) - return self._build_proxy_response(result) - - async def _do_proxy(self, request: Request, path: str) -> dict: - """Core proxy logic. Returns dict with request_body, response_body, status_code, headers.""" + # Forward all other paths to SGLang router worker_url = self._use_url() url = f"{worker_url}/{path}" + # Get request body and headers body = await request.body() headers = dict(request.headers) try: response = await self.client.request(request.method, url, content=body, headers=headers) + # Eagerly read content so we can return JSON (not streaming) content = await response.aread() - return { - "request_body": body, - "response_body": content, - "status_code": response.status_code, - "headers": dict(response.headers), - } + content_type = response.headers.get("content-type", "") + try: + # Prefer parsing JSON if possible + data = json.loads(content) + return JSONResponse( + content=data, + status_code=response.status_code, + headers=dict(response.headers), + ) + except Exception: + # Fall back to raw body with original content type + return Response( + content=content, + status_code=response.status_code, + headers=dict(response.headers), + media_type=content_type or None, + ) + finally: self._finish_url(worker_url) - def _build_proxy_response(self, result: dict) -> Response: - """Build HTTP response from proxy result.""" - content = result["response_body"] - status_code = result["status_code"] - headers = result["headers"] - content_type = headers.get("content-type", "") - try: - data = json.loads(content) - return JSONResponse(content=data, status_code=status_code, headers=headers) - except Exception: - return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) - async def add_worker(self, request: Request): """Add a new worker to the router. Supports providing the URL via query string or JSON body. diff --git a/miles/router/sessions.py b/miles/router/sessions.py deleted file mode 100644 index 9d753e5975..0000000000 --- a/miles/router/sessions.py +++ /dev/null @@ -1,124 +0,0 @@ -import json -import time -import uuid -from typing import TYPE_CHECKING - -from fastapi import Request -from fastapi.responses import JSONResponse, Response -from pydantic import BaseModel -from transformers import AutoTokenizer - -if TYPE_CHECKING: - from miles.router.router import MilesRouter - - -class SessionRecord(BaseModel): - timestamp: float - method: str - path: str - request: dict - response: dict - status_code: int - - -class GetSessionResponse(BaseModel): - session_id: str - records: list[SessionRecord] - - -class SessionManager: - def __init__(self): - self.sessions: dict[str, list[SessionRecord]] = {} - - def create_session(self) -> str: - session_id = uuid.uuid4().hex - self.sessions[session_id] = [] - return session_id - - def get_session(self, session_id: str) -> list[SessionRecord] | None: - return self.sessions.get(session_id) - - def delete_session(self, session_id: str) -> list[SessionRecord]: - assert session_id in self.sessions - return self.sessions.pop(session_id) - - def add_record(self, session_id: str, record: SessionRecord): - assert session_id in self.sessions - self.sessions[session_id].append(record) - - -def setup_session_routes(app, router: "MilesRouter"): - manager = SessionManager() - - # TODO temporary hack before @guapisolo implements TITO - # ============================= HACK START =============================== - # Lazy load tokenizer only when needed (for tests that don't have hf_checkpoint) - tokenizer = None - - def get_tokenizer(): - nonlocal tokenizer - if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained(router.args.hf_checkpoint, trust_remote_code=True) - return tokenizer - - # ============================= HACK END =============================== - - @app.post("/sessions") - async def create_session(): - session_id = manager.create_session() - return {"session_id": session_id} - - @app.get("/sessions/{session_id}") - async def get_session(session_id: str): - records = manager.get_session(session_id) - if records is None: - return JSONResponse(status_code=404, content={"error": "session not found"}) - return GetSessionResponse(session_id=session_id, records=records) - - @app.delete("/sessions/{session_id}") - async def delete_session(session_id: str): - if session_id not in manager.sessions: - return JSONResponse(status_code=404, content={"error": "session not found"}) - manager.delete_session(session_id) - return Response(status_code=204) - - @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) - async def session_proxy(request: Request, session_id: str, path: str): - if session_id not in manager.sessions: - return JSONResponse(status_code=404, content={"error": "session not found"}) - - result = await router._do_proxy(request, path) - - request_body = json.loads(result["request_body"]) - response_body = json.loads(result["response_body"]) - - # TODO: remove this hack when @guapisolo implements the real TITO - # ============================= HACK START =============================== - if "messages" in request_body and "input_ids" not in request_body: - request_body["input_ids"] = get_tokenizer().apply_chat_template( - request_body["messages"], - add_generation_prompt=True, - add_special_tokens=False, - tools=request_body.get("tools"), - ) - if ( - "logprobs" in response_body.get("choices", [{}])[0] - and "content" in response_body["choices"][0]["logprobs"] - ): - logprobs_content = response_body["choices"][0]["logprobs"]["content"] - for item in logprobs_content: - if "token" in item and "token_id" not in item: - item["token_id"] = get_tokenizer().convert_tokens_to_ids(item["token"]) - # ============================= HACK END =============================== - - record = SessionRecord( - timestamp=time.time(), - method=request.method, - path=path, - request=request_body, - response=response_body, - status_code=result["status_code"], - ) - manager.add_record(session_id, record) - - return router._build_proxy_response(result) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 0710202924..79b2c419ca 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -10,10 +10,8 @@ from miles.backends.sglang_utils.arguments import add_sglang_arguments from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args -from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from miles.utils.logging_utils import configure_logger -from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -206,11 +204,7 @@ def add_rollout_arguments(parser): parser.add_argument( "--rollout-function-path", type=str, - default=( - "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn" - if enable_experimental_rollout_refactor() - else "miles.rollout.sglang_rollout.generate_rollout" - ), + default="miles.rollout.sglang_rollout.generate_rollout", help=( "Path to the rollout generation function." "You should use this model to create your own custom rollout function, " @@ -1350,20 +1344,6 @@ def add_ci_arguments(parser): ) return parser - def add_user_provided_function_arguments(parser): - args_partial, _ = parser.parse_known_args() - for path in [ - args_partial.rollout_function_path, - args_partial.custom_generate_function_path, - ]: - try: - fn = load_function(path) - except (ModuleNotFoundError, ValueError): - continue - if fn is not None and callable(getattr(fn, "add_arguments", None)): - fn.add_arguments(parser) - return parser - def add_sglang_tp_size(): temp_parser = argparse.ArgumentParser(add_help=False) temp_parser.add_argument("--rollout-num-gpus-per-engine", type=int, default=1) @@ -1394,8 +1374,6 @@ def add_sglang_tp_size(): parser = add_prefill_decode_disaggregation_arguments(parser) parser = add_ci_arguments(parser) parser = add_custom_megatron_plugins_arguments(parser) - if enable_experimental_rollout_refactor(): - parser = add_user_provided_function_arguments(parser) reset_arg( parser, "--custom-config-path", diff --git a/miles/utils/environ.py b/miles/utils/environ.py deleted file mode 100644 index 35d1f350ee..0000000000 --- a/miles/utils/environ.py +++ /dev/null @@ -1,14 +0,0 @@ -import os - -_printed_experimental_rollout_refactor = False - - -def enable_experimental_rollout_refactor() -> bool: - result = bool(int(os.environ.get("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", "0"))) - - global _printed_experimental_rollout_refactor - if result and not _printed_experimental_rollout_refactor: - print("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1 is enabled (experimental feature)") - _printed_experimental_rollout_refactor = True - - return result diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 0abdbbf59d..2b3e6e192f 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -162,15 +162,11 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60, action="post"): +async def _post(client, url, payload, max_retries=60): retry_count = 0 while retry_count < max_retries: try: - if action in ("delete", "get"): - assert not payload - response = await getattr(client, action)(url) - else: - response = await getattr(client, action)(url, json=payload or {}) + response = await client.post(url, json=payload or {}) response.raise_for_status() try: output = response.json() @@ -244,8 +240,8 @@ def __init__(self, concurrency: int): timeout=httpx.Timeout(None), ) - async def do_post(self, url, payload, max_retries=60, action="post"): - return await _post(self._client, url, payload, max_retries, action=action) + async def do_post(self, url, payload, max_retries=60): + return await _post(self._client, url, payload, max_retries) # Create actors per node created = [] @@ -269,8 +265,7 @@ async def do_post(self, url, payload, max_retries=60, action="post"): _post_actors = created -# TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) -async def post(url, payload, max_retries=60, action="post"): +async def post(url, payload, max_retries=60): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: @@ -279,16 +274,15 @@ async def post(url, payload, max_retries=60, action="post"): actor = _next_actor() if actor is not None: # Use a thread to avoid blocking the event loop on ray.get - obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) + obj_ref = actor.do_post.remote(url, payload, max_retries) return await asyncio.to_thread(ray.get, obj_ref) except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries, action=action) + return await _post(_http_client, url, payload, max_retries) -# TODO unify w/ `post` to add retries and remote-execution async def get(url): response = await _http_client.get(url) response.raise_for_status() diff --git a/miles/utils/misc.py b/miles/utils/misc.py index bae72ec0d7..c0a96d6366 100644 --- a/miles/utils/misc.py +++ b/miles/utils/misc.py @@ -1,55 +1,17 @@ -import asyncio import importlib import subprocess -from contextlib import contextmanager import ray from miles.utils.http_utils import is_port_available -# Mainly used for test purpose where `load_function` needs to load many in-flight generated functions -class FunctionRegistry: - def __init__(self): - self._registry: dict[str, object] = {} - - @contextmanager - def temporary(self, name: str, fn: object): - self._register(name, fn) - try: - yield - finally: - self._unregister(name) - - def get(self, name: str) -> object | None: - return self._registry.get(name) - - def _register(self, name: str, fn: object) -> None: - assert name not in self._registry - self._registry[name] = fn - - def _unregister(self, name: str) -> None: - assert name in self._registry - self._registry.pop(name) - - -function_registry = FunctionRegistry() - - -# TODO may rename to `load_object` since it can be used to load things like tool_specs def load_function(path): """ - Load a function from registry or module. + Load a function from a module. :param path: The path to the function, e.g. "module.submodule.function". :return: The function object. """ - if path is None: - return None - - registered = function_registry.get(path) - if registered is not None: - return registered - module_path, _, attr = path.rpartition(".") module = importlib.import_module(module_path) return getattr(module, attr) @@ -68,9 +30,8 @@ def __call__(cls, *args, **kwargs): cls._instances[cls] = instance return cls._instances[cls] - @staticmethod - def clear_all_instances(): - SingletonMeta._instances.clear() + def clear_instances(cls): + cls._instances = {} def exec_command(cmd: str, capture_output: bool = False) -> str | None: @@ -131,8 +92,3 @@ def should_run_periodic_action( step = rollout_id + 1 return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0) - - -async def as_completed_async(tasks): - for coro in asyncio.as_completed(tasks): - yield await coro diff --git a/miles/utils/test_utils/__init__.py b/miles/utils/test_utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py deleted file mode 100644 index 2c0dddfe54..0000000000 --- a/miles/utils/test_utils/mock_sglang_server.py +++ /dev/null @@ -1,248 +0,0 @@ -import asyncio -import re -import time -import uuid -from collections.abc import Callable -from contextlib import contextmanager -from dataclasses import asdict, dataclass - -from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse -from pydantic import TypeAdapter -from sglang.srt.entrypoints.openai.protocol import Tool -from sglang.srt.function_call.function_call_parser import FunctionCallParser -from transformers import AutoTokenizer - -from miles.utils.http_utils import find_available_port -from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer - - -@dataclass(frozen=True) -class ProcessResultMetaInfo: - weight_version: str | None = None - routed_experts: str | None = None - spec_accept_token_num: int | None = None - spec_draft_token_num: int | None = None - spec_verify_ct: int | None = None - - def to_dict(self) -> dict: - return {k: v for k, v in asdict(self).items() if v is not None} - - -@dataclass(frozen=True) -class ProcessResult: - text: str - finish_reason: str = "stop" - cached_tokens: int = 0 - meta_info: ProcessResultMetaInfo = ProcessResultMetaInfo() - - -ProcessFn = Callable[[str], ProcessResult] - - -class MockSGLangServer: - def __init__( - self, - model_name: str, - process_fn: ProcessFn, - host: str, - port: int, - latency: float = 0.0, - ): - self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - self.process_fn = process_fn - self.host = host - self.port = port or find_available_port(30000) - self.latency = latency - - self.app = FastAPI() - self._server: UvicornThreadServer | None = None - - self.request_log: list[dict] = [] - self._concurrency = Counter() - - self._setup_routes() - - @property - def max_concurrent(self) -> int: - return self._concurrency.max_value - - def reset_stats(self): - self.request_log.clear() - self._concurrency.reset() - - def start(self): - self._server = UvicornThreadServer(self.app, host=self.host, port=self.port) - self._server.start() - - def stop(self): - if self._server is not None: - self._server.stop() - - @property - def url(self) -> str: - return f"http://{self.host}:{self.port}" - - def _setup_routes(self): - @self.app.post("/generate") - async def generate(request: Request): - return await self._handle_generate_like_request(request, self._compute_generate_response) - - @self.app.post("/v1/chat/completions") - async def chat_completions(request: Request): - return await self._handle_generate_like_request(request, self._compute_chat_completions_response) - - @self.app.get("/health") - async def health(): - return JSONResponse(content={"status": "ok"}) - - @self.app.post("/abort_request") - async def abort_request(_request: Request): - return JSONResponse(content={"status": "ok"}) - - async def _handle_generate_like_request(self, request: Request, compute_fn: Callable[[dict], dict]): - payload = await request.json() - self.request_log.append(payload) - with self._concurrency.track(): - if self.latency > 0: - await asyncio.sleep(self.latency) - response = compute_fn(payload) - return JSONResponse(content=response) - - def _compute_generate_response(self, payload: dict) -> dict: - assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True" - input_ids = payload.get("input_ids", []) - - prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False) - process_result = self.process_fn(prompt_str) - output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) - - prompt_tokens = len(input_ids) - completion_tokens = len(output_ids) - - finish_reason_dict = {"type": process_result.finish_reason} - if process_result.finish_reason == "length": - finish_reason_dict["length"] = completion_tokens - - output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] - - meta_info = { - "finish_reason": finish_reason_dict, - "prompt_tokens": prompt_tokens, - "cached_tokens": process_result.cached_tokens, - "completion_tokens": completion_tokens, - "output_token_logprobs": output_token_logprobs, - **process_result.meta_info.to_dict(), - } - - return {"text": process_result.text, "meta_info": meta_info} - - def _compute_chat_completions_response(self, payload: dict) -> dict: - messages = payload.get("messages", []) - tools = payload.get("tools") - - prompt_str = self.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True, tools=tools - ) - - process_result = self.process_fn(prompt_str) - output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) - - logprobs_content = [ - {"token": self.tokenizer.convert_ids_to_tokens(tid), "logprob": -1 / 128 * i} - for i, tid in enumerate(output_ids) - ] - - finish_reason = process_result.finish_reason - tool_calls = None - if tools and finish_reason == "stop": - parser = FunctionCallParser( - tools=TypeAdapter(list[Tool]).validate_python(tools), - tool_call_parser="qwen25", - ) - message_content, parsed_calls = parser.parse_non_stream(process_result.text) - if parsed_calls: - finish_reason = "tool_calls" - tool_calls = [ - { - "id": f"call{i:05d}", - "type": "function", - "function": {"name": call.name, "arguments": call.parameters or "{}"}, - } - for i, call in enumerate(parsed_calls) - ] - else: - message_content = process_result.text - - return { - "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", - "object": "chat.completion", - "created": int(time.time()), - "model": "mock-model", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": message_content, - "tool_calls": tool_calls, - }, - "logprobs": {"content": logprobs_content}, - "finish_reason": finish_reason, - } - ], - } - - -class Counter: - def __init__(self): - self._current = 0 - self._max = 0 - - @property - def max_value(self) -> int: - return self._max - - def reset(self): - self._current = 0 - self._max = 0 - - @contextmanager - def track(self): - self._current += 1 - self._max = max(self._max, self._current) - try: - yield - finally: - self._current -= 1 - - -def default_process_fn(prompt: str) -> ProcessResult: - match = re.search(r"What is 1\+(\d+)\?", prompt) - if match: - num = int(match.group(1)) - ans = 1 + num - return ProcessResult(text=f"\\boxed{{{ans}}}", finish_reason="stop") - return ProcessResult(text="I don't understand.", finish_reason="stop") - - -@contextmanager -def with_mock_server( - model_name: str = "Qwen/Qwen3-0.6B", - process_fn: ProcessFn = default_process_fn, - host: str = "127.0.0.1", - port: int | None = None, - latency: float = 0.0, -): - server = MockSGLangServer( - model_name=model_name, - process_fn=process_fn, - host=host, - port=port, - latency=latency, - ) - try: - server.start() - yield server - finally: - server.stop() diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py deleted file mode 100644 index 6b99e36739..0000000000 --- a/miles/utils/test_utils/mock_tools.py +++ /dev/null @@ -1,268 +0,0 @@ -import json - -from transformers import AutoTokenizer - -from miles.utils.test_utils.mock_sglang_server import ProcessResult - -SAMPLE_TOOLS = [ - { - "type": "function", - "function": { - "name": "get_year", - "description": "Get current year", - "parameters": { - "type": "object", - "properties": {}, - "required": [], - }, - }, - }, - { - "type": "function", - "function": { - "name": "get_temperature", - "description": "Get temperature for a location", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - "required": ["location"], - }, - }, - }, -] - - -def _get_year(params: dict) -> str: - assert len(params) == 0 - return json.dumps({"year": 2026}) - - -def _get_temperature(params: dict) -> str: - temps = {"Mars": -60, "Earth": 15} - location = params.get("location") - assert location in temps, f"Unknown location: {location}" - return json.dumps({"temperature": temps[location]}) - - -TOOL_EXECUTORS = { - "get_year": _get_year, - "get_temperature": _get_temperature, -} - - -async def execute_tool_call(name: str, params: dict) -> str: - return TOOL_EXECUTORS[name](params) - - -_SYSTEM_PROMPT = ( - "<|im_start|>system\n" - "# Tools\n" - "\n" - "You may call one or more functions to assist with the user query.\n" - "\n" - "You are provided with function signatures within XML tags:\n" - "\n" - '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' - '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' - "\n" - "\n" - "For each function call, return a json object with function name and arguments within XML tags:\n" - "\n" - '{"name": , "arguments": }\n' - "<|im_end|>\n" -) - - -_TOKENIZER = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) - - -class TwoTurnStub: - """Stub for 2-turn: get_year + get_temperature(Mars) -> final answer""" - - USER_QUESTION = "What is 42 + year + temperature?" - - FIRST_RESPONSE = ( - "Let me get the year and temperature first.\n" - "\n" - '{"name": "get_year", "arguments": {}}\n' - "\n" - "\n" - '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' - "<|im_end|>\n" - ) - - FIRST_TOOL_RESPONSE = ( - "<|im_start|>user\n" - "\n" - '{"year": 2026}\n' - "\n" - "\n" - '{"temperature": -60}\n' - "<|im_end|>\n" - "<|im_start|>assistant\n" - ) - - SECOND_RESPONSE = "The answer is: 42 + 2026 + -60 = 2008." - - FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" - SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE - - PROMPT = [{"role": "user", "content": USER_QUESTION}] - - FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] - SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] - - FIRST_RESPONSE_CONTENT = "Let me get the year and temperature first." - FIRST_TOOL_CALLS_OPENAI_FORMAT = [ - {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, - { - "id": "call00001", - "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, - "type": "function", - }, - ] - - OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] - - OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ - { - "content": FIRST_RESPONSE_CONTENT, - "refusal": None, - "role": "assistant", - "annotations": None, - "audio": None, - "function_call": None, - "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, - }, - {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, - {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, - ] - - @staticmethod - def process_fn(prompt: str) -> ProcessResult: - prompt_response_pairs = { - TwoTurnStub.FIRST_PROMPT: TwoTurnStub.FIRST_RESPONSE, - TwoTurnStub.SECOND_PROMPT: TwoTurnStub.SECOND_RESPONSE, - } - - for expect_prompt, response in prompt_response_pairs.items(): - if prompt == expect_prompt: - return ProcessResult(text=response, finish_reason="stop") - - raise ValueError(f"Unexpected {prompt=}") - - -class ThreeTurnStub: - """Stub for 3-turn: get_year + get_temperature(Mars) -> get_temperature(Earth) -> final answer""" - - USER_QUESTION = "What is 42 + year + Mars temperature + Earth temperature?" - - FIRST_RESPONSE = ( - "Let me get the year and Mars temperature first.\n" - "\n" - '{"name": "get_year", "arguments": {}}\n' - "\n" - "\n" - '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' - "<|im_end|>\n" - ) - - SECOND_RESPONSE = ( - "Now let me get Earth temperature.\n" - "\n" - '{"name": "get_temperature", "arguments": {"location": "Earth"}}\n' - "<|im_end|>\n" - ) - - FIRST_TOOL_RESPONSE = ( - "<|im_start|>user\n" - "\n" - '{"year": 2026}\n' - "\n" - "\n" - '{"temperature": -60}\n' - "<|im_end|>\n" - "<|im_start|>assistant\n" - ) - - SECOND_TOOL_RESPONSE = ( - "<|im_start|>user\n" - "\n" - '{"temperature": 15}\n' - "<|im_end|>\n" - "<|im_start|>assistant\n" - ) - - THIRD_RESPONSE = "The answer is: 42 + 2026 + -60 + 15 = 2023." - - FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" - SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE - THIRD_PROMPT = SECOND_PROMPT + SECOND_RESPONSE + SECOND_TOOL_RESPONSE - - PROMPT = [{"role": "user", "content": USER_QUESTION}] - - FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] - SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] - THIRD_PROMPT_TOKEN_IDS = _TOKENIZER(THIRD_PROMPT, add_special_tokens=False)["input_ids"] - - FIRST_RESPONSE_CONTENT = "Let me get the year and Mars temperature first." - FIRST_TOOL_CALLS_OPENAI_FORMAT = [ - {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, - { - "id": "call00001", - "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, - "type": "function", - }, - ] - - SECOND_RESPONSE_CONTENT = "Now let me get Earth temperature." - SECOND_TOOL_CALLS_OPENAI_FORMAT = [ - { - "id": "call00000", - "function": {"arguments": '{"location": "Earth"}', "name": "get_temperature"}, - "type": "function", - }, - ] - - OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] - - OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ - { - "content": FIRST_RESPONSE_CONTENT, - "refusal": None, - "role": "assistant", - "annotations": None, - "audio": None, - "function_call": None, - "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, - }, - {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, - {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, - ] - - OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT = OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT + [ - { - "content": SECOND_RESPONSE_CONTENT, - "refusal": None, - "role": "assistant", - "annotations": None, - "audio": None, - "function_call": None, - "tool_calls": SECOND_TOOL_CALLS_OPENAI_FORMAT, - }, - {"role": "tool", "tool_call_id": "call00000", "content": '{"temperature": 15}', "name": "get_temperature"}, - ] - - @staticmethod - def process_fn(prompt: str) -> ProcessResult: - prompt_response_pairs = { - ThreeTurnStub.FIRST_PROMPT: ThreeTurnStub.FIRST_RESPONSE, - ThreeTurnStub.SECOND_PROMPT: ThreeTurnStub.SECOND_RESPONSE, - ThreeTurnStub.THIRD_PROMPT: ThreeTurnStub.THIRD_RESPONSE, - } - - for expect_prompt, response in prompt_response_pairs.items(): - if prompt == expect_prompt: - return ProcessResult(text=response, finish_reason="stop") - - raise ValueError(f"Unexpected {prompt=}") diff --git a/miles/utils/test_utils/uvicorn_thread_server.py b/miles/utils/test_utils/uvicorn_thread_server.py deleted file mode 100644 index 904343c984..0000000000 --- a/miles/utils/test_utils/uvicorn_thread_server.py +++ /dev/null @@ -1,49 +0,0 @@ -import asyncio -import socket -import threading -import time - -import uvicorn - - -class UvicornThreadServer: - def __init__(self, app, host: str, port: int): - self._app = app - self.host = host - self.port = port - self._server: uvicorn.Server | None = None - self._thread: threading.Thread | None = None - - @property - def url(self) -> str: - return f"http://{self.host}:{self.port}" - - def start(self) -> None: - config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info") - self._server = uvicorn.Server(config) - - def run() -> None: - asyncio.run(self._server.serve()) - - self._thread = threading.Thread(target=run, daemon=True) - self._thread.start() - self._wait_for_port_open() - - def stop(self) -> None: - if self._server is not None: - self._server.should_exit = True - if self._thread is not None and self._thread.is_alive(): - self._thread.join(timeout=2.0) - - def _wait_for_port_open(self) -> None: - for _ in range(50): - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - result = sock.connect_ex((self.host, self.port)) - sock.close() - if result == 0: - return - except Exception: - pass - time.sleep(0.1) - raise RuntimeError(f"Failed to start server on {self.url}") diff --git a/miles/utils/types.py b/miles/utils/types.py index 5200d625e6..0a2531a7af 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -145,24 +145,6 @@ def get_reward_value(self, args) -> float: def effective_response_length(self): return sum(self.loss_mask) if self.loss_mask is not None else self.response_length - def validate(self): - assert self.response_length >= 0, f"response_length must be >= 0, got {self.response_length}" - assert ( - len(self.tokens) >= self.response_length - ), f"tokens length ({len(self.tokens)}) must be >= response_length ({self.response_length})" - if self.loss_mask is not None: - assert ( - len(self.loss_mask) == self.response_length - ), f"loss_mask length ({len(self.loss_mask)}) != response_length ({self.response_length})" - if self.rollout_log_probs is not None: - assert ( - len(self.rollout_log_probs) == self.response_length - ), f"rollout_log_probs length ({len(self.rollout_log_probs)}) != response_length ({self.response_length})" - if self.rollout_routed_experts is not None: - actual = len(self.rollout_routed_experts) - expect = len(self.tokens) - 1 - assert actual == expect, f"rollout_routed_experts length ({actual}) != len(tokens) - 1 ({expect})" - def update_from_meta_info(self, args, meta_info: dict): """ Update the sample with new information from meta_info returned by the rollout engine. diff --git a/requirements.txt b/requirements.txt index dacd51132c..2c20195fc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf pillow -pybase64 pylatexenc pyyaml qwen_vl_utils # for VLM diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/ci/gpu_lock_exec.py b/tests/ci/gpu_lock_exec.py index 20379f76a2..9507e2e858 100644 --- a/tests/ci/gpu_lock_exec.py +++ b/tests/ci/gpu_lock_exec.py @@ -19,14 +19,11 @@ def main(): _execute_print_only(args) return - if args.count == 0 and not args.devices: - print("[gpu_lock_exec] Do not acquire GPU since count=0", flush=True) - else: - fd_locks = _try_acquire(args) + fd_locks = _try_acquire(args) - dev_list = ",".join(str(x.gpu_id) for x in fd_locks) - os.environ[args.target_env_name] = dev_list - print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) + dev_list = ",".join(str(x.gpu_id) for x in fd_locks) + os.environ[args.target_env_name] = dev_list + print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) _os_execvp(args) diff --git a/tests/e2e/.gitkeep b/tests/e2e/.gitkeep deleted file mode 100644 index 615f2b076c..0000000000 --- a/tests/e2e/.gitkeep +++ /dev/null @@ -1 +0,0 @@ -# TODO: may move e2e tests to this folder \ No newline at end of file diff --git a/tests/fast/__init__.py b/tests/fast/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/fast/conftest.py b/tests/fast/conftest.py deleted file mode 100644 index 4cb30e91fa..0000000000 --- a/tests/fast/conftest.py +++ /dev/null @@ -1,15 +0,0 @@ -import os - -import pytest - -from tests.fast.fixtures.generation_fixtures import generation_env -from tests.fast.fixtures.rollout_fixtures import rollout_env - -_ = rollout_env, generation_env - - -@pytest.fixture(autouse=True) -def enable_experimental_rollout_refactor(): - os.environ["MILES_EXPERIMENTAL_ROLLOUT_REFACTOR"] = "1" - yield - os.environ.pop("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", None) diff --git a/tests/fast/fixtures/__init__.py b/tests/fast/fixtures/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/tests/fast/fixtures/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/fast/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py deleted file mode 100644 index 816371ee3a..0000000000 --- a/tests/fast/fixtures/generation_fixtures.py +++ /dev/null @@ -1,274 +0,0 @@ -""" -Fixtures to test custom-generate-function -""" - -from argparse import Namespace -from contextlib import contextmanager -from dataclasses import dataclass -from types import SimpleNamespace -from typing import Any -from unittest.mock import patch - -import pytest -import requests - -from miles.rollout.base_types import GenerateFnInput -from miles.rollout.inference_rollout.compatibility import load_generate_function -from miles.rollout.inference_rollout.inference_rollout_common import GenerateState -from miles.router.router import MilesRouter -from miles.utils.async_utils import run -from miles.utils.http_utils import find_available_port, init_http_client -from miles.utils.misc import SingletonMeta -from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server -from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer -from miles.utils.types import Sample - -MODEL_NAME = "Qwen/Qwen3-0.6B" -RESPONSE_TEXT = "\\boxed{8}" -DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} - -VARIANT_TO_GENERATE_FN_PATH = { - "old_sglang_rollout": "miles.rollout.sglang_rollout.generate", - "single_turn": "miles.rollout.generate_hub.single_turn.generate", - "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn.generate", - "multi_turn_multi_samples": "miles.rollout.generate_hub.multi_turn.generate", - "agentic_tool_call_single_sample": "miles.rollout.generate_hub.agentic_tool_call.generate", - "agentic_tool_call_multi_samples": "miles.rollout.generate_hub.agentic_tool_call.generate", -} - - -def extra_argv_for_variant( - variant: str, - *, - custom_generate_function_path: str | None = None, - generate_max_turns: int = 16, - generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", - generate_tool_call_parser: str = "qwen25", - generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", -) -> list[str]: - argv = [ - "--custom-generate-function-path", - custom_generate_function_path or VARIANT_TO_GENERATE_FN_PATH[variant], - ] - - if variant in ( - "multi_turn_single_sample", - "multi_turn_multi_samples", - "agentic_tool_call_single_sample", - "agentic_tool_call_multi_samples", - ): - argv += [ - "--generate-max-turns", - str(generate_max_turns), - "--generate-tool-specs-path", - generate_tool_specs_path, - "--generate-execute-tool-function-path", - generate_execute_tool_function_path, - ] - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): - argv += ["--generate-tool-call-parser", generate_tool_call_parser] - if variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples"): - argv.append("--generate-multi-samples") - - return argv - - -def listify(x): - return x if isinstance(x, list) else [x] - - -def make_sample( - *, - prompt: str | list[dict] = "What is 1+7?", - tokens: list[int] | None = None, - response: str = "", - response_length: int = 0, - status: Sample.Status = Sample.Status.PENDING, - multimodal_inputs: dict | None = None, -) -> Sample: - return Sample( - prompt=prompt, - tokens=tokens or [], - response=response, - response_length=response_length, - status=status, - multimodal_inputs=multimodal_inputs, - ) - - -@dataclass -class GenerateEnv: - args: Namespace - mock_server: Any - - -@dataclass -class GenerateResult: - sample: Sample | list[Sample] - requests: list[dict] - - -def run_generate( - env: GenerateEnv, - sample: Sample, - sampling_params: dict[str, Any] | None = None, - *, - variant: str = "single_turn", -) -> GenerateResult: - env.mock_server.request_log.clear() - result_sample = run( - _call_generate( - env.args, - sample, - sampling_params or DEFAULT_SAMPLING_PARAMS, - variant=variant, - ) - ) - return GenerateResult(sample=result_sample, requests=list(env.mock_server.request_log)) - - -async def _call_generate( - args: Namespace, - sample: Sample, - sampling_params: dict[str, Any], - *, - variant: str = "single_turn", -) -> Sample: - generate_fn = load_generate_function(VARIANT_TO_GENERATE_FN_PATH[variant]) - state = GenerateState(args) - input = GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params.copy(), evaluation=False) - output = await generate_fn(input) - return output.samples - - -def make_args( - *, - variant: str, - router_port: int, - use_rollout_routing_replay: bool = False, - sglang_speculative_algorithm: str | None = None, - model_name: str = MODEL_NAME, - extra_argv: list[str] | None = None, - custom_generate_function_path: str | None = None, - generate_max_turns: int = 16, - generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", - generate_tool_call_parser: str = "qwen25", - generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", - rollout_max_context_len: int | None = None, -) -> Namespace: - argv = [ - "pytest", - "--train-backend", - "fsdp", - "--rollout-batch-size", - "1", - "--num-rollout", - "1", - "--rollout-num-gpus", - "1", - "--rollout-num-gpus-per-engine", - "1", - "--hf-checkpoint", - model_name, - "--prompt-data", - "/dev/null", - "--rm-type", - "math", - "--sglang-router-ip", - "127.0.0.1", - "--sglang-router-port", - str(router_port), - "--rollout-max-response-len", - "16", - ] - if use_rollout_routing_replay: - argv.append("--use-rollout-routing-replay") - if sglang_speculative_algorithm: - argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) - if rollout_max_context_len is not None: - argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) - - argv.extend( - extra_argv_for_variant( - variant, - custom_generate_function_path=custom_generate_function_path, - generate_max_turns=generate_max_turns, - generate_tool_specs_path=generate_tool_specs_path, - generate_tool_call_parser=generate_tool_call_parser, - generate_execute_tool_function_path=generate_execute_tool_function_path, - ) - ) - - if extra_argv: - argv.extend(extra_argv) - - from miles.utils.arguments import parse_args - - with patch("sys.argv", argv): - args = parse_args() - - init_http_client(args) - return args - - -@contextmanager -def with_miles_router(backend_url: str, model_name: str): - router_args = SimpleNamespace( - miles_router_max_connections=10, - miles_router_timeout=30, - miles_router_middleware_paths=[], - rollout_health_check_interval=60, - miles_router_health_check_failure_threshold=3, - hf_checkpoint=model_name, - ) - router = MilesRouter(router_args) - - port = find_available_port(31000) - server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) - server.start() - - url = f"http://127.0.0.1:{port}" - requests.post(f"{url}/add_worker", json={"url": backend_url}) - - try: - yield port - finally: - server.stop() - - -@pytest.fixture -def generation_env(request, variant): - SingletonMeta.clear_all_instances() - params = getattr(request, "param", {}) - args_kwargs = params.get("args_kwargs", {}) - model_name = args_kwargs.get("model_name", MODEL_NAME) - custom_generate_function_path = VARIANT_TO_GENERATE_FN_PATH[variant] - - def process_fn(_): - x = params.get("process_fn_kwargs", {}) - return ProcessResult( - text=x.get("response_text", RESPONSE_TEXT), - finish_reason=x.get("finish_reason", "stop"), - cached_tokens=x.get("cached_tokens", 0), - meta_info=ProcessResultMetaInfo( - weight_version=x.get("weight_version"), - routed_experts=x.get("routed_experts"), - spec_accept_token_num=x.get("spec_accept_token_num"), - spec_draft_token_num=x.get("spec_draft_token_num"), - spec_verify_ct=x.get("spec_verify_ct"), - ), - ) - - with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server: - with with_miles_router(mock_server.url, model_name) as router_port: - other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} - args = make_args( - variant=variant, - router_port=router_port, - model_name=model_name, - custom_generate_function_path=custom_generate_function_path, - **other_args_kwargs, - ) - yield GenerateEnv(args=args, mock_server=mock_server) - - SingletonMeta.clear_all_instances() diff --git a/tests/fast/fixtures/rollout_fixtures.py b/tests/fast/fixtures/rollout_fixtures.py deleted file mode 100644 index 44d8a50d79..0000000000 --- a/tests/fast/fixtures/rollout_fixtures.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Fixtures to test rollout-function -""" - -import json -from argparse import Namespace -from collections.abc import Iterator -from contextlib import contextmanager -from dataclasses import dataclass -from pathlib import Path -from unittest.mock import patch - -import pytest -import requests - -from miles.rollout.data_source import DataSource, RolloutDataSourceWithBuffer -from miles.router.router import MilesRouter -from miles.utils.arguments import parse_args -from miles.utils.http_utils import find_available_port, init_http_client -from miles.utils.misc import SingletonMeta -from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, with_mock_server -from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer - - -@dataclass(frozen=True) -class RolloutEnvConfig: - extra_argv: list[str] | None = None - data_rows: list[dict] | None = None - latency: float = 0.0 - - -@dataclass(frozen=True) -class RolloutEnv: - args: Namespace - data_source: DataSource - mock_server: MockSGLangServer - - -def _build_args(*, data_path: str, router_port: int, extra_argv: list[str] | None = None) -> Namespace: - argv = [ - "pytest", - "--train-backend", - "fsdp", - "--rollout-batch-size", - "1", - "--n-samples-per-prompt", - "1", - "--num-rollout", - "1", - "--rollout-num-gpus", - "1", - "--rollout-num-gpus-per-engine", - "1", - "--hf-checkpoint", - "Qwen/Qwen3-0.6B", - "--prompt-data", - data_path, - "--input-key", - "input", - "--label-key", - "label", - "--rm-type", - "math", - "--eval-prompt-data", - "toy", - data_path, - "--use-miles-router", - "--sglang-router-ip", - "127.0.0.1", - "--sglang-router-port", - str(router_port), - "--rollout-max-response-len", - "16", - ] + (extra_argv or []) - with patch("sys.argv", argv): - args = parse_args() - args.miles_router_middleware_paths = [] - init_http_client(args) - return args - - -@contextmanager -def _with_miles_router(args: Namespace) -> Iterator[UvicornThreadServer]: - router = MilesRouter(args, verbose=False) - server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) - try: - server.start() - yield server - finally: - server.stop() - - -def _write_jsonl(path: str, rows: list[dict]) -> None: - Path(path).write_text("".join(json.dumps(row, ensure_ascii=False) + "\n" for row in rows), encoding="utf-8") - - -DEFAULT_DATA_ROWS = [{"input": "What is 1+7?", "label": "8"}] - - -@pytest.fixture -def rollout_env(tmp_path, request) -> RolloutEnv: - config = request.param - assert isinstance(config, RolloutEnvConfig) - - data_rows = config.data_rows or DEFAULT_DATA_ROWS - - data_path = str(tmp_path / "data.jsonl") - _write_jsonl(data_path, data_rows) - - router_port = find_available_port(20000) - args = _build_args(data_path=data_path, router_port=router_port, extra_argv=config.extra_argv) - - SingletonMeta.clear_all_instances() - - with with_mock_server(model_name=args.hf_checkpoint, latency=config.latency) as mock_server: - with _with_miles_router(args) as router_server: - r = requests.post( - f"{router_server.url}/add_worker", - params={"url": mock_server.url}, - timeout=5.0, - ) - r.raise_for_status() - - data_source = RolloutDataSourceWithBuffer(args) - yield RolloutEnv(args=args, data_source=data_source, mock_server=mock_server) - - SingletonMeta.clear_all_instances() diff --git a/tests/fast/rollout/__init__.py b/tests/fast/rollout/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/fast/rollout/generate_hub/__init__.py b/tests/fast/rollout/generate_hub/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/fast/rollout/generate_hub/test_multi_turn.py b/tests/fast/rollout/generate_hub/test_multi_turn.py deleted file mode 100644 index 5d974aaadd..0000000000 --- a/tests/fast/rollout/generate_hub/test_multi_turn.py +++ /dev/null @@ -1,572 +0,0 @@ -from copy import deepcopy -from dataclasses import dataclass, replace -from itertools import groupby - -import numpy as np -import pybase64 -import pytest -from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate -from transformers import AutoTokenizer - -from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo -from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, ThreeTurnStub, TwoTurnStub -from miles.utils.types import Sample - -_ = generation_env, SAMPLE_TOOLS, TwoTurnStub, ThreeTurnStub - - -def is_agentic_variant(variant: str) -> bool: - return variant in ("agentic_tool_call_single_sample", "agentic_tool_call_multi_samples") - - -# ------------------------------------ fixtures and consts ---------------------------------------- - - -MODEL_NAME = "Qwen/Qwen3-0.6B" -DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} -TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) - - -@pytest.fixture( - params=[ - "multi_turn_single_sample", - "multi_turn_multi_samples", - "agentic_tool_call_single_sample", - "agentic_tool_call_multi_samples", - ] -) -def variant(request): - return request.param - - -@dataclass(frozen=True) -class SampleParsedChunk: - tokens_decoded_str: str - loss_mask_value: int - rollout_log_probs: list[float] - - -@dataclass -class ExpectedSampleInfo: - chunks: list[SampleParsedChunk] - partial_sample: Sample - - -def token_len(text: str) -> int: - return len(TOKENIZER(text, add_special_tokens=False)["input_ids"]) - - -def expected_chunk(text: str, loss_mask: int) -> SampleParsedChunk: - n = token_len(text) - log_probs = [-1 / 128 * i for i in range(n)] if loss_mask else [0.0] * n - return SampleParsedChunk(text, loss_mask, log_probs) - - -def parse_sample_into_chunks(sample: Sample, tokenizer) -> list[SampleParsedChunk]: - prompt_len = len(sample.tokens) - sample.response_length - response_tokens = sample.tokens[prompt_len:] - loss_mask = sample.loss_mask or [] - log_probs = sample.rollout_log_probs or [] - - chunks = [] - idx = 0 - for mask_val, group in groupby(loss_mask): - group_len = len(list(group)) - sli = slice(idx, idx + group_len) - chunks.append( - SampleParsedChunk( - tokens_decoded_str=tokenizer.decode(response_tokens[sli]), - loss_mask_value=mask_val, - rollout_log_probs=log_probs[sli], - ) - ) - idx += group_len - return chunks - - -def expected_partial_sample( - *, - prompt: list[dict], - response: str, - response_length: int, - status: Sample.Status = Sample.Status.COMPLETED, -) -> Sample: - return Sample( - prompt=prompt, - response=response, - response_length=response_length, - status=status, - tokens=[], - loss_mask=[], - rollout_log_probs=[], - weight_versions=[], - spec_info=Sample.SpecInfo(), - prefix_cache_info=Sample.PrefixCacheInfo(), - ) - - -def verify_samples(actual: Sample | list[Sample], expected: list[ExpectedSampleInfo]): - actual = listify(actual) - assert len(actual) == len(expected) - - for actual_item, expected_item in zip(actual, expected, strict=True): - actual_chunks = parse_sample_into_chunks(actual_item, TOKENIZER) - assert actual_chunks == expected_item.chunks - - actual_partial = replace( - deepcopy(actual_item), - tokens=[], - loss_mask=[], - rollout_log_probs=[], - prefix_cache_info=Sample.PrefixCacheInfo(), - ) - assert actual_partial == expected_item.partial_sample - - -def _run_generate(variant: str, env: GenerateEnv, sample: Sample, sampling_params: dict | None = None): - return run_generate(env, sample, sampling_params, variant=variant) - - -def expected_request(input_ids: list[int], sampling_params: dict | None = None) -> dict: - return { - "input_ids": input_ids, - "sampling_params": sampling_params or DEFAULT_SAMPLING_PARAMS, - "return_logprob": True, - "return_routed_experts": False, - } - - -def expected_openai_request(messages: list[dict]) -> dict: - return {"messages": messages, "model": "default", "tools": SAMPLE_TOOLS} - - -SINGLE_TURN_PROMPT = [{"role": "user", "content": "What is 1+1?"}] -SINGLE_TURN_RESPONSE = "The answer is 2." -_SINGLE_TURN_PROMPT_TEXT = TOKENIZER.apply_chat_template( - SINGLE_TURN_PROMPT, tokenize=False, add_generation_prompt=True, tools=SAMPLE_TOOLS -) -SINGLE_TURN_PROMPT_TOKEN_IDS = TOKENIZER(_SINGLE_TURN_PROMPT_TEXT, add_special_tokens=False)["input_ids"] -SINGLE_TURN_PROMPT_TOKEN_LEN = len(SINGLE_TURN_PROMPT_TOKEN_IDS) - - -# ------------------------------------ tests ---------------------------------------- - - -class TestBasicMultiTurn: - def test_single_turn_no_tool_call(self, variant, generation_env): - generation_env.mock_server.process_fn = lambda _: ProcessResult( - text=SINGLE_TURN_RESPONSE, finish_reason="stop" - ) - - result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) - - if is_agentic_variant(variant): - assert result.requests == [expected_openai_request(SINGLE_TURN_PROMPT)] - else: - assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] - verify_samples( - result.sample, - [ - ExpectedSampleInfo( - chunks=[ - SampleParsedChunk( - tokens_decoded_str=SINGLE_TURN_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(6)], - ), - ], - partial_sample=expected_partial_sample( - prompt=SINGLE_TURN_PROMPT, response=SINGLE_TURN_RESPONSE, response_length=6 - ), - ), - ], - ) - - def test_two_turns_with_tool_call(self, variant, generation_env): - generation_env.mock_server.process_fn = TwoTurnStub.process_fn - - S = TwoTurnStub - result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) - - if is_agentic_variant(variant): - assert result.requests == [ - expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), - expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), - ] - else: - assert result.requests == [ - expected_request(S.FIRST_PROMPT_TOKEN_IDS), - expected_request(S.SECOND_PROMPT_TOKEN_IDS), - ] - if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): - full_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + S.SECOND_RESPONSE - expected = [ - ExpectedSampleInfo( - chunks=[ - expected_chunk(S.FIRST_RESPONSE, 1), - expected_chunk(S.FIRST_TOOL_RESPONSE, 0), - expected_chunk(S.SECOND_RESPONSE, 1), - ], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=full_response, - response_length=token_len(full_response), - ), - ), - ] - else: - expected = [ - ExpectedSampleInfo( - chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=S.FIRST_RESPONSE, - response_length=token_len(S.FIRST_RESPONSE), - ), - ), - ExpectedSampleInfo( - chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=S.SECOND_RESPONSE, - response_length=token_len(S.SECOND_RESPONSE), - ), - ), - ] - verify_samples(result.sample, expected) - - -class TestExitConditions: - def test_partial_rollout_not_supported(self, variant, generation_env): - if is_agentic_variant(variant): - pytest.skip("agentic_tool_call does not check partial_rollout flag") - generation_env.args.partial_rollout = True - - with pytest.raises(AssertionError, match="Partial rollout is not supported"): - _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) - - def test_abort_preserves_content(self, variant, generation_env): - if is_agentic_variant(variant): - pytest.skip("agentic_tool_call does not handle abort finish_reason") - generation_env.mock_server.process_fn = lambda _: ProcessResult( - text=SINGLE_TURN_RESPONSE, finish_reason="abort" - ) - - result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) - - assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] - verify_samples( - result.sample, - [ - ExpectedSampleInfo( - chunks=[ - SampleParsedChunk( - tokens_decoded_str=SINGLE_TURN_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(6)], - ), - ], - partial_sample=expected_partial_sample( - prompt=SINGLE_TURN_PROMPT, - response=SINGLE_TURN_RESPONSE, - response_length=6, - status=Sample.Status.ABORTED, - ), - ), - ], - ) - - def test_finish_reason_length_exits_and_preserves_content(self, variant, generation_env): - S = TwoTurnStub - generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="length") - - result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) - - if is_agentic_variant(variant): - assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] - else: - assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] - verify_samples( - result.sample, - [ - ExpectedSampleInfo( - chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=S.FIRST_RESPONSE, - response_length=token_len(S.FIRST_RESPONSE), - status=Sample.Status.TRUNCATED, - ), - ), - ], - ) - - @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"generate_max_turns": 1}}], indirect=True) - def test_max_turns_reached(self, variant, generation_env): - S = TwoTurnStub - generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="stop") - - result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) - - if is_agentic_variant(variant): - assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] - else: - assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] - if variant == "multi_turn_single_sample": - expected = [ - ExpectedSampleInfo( - chunks=[ - expected_chunk(S.FIRST_RESPONSE, 1), - expected_chunk(S.FIRST_TOOL_RESPONSE, 0), - ], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE, - response_length=token_len(S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE), - ), - ), - ] - else: - expected = [ - ExpectedSampleInfo( - chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=S.FIRST_RESPONSE, - response_length=token_len(S.FIRST_RESPONSE), - ), - ), - ] - verify_samples(result.sample, expected) - - -class TestRespectMaxContextLen: - @pytest.mark.parametrize( - "generation_env", [{"args_kwargs": {"rollout_max_context_len": SINGLE_TURN_PROMPT_TOKEN_LEN}}], indirect=True - ) - def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): - if is_agentic_variant(variant): - pytest.skip("TODO: implement") - result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) - assert result.requests == [] - if variant == "multi_turn_single_sample": - expected = [ - ExpectedSampleInfo( - chunks=[], - partial_sample=expected_partial_sample( - prompt=SINGLE_TURN_PROMPT, response="", response_length=0, status=Sample.Status.TRUNCATED - ), - ) - ] - else: - expected = [] - verify_samples(result.sample, expected) - - @pytest.mark.parametrize( - "generation_env", - [ - { - "args_kwargs": { - "rollout_max_context_len": len(TwoTurnStub.FIRST_PROMPT_TOKEN_IDS) - + token_len(TwoTurnStub.FIRST_RESPONSE) - + token_len(TwoTurnStub.FIRST_TOOL_RESPONSE) - } - } - ], - indirect=True, - ) - def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env): - if is_agentic_variant(variant): - pytest.skip("TODO: implement") - S = TwoTurnStub - generation_env.mock_server.process_fn = S.process_fn - - result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) - - assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] - if variant == "multi_turn_single_sample": - partial_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE - expected = [ - ExpectedSampleInfo( - chunks=[ - expected_chunk(S.FIRST_RESPONSE, 1), - expected_chunk(S.FIRST_TOOL_RESPONSE, 0), - ], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=partial_response, - response_length=token_len(partial_response), - status=Sample.Status.TRUNCATED, - ), - ), - ] - else: - expected = [ - ExpectedSampleInfo( - chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=S.FIRST_RESPONSE, - response_length=token_len(S.FIRST_RESPONSE), - status=Sample.Status.TRUNCATED, - ), - ), - ] - verify_samples(result.sample, expected) - - @pytest.mark.parametrize( - "generation_env,expected_max_new_tokens", - [ - ( - {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 10}}, - 10, - ), - ( - {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 100}}, - 64, - ), - ], - indirect=["generation_env"], - ) - def test_second_turn_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): - if is_agentic_variant(variant): - pytest.skip("TODO: implement") - S = TwoTurnStub - generation_env.mock_server.process_fn = S.process_fn - - result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) - - assert len(result.requests) >= 2 - assert result.requests[1]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens - assert result.requests[1]["sampling_params"]["temperature"] == DEFAULT_SAMPLING_PARAMS["temperature"] - - -class TestThreeTurn: - """Need to test 3-turn case besides 2-turn, because e.g. merge_samples may behave differently.""" - - def test_three_turns_with_sequential_tool_calls(self, variant, generation_env): - generation_env.mock_server.process_fn = ThreeTurnStub.process_fn - - S = ThreeTurnStub - result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) - - if is_agentic_variant(variant): - assert result.requests == [ - expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), - expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), - expected_openai_request(S.OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT), - ] - else: - assert result.requests == [ - expected_request(S.FIRST_PROMPT_TOKEN_IDS), - expected_request(S.SECOND_PROMPT_TOKEN_IDS), - expected_request(S.THIRD_PROMPT_TOKEN_IDS), - ] - if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): - full_response = ( - S.FIRST_RESPONSE - + S.FIRST_TOOL_RESPONSE - + S.SECOND_RESPONSE - + S.SECOND_TOOL_RESPONSE - + S.THIRD_RESPONSE - ) - expected = [ - ExpectedSampleInfo( - chunks=[ - expected_chunk(S.FIRST_RESPONSE, 1), - expected_chunk(S.FIRST_TOOL_RESPONSE, 0), - expected_chunk(S.SECOND_RESPONSE, 1), - expected_chunk(S.SECOND_TOOL_RESPONSE, 0), - expected_chunk(S.THIRD_RESPONSE, 1), - ], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=full_response, - response_length=token_len(full_response), - ), - ), - ] - else: - expected = [ - ExpectedSampleInfo( - chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=S.FIRST_RESPONSE, - response_length=token_len(S.FIRST_RESPONSE), - ), - ), - ExpectedSampleInfo( - chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=S.SECOND_RESPONSE, - response_length=token_len(S.SECOND_RESPONSE), - ), - ), - ExpectedSampleInfo( - chunks=[expected_chunk(S.THIRD_RESPONSE, 1)], - partial_sample=expected_partial_sample( - prompt=S.PROMPT, - response=S.THIRD_RESPONSE, - response_length=token_len(S.THIRD_RESPONSE), - ), - ), - ] - verify_samples(result.sample, expected) - - -class TestRoutedExpertsMultiTurn: - @pytest.mark.parametrize( - "generation_env", - [ - { - "args_kwargs": { - "use_rollout_routing_replay": True, - } - } - ], - indirect=True, - ) - def test_two_turns_routed_experts(self, variant, generation_env): - if is_agentic_variant(variant): - pytest.skip("TODO: implement") - - S = TwoTurnStub - num_layers, moe_router_topk = 2, 4 - generation_env.args.num_layers = num_layers - generation_env.args.moe_router_topk = moe_router_topk - - def make_routed_experts(prompt_token_ids, response_text): - total_tokens = len(prompt_token_ids) + token_len(response_text) - routed_experts_len = total_tokens - 1 - return np.arange(routed_experts_len * num_layers * moe_router_topk, dtype=np.int32).reshape( - routed_experts_len, num_layers, moe_router_topk - ) - - first_routed_experts = make_routed_experts(S.FIRST_PROMPT_TOKEN_IDS, S.FIRST_RESPONSE) - second_routed_experts = make_routed_experts(S.SECOND_PROMPT_TOKEN_IDS, S.SECOND_RESPONSE) - - def process_fn(prompt: str) -> ProcessResult: - if prompt == S.FIRST_PROMPT: - text, routed_experts = S.FIRST_RESPONSE, first_routed_experts - elif prompt == S.SECOND_PROMPT: - text, routed_experts = S.SECOND_RESPONSE, second_routed_experts - else: - raise ValueError(f"Unexpected prompt: {prompt}") - return ProcessResult( - text=text, - finish_reason="stop", - meta_info=ProcessResultMetaInfo( - routed_experts=pybase64.b64encode(routed_experts.tobytes()).decode("ascii") - ), - ) - - generation_env.mock_server.process_fn = process_fn - result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT), DEFAULT_SAMPLING_PARAMS) - - sample = result.sample[-1] if isinstance(result.sample, list) else result.sample - assert sample.rollout_routed_experts is not None - assert sample.rollout_routed_experts.shape == second_routed_experts.shape - np.testing.assert_array_equal(sample.rollout_routed_experts, second_routed_experts) - assert len(sample.tokens) - 1 == second_routed_experts.shape[0] diff --git a/tests/fast/rollout/generate_hub/test_single_turn.py b/tests/fast/rollout/generate_hub/test_single_turn.py deleted file mode 100644 index a58e6fb3c6..0000000000 --- a/tests/fast/rollout/generate_hub/test_single_turn.py +++ /dev/null @@ -1,424 +0,0 @@ -import numpy as np -import pybase64 -import pytest -import torch -from PIL import Image -from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate -from transformers import AutoProcessor - -from miles.utils.processing_utils import encode_image_for_rollout_engine -from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo -from miles.utils.types import Sample - -_ = generation_env - -# ------------------------------------ fixtures and consts ---------------------------------------- - - -MODEL_NAME = "Qwen/Qwen3-0.6B" -PROMPT = "What is 1+7?" -PROMPT_TOKENS = [3838, 374, 220, 16, 10, 22, 30] -PROMPT_TOKEN_LEN = len(PROMPT_TOKENS) -RESPONSE_TOKENS = [59, 79075, 90, 23, 92] -RESPONSE_TEXT = "\\boxed{8}" -RESPONSE_LOG_PROBS = [-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125] -SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} -DEFAULT_MAX_NEW_TOKENS = SAMPLING_PARAMS["max_new_tokens"] - - -@pytest.fixture(params=["old_sglang_rollout", "single_turn", "multi_turn_single_sample", "multi_turn_multi_samples"]) -def variant(request): - return request.param - - -def expected_request( - variant: str, - *, - input_ids: list[int] | None = None, - sampling_params: dict | None = None, - return_routed_experts: bool = False, - image_data: list[str] | None = None, -) -> dict: - result = { - "input_ids": input_ids or PROMPT_TOKENS, - "sampling_params": sampling_params or SAMPLING_PARAMS, - "return_logprob": True, - } - if variant in ("single_turn", "multi_turn_single_sample", "multi_turn_multi_samples") or return_routed_experts: - result["return_routed_experts"] = return_routed_experts - if image_data is not None: - result["image_data"] = image_data - return result - - -class _Unset: - pass - - -_UNSET = _Unset() - - -def expected_sample( - variant: str, - *, - prompt: str = PROMPT, - response: str = RESPONSE_TEXT, - response_length: int = 5, - tokens: list[int] | None | _Unset = _UNSET, - rollout_log_probs: list[float] | None | _Unset = _UNSET, - status: Sample.Status = Sample.Status.COMPLETED, - cached_tokens: int = 0, - prompt_tokens: int = 7, - weight_versions: list[str] | None = None, - rollout_routed_experts: np.ndarray | None = None, - spec_info: Sample.SpecInfo | None = None, - multimodal_inputs: dict | None = None, - multimodal_train_inputs: dict | None = None, - loss_mask: list[int] | None | _Unset = _UNSET, -) -> Sample: - actual_response_length = response_length if response_length is not None else len(RESPONSE_TOKENS) - if isinstance(loss_mask, _Unset): - loss_mask = ( - [1] * actual_response_length - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") - else None - ) - - return Sample( - group_index=None, - index=None, - prompt=prompt, - tokens=PROMPT_TOKENS + RESPONSE_TOKENS if isinstance(tokens, _Unset) else tokens, - multimodal_inputs=multimodal_inputs, - multimodal_train_inputs=multimodal_train_inputs, - response=response, - response_length=response_length, - label=None, - reward=None, - loss_mask=loss_mask, - weight_versions=weight_versions or [], - rollout_log_probs=RESPONSE_LOG_PROBS if isinstance(rollout_log_probs, _Unset) else rollout_log_probs, - rollout_routed_experts=rollout_routed_experts, - remove_sample=False, - status=status, - metadata={}, - train_metadata=None, - non_generation_time=0.0, - spec_info=spec_info or Sample.SpecInfo(), - prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=cached_tokens, total_prompt_tokens=prompt_tokens), - ) - - -def _make_sample(tokens=None, response="", response_length=0, status=Sample.Status.PENDING, multimodal_inputs=None): - return make_sample( - prompt=PROMPT, - tokens=tokens, - response=response, - response_length=response_length, - status=status, - multimodal_inputs=multimodal_inputs, - ) - - -def _run_generate(variant: str, env: GenerateEnv, sample: Sample | None = None, sampling_params: dict | None = None): - return run_generate(env, sample or _make_sample(), sampling_params or SAMPLING_PARAMS, variant=variant) - - -# ------------------------------------ tests ---------------------------------------- - - -class TestBasicGeneration: - def test_basic_generation(self, variant, generation_env): - result = _run_generate(variant, generation_env) - assert result.requests == [expected_request(variant)] - assert listify(result.sample) == [expected_sample(variant)] - - -class TestResumedSingleTurn: - def test_two_consecutive_calls_on_same_sample(self, variant, generation_env): - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): - pytest.skip("not tested yet") - partial_text = "\\boxed" - partial_tokens = [59, 79075] - partial_log_probs = [-0.0, -0.0078125] - - remaining_text = "{8}" - remaining_tokens = [90, 23, 92] - remaining_log_probs = [-0.0, -0.0078125, -0.015625] - - generation_env.mock_server.process_fn = lambda _: ProcessResult(text=partial_text, finish_reason="abort") - sample = _make_sample() - result1 = _run_generate(variant, generation_env, sample) - assert result1.requests == [expected_request(variant)] - assert result1.sample == expected_sample( - variant, - response=partial_text, - response_length=2, - tokens=PROMPT_TOKENS + partial_tokens, - rollout_log_probs=partial_log_probs, - status=Sample.Status.ABORTED, - ) - - generation_env.mock_server.process_fn = lambda _: ProcessResult(text=remaining_text, finish_reason="stop") - result2 = _run_generate(variant, generation_env, result1.sample) - tokens_after_turn1 = PROMPT_TOKENS + partial_tokens - assert result2.requests == [ - expected_request( - variant, - input_ids=tokens_after_turn1, - sampling_params={"max_new_tokens": 14, "temperature": 0.7}, - ) - ] - assert result2.sample == expected_sample( - variant, - response=partial_text + remaining_text, - response_length=2 + 3, - tokens=tokens_after_turn1 + remaining_tokens, - rollout_log_probs=partial_log_probs + remaining_log_probs, - prompt_tokens=len(PROMPT_TOKENS) + len(tokens_after_turn1), - status=Sample.Status.COMPLETED, - ) - - -class TestFinishReason: - @pytest.mark.parametrize( - "generation_env,expected_status", - [ - ({"process_fn_kwargs": {"finish_reason": "stop"}}, Sample.Status.COMPLETED), - ({"process_fn_kwargs": {"finish_reason": "length"}}, Sample.Status.TRUNCATED), - ({"process_fn_kwargs": {"finish_reason": "abort"}}, Sample.Status.ABORTED), - ], - indirect=["generation_env"], - ) - def test_finish_reason_sets_status(self, variant, generation_env, expected_status): - result = _run_generate(variant, generation_env) - assert result.requests == [expected_request(variant)] - assert listify(result.sample) == [expected_sample(variant, status=expected_status)] - - -class TestRoutedExperts: - @pytest.mark.parametrize( - "generation_env", - [ - { - "args_kwargs": {"use_rollout_routing_replay": True}, - "process_fn_kwargs": {"routed_experts": "placeholder"}, - } - ], - indirect=True, - ) - def test_routed_experts_enabled_and_parsed(self, variant, generation_env): - num_layers, moe_router_topk = 2, 4 - num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS) - routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape( - num_tokens - 1, num_layers, moe_router_topk - ) - - generation_env.args.num_layers = num_layers - generation_env.args.moe_router_topk = moe_router_topk - routed_experts_str = pybase64.b64encode(routed_experts_array.tobytes()).decode("ascii") - generation_env.mock_server.process_fn = lambda _: ProcessResult( - text=RESPONSE_TEXT, - finish_reason="stop", - meta_info=ProcessResultMetaInfo(routed_experts=routed_experts_str), - ) - - result = _run_generate(variant, generation_env) - assert result.requests == [expected_request(variant, return_routed_experts=True)] - sample = result.sample[0] if isinstance(result.sample, list) else result.sample - assert sample.rollout_routed_experts is not None - assert sample.rollout_routed_experts.shape == (num_tokens - 1, num_layers, moe_router_topk) - np.testing.assert_array_equal(sample.rollout_routed_experts, routed_experts_array) - - -class TestMetaInfo: - @pytest.mark.parametrize( - "generation_env", [{"process_fn_kwargs": {"cached_tokens": 3, "weight_version": "v1.0"}}], indirect=True - ) - def test_meta_info_fields_updated(self, variant, generation_env): - result = _run_generate(variant, generation_env) - assert result.requests == [expected_request(variant)] - assert listify(result.sample) == [expected_sample(variant, cached_tokens=3, weight_versions=["v1.0"])] - - @pytest.mark.parametrize( - "generation_env", - [ - { - "args_kwargs": {"sglang_speculative_algorithm": "EAGLE"}, - "process_fn_kwargs": {"spec_accept_token_num": 10, "spec_draft_token_num": 15, "spec_verify_ct": 3}, - } - ], - indirect=True, - ) - def test_spec_info_updated(self, variant, generation_env): - result = _run_generate(variant, generation_env) - assert result.requests == [expected_request(variant)] - assert listify(result.sample) == [ - expected_sample( - variant, - spec_info=Sample.SpecInfo( - spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3, completion_token_num=5 - ), - ) - ] - - -class TestInputStatusValidation: - @pytest.mark.parametrize("status", [Sample.Status.PENDING, Sample.Status.ABORTED]) - def test_allowed_statuses(self, variant, generation_env, status): - result = _run_generate(variant, generation_env, _make_sample(status=status)) - assert result.requests == [expected_request(variant)] - assert listify(result.sample) == [expected_sample(variant)] - - @pytest.mark.parametrize("status", [Sample.Status.COMPLETED, Sample.Status.TRUNCATED]) - def test_rejected_statuses(self, variant, generation_env, status): - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): - pytest.skip("not tested yet") - with pytest.raises(AssertionError): - _run_generate(variant, generation_env, _make_sample(status=status)) - - -class TestPayloadStructure: - def test_sampling_params_passed_through(self, variant, generation_env): - result = _run_generate( - variant, generation_env, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9} - ) - assert result.requests == [ - expected_request(variant, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) - ] - assert listify(result.sample) == [expected_sample(variant)] - - -class TestBoundaryConditions: - def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): - pytest.skip("not tested yet") - existing_tokens = [1, 2, 3, 4, 5, 6, 7] + list(range(100, 110)) - sample = _make_sample(tokens=existing_tokens, response="x" * 10, response_length=10) - - result = _run_generate(variant, generation_env, sample, {"max_new_tokens": 10, "temperature": 0.7}) - assert result.requests == [] - assert result.sample == expected_sample( - variant, - response="x" * 10, - response_length=10, - tokens=existing_tokens, - rollout_log_probs=None, - status=Sample.Status.TRUNCATED, - prompt_tokens=0, - ) - - @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"rollout_max_context_len": 5}}], indirect=True) - def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): - if variant == "old_sglang_rollout": - pytest.skip("old_sglang_rollout does not support rollout_max_context_len") - if variant == "multi_turn_multi_samples": - pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") - result = _run_generate(variant, generation_env) - assert result.requests == [] - tokens = PROMPT_TOKENS if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") else [] - assert listify(result.sample) == [ - expected_sample( - variant, - response="", - response_length=0, - tokens=tokens, - rollout_log_probs=None, - status=Sample.Status.TRUNCATED, - prompt_tokens=0, - loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, - ) - ] - - @pytest.mark.parametrize( - "generation_env,expected_max_new_tokens", - [ - ({"args_kwargs": {"rollout_max_context_len": 10}}, 10 - PROMPT_TOKEN_LEN), - ({"args_kwargs": {"rollout_max_context_len": 8}}, 8 - PROMPT_TOKEN_LEN), - ({"args_kwargs": {"rollout_max_context_len": 100}}, DEFAULT_MAX_NEW_TOKENS), - ], - indirect=["generation_env"], - ) - def test_moderate_length_input_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): - if variant == "old_sglang_rollout": - pytest.skip("old_sglang_rollout does not support rollout_max_context_len") - result = _run_generate(variant, generation_env) - assert len(result.requests) == 1 - assert result.requests[0]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens - assert result.requests[0]["sampling_params"]["temperature"] == SAMPLING_PARAMS["temperature"] - assert listify(result.sample) == [expected_sample(variant)] - - @pytest.mark.parametrize( - "generation_env", - [{"args_kwargs": {"rollout_max_context_len": PROMPT_TOKEN_LEN}}], - indirect=True, - ) - def test_adjusted_max_new_tokens_zero_returns_truncated(self, variant, generation_env): - if variant == "old_sglang_rollout": - pytest.skip("old_sglang_rollout does not support rollout_max_context_len") - if variant == "multi_turn_multi_samples": - pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") - result = _run_generate(variant, generation_env) - assert result.requests == [] - tokens = PROMPT_TOKENS if variant == "multi_turn_single_sample" else [] - assert listify(result.sample) == [ - expected_sample( - variant, - response="", - response_length=0, - tokens=tokens, - rollout_log_probs=None, - status=Sample.Status.TRUNCATED, - prompt_tokens=0, - loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, - ) - ] - - -class TestEmptyResponse: - @pytest.mark.parametrize("generation_env", [{"process_fn_kwargs": {"response_text": ""}}], indirect=True) - def test_empty_response(self, variant, generation_env): - result = _run_generate(variant, generation_env) - assert result.requests == [expected_request(variant)] - assert listify(result.sample) == [ - expected_sample(variant, response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[]) - ] - - -VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" - - -class TestMultimodal: - @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True) - def test_multimodal_inputs_processed(self, variant, generation_env): - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): - pytest.skip("not tested yet") - test_image = Image.new("RGB", (64, 64), color="red") - multimodal_inputs = {"images": [test_image]} - processor = AutoProcessor.from_pretrained(VLM_MODEL_NAME, trust_remote_code=True) - expected_mti = { - k: v - for k, v in processor(text=PROMPT, **multimodal_inputs).items() - if k not in ["input_ids", "attention_mask"] - } - - result = _run_generate(variant, generation_env, _make_sample(multimodal_inputs=multimodal_inputs)) - - assert result.requests == [ - expected_request( - variant, - input_ids=PROMPT_TOKENS, - image_data=[encode_image_for_rollout_engine(test_image)], - ) - ] - actual_mti = result.sample.multimodal_train_inputs - assert actual_mti is not None - assert set(actual_mti.keys()) == set(expected_mti.keys()) - assert torch.all(actual_mti["pixel_values"] == expected_mti["pixel_values"]) - assert torch.all(actual_mti["image_grid_thw"] == expected_mti["image_grid_thw"]) - assert result.sample == expected_sample( - variant, - tokens=PROMPT_TOKENS + RESPONSE_TOKENS, - multimodal_inputs=multimodal_inputs, - multimodal_train_inputs=actual_mti, - ) diff --git a/tests/fast/rollout/generate_hub/test_tool_call_utils.py b/tests/fast/rollout/generate_hub/test_tool_call_utils.py deleted file mode 100644 index 0f2305e753..0000000000 --- a/tests/fast/rollout/generate_hub/test_tool_call_utils.py +++ /dev/null @@ -1,99 +0,0 @@ -import pytest - -from miles.rollout.generate_utils.tool_call_utils import _DUMMY_USER, _build_dummy_assistant, tokenize_tool_responses - -TOOL_CALL_TEST_MODELS = [ - "Qwen/Qwen2.5-0.5B-Instruct", - "Qwen/Qwen3-0.6B", - "Qwen/Qwen3-4B-Instruct-2507", - "Qwen/Qwen3-Coder-30B-A3B-Instruct", - # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI - "mistralai/Mistral-7B-Instruct-v0.3", - "deepseek-ai/DeepSeek-V3", - "stepfun-ai/step3", - "MiniMaxAI/MiniMax-M2", - "internlm/internlm3-8b-instruct", - "THUDM/glm-4-9b-chat", - "moonshotai/Kimi-K2-Instruct", - "XiaomiMiMo/MiMo-7B-RL", -] - -SINGLE_TOOL_CALL_ONLY_MODELS = [ - # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo -] - -# Models where tokenize->decode produces extra whitespace vs direct string diff -TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS = [ - "THUDM/glm-4-9b-chat", -] - -SAMPLE_TOOL_RESPONSES = [ - { - "role": "tool", - "tool_call_id": "call00000", - "content": '{"year": 2026}', - "name": "get_year", - }, - { - "role": "tool", - "tool_call_id": "call00001", - "content": '{"temperature": 25}', - "name": "get_temperature", - }, -] - - -class TestTokenizeToolResponses: - @pytest.mark.parametrize("model_name", ["Qwen/Qwen3-0.6B"]) - def test_snapshot(self, model_name): - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - token_ids = tokenize_tool_responses(SAMPLE_TOOL_RESPONSES, tokenizer) - decoded = tokenizer.decode(token_ids) - - assert decoded == ( - "<|im_start|>user\n" - "\n" - '{"year": 2026}\n' - "\n" - "\n" - '{"temperature": 25}\n' - "<|im_end|>\n" - "<|im_start|>assistant\n" - ) - - @pytest.mark.parametrize("num_tools", [1, 2]) - @pytest.mark.parametrize("model_name", TOOL_CALL_TEST_MODELS) - def test_tokenize_tool_responses(self, model_name, num_tools): - if num_tools > 1 and model_name in SINGLE_TOOL_CALL_ONLY_MODELS: - pytest.skip(f"{model_name} only supports single tool call") - - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - - tool_responses = SAMPLE_TOOL_RESPONSES[:num_tools] - assert len(tool_responses) == num_tools - - actual_token_ids = tokenize_tool_responses(tool_responses, tokenizer) - actual_str = tokenizer.decode(actual_token_ids) - - dummy_assistant = _build_dummy_assistant(tool_responses) - base_messages = [_DUMMY_USER, dummy_assistant] - expected_str = self._compute_chat_template_diff(base_messages, tool_responses, tokenizer) - - if model_name in TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS: - # Some models produce whitespace differences between tokenize->decode and direct string diff - actual_str = actual_str.replace(" ", "") - expected_str = expected_str.replace(" ", "") - - assert actual_str == expected_str, f"{model_name=}" - - @staticmethod - def _compute_chat_template_diff(base_messages, extra_messages, tokenizer) -> str: - text_with = tokenizer.apply_chat_template( - base_messages + extra_messages, tokenize=False, add_generation_prompt=True - ) - text_without = tokenizer.apply_chat_template(base_messages, tokenize=False, add_generation_prompt=False) - return text_with[len(text_without) :] diff --git a/tests/fast/rollout/generate_utils/__init__.py b/tests/fast/rollout/generate_utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/fast/rollout/generate_utils/test_sample_utils.py b/tests/fast/rollout/generate_utils/test_sample_utils.py deleted file mode 100644 index c53fbbb56a..0000000000 --- a/tests/fast/rollout/generate_utils/test_sample_utils.py +++ /dev/null @@ -1,156 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from miles.rollout.generate_utils.sample_utils import _merge_sample_pair -from miles.utils.types import Sample - - -@pytest.fixture -def mock_tokenizer(): - tokenizer = MagicMock() - tokenizer.decode = lambda tokens: f"" - return tokenizer - - -def make_sample( - prompt="test_prompt", - tokens=None, - response="", - response_length=0, - loss_mask=None, - rollout_log_probs=None, - status=Sample.Status.COMPLETED, - label="test_label", - reward=1.0, - index=0, - group_index=0, -): - return Sample( - prompt=prompt, - tokens=tokens or [], - response=response, - response_length=response_length, - loss_mask=loss_mask, - rollout_log_probs=rollout_log_probs, - status=status, - label=label, - reward=reward, - index=index, - group_index=group_index, - ) - - -class TestMergeSamples: - def test_basic_merge(self, mock_tokenizer): - a = make_sample( - tokens=[1, 2, 3, 10, 11, 12], - response="response1", - response_length=3, - loss_mask=[1, 1, 1], - rollout_log_probs=[-0.1, -0.2, -0.3], - ) - b = make_sample( - tokens=[1, 2, 3, 10, 11, 12, 20, 21, 30, 31, 32], - response="response2", - response_length=3, - loss_mask=[1, 1, 1], - rollout_log_probs=[-0.4, -0.5, -0.6], - status=Sample.Status.TRUNCATED, - ) - - merged = _merge_sample_pair(a, b, mock_tokenizer) - - assert merged.tokens == b.tokens - assert merged.response_length == 3 + 2 + 3 - assert merged.loss_mask == [1, 1, 1, 0, 0, 1, 1, 1] - assert merged.rollout_log_probs == [-0.1, -0.2, -0.3, 0.0, 0.0, -0.4, -0.5, -0.6] - assert merged.prompt == a.prompt - assert merged.status == b.status - assert merged.label == a.label - assert merged.index == a.index - assert merged.group_index == a.group_index - assert "response1" in merged.response - assert "response2" in merged.response - assert "" in merged.response - - def test_loss_mask_none_defaults_to_all_ones(self, mock_tokenizer): - a = make_sample( - tokens=[1, 2, 10], - response_length=1, - loss_mask=None, - rollout_log_probs=None, - ) - b = make_sample( - tokens=[1, 2, 10, 20, 30], - response_length=1, - loss_mask=None, - rollout_log_probs=None, - ) - - merged = _merge_sample_pair(a, b, mock_tokenizer) - - assert merged.loss_mask == [1, 0, 1] - assert merged.rollout_log_probs == [0.0, 0.0, 0.0] - - def test_tokens_prefix_mismatch_raises(self, mock_tokenizer): - a = make_sample( - tokens=[1, 2, 3], - response_length=1, - loss_mask=[1], - ) - b = make_sample( - tokens=[1, 2, 99, 20, 30], - response_length=1, - loss_mask=[1], - ) - - with pytest.raises(AssertionError, match="b.tokens must start with a.tokens"): - _merge_sample_pair(a, b, mock_tokenizer) - - def test_field_mismatch_raises(self, mock_tokenizer): - a = make_sample( - tokens=[1, 2, 10], - response_length=1, - loss_mask=[1], - index=0, - ) - b = make_sample( - tokens=[1, 2, 10, 20, 30], - response_length=1, - loss_mask=[1], - index=1, - ) - - with pytest.raises(AssertionError, match="index mismatch"): - _merge_sample_pair(a, b, mock_tokenizer) - - def test_obs_len_invalid_raises(self, mock_tokenizer): - a = make_sample( - tokens=[1, 2, 10], - response_length=1, - loss_mask=[1], - ) - b = make_sample( - tokens=[1, 2, 10, 30], - response_length=1, - loss_mask=[1], - ) - - with pytest.raises(AssertionError, match="obs_len must be > 0"): - _merge_sample_pair(a, b, mock_tokenizer) - - def test_sample_validate_fails_raises(self, mock_tokenizer): - a = make_sample( - tokens=[1, 2, 10, 11], - response_length=2, - loss_mask=[1], - ) - b = make_sample( - tokens=[1, 2, 10, 11, 20, 30], - response_length=1, - loss_mask=[1], - ) - - with pytest.raises(AssertionError, match="loss_mask length"): - _merge_sample_pair(a, b, mock_tokenizer) diff --git a/tests/fast/rollout/inference_rollout/__init__.py b/tests/fast/rollout/inference_rollout/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/fast/rollout/inference_rollout/conftest.py b/tests/fast/rollout/inference_rollout/conftest.py deleted file mode 100644 index ca47edeeb6..0000000000 --- a/tests/fast/rollout/inference_rollout/conftest.py +++ /dev/null @@ -1,45 +0,0 @@ -from unittest.mock import patch - -import pytest - -from miles.utils.arguments import parse_args - - -def _build_mock_args(extra_argv: list[str] | None = None): - argv = [ - "pytest", - "--train-backend", - "fsdp", - "--rollout-batch-size", - "2", - "--n-samples-per-prompt", - "1", - "--num-rollout", - "1", - "--rollout-num-gpus", - "4", - "--rollout-num-gpus-per-engine", - "2", - "--hf-checkpoint", - "Qwen/Qwen3-0.6B", - "--prompt-data", - "/dev/null", - "--input-key", - "input", - "--label-key", - "label", - "--rm-type", - "math", - "--use-miles-router", - "--sglang-router-ip", - "127.0.0.1", - "--sglang-router-port", - "30000", - ] + (extra_argv or []) - with patch("sys.argv", argv): - return parse_args() - - -@pytest.fixture -def mock_args(): - return _build_mock_args() diff --git a/tests/fast/rollout/inference_rollout/integration/__init__.py b/tests/fast/rollout/inference_rollout/integration/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/fast/rollout/inference_rollout/integration/test_basic.py b/tests/fast/rollout/inference_rollout/integration/test_basic.py deleted file mode 100644 index 5b791829d5..0000000000 --- a/tests/fast/rollout/inference_rollout/integration/test_basic.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant -from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig -from tests.fast.rollout.inference_rollout.integration.utils import ( - MODULAR_ROLLOUT_BASE_ARGV, - expected_sample, - load_and_call_train, -) - -from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput -from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function - -_VARIANTS = [ - pytest.param( - RolloutEnvConfig( - extra_argv=[ - "--rollout-function-path", - "miles.rollout.sglang_rollout.generate_rollout", - "--eval-function-path", - "miles.rollout.sglang_rollout.generate_rollout", - "--custom-generate-function-path", - "miles.rollout.sglang_rollout.generate", - ] - ), - id="old_rollout_old_generate", - ), - pytest.param( - RolloutEnvConfig( - extra_argv=[ - "--rollout-function-path", - "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", - "--custom-generate-function-path", - "miles.rollout.sglang_rollout.generate", - ] - ), - id="new_rollout_old_generate", - ), - pytest.param( - RolloutEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant("single_turn")), - id="new_rollout_new_generate", - ), -] - - -@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) -def test_train(rollout_env): - env = rollout_env - out = load_and_call_train(env.args, env.data_source) - - assert len(out.samples) == env.args.rollout_batch_size - group = out.samples[0] - assert len(group) == env.args.n_samples_per_prompt - assert group[0] == expected_sample(group_index=0) - - -@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) -def test_eval(rollout_env): - env = rollout_env - fn = load_rollout_function( - RolloutFnConstructorInput(args=env.args, data_source=env.data_source), env.args.eval_function_path - ) - out = call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) - - assert "toy" in out.data - rewards = out.data["toy"]["rewards"] - samples = out.data["toy"]["samples"] - assert len(rewards) == len(samples) == env.args.n_samples_per_eval_prompt - assert rewards[0] == 1 - assert samples[0] == expected_sample(group_index=None) diff --git a/tests/fast/rollout/inference_rollout/integration/test_deterministic.py b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py deleted file mode 100644 index 69a2359117..0000000000 --- a/tests/fast/rollout/inference_rollout/integration/test_deterministic.py +++ /dev/null @@ -1,37 +0,0 @@ -import pytest - -from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train - - -@pytest.mark.parametrize( - "rollout_env,expected_seeds", - [ - pytest.param( - integration_env_config( - [ - "--sglang-enable-deterministic-inference", - "--rollout-seed", - "42", - "--n-samples-per-prompt", - "3", - "--rollout-batch-size", - "1", - ] - ), - {42, 43, 44}, - id="enabled", - ), - pytest.param( - integration_env_config(["--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), - {None}, - id="disabled", - ), - ], - indirect=["rollout_env"], -) -def test_sampling_seeds(rollout_env, expected_seeds): - env = rollout_env - load_and_call_train(env.args, env.data_source) - - seeds = {req.get("sampling_params", {}).get("sampling_seed") for req in env.mock_server.request_log} - assert seeds == expected_seeds diff --git a/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py deleted file mode 100644 index 0ca5743ac5..0000000000 --- a/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py +++ /dev/null @@ -1,46 +0,0 @@ -from contextlib import nullcontext - -import pytest -from tests.fast.rollout.inference_rollout.integration.utils import ( - MIXED_DATA_ROWS, - filter_by_reward, - integration_env_config, - load_and_call_train, -) - -from miles.utils.misc import function_registry - - -@pytest.mark.parametrize( - "rollout_env,use_filter,expect_all_correct", - [ - pytest.param( - integration_env_config(["--rollout-batch-size", "4"], data_rows=MIXED_DATA_ROWS), - False, - False, - id="no_filter", - ), - pytest.param( - integration_env_config( - ["--rollout-batch-size", "3", "--dynamic-sampling-filter-path", "test:filter_by_reward"], - data_rows=MIXED_DATA_ROWS, - ), - True, - True, - id="with_filter", - ), - ], - indirect=["rollout_env"], -) -def test_filter_effect(rollout_env, use_filter, expect_all_correct): - env = rollout_env - ctx = function_registry.temporary("test:filter_by_reward", filter_by_reward) if use_filter else nullcontext() - - with ctx: - out = load_and_call_train(env.args, env.data_source) - - rewards = {group[0].reward for group in out.samples} - if expect_all_correct: - assert rewards == {1}, "Filter should keep only correct samples" - else: - assert 0 in rewards, "Without filter, incorrect samples should be present" diff --git a/tests/fast/rollout/inference_rollout/integration/test_group_rm.py b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py deleted file mode 100644 index afd870c302..0000000000 --- a/tests/fast/rollout/inference_rollout/integration/test_group_rm.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train - - -@pytest.mark.parametrize( - "rollout_env", - [ - pytest.param( - integration_env_config(["--group-rm", "--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), - id="group_rm_enabled", - ), - ], - indirect=True, -) -def test_group_rm_rewards_set(rollout_env): - env = rollout_env - out = load_and_call_train(env.args, env.data_source) - - assert len(out.samples) == env.args.rollout_batch_size - rewards = [sample.reward for group in out.samples for sample in group] - assert all(r in (0, 1) for r in rewards) diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py deleted file mode 100644 index 2b12d3d88f..0000000000 --- a/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest -from tests.fast.fixtures.rollout_fixtures import DEFAULT_DATA_ROWS, RolloutEnvConfig -from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_train - -from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput -from miles.utils.misc import function_registry -from miles.utils.types import Sample - - -async def _multi_sample_generate(input: GenerateFnInput) -> GenerateFnOutput: - sample = input.sample - s1 = Sample( - prompt=sample.prompt, - response="\\boxed{8}", - response_length=5, - tokens=sample.tokens + [59, 79075, 90, 23, 92], - label=sample.label, - reward=None, - status=Sample.Status.COMPLETED, - ) - s2 = Sample( - prompt=sample.prompt, - response="\\boxed{8}", - response_length=5, - tokens=sample.tokens + [59, 79075, 90, 23, 92], - label=sample.label, - reward=0.5, - status=Sample.Status.COMPLETED, - ) - return GenerateFnOutput(samples=[s1, s2]) - - -@pytest.mark.parametrize( - "rollout_env", - [ - pytest.param( - RolloutEnvConfig( - extra_argv=MODULAR_ROLLOUT_BASE_ARGV - + [ - "--custom-generate-function-path", - "test:multi_sample_generate", - "--rollout-batch-size", - "1", - "--n-samples-per-prompt", - "1", - ], - data_rows=DEFAULT_DATA_ROWS, - ), - id="multi_sample_output", - ), - ], - indirect=True, -) -def test_multi_sample_output_preserves_existing_reward(rollout_env): - env = rollout_env - with function_registry.temporary("test:multi_sample_generate", _multi_sample_generate): - out = load_and_call_train(env.args, env.data_source) - - assert len(out.samples) == env.args.rollout_batch_size - group = out.samples[0] - assert isinstance(group[0], list) - samples = group[0] - assert len(samples) == 2 - assert samples[0].reward == 1 - assert samples[1].reward == 0.5 diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py deleted file mode 100644 index c41d713991..0000000000 --- a/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py +++ /dev/null @@ -1,114 +0,0 @@ -from typing import Any - -import pytest -from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant -from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig -from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_rollout - -from miles.utils.test_utils.mock_tools import TwoTurnStub -from miles.utils.types import Sample - - -TWO_TURN_DATA_ROWS = [{"input": [{"role": "user", "content": TwoTurnStub.USER_QUESTION}], "label": "2008"}] - -_VARIANT_NAMES = [ - "multi_turn_single_sample", - "multi_turn_multi_samples", - "agentic_tool_call_single_sample", - "agentic_tool_call_multi_samples", -] - -_BASE_EXTRA_ARGV = [ - "--rollout-batch-size", - "2", - "--n-samples-per-prompt", - "2", - "--n-samples-per-eval-prompt", - "2", - "--custom-rm-path", - "tests.fast.rollout.inference_rollout.integration.test_multi_turn._simple_reward_function", -] - - -def _config_for_variant(variant: str) -> RolloutEnvConfig: - return RolloutEnvConfig( - extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + _BASE_EXTRA_ARGV, - data_rows=TWO_TURN_DATA_ROWS, - ) - - -@pytest.mark.parametrize( - "variant,rollout_env", - [pytest.param(variant, _config_for_variant(variant), id=variant) for variant in _VARIANT_NAMES], - indirect=["rollout_env"], -) -@pytest.mark.parametrize("test_type", ["train", "eval"]) -def test_rollout(rollout_env, variant, test_type): - env = rollout_env - env.mock_server.process_fn = TwoTurnStub.process_fn - - out = load_and_call_rollout(env.args, env.data_source, mode=test_type) - - if test_type == "train": - assert len(out.samples) == env.args.rollout_batch_size - group = out.samples[0] - _verify_samples(variant, group) - else: - assert "toy" in out.data - samples = out.data["toy"]["samples"] - _verify_samples(variant, samples) - - -def _verify_samples(variant: str, samples: list[Any]): - is_multi_samples = variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples") - - if is_multi_samples: - if len(samples) > 0 and isinstance(samples[0], list): - # Train mode: list[list[Sample]], grouped by prompt - assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" - for group_sample in samples: - assert isinstance(group_sample, list), "multi_samples variant should return list[Sample] per generate" - _verify_group_samples(group_sample) - else: - # Eval mode: list[Sample], flattened - # n_samples_per_eval_prompt=2, and each generate returns 2 turns, so 2*2=4 samples - assert ( - len(samples) == 4 - ), f"n_samples_per_eval_prompt=2, each generate returns 2 turns, so should have 4 samples, got {len(samples)}" - # Group samples by prompt (every 2 samples form a group) - group_samples_list = [samples[i : i + 2] for i in range(0, len(samples), 2)] - for group_samples in group_samples_list: - _verify_group_samples(group_samples) - else: - assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" - for sample in samples: - assert isinstance(sample, Sample), "single_sample variant should return Sample, not list" - _verify_sample(sample) - - -def _verify_group_samples(group_samples: list[Sample], expected_count: int = 2): - assert len(group_samples) == expected_count, f"Group should have {expected_count} samples (one per turn)" - for i, sample in enumerate(group_samples): - _verify_sample(sample, expect_answer=(i == len(group_samples) - 1)) - - -def _verify_sample(sample: Sample, expected_reward: float = 1.0, expect_answer: bool = True): - assert sample.status == Sample.Status.COMPLETED - assert sample.reward == expected_reward, f"Sample should have reward={expected_reward}" - if expect_answer: - assert "2008" in sample.response, "Response should contain final answer '2008'" - - -async def _simple_reward_function(args, samples: Sample | list[Sample]) -> float | list[float]: - if isinstance(samples, list): - # For multi_samples variants, use the last sample's reward - if getattr(args, "generate_multi_samples", False): - return [_check_reward(samples[-1])] * len(samples) - else: - return [_check_reward(sample) for sample in samples] - else: - return _check_reward(samples) - - -def _check_reward(sample: Sample) -> float: - return float(sample.response and (str(sample.label) in sample.response)) diff --git a/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py deleted file mode 100644 index 0812962cc7..0000000000 --- a/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -from tests.fast.rollout.inference_rollout.integration.utils import ( - filter_by_reward, - integration_env_config, - load_and_call_train, -) - -from miles.utils.misc import function_registry - -_DATA_ROWS = [ - {"input": "What is 1+7?", "label": "8"}, - {"input": "What is 1+8?", "label": "wrong"}, - {"input": "What is 1+9?", "label": "wrong"}, - {"input": "What is 1+6?", "label": "wrong"}, -] - -_BASE_ARGV = [ - "--over-sampling-batch-size", - "4", - "--dynamic-sampling-filter-path", - "test:filter_by_reward", -] - - -def _over_sampling_config(rollout_batch_size: int): - return integration_env_config(["--rollout-batch-size", str(rollout_batch_size)] + _BASE_ARGV, data_rows=_DATA_ROWS) - - -@pytest.mark.parametrize( - "rollout_env,expected_rounds", - [ - pytest.param(_over_sampling_config(1), 1, id="one_round"), - pytest.param(_over_sampling_config(2), 2, id="two_rounds"), - ], - indirect=["rollout_env"], -) -def test_over_sampling_rounds(rollout_env, expected_rounds): - env = rollout_env - - with function_registry.temporary("test:filter_by_reward", filter_by_reward): - out = load_and_call_train(env.args, env.data_source) - - assert len(out.samples) == env.args.rollout_batch_size - assert all(group[0].reward == 1 for group in out.samples) - - requests_count = len(env.mock_server.request_log) - expected_requests = expected_rounds * env.args.over_sampling_batch_size - assert requests_count == expected_requests, f"Expected {expected_rounds} round(s) = {expected_requests} requests" diff --git a/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py deleted file mode 100644 index 36e78c16c1..0000000000 --- a/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py +++ /dev/null @@ -1,67 +0,0 @@ -from unittest.mock import Mock - -import pytest -from tests.fast.rollout.inference_rollout.integration.utils import ( - filter_by_reward, - integration_env_config, - load_and_call_train, -) - -from miles.utils.misc import function_registry - -# Data with only 2 reward=1 samples out of 4. -# This ensures all 4 samples must be generated to collect 2 valid ones. -_FILTER_TEST_DATA_ROWS = [ - {"input": "What is 1+7?", "label": "8"}, # reward=1 - {"input": "What is 1+8?", "label": "wrong"}, # reward=0 - {"input": "What is 1+9?", "label": "wrong"}, # reward=0 - {"input": "What is 1+6?", "label": "7"}, # reward=1 -] - - -@pytest.mark.parametrize( - "rollout_env", - [ - pytest.param( - integration_env_config( - [ - "--rollout-batch-size", - "2", - "--over-sampling-batch-size", - "4", - "--dynamic-sampling-filter-path", - "test:filter_by_reward", - "--rollout-sample-filter-path", - "test:sample_filter", - "--rollout-all-samples-process-path", - "test:all_samples_process", - ], - data_rows=_FILTER_TEST_DATA_ROWS, - ), - id="sample_filter_vs_all_samples", - ), - ], - indirect=True, -) -def test_sample_filter_and_all_samples_process(rollout_env): - env = rollout_env - sample_filter_mock = Mock() - all_samples_process_mock = Mock() - - with ( - function_registry.temporary("test:filter_by_reward", filter_by_reward), - function_registry.temporary("test:sample_filter", sample_filter_mock), - function_registry.temporary("test:all_samples_process", all_samples_process_mock), - ): - load_and_call_train(env.args, env.data_source) - - sample_filter_mock.assert_called_once() - _, filtered_data = sample_filter_mock.call_args[0] - rewards = [g[0][0].reward if isinstance(g[0], list) else g[0].reward for g in filtered_data] - assert all(r == 1 for r in rewards) - - all_samples_process_mock.assert_called_once() - _, all_samples, data_source = all_samples_process_mock.call_args[0] - assert data_source is not None - - assert len(all_samples) > len(filtered_data), "all_samples_process should see more samples than sample_filter" diff --git a/tests/fast/rollout/inference_rollout/integration/test_semaphore.py b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py deleted file mode 100644 index 889a9ff8ac..0000000000 --- a/tests/fast/rollout/inference_rollout/integration/test_semaphore.py +++ /dev/null @@ -1,33 +0,0 @@ -import pytest - -from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train - -_DATA_ROWS = [{"input": f"What is 1+{i}?", "label": str(1 + i)} for i in range(10)] -_BASE_ARGV = ["--rollout-batch-size", "4", "--n-samples-per-prompt", "2"] - - -@pytest.mark.parametrize( - "rollout_env,expected_range", - [ - pytest.param( - integration_env_config( - ["--sglang-server-concurrency", "1"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 - ), - (1, 1), - id="limit_1", - ), - pytest.param( - integration_env_config( - ["--sglang-server-concurrency", "999"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 - ), - (2, 999), - id="no_limit", - ), - ], - indirect=["rollout_env"], -) -def test_max_concurrent(rollout_env, expected_range): - env = rollout_env - load_and_call_train(env.args, env.data_source) - min_expected, max_expected = expected_range - assert min_expected <= env.mock_server.max_concurrent <= max_expected diff --git a/tests/fast/rollout/inference_rollout/integration/utils.py b/tests/fast/rollout/inference_rollout/integration/utils.py deleted file mode 100644 index ad413cf949..0000000000 --- a/tests/fast/rollout/inference_rollout/integration/utils.py +++ /dev/null @@ -1,89 +0,0 @@ -from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant -from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig - -from miles.rollout.base_types import ( - RolloutFnConstructorInput, - RolloutFnEvalInput, - RolloutFnOutput, - RolloutFnTrainInput, -) -from miles.rollout.filter_hub.base_types import DynamicFilterOutput -from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function -from miles.utils.types import Sample - - -def expected_sample(*, group_index: int | None) -> Sample: - return Sample( - group_index=group_index, - index=0, - prompt="What is 1+7?", - tokens=[3838, 374, 220, 16, 10, 22, 30, 59, 79075, 90, 23, 92], - multimodal_inputs=None, - multimodal_train_inputs=None, - response="\\boxed{8}", - response_length=5, - label="8", - reward=1, - loss_mask=None, - weight_versions=[], - rollout_log_probs=[-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125], - rollout_routed_experts=None, - remove_sample=False, - status=Sample.Status.COMPLETED, - metadata={}, - train_metadata=None, - non_generation_time=0.0, - spec_info=Sample.SpecInfo( - spec_accept_token_num=0, spec_draft_token_num=0, spec_verify_ct=0, completion_token_num=0 - ), - prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=0, total_prompt_tokens=7), - ) - - -MODULAR_ROLLOUT_BASE_ARGV = [ - "--rollout-function-path", - "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", -] - -MIXED_DATA_ROWS = [ - {"input": "What is 1+7?", "label": "8"}, - {"input": "What is 1+8?", "label": "9"}, - {"input": "What is 1+9?", "label": "wrong"}, - {"input": "What is 1+6?", "label": "7"}, -] - - -def integration_env_config( - extra_argv: list[str], - data_rows: list[dict] | None = None, - latency: float = 0.0, - variant: str = "single_turn", -): - return RolloutEnvConfig( - extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + extra_argv, - data_rows=data_rows, - latency=latency, - ) - - -def load_and_call_rollout(args, data_source, mode: str = "train") -> RolloutFnOutput: - function_path = args.rollout_function_path if mode == "train" else args.eval_function_path - fn = load_rollout_function( - RolloutFnConstructorInput(args=args, data_source=data_source), - function_path, - ) - if mode == "train": - return call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0)) - else: - return call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) - - -def load_and_call_train(args, data_source): - return load_and_call_rollout(args, data_source, mode="train") - - -def filter_by_reward(args, samples, **kwargs): - reward = samples[0].reward if not isinstance(samples[0], list) else samples[0][0].reward - if reward == 1: - return DynamicFilterOutput(keep=True) - return DynamicFilterOutput(keep=False, reason="reward_zero") diff --git a/tests/fast/rollout/inference_rollout/test_compatibility.py b/tests/fast/rollout/inference_rollout/test_compatibility.py deleted file mode 100644 index ddfecd067b..0000000000 --- a/tests/fast/rollout/inference_rollout/test_compatibility.py +++ /dev/null @@ -1,196 +0,0 @@ -import asyncio -from unittest.mock import MagicMock - -import pytest - -from miles.rollout.base_types import ( - GenerateFnInput, - GenerateFnOutput, - RolloutFnConstructorInput, - RolloutFnEvalInput, - RolloutFnEvalOutput, - RolloutFnTrainInput, - RolloutFnTrainOutput, -) -from miles.rollout.inference_rollout.compatibility import ( - LegacyGenerateFnAdapter, - LegacyRolloutFnAdapter, - call_rollout_function, - load_generate_function, - load_rollout_function, -) -from miles.utils.async_utils import run -from miles.utils.misc import function_registry - - -@pytest.fixture -def constructor_input(): - return RolloutFnConstructorInput(args="dummy_args", data_source="dummy_data_source") - - -@pytest.fixture -def make_generate_fn_input(): - def _make(evaluation: bool = False): - state = MagicMock() - state.args = MagicMock() - - return GenerateFnInput( - state=state, - sample={"text": "test prompt"}, - sampling_params={"temperature": 0.7}, - evaluation=evaluation, - ) - - return _make - - -class TestSupportedRolloutFormats: - """ - Documentation test to show various supported rollout function formats - """ - - @pytest.mark.parametrize("evaluation", [False, True]) - def test_format_1_legacy_function_raw_output(self, constructor_input, evaluation): - def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): - if evaluation: - return {"metric": {"accuracy": 0.9}} - return [[{"text": "sample"}]] - - with function_registry.temporary("test:legacy_rollout", legacy_rollout_fn): - fn = load_rollout_function(constructor_input, "test:legacy_rollout") - - input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput - result = call_rollout_function(fn, input_cls(rollout_id=1)) - - assert isinstance(fn, LegacyRolloutFnAdapter) - if evaluation: - assert isinstance(result, RolloutFnEvalOutput) - assert result.data == {"metric": {"accuracy": 0.9}} - else: - assert isinstance(result, RolloutFnTrainOutput) - assert result.samples == [[{"text": "sample"}]] - - @pytest.mark.parametrize("evaluation", [False, True]) - def test_format_2_legacy_function_typed_output(self, constructor_input, evaluation): - def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): - if evaluation: - return RolloutFnEvalOutput(data={"ds": {"acc": 0.95}}) - return RolloutFnTrainOutput(samples=[[{"text": "typed"}]]) - - with function_registry.temporary("test:legacy_typed", legacy_rollout_fn): - fn = load_rollout_function(constructor_input, "test:legacy_typed") - - input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput - result = call_rollout_function(fn, input_cls(rollout_id=1)) - - if evaluation: - assert isinstance(result, RolloutFnEvalOutput) - assert result.data == {"ds": {"acc": 0.95}} - else: - assert isinstance(result, RolloutFnTrainOutput) - assert result.samples == [[{"text": "typed"}]] - - @pytest.mark.parametrize("evaluation", [False, True]) - def test_format_3_sync_class(self, constructor_input, evaluation): - class SyncRolloutFn: - def __init__(self, input: RolloutFnConstructorInput): - pass - - def __call__(self, input): - if input.evaluation: - return RolloutFnEvalOutput(data={"test": {"score": 1}}) - return RolloutFnTrainOutput(samples=[[{"text": "sync"}]]) - - with function_registry.temporary("test:sync_class", SyncRolloutFn): - fn = load_rollout_function(constructor_input, "test:sync_class") - - input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput - result = call_rollout_function(fn, input_cls(rollout_id=1)) - - assert isinstance(fn, SyncRolloutFn) - expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput - assert isinstance(result, expected_type) - - @pytest.mark.parametrize("evaluation", [False, True]) - def test_format_4_async_class(self, constructor_input, evaluation): - class AsyncRolloutFn: - def __init__(self, input: RolloutFnConstructorInput): - pass - - async def __call__(self, input): - await asyncio.sleep(0.001) - if input.evaluation: - return RolloutFnEvalOutput(data={"benchmark": {"accuracy": 0.98}}) - return RolloutFnTrainOutput(samples=[[{"text": "async"}]]) - - with function_registry.temporary("test:async_class", AsyncRolloutFn): - fn = load_rollout_function(constructor_input, "test:async_class") - - input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput - result = call_rollout_function(fn, input_cls(rollout_id=1)) - - assert isinstance(fn, AsyncRolloutFn) - expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput - assert isinstance(result, expected_type) - - -class TestSupportedGenerateFormats: - """ - Documentation test similar to TestSupportedRolloutFormats - """ - - @pytest.mark.parametrize("evaluation", [False, True]) - def test_format_1_legacy_function_with_evaluation_param(self, make_generate_fn_input, evaluation): - async def legacy_generate_fn(args, sample, sampling_params, evaluation=False): - return "my_sample" - - with function_registry.temporary("test:legacy_gen_eval", legacy_generate_fn): - fn = load_generate_function("test:legacy_gen_eval") - - result = run(fn(make_generate_fn_input(evaluation))) - - assert isinstance(fn, LegacyGenerateFnAdapter) - assert isinstance(result, GenerateFnOutput) - assert result.samples == "my_sample" - - @pytest.mark.parametrize("evaluation", [False, True]) - def test_format_2_legacy_function_without_evaluation_param(self, make_generate_fn_input, evaluation): - async def legacy_generate_fn(args, sample, sampling_params): - return "my_sample" - - with function_registry.temporary("test:legacy_gen", legacy_generate_fn): - fn = load_generate_function("test:legacy_gen") - - result = run(fn(make_generate_fn_input(evaluation))) - - assert isinstance(fn, LegacyGenerateFnAdapter) - assert isinstance(result, GenerateFnOutput) - assert result.samples == "my_sample" - - @pytest.mark.parametrize("evaluation", [False, True]) - def test_format_3_new_async_function_api(self, make_generate_fn_input, evaluation): - async def generate(input: GenerateFnInput) -> GenerateFnOutput: - return GenerateFnOutput(samples="my_sample") - - with function_registry.temporary("test:new_async", generate): - fn = load_generate_function("test:new_async") - - result = run(fn(make_generate_fn_input(evaluation))) - - assert isinstance(result, GenerateFnOutput) - assert result.samples == "my_sample" - - @pytest.mark.parametrize("evaluation", [False, True]) - def test_format_4_new_class_api(self, make_generate_fn_input, evaluation): - class MyGenerateFn: - async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: - return GenerateFnOutput(samples="my_sample") - - with function_registry.temporary("test:new_class", MyGenerateFn): - fn = load_generate_function("test:new_class") - - result = run(fn(make_generate_fn_input(evaluation))) - - assert isinstance(fn, MyGenerateFn) - assert isinstance(result, GenerateFnOutput) - assert result.samples == "my_sample" diff --git a/tests/fast/rollout/rm_hub/__init__.py b/tests/fast/rollout/rm_hub/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/fast/rollout/rm_hub/test_deepscaler.py b/tests/fast/rollout/rm_hub/test_deepscaler.py deleted file mode 100644 index bd4c606a68..0000000000 --- a/tests/fast/rollout/rm_hub/test_deepscaler.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest - -from miles.rollout.rm_hub.deepscaler import get_deepscaler_rule_based_reward - - -class TestGetDeepscalerRuleBasedReward: - @pytest.mark.parametrize( - "response,label,expected", - [ - (r"Let me analyze...The answer is \boxed{42}", "42", 1), - (r"Thinking...The answer is \boxed{wrong}", "42", 0), - (r"###Response\boxed{42}", "42", 1), - (r"###Response\boxed{wrong}", "42", 0), - (r"The answer is \boxed{42}", "42", 0), - (r"The answer is 42", "42", 0), - (r"\boxed{42}", "", 0), - (r"\boxed{42}", r"\boxed{42}", 1), - (r"\boxed{123}", 123, 1), - (r"\boxed{3.14}", 3.14, 1), - (r"\boxed{1/2}", "0.5", 1), - (r"\boxed{\frac{1}{2}}", "0.5", 1), - (r"First thoughtSecond thought\boxed{42}", "42", 1), - ], - ) - def test_get_deepscaler_rule_based_reward(self, response, label, expected): - assert get_deepscaler_rule_based_reward(response, label) == expected diff --git a/tests/fast/rollout/rm_hub/test_f1.py b/tests/fast/rollout/rm_hub/test_f1.py deleted file mode 100644 index c9ecf9614d..0000000000 --- a/tests/fast/rollout/rm_hub/test_f1.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest - -from miles.rollout.rm_hub.f1 import f1_score, normalize_answer - - -class TestNormalizeAnswer: - @pytest.mark.parametrize( - "input_str,expected", - [ - ("Hello World", "hello world"), - ("The quick brown fox", "quick brown fox"), - ("A cat and a dog", "cat and dog"), - ("Hello, world!", "hello world"), - (" multiple spaces ", "multiple spaces"), - ("An apple", "apple"), - ("UPPERCASE", "uppercase"), - ], - ) - def test_normalize_answer(self, input_str, expected): - assert normalize_answer(input_str) == expected - - -class TestF1Score: - @pytest.mark.parametrize( - "prediction,ground_truth,expected_f1,expected_prec,expected_recall", - [ - ("hello world", "hello world", 1.0, 1.0, 1.0), - ("hello world foo", "hello world bar", 2 / 3, 2 / 3, 2 / 3), - ("abc", "xyz", 0, 0, 0), - (None, "anything", 0, 0, 0), - ("yes", "no", 0, 0, 0), - ("no", "yes", 0, 0, 0), - ("yes", "yes", 1.0, 1.0, 1.0), - ("noanswer", "yes", 0, 0, 0), - ("the answer is correct", "answer is correct", 1.0, 1.0, 1.0), - ("hello, world!", "hello world", 1.0, 1.0, 1.0), - ("hello", "hello world", pytest.approx(2 / 3), 1.0, 0.5), - ], - ) - def test_f1_score(self, prediction, ground_truth, expected_f1, expected_prec, expected_recall): - f1, prec, recall = f1_score(prediction, ground_truth) - assert f1 == expected_f1 - assert prec == expected_prec - assert recall == expected_recall diff --git a/tests/fast/rollout/rm_hub/test_gpqa.py b/tests/fast/rollout/rm_hub/test_gpqa.py deleted file mode 100644 index 45cefd2015..0000000000 --- a/tests/fast/rollout/rm_hub/test_gpqa.py +++ /dev/null @@ -1,86 +0,0 @@ -import pytest - -from miles.rollout.rm_hub.gpqa import ( - _extract_letter_from_response, - _normalize_text, - _strip_chain_of_thought, - compute_gpqa_reward, -) - - -class TestStripChainOfThought: - @pytest.mark.parametrize( - "text,expected", - [ - ("Let me think...The answer is A", "The answer is A"), - ("The answer is A", "The answer is A"), - ("", ""), - (None, ""), - ], - ) - def test_strip_chain_of_thought(self, text, expected): - assert _strip_chain_of_thought(text) == expected - - -class TestNormalizeText: - @pytest.mark.parametrize( - "input_str,expected", - [ - ("Hello World", "hello world"), - ("Test-123", "test 123"), - ("A, B, C", "a b c"), - ("", ""), - ], - ) - def test_normalize_text(self, input_str, expected): - assert _normalize_text(input_str) == expected - - -class TestExtractLetterFromResponse: - @pytest.mark.parametrize( - "response,expected", - [ - ("The answer is A", "A"), - ("answer: B", "B"), - ("I think C is correct", "C"), - ("final answer: D", "D"), - ("Option A is the best choice", "A"), - ("The answer is B", "B"), - ("After analysis, my choice is C", "C"), - ("A B C D", "D"), - ("No valid letter here", None), - ("", None), - (None, None), - ("The answer is Z", None), - ], - ) - def test_extract_letter(self, response, expected): - assert _extract_letter_from_response(response, "ABCD") == expected - - -class TestComputeGpqaReward: - @pytest.mark.parametrize( - "response,label,metadata,expected", - [ - ("Answer: A", "A", None, 1.0), - ("Answer: A", "B", None, 0.0), - (None, "A", None, 0.0), - ("Answer: B", "ignored", {"correct_letter": "B"}, 1.0), - ("Answer: A", "ignored", {"correct_letter": "B"}, 0.0), - ("Answer: A", 0, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), - ("Answer: B", 1, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), - ("Answer: X", "X", {"valid_letters": ["X", "Y", "Z"]}, 1.0), - ("Answer: A", "X", {"valid_letters": ["X", "Y", "Z"]}, 0.0), - ( - "I believe the answer is Paris", - "", - {"choices": ["Paris", "London", "Berlin", "Rome"], "correct_letter": "A"}, - 1.0, - ), - ("Answer: A", "", {"choices": {"A": "Paris", "B": "London"}, "correct_letter": "A"}, 1.0), - ("The answer is Paris", "Paris", {"choices": ["Paris", "London", "Berlin", "Rome"]}, 1.0), - ("Let me think step by step...The answer is A", "A", None, 1.0), - ], - ) - def test_compute_gpqa_reward(self, response, label, metadata, expected): - assert compute_gpqa_reward(response, label, metadata=metadata) == expected diff --git a/tests/fast/rollout/rm_hub/test_math_dapo_utils.py b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py deleted file mode 100644 index 56a7f6d1f9..0000000000 --- a/tests/fast/rollout/rm_hub/test_math_dapo_utils.py +++ /dev/null @@ -1,108 +0,0 @@ -import pytest - -from miles.rollout.rm_hub.math_dapo_utils import ( - compute_score, - is_correct_minerva, - is_correct_strict_box, - last_boxed_only_string, - normalize_final_answer, - remove_boxed, -) - - -class TestLastBoxedOnlyString: - @pytest.mark.parametrize( - "input_str,expected", - [ - (r"The answer is \boxed{42}", r"\boxed{42}"), - (r"\boxed{x^2}", r"\boxed{x^2}"), - (r"No boxed", None), - (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), - ], - ) - def test_last_boxed_only_string(self, input_str, expected): - assert last_boxed_only_string(input_str) == expected - - -class TestRemoveBoxed: - @pytest.mark.parametrize( - "input_str,expected", - [ - (r"\boxed{42}", "42"), - (r"\boxed{x + 1}", "x + 1"), - ], - ) - def test_remove_boxed_valid(self, input_str, expected): - assert remove_boxed(input_str) == expected - - def test_remove_boxed_invalid(self): - with pytest.raises(AssertionError): - remove_boxed("not boxed") - - -class TestNormalizeFinalAnswer: - @pytest.mark.parametrize( - "input_str,expected", - [ - ("42", "42"), - (" 42 ", "42"), - (r"\text{hello}", "hello"), - (r"\textbf{bold}", "bold"), - (r"x = 42", "42"), - (r"100 square", "100"), - (r"$50$ dollars", "50"), - (r"\boxed{42}", "42"), - (r"\frac12", r"\frac{1}{2}"), - (r"\sqrt3", r"\sqrt{3}"), - ("1,000", "1000"), - ("<|im_end|>", ""), - ], - ) - def test_normalize_final_answer(self, input_str, expected): - assert normalize_final_answer(input_str) == expected - - -class TestIsCorrectMinerva: - @pytest.mark.parametrize( - "solution,gt,gt_need_extract,expected_correct", - [ - ("Answer: 42", "42", False, True), - ("Answer: 100", "42", False, False), - ("Answer: wrong", "42", False, False), - ("Answer: 42", r"\boxed{42}", True, True), - ], - ) - def test_is_correct_minerva(self, solution, gt, gt_need_extract, expected_correct): - correct, pred = is_correct_minerva(solution, gt, gt_need_extract=gt_need_extract) - assert correct == expected_correct - - -class TestIsCorrectStrictBox: - @pytest.mark.parametrize( - "pred,gt,expected_score,expected_pred", - [ - (r"blah blah \boxed{42}", "42", 1, "42"), - (r"\boxed{wrong}", "42", -1, "wrong"), - ("no box here", "42", -1, None), - ], - ) - def test_is_correct_strict_box(self, pred, gt, expected_score, expected_pred): - score, extracted = is_correct_strict_box(pred, gt) - assert score == expected_score - assert extracted == expected_pred - - -class TestComputeScore: - @pytest.mark.parametrize( - "solution,gt,strict_box,expected_score,expected_acc", - [ - ("Answer: 42", "42", False, 1.0, True), - ("Answer: wrong", "42", False, -1.0, False), - (r"\boxed{42}", "42", True, 1.0, True), - ("x" * 500 + " Answer: 42", "42", False, 1.0, True), - ], - ) - def test_compute_score(self, solution, gt, strict_box, expected_score, expected_acc): - result = compute_score(solution, gt, strict_box_verify=strict_box) - assert result["score"] == expected_score - assert result["acc"] == expected_acc diff --git a/tests/fast/rollout/rm_hub/test_math_utils.py b/tests/fast/rollout/rm_hub/test_math_utils.py deleted file mode 100644 index 2423ed4acc..0000000000 --- a/tests/fast/rollout/rm_hub/test_math_utils.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest - -from miles.rollout.rm_hub.math_utils import ( - _normalize, - extract_answer, - grade_answer_mathd, - grade_answer_sympy, - grade_answer_verl, - last_boxed_only_string, - remove_boxed, -) - - -class TestLastBoxedOnlyString: - @pytest.mark.parametrize( - "input_str,expected", - [ - (r"The answer is \boxed{42}", r"\boxed{42}"), - (r"\boxed{x^2 + 1}", r"\boxed{x^2 + 1}"), - (r"So \boxed{\frac{1}{2}}", r"\boxed{\frac{1}{2}}"), - (r"No boxed here", None), - (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), - (r"\boxed{nested {braces}}", r"\boxed{nested {braces}}"), - (r"\fbox{fbox content}", r"\fbox{fbox content}"), - ("", None), - ], - ) - def test_last_boxed_only_string(self, input_str, expected): - assert last_boxed_only_string(input_str) == expected - - -class TestRemoveBoxed: - @pytest.mark.parametrize( - "input_str,expected", - [ - (r"\boxed{42}", "42"), - (r"\boxed{x^2 + 1}", "x^2 + 1"), - (r"\boxed{\frac{1}{2}}", r"\frac{1}{2}"), - ("not boxed", None), - ], - ) - def test_remove_boxed(self, input_str, expected): - assert remove_boxed(input_str) == expected - - -class TestExtractAnswer: - @pytest.mark.parametrize( - "input_str,expected", - [ - (r"The answer is \boxed{42}", "42"), - (r"So \boxed{\frac{1}{2}}", r"\frac{1}{2}"), - (r"Multiple \boxed{1} then \boxed{final}", "final"), - (r"No boxed here", None), - ("", None), - ], - ) - def test_extract_answer(self, input_str, expected): - assert extract_answer(input_str) == expected - - -class TestNormalize: - @pytest.mark.parametrize( - "input_str,expected", - [ - ("1,000", "1000"), - (r"\text{hello}", "hello"), - (" 42 ", "42"), - (r"100%", "100"), - (r"\$50", "50"), - ("HELLO", "hello"), - ("1,234,567", "1234567"), - (None, None), - ], - ) - def test_normalize(self, input_str, expected): - assert _normalize(input_str) == expected - - -class TestGradeAnswerMathd: - @pytest.mark.parametrize( - "given,ground_truth,expected", - [ - ("42", "42", True), - (" 42 ", "42", True), - (r"\frac{1}{2}", r"\frac{1}{2}", True), - ("wrong", "42", False), - ("", "42", False), - ], - ) - def test_grade_answer_mathd(self, given, ground_truth, expected): - assert grade_answer_mathd(given, ground_truth) == expected - - -class TestGradeAnswerSympy: - @pytest.mark.parametrize( - "given,ground_truth,expected", - [ - ("42", "42", True), - ("x^2", "x^2", True), - ("1/2", "0.5", True), - (r"\frac{1}{2}", "0.5", True), - ("wrong", "42", False), - ("", "42", False), - ("(1,2)", "(1,2)", True), - ("(1,2,3)", "(1,2)", False), - ("42", None, False), - ], - ) - def test_grade_answer_sympy(self, given, ground_truth, expected): - assert grade_answer_sympy(given, ground_truth) == expected - - -class TestGradeAnswerVerl: - @pytest.mark.parametrize( - "solution,ground_truth,expected", - [ - (r"\boxed{42}", "42", True), - (r"The answer is \boxed{42}", "42", True), - (r"\boxed{1/2}", r"\frac{1}{2}", True), - (r"\boxed{wrong}", "42", False), - ("no boxed", "42", False), - (r"\boxed{42}", r"\boxed{42}", True), - ("", "42", False), - (r"\boxed{42}", "", False), - (r"\boxed{42}", None, False), - ], - ) - def test_grade_answer_verl(self, solution, ground_truth, expected): - assert grade_answer_verl(solution, ground_truth) == expected diff --git a/tests/fast/rollout/rm_hub/test_rm_hub.py b/tests/fast/rollout/rm_hub/test_rm_hub.py deleted file mode 100644 index a3dadbdaf0..0000000000 --- a/tests/fast/rollout/rm_hub/test_rm_hub.py +++ /dev/null @@ -1,126 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from miles.rollout.rm_hub import async_rm, batched_async_rm -from miles.utils.async_utils import run -from miles.utils.types import Sample - - -@pytest.fixture -def mock_args(): - args = MagicMock() - args.custom_rm_path = None - args.rm_type = None - args.rm_url = None - return args - - -class TestAsyncRm: - @pytest.mark.parametrize( - "rm_type,response,label,expected", - [ - ("math", r"\boxed{42}", "42", 1), - ("math", r"\boxed{wrong}", "42", 0), - ("f1", "hello world", "hello world", 1.0), - ("dapo", "Answer: 42", "42", {"score": 1.0}), - ("deepscaler", r"\boxed{42}", "42", 1), - ("gpqa", "Answer: A", "A", 1.0), - ("boxed_f1", r"Final answer is \boxed{hello world}", "hello world", 1.0), - ], - ) - def test_rm_types(self, mock_args, rm_type, response, label, expected): - mock_args.rm_type = rm_type - sample = Sample(prompt="", response=response, label=label) - reward = run(async_rm(mock_args, sample)) - if isinstance(expected, dict): - for k, v in expected.items(): - assert reward[k] == v - else: - assert reward == expected - - def test_f1_rm_partial(self, mock_args): - mock_args.rm_type = "f1" - sample = Sample(prompt="", response="hello", label="hello world") - reward = run(async_rm(mock_args, sample)) - assert 0 < reward < 1 - - def test_random_rm(self, mock_args): - mock_args.rm_type = "random" - sample = Sample(prompt="", response="anything", label="anything") - reward = run(async_rm(mock_args, sample)) - assert reward in [0, 1] - - def test_rm_type_from_metadata(self, mock_args): - mock_args.rm_type = None - sample = Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}) - reward = run(async_rm(mock_args, sample)) - assert reward == 1 - - @pytest.mark.parametrize( - "rm_type,match", - [ - ("unknown_type", "not implemented"), - ("", "not specified"), - ], - ) - def test_invalid_rm_type_raises(self, mock_args, rm_type, match): - mock_args.rm_type = rm_type - sample = Sample(prompt="", response="test", label="test") - with pytest.raises(NotImplementedError, match=match): - run(async_rm(mock_args, sample)) - - -class TestBatchedAsyncRm: - @pytest.mark.parametrize( - "rm_type,samples_data,expected", - [ - ( - "math", - [(r"\boxed{42}", "42"), (r"\boxed{100}", "100"), (r"\boxed{wrong}", "42")], - [1, 1, 0], - ), - ( - "f1", - [("hello world", "hello world"), ("different", "something else")], - [1.0, 0], - ), - ], - ) - def test_batched_rm(self, mock_args, rm_type, samples_data, expected): - mock_args.rm_type = rm_type - samples = [Sample(prompt="", response=r, label=label) for r, label in samples_data] - rewards = run(batched_async_rm(mock_args, samples)) - assert rewards == expected - - def test_inplace_set_reward_field(self, mock_args): - mock_args.rm_type = "math" - samples = [ - Sample(prompt="", response=r"\boxed{42}", label="42"), - Sample(prompt="", response=r"\boxed{100}", label="100"), - ] - result = run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) - assert result is None - assert samples[0].reward == 1 - assert samples[1].reward == 1 - - def test_inplace_raises_on_existing_reward(self, mock_args): - mock_args.rm_type = "math" - samples = [Sample(prompt="", response=r"\boxed{42}", label="42", reward=0.5)] - with pytest.raises(AssertionError, match="Overriding"): - run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) - - def test_empty_samples(self, mock_args): - mock_args.rm_type = "math" - rewards = run(batched_async_rm(mock_args, [])) - assert rewards == [] - - def test_mixed_rm_types_via_metadata(self, mock_args): - mock_args.rm_type = None - samples = [ - Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}), - Sample(prompt="", response="hello", label="hello", metadata={"rm_type": "f1"}), - ] - rewards = run(batched_async_rm(mock_args, samples)) - assert rewards[0] == 1 - assert rewards[1] == 1.0 diff --git a/tests/fast/router/__init__.py b/tests/fast/router/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/fast/router/test_router.py b/tests/fast/router/test_router.py deleted file mode 100644 index 7c645fe304..0000000000 --- a/tests/fast/router/test_router.py +++ /dev/null @@ -1,204 +0,0 @@ -import asyncio -from argparse import Namespace - -import pytest -import requests - -from miles.router.router import MilesRouter -from miles.utils.http_utils import find_available_port -from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, default_process_fn -from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer - - -def make_router_args(router_port: int, **overrides) -> Namespace: - defaults = dict( - sglang_router_ip="127.0.0.1", - sglang_router_port=router_port, - rollout_health_check_interval=1.0, - miles_router_health_check_failure_threshold=3, - miles_router_max_connections=100, - miles_router_timeout=None, - miles_router_middleware_paths=[], - ) - defaults.update(overrides) - return Namespace(**defaults) - - -def create_mock_worker(start_port: int = 30000) -> MockSGLangServer: - port = find_available_port(start_port) - return MockSGLangServer( - model_name="Qwen/Qwen3-0.6B", - process_fn=default_process_fn, - host="127.0.0.1", - port=port, - latency=0.0, - ) - - -class RouterEnv: - def __init__(self, router: MilesRouter, server: UvicornThreadServer): - self.router = router - self.server = server - - @property - def url(self) -> str: - return self.server.url - - -@pytest.fixture -def router_env(): - args = make_router_args(find_available_port(20000)) - router = MilesRouter(args, verbose=False) - server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) - server.start() - yield RouterEnv(router, server) - server.stop() - - -@pytest.fixture -def mock_worker(): - server = create_mock_worker() - server.start() - yield server - server.stop() - - -@pytest.fixture -def mock_worker_factory(): - servers = [] - - def _create(): - start_port = 30000 + len(servers) * 100 - server = create_mock_worker(start_port) - server.start() - servers.append(server) - return server - - yield _create - for s in servers: - s.stop() - - -@pytest.fixture -def router_factory(): - def _create(**overrides) -> MilesRouter: - args = make_router_args(find_available_port(20000), **overrides) - return MilesRouter(args, verbose=False) - - return _create - - -class TestWorkerManagement: - def test_add_worker_via_query_param(self, router_env: RouterEnv): - worker_url = "http://127.0.0.1:30001" - r = requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0) - r.raise_for_status() - - assert r.json()["status"] == "success" - assert worker_url in router_env.router.worker_request_counts - assert router_env.router.worker_request_counts[worker_url] == 0 - - def test_add_worker_via_body(self, router_env: RouterEnv): - worker_url = "http://127.0.0.1:30002" - r = requests.post(f"{router_env.url}/add_worker", json={"url": worker_url}, timeout=5.0) - r.raise_for_status() - - assert r.json()["status"] == "success" - assert worker_url in router_env.router.worker_request_counts - - def test_add_worker_duplicate(self, router_env: RouterEnv): - worker_url = "http://127.0.0.1:30003" - requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() - requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() - - assert len(router_env.router.worker_request_counts) == 1 - assert worker_url in router_env.router.worker_request_counts - - def test_add_worker_missing_url(self, router_env: RouterEnv): - r = requests.post(f"{router_env.url}/add_worker", json={}, timeout=5.0) - assert r.status_code == 400 - assert "error" in r.json() - - def test_list_workers(self, router_env: RouterEnv): - worker_urls = ["http://127.0.0.1:30001", "http://127.0.0.1:30002"] - for url in worker_urls: - requests.post(f"{router_env.url}/add_worker", params={"url": url}, timeout=5.0) - - r = requests.get(f"{router_env.url}/list_workers", timeout=5.0) - r.raise_for_status() - assert set(r.json()["urls"]) == set(worker_urls) - - -class TestLoadBalancing: - def test_use_url_selects_min_load(self, router_factory): - router = router_factory() - router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8} - - selected = router._use_url() - assert selected == "http://w2:8000" - assert router.worker_request_counts["http://w2:8000"] == 3 - - def test_use_url_excludes_dead_workers(self, router_factory): - router = router_factory() - router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 1, "http://w3:8000": 3} - router.dead_workers = {"http://w2:8000"} - - selected = router._use_url() - assert selected == "http://w3:8000" - assert router.worker_request_counts["http://w3:8000"] == 4 - - def test_use_url_raises_when_all_dead(self, router_factory): - router = router_factory() - router.worker_request_counts = {"http://w1:8000": 0} - router.dead_workers = {"http://w1:8000"} - - with pytest.raises(RuntimeError, match="No healthy workers"): - router._use_url() - - -# TODO: extract main body inside `_health_check_loop`, then can test that function -class TestHealthCheck: - def test_check_worker_health_success(self, router_factory, mock_worker: MockSGLangServer): - router = router_factory() - url, healthy = asyncio.run(router._check_worker_health(mock_worker.url)) - assert url == mock_worker.url - assert healthy is True - - def test_check_worker_health_failure(self, router_factory): - router = router_factory() - url, healthy = asyncio.run(router._check_worker_health("http://127.0.0.1:59999")) - assert url == "http://127.0.0.1:59999" - assert healthy is False - - -class TestProxyIntegration: - def test_proxy_forwards_request(self, router_env: RouterEnv, mock_worker: MockSGLangServer): - requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0).raise_for_status() - - payload = {"input_ids": [1, 2, 3], "return_logprob": True} - r = requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0) - r.raise_for_status() - - assert "text" in r.json() - assert len(mock_worker.request_log) == 1 - assert mock_worker.request_log[0] == payload - - def test_proxy_multi_worker(self, router_env: RouterEnv, mock_worker_factory): - worker1, worker2 = mock_worker_factory(), mock_worker_factory() - requests.post(f"{router_env.url}/add_worker", params={"url": worker1.url}, timeout=5.0) - requests.post(f"{router_env.url}/add_worker", params={"url": worker2.url}, timeout=5.0) - - payload = {"input_ids": [1, 2, 3], "return_logprob": True} - for _ in range(4): - requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0).raise_for_status() - - all_requests = worker1.request_log + worker2.request_log - assert len(all_requests) == 4 - assert all(req == payload for req in all_requests) - - def test_proxy_health_endpoint(self, router_env: RouterEnv, mock_worker: MockSGLangServer): - requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0) - - r = requests.get(f"{router_env.url}/health", timeout=5.0) - r.raise_for_status() - assert r.json()["status"] == "ok" diff --git a/tests/fast/router/test_sessions.py b/tests/fast/router/test_sessions.py deleted file mode 100644 index 5c6edafe20..0000000000 --- a/tests/fast/router/test_sessions.py +++ /dev/null @@ -1,195 +0,0 @@ -from types import SimpleNamespace - -import pytest -import requests - -from miles.router.router import MilesRouter -from miles.router.sessions import SessionManager, SessionRecord -from miles.utils.http_utils import find_available_port -from miles.utils.test_utils.mock_sglang_server import ProcessResult, with_mock_server -from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer - - -class TestSessionManager: - def test_create_session(self): - manager = SessionManager() - session_id = manager.create_session() - assert session_id is not None - assert len(session_id) == 32 - assert session_id in manager.sessions - assert manager.sessions[session_id] == [] - - def test_get_session_exists(self): - manager = SessionManager() - session_id = manager.create_session() - records = manager.get_session(session_id) - assert records == [] - - def test_get_session_not_exists(self): - manager = SessionManager() - records = manager.get_session("nonexistent") - assert records is None - - def test_delete_session_exists(self): - manager = SessionManager() - session_id = manager.create_session() - records = manager.delete_session(session_id) - assert records == [] - assert session_id not in manager.sessions - - def test_delete_session_not_exists(self): - manager = SessionManager() - with pytest.raises(AssertionError): - manager.delete_session("nonexistent") - - def test_add_record(self): - manager = SessionManager() - session_id = manager.create_session() - record = SessionRecord( - timestamp=1234567890.0, - method="POST", - path="generate", - request={"prompt": "hello"}, - response={"text": "world"}, - status_code=200, - ) - manager.add_record(session_id, record) - assert len(manager.sessions[session_id]) == 1 - assert manager.sessions[session_id][0] == record - - def test_add_record_nonexistent_session(self): - manager = SessionManager() - record = SessionRecord( - timestamp=1234567890.0, - method="POST", - path="generate", - request={}, - response={}, - status_code=200, - ) - with pytest.raises(AssertionError): - manager.add_record("nonexistent", record) - - -@pytest.fixture(scope="class") -def router_url(): - def process_fn(prompt: str) -> ProcessResult: - return ProcessResult(text=f"echo: {prompt}", finish_reason="stop") - - with with_mock_server(process_fn=process_fn) as backend: - args = SimpleNamespace( - miles_router_max_connections=10, - miles_router_timeout=30, - miles_router_middleware_paths=[], - rollout_health_check_interval=60, - miles_router_health_check_failure_threshold=3, - hf_checkpoint="Qwen/Qwen3-0.6B", - ) - router = MilesRouter(args) - - port = find_available_port(31000) - server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) - server.start() - - url = f"http://127.0.0.1:{port}" - requests.post(f"{url}/add_worker", json={"url": backend.url}) - - try: - yield url - finally: - server.stop() - - -class TestSessionRoutes: - def test_create_session(self, router_url): - response = requests.post(f"{router_url}/sessions") - assert response.status_code == 200 - data = response.json() - assert "session_id" in data - assert len(data["session_id"]) == 32 - - def test_get_session(self, router_url): - session_id = requests.post(f"{router_url}/sessions").json()["session_id"] - - get_resp = requests.get(f"{router_url}/sessions/{session_id}") - assert get_resp.status_code == 200 - data = get_resp.json() - assert data["session_id"] == session_id - assert data["records"] == [] - - def test_get_session_not_found(self, router_url): - response = requests.get(f"{router_url}/sessions/nonexistent") - assert response.status_code == 404 - assert response.json()["error"] == "session not found" - - def test_get_with_records(self, router_url): - session_id = requests.post(f"{router_url}/sessions").json()["session_id"] - - requests.post( - f"{router_url}/sessions/{session_id}/generate", - json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, - ) - - get_resp = requests.get(f"{router_url}/sessions/{session_id}") - assert get_resp.status_code == 200 - data = get_resp.json() - assert data["session_id"] == session_id - assert len(data["records"]) == 1 - - def test_delete_session(self, router_url): - session_id = requests.post(f"{router_url}/sessions").json()["session_id"] - - delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") - assert delete_resp.status_code == 204 - assert delete_resp.text == "" - - assert requests.delete(f"{router_url}/sessions/{session_id}").status_code == 404 - - def test_delete_session_not_found(self, router_url): - response = requests.delete(f"{router_url}/sessions/nonexistent") - assert response.status_code == 404 - assert response.json()["error"] == "session not found" - - -class TestSessionProxy: - def test_proxy_session_not_found(self, router_url): - response = requests.post(f"{router_url}/sessions/nonexistent/generate", json={}) - assert response.status_code == 404 - assert response.json()["error"] == "session not found" - - def test_proxy_records_request_response(self, router_url): - session_id = requests.post(f"{router_url}/sessions").json()["session_id"] - - resp = requests.post( - f"{router_url}/sessions/{session_id}/generate", - json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, - ) - assert resp.status_code == 200 - assert "text" in resp.json() - - get_resp = requests.get(f"{router_url}/sessions/{session_id}") - records = get_resp.json()["records"] - assert len(records) == 1 - assert records[0]["method"] == "POST" - assert records[0]["path"] == "generate" - assert records[0]["request"]["input_ids"] == [1, 2, 3] - assert "text" in records[0]["response"] - - delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") - assert delete_resp.status_code == 204 - - def test_proxy_accumulates_records(self, router_url): - session_id = requests.post(f"{router_url}/sessions").json()["session_id"] - - for _ in range(3): - requests.post( - f"{router_url}/sessions/{session_id}/generate", - json={"input_ids": [1], "sampling_params": {}, "return_logprob": True}, - ) - - get_resp = requests.get(f"{router_url}/sessions/{session_id}") - records = get_resp.json()["records"] - assert len(records) == 3 - - delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") - assert delete_resp.status_code == 204 diff --git a/tests/fast/utils/__init__.py b/tests/fast/utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/fast/utils/test_arguments.py b/tests/fast/utils/test_arguments.py deleted file mode 100644 index 9bd1a620d6..0000000000 --- a/tests/fast/utils/test_arguments.py +++ /dev/null @@ -1,58 +0,0 @@ -import argparse -import sys -from unittest.mock import patch - -import pytest - -from miles.utils.arguments import get_miles_extra_args_provider -from miles.utils.misc import function_registry - -PATH_ARGS = ["--rollout-function-path", "--custom-generate-function-path"] -REQUIRED_ARGS = ["--rollout-batch-size", "64"] - - -def make_class_with_add_arguments(): - class MyFn: - @classmethod - def add_arguments(cls, parser): - parser.add_argument("--my-custom-arg", type=int, default=42) - - return MyFn - - -def make_function_with_add_arguments(): - def my_fn(): - pass - - my_fn.add_arguments = lambda parser: parser.add_argument("--my-custom-arg", type=int, default=42) - return my_fn - - -def make_function_without_add_arguments(): - def my_fn(): - pass - - return my_fn - - -@pytest.mark.parametrize("path_arg", PATH_ARGS) -class TestAddArgumentsSupport: - - @pytest.mark.parametrize("fn_factory", [make_class_with_add_arguments, make_function_with_add_arguments]) - def test_add_arguments_is_called_and_arg_is_parsed(self, path_arg, fn_factory): - fn = fn_factory() - with function_registry.temporary("test:fn", fn), patch.object( - sys, "argv", ["test", path_arg, "test:fn", "--my-custom-arg", "100"] + REQUIRED_ARGS - ): - parser = argparse.ArgumentParser() - get_miles_extra_args_provider()(parser) - args, _ = parser.parse_known_args() - assert args.my_custom_arg == 100 - - def test_skips_function_without_add_arguments(self, path_arg): - fn = make_function_without_add_arguments() - with function_registry.temporary("test:fn", fn), patch.object( - sys, "argv", ["test", path_arg, "test:fn"] + REQUIRED_ARGS - ): - parser = argparse.ArgumentParser() - get_miles_extra_args_provider()(parser) diff --git a/tests/fast/utils/test_misc.py b/tests/fast/utils/test_misc.py deleted file mode 100644 index 810c2b67c7..0000000000 --- a/tests/fast/utils/test_misc.py +++ /dev/null @@ -1,59 +0,0 @@ -import os - -import pytest - -from miles.utils.misc import FunctionRegistry, function_registry, load_function - - -def _fn_a(): - return "a" - - -def _fn_b(): - return "b" - - -class TestFunctionRegistry: - def test_register_and_get(self): - registry = FunctionRegistry() - with registry.temporary("my_fn", _fn_a): - assert registry.get("my_fn") is _fn_a - - def test_register_duplicate_raises(self): - registry = FunctionRegistry() - with registry.temporary("my_fn", _fn_a): - with pytest.raises(AssertionError): - with registry.temporary("my_fn", _fn_b): - pass - - def test_unregister(self): - registry = FunctionRegistry() - with registry.temporary("my_fn", _fn_a): - assert registry.get("my_fn") is _fn_a - assert registry.get("my_fn") is None - - def test_temporary_cleanup_on_exception(self): - registry = FunctionRegistry() - with pytest.raises(RuntimeError): - with registry.temporary("temp_fn", _fn_a): - raise RuntimeError("test") - assert registry.get("temp_fn") is None - - -class TestLoadFunction: - def test_load_from_module(self): - import os.path - - assert load_function("os.path.join") is os.path.join - - def test_load_none_returns_none(self): - assert load_function(None) is None - - def test_load_from_registry(self): - with function_registry.temporary("test:my_fn", _fn_a): - assert load_function("test:my_fn") is _fn_a - - def test_registry_takes_precedence(self): - with function_registry.temporary("os.path.join", _fn_b): - assert load_function("os.path.join") is _fn_b - assert load_function("os.path.join") is os.path.join diff --git a/tests/fast/utils/test_utils/__init__.py b/tests/fast/utils/test_utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/fast/utils/test_utils/test_mock_sglang_server.py b/tests/fast/utils/test_utils/test_mock_sglang_server.py deleted file mode 100644 index 6633678da1..0000000000 --- a/tests/fast/utils/test_utils/test_mock_sglang_server.py +++ /dev/null @@ -1,409 +0,0 @@ -import asyncio -import concurrent.futures -import time - -import pytest -import requests - -from miles.utils.test_utils.mock_sglang_server import ( - Counter, - ProcessResult, - ProcessResultMetaInfo, - default_process_fn, - with_mock_server, -) -from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub - - -def expected_logprobs(tokenizer, text: str) -> list[dict]: - output_ids = tokenizer.encode(text, add_special_tokens=False) - return [{"token": tokenizer.convert_ids_to_tokens(tid), "logprob": -i / 128} for i, tid in enumerate(output_ids)] - - -@pytest.fixture(scope="module") -def mock_server(): - with with_mock_server() as server: - yield server - - -class TestProcessResultMetaInfo: - def test_to_dict_empty(self): - assert ProcessResultMetaInfo().to_dict() == {} - - def test_to_dict_single_field(self): - assert ProcessResultMetaInfo(weight_version="v1").to_dict() == {"weight_version": "v1"} - - def test_to_dict_partial_fields(self): - assert ProcessResultMetaInfo(weight_version="v1", spec_accept_token_num=10).to_dict() == { - "weight_version": "v1", - "spec_accept_token_num": 10, - } - - def test_to_dict_all_fields(self): - assert ProcessResultMetaInfo( - weight_version="v1", - routed_experts="abc", - spec_accept_token_num=10, - spec_draft_token_num=15, - spec_verify_ct=3, - ).to_dict() == { - "weight_version": "v1", - "routed_experts": "abc", - "spec_accept_token_num": 10, - "spec_draft_token_num": 15, - "spec_verify_ct": 3, - } - - -class TestDefaultProcessFn: - def test_math_question(self): - assert default_process_fn("What is 1+5?") == ProcessResult(text="\\boxed{6}", finish_reason="stop") - assert default_process_fn("What is 1+10?") == ProcessResult(text="\\boxed{11}", finish_reason="stop") - - def test_unknown_question(self): - assert default_process_fn("Hello") == ProcessResult(text="I don't understand.", finish_reason="stop") - - -class TestCounter: - def test_tracks_max(self): - counter = Counter() - assert counter.max_value == 0 - - with counter.track(): - assert counter.max_value == 1 - with counter.track(): - assert counter.max_value == 2 - - counter.reset() - assert counter.max_value == 0 - - def test_concurrent_tasks(self): - counter = Counter() - - async def task(): - with counter.track(): - await asyncio.sleep(0.1) - - async def run_all(): - await asyncio.gather(task(), task(), task()) - - asyncio.run(run_all()) - assert counter.max_value == 3 - - -class TestMockServerBasic: - def test_start_stop(self, mock_server): - assert mock_server.port > 0 - assert f"http://{mock_server.host}:{mock_server.port}" == mock_server.url - - def test_request_log_and_reset_stats(self, mock_server): - mock_server.reset_stats() - assert len(mock_server.request_log) == 0 - - payload = {"input_ids": [1, 2, 3], "sampling_params": {"temperature": 0.5}, "return_logprob": True} - requests.post(f"{mock_server.url}/generate", json=payload, timeout=5.0) - assert len(mock_server.request_log) == 1 - assert mock_server.request_log[0] == payload - - mock_server.reset_stats() - assert len(mock_server.request_log) == 0 - assert mock_server.max_concurrent == 0 - - @pytest.mark.parametrize("latency,min_time,max_time", [(0.0, 0.0, 0.3), (0.5, 0.5, 1.0)]) - def test_latency(self, latency, min_time, max_time): - with with_mock_server(latency=latency) as server: - start = time.time() - requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) - elapsed = time.time() - start - assert min_time <= elapsed < max_time - - def test_max_concurrent_with_latency(self): - with with_mock_server(latency=0.1) as server: - - def send_request(): - requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) - - with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: - futures = [executor.submit(send_request) for _ in range(3)] - concurrent.futures.wait(futures) - - assert server.max_concurrent == 3 - - def test_health_endpoint(self, mock_server): - response = requests.get(f"{mock_server.url}/health", timeout=5.0) - assert response.status_code == 200 - assert response.json() == {"status": "ok"} - - def test_abort_request_endpoint(self, mock_server): - response = requests.post(f"{mock_server.url}/abort_request", json={}, timeout=5.0) - assert response.status_code == 200 - assert response.json() == {"status": "ok"} - - -class TestGenerateEndpoint: - def test_basic(self, mock_server): - prompt = "What is 1+7?" - input_ids = mock_server.tokenizer.encode(prompt, add_special_tokens=False) - assert input_ids == [3838, 374, 220, 16, 10, 22, 30] - - response = requests.post( - f"{mock_server.url}/generate", - json={ - "input_ids": input_ids, - "sampling_params": {"temperature": 0.7, "max_new_tokens": 10}, - "return_logprob": True, - }, - timeout=5.0, - ) - assert response.status_code == 200 - assert response.json() == { - "text": "\\boxed{8}", - "meta_info": { - "finish_reason": {"type": "stop"}, - "prompt_tokens": len(input_ids), - "cached_tokens": 0, - "completion_tokens": 5, - "output_token_logprobs": [ - [-0.0, 59], - [-0.0078125, 79075], - [-0.015625, 90], - [-0.0234375, 23], - [-0.03125, 92], - ], - }, - } - - def test_with_meta_info(self): - def process_fn(_: str) -> ProcessResult: - return ProcessResult( - text="ok", - finish_reason="stop", - cached_tokens=5, - meta_info=ProcessResultMetaInfo( - weight_version="v2.0", - routed_experts="encoded_data", - spec_accept_token_num=10, - spec_draft_token_num=15, - spec_verify_ct=3, - ), - ) - - with with_mock_server(process_fn=process_fn) as server: - response = requests.post( - f"{server.url}/generate", - json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, - timeout=5.0, - ) - - assert response.json() == { - "text": "ok", - "meta_info": { - "finish_reason": {"type": "stop"}, - "prompt_tokens": 3, - "cached_tokens": 5, - "completion_tokens": 1, - "output_token_logprobs": [[-0.0, 562]], - "weight_version": "v2.0", - "routed_experts": "encoded_data", - "spec_accept_token_num": 10, - "spec_draft_token_num": 15, - "spec_verify_ct": 3, - }, - } - - def test_finish_reason_length(self): - def process_fn(_: str) -> ProcessResult: - return ProcessResult(text="truncated output", finish_reason="length") - - with with_mock_server(process_fn=process_fn) as server: - response = requests.post( - f"{server.url}/generate", - json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, - timeout=5.0, - ) - data = response.json() - - finish_reason = data["meta_info"]["finish_reason"] - assert finish_reason["type"] == "length" - assert finish_reason["length"] == data["meta_info"]["completion_tokens"] - - -class TestChatCompletionsEndpoint: - def test_basic(self, mock_server): - response = requests.post( - f"{mock_server.url}/v1/chat/completions", - json={ - "model": "test-model", - "messages": [{"role": "user", "content": "What is 1+5?"}], - }, - timeout=5.0, - ) - assert response.status_code == 200 - data = response.json() - - assert data["id"].startswith("chatcmpl-") - assert isinstance(data["created"], int) - assert data == { - "id": data["id"], - "object": "chat.completion", - "created": data["created"], - "model": "mock-model", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "\\boxed{6}", "tool_calls": None}, - "logprobs": {"content": expected_logprobs(mock_server.tokenizer, "\\boxed{6}")}, - "finish_reason": "stop", - } - ], - } - - def test_with_tool_calls(self): - tool_call_response = 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n' - - def process_fn(_: str) -> ProcessResult: - return ProcessResult(text=tool_call_response, finish_reason="stop") - - with with_mock_server(process_fn=process_fn) as server: - response = requests.post( - f"{server.url}/v1/chat/completions", - json={ - "model": "test", - "messages": [{"role": "user", "content": "What year is it?"}], - "tools": SAMPLE_TOOLS, - }, - timeout=5.0, - ) - data = response.json() - - assert data["choices"][0] == { - "index": 0, - "message": { - "role": "assistant", - "content": "Let me check for you.", - "tool_calls": [ - {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}} - ], - }, - "logprobs": {"content": expected_logprobs(server.tokenizer, tool_call_response)}, - "finish_reason": "tool_calls", - } - - def test_with_tools_but_no_tool_call(self): - response_text = "The weather is sunny today." - - def process_fn(_: str) -> ProcessResult: - return ProcessResult(text=response_text, finish_reason="stop") - - with with_mock_server(process_fn=process_fn) as server: - response = requests.post( - f"{server.url}/v1/chat/completions", - json={ - "model": "test", - "messages": [{"role": "user", "content": "What's the weather?"}], - "tools": SAMPLE_TOOLS, - }, - timeout=5.0, - ) - data = response.json() - - assert data["choices"][0] == { - "index": 0, - "message": {"role": "assistant", "content": response_text, "tool_calls": None}, - "logprobs": {"content": expected_logprobs(server.tokenizer, response_text)}, - "finish_reason": "stop", - } - - def test_with_multiple_tool_calls(self): - multi_tool_response = ( - "I will get year and temperature.\n" - '\n{"name": "get_year", "arguments": {}}\n\n' - '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n' - ) - - def process_fn(_: str) -> ProcessResult: - return ProcessResult(text=multi_tool_response, finish_reason="stop") - - with with_mock_server(process_fn=process_fn) as server: - response = requests.post( - f"{server.url}/v1/chat/completions", - json={ - "model": "test", - "messages": [{"role": "user", "content": "What year and temperature?"}], - "tools": SAMPLE_TOOLS, - }, - timeout=5.0, - ) - data = response.json() - - assert data["choices"][0] == { - "index": 0, - "message": { - "role": "assistant", - "content": "I will get year and temperature.", - "tool_calls": [ - {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}}, - { - "id": "call00001", - "type": "function", - "function": {"name": "get_temperature", "arguments": '{"location": "Shanghai"}'}, - }, - ], - }, - "logprobs": {"content": expected_logprobs(server.tokenizer, multi_tool_response)}, - "finish_reason": "tool_calls", - } - - -class TestMultiTurnToolCallProcessFn: - @pytest.mark.parametrize( - "prompt,expected_response", - [ - pytest.param(TwoTurnStub.FIRST_PROMPT, TwoTurnStub.FIRST_RESPONSE, id="first_turn"), - pytest.param(TwoTurnStub.SECOND_PROMPT, TwoTurnStub.SECOND_RESPONSE, id="second_turn"), - ], - ) - def test_generate_endpoint(self, prompt, expected_response): - with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: - input_ids = server.tokenizer.encode(prompt, add_special_tokens=False) - response = requests.post( - f"{server.url}/generate", - json={"input_ids": input_ids, "sampling_params": {}, "return_logprob": True}, - timeout=5.0, - ) - assert response.status_code == 200 - data = response.json() - assert data["text"] == expected_response - assert data["meta_info"]["finish_reason"] == {"type": "stop"} - - @pytest.mark.parametrize( - "messages,expected_content,expected_tool_calls,expected_finish_reason", - [ - pytest.param( - TwoTurnStub.OPENAI_MESSAGES_FIRST_TURN, - TwoTurnStub.FIRST_RESPONSE_CONTENT, - TwoTurnStub.FIRST_TOOL_CALLS_OPENAI_FORMAT, - "tool_calls", - id="first_turn", - ), - pytest.param( - TwoTurnStub.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT, - TwoTurnStub.SECOND_RESPONSE, - None, - "stop", - id="second_turn", - ), - ], - ) - def test_chat_completions_endpoint(self, messages, expected_content, expected_tool_calls, expected_finish_reason): - with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: - response = requests.post( - f"{server.url}/v1/chat/completions", - json={"model": "test", "messages": messages, "tools": SAMPLE_TOOLS}, - timeout=5.0, - ) - assert response.status_code == 200 - data = response.json() - assert data["choices"][0]["message"]["content"] == expected_content - assert data["choices"][0]["message"]["tool_calls"] == expected_tool_calls - assert data["choices"][0]["finish_reason"] == expected_finish_reason diff --git a/tests/fast/utils/test_utils/test_mock_tools.py b/tests/fast/utils/test_utils/test_mock_tools.py deleted file mode 100644 index 3f2116ec01..0000000000 --- a/tests/fast/utils/test_utils/test_mock_tools.py +++ /dev/null @@ -1,111 +0,0 @@ -import asyncio - -import pytest -from pydantic import TypeAdapter -from sglang.srt.entrypoints.openai.protocol import Tool -from sglang.srt.function_call.core_types import ToolCallItem -from sglang.srt.function_call.function_call_parser import FunctionCallParser - -from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub, execute_tool_call - - -class TestExecuteToolCall: - def test_execute_get_year(self): - result = asyncio.run(execute_tool_call("get_year", {})) - assert result == '{"year": 2026}' - - def test_execute_get_temperature(self): - result = asyncio.run(execute_tool_call("get_temperature", {"location": "Mars"})) - assert result == '{"temperature": -60}' - - -class TestApplyChatTemplateWithTools: - EXPECTED_PROMPT_WITHOUT_TOOLS = ( - "<|im_start|>user\n" "What's the weather in Paris?<|im_end|>\n" "<|im_start|>assistant\n" - ) - - EXPECTED_PROMPT_WITH_TOOLS = ( - "<|im_start|>system\n" - "# Tools\n\n" - "You may call one or more functions to assist with the user query.\n\n" - "You are provided with function signatures within XML tags:\n" - "\n" - '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' - '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' - "\n\n" - "For each function call, return a json object with function name and arguments within XML tags:\n" - "\n" - '{"name": , "arguments": }\n' - "<|im_end|>\n" - "<|im_start|>user\n" - "What's the weather in Paris?<|im_end|>\n" - "<|im_start|>assistant\n" - ) - - @pytest.mark.parametrize( - "tools,expected", - [ - pytest.param(None, EXPECTED_PROMPT_WITHOUT_TOOLS, id="without_tools"), - pytest.param(SAMPLE_TOOLS, EXPECTED_PROMPT_WITH_TOOLS, id="with_tools"), - ], - ) - def test_apply_chat_template(self, tools, expected): - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) - messages = [{"role": "user", "content": "What's the weather in Paris?"}] - - prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools) - - assert prompt == expected - - -class TestSGLangFunctionCallParser: - """Test to demonstrate and ensure SGLang function call parser have features we need without breaking changes.""" - - @pytest.mark.parametrize( - "model_output,expected", - [ - pytest.param( - 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n', - ( - "Let me check for you.", - [ToolCallItem(tool_index=-1, name="get_year", parameters="{}")], - ), - id="single_tool_call", - ), - pytest.param( - "I will get year and temperature.\n" - '\n{"name": "get_year", "arguments": {}}\n\n' - '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n', - ( - "I will get year and temperature.", - [ - ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), - ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Shanghai"}'), - ], - ), - id="multi_tool_calls", - ), - pytest.param( - "The weather is sunny today.", - ("The weather is sunny today.", []), - id="no_tool_call", - ), - pytest.param( - TwoTurnStub.FIRST_RESPONSE, - ( - "Let me get the year and temperature first.", - [ - ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), - ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Mars"}'), - ], - ), - id="multi_turn_first_response", - ), - ], - ) - def test_parse_non_stream(self, model_output, expected): - tools = TypeAdapter(list[Tool]).validate_python(SAMPLE_TOOLS) - parser = FunctionCallParser(tools=tools, tool_call_parser="qwen25") - assert parser.parse_non_stream(model_output) == expected diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index 9b6e69c295..c5c0838c53 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -126,7 +126,6 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, before_ray_job_submit=_launch_background, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_mimo_7B_mtp_only_grad.py b/tests/test_mimo_7B_mtp_only_grad.py index d90a2d7a71..97c76ace5a 100644 --- a/tests/test_mimo_7B_mtp_only_grad.py +++ b/tests/test_mimo_7B_mtp_only_grad.py @@ -135,7 +135,6 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_moonlight_16B_A3B.py b/tests/test_moonlight_16B_A3B.py index c35943ec15..b1255982ed 100644 --- a/tests/test_moonlight_16B_A3B.py +++ b/tests/test_moonlight_16B_A3B.py @@ -113,7 +113,6 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index ae3c383ae8..15ca8ce5fe 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -115,7 +115,6 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index 4d7f034f6c..dcdbd58347 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -120,7 +120,6 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index 32b60f5937..dcaaf5e1f7 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -120,7 +120,6 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py index b1954a4e83..90cd15cb68 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py @@ -118,7 +118,6 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_short.py b/tests/test_qwen2.5_0.5B_gsm8k_short.py index 86e21eac8d..867fdcad60 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_short.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_short.py @@ -117,7 +117,6 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 3d4768e420..3d19b48ced 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -93,7 +93,6 @@ def execute(): train_args=train_args, num_gpus_per_node=2, megatron_model_type=None, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index fcd7772882..3d70f3e4ce 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -95,7 +95,6 @@ def execute(): num_gpus_per_node=2 if FEW_GPU else 4, megatron_model_type=None, train_script="train_async.py", - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/test_qwen3_0.6B_megatron_fsdp_align.py index b89a2f283b..1431d8c3d4 100644 --- a/tests/test_qwen3_0.6B_megatron_fsdp_align.py +++ b/tests/test_qwen3_0.6B_megatron_fsdp_align.py @@ -97,7 +97,6 @@ def execute(): train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "), num_gpus_per_node=NUM_GPUS, megatron_model_type=None, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) U.execute_train( @@ -110,7 +109,6 @@ def execute(): ), num_gpus_per_node=NUM_GPUS, megatron_model_type=None, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) U.execute_train( @@ -137,7 +135,6 @@ def execute(): "--debug-train-only " ), num_gpus_per_node=NUM_GPUS, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, megatron_model_type=MODEL_TYPE, ) diff --git a/tests/test_qwen3_0.6B_parallel_check.py b/tests/test_qwen3_0.6B_parallel_check.py index d0ad283d15..44f5c42fa5 100644 --- a/tests/test_qwen3_0.6B_parallel_check.py +++ b/tests/test_qwen3_0.6B_parallel_check.py @@ -95,7 +95,6 @@ def execute(): ), num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) # 8 GPU CPU 1 for num_gpus in [8, 4, 2]: @@ -125,7 +124,6 @@ def execute(): train_args=args, num_gpus_per_node=num_gpus, megatron_model_type=MODEL_TYPE, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) train_args += "--calculate-per-token-loss " diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index b30eeed8e5..adff108043 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -139,7 +139,6 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_4B_ckpt.py b/tests/test_qwen3_4B_ckpt.py index 0df4492e10..22fb2b5fc3 100644 --- a/tests/test_qwen3_4B_ckpt.py +++ b/tests/test_qwen3_4B_ckpt.py @@ -124,7 +124,6 @@ def execute(mode: str = ""): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_4B_fsdp_true_on_policy.py b/tests/test_qwen3_4B_fsdp_true_on_policy.py index 03ba4094e9..7c975c7cc2 100644 --- a/tests/test_qwen3_4B_fsdp_true_on_policy.py +++ b/tests/test_qwen3_4B_fsdp_true_on_policy.py @@ -95,7 +95,6 @@ def execute(): "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", "CUBLAS_WORKSPACE_CONFIG": ":4096:8", "CUDA_DEVICE_MAX_CONNECTIONS": "1", - "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", } U.execute_train( diff --git a/tests/test_qwen3_4B_ppo.py b/tests/test_qwen3_4B_ppo.py index d4c1ac273a..962f610fac 100644 --- a/tests/test_qwen3_4B_ppo.py +++ b/tests/test_qwen3_4B_ppo.py @@ -122,7 +122,6 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, - extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_vl_4B_fsdp.py b/tests/test_qwen3_vl_4B_fsdp.py index bc4ef3293c..fbdffd237e 100644 --- a/tests/test_qwen3_vl_4B_fsdp.py +++ b/tests/test_qwen3_vl_4B_fsdp.py @@ -92,7 +92,6 @@ def execute(): extra_env_vars = { "CUDA_DEVICE_MAX_CONNECTIONS": "1", - "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", } U.execute_train( diff --git a/tests/fast/utils/test_mask_utils.py b/tests/utils/test_mask_utils.py similarity index 100% rename from tests/fast/utils/test_mask_utils.py rename to tests/utils/test_mask_utils.py From df8721103fb0102ea561fa4b2c4b1f7e56dfb24d Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:57:41 +0800 Subject: [PATCH 04/77] Add new API with extensibility and compatibility adapters (#432) --- miles/ray/rollout.py | 12 +- miles/rollout/base_types.py | 50 ++++++-- miles/rollout/modular_rollout/__init__.py | 0 .../rollout/modular_rollout/compatibility.py | 50 ++++++++ tests/rollout/__init__.py | 0 tests/rollout/modular_rollout/__init__.py | 0 .../modular_rollout/test_compatibility.py | 112 ++++++++++++++++++ 7 files changed, 212 insertions(+), 12 deletions(-) create mode 100644 miles/rollout/modular_rollout/__init__.py create mode 100644 miles/rollout/modular_rollout/compatibility.py create mode 100644 tests/rollout/__init__.py create mode 100644 tests/rollout/modular_rollout/__init__.py create mode 100644 tests/rollout/modular_rollout/test_compatibility.py diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 79c6649be0..1cba8b7e00 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -13,7 +13,8 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from miles.backends.sglang_utils.sglang_engine import SGLangEngine -from miles.rollout.base_types import call_rollout_fn +from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput, RolloutFnTrainInput +from miles.rollout.modular_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils import tracking_utils from miles.utils.health_monitor import RolloutHealthMonitor from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client @@ -53,8 +54,9 @@ def __init__(self, args, pg): data_source_cls = load_function(self.args.data_source_path) self.data_source = data_source_cls(args) - self.generate_rollout = load_function(self.args.rollout_function_path) - self.eval_generate_rollout = load_function(self.args.eval_function_path) + input = RolloutFnConstructorInput(args=args, data_source=self.data_source) + self.generate_rollout = load_rollout_function(input, self.args.rollout_function_path) + self.eval_generate_rollout = load_rollout_function(input, self.args.eval_function_path) self.custom_reward_post_process_func = None if self.args.custom_reward_post_process_path is not None: self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path) @@ -142,7 +144,7 @@ def eval(self, rollout_id): return self.health_monitoring_resume() - result = call_rollout_fn(self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True) + result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) data = result.data self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True) metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics) @@ -224,7 +226,7 @@ def _get_rollout_data(self, rollout_id): ) metrics = None else: - data = call_rollout_fn(self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False) + data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id)) metrics = data.metrics data = data.samples # flatten the data if it is a list of lists diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index faa85c7269..d6eb1e8f0d 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -1,9 +1,43 @@ +from argparse import Namespace +from collections.abc import Awaitable from dataclasses import dataclass -from typing import Any +from typing import Any, Protocol, runtime_checkable +from miles.rollout.data_source import DataSource from miles.utils.types import Sample +@dataclass(frozen=True) +class RolloutFnConstructorInput: + args: Namespace + # TODO may refactor DataSource API + data_source: DataSource + + +@dataclass(frozen=True) +class RolloutFnBaseInput: + rollout_id: int + + @property + def evaluation(self): + raise NotImplementedError + + +# subclassing for different data in the future +@dataclass(frozen=True) +class RolloutFnTrainInput(RolloutFnBaseInput): + @property + def evaluation(self): + return False + + +@dataclass(frozen=True) +class RolloutFnEvalInput(RolloutFnBaseInput): + @property + def evaluation(self): + return True + + @dataclass class RolloutFnTrainOutput: samples: list[list[Sample]] @@ -16,11 +50,13 @@ class RolloutFnEvalOutput: metrics: dict[str, Any] = None -def call_rollout_fn(fn, *args, evaluation: bool, **kwargs): - output = fn(*args, **kwargs, evaluation=evaluation) +RolloutFnInput = RolloutFnTrainInput | RolloutFnEvalInput +RolloutFnOutput = RolloutFnTrainOutput | RolloutFnEvalOutput - # compatibility for legacy version - if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)): - output = RolloutFnEvalOutput(data=output) if evaluation else RolloutFnTrainOutput(samples=output) - return output +# TODO: may add add_arguments +# TODO: may add save/load if need it to be stateful +# Duck typing, users do not need to extend this class +@runtime_checkable +class RolloutFnProtocol(Protocol): + def __call__(self, input: RolloutFnInput) -> RolloutFnOutput | Awaitable[RolloutFnOutput]: ... diff --git a/miles/rollout/modular_rollout/__init__.py b/miles/rollout/modular_rollout/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/miles/rollout/modular_rollout/compatibility.py b/miles/rollout/modular_rollout/compatibility.py new file mode 100644 index 0000000000..7d1a70e79c --- /dev/null +++ b/miles/rollout/modular_rollout/compatibility.py @@ -0,0 +1,50 @@ +import inspect +from collections.abc import Callable + +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalOutput, + RolloutFnInput, + RolloutFnOutput, + RolloutFnProtocol, + RolloutFnTrainOutput, +) +from miles.utils.async_utils import run +from miles.utils.misc import load_function + + +class LegacyRolloutFnAdapter: + def __init__(self, input: RolloutFnConstructorInput, fn: Callable): + self.args = input.args + self.data_source = input.data_source + self.fn = fn + + def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: + output = self.fn(self.args, input.rollout_id, self.data_source, evaluation=input.evaluation) + + # compatibility for legacy version + if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)): + output = RolloutFnEvalOutput(data=output) if input.evaluation else RolloutFnTrainOutput(samples=output) + + return output + + +assert issubclass(LegacyRolloutFnAdapter, RolloutFnProtocol) + + +def load_rollout_function(input: RolloutFnConstructorInput, path: str): + fn = load_function(path) + + if inspect.isclass(fn): + return fn(input) + else: + return LegacyRolloutFnAdapter(input, fn) + + +def call_rollout_function(fn: RolloutFnProtocol, input: RolloutFnInput) -> RolloutFnOutput: + output = fn(input) + + if inspect.iscoroutine(output): + output = run(output) + + return output diff --git a/tests/rollout/__init__.py b/tests/rollout/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/rollout/modular_rollout/__init__.py b/tests/rollout/modular_rollout/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/rollout/modular_rollout/test_compatibility.py b/tests/rollout/modular_rollout/test_compatibility.py new file mode 100644 index 0000000000..596fa76270 --- /dev/null +++ b/tests/rollout/modular_rollout/test_compatibility.py @@ -0,0 +1,112 @@ +import asyncio +from unittest.mock import patch + +import pytest + +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnEvalOutput, + RolloutFnTrainInput, + RolloutFnTrainOutput, +) +from miles.rollout.modular_rollout.compatibility import ( + LegacyRolloutFnAdapter, + call_rollout_function, + load_rollout_function, +) + + +@pytest.fixture +def constructor_input(): + return RolloutFnConstructorInput(args="dummy_args", data_source="dummy_data_source") + + +class TestSupportedRolloutFormats: + """ + Documentation test to show various supported rollout function formats + """ + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_1_legacy_function_raw_output(self, constructor_input, evaluation): + def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): + if evaluation: + return {"metric": {"accuracy": 0.9}} + return [[{"text": "sample"}]] + + with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "path.to.fn") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, LegacyRolloutFnAdapter) + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"metric": {"accuracy": 0.9}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "sample"}]] + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_2_legacy_function_typed_output(self, constructor_input, evaluation): + def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): + if evaluation: + return RolloutFnEvalOutput(data={"ds": {"acc": 0.95}}) + return RolloutFnTrainOutput(samples=[[{"text": "typed"}]]) + + with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "path.to.fn") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"ds": {"acc": 0.95}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "typed"}]] + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_3_sync_class(self, constructor_input, evaluation): + class SyncRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + pass + + def __call__(self, input): + if input.evaluation: + return RolloutFnEvalOutput(data={"test": {"score": 1}}) + return RolloutFnTrainOutput(samples=[[{"text": "sync"}]]) + + with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=SyncRolloutFn): + fn = load_rollout_function(constructor_input, "path.to.SyncRolloutFn") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, SyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_4_async_class(self, constructor_input, evaluation): + class AsyncRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + pass + + async def __call__(self, input): + await asyncio.sleep(0.001) + if input.evaluation: + return RolloutFnEvalOutput(data={"benchmark": {"accuracy": 0.98}}) + return RolloutFnTrainOutput(samples=[[{"text": "async"}]]) + + with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=AsyncRolloutFn): + fn = load_rollout_function(constructor_input, "path.to.AsyncRolloutFn") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, AsyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) From 2dbe0d7053287fd3f7db827f26b8db447554750b Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:58:13 +0800 Subject: [PATCH 05/77] Copy and split sglang_rollout.py to modular_rollout (#433) --- .../modular_rollout/inference_wrapper.py | 101 ++++++++++ .../modular_rollout/orchestration_common.py | 176 +++++++++++++++++ .../modular_rollout/orchestration_eval.py | 132 +++++++++++++ .../modular_rollout/orchestration_train.py | 178 ++++++++++++++++++ 4 files changed, 587 insertions(+) create mode 100644 miles/rollout/modular_rollout/inference_wrapper.py create mode 100644 miles/rollout/modular_rollout/orchestration_common.py create mode 100644 miles/rollout/modular_rollout/orchestration_eval.py create mode 100644 miles/rollout/modular_rollout/orchestration_train.py diff --git a/miles/rollout/modular_rollout/inference_wrapper.py b/miles/rollout/modular_rollout/inference_wrapper.py new file mode 100644 index 0000000000..d27311c1d3 --- /dev/null +++ b/miles/rollout/modular_rollout/inference_wrapper.py @@ -0,0 +1,101 @@ +from argparse import Namespace +from typing import Any + +import numpy as np +import pybase64 + +from miles.utils.http_utils import post +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.types import Sample + + +async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: + """Generate using traditional SGLang router with token-based workflow""" + + if args.ci_test: + assert isinstance(sample.prompt, str) + + from miles.rollout.modular_rollout.orchestration_common import GenerateState + + state = GenerateState(args) + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + assert ( + sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED + ), f"Sample status is {sample.status}" + + if state.processor: + processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + sample.multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + else: + prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False) + + if len(sample.response) > 0: + sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) + + assert ( + sampling_params["max_new_tokens"] >= 0 + ), f"max_new_tokens: {sampling_params['max_new_tokens']} should not be less than 0" + if sampling_params["max_new_tokens"] == 0: + sample.status = Sample.Status.TRUNCATED + return sample + + # Prepare payload for sglang server + payload = { + "sampling_params": sampling_params, + "return_logprob": True, + } + + if args.use_rollout_routing_replay: + payload["return_routed_experts"] = True + + if sample.multimodal_inputs and sample.multimodal_inputs["images"]: + image_data = sample.multimodal_inputs["images"] + payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] + + # Use existing tokens for multi-turn or tokenize the new prompt + if len(sample.response) > 0: + payload["input_ids"] = sample.tokens + else: + payload["input_ids"] = prompt_ids + if not sample.tokens: # Initialize sample.tokens for the first turn + sample.tokens = prompt_ids + + output = await post(url, payload) + + if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: + from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + + sample = await postprocess_sample_with_radix_tree(args, sample, output) + else: + if "output_token_logprobs" in output["meta_info"]: + new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + new_response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] + else: + new_response_tokens, new_response_log_probs = [], [] + + # Update sample with tokens directly - avoiding re-tokenization + sample.tokens = sample.tokens + new_response_tokens + sample.response_length += len(new_response_tokens) + sample.response += output["text"] + + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += new_response_log_probs + + if "routed_experts" in output["meta_info"]: + sample.rollout_routed_experts = np.frombuffer( + pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")), + dtype=np.int32, + ).reshape( + len(sample.tokens) - 1, + args.num_layers, + args.moe_router_topk, + ) + + sample.update_from_meta_info(args, output["meta_info"]) + + return sample diff --git a/miles/rollout/modular_rollout/orchestration_common.py b/miles/rollout/modular_rollout/orchestration_common.py new file mode 100644 index 0000000000..2c8c681ae7 --- /dev/null +++ b/miles/rollout/modular_rollout/orchestration_common.py @@ -0,0 +1,176 @@ +import asyncio +import inspect +import logging +from argparse import Namespace +from contextlib import contextmanager +from typing import Any + +import numpy as np + +from miles.rollout.modular_rollout.inference_wrapper import generate +from miles.rollout.rm_hub import async_rm, batched_async_rm +from miles.utils.misc import SingletonMeta, load_function +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class GenerateState(metaclass=SingletonMeta): + """ + The global state for the generation process. + """ + + def __init__(self, args: Namespace) -> None: + # persistent state for the generation process + self.args = args + self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + + self.semaphore = asyncio.Semaphore( + args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + ) + self.sampling_params: dict[str, Any] = dict( + temperature=args.rollout_temperature, + top_p=args.rollout_top_p, + top_k=args.rollout_top_k, + max_new_tokens=args.rollout_max_response_len, + stop=args.rollout_stop, + stop_token_ids=args.rollout_stop_token_ids, + skip_special_tokens=args.rollout_skip_special_tokens, + no_stop_trim=True, + spaces_between_special_tokens=False, + ) + + if getattr(args, "sglang_enable_deterministic_inference", False): + sampling_seed_base = args.rollout_seed + self.group_sampling_seeds = [sampling_seed_base + i for i in range(args.n_samples_per_prompt)] + + # dp rank balancing + self.dp_counts = [0] * (args.sglang_dp_size or 1) + self.dp_rank = 0 + + self.reset() + + @contextmanager + def dp_rank_context(self): + candidates = [i for i, count in enumerate(self.dp_counts) if count == min(self.dp_counts)] + dp_rank = int(np.random.choice(candidates)) + self.dp_counts[dp_rank] += 1 + self.dp_rank = dp_rank + try: + yield dp_rank + finally: + self.dp_counts[dp_rank] -= 1 + assert self.dp_counts[dp_rank] >= 0 + + def reset(self) -> None: + self.remaining_batch_size = 0 + self.pendings = set() + self.aborted = False + + def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: + for group in samples: + self.pendings.add( + asyncio.create_task( + # submit a group of samples as a single task. + generate_and_rm_group( + self.args, + group, + sampling_params=self.sampling_params.copy(), + evaluation=False, + ) + ) + ) + self.remaining_batch_size += len(samples) + + +async def generate_and_rm( + args: Namespace, + sample: Sample | list[Sample], + sampling_params: dict[str, Any], + evaluation: bool = False, +) -> Sample | list[Sample]: + # mask previous off-policy generation for partial rollout + if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: + sample.loss_mask = [0] * sample.response_length + + # For samples with existing response, check if they're complete + if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED: + assert sample.response is not None + if not args.group_rm: + assert sample.reward is not None + return sample + + state = GenerateState(args) + + # generate + async with state.semaphore: + if state.aborted: + sample.status = Sample.Status.ABORTED + return sample + + with state.dp_rank_context() as _: + if args.custom_generate_function_path is not None: + custom_generate_func = load_function(args.custom_generate_function_path) + # if signature has evaluation, pass evaluation + if "evaluation" in inspect.signature(custom_generate_func).parameters: + sample = await custom_generate_func(args, sample, sampling_params, evaluation=evaluation) + else: + sample = await custom_generate_func(args, sample, sampling_params) + else: + sample = await generate(args, sample, sampling_params) + + # for the rm that need the whole group, we will not do the rm here + if args.group_rm: + return sample + + # multi samples + if isinstance(sample, list): + samples = sample + if any([sample.status == Sample.Status.ABORTED for sample in samples]): + return samples + + # for multi agent system, the reward of some sample is calculated during generation. + samples_need_reward = [sample for sample in samples if sample.reward is None] + rewards = await batched_async_rm(args, samples_need_reward) + for sample, reward in zip(samples_need_reward, rewards, strict=False): + sample.reward = reward + return samples + else: + if sample.status == Sample.Status.ABORTED: + return sample + # for multi-turn environment, a reward could be assigned to the agent. + if sample.reward is None: + sample.reward = await async_rm(args, sample) + + return sample + + +async def generate_and_rm_group( + args: Namespace, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False +) -> list[Sample]: + state = GenerateState(args) + + if state.aborted: + return group + + tasks = [] + for idx, sample in enumerate(group): + current_sampling_params = sampling_params.copy() + if getattr(args, "sglang_enable_deterministic_inference", False): + seed = state.group_sampling_seeds[idx] + current_sampling_params["sampling_seed"] = seed + tasks.append( + asyncio.create_task(generate_and_rm(args, sample, current_sampling_params, evaluation=evaluation)) + ) + + group = await asyncio.gather(*tasks) + + # for the rm that need the whole group, we will do the rm here + if not state.aborted and args.group_rm: + rewards = await batched_async_rm(args, group) + for sample, reward in zip(group, rewards, strict=False): + sample.reward = reward + + return group diff --git a/miles/rollout/modular_rollout/orchestration_eval.py b/miles/rollout/modular_rollout/orchestration_eval.py new file mode 100644 index 0000000000..76afe265a6 --- /dev/null +++ b/miles/rollout/modular_rollout/orchestration_eval.py @@ -0,0 +1,132 @@ +import asyncio +import copy +import logging +from argparse import Namespace +from typing import Any + +from tqdm import tqdm + +from miles.rollout.base_types import RolloutFnEvalOutput +from miles.rollout.modular_rollout.orchestration_common import generate_and_rm +from miles.utils.data import Dataset +from miles.utils.eval_config import EvalDatasetConfig +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + +EVAL_PROMPT_DATASET = {} + + +async def eval_rollout(args: Namespace, rollout_id: int) -> tuple[dict[str, dict[str, list[Any]]], list[list[Sample]]]: + assert not args.group_rm, "Group RM is not supported for eval rollout" + + coros = [] + for dataset_cfg in getattr(args, "eval_datasets", []) or []: + coros.append(eval_rollout_single_dataset(args, rollout_id, dataset_cfg)) + results_list = await asyncio.gather(*coros) + results = {} + for r in results_list: + results.update(r) + return RolloutFnEvalOutput(data=results), [] + + +async def eval_rollout_single_dataset( + args: Namespace, rollout_id: int, dataset_cfg: EvalDatasetConfig +) -> dict[str, dict[str, list[Any]]]: + """An example to implement the eval_rollout function for an rule based rm rollout generation. + + Args: + args: the whole args + rollout_id: int, the id of the rollout, used for deterministic data generation + dataset_cfg: configuration of the dataset + """ + assert not args.group_rm, "Group RM is not supported for eval rollout" + + global EVAL_PROMPT_DATASET + + cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template) + if cache_key not in EVAL_PROMPT_DATASET: + tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + EVAL_PROMPT_DATASET[cache_key] = Dataset( + path=dataset_cfg.path, + tokenizer=tokenizer, + processor=processor, + max_length=args.eval_max_prompt_len, + prompt_key=dataset_cfg.input_key, + label_key=dataset_cfg.label_key, + multimodal_keys=args.multimodal_keys, + metadata_key=dataset_cfg.metadata_key, + tool_key=dataset_cfg.tool_key, + apply_chat_template=args.apply_chat_template, + apply_chat_template_kwargs=args.apply_chat_template_kwargs, + ) + dataset = EVAL_PROMPT_DATASET[cache_key] + + base_sampling_params = dict( + temperature=dataset_cfg.temperature, + top_p=dataset_cfg.top_p, + top_k=dataset_cfg.top_k, + max_new_tokens=dataset_cfg.max_response_len, + stop=args.rollout_stop, + stop_token_ids=args.rollout_stop_token_ids, + skip_special_tokens=args.rollout_skip_special_tokens, + no_stop_trim=True, + spaces_between_special_tokens=False, + ) + + tasks = [] + # do multiple samples for eval prompts + sample_index = 0 + for _i, prompt_sample in enumerate(dataset.samples): + for j in range(dataset_cfg.n_samples_per_eval_prompt): + # use the same prompt for multiple samples + sample = copy.deepcopy(prompt_sample) + sample.index = sample_index + sample_index += 1 + sample.metadata = dataset_cfg.inject_metadata(getattr(sample, "metadata", None)) + sampling_params = base_sampling_params + if getattr(args, "sglang_enable_deterministic_inference", False): + sampling_params = base_sampling_params.copy() + sampling_params["sampling_seed"] = args.rollout_seed + j + tasks.append( + asyncio.create_task( + generate_and_rm( + args, + sample, + sampling_params=sampling_params, + evaluation=True, + ) + ) + ) + + data = [] + do_print = True + pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) + for coro in asyncio.as_completed(tasks): + sample = await coro + if do_print: + logger.info( + "eval_rollout_single_dataset example data: " + f"{[str(sample.prompt) + sample.response]} " + f"reward={sample.reward}" + ) + do_print = False + if isinstance(sample, list): + data.extend(sample) + else: + data.append(sample) + pbar.update(1) + pbar.close() + + data.sort(key=lambda sample: sample.index) + + reward_key = args.eval_reward_key or args.reward_key + return { + dataset_cfg.name: { + "rewards": [sample.reward if not reward_key else sample.reward[reward_key] for sample in data], + "truncated": [sample.status == Sample.Status.TRUNCATED for sample in data], + "samples": data, + } + } diff --git a/miles/rollout/modular_rollout/orchestration_train.py b/miles/rollout/modular_rollout/orchestration_train.py new file mode 100644 index 0000000000..7682c4fdaf --- /dev/null +++ b/miles/rollout/modular_rollout/orchestration_train.py @@ -0,0 +1,178 @@ +import asyncio +import logging +from argparse import Namespace +from collections.abc import Callable +from typing import Any + +import sglang_router +from packaging.version import parse +from tqdm import tqdm + +from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput +from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter +from miles.rollout.modular_rollout.orchestration_common import GenerateState +from miles.rollout.modular_rollout.orchestration_eval import eval_rollout +from miles.utils.async_utils import run +from miles.utils.http_utils import get, post +from miles.utils.misc import load_function +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: + aborted_samples = [] + + state = GenerateState(args) + assert not state.aborted + state.aborted = True + + if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") + urls = response["urls"] + else: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") + urls = [worker["url"] for worker in response["workers"]] + + logger.info(f"Abort request for {urls}") + await asyncio.gather(*[post(f"{url}/abort_request", {"abort_all": True}) for url in urls]) + + # make sure all the pending tasks are finished + count = 0 + while state.pendings: + done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED) + + if not args.partial_rollout: + continue + + # for partial rollout, collect the partial samples into the data buffer + for task in done: + group = task.result() + for sample in group: + if sample.response and "start_rollout_id" not in sample.metadata: + sample.metadata["start_rollout_id"] = rollout_id + aborted_samples.append(group) + count += len(group) + + if args.partial_rollout: + logger.info(f"Collected {count} partial samples into the data buffer") + + return aborted_samples + + +async def generate_rollout_async( + args: Namespace, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] +) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: + """An example to implement the generate_rollout function for an rule based rm rollout generation. + + Args: + args: the whole args + rollout_id: int, the id of the rollout, used for deterministic data generation + data_source: the data source to fetch + + Returns: + tuple[RolloutFnTrainOutput, list[list[Sample]]]: + - data: a list of groups of samples generated by the rollout, length equals `rollout_batch_size` + - aborted_samples: any partial groups collected during abort when partial_rollout is enabled + """ + assert args.rollout_global_dataset + + state = GenerateState(args) + + # instantiate data filters + dynamic_filter = ( + load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None + ) + + metric_gatherer = MetricGatherer() + + # target_data_size is the total number of valid samples to get + target_data_size = args.rollout_batch_size + + data = [] + all_data = [] + do_print = True + pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation") + while len(data) < target_data_size: + while state.remaining_batch_size < target_data_size: + # get samples from the buffer and submit the generation requests. + samples = data_source(args.over_sampling_batch_size) + state.submit_generate_tasks(samples) + + # wait for the generation to finish + done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED) + for task in done: + group: list[Sample] = task.result() + + if do_print: + sample = group[0][0] if isinstance(group[0], list) else group[0] + logger.info( + f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", + ) + do_print = False + + assert len(group) == args.n_samples_per_prompt + all_data.append(group) + dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) + if not dynamic_filter_output.keep: + metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) + state.remaining_batch_size -= 1 + continue + + # add the samples to the data + # NOTE: here we have not stored all the unused samples back to the data buffer. + if len(data) < target_data_size: + data.append(group) + pbar.update(args.n_samples_per_prompt) + + pbar.close() + sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] + logger.info( + f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", + ) + + # there are still some unfinished requests, abort them + aborted_samples = await abort(args, rollout_id) + + assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" + data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) + all_samples = sorted( + all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index + ) + + # reset the global state to prevent effects on the next rollout or eval. + state.reset() + if args.rollout_sample_filter_path is not None: + filter_func = load_function(args.rollout_sample_filter_path) + filter_func(args, data) + + # There can be circumstances where users want to process all samples including filtered ones. + if args.rollout_all_samples_process_path is not None: + process_func = load_function(args.rollout_all_samples_process_path) + process_func(args, all_samples, data_source) + + return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples + + +def generate_rollout( + args: Namespace, rollout_id: int, data_source: Any, evaluation: bool = False +) -> RolloutFnTrainOutput | RolloutFnEvalOutput: + """An example to implement the generate_rollout function for an rule based rm rollout generation. + + Args: + args: the whole args + rollout_id: int, the id of the rollout, used for deterministic data generation + data_buffer: the data buffer to store the generated samples + evaluation: bool, whether the rollout is for evaluation or not + + Returns: + list[list[Sample]]: a list of list of samples generated by the rollout + """ + assert args.rollout_global_dataset + if evaluation: + output, _ = run(eval_rollout(args, rollout_id)) + return output + + output, aborted_samples = run(generate_rollout_async(args, rollout_id, data_source.get_samples)) + data_source.add_samples(aborted_samples) + return output From 4adb662949d3f9556e1fc88c74eabf0e2f5df3d1 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:59:49 +0800 Subject: [PATCH 06/77] Use new rollout function API for modular rollout (#434) --- .../modular_rollout/orchestration_eval.py | 32 +++++++++-------- .../modular_rollout/orchestration_train.py | 34 ++++++------------- 2 files changed, 28 insertions(+), 38 deletions(-) diff --git a/miles/rollout/modular_rollout/orchestration_eval.py b/miles/rollout/modular_rollout/orchestration_eval.py index 76afe265a6..e89b2f2edb 100644 --- a/miles/rollout/modular_rollout/orchestration_eval.py +++ b/miles/rollout/modular_rollout/orchestration_eval.py @@ -6,7 +6,7 @@ from tqdm import tqdm -from miles.rollout.base_types import RolloutFnEvalOutput +from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput, RolloutFnEvalOutput from miles.rollout.modular_rollout.orchestration_common import generate_and_rm from miles.utils.data import Dataset from miles.utils.eval_config import EvalDatasetConfig @@ -18,19 +18,6 @@ EVAL_PROMPT_DATASET = {} -async def eval_rollout(args: Namespace, rollout_id: int) -> tuple[dict[str, dict[str, list[Any]]], list[list[Sample]]]: - assert not args.group_rm, "Group RM is not supported for eval rollout" - - coros = [] - for dataset_cfg in getattr(args, "eval_datasets", []) or []: - coros.append(eval_rollout_single_dataset(args, rollout_id, dataset_cfg)) - results_list = await asyncio.gather(*coros) - results = {} - for r in results_list: - results.update(r) - return RolloutFnEvalOutput(data=results), [] - - async def eval_rollout_single_dataset( args: Namespace, rollout_id: int, dataset_cfg: EvalDatasetConfig ) -> dict[str, dict[str, list[Any]]]: @@ -130,3 +117,20 @@ async def eval_rollout_single_dataset( "samples": data, } } + + +class SimpleEvalRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + self.args = input.args + + async def __call__(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput: + assert not self.args.group_rm, "Group RM is not supported for eval rollout" + + coros = [] + for dataset_cfg in getattr(self.args, "eval_datasets", []) or []: + coros.append(eval_rollout_single_dataset(self.args, input.rollout_id, dataset_cfg)) + results_list = await asyncio.gather(*coros) + results = {} + for r in results_list: + results.update(r) + return RolloutFnEvalOutput(data=results) diff --git a/miles/rollout/modular_rollout/orchestration_train.py b/miles/rollout/modular_rollout/orchestration_train.py index 7682c4fdaf..cd9549df48 100644 --- a/miles/rollout/modular_rollout/orchestration_train.py +++ b/miles/rollout/modular_rollout/orchestration_train.py @@ -2,17 +2,14 @@ import logging from argparse import Namespace from collections.abc import Callable -from typing import Any import sglang_router from packaging.version import parse from tqdm import tqdm -from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput +from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnTrainInput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter from miles.rollout.modular_rollout.orchestration_common import GenerateState -from miles.rollout.modular_rollout.orchestration_eval import eval_rollout -from miles.utils.async_utils import run from miles.utils.http_utils import get, post from miles.utils.misc import load_function from miles.utils.types import Sample @@ -154,25 +151,14 @@ async def generate_rollout_async( return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples -def generate_rollout( - args: Namespace, rollout_id: int, data_source: Any, evaluation: bool = False -) -> RolloutFnTrainOutput | RolloutFnEvalOutput: - """An example to implement the generate_rollout function for an rule based rm rollout generation. - - Args: - args: the whole args - rollout_id: int, the id of the rollout, used for deterministic data generation - data_buffer: the data buffer to store the generated samples - evaluation: bool, whether the rollout is for evaluation or not +class SimpleTrainRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + self.args = input.args + self.data_source = input.data_source - Returns: - list[list[Sample]]: a list of list of samples generated by the rollout - """ - assert args.rollout_global_dataset - if evaluation: - output, _ = run(eval_rollout(args, rollout_id)) + async def __call__(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput: + output, aborted_samples = await generate_rollout_async( + self.args, input.rollout_id, self.data_source.get_samples + ) + self.data_source.add_samples(aborted_samples) return output - - output, aborted_samples = run(generate_rollout_async(args, rollout_id, data_source.get_samples)) - data_source.add_samples(aborted_samples) - return output From d08836c3e4e14bebb0c814fafff638ba3f392922 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:00:11 +0800 Subject: [PATCH 07/77] Add mock SGLang server (#435) --- miles/utils/test_utils/__init__.py | 0 miles/utils/test_utils/mock_sglang_server.py | 122 ++++++++++++++++++ .../utils/test_utils/uvicorn_thread_server.py | 49 +++++++ .../test_utils/test_mock_sglang_server.py | 79 ++++++++++++ 4 files changed, 250 insertions(+) create mode 100644 miles/utils/test_utils/__init__.py create mode 100644 miles/utils/test_utils/mock_sglang_server.py create mode 100644 miles/utils/test_utils/uvicorn_thread_server.py create mode 100644 tests/utils/test_utils/test_mock_sglang_server.py diff --git a/miles/utils/test_utils/__init__.py b/miles/utils/test_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py new file mode 100644 index 0000000000..6d4144fc1f --- /dev/null +++ b/miles/utils/test_utils/mock_sglang_server.py @@ -0,0 +1,122 @@ +import re +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from transformers import AutoTokenizer + +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +@dataclass(frozen=True) +class ProcessResult: + text: str + finish_reason: str + + +ProcessFn = Callable[[str], ProcessResult] + + +class MockSGLangServer: + def __init__( + self, + model_name: str, + process_fn: ProcessFn, + host: str, + port: int, + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.process_fn = process_fn + self.host = host + self.port = port or find_available_port(30000) + + self.app = FastAPI() + self._server: UvicornThreadServer | None = None + + self._setup_routes() + + def _setup_routes(self): + @self.app.post("/generate") + async def generate(request: Request): + payload = await request.json() + + assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True" + input_ids = payload.get("input_ids", []) + + prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False) + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + prompt_tokens = len(input_ids) + completion_tokens = len(output_ids) + + finish_reason_dict = {"type": process_result.finish_reason} + if process_result.finish_reason == "length": + finish_reason_dict["length"] = completion_tokens + + output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] + + response = { + "text": process_result.text, + "meta_info": { + "finish_reason": finish_reason_dict, + "prompt_tokens": prompt_tokens, + "cached_tokens": 0, + "completion_tokens": completion_tokens, + "output_token_logprobs": output_token_logprobs, + }, + } + + return JSONResponse(content=response) + + @self.app.get("/health") + async def health(): + return JSONResponse(content={"status": "ok"}) + + @self.app.post("/abort_request") + async def abort_request(_request: Request): + return JSONResponse(content={"status": "ok"}) + + def start(self): + self._server = UvicornThreadServer(self.app, host=self.host, port=self.port) + self._server.start() + + def stop(self): + if self._server is not None: + self._server.stop() + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + +def default_process_fn(prompt: str) -> ProcessResult: + match = re.search(r"What is 1\+(\d+)\?", prompt) + if match: + num = int(match.group(1)) + ans = 1 + num + return ProcessResult(text=f"\\boxed{{{ans}}}", finish_reason="stop") + return ProcessResult(text="I don't understand.", finish_reason="stop") + + +@contextmanager +def with_mock_server( + model_name: str = "Qwen/Qwen3-0.6B", + process_fn: ProcessFn = default_process_fn, + host: str = "127.0.0.1", + port: int | None = None, +): + server = MockSGLangServer( + model_name=model_name, + process_fn=process_fn, + host=host, + port=port, + ) + try: + server.start() + yield server + finally: + server.stop() diff --git a/miles/utils/test_utils/uvicorn_thread_server.py b/miles/utils/test_utils/uvicorn_thread_server.py new file mode 100644 index 0000000000..904343c984 --- /dev/null +++ b/miles/utils/test_utils/uvicorn_thread_server.py @@ -0,0 +1,49 @@ +import asyncio +import socket +import threading +import time + +import uvicorn + + +class UvicornThreadServer: + def __init__(self, app, host: str, port: int): + self._app = app + self.host = host + self.port = port + self._server: uvicorn.Server | None = None + self._thread: threading.Thread | None = None + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + def start(self) -> None: + config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info") + self._server = uvicorn.Server(config) + + def run() -> None: + asyncio.run(self._server.serve()) + + self._thread = threading.Thread(target=run, daemon=True) + self._thread.start() + self._wait_for_port_open() + + def stop(self) -> None: + if self._server is not None: + self._server.should_exit = True + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=2.0) + + def _wait_for_port_open(self) -> None: + for _ in range(50): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex((self.host, self.port)) + sock.close() + if result == 0: + return + except Exception: + pass + time.sleep(0.1) + raise RuntimeError(f"Failed to start server on {self.url}") diff --git a/tests/utils/test_utils/test_mock_sglang_server.py b/tests/utils/test_utils/test_mock_sglang_server.py new file mode 100644 index 0000000000..6163e68bda --- /dev/null +++ b/tests/utils/test_utils/test_mock_sglang_server.py @@ -0,0 +1,79 @@ +import pytest +import requests + +from miles.utils.test_utils.mock_sglang_server import ProcessResult, default_process_fn, with_mock_server + + +@pytest.fixture(scope="module") +def mock_server(): + with with_mock_server() as server: + yield server + + +def test_basic_server_start_stop(mock_server): + assert mock_server.port > 0 + assert f"http://{mock_server.host}:{mock_server.port}" == mock_server.url + + +def test_generate_endpoint_basic(mock_server): + prompt = "What is 1+7?" + input_ids = mock_server.tokenizer.encode(prompt, add_special_tokens=False) + assert input_ids == [3838, 374, 220, 16, 10, 22, 30] + + response = requests.post( + f"{mock_server.url}/generate", + json={ + "input_ids": input_ids, + "sampling_params": {"temperature": 0.7, "max_new_tokens": 10}, + "return_logprob": True, + }, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + + assert data == { + "text": "\\boxed{8}", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": len(input_ids), + "cached_tokens": 0, + "completion_tokens": 5, + "output_token_logprobs": [ + [-0.0, 59], + [-0.0078125, 79075], + [-0.015625, 90], + [-0.0234375, 23], + [-0.03125, 92], + ], + }, + } + + +def test_process_fn_receives_decoded_prompt(mock_server): + received_prompts = [] + + def process_fn(prompt: str) -> ProcessResult: + received_prompts.append(prompt) + return ProcessResult(text="response", finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + input_ids = [1, 2, 3] + requests.post(f"{server.url}/generate", json={"input_ids": input_ids, "sampling_params": {}}, timeout=5.0) + + assert len(received_prompts) == 1 + assert isinstance(received_prompts[0], str) + + +def test_default_process_fn(): + result = default_process_fn("What is 1+5?") + assert result.text == "\\boxed{6}" + assert result.finish_reason == "stop" + + result = default_process_fn("What is 1+10?") + assert result.text == "\\boxed{11}" + assert result.finish_reason == "stop" + + result = default_process_fn("Hello") + assert result.text == "I don't understand." + assert result.finish_reason == "stop" From e4c2dbfcddfe6ee93ad30a4bae9f0da4f78fb918 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:00:31 +0800 Subject: [PATCH 08/77] Add integration test for rollout generation for several combinations with compatibility test (#436) --- tests/__init__.py | 1 + tests/conftest.py | 3 + tests/fixtures/__init__.py | 1 + tests/fixtures/rollout_integration.py | 108 ++++++++++++++++++ .../modular_rollout/test_integration.py | 98 ++++++++++++++++ 5 files changed, 211 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/rollout_integration.py create mode 100644 tests/rollout/modular_rollout/test_integration.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..6697bd0b90 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,3 @@ +from tests.fixtures.rollout_integration import rollout_integration_env + +_ = rollout_integration_env diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/fixtures/rollout_integration.py b/tests/fixtures/rollout_integration.py new file mode 100644 index 0000000000..079147d289 --- /dev/null +++ b/tests/fixtures/rollout_integration.py @@ -0,0 +1,108 @@ +import json +from argparse import Namespace +from collections.abc import Iterator +from contextlib import contextmanager +from pathlib import Path +from unittest.mock import patch + +import pytest +import requests + +from miles.rollout.data_source import RolloutDataSourceWithBuffer +from miles.rollout.modular_rollout.orchestration_common import GenerateState +from miles.router.router import MilesRouter +from miles.utils.arguments import parse_args +from miles.utils.http_utils import find_available_port, init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.test_utils.mock_sglang_server import with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +def _build_args(*, data_path: str, router_port: int, extra_argv: list[str] | None = None) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--n-samples-per-prompt", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + "Qwen/Qwen3-0.6B", + "--prompt-data", + data_path, + "--input-key", + "input", + "--label-key", + "label", + "--rm-type", + "math", + "--eval-prompt-data", + "toy", + data_path, + "--use-miles-router", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + (extra_argv or []) + with patch("sys.argv", argv): + args = parse_args() + args.miles_router_middleware_paths = [] + init_http_client(args) + return args + + +@contextmanager +def _with_miles_router(args: Namespace) -> Iterator[UvicornThreadServer]: + router = MilesRouter(args, verbose=False) + server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) + try: + server.start() + yield server + finally: + server.stop() + + +def _write_jsonl(path: str, rows: list[dict]) -> None: + Path(path).write_text("".join(json.dumps(row, ensure_ascii=False) + "\n" for row in rows), encoding="utf-8") + + +def _cleanup_legacy_singleton(): + SingletonMeta._instances.pop(GenerateState, None) + + +@pytest.fixture +def rollout_integration_env(tmp_path, request): + extra_argv = request.param + assert isinstance(extra_argv, list) + + data_path = str(tmp_path / "data.jsonl") + _write_jsonl(data_path, [{"input": "What is 1+7?", "label": "8"}]) + + router_port = find_available_port(20000) + args = _build_args(data_path=data_path, router_port=router_port, extra_argv=extra_argv) + + _cleanup_legacy_singleton() + + with with_mock_server(model_name=args.hf_checkpoint) as mock_server: + with _with_miles_router(args) as router_server: + r = requests.post( + f"{router_server.url}/add_worker", + params={"url": mock_server.url}, + timeout=5.0, + ) + r.raise_for_status() + + data_source = RolloutDataSourceWithBuffer(args) + yield args, data_source + + _cleanup_legacy_singleton() diff --git a/tests/rollout/modular_rollout/test_integration.py b/tests/rollout/modular_rollout/test_integration.py new file mode 100644 index 0000000000..ed21ceee51 --- /dev/null +++ b/tests/rollout/modular_rollout/test_integration.py @@ -0,0 +1,98 @@ +import pytest + +from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput, RolloutFnTrainInput +from miles.rollout.modular_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.utils.types import Sample + + +def _expected_sample(*, group_index: int | None) -> Sample: + return Sample( + group_index=group_index, + index=0, + prompt="What is 1+7?", + tokens=[3838, 374, 220, 16, 10, 22, 30, 59, 79075, 90, 23, 92], + multimodal_inputs=None, + multimodal_train_inputs=None, + response="\\boxed{8}", + response_length=5, + label="8", + reward=1, + loss_mask=None, + weight_versions=[], + rollout_log_probs=[-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125], + rollout_routed_experts=None, + remove_sample=False, + status=Sample.Status.COMPLETED, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=Sample.SpecInfo( + spec_accept_token_num=0, spec_draft_token_num=0, spec_verify_ct=0, completion_token_num=0 + ), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=0, total_prompt_tokens=7), + ) + + +_ROLLOUT_ARGV_VARIANTS = [ + pytest.param( + [ + "--rollout-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--eval-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ], + id="old_rollout_old_generate", + ), + pytest.param( + [ + "--rollout-function-path", + "miles.rollout.modular_rollout.orchestration_train.SimpleTrainRolloutFn", + "--eval-function-path", + "miles.rollout.modular_rollout.orchestration_eval.SimpleEvalRolloutFn", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ], + id="new_rollout_old_generate", + ), + pytest.param( + [ + "--rollout-function-path", + "miles.rollout.modular_rollout.orchestration_train.SimpleTrainRolloutFn", + "--eval-function-path", + "miles.rollout.modular_rollout.orchestration_eval.SimpleEvalRolloutFn", + "--custom-generate-function-path", + "miles.rollout.modular_rollout.inference_wrapper.generate", + ], + id="new_rollout_new_generate", + ), +] + + +@pytest.mark.parametrize("rollout_integration_env", _ROLLOUT_ARGV_VARIANTS, indirect=True) +def test_simple_train_rollout_fn_integration(rollout_integration_env): + args, data_source = rollout_integration_env + fn = load_rollout_function( + RolloutFnConstructorInput(args=args, data_source=data_source), args.rollout_function_path + ) + out = call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0)) + + assert len(out.samples) == args.rollout_batch_size + group = out.samples[0] + assert len(group) == args.n_samples_per_prompt + assert group[0] == _expected_sample(group_index=0) + + +@pytest.mark.parametrize("rollout_integration_env", _ROLLOUT_ARGV_VARIANTS, indirect=True) +def test_simple_eval_rollout_fn_integration(rollout_integration_env): + args, data_source = rollout_integration_env + fn = load_rollout_function(RolloutFnConstructorInput(args=args, data_source=data_source), args.eval_function_path) + out = call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + assert "toy" in out.data + rewards = out.data["toy"]["rewards"] + samples = out.data["toy"]["samples"] + assert len(rewards) == len(samples) == args.n_samples_per_eval_prompt + assert rewards[0] == 1 + assert samples[0] == _expected_sample(group_index=None) From 79368ffc00adf0d84797d20ade19377b2e1dfef7 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:02:33 +0800 Subject: [PATCH 09/77] Add new sample generation API (#437) --- miles/rollout/base_types.py | 34 +++++++- .../rollout/modular_rollout/compatibility.py | 36 ++++++++ .../modular_rollout/inference_wrapper.py | 13 ++- .../modular_rollout/orchestration_common.py | 19 +++-- .../modular_rollout/test_compatibility.py | 85 ++++++++++++++++++- 5 files changed, 169 insertions(+), 18 deletions(-) diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index d6eb1e8f0d..9b276c0dce 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -1,11 +1,16 @@ +from __future__ import annotations + from argparse import Namespace from collections.abc import Awaitable from dataclasses import dataclass -from typing import Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from miles.rollout.data_source import DataSource from miles.utils.types import Sample +if TYPE_CHECKING: + from miles.rollout.modular_rollout.orchestration_common import GenerateState + @dataclass(frozen=True) class RolloutFnConstructorInput: @@ -38,12 +43,14 @@ def evaluation(self): return True +# TODO make it frozen @dataclass class RolloutFnTrainOutput: samples: list[list[Sample]] metrics: dict[str, Any] = None +# TODO make it frozen @dataclass class RolloutFnEvalOutput: data: dict[str, dict[str, Any]] @@ -60,3 +67,28 @@ class RolloutFnEvalOutput: @runtime_checkable class RolloutFnProtocol(Protocol): def __call__(self, input: RolloutFnInput) -> RolloutFnOutput | Awaitable[RolloutFnOutput]: ... + + +# TODO maybe put to modular_rollout folder depending on overall folder structure +@dataclass(frozen=True) +class GenerateFnInput: + state: GenerateState + sample: Sample + sampling_params: dict[str, Any] + evaluation: bool + + @property + def args(self) -> Namespace: + return self.state.args + + +@dataclass(frozen=True) +class GenerateFnOutput: + sample: Sample | list[Sample] + + +# TODO: may add add_arguments +# TODO: may add save/load if need it to be stateful +@runtime_checkable +class GenerateFnProtocol(Protocol): + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: ... diff --git a/miles/rollout/modular_rollout/compatibility.py b/miles/rollout/modular_rollout/compatibility.py index 7d1a70e79c..f4455a8b80 100644 --- a/miles/rollout/modular_rollout/compatibility.py +++ b/miles/rollout/modular_rollout/compatibility.py @@ -2,6 +2,8 @@ from collections.abc import Callable from miles.rollout.base_types import ( + GenerateFnInput, + GenerateFnOutput, RolloutFnConstructorInput, RolloutFnEvalOutput, RolloutFnInput, @@ -48,3 +50,37 @@ def call_rollout_function(fn: RolloutFnProtocol, input: RolloutFnInput) -> Rollo output = run(output) return output + + +class LegacyGenerateFnAdapter: + def __init__(self, fn: Callable): + self.fn = fn + self._has_evaluation_param = "evaluation" in inspect.signature(fn).parameters + + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: + if self._has_evaluation_param: + output = await self.fn(input.args, input.sample, input.sampling_params, evaluation=input.evaluation) + else: + output = await self.fn(input.args, input.sample, input.sampling_params) + + if not isinstance(output, GenerateFnOutput): + output = GenerateFnOutput(sample=output) + + return output + + +def load_generate_function(path: str): + fn = load_function(path) + + if inspect.isclass(fn): + return fn() + elif _is_legacy_generate_fn(fn): + return LegacyGenerateFnAdapter(fn) + else: + return fn + + +def _is_legacy_generate_fn(fn: Callable) -> bool: + sig = inspect.signature(fn) + params = list(sig.parameters.keys()) + return len(params) >= 3 and params[0] != "input" diff --git a/miles/rollout/modular_rollout/inference_wrapper.py b/miles/rollout/modular_rollout/inference_wrapper.py index d27311c1d3..a457992d5a 100644 --- a/miles/rollout/modular_rollout/inference_wrapper.py +++ b/miles/rollout/modular_rollout/inference_wrapper.py @@ -1,23 +1,22 @@ -from argparse import Namespace -from typing import Any - import numpy as np import pybase64 +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput from miles.utils.http_utils import post from miles.utils.processing_utils import encode_image_for_rollout_engine from miles.utils.types import Sample -async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: +async def generate(input: GenerateFnInput) -> GenerateFnOutput: """Generate using traditional SGLang router with token-based workflow""" + state = input.state + args = input.args + sample = input.sample + sampling_params = input.sampling_params if args.ci_test: assert isinstance(sample.prompt, str) - from miles.rollout.modular_rollout.orchestration_common import GenerateState - - state = GenerateState(args) url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" assert ( diff --git a/miles/rollout/modular_rollout/orchestration_common.py b/miles/rollout/modular_rollout/orchestration_common.py index 2c8c681ae7..22d9f1d0e0 100644 --- a/miles/rollout/modular_rollout/orchestration_common.py +++ b/miles/rollout/modular_rollout/orchestration_common.py @@ -1,5 +1,4 @@ import asyncio -import inspect import logging from argparse import Namespace from contextlib import contextmanager @@ -7,9 +6,11 @@ import numpy as np +from miles.rollout.base_types import GenerateFnInput +from miles.rollout.modular_rollout.compatibility import load_generate_function from miles.rollout.modular_rollout.inference_wrapper import generate from miles.rollout.rm_hub import async_rm, batched_async_rm -from miles.utils.misc import SingletonMeta, load_function +from miles.utils.misc import SingletonMeta from miles.utils.processing_utils import load_processor, load_tokenizer from miles.utils.types import Sample @@ -111,15 +112,15 @@ async def generate_and_rm( return sample with state.dp_rank_context() as _: + # TODO load function only once during whole lifetime if args.custom_generate_function_path is not None: - custom_generate_func = load_function(args.custom_generate_function_path) - # if signature has evaluation, pass evaluation - if "evaluation" in inspect.signature(custom_generate_func).parameters: - sample = await custom_generate_func(args, sample, sampling_params, evaluation=evaluation) - else: - sample = await custom_generate_func(args, sample, sampling_params) + fn = load_generate_function(args.custom_generate_function_path) else: - sample = await generate(args, sample, sampling_params) + fn = generate + output = await fn( + GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params, evaluation=evaluation) + ) + sample = output.sample # for the rm that need the whole group, we will not do the rm here if args.group_rm: diff --git a/tests/rollout/modular_rollout/test_compatibility.py b/tests/rollout/modular_rollout/test_compatibility.py index 596fa76270..c3beba996b 100644 --- a/tests/rollout/modular_rollout/test_compatibility.py +++ b/tests/rollout/modular_rollout/test_compatibility.py @@ -1,9 +1,11 @@ import asyncio -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from miles.rollout.base_types import ( + GenerateFnInput, + GenerateFnOutput, RolloutFnConstructorInput, RolloutFnEvalInput, RolloutFnEvalOutput, @@ -11,10 +13,13 @@ RolloutFnTrainOutput, ) from miles.rollout.modular_rollout.compatibility import ( + LegacyGenerateFnAdapter, LegacyRolloutFnAdapter, call_rollout_function, + load_generate_function, load_rollout_function, ) +from miles.utils.async_utils import run @pytest.fixture @@ -22,6 +27,22 @@ def constructor_input(): return RolloutFnConstructorInput(args="dummy_args", data_source="dummy_data_source") +@pytest.fixture +def make_generate_fn_input(): + def _make(evaluation: bool = False): + state = MagicMock() + state.args = MagicMock() + + return GenerateFnInput( + state=state, + sample={"text": "test prompt"}, + sampling_params={"temperature": 0.7}, + evaluation=evaluation, + ) + + return _make + + class TestSupportedRolloutFormats: """ Documentation test to show various supported rollout function formats @@ -110,3 +131,65 @@ async def __call__(self, input): assert isinstance(fn, AsyncRolloutFn) expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput assert isinstance(result, expected_type) + + +class TestSupportedGenerateFormats: + """ + Documentation test similar to TestSupportedRolloutFormats + """ + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_1_legacy_function_with_evaluation_param(self, make_generate_fn_input, evaluation): + async def legacy_generate_fn(args, sample, sampling_params, evaluation=False): + return "my_sample" + + with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=legacy_generate_fn): + fn = load_generate_function("path.to.fn") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.sample == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_2_legacy_function_without_evaluation_param(self, make_generate_fn_input, evaluation): + async def legacy_generate_fn(args, sample, sampling_params): + return "my_sample" + + with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=legacy_generate_fn): + fn = load_generate_function("path.to.fn") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.sample == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_3_new_async_function_api(self, make_generate_fn_input, evaluation): + async def generate(input: GenerateFnInput) -> GenerateFnOutput: + return GenerateFnOutput(sample="my_sample") + + with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=generate): + fn = load_generate_function("path.to.fn") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(result, GenerateFnOutput) + assert result.sample == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_4_new_class_api(self, make_generate_fn_input, evaluation): + class MyGenerateFn: + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: + return GenerateFnOutput(sample="my_sample") + + with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=MyGenerateFn): + fn = load_generate_function("path.to.fn") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, MyGenerateFn) + assert isinstance(result, GenerateFnOutput) + assert result.sample == "my_sample" From fd7a755d2c3f6bc351b41e1ee4b76e41b8c04a00 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:05:32 +0800 Subject: [PATCH 10/77] Remove global variables in modular rollout (#438) --- .../modular_rollout/inference_wrapper.py | 4 +- .../modular_rollout/orchestration_common.py | 40 +++++++++---------- .../modular_rollout/orchestration_eval.py | 31 ++++++-------- .../modular_rollout/orchestration_train.py | 28 ++++--------- 4 files changed, 42 insertions(+), 61 deletions(-) diff --git a/miles/rollout/modular_rollout/inference_wrapper.py b/miles/rollout/modular_rollout/inference_wrapper.py index a457992d5a..56529c7da4 100644 --- a/miles/rollout/modular_rollout/inference_wrapper.py +++ b/miles/rollout/modular_rollout/inference_wrapper.py @@ -40,7 +40,7 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: ), f"max_new_tokens: {sampling_params['max_new_tokens']} should not be less than 0" if sampling_params["max_new_tokens"] == 0: sample.status = Sample.Status.TRUNCATED - return sample + return GenerateFnOutput(sample=sample) # Prepare payload for sglang server payload = { @@ -97,4 +97,4 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: sample.update_from_meta_info(args, output["meta_info"]) - return sample + return GenerateFnOutput(sample=sample) diff --git a/miles/rollout/modular_rollout/orchestration_common.py b/miles/rollout/modular_rollout/orchestration_common.py index 22d9f1d0e0..a97cf68d6e 100644 --- a/miles/rollout/modular_rollout/orchestration_common.py +++ b/miles/rollout/modular_rollout/orchestration_common.py @@ -10,18 +10,13 @@ from miles.rollout.modular_rollout.compatibility import load_generate_function from miles.rollout.modular_rollout.inference_wrapper import generate from miles.rollout.rm_hub import async_rm, batched_async_rm -from miles.utils.misc import SingletonMeta from miles.utils.processing_utils import load_processor, load_tokenizer from miles.utils.types import Sample logger = logging.getLogger(__name__) -class GenerateState(metaclass=SingletonMeta): - """ - The global state for the generation process. - """ - +class GenerateState: def __init__(self, args: Namespace) -> None: # persistent state for the generation process self.args = args @@ -51,6 +46,11 @@ def __init__(self, args: Namespace) -> None: self.dp_counts = [0] * (args.sglang_dp_size or 1) self.dp_rank = 0 + if args.custom_generate_function_path is not None: + self.generate_function = load_generate_function(args.custom_generate_function_path) + else: + self.generate_function = generate + self.reset() @contextmanager @@ -76,7 +76,7 @@ def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: asyncio.create_task( # submit a group of samples as a single task. generate_and_rm_group( - self.args, + self, group, sampling_params=self.sampling_params.copy(), evaluation=False, @@ -87,11 +87,13 @@ def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: async def generate_and_rm( - args: Namespace, + state: GenerateState, sample: Sample | list[Sample], sampling_params: dict[str, Any], evaluation: bool = False, ) -> Sample | list[Sample]: + args = state.args + # mask previous off-policy generation for partial rollout if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: sample.loss_mask = [0] * sample.response_length @@ -103,8 +105,6 @@ async def generate_and_rm( assert sample.reward is not None return sample - state = GenerateState(args) - # generate async with state.semaphore: if state.aborted: @@ -112,13 +112,13 @@ async def generate_and_rm( return sample with state.dp_rank_context() as _: - # TODO load function only once during whole lifetime - if args.custom_generate_function_path is not None: - fn = load_generate_function(args.custom_generate_function_path) - else: - fn = generate - output = await fn( - GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params, evaluation=evaluation) + output = await state.generate_function( + GenerateFnInput( + state=state, + sample=sample, + sampling_params=sampling_params, + evaluation=evaluation, + ) ) sample = output.sample @@ -149,9 +149,9 @@ async def generate_and_rm( async def generate_and_rm_group( - args: Namespace, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False + state: GenerateState, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False ) -> list[Sample]: - state = GenerateState(args) + args = state.args if state.aborted: return group @@ -163,7 +163,7 @@ async def generate_and_rm_group( seed = state.group_sampling_seeds[idx] current_sampling_params["sampling_seed"] = seed tasks.append( - asyncio.create_task(generate_and_rm(args, sample, current_sampling_params, evaluation=evaluation)) + asyncio.create_task(generate_and_rm(state, sample, current_sampling_params, evaluation=evaluation)) ) group = await asyncio.gather(*tasks) diff --git a/miles/rollout/modular_rollout/orchestration_eval.py b/miles/rollout/modular_rollout/orchestration_eval.py index e89b2f2edb..cb76901ef0 100644 --- a/miles/rollout/modular_rollout/orchestration_eval.py +++ b/miles/rollout/modular_rollout/orchestration_eval.py @@ -1,13 +1,12 @@ import asyncio import copy import logging -from argparse import Namespace from typing import Any from tqdm import tqdm from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput, RolloutFnEvalOutput -from miles.rollout.modular_rollout.orchestration_common import generate_and_rm +from miles.rollout.modular_rollout.orchestration_common import GenerateState, generate_and_rm from miles.utils.data import Dataset from miles.utils.eval_config import EvalDatasetConfig from miles.utils.processing_utils import load_processor, load_tokenizer @@ -15,28 +14,20 @@ logger = logging.getLogger(__name__) -EVAL_PROMPT_DATASET = {} - async def eval_rollout_single_dataset( - args: Namespace, rollout_id: int, dataset_cfg: EvalDatasetConfig + state: GenerateState, + dataset_cfg: EvalDatasetConfig, + prompt_dataset_cache: dict[Any, Dataset], ) -> dict[str, dict[str, list[Any]]]: - """An example to implement the eval_rollout function for an rule based rm rollout generation. - - Args: - args: the whole args - rollout_id: int, the id of the rollout, used for deterministic data generation - dataset_cfg: configuration of the dataset - """ + args = state.args assert not args.group_rm, "Group RM is not supported for eval rollout" - global EVAL_PROMPT_DATASET - cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template) - if cache_key not in EVAL_PROMPT_DATASET: + if cache_key not in prompt_dataset_cache: tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) processor = load_processor(args.hf_checkpoint, trust_remote_code=True) - EVAL_PROMPT_DATASET[cache_key] = Dataset( + prompt_dataset_cache[cache_key] = Dataset( path=dataset_cfg.path, tokenizer=tokenizer, processor=processor, @@ -49,7 +40,7 @@ async def eval_rollout_single_dataset( apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, ) - dataset = EVAL_PROMPT_DATASET[cache_key] + dataset = prompt_dataset_cache[cache_key] base_sampling_params = dict( temperature=dataset_cfg.temperature, @@ -80,7 +71,7 @@ async def eval_rollout_single_dataset( tasks.append( asyncio.create_task( generate_and_rm( - args, + state, sample, sampling_params=sampling_params, evaluation=True, @@ -122,13 +113,15 @@ async def eval_rollout_single_dataset( class SimpleEvalRolloutFn: def __init__(self, input: RolloutFnConstructorInput): self.args = input.args + self.prompt_dataset_cache = {} + self.state = GenerateState(self.args) async def __call__(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput: assert not self.args.group_rm, "Group RM is not supported for eval rollout" coros = [] for dataset_cfg in getattr(self.args, "eval_datasets", []) or []: - coros.append(eval_rollout_single_dataset(self.args, input.rollout_id, dataset_cfg)) + coros.append(eval_rollout_single_dataset(self.state, dataset_cfg, self.prompt_dataset_cache)) results_list = await asyncio.gather(*coros) results = {} for r in results_list: diff --git a/miles/rollout/modular_rollout/orchestration_train.py b/miles/rollout/modular_rollout/orchestration_train.py index cd9549df48..605541b7d1 100644 --- a/miles/rollout/modular_rollout/orchestration_train.py +++ b/miles/rollout/modular_rollout/orchestration_train.py @@ -1,6 +1,5 @@ import asyncio import logging -from argparse import Namespace from collections.abc import Callable import sglang_router @@ -17,10 +16,11 @@ logger = logging.getLogger(__name__) -async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: +async def abort(state: GenerateState, rollout_id: int) -> list[list[Sample]]: + args = state.args + aborted_samples = [] - state = GenerateState(args) assert not state.aborted state.aborted = True @@ -58,24 +58,11 @@ async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: async def generate_rollout_async( - args: Namespace, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] + state: GenerateState, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] ) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: - """An example to implement the generate_rollout function for an rule based rm rollout generation. - - Args: - args: the whole args - rollout_id: int, the id of the rollout, used for deterministic data generation - data_source: the data source to fetch - - Returns: - tuple[RolloutFnTrainOutput, list[list[Sample]]]: - - data: a list of groups of samples generated by the rollout, length equals `rollout_batch_size` - - aborted_samples: any partial groups collected during abort when partial_rollout is enabled - """ + args = state.args assert args.rollout_global_dataset - state = GenerateState(args) - # instantiate data filters dynamic_filter = ( load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None @@ -129,7 +116,7 @@ async def generate_rollout_async( ) # there are still some unfinished requests, abort them - aborted_samples = await abort(args, rollout_id) + aborted_samples = await abort(state, rollout_id) assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) @@ -155,10 +142,11 @@ class SimpleTrainRolloutFn: def __init__(self, input: RolloutFnConstructorInput): self.args = input.args self.data_source = input.data_source + self.state = GenerateState(self.args) async def __call__(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput: output, aborted_samples = await generate_rollout_async( - self.args, input.rollout_id, self.data_source.get_samples + self.state, input.rollout_id, self.data_source.get_samples ) self.data_source.add_samples(aborted_samples) return output From d6b23e03f42aa5b6b6c518512f5a875b9410401a Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:07:05 +0800 Subject: [PATCH 11/77] Remove misplaced fields in GenerateState (#439) --- .../modular_rollout/orchestration_common.py | 17 ---------- .../modular_rollout/orchestration_train.py | 33 ++++++++++++++----- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/miles/rollout/modular_rollout/orchestration_common.py b/miles/rollout/modular_rollout/orchestration_common.py index a97cf68d6e..2d25d871ed 100644 --- a/miles/rollout/modular_rollout/orchestration_common.py +++ b/miles/rollout/modular_rollout/orchestration_common.py @@ -66,25 +66,8 @@ def dp_rank_context(self): assert self.dp_counts[dp_rank] >= 0 def reset(self) -> None: - self.remaining_batch_size = 0 - self.pendings = set() self.aborted = False - def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: - for group in samples: - self.pendings.add( - asyncio.create_task( - # submit a group of samples as a single task. - generate_and_rm_group( - self, - group, - sampling_params=self.sampling_params.copy(), - evaluation=False, - ) - ) - ) - self.remaining_batch_size += len(samples) - async def generate_and_rm( state: GenerateState, diff --git a/miles/rollout/modular_rollout/orchestration_train.py b/miles/rollout/modular_rollout/orchestration_train.py index 605541b7d1..b373e5c5db 100644 --- a/miles/rollout/modular_rollout/orchestration_train.py +++ b/miles/rollout/modular_rollout/orchestration_train.py @@ -8,7 +8,7 @@ from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnTrainInput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter -from miles.rollout.modular_rollout.orchestration_common import GenerateState +from miles.rollout.modular_rollout.orchestration_common import GenerateState, generate_and_rm_group from miles.utils.http_utils import get, post from miles.utils.misc import load_function from miles.utils.types import Sample @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -async def abort(state: GenerateState, rollout_id: int) -> list[list[Sample]]: +async def abort(state: GenerateState, pendings: set, rollout_id: int) -> list[list[Sample]]: args = state.args aborted_samples = [] @@ -36,8 +36,8 @@ async def abort(state: GenerateState, rollout_id: int) -> list[list[Sample]]: # make sure all the pending tasks are finished count = 0 - while state.pendings: - done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED) + while pendings: + done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED) if not args.partial_rollout: continue @@ -57,6 +57,21 @@ async def abort(state: GenerateState, rollout_id: int) -> list[list[Sample]]: return aborted_samples +def submit_generate_tasks(state: GenerateState, samples: list[list[Sample]]): + return [ + asyncio.create_task( + # submit a group of samples as a single task. + generate_and_rm_group( + state, + group, + sampling_params=state.sampling_params.copy(), + evaluation=False, + ) + ) + for group in samples + ] + + async def generate_rollout_async( state: GenerateState, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] ) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: @@ -73,18 +88,19 @@ async def generate_rollout_async( # target_data_size is the total number of valid samples to get target_data_size = args.rollout_batch_size + pendings = set() data = [] all_data = [] do_print = True pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation") while len(data) < target_data_size: - while state.remaining_batch_size < target_data_size: + while len(data) + len(pendings) < target_data_size: # get samples from the buffer and submit the generation requests. samples = data_source(args.over_sampling_batch_size) - state.submit_generate_tasks(samples) + pendings |= submit_generate_tasks(state, samples) # wait for the generation to finish - done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED) + done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED) for task in done: group: list[Sample] = task.result() @@ -100,7 +116,6 @@ async def generate_rollout_async( dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) if not dynamic_filter_output.keep: metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) - state.remaining_batch_size -= 1 continue # add the samples to the data @@ -116,7 +131,7 @@ async def generate_rollout_async( ) # there are still some unfinished requests, abort them - aborted_samples = await abort(state, rollout_id) + aborted_samples = await abort(state, pendings, rollout_id) assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) From 8b1fe7fda2b7d704449dcb3d7b8ad21ea428cce0 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:07:32 +0800 Subject: [PATCH 12/77] Temporarily remove DP rank balancing in generate state (#440) --- .../modular_rollout/orchestration_common.py | 35 +++++-------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/miles/rollout/modular_rollout/orchestration_common.py b/miles/rollout/modular_rollout/orchestration_common.py index 2d25d871ed..e142ff4e0e 100644 --- a/miles/rollout/modular_rollout/orchestration_common.py +++ b/miles/rollout/modular_rollout/orchestration_common.py @@ -1,10 +1,8 @@ import asyncio import logging from argparse import Namespace -from contextlib import contextmanager from typing import Any -import numpy as np from miles.rollout.base_types import GenerateFnInput from miles.rollout.modular_rollout.compatibility import load_generate_function @@ -42,10 +40,6 @@ def __init__(self, args: Namespace) -> None: sampling_seed_base = args.rollout_seed self.group_sampling_seeds = [sampling_seed_base + i for i in range(args.n_samples_per_prompt)] - # dp rank balancing - self.dp_counts = [0] * (args.sglang_dp_size or 1) - self.dp_rank = 0 - if args.custom_generate_function_path is not None: self.generate_function = load_generate_function(args.custom_generate_function_path) else: @@ -53,18 +47,6 @@ def __init__(self, args: Namespace) -> None: self.reset() - @contextmanager - def dp_rank_context(self): - candidates = [i for i, count in enumerate(self.dp_counts) if count == min(self.dp_counts)] - dp_rank = int(np.random.choice(candidates)) - self.dp_counts[dp_rank] += 1 - self.dp_rank = dp_rank - try: - yield dp_rank - finally: - self.dp_counts[dp_rank] -= 1 - assert self.dp_counts[dp_rank] >= 0 - def reset(self) -> None: self.aborted = False @@ -94,16 +76,15 @@ async def generate_and_rm( sample.status = Sample.Status.ABORTED return sample - with state.dp_rank_context() as _: - output = await state.generate_function( - GenerateFnInput( - state=state, - sample=sample, - sampling_params=sampling_params, - evaluation=evaluation, - ) + output = await state.generate_function( + GenerateFnInput( + state=state, + sample=sample, + sampling_params=sampling_params, + evaluation=evaluation, ) - sample = output.sample + ) + sample = output.sample # for the rm that need the whole group, we will not do the rm here if args.group_rm: From e436f1c7f3b5c59f32c9701e9857f7c6e92300b1 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:07:59 +0800 Subject: [PATCH 13/77] Cleanup and shorten modular rollout code (#441) --- miles/rollout/base_types.py | 4 +- .../rollout/modular_rollout/compatibility.py | 4 +- .../modular_rollout/inference_wrapper.py | 4 +- .../modular_rollout/orchestration_common.py | 63 +++++++++++-------- .../modular_rollout/orchestration_eval.py | 25 +++----- .../modular_rollout/orchestration_train.py | 60 ++++++++---------- miles/rollout/rm_hub/__init__.py | 12 +++- miles/utils/misc.py | 9 +++ 8 files changed, 99 insertions(+), 82 deletions(-) diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index 9b276c0dce..e4aa454302 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -84,7 +84,9 @@ def args(self) -> Namespace: @dataclass(frozen=True) class GenerateFnOutput: - sample: Sample | list[Sample] + # One generate may lead to multiple samples, such as multi-agent, tree-like exploration, or + # multi-turn with removing thinking tokens. + samples: Sample | list[Sample] # TODO: may add add_arguments diff --git a/miles/rollout/modular_rollout/compatibility.py b/miles/rollout/modular_rollout/compatibility.py index f4455a8b80..41427d0ed0 100644 --- a/miles/rollout/modular_rollout/compatibility.py +++ b/miles/rollout/modular_rollout/compatibility.py @@ -64,13 +64,15 @@ async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: output = await self.fn(input.args, input.sample, input.sampling_params) if not isinstance(output, GenerateFnOutput): - output = GenerateFnOutput(sample=output) + output = GenerateFnOutput(samples=output) return output def load_generate_function(path: str): fn = load_function(path) + if fn is None: + return None if inspect.isclass(fn): return fn() diff --git a/miles/rollout/modular_rollout/inference_wrapper.py b/miles/rollout/modular_rollout/inference_wrapper.py index 56529c7da4..3a09d3dfdd 100644 --- a/miles/rollout/modular_rollout/inference_wrapper.py +++ b/miles/rollout/modular_rollout/inference_wrapper.py @@ -40,7 +40,7 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: ), f"max_new_tokens: {sampling_params['max_new_tokens']} should not be less than 0" if sampling_params["max_new_tokens"] == 0: sample.status = Sample.Status.TRUNCATED - return GenerateFnOutput(sample=sample) + return GenerateFnOutput(samples=sample) # Prepare payload for sglang server payload = { @@ -97,4 +97,4 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: sample.update_from_meta_info(args, output["meta_info"]) - return GenerateFnOutput(sample=sample) + return GenerateFnOutput(samples=sample) diff --git a/miles/rollout/modular_rollout/orchestration_common.py b/miles/rollout/modular_rollout/orchestration_common.py index e142ff4e0e..da9e90654b 100644 --- a/miles/rollout/modular_rollout/orchestration_common.py +++ b/miles/rollout/modular_rollout/orchestration_common.py @@ -21,29 +21,18 @@ def __init__(self, args: Namespace) -> None: self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) - self.semaphore = asyncio.Semaphore( + self.generate_fn_semaphore = asyncio.Semaphore( args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine ) - self.sampling_params: dict[str, Any] = dict( + self.sampling_params: dict[str, Any] = compute_sampling_params( + args, temperature=args.rollout_temperature, top_p=args.rollout_top_p, top_k=args.rollout_top_k, max_new_tokens=args.rollout_max_response_len, - stop=args.rollout_stop, - stop_token_ids=args.rollout_stop_token_ids, - skip_special_tokens=args.rollout_skip_special_tokens, - no_stop_trim=True, - spaces_between_special_tokens=False, ) - if getattr(args, "sglang_enable_deterministic_inference", False): - sampling_seed_base = args.rollout_seed - self.group_sampling_seeds = [sampling_seed_base + i for i in range(args.n_samples_per_prompt)] - - if args.custom_generate_function_path is not None: - self.generate_function = load_generate_function(args.custom_generate_function_path) - else: - self.generate_function = generate + self.generate_function = load_generate_function(args.custom_generate_function_path) or generate self.reset() @@ -71,7 +60,7 @@ async def generate_and_rm( return sample # generate - async with state.semaphore: + async with state.generate_fn_semaphore: if state.aborted: sample.status = Sample.Status.ABORTED return sample @@ -84,12 +73,14 @@ async def generate_and_rm( evaluation=evaluation, ) ) - sample = output.sample + sample = output.samples + # TODO change to `if not args.group_rm: do reward model` for more clarity after the refactor below # for the rm that need the whole group, we will not do the rm here if args.group_rm: return sample + # TODO: unify the two branches into one if we decide to use list as output type # multi samples if isinstance(sample, list): samples = sample @@ -98,9 +89,7 @@ async def generate_and_rm( # for multi agent system, the reward of some sample is calculated during generation. samples_need_reward = [sample for sample in samples if sample.reward is None] - rewards = await batched_async_rm(args, samples_need_reward) - for sample, reward in zip(samples_need_reward, rewards, strict=False): - sample.reward = reward + await batched_async_rm(args, samples_need_reward, inplace_set_reward_field=True) return samples else: if sample.status == Sample.Status.ABORTED: @@ -124,18 +113,38 @@ async def generate_and_rm_group( for idx, sample in enumerate(group): current_sampling_params = sampling_params.copy() if getattr(args, "sglang_enable_deterministic_inference", False): - seed = state.group_sampling_seeds[idx] - current_sampling_params["sampling_seed"] = seed + current_sampling_params["sampling_seed"] = args.rollout_seed + idx tasks.append( asyncio.create_task(generate_and_rm(state, sample, current_sampling_params, evaluation=evaluation)) ) group = await asyncio.gather(*tasks) + if state.aborted: + return group - # for the rm that need the whole group, we will do the rm here - if not state.aborted and args.group_rm: - rewards = await batched_async_rm(args, group) - for sample, reward in zip(group, rewards, strict=False): - sample.reward = reward + if args.group_rm: + await batched_async_rm(args, group, inplace_set_reward_field=True) return group + + +def compute_sampling_params( + args, + *, + # after unifying configuration, this can be further refactored + temperature, + top_p, + top_k, + max_new_tokens, +): + return dict( + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_new_tokens=max_new_tokens, + stop=args.rollout_stop, + stop_token_ids=args.rollout_stop_token_ids, + skip_special_tokens=args.rollout_skip_special_tokens, + no_stop_trim=True, + spaces_between_special_tokens=False, + ) diff --git a/miles/rollout/modular_rollout/orchestration_eval.py b/miles/rollout/modular_rollout/orchestration_eval.py index cb76901ef0..5d95c54d49 100644 --- a/miles/rollout/modular_rollout/orchestration_eval.py +++ b/miles/rollout/modular_rollout/orchestration_eval.py @@ -6,9 +6,10 @@ from tqdm import tqdm from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput, RolloutFnEvalOutput -from miles.rollout.modular_rollout.orchestration_common import GenerateState, generate_and_rm +from miles.rollout.modular_rollout.orchestration_common import GenerateState, compute_sampling_params, generate_and_rm from miles.utils.data import Dataset from miles.utils.eval_config import EvalDatasetConfig +from miles.utils.misc import as_completed_async from miles.utils.processing_utils import load_processor, load_tokenizer from miles.utils.types import Sample @@ -42,16 +43,12 @@ async def eval_rollout_single_dataset( ) dataset = prompt_dataset_cache[cache_key] - base_sampling_params = dict( + base_sampling_params = compute_sampling_params( + args, temperature=dataset_cfg.temperature, top_p=dataset_cfg.top_p, top_k=dataset_cfg.top_k, max_new_tokens=dataset_cfg.max_response_len, - stop=args.rollout_stop, - stop_token_ids=args.rollout_stop_token_ids, - skip_special_tokens=args.rollout_skip_special_tokens, - no_stop_trim=True, - spaces_between_special_tokens=False, ) tasks = [] @@ -82,8 +79,7 @@ async def eval_rollout_single_dataset( data = [] do_print = True pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) - for coro in asyncio.as_completed(tasks): - sample = await coro + async for sample in as_completed_async(tasks): if do_print: logger.info( "eval_rollout_single_dataset example data: " @@ -112,18 +108,15 @@ async def eval_rollout_single_dataset( class SimpleEvalRolloutFn: def __init__(self, input: RolloutFnConstructorInput): - self.args = input.args self.prompt_dataset_cache = {} - self.state = GenerateState(self.args) + self.state = GenerateState(input.args) async def __call__(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput: - assert not self.args.group_rm, "Group RM is not supported for eval rollout" + assert not self.state.args.group_rm, "Group RM is not supported for eval rollout" coros = [] - for dataset_cfg in getattr(self.args, "eval_datasets", []) or []: + for dataset_cfg in getattr(self.state.args, "eval_datasets", []) or []: coros.append(eval_rollout_single_dataset(self.state, dataset_cfg, self.prompt_dataset_cache)) results_list = await asyncio.gather(*coros) - results = {} - for r in results_list: - results.update(r) + results = {k: v for r in results_list for k, v in r.items()} return RolloutFnEvalOutput(data=results) diff --git a/miles/rollout/modular_rollout/orchestration_train.py b/miles/rollout/modular_rollout/orchestration_train.py index b373e5c5db..2adfa2dce1 100644 --- a/miles/rollout/modular_rollout/orchestration_train.py +++ b/miles/rollout/modular_rollout/orchestration_train.py @@ -1,5 +1,6 @@ import asyncio import logging +from argparse import Namespace from collections.abc import Callable import sglang_router @@ -10,7 +11,7 @@ from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter from miles.rollout.modular_rollout.orchestration_common import GenerateState, generate_and_rm_group from miles.utils.http_utils import get, post -from miles.utils.misc import load_function +from miles.utils.misc import as_completed_async, load_function from miles.utils.types import Sample logger = logging.getLogger(__name__) @@ -19,44 +20,40 @@ async def abort(state: GenerateState, pendings: set, rollout_id: int) -> list[list[Sample]]: args = state.args - aborted_samples = [] - assert not state.aborted state.aborted = True - if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router: - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") - urls = response["urls"] - else: - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") - urls = [worker["url"] for worker in response["workers"]] - + urls = await get_worker_urls(args) logger.info(f"Abort request for {urls}") await asyncio.gather(*[post(f"{url}/abort_request", {"abort_all": True}) for url in urls]) # make sure all the pending tasks are finished - count = 0 - while pendings: - done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED) - + aborted_samples = [] + async for group in as_completed_async(pendings): if not args.partial_rollout: continue # for partial rollout, collect the partial samples into the data buffer - for task in done: - group = task.result() - for sample in group: - if sample.response and "start_rollout_id" not in sample.metadata: - sample.metadata["start_rollout_id"] = rollout_id - aborted_samples.append(group) - count += len(group) + for sample in group: + if sample.response and "start_rollout_id" not in sample.metadata: + sample.metadata["start_rollout_id"] = rollout_id + aborted_samples.append(group) if args.partial_rollout: - logger.info(f"Collected {count} partial samples into the data buffer") + logger.info(f"Collected {sum(len(x) for x in aborted_samples)} partial samples into the data buffer") return aborted_samples +async def get_worker_urls(args: Namespace): + if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") + return response["urls"] + else: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") + return [worker["url"] for worker in response["workers"]] + + def submit_generate_tasks(state: GenerateState, samples: list[list[Sample]]): return [ asyncio.create_task( @@ -79,9 +76,7 @@ async def generate_rollout_async( assert args.rollout_global_dataset # instantiate data filters - dynamic_filter = ( - load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None - ) + dynamic_filter = load_function(args.dynamic_sampling_filter_path) metric_gatherer = MetricGatherer() @@ -97,7 +92,7 @@ async def generate_rollout_async( while len(data) + len(pendings) < target_data_size: # get samples from the buffer and submit the generation requests. samples = data_source(args.over_sampling_batch_size) - pendings |= submit_generate_tasks(state, samples) + pendings.update(submit_generate_tasks(state, samples)) # wait for the generation to finish done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED) @@ -141,23 +136,20 @@ async def generate_rollout_async( # reset the global state to prevent effects on the next rollout or eval. state.reset() - if args.rollout_sample_filter_path is not None: - filter_func = load_function(args.rollout_sample_filter_path) - filter_func(args, data) + if f := load_function(args.rollout_sample_filter_path): + f(args, data) # There can be circumstances where users want to process all samples including filtered ones. - if args.rollout_all_samples_process_path is not None: - process_func = load_function(args.rollout_all_samples_process_path) - process_func(args, all_samples, data_source) + if f := load_function(args.rollout_all_samples_process_path): + f(args, all_samples, data_source) return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples class SimpleTrainRolloutFn: def __init__(self, input: RolloutFnConstructorInput): - self.args = input.args self.data_source = input.data_source - self.state = GenerateState(self.args) + self.state = GenerateState(input.args) async def __call__(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput: output, aborted_samples = await generate_rollout_async( diff --git a/miles/rollout/rm_hub/__init__.py b/miles/rollout/rm_hub/__init__.py index 62b253ddee..e9ee29db41 100644 --- a/miles/rollout/rm_hub/__init__.py +++ b/miles/rollout/rm_hub/__init__.py @@ -69,8 +69,18 @@ async def async_rm(args, sample: Sample, **kwargs): async def batched_async_rm( args, samples: list[Sample], + inplace_set_reward_field: bool = False, **kwargs, -) -> list[int | float]: +) -> list[int | float] | None: + if inplace_set_reward_field: + rewards = await batched_async_rm(args, samples, **kwargs) + for sample, reward in zip(samples, rewards, strict=True): + assert ( + sample.reward is None + ), f"Overriding sample.reward from {sample.reward} to {reward}, is this intended?" + sample.reward = reward + return None + if args.custom_rm_path is not None: # Ensure the custom reward function is implemented in batch mode rm_function = load_function(args.custom_rm_path) diff --git a/miles/utils/misc.py b/miles/utils/misc.py index c0a96d6366..823738a56f 100644 --- a/miles/utils/misc.py +++ b/miles/utils/misc.py @@ -1,3 +1,4 @@ +import asyncio import importlib import subprocess @@ -12,6 +13,9 @@ def load_function(path): :param path: The path to the function, e.g. "module.submodule.function". :return: The function object. """ + if path is None: + return None + module_path, _, attr = path.rpartition(".") module = importlib.import_module(module_path) return getattr(module, attr) @@ -92,3 +96,8 @@ def should_run_periodic_action( step = rollout_id + 1 return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0) + + +async def as_completed_async(tasks): + for coro in asyncio.as_completed(tasks): + yield await coro From b915eb37c8c8892c4d6d9ce411622a45978c7d9d Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:08:19 +0800 Subject: [PATCH 14/77] Add tests for all reward functions (#443) --- tests/rollout/rm_hub/__init__.py | 0 tests/rollout/rm_hub/test_deepscaler.py | 26 ++++ tests/rollout/rm_hub/test_f1.py | 44 +++++++ tests/rollout/rm_hub/test_gpqa.py | 86 +++++++++++++ tests/rollout/rm_hub/test_math_dapo_utils.py | 108 ++++++++++++++++ tests/rollout/rm_hub/test_math_utils.py | 129 +++++++++++++++++++ tests/rollout/rm_hub/test_rm_hub.py | 126 ++++++++++++++++++ 7 files changed, 519 insertions(+) create mode 100644 tests/rollout/rm_hub/__init__.py create mode 100644 tests/rollout/rm_hub/test_deepscaler.py create mode 100644 tests/rollout/rm_hub/test_f1.py create mode 100644 tests/rollout/rm_hub/test_gpqa.py create mode 100644 tests/rollout/rm_hub/test_math_dapo_utils.py create mode 100644 tests/rollout/rm_hub/test_math_utils.py create mode 100644 tests/rollout/rm_hub/test_rm_hub.py diff --git a/tests/rollout/rm_hub/__init__.py b/tests/rollout/rm_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/rollout/rm_hub/test_deepscaler.py b/tests/rollout/rm_hub/test_deepscaler.py new file mode 100644 index 0000000000..bd4c606a68 --- /dev/null +++ b/tests/rollout/rm_hub/test_deepscaler.py @@ -0,0 +1,26 @@ +import pytest + +from miles.rollout.rm_hub.deepscaler import get_deepscaler_rule_based_reward + + +class TestGetDeepscalerRuleBasedReward: + @pytest.mark.parametrize( + "response,label,expected", + [ + (r"Let me analyze...The answer is \boxed{42}", "42", 1), + (r"Thinking...The answer is \boxed{wrong}", "42", 0), + (r"###Response\boxed{42}", "42", 1), + (r"###Response\boxed{wrong}", "42", 0), + (r"The answer is \boxed{42}", "42", 0), + (r"The answer is 42", "42", 0), + (r"\boxed{42}", "", 0), + (r"\boxed{42}", r"\boxed{42}", 1), + (r"\boxed{123}", 123, 1), + (r"\boxed{3.14}", 3.14, 1), + (r"\boxed{1/2}", "0.5", 1), + (r"\boxed{\frac{1}{2}}", "0.5", 1), + (r"First thoughtSecond thought\boxed{42}", "42", 1), + ], + ) + def test_get_deepscaler_rule_based_reward(self, response, label, expected): + assert get_deepscaler_rule_based_reward(response, label) == expected diff --git a/tests/rollout/rm_hub/test_f1.py b/tests/rollout/rm_hub/test_f1.py new file mode 100644 index 0000000000..c9ecf9614d --- /dev/null +++ b/tests/rollout/rm_hub/test_f1.py @@ -0,0 +1,44 @@ +import pytest + +from miles.rollout.rm_hub.f1 import f1_score, normalize_answer + + +class TestNormalizeAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("Hello World", "hello world"), + ("The quick brown fox", "quick brown fox"), + ("A cat and a dog", "cat and dog"), + ("Hello, world!", "hello world"), + (" multiple spaces ", "multiple spaces"), + ("An apple", "apple"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_normalize_answer(self, input_str, expected): + assert normalize_answer(input_str) == expected + + +class TestF1Score: + @pytest.mark.parametrize( + "prediction,ground_truth,expected_f1,expected_prec,expected_recall", + [ + ("hello world", "hello world", 1.0, 1.0, 1.0), + ("hello world foo", "hello world bar", 2 / 3, 2 / 3, 2 / 3), + ("abc", "xyz", 0, 0, 0), + (None, "anything", 0, 0, 0), + ("yes", "no", 0, 0, 0), + ("no", "yes", 0, 0, 0), + ("yes", "yes", 1.0, 1.0, 1.0), + ("noanswer", "yes", 0, 0, 0), + ("the answer is correct", "answer is correct", 1.0, 1.0, 1.0), + ("hello, world!", "hello world", 1.0, 1.0, 1.0), + ("hello", "hello world", pytest.approx(2 / 3), 1.0, 0.5), + ], + ) + def test_f1_score(self, prediction, ground_truth, expected_f1, expected_prec, expected_recall): + f1, prec, recall = f1_score(prediction, ground_truth) + assert f1 == expected_f1 + assert prec == expected_prec + assert recall == expected_recall diff --git a/tests/rollout/rm_hub/test_gpqa.py b/tests/rollout/rm_hub/test_gpqa.py new file mode 100644 index 0000000000..45cefd2015 --- /dev/null +++ b/tests/rollout/rm_hub/test_gpqa.py @@ -0,0 +1,86 @@ +import pytest + +from miles.rollout.rm_hub.gpqa import ( + _extract_letter_from_response, + _normalize_text, + _strip_chain_of_thought, + compute_gpqa_reward, +) + + +class TestStripChainOfThought: + @pytest.mark.parametrize( + "text,expected", + [ + ("Let me think...The answer is A", "The answer is A"), + ("The answer is A", "The answer is A"), + ("", ""), + (None, ""), + ], + ) + def test_strip_chain_of_thought(self, text, expected): + assert _strip_chain_of_thought(text) == expected + + +class TestNormalizeText: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("Hello World", "hello world"), + ("Test-123", "test 123"), + ("A, B, C", "a b c"), + ("", ""), + ], + ) + def test_normalize_text(self, input_str, expected): + assert _normalize_text(input_str) == expected + + +class TestExtractLetterFromResponse: + @pytest.mark.parametrize( + "response,expected", + [ + ("The answer is A", "A"), + ("answer: B", "B"), + ("I think C is correct", "C"), + ("final answer: D", "D"), + ("Option A is the best choice", "A"), + ("The answer is B", "B"), + ("After analysis, my choice is C", "C"), + ("A B C D", "D"), + ("No valid letter here", None), + ("", None), + (None, None), + ("The answer is Z", None), + ], + ) + def test_extract_letter(self, response, expected): + assert _extract_letter_from_response(response, "ABCD") == expected + + +class TestComputeGpqaReward: + @pytest.mark.parametrize( + "response,label,metadata,expected", + [ + ("Answer: A", "A", None, 1.0), + ("Answer: A", "B", None, 0.0), + (None, "A", None, 0.0), + ("Answer: B", "ignored", {"correct_letter": "B"}, 1.0), + ("Answer: A", "ignored", {"correct_letter": "B"}, 0.0), + ("Answer: A", 0, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), + ("Answer: B", 1, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), + ("Answer: X", "X", {"valid_letters": ["X", "Y", "Z"]}, 1.0), + ("Answer: A", "X", {"valid_letters": ["X", "Y", "Z"]}, 0.0), + ( + "I believe the answer is Paris", + "", + {"choices": ["Paris", "London", "Berlin", "Rome"], "correct_letter": "A"}, + 1.0, + ), + ("Answer: A", "", {"choices": {"A": "Paris", "B": "London"}, "correct_letter": "A"}, 1.0), + ("The answer is Paris", "Paris", {"choices": ["Paris", "London", "Berlin", "Rome"]}, 1.0), + ("Let me think step by step...The answer is A", "A", None, 1.0), + ], + ) + def test_compute_gpqa_reward(self, response, label, metadata, expected): + assert compute_gpqa_reward(response, label, metadata=metadata) == expected diff --git a/tests/rollout/rm_hub/test_math_dapo_utils.py b/tests/rollout/rm_hub/test_math_dapo_utils.py new file mode 100644 index 0000000000..56a7f6d1f9 --- /dev/null +++ b/tests/rollout/rm_hub/test_math_dapo_utils.py @@ -0,0 +1,108 @@ +import pytest + +from miles.rollout.rm_hub.math_dapo_utils import ( + compute_score, + is_correct_minerva, + is_correct_strict_box, + last_boxed_only_string, + normalize_final_answer, + remove_boxed, +) + + +class TestLastBoxedOnlyString: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", r"\boxed{42}"), + (r"\boxed{x^2}", r"\boxed{x^2}"), + (r"No boxed", None), + (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), + ], + ) + def test_last_boxed_only_string(self, input_str, expected): + assert last_boxed_only_string(input_str) == expected + + +class TestRemoveBoxed: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"\boxed{42}", "42"), + (r"\boxed{x + 1}", "x + 1"), + ], + ) + def test_remove_boxed_valid(self, input_str, expected): + assert remove_boxed(input_str) == expected + + def test_remove_boxed_invalid(self): + with pytest.raises(AssertionError): + remove_boxed("not boxed") + + +class TestNormalizeFinalAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("42", "42"), + (" 42 ", "42"), + (r"\text{hello}", "hello"), + (r"\textbf{bold}", "bold"), + (r"x = 42", "42"), + (r"100 square", "100"), + (r"$50$ dollars", "50"), + (r"\boxed{42}", "42"), + (r"\frac12", r"\frac{1}{2}"), + (r"\sqrt3", r"\sqrt{3}"), + ("1,000", "1000"), + ("<|im_end|>", ""), + ], + ) + def test_normalize_final_answer(self, input_str, expected): + assert normalize_final_answer(input_str) == expected + + +class TestIsCorrectMinerva: + @pytest.mark.parametrize( + "solution,gt,gt_need_extract,expected_correct", + [ + ("Answer: 42", "42", False, True), + ("Answer: 100", "42", False, False), + ("Answer: wrong", "42", False, False), + ("Answer: 42", r"\boxed{42}", True, True), + ], + ) + def test_is_correct_minerva(self, solution, gt, gt_need_extract, expected_correct): + correct, pred = is_correct_minerva(solution, gt, gt_need_extract=gt_need_extract) + assert correct == expected_correct + + +class TestIsCorrectStrictBox: + @pytest.mark.parametrize( + "pred,gt,expected_score,expected_pred", + [ + (r"blah blah \boxed{42}", "42", 1, "42"), + (r"\boxed{wrong}", "42", -1, "wrong"), + ("no box here", "42", -1, None), + ], + ) + def test_is_correct_strict_box(self, pred, gt, expected_score, expected_pred): + score, extracted = is_correct_strict_box(pred, gt) + assert score == expected_score + assert extracted == expected_pred + + +class TestComputeScore: + @pytest.mark.parametrize( + "solution,gt,strict_box,expected_score,expected_acc", + [ + ("Answer: 42", "42", False, 1.0, True), + ("Answer: wrong", "42", False, -1.0, False), + (r"\boxed{42}", "42", True, 1.0, True), + ("x" * 500 + " Answer: 42", "42", False, 1.0, True), + ], + ) + def test_compute_score(self, solution, gt, strict_box, expected_score, expected_acc): + result = compute_score(solution, gt, strict_box_verify=strict_box) + assert result["score"] == expected_score + assert result["acc"] == expected_acc diff --git a/tests/rollout/rm_hub/test_math_utils.py b/tests/rollout/rm_hub/test_math_utils.py new file mode 100644 index 0000000000..2423ed4acc --- /dev/null +++ b/tests/rollout/rm_hub/test_math_utils.py @@ -0,0 +1,129 @@ +import pytest + +from miles.rollout.rm_hub.math_utils import ( + _normalize, + extract_answer, + grade_answer_mathd, + grade_answer_sympy, + grade_answer_verl, + last_boxed_only_string, + remove_boxed, +) + + +class TestLastBoxedOnlyString: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", r"\boxed{42}"), + (r"\boxed{x^2 + 1}", r"\boxed{x^2 + 1}"), + (r"So \boxed{\frac{1}{2}}", r"\boxed{\frac{1}{2}}"), + (r"No boxed here", None), + (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), + (r"\boxed{nested {braces}}", r"\boxed{nested {braces}}"), + (r"\fbox{fbox content}", r"\fbox{fbox content}"), + ("", None), + ], + ) + def test_last_boxed_only_string(self, input_str, expected): + assert last_boxed_only_string(input_str) == expected + + +class TestRemoveBoxed: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"\boxed{42}", "42"), + (r"\boxed{x^2 + 1}", "x^2 + 1"), + (r"\boxed{\frac{1}{2}}", r"\frac{1}{2}"), + ("not boxed", None), + ], + ) + def test_remove_boxed(self, input_str, expected): + assert remove_boxed(input_str) == expected + + +class TestExtractAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", "42"), + (r"So \boxed{\frac{1}{2}}", r"\frac{1}{2}"), + (r"Multiple \boxed{1} then \boxed{final}", "final"), + (r"No boxed here", None), + ("", None), + ], + ) + def test_extract_answer(self, input_str, expected): + assert extract_answer(input_str) == expected + + +class TestNormalize: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("1,000", "1000"), + (r"\text{hello}", "hello"), + (" 42 ", "42"), + (r"100%", "100"), + (r"\$50", "50"), + ("HELLO", "hello"), + ("1,234,567", "1234567"), + (None, None), + ], + ) + def test_normalize(self, input_str, expected): + assert _normalize(input_str) == expected + + +class TestGradeAnswerMathd: + @pytest.mark.parametrize( + "given,ground_truth,expected", + [ + ("42", "42", True), + (" 42 ", "42", True), + (r"\frac{1}{2}", r"\frac{1}{2}", True), + ("wrong", "42", False), + ("", "42", False), + ], + ) + def test_grade_answer_mathd(self, given, ground_truth, expected): + assert grade_answer_mathd(given, ground_truth) == expected + + +class TestGradeAnswerSympy: + @pytest.mark.parametrize( + "given,ground_truth,expected", + [ + ("42", "42", True), + ("x^2", "x^2", True), + ("1/2", "0.5", True), + (r"\frac{1}{2}", "0.5", True), + ("wrong", "42", False), + ("", "42", False), + ("(1,2)", "(1,2)", True), + ("(1,2,3)", "(1,2)", False), + ("42", None, False), + ], + ) + def test_grade_answer_sympy(self, given, ground_truth, expected): + assert grade_answer_sympy(given, ground_truth) == expected + + +class TestGradeAnswerVerl: + @pytest.mark.parametrize( + "solution,ground_truth,expected", + [ + (r"\boxed{42}", "42", True), + (r"The answer is \boxed{42}", "42", True), + (r"\boxed{1/2}", r"\frac{1}{2}", True), + (r"\boxed{wrong}", "42", False), + ("no boxed", "42", False), + (r"\boxed{42}", r"\boxed{42}", True), + ("", "42", False), + (r"\boxed{42}", "", False), + (r"\boxed{42}", None, False), + ], + ) + def test_grade_answer_verl(self, solution, ground_truth, expected): + assert grade_answer_verl(solution, ground_truth) == expected diff --git a/tests/rollout/rm_hub/test_rm_hub.py b/tests/rollout/rm_hub/test_rm_hub.py new file mode 100644 index 0000000000..a3dadbdaf0 --- /dev/null +++ b/tests/rollout/rm_hub/test_rm_hub.py @@ -0,0 +1,126 @@ +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.rm_hub import async_rm, batched_async_rm +from miles.utils.async_utils import run +from miles.utils.types import Sample + + +@pytest.fixture +def mock_args(): + args = MagicMock() + args.custom_rm_path = None + args.rm_type = None + args.rm_url = None + return args + + +class TestAsyncRm: + @pytest.mark.parametrize( + "rm_type,response,label,expected", + [ + ("math", r"\boxed{42}", "42", 1), + ("math", r"\boxed{wrong}", "42", 0), + ("f1", "hello world", "hello world", 1.0), + ("dapo", "Answer: 42", "42", {"score": 1.0}), + ("deepscaler", r"\boxed{42}", "42", 1), + ("gpqa", "Answer: A", "A", 1.0), + ("boxed_f1", r"Final answer is \boxed{hello world}", "hello world", 1.0), + ], + ) + def test_rm_types(self, mock_args, rm_type, response, label, expected): + mock_args.rm_type = rm_type + sample = Sample(prompt="", response=response, label=label) + reward = run(async_rm(mock_args, sample)) + if isinstance(expected, dict): + for k, v in expected.items(): + assert reward[k] == v + else: + assert reward == expected + + def test_f1_rm_partial(self, mock_args): + mock_args.rm_type = "f1" + sample = Sample(prompt="", response="hello", label="hello world") + reward = run(async_rm(mock_args, sample)) + assert 0 < reward < 1 + + def test_random_rm(self, mock_args): + mock_args.rm_type = "random" + sample = Sample(prompt="", response="anything", label="anything") + reward = run(async_rm(mock_args, sample)) + assert reward in [0, 1] + + def test_rm_type_from_metadata(self, mock_args): + mock_args.rm_type = None + sample = Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}) + reward = run(async_rm(mock_args, sample)) + assert reward == 1 + + @pytest.mark.parametrize( + "rm_type,match", + [ + ("unknown_type", "not implemented"), + ("", "not specified"), + ], + ) + def test_invalid_rm_type_raises(self, mock_args, rm_type, match): + mock_args.rm_type = rm_type + sample = Sample(prompt="", response="test", label="test") + with pytest.raises(NotImplementedError, match=match): + run(async_rm(mock_args, sample)) + + +class TestBatchedAsyncRm: + @pytest.mark.parametrize( + "rm_type,samples_data,expected", + [ + ( + "math", + [(r"\boxed{42}", "42"), (r"\boxed{100}", "100"), (r"\boxed{wrong}", "42")], + [1, 1, 0], + ), + ( + "f1", + [("hello world", "hello world"), ("different", "something else")], + [1.0, 0], + ), + ], + ) + def test_batched_rm(self, mock_args, rm_type, samples_data, expected): + mock_args.rm_type = rm_type + samples = [Sample(prompt="", response=r, label=label) for r, label in samples_data] + rewards = run(batched_async_rm(mock_args, samples)) + assert rewards == expected + + def test_inplace_set_reward_field(self, mock_args): + mock_args.rm_type = "math" + samples = [ + Sample(prompt="", response=r"\boxed{42}", label="42"), + Sample(prompt="", response=r"\boxed{100}", label="100"), + ] + result = run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) + assert result is None + assert samples[0].reward == 1 + assert samples[1].reward == 1 + + def test_inplace_raises_on_existing_reward(self, mock_args): + mock_args.rm_type = "math" + samples = [Sample(prompt="", response=r"\boxed{42}", label="42", reward=0.5)] + with pytest.raises(AssertionError, match="Overriding"): + run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) + + def test_empty_samples(self, mock_args): + mock_args.rm_type = "math" + rewards = run(batched_async_rm(mock_args, [])) + assert rewards == [] + + def test_mixed_rm_types_via_metadata(self, mock_args): + mock_args.rm_type = None + samples = [ + Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}), + Sample(prompt="", response="hello", label="hello", metadata={"rm_type": "f1"}), + ] + rewards = run(batched_async_rm(mock_args, samples)) + assert rewards[0] == 1 + assert rewards[1] == 1.0 From 3e3ce1a797791a7f8d1bd59da4d0c5b7c519d393 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:08:51 +0800 Subject: [PATCH 15/77] Enhance mock sglang server with concurrency and requests recording and latency (#444) --- miles/utils/test_utils/mock_sglang_server.py | 88 ++++++++++++++----- .../test_utils/test_mock_sglang_server.py | 88 +++++++++++++++---- 2 files changed, 139 insertions(+), 37 deletions(-) diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py index 6d4144fc1f..e0f1673583 100644 --- a/miles/utils/test_utils/mock_sglang_server.py +++ b/miles/utils/test_utils/mock_sglang_server.py @@ -1,3 +1,4 @@ +import asyncio import re from collections.abc import Callable from contextlib import contextmanager @@ -27,50 +28,68 @@ def __init__( process_fn: ProcessFn, host: str, port: int, + latency: float = 0.0, ): self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) self.process_fn = process_fn self.host = host self.port = port or find_available_port(30000) + self.latency = latency self.app = FastAPI() self._server: UvicornThreadServer | None = None + self.request_log: list[dict] = [] + self._concurrency = Counter() + self._setup_routes() + @property + def max_concurrent(self) -> int: + return self._concurrency.max_value + + def reset_stats(self): + self.request_log.clear() + self._concurrency.reset() + def _setup_routes(self): @self.app.post("/generate") async def generate(request: Request): payload = await request.json() + self.request_log.append(payload) + + with self._concurrency.track(): + if self.latency > 0: + await asyncio.sleep(self.latency) - assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True" - input_ids = payload.get("input_ids", []) + assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True" + input_ids = payload.get("input_ids", []) - prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False) - process_result = self.process_fn(prompt_str) - output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False) + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) - prompt_tokens = len(input_ids) - completion_tokens = len(output_ids) + prompt_tokens = len(input_ids) + completion_tokens = len(output_ids) - finish_reason_dict = {"type": process_result.finish_reason} - if process_result.finish_reason == "length": - finish_reason_dict["length"] = completion_tokens + finish_reason_dict = {"type": process_result.finish_reason} + if process_result.finish_reason == "length": + finish_reason_dict["length"] = completion_tokens - output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] + output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] - response = { - "text": process_result.text, - "meta_info": { - "finish_reason": finish_reason_dict, - "prompt_tokens": prompt_tokens, - "cached_tokens": 0, - "completion_tokens": completion_tokens, - "output_token_logprobs": output_token_logprobs, - }, - } + response = { + "text": process_result.text, + "meta_info": { + "finish_reason": finish_reason_dict, + "prompt_tokens": prompt_tokens, + "cached_tokens": 0, + "completion_tokens": completion_tokens, + "output_token_logprobs": output_token_logprobs, + }, + } - return JSONResponse(content=response) + return JSONResponse(content=response) @self.app.get("/health") async def health(): @@ -93,6 +112,29 @@ def url(self) -> str: return f"http://{self.host}:{self.port}" +class Counter: + def __init__(self): + self._current = 0 + self._max = 0 + + @property + def max_value(self) -> int: + return self._max + + def reset(self): + self._current = 0 + self._max = 0 + + @contextmanager + def track(self): + self._current += 1 + self._max = max(self._max, self._current) + try: + yield + finally: + self._current -= 1 + + def default_process_fn(prompt: str) -> ProcessResult: match = re.search(r"What is 1\+(\d+)\?", prompt) if match: @@ -108,12 +150,14 @@ def with_mock_server( process_fn: ProcessFn = default_process_fn, host: str = "127.0.0.1", port: int | None = None, + latency: float = 0.0, ): server = MockSGLangServer( model_name=model_name, process_fn=process_fn, host=host, port=port, + latency=latency, ) try: server.start() diff --git a/tests/utils/test_utils/test_mock_sglang_server.py b/tests/utils/test_utils/test_mock_sglang_server.py index 6163e68bda..0601307d74 100644 --- a/tests/utils/test_utils/test_mock_sglang_server.py +++ b/tests/utils/test_utils/test_mock_sglang_server.py @@ -1,7 +1,11 @@ +import asyncio +import concurrent.futures +import time + import pytest import requests -from miles.utils.test_utils.mock_sglang_server import ProcessResult, default_process_fn, with_mock_server +from miles.utils.test_utils.mock_sglang_server import Counter, ProcessResult, default_process_fn, with_mock_server @pytest.fixture(scope="module") @@ -50,7 +54,7 @@ def test_generate_endpoint_basic(mock_server): } -def test_process_fn_receives_decoded_prompt(mock_server): +def test_process_fn_receives_decoded_prompt(): received_prompts = [] def process_fn(prompt: str) -> ProcessResult: @@ -58,22 +62,76 @@ def process_fn(prompt: str) -> ProcessResult: return ProcessResult(text="response", finish_reason="stop") with with_mock_server(process_fn=process_fn) as server: - input_ids = [1, 2, 3] - requests.post(f"{server.url}/generate", json={"input_ids": input_ids, "sampling_params": {}}, timeout=5.0) + requests.post(f"{server.url}/generate", json={"input_ids": [1, 2, 3], "sampling_params": {}}, timeout=5.0) - assert len(received_prompts) == 1 - assert isinstance(received_prompts[0], str) + assert len(received_prompts) == 1 + assert isinstance(received_prompts[0], str) def test_default_process_fn(): - result = default_process_fn("What is 1+5?") - assert result.text == "\\boxed{6}" - assert result.finish_reason == "stop" + assert default_process_fn("What is 1+5?") == ProcessResult(text="\\boxed{6}", finish_reason="stop") + assert default_process_fn("What is 1+10?") == ProcessResult(text="\\boxed{11}", finish_reason="stop") + assert default_process_fn("Hello") == ProcessResult(text="I don't understand.", finish_reason="stop") + + +def test_request_log_and_reset_stats(mock_server): + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + + payload = {"input_ids": [1, 2, 3], "sampling_params": {"temperature": 0.5}, "return_logprob": True} + requests.post(f"{mock_server.url}/generate", json=payload, timeout=5.0) + assert len(mock_server.request_log) == 1 + assert mock_server.request_log[0] == payload + + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + assert mock_server.max_concurrent == 0 + + +@pytest.mark.parametrize("latency,min_time,max_time", [(0.0, 0.0, 0.3), (0.5, 0.5, 1.0)]) +def test_latency(latency, min_time, max_time): + with with_mock_server(latency=latency) as server: + start = time.time() + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + elapsed = time.time() - start + assert min_time <= elapsed < max_time + + +def test_max_concurrent_with_latency(): + with with_mock_server(latency=0.1) as server: + + def send_request(): + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(send_request) for _ in range(3)] + concurrent.futures.wait(futures) + + assert server.max_concurrent == 3 + + +def test_counter_tracks_max(): + counter = Counter() + assert counter.max_value == 0 + + with counter.track(): + assert counter.max_value == 1 + with counter.track(): + assert counter.max_value == 2 + + counter.reset() + assert counter.max_value == 0 + + +def test_counter_concurrent_tasks(): + counter = Counter() + + async def task(): + with counter.track(): + await asyncio.sleep(0.1) - result = default_process_fn("What is 1+10?") - assert result.text == "\\boxed{11}" - assert result.finish_reason == "stop" + async def run_all(): + await asyncio.gather(task(), task(), task()) - result = default_process_fn("Hello") - assert result.text == "I don't understand." - assert result.finish_reason == "stop" + asyncio.run(run_all()) + assert counter.max_value == 3 From 274fc42e0fbb4ec1ec3e222591c3c02640d7bee9 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:09:10 +0800 Subject: [PATCH 16/77] Add FunctionRegistry to patch load_function (#445) --- miles/utils/misc.py | 35 ++++- .../modular_rollout/test_compatibility.py | 123 +++++++++--------- tests/utils/test_misc.py | 59 +++++++++ 3 files changed, 155 insertions(+), 62 deletions(-) create mode 100644 tests/utils/test_misc.py diff --git a/miles/utils/misc.py b/miles/utils/misc.py index 823738a56f..fa772b5222 100644 --- a/miles/utils/misc.py +++ b/miles/utils/misc.py @@ -1,21 +1,54 @@ import asyncio import importlib import subprocess +from contextlib import contextmanager import ray from miles.utils.http_utils import is_port_available +# Mainly used for test purpose where `load_function` needs to load many in-flight generated functions +class FunctionRegistry: + def __init__(self): + self._registry: dict[str, object] = {} + + @contextmanager + def temporary(self, name: str, fn: object): + self._register(name, fn) + try: + yield + finally: + self._unregister(name) + + def get(self, name: str) -> object | None: + return self._registry.get(name) + + def _register(self, name: str, fn: object) -> None: + assert name not in self._registry + self._registry[name] = fn + + def _unregister(self, name: str) -> None: + assert name in self._registry + self._registry.pop(name) + + +function_registry = FunctionRegistry() + + def load_function(path): """ - Load a function from a module. + Load a function from registry or module. :param path: The path to the function, e.g. "module.submodule.function". :return: The function object. """ if path is None: return None + registered = function_registry.get(path) + if registered is not None: + return registered + module_path, _, attr = path.rpartition(".") module = importlib.import_module(module_path) return getattr(module, attr) diff --git a/tests/rollout/modular_rollout/test_compatibility.py b/tests/rollout/modular_rollout/test_compatibility.py index c3beba996b..f012cbd490 100644 --- a/tests/rollout/modular_rollout/test_compatibility.py +++ b/tests/rollout/modular_rollout/test_compatibility.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -20,6 +20,7 @@ load_rollout_function, ) from miles.utils.async_utils import run +from miles.utils.misc import function_registry @pytest.fixture @@ -55,19 +56,19 @@ def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): return {"metric": {"accuracy": 0.9}} return [[{"text": "sample"}]] - with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=legacy_rollout_fn): - fn = load_rollout_function(constructor_input, "path.to.fn") + with function_registry.temporary("test:legacy_rollout", legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "test:legacy_rollout") - input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput - result = call_rollout_function(fn, input_cls(rollout_id=1)) + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) - assert isinstance(fn, LegacyRolloutFnAdapter) - if evaluation: - assert isinstance(result, RolloutFnEvalOutput) - assert result.data == {"metric": {"accuracy": 0.9}} - else: - assert isinstance(result, RolloutFnTrainOutput) - assert result.samples == [[{"text": "sample"}]] + assert isinstance(fn, LegacyRolloutFnAdapter) + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"metric": {"accuracy": 0.9}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "sample"}]] @pytest.mark.parametrize("evaluation", [False, True]) def test_format_2_legacy_function_typed_output(self, constructor_input, evaluation): @@ -76,18 +77,18 @@ def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): return RolloutFnEvalOutput(data={"ds": {"acc": 0.95}}) return RolloutFnTrainOutput(samples=[[{"text": "typed"}]]) - with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=legacy_rollout_fn): - fn = load_rollout_function(constructor_input, "path.to.fn") + with function_registry.temporary("test:legacy_typed", legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "test:legacy_typed") - input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput - result = call_rollout_function(fn, input_cls(rollout_id=1)) + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) - if evaluation: - assert isinstance(result, RolloutFnEvalOutput) - assert result.data == {"ds": {"acc": 0.95}} - else: - assert isinstance(result, RolloutFnTrainOutput) - assert result.samples == [[{"text": "typed"}]] + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"ds": {"acc": 0.95}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "typed"}]] @pytest.mark.parametrize("evaluation", [False, True]) def test_format_3_sync_class(self, constructor_input, evaluation): @@ -100,15 +101,15 @@ def __call__(self, input): return RolloutFnEvalOutput(data={"test": {"score": 1}}) return RolloutFnTrainOutput(samples=[[{"text": "sync"}]]) - with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=SyncRolloutFn): - fn = load_rollout_function(constructor_input, "path.to.SyncRolloutFn") + with function_registry.temporary("test:sync_class", SyncRolloutFn): + fn = load_rollout_function(constructor_input, "test:sync_class") - input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput - result = call_rollout_function(fn, input_cls(rollout_id=1)) + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) - assert isinstance(fn, SyncRolloutFn) - expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput - assert isinstance(result, expected_type) + assert isinstance(fn, SyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) @pytest.mark.parametrize("evaluation", [False, True]) def test_format_4_async_class(self, constructor_input, evaluation): @@ -122,15 +123,15 @@ async def __call__(self, input): return RolloutFnEvalOutput(data={"benchmark": {"accuracy": 0.98}}) return RolloutFnTrainOutput(samples=[[{"text": "async"}]]) - with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=AsyncRolloutFn): - fn = load_rollout_function(constructor_input, "path.to.AsyncRolloutFn") + with function_registry.temporary("test:async_class", AsyncRolloutFn): + fn = load_rollout_function(constructor_input, "test:async_class") - input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput - result = call_rollout_function(fn, input_cls(rollout_id=1)) + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) - assert isinstance(fn, AsyncRolloutFn) - expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput - assert isinstance(result, expected_type) + assert isinstance(fn, AsyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) class TestSupportedGenerateFormats: @@ -143,53 +144,53 @@ def test_format_1_legacy_function_with_evaluation_param(self, make_generate_fn_i async def legacy_generate_fn(args, sample, sampling_params, evaluation=False): return "my_sample" - with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=legacy_generate_fn): - fn = load_generate_function("path.to.fn") + with function_registry.temporary("test:legacy_gen_eval", legacy_generate_fn): + fn = load_generate_function("test:legacy_gen_eval") - result = run(fn(make_generate_fn_input(evaluation))) + result = run(fn(make_generate_fn_input(evaluation))) - assert isinstance(fn, LegacyGenerateFnAdapter) - assert isinstance(result, GenerateFnOutput) - assert result.sample == "my_sample" + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" @pytest.mark.parametrize("evaluation", [False, True]) def test_format_2_legacy_function_without_evaluation_param(self, make_generate_fn_input, evaluation): async def legacy_generate_fn(args, sample, sampling_params): return "my_sample" - with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=legacy_generate_fn): - fn = load_generate_function("path.to.fn") + with function_registry.temporary("test:legacy_gen", legacy_generate_fn): + fn = load_generate_function("test:legacy_gen") - result = run(fn(make_generate_fn_input(evaluation))) + result = run(fn(make_generate_fn_input(evaluation))) - assert isinstance(fn, LegacyGenerateFnAdapter) - assert isinstance(result, GenerateFnOutput) - assert result.sample == "my_sample" + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" @pytest.mark.parametrize("evaluation", [False, True]) def test_format_3_new_async_function_api(self, make_generate_fn_input, evaluation): async def generate(input: GenerateFnInput) -> GenerateFnOutput: - return GenerateFnOutput(sample="my_sample") + return GenerateFnOutput(samples="my_sample") - with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=generate): - fn = load_generate_function("path.to.fn") + with function_registry.temporary("test:new_async", generate): + fn = load_generate_function("test:new_async") - result = run(fn(make_generate_fn_input(evaluation))) + result = run(fn(make_generate_fn_input(evaluation))) - assert isinstance(result, GenerateFnOutput) - assert result.sample == "my_sample" + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" @pytest.mark.parametrize("evaluation", [False, True]) def test_format_4_new_class_api(self, make_generate_fn_input, evaluation): class MyGenerateFn: async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: - return GenerateFnOutput(sample="my_sample") + return GenerateFnOutput(samples="my_sample") - with patch("miles.rollout.modular_rollout.compatibility.load_function", return_value=MyGenerateFn): - fn = load_generate_function("path.to.fn") + with function_registry.temporary("test:new_class", MyGenerateFn): + fn = load_generate_function("test:new_class") - result = run(fn(make_generate_fn_input(evaluation))) + result = run(fn(make_generate_fn_input(evaluation))) - assert isinstance(fn, MyGenerateFn) - assert isinstance(result, GenerateFnOutput) - assert result.sample == "my_sample" + assert isinstance(fn, MyGenerateFn) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" diff --git a/tests/utils/test_misc.py b/tests/utils/test_misc.py new file mode 100644 index 0000000000..810c2b67c7 --- /dev/null +++ b/tests/utils/test_misc.py @@ -0,0 +1,59 @@ +import os + +import pytest + +from miles.utils.misc import FunctionRegistry, function_registry, load_function + + +def _fn_a(): + return "a" + + +def _fn_b(): + return "b" + + +class TestFunctionRegistry: + def test_register_and_get(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + assert registry.get("my_fn") is _fn_a + + def test_register_duplicate_raises(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + with pytest.raises(AssertionError): + with registry.temporary("my_fn", _fn_b): + pass + + def test_unregister(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + assert registry.get("my_fn") is _fn_a + assert registry.get("my_fn") is None + + def test_temporary_cleanup_on_exception(self): + registry = FunctionRegistry() + with pytest.raises(RuntimeError): + with registry.temporary("temp_fn", _fn_a): + raise RuntimeError("test") + assert registry.get("temp_fn") is None + + +class TestLoadFunction: + def test_load_from_module(self): + import os.path + + assert load_function("os.path.join") is os.path.join + + def test_load_none_returns_none(self): + assert load_function(None) is None + + def test_load_from_registry(self): + with function_registry.temporary("test:my_fn", _fn_a): + assert load_function("test:my_fn") is _fn_a + + def test_registry_takes_precedence(self): + with function_registry.temporary("os.path.join", _fn_b): + assert load_function("os.path.join") is _fn_b + assert load_function("os.path.join") is os.path.join From 5b1eb163c9f5d0e04317af4daab8d16d3a46b613 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:09:45 +0800 Subject: [PATCH 17/77] Add integration tests to cover various modes and features in rollout (#446) --- tests/fixtures/rollout_integration.py | 38 +++++++--- tests/rollout/modular_rollout/conftest.py | 45 +++++++++++ .../modular_rollout/integration/__init__.py | 0 .../modular_rollout/integration/test_basic.py | 70 ++++++++++++++++++ .../integration/test_deterministic.py | 37 ++++++++++ .../integration/test_dynamic_filter.py | 46 ++++++++++++ .../integration/test_group_rm.py | 22 ++++++ .../integration/test_multi_sample.py | 65 ++++++++++++++++ .../integration/test_over_sampling.py | 44 +++++++++++ .../integration/test_sample_filter.py | 59 +++++++++++++++ .../integration/test_semaphore.py | 29 ++++++++ .../modular_rollout/integration/utils.py | 74 +++++++++++++++++++ 12 files changed, 520 insertions(+), 9 deletions(-) create mode 100644 tests/rollout/modular_rollout/conftest.py create mode 100644 tests/rollout/modular_rollout/integration/__init__.py create mode 100644 tests/rollout/modular_rollout/integration/test_basic.py create mode 100644 tests/rollout/modular_rollout/integration/test_deterministic.py create mode 100644 tests/rollout/modular_rollout/integration/test_dynamic_filter.py create mode 100644 tests/rollout/modular_rollout/integration/test_group_rm.py create mode 100644 tests/rollout/modular_rollout/integration/test_multi_sample.py create mode 100644 tests/rollout/modular_rollout/integration/test_over_sampling.py create mode 100644 tests/rollout/modular_rollout/integration/test_sample_filter.py create mode 100644 tests/rollout/modular_rollout/integration/test_semaphore.py create mode 100644 tests/rollout/modular_rollout/integration/utils.py diff --git a/tests/fixtures/rollout_integration.py b/tests/fixtures/rollout_integration.py index 079147d289..ea2c3aa0a3 100644 --- a/tests/fixtures/rollout_integration.py +++ b/tests/fixtures/rollout_integration.py @@ -2,22 +2,37 @@ from argparse import Namespace from collections.abc import Iterator from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path from unittest.mock import patch import pytest import requests -from miles.rollout.data_source import RolloutDataSourceWithBuffer +from miles.rollout.data_source import DataSource, RolloutDataSourceWithBuffer from miles.rollout.modular_rollout.orchestration_common import GenerateState from miles.router.router import MilesRouter from miles.utils.arguments import parse_args from miles.utils.http_utils import find_available_port, init_http_client from miles.utils.misc import SingletonMeta -from miles.utils.test_utils.mock_sglang_server import with_mock_server +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, with_mock_server from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer +@dataclass(frozen=True) +class IntegrationEnvConfig: + extra_argv: list[str] | None = None + data_rows: list[dict] | None = None + latency: float = 0.0 + + +@dataclass(frozen=True) +class IntegrationEnv: + args: Namespace + data_source: DataSource + mock_server: MockSGLangServer + + def _build_args(*, data_path: str, router_port: int, extra_argv: list[str] | None = None) -> Namespace: argv = [ "pytest", @@ -80,20 +95,25 @@ def _cleanup_legacy_singleton(): SingletonMeta._instances.pop(GenerateState, None) +DEFAULT_DATA_ROWS = [{"input": "What is 1+7?", "label": "8"}] + + @pytest.fixture -def rollout_integration_env(tmp_path, request): - extra_argv = request.param - assert isinstance(extra_argv, list) +def rollout_integration_env(tmp_path, request) -> IntegrationEnv: + config = request.param + assert isinstance(config, IntegrationEnvConfig) + + data_rows = config.data_rows or DEFAULT_DATA_ROWS data_path = str(tmp_path / "data.jsonl") - _write_jsonl(data_path, [{"input": "What is 1+7?", "label": "8"}]) + _write_jsonl(data_path, data_rows) router_port = find_available_port(20000) - args = _build_args(data_path=data_path, router_port=router_port, extra_argv=extra_argv) + args = _build_args(data_path=data_path, router_port=router_port, extra_argv=config.extra_argv) _cleanup_legacy_singleton() - with with_mock_server(model_name=args.hf_checkpoint) as mock_server: + with with_mock_server(model_name=args.hf_checkpoint, latency=config.latency) as mock_server: with _with_miles_router(args) as router_server: r = requests.post( f"{router_server.url}/add_worker", @@ -103,6 +123,6 @@ def rollout_integration_env(tmp_path, request): r.raise_for_status() data_source = RolloutDataSourceWithBuffer(args) - yield args, data_source + yield IntegrationEnv(args=args, data_source=data_source, mock_server=mock_server) _cleanup_legacy_singleton() diff --git a/tests/rollout/modular_rollout/conftest.py b/tests/rollout/modular_rollout/conftest.py new file mode 100644 index 0000000000..ca47edeeb6 --- /dev/null +++ b/tests/rollout/modular_rollout/conftest.py @@ -0,0 +1,45 @@ +from unittest.mock import patch + +import pytest + +from miles.utils.arguments import parse_args + + +def _build_mock_args(extra_argv: list[str] | None = None): + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "2", + "--n-samples-per-prompt", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "4", + "--rollout-num-gpus-per-engine", + "2", + "--hf-checkpoint", + "Qwen/Qwen3-0.6B", + "--prompt-data", + "/dev/null", + "--input-key", + "input", + "--label-key", + "label", + "--rm-type", + "math", + "--use-miles-router", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + "30000", + ] + (extra_argv or []) + with patch("sys.argv", argv): + return parse_args() + + +@pytest.fixture +def mock_args(): + return _build_mock_args() diff --git a/tests/rollout/modular_rollout/integration/__init__.py b/tests/rollout/modular_rollout/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/rollout/modular_rollout/integration/test_basic.py b/tests/rollout/modular_rollout/integration/test_basic.py new file mode 100644 index 0000000000..bbb82ae50e --- /dev/null +++ b/tests/rollout/modular_rollout/integration/test_basic.py @@ -0,0 +1,70 @@ +import pytest +from tests.fixtures.rollout_integration import IntegrationEnvConfig +from tests.rollout.modular_rollout.integration.utils import ( + MODULAR_ROLLOUT_BASE_ARGV, + expected_sample, + load_and_call_train, +) + +from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput +from miles.rollout.modular_rollout.compatibility import call_rollout_function, load_rollout_function + +_VARIANTS = [ + pytest.param( + IntegrationEnvConfig( + extra_argv=[ + "--rollout-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--eval-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ] + ), + id="old_rollout_old_generate", + ), + pytest.param( + IntegrationEnvConfig( + extra_argv=[ + "--rollout-function-path", + "miles.rollout.modular_rollout.orchestration_train.SimpleTrainRolloutFn", + "--eval-function-path", + "miles.rollout.modular_rollout.orchestration_eval.SimpleEvalRolloutFn", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ] + ), + id="new_rollout_old_generate", + ), + pytest.param( + IntegrationEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV), + id="new_rollout_new_generate", + ), +] + + +@pytest.mark.parametrize("rollout_integration_env", _VARIANTS, indirect=True) +def test_train(rollout_integration_env): + env = rollout_integration_env + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + assert len(group) == env.args.n_samples_per_prompt + assert group[0] == expected_sample(group_index=0) + + +@pytest.mark.parametrize("rollout_integration_env", _VARIANTS, indirect=True) +def test_eval(rollout_integration_env): + env = rollout_integration_env + fn = load_rollout_function( + RolloutFnConstructorInput(args=env.args, data_source=env.data_source), env.args.eval_function_path + ) + out = call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + assert "toy" in out.data + rewards = out.data["toy"]["rewards"] + samples = out.data["toy"]["samples"] + assert len(rewards) == len(samples) == env.args.n_samples_per_eval_prompt + assert rewards[0] == 1 + assert samples[0] == expected_sample(group_index=None) diff --git a/tests/rollout/modular_rollout/integration/test_deterministic.py b/tests/rollout/modular_rollout/integration/test_deterministic.py new file mode 100644 index 0000000000..63316ceb45 --- /dev/null +++ b/tests/rollout/modular_rollout/integration/test_deterministic.py @@ -0,0 +1,37 @@ +import pytest + +from tests.rollout.modular_rollout.integration.utils import config, load_and_call_train + + +@pytest.mark.parametrize( + "rollout_integration_env,expected_seeds", + [ + pytest.param( + config( + [ + "--sglang-enable-deterministic-inference", + "--rollout-seed", + "42", + "--n-samples-per-prompt", + "3", + "--rollout-batch-size", + "1", + ] + ), + {42, 43, 44}, + id="enabled", + ), + pytest.param( + config(["--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + {None}, + id="disabled", + ), + ], + indirect=["rollout_integration_env"], +) +def test_sampling_seeds(rollout_integration_env, expected_seeds): + env = rollout_integration_env + load_and_call_train(env.args, env.data_source) + + seeds = {req.get("sampling_params", {}).get("sampling_seed") for req in env.mock_server.request_log} + assert seeds == expected_seeds diff --git a/tests/rollout/modular_rollout/integration/test_dynamic_filter.py b/tests/rollout/modular_rollout/integration/test_dynamic_filter.py new file mode 100644 index 0000000000..c7e86657c5 --- /dev/null +++ b/tests/rollout/modular_rollout/integration/test_dynamic_filter.py @@ -0,0 +1,46 @@ +from contextlib import nullcontext + +import pytest +from tests.rollout.modular_rollout.integration.utils import ( + MIXED_DATA_ROWS, + config, + filter_by_reward, + load_and_call_train, +) + +from miles.utils.misc import function_registry + + +@pytest.mark.parametrize( + "rollout_integration_env,use_filter,expect_all_correct", + [ + pytest.param( + config(["--rollout-batch-size", "4"], data_rows=MIXED_DATA_ROWS), + False, + False, + id="no_filter", + ), + pytest.param( + config( + ["--rollout-batch-size", "3", "--dynamic-sampling-filter-path", "test:filter_by_reward"], + data_rows=MIXED_DATA_ROWS, + ), + True, + True, + id="with_filter", + ), + ], + indirect=["rollout_integration_env"], +) +def test_filter_effect(rollout_integration_env, use_filter, expect_all_correct): + env = rollout_integration_env + ctx = function_registry.temporary("test:filter_by_reward", filter_by_reward) if use_filter else nullcontext() + + with ctx: + out = load_and_call_train(env.args, env.data_source) + + rewards = {group[0].reward for group in out.samples} + if expect_all_correct: + assert rewards == {1}, "Filter should keep only correct samples" + else: + assert 0 in rewards, "Without filter, incorrect samples should be present" diff --git a/tests/rollout/modular_rollout/integration/test_group_rm.py b/tests/rollout/modular_rollout/integration/test_group_rm.py new file mode 100644 index 0000000000..8b8ab269d6 --- /dev/null +++ b/tests/rollout/modular_rollout/integration/test_group_rm.py @@ -0,0 +1,22 @@ +import pytest + +from tests.rollout.modular_rollout.integration.utils import config, load_and_call_train + + +@pytest.mark.parametrize( + "rollout_integration_env", + [ + pytest.param( + config(["--group-rm", "--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + id="group_rm_enabled", + ), + ], + indirect=True, +) +def test_group_rm_rewards_set(rollout_integration_env): + env = rollout_integration_env + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + rewards = [sample.reward for group in out.samples for sample in group] + assert all(r in (0, 1) for r in rewards) diff --git a/tests/rollout/modular_rollout/integration/test_multi_sample.py b/tests/rollout/modular_rollout/integration/test_multi_sample.py new file mode 100644 index 0000000000..72cdee12b9 --- /dev/null +++ b/tests/rollout/modular_rollout/integration/test_multi_sample.py @@ -0,0 +1,65 @@ +import pytest +from tests.fixtures.rollout_integration import DEFAULT_DATA_ROWS, IntegrationEnvConfig +from tests.rollout.modular_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_train + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.utils.misc import function_registry +from miles.utils.types import Sample + + +async def _multi_sample_generate(input: GenerateFnInput) -> GenerateFnOutput: + sample = input.sample + s1 = Sample( + prompt=sample.prompt, + response="\\boxed{8}", + response_length=5, + tokens=sample.tokens + [59, 79075, 90, 23, 92], + label=sample.label, + reward=None, + status=Sample.Status.COMPLETED, + ) + s2 = Sample( + prompt=sample.prompt, + response="\\boxed{8}", + response_length=5, + tokens=sample.tokens + [59, 79075, 90, 23, 92], + label=sample.label, + reward=0.5, + status=Sample.Status.COMPLETED, + ) + return GenerateFnOutput(samples=[s1, s2]) + + +@pytest.mark.parametrize( + "rollout_integration_env", + [ + pytest.param( + IntegrationEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV[:4] + + [ + "--custom-generate-function-path", + "test:multi_sample_generate", + "--rollout-batch-size", + "1", + "--n-samples-per-prompt", + "1", + ], + data_rows=DEFAULT_DATA_ROWS, + ), + id="multi_sample_output", + ), + ], + indirect=True, +) +def test_multi_sample_output_preserves_existing_reward(rollout_integration_env): + env = rollout_integration_env + with function_registry.temporary("test:multi_sample_generate", _multi_sample_generate): + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + assert isinstance(group[0], list) + samples = group[0] + assert len(samples) == 2 + assert samples[0].reward == 1 + assert samples[1].reward == 0.5 diff --git a/tests/rollout/modular_rollout/integration/test_over_sampling.py b/tests/rollout/modular_rollout/integration/test_over_sampling.py new file mode 100644 index 0000000000..17ae7cb38f --- /dev/null +++ b/tests/rollout/modular_rollout/integration/test_over_sampling.py @@ -0,0 +1,44 @@ +import pytest +from tests.rollout.modular_rollout.integration.utils import config, filter_by_reward, load_and_call_train + +from miles.utils.misc import function_registry + +_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, + {"input": "What is 1+8?", "label": "wrong"}, + {"input": "What is 1+9?", "label": "wrong"}, + {"input": "What is 1+6?", "label": "wrong"}, +] + +_BASE_ARGV = [ + "--over-sampling-batch-size", + "4", + "--dynamic-sampling-filter-path", + "test:filter_by_reward", +] + + +def _over_sampling_config(rollout_batch_size: int): + return config(["--rollout-batch-size", str(rollout_batch_size)] + _BASE_ARGV, data_rows=_DATA_ROWS) + + +@pytest.mark.parametrize( + "rollout_integration_env,expected_rounds", + [ + pytest.param(_over_sampling_config(1), 1, id="one_round"), + pytest.param(_over_sampling_config(2), 2, id="two_rounds"), + ], + indirect=["rollout_integration_env"], +) +def test_over_sampling_rounds(rollout_integration_env, expected_rounds): + env = rollout_integration_env + + with function_registry.temporary("test:filter_by_reward", filter_by_reward): + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + assert all(group[0].reward == 1 for group in out.samples) + + requests_count = len(env.mock_server.request_log) + expected_requests = expected_rounds * env.args.over_sampling_batch_size + assert requests_count == expected_requests, f"Expected {expected_rounds} round(s) = {expected_requests} requests" diff --git a/tests/rollout/modular_rollout/integration/test_sample_filter.py b/tests/rollout/modular_rollout/integration/test_sample_filter.py new file mode 100644 index 0000000000..c5c183ba3d --- /dev/null +++ b/tests/rollout/modular_rollout/integration/test_sample_filter.py @@ -0,0 +1,59 @@ +from unittest.mock import Mock + +import pytest +from tests.rollout.modular_rollout.integration.utils import ( + MIXED_DATA_ROWS, + config, + filter_by_reward, + load_and_call_train, +) + +from miles.utils.misc import function_registry + + +@pytest.mark.parametrize( + "rollout_integration_env", + [ + pytest.param( + config( + [ + "--rollout-batch-size", + "2", + "--over-sampling-batch-size", + "4", + "--dynamic-sampling-filter-path", + "test:filter_by_reward", + "--rollout-sample-filter-path", + "test:sample_filter", + "--rollout-all-samples-process-path", + "test:all_samples_process", + ], + data_rows=MIXED_DATA_ROWS, + ), + id="sample_filter_vs_all_samples", + ), + ], + indirect=True, +) +def test_sample_filter_and_all_samples_process(rollout_integration_env): + env = rollout_integration_env + sample_filter_mock = Mock() + all_samples_process_mock = Mock() + + with ( + function_registry.temporary("test:filter_by_reward", filter_by_reward), + function_registry.temporary("test:sample_filter", sample_filter_mock), + function_registry.temporary("test:all_samples_process", all_samples_process_mock), + ): + load_and_call_train(env.args, env.data_source) + + sample_filter_mock.assert_called_once() + _, filtered_data = sample_filter_mock.call_args[0] + rewards = [g[0][0].reward if isinstance(g[0], list) else g[0].reward for g in filtered_data] + assert all(r == 1 for r in rewards) + + all_samples_process_mock.assert_called_once() + _, all_samples, data_source = all_samples_process_mock.call_args[0] + assert data_source is not None + + assert len(all_samples) > len(filtered_data), "all_samples_process should see more samples than sample_filter" diff --git a/tests/rollout/modular_rollout/integration/test_semaphore.py b/tests/rollout/modular_rollout/integration/test_semaphore.py new file mode 100644 index 0000000000..bcd09e3559 --- /dev/null +++ b/tests/rollout/modular_rollout/integration/test_semaphore.py @@ -0,0 +1,29 @@ +import pytest + +from tests.rollout.modular_rollout.integration.utils import config, load_and_call_train + +_DATA_ROWS = [{"input": f"What is 1+{i}?", "label": str(1 + i)} for i in range(10)] +_BASE_ARGV = ["--rollout-batch-size", "4", "--n-samples-per-prompt", "2"] + + +@pytest.mark.parametrize( + "rollout_integration_env,expected_range", + [ + pytest.param( + config(["--sglang-server-concurrency", "1"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05), + (1, 1), + id="limit_1", + ), + pytest.param( + config(["--sglang-server-concurrency", "999"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05), + (2, 999), + id="no_limit", + ), + ], + indirect=["rollout_integration_env"], +) +def test_max_concurrent(rollout_integration_env, expected_range): + env = rollout_integration_env + load_and_call_train(env.args, env.data_source) + min_expected, max_expected = expected_range + assert min_expected <= env.mock_server.max_concurrent <= max_expected diff --git a/tests/rollout/modular_rollout/integration/utils.py b/tests/rollout/modular_rollout/integration/utils.py new file mode 100644 index 0000000000..112409595a --- /dev/null +++ b/tests/rollout/modular_rollout/integration/utils.py @@ -0,0 +1,74 @@ +from tests.fixtures.rollout_integration import IntegrationEnvConfig + +from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnTrainInput +from miles.rollout.filter_hub.base_types import DynamicFilterOutput +from miles.rollout.modular_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.utils.types import Sample + + +def expected_sample(*, group_index: int | None) -> Sample: + return Sample( + group_index=group_index, + index=0, + prompt="What is 1+7?", + tokens=[3838, 374, 220, 16, 10, 22, 30, 59, 79075, 90, 23, 92], + multimodal_inputs=None, + multimodal_train_inputs=None, + response="\\boxed{8}", + response_length=5, + label="8", + reward=1, + loss_mask=None, + weight_versions=[], + rollout_log_probs=[-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125], + rollout_routed_experts=None, + remove_sample=False, + status=Sample.Status.COMPLETED, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=Sample.SpecInfo( + spec_accept_token_num=0, spec_draft_token_num=0, spec_verify_ct=0, completion_token_num=0 + ), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=0, total_prompt_tokens=7), + ) + + +MODULAR_ROLLOUT_BASE_ARGV = [ + "--rollout-function-path", + "miles.rollout.modular_rollout.orchestration_train.SimpleTrainRolloutFn", + "--eval-function-path", + "miles.rollout.modular_rollout.orchestration_eval.SimpleEvalRolloutFn", + "--custom-generate-function-path", + "miles.rollout.modular_rollout.inference_wrapper.generate", +] + +MIXED_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, + {"input": "What is 1+8?", "label": "9"}, + {"input": "What is 1+9?", "label": "wrong"}, + {"input": "What is 1+6?", "label": "7"}, +] + + +def config(extra_argv: list[str], data_rows: list[dict] | None = None, latency: float = 0.0): + return IntegrationEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv, + data_rows=data_rows, + latency=latency, + ) + + +def load_and_call_train(args, data_source): + fn = load_rollout_function( + RolloutFnConstructorInput(args=args, data_source=data_source), + args.rollout_function_path, + ) + return call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0)) + + +def filter_by_reward(args, samples, **kwargs): + reward = samples[0].reward if not isinstance(samples[0], list) else samples[0][0].reward + if reward == 1: + return DynamicFilterOutput(keep=True) + return DynamicFilterOutput(keep=False, reason="reward_zero") From dda50315560addea7851ad9705437ebc3d8fa53a Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:10:01 +0800 Subject: [PATCH 18/77] Support speculative information in mock sglang server (#449) --- miles/utils/test_utils/mock_sglang_server.py | 33 +++++++--- .../test_utils/test_mock_sglang_server.py | 66 ++++++++++++++++++- 2 files changed, 90 insertions(+), 9 deletions(-) diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py index e0f1673583..d13b5bdf8a 100644 --- a/miles/utils/test_utils/mock_sglang_server.py +++ b/miles/utils/test_utils/mock_sglang_server.py @@ -2,7 +2,7 @@ import re from collections.abc import Callable from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import asdict, dataclass from fastapi import FastAPI, Request from fastapi.responses import JSONResponse @@ -12,10 +12,24 @@ from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer +@dataclass(frozen=True) +class ProcessResultMetaInfo: + weight_version: str | None = None + routed_experts: str | None = None + spec_accept_token_num: int | None = None + spec_draft_token_num: int | None = None + spec_verify_ct: int | None = None + + def to_dict(self) -> dict: + return {k: v for k, v in asdict(self).items() if v is not None} + + @dataclass(frozen=True) class ProcessResult: text: str finish_reason: str + cached_tokens: int = 0 + meta_info: ProcessResultMetaInfo = ProcessResultMetaInfo() ProcessFn = Callable[[str], ProcessResult] @@ -78,15 +92,18 @@ async def generate(request: Request): output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] + meta_info = { + "finish_reason": finish_reason_dict, + "prompt_tokens": prompt_tokens, + "cached_tokens": process_result.cached_tokens, + "completion_tokens": completion_tokens, + "output_token_logprobs": output_token_logprobs, + **process_result.meta_info.to_dict(), + } + response = { "text": process_result.text, - "meta_info": { - "finish_reason": finish_reason_dict, - "prompt_tokens": prompt_tokens, - "cached_tokens": 0, - "completion_tokens": completion_tokens, - "output_token_logprobs": output_token_logprobs, - }, + "meta_info": meta_info, } return JSONResponse(content=response) diff --git a/tests/utils/test_utils/test_mock_sglang_server.py b/tests/utils/test_utils/test_mock_sglang_server.py index 0601307d74..9326122b87 100644 --- a/tests/utils/test_utils/test_mock_sglang_server.py +++ b/tests/utils/test_utils/test_mock_sglang_server.py @@ -5,7 +5,13 @@ import pytest import requests -from miles.utils.test_utils.mock_sglang_server import Counter, ProcessResult, default_process_fn, with_mock_server +from miles.utils.test_utils.mock_sglang_server import ( + Counter, + ProcessResult, + ProcessResultMetaInfo, + default_process_fn, + with_mock_server, +) @pytest.fixture(scope="module") @@ -74,6 +80,64 @@ def test_default_process_fn(): assert default_process_fn("Hello") == ProcessResult(text="I don't understand.", finish_reason="stop") +def test_process_result_meta_info_to_dict(): + assert ProcessResultMetaInfo().to_dict() == {} + assert ProcessResultMetaInfo(weight_version="v1").to_dict() == {"weight_version": "v1"} + assert ProcessResultMetaInfo(weight_version="v1", spec_accept_token_num=10).to_dict() == { + "weight_version": "v1", + "spec_accept_token_num": 10, + } + assert ProcessResultMetaInfo( + weight_version="v1", routed_experts="abc", spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3 + ).to_dict() == { + "weight_version": "v1", + "routed_experts": "abc", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + } + + +def test_generate_endpoint_with_meta_info(): + def process_fn(_: str) -> ProcessResult: + return ProcessResult( + text="ok", + finish_reason="stop", + cached_tokens=5, + meta_info=ProcessResultMetaInfo( + weight_version="v2.0", + routed_experts="encoded_data", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ), + ) + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + data = response.json() + + assert data == { + "text": "ok", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": 3, + "cached_tokens": 5, + "completion_tokens": 1, + "output_token_logprobs": [[-0.0, 562]], + "weight_version": "v2.0", + "routed_experts": "encoded_data", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + }, + } + + def test_request_log_and_reset_stats(mock_server): mock_server.reset_stats() assert len(mock_server.request_log) == 0 From 4cc51e14f0a4df6755b68c7c4cfafd19b3e9b736 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:10:24 +0800 Subject: [PATCH 19/77] Add thorough test for single turn generate function (#450) --- miles/rollout/generate_hub/single_turn.py | 100 ++++ .../modular_rollout/orchestration_common.py | 3 +- miles/utils/misc.py | 5 +- tests/fixtures/rollout_integration.py | 9 +- tests/rollout/generate_hub/__init__.py | 0 .../rollout/generate_hub/test_single_turn.py | 430 ++++++++++++++++++ .../modular_rollout/integration/utils.py | 2 +- 7 files changed, 537 insertions(+), 12 deletions(-) create mode 100644 miles/rollout/generate_hub/single_turn.py create mode 100644 tests/rollout/generate_hub/__init__.py create mode 100644 tests/rollout/generate_hub/test_single_turn.py diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py new file mode 100644 index 0000000000..3a09d3dfdd --- /dev/null +++ b/miles/rollout/generate_hub/single_turn.py @@ -0,0 +1,100 @@ +import numpy as np +import pybase64 + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.utils.http_utils import post +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.types import Sample + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + """Generate using traditional SGLang router with token-based workflow""" + state = input.state + args = input.args + sample = input.sample + sampling_params = input.sampling_params + + if args.ci_test: + assert isinstance(sample.prompt, str) + + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + assert ( + sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED + ), f"Sample status is {sample.status}" + + if state.processor: + processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + sample.multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + else: + prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False) + + if len(sample.response) > 0: + sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) + + assert ( + sampling_params["max_new_tokens"] >= 0 + ), f"max_new_tokens: {sampling_params['max_new_tokens']} should not be less than 0" + if sampling_params["max_new_tokens"] == 0: + sample.status = Sample.Status.TRUNCATED + return GenerateFnOutput(samples=sample) + + # Prepare payload for sglang server + payload = { + "sampling_params": sampling_params, + "return_logprob": True, + } + + if args.use_rollout_routing_replay: + payload["return_routed_experts"] = True + + if sample.multimodal_inputs and sample.multimodal_inputs["images"]: + image_data = sample.multimodal_inputs["images"] + payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] + + # Use existing tokens for multi-turn or tokenize the new prompt + if len(sample.response) > 0: + payload["input_ids"] = sample.tokens + else: + payload["input_ids"] = prompt_ids + if not sample.tokens: # Initialize sample.tokens for the first turn + sample.tokens = prompt_ids + + output = await post(url, payload) + + if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: + from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + + sample = await postprocess_sample_with_radix_tree(args, sample, output) + else: + if "output_token_logprobs" in output["meta_info"]: + new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + new_response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] + else: + new_response_tokens, new_response_log_probs = [], [] + + # Update sample with tokens directly - avoiding re-tokenization + sample.tokens = sample.tokens + new_response_tokens + sample.response_length += len(new_response_tokens) + sample.response += output["text"] + + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += new_response_log_probs + + if "routed_experts" in output["meta_info"]: + sample.rollout_routed_experts = np.frombuffer( + pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")), + dtype=np.int32, + ).reshape( + len(sample.tokens) - 1, + args.num_layers, + args.moe_router_topk, + ) + + sample.update_from_meta_info(args, output["meta_info"]) + + return GenerateFnOutput(samples=sample) diff --git a/miles/rollout/modular_rollout/orchestration_common.py b/miles/rollout/modular_rollout/orchestration_common.py index da9e90654b..ab0f55f2b2 100644 --- a/miles/rollout/modular_rollout/orchestration_common.py +++ b/miles/rollout/modular_rollout/orchestration_common.py @@ -3,10 +3,9 @@ from argparse import Namespace from typing import Any - from miles.rollout.base_types import GenerateFnInput +from miles.rollout.generate_hub.single_turn import generate from miles.rollout.modular_rollout.compatibility import load_generate_function -from miles.rollout.modular_rollout.inference_wrapper import generate from miles.rollout.rm_hub import async_rm, batched_async_rm from miles.utils.processing_utils import load_processor, load_tokenizer from miles.utils.types import Sample diff --git a/miles/utils/misc.py b/miles/utils/misc.py index fa772b5222..88e2213518 100644 --- a/miles/utils/misc.py +++ b/miles/utils/misc.py @@ -67,8 +67,9 @@ def __call__(cls, *args, **kwargs): cls._instances[cls] = instance return cls._instances[cls] - def clear_instances(cls): - cls._instances = {} + @staticmethod + def clear_all_instances(): + SingletonMeta._instances.clear() def exec_command(cmd: str, capture_output: bool = False) -> str | None: diff --git a/tests/fixtures/rollout_integration.py b/tests/fixtures/rollout_integration.py index ea2c3aa0a3..74ce0b5134 100644 --- a/tests/fixtures/rollout_integration.py +++ b/tests/fixtures/rollout_integration.py @@ -10,7 +10,6 @@ import requests from miles.rollout.data_source import DataSource, RolloutDataSourceWithBuffer -from miles.rollout.modular_rollout.orchestration_common import GenerateState from miles.router.router import MilesRouter from miles.utils.arguments import parse_args from miles.utils.http_utils import find_available_port, init_http_client @@ -91,10 +90,6 @@ def _write_jsonl(path: str, rows: list[dict]) -> None: Path(path).write_text("".join(json.dumps(row, ensure_ascii=False) + "\n" for row in rows), encoding="utf-8") -def _cleanup_legacy_singleton(): - SingletonMeta._instances.pop(GenerateState, None) - - DEFAULT_DATA_ROWS = [{"input": "What is 1+7?", "label": "8"}] @@ -111,7 +106,7 @@ def rollout_integration_env(tmp_path, request) -> IntegrationEnv: router_port = find_available_port(20000) args = _build_args(data_path=data_path, router_port=router_port, extra_argv=config.extra_argv) - _cleanup_legacy_singleton() + SingletonMeta.clear_all_instances() with with_mock_server(model_name=args.hf_checkpoint, latency=config.latency) as mock_server: with _with_miles_router(args) as router_server: @@ -125,4 +120,4 @@ def rollout_integration_env(tmp_path, request) -> IntegrationEnv: data_source = RolloutDataSourceWithBuffer(args) yield IntegrationEnv(args=args, data_source=data_source, mock_server=mock_server) - _cleanup_legacy_singleton() + SingletonMeta.clear_all_instances() diff --git a/tests/rollout/generate_hub/__init__.py b/tests/rollout/generate_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/rollout/generate_hub/test_single_turn.py b/tests/rollout/generate_hub/test_single_turn.py new file mode 100644 index 0000000000..f9a63716bd --- /dev/null +++ b/tests/rollout/generate_hub/test_single_turn.py @@ -0,0 +1,430 @@ +from argparse import Namespace +from dataclasses import dataclass +from typing import Any +from unittest.mock import patch + +import numpy as np +import pybase64 +import pytest +import torch +from PIL import Image +from transformers import AutoProcessor + +from miles.rollout.base_types import GenerateFnInput +from miles.rollout.modular_rollout.orchestration_common import GenerateState +from miles.utils.async_utils import run +from miles.utils.http_utils import init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server +from miles.utils.types import Sample + +# ------------------------------------ fixtures and consts ---------------------------------------- + + +MODEL_NAME = "Qwen/Qwen3-0.6B" +PROMPT = "What is 1+7?" +PROMPT_TOKENS = [3838, 374, 220, 16, 10, 22, 30] +RESPONSE_TOKENS = [59, 79075, 90, 23, 92] +RESPONSE_TEXT = "\\boxed{8}" +RESPONSE_LOG_PROBS = [-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125] +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} + + +@pytest.fixture(params=["sglang_rollout", "modular_rollout"]) +def variant(request): + return request.param + + +def expected_request( + variant: str, + *, + input_ids: list[int] | None = None, + sampling_params: dict | None = None, + return_routed_experts: bool = False, + image_data: list[str] | None = None, +) -> dict: + result = { + "input_ids": input_ids or PROMPT_TOKENS, + "sampling_params": sampling_params or DEFAULT_SAMPLING_PARAMS, + "return_logprob": True, + } + if variant == "modular_rollout" or return_routed_experts: + result["return_routed_experts"] = return_routed_experts + if image_data is not None: + result["image_data"] = image_data + return result + + +def expected_sample( + *, + prompt: str = PROMPT, + response: str = RESPONSE_TEXT, + response_length: int = 5, + tokens: list[int] | None = None, + rollout_log_probs: list[float] | None = None, + status: Sample.Status = Sample.Status.COMPLETED, + cached_tokens: int = 0, + prompt_tokens: int = 7, + weight_versions: list[str] | None = None, + rollout_routed_experts: np.ndarray | None = None, + spec_info: Sample.SpecInfo | None = None, + multimodal_inputs: dict | None = None, + multimodal_train_inputs: dict | None = None, +) -> Sample: + return Sample( + group_index=None, + index=None, + prompt=prompt, + tokens=tokens if tokens is not None else PROMPT_TOKENS + RESPONSE_TOKENS, + multimodal_inputs=multimodal_inputs, + multimodal_train_inputs=multimodal_train_inputs, + response=response, + response_length=response_length, + label=None, + reward=None, + loss_mask=None, + weight_versions=weight_versions or [], + rollout_log_probs=rollout_log_probs if rollout_log_probs is not None else RESPONSE_LOG_PROBS, + rollout_routed_experts=rollout_routed_experts, + remove_sample=False, + status=status, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=spec_info or Sample.SpecInfo(), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=cached_tokens, total_prompt_tokens=prompt_tokens), + ) + + +def make_args( + *, + router_port: int, + use_rollout_routing_replay: bool = False, + sglang_speculative_algorithm: str | None = None, + model_name: str = MODEL_NAME, +) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + model_name, + "--prompt-data", + "/dev/null", + "--rm-type", + "math", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + if use_rollout_routing_replay: + argv.append("--use-rollout-routing-replay") + if sglang_speculative_algorithm: + argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) + + from miles.utils.arguments import parse_args + + with patch("sys.argv", argv): + args = parse_args() + + init_http_client(args) + return args + + +async def call_generate(variant: str, args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: + if variant == "sglang_rollout": + from miles.rollout.sglang_rollout import generate + + return await generate(args, sample, sampling_params.copy()) + elif variant == "modular_rollout": + from miles.rollout.generate_hub.single_turn import generate + + state = GenerateState(args) + output = await generate( + GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params.copy(), evaluation=False) + ) + return output.samples + else: + raise NotImplementedError + + +@dataclass +class GenerateEnv: + args: Namespace + mock_server: Any + + +@dataclass +class GenerateResult: + sample: Sample + requests: list[dict] + + +@pytest.fixture +def env(request): + SingletonMeta.clear_all_instances() + params = getattr(request, "param", {}) + args_kwargs = params.get("args_kwargs", {}) + model_name = args_kwargs.get("model_name", MODEL_NAME) + + def process_fn(_): + x = params.get("process_fn_kwargs", {}) + return ProcessResult( + text=x.get("response_text", RESPONSE_TEXT), + finish_reason=x.get("finish_reason", "stop"), + cached_tokens=x.get("cached_tokens", 0), + meta_info=ProcessResultMetaInfo( + weight_version=x.get("weight_version"), + routed_experts=x.get("routed_experts"), + spec_accept_token_num=x.get("spec_accept_token_num"), + spec_draft_token_num=x.get("spec_draft_token_num"), + spec_verify_ct=x.get("spec_verify_ct"), + ), + ) + + with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server: + other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} + args = make_args(router_port=mock_server.port, model_name=model_name, **other_args_kwargs) + yield GenerateEnv(args=args, mock_server=mock_server) + + SingletonMeta.clear_all_instances() + + +def make_sample(tokens=None, response="", response_length=0, status=Sample.Status.PENDING, multimodal_inputs=None): + return Sample( + prompt=PROMPT, + tokens=tokens or [], + response=response, + response_length=response_length, + status=status, + multimodal_inputs=multimodal_inputs, + ) + + +def run_generate(variant: str, env: GenerateEnv, sample: Sample | None = None, sampling_params: dict | None = None): + env.mock_server.request_log.clear() + result_sample = run( + call_generate(variant, env.args, sample or make_sample(), sampling_params or DEFAULT_SAMPLING_PARAMS) + ) + return GenerateResult(sample=result_sample, requests=list(env.mock_server.request_log)) + + +# ------------------------------------ tests ---------------------------------------- + + +class TestBasicGeneration: + def test_basic_generation(self, variant, env): + result = run_generate(variant, env) + assert result.requests == [expected_request(variant)] + assert result.sample == expected_sample() + + +class TestResumedSingleTurn: + def test_two_consecutive_calls_on_same_sample(self, variant, env): + partial_text = "\\boxed" + partial_tokens = [59, 79075] + partial_log_probs = [-0.0, -0.0078125] + + remaining_text = "{8}" + remaining_tokens = [90, 23, 92] + remaining_log_probs = [-0.0, -0.0078125, -0.015625] + + env.mock_server.process_fn = lambda _: ProcessResult(text=partial_text, finish_reason="abort") + sample = make_sample() + result1 = run_generate(variant, env, sample) + assert result1.requests == [expected_request(variant)] + assert result1.sample == expected_sample( + response=partial_text, + response_length=2, + tokens=PROMPT_TOKENS + partial_tokens, + rollout_log_probs=partial_log_probs, + status=Sample.Status.ABORTED, + ) + + env.mock_server.process_fn = lambda _: ProcessResult(text=remaining_text, finish_reason="stop") + result2 = run_generate(variant, env, result1.sample) + tokens_after_turn1 = PROMPT_TOKENS + partial_tokens + assert result2.requests == [ + expected_request( + variant, + input_ids=tokens_after_turn1, + sampling_params={"max_new_tokens": 14, "temperature": 0.7}, + ) + ] + assert result2.sample == expected_sample( + response=partial_text + remaining_text, + response_length=2 + 3, + tokens=tokens_after_turn1 + remaining_tokens, + rollout_log_probs=partial_log_probs + remaining_log_probs, + prompt_tokens=len(PROMPT_TOKENS) + len(tokens_after_turn1), + status=Sample.Status.COMPLETED, + ) + + +class TestFinishReason: + @pytest.mark.parametrize( + "env,expected_status", + [ + ({"process_fn_kwargs": {"finish_reason": "stop"}}, Sample.Status.COMPLETED), + ({"process_fn_kwargs": {"finish_reason": "length"}}, Sample.Status.TRUNCATED), + ({"process_fn_kwargs": {"finish_reason": "abort"}}, Sample.Status.ABORTED), + ], + indirect=["env"], + ) + def test_finish_reason_sets_status(self, variant, env, expected_status): + result = run_generate(variant, env) + assert result.requests == [expected_request(variant)] + assert result.sample == expected_sample(status=expected_status) + + +class TestRoutedExperts: + @pytest.mark.parametrize( + "env", + [ + { + "args_kwargs": {"use_rollout_routing_replay": True}, + "process_fn_kwargs": {"routed_experts": "placeholder"}, + } + ], + indirect=True, + ) + def test_routed_experts_enabled_and_parsed(self, variant, env): + num_layers, moe_router_topk = 2, 4 + num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS) + routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape( + num_tokens - 1, num_layers, moe_router_topk + ) + + env.args.num_layers = num_layers + env.args.moe_router_topk = moe_router_topk + routed_experts_str = pybase64.b64encode(routed_experts_array.tobytes()).decode("ascii") + env.mock_server.process_fn = lambda _: ProcessResult( + text=RESPONSE_TEXT, + finish_reason="stop", + meta_info=ProcessResultMetaInfo(routed_experts=routed_experts_str), + ) + + result = run_generate(variant, env) + assert result.requests == [expected_request(variant, return_routed_experts=True)] + assert result.sample.rollout_routed_experts is not None + assert result.sample.rollout_routed_experts.shape == (num_tokens - 1, num_layers, moe_router_topk) + np.testing.assert_array_equal(result.sample.rollout_routed_experts, routed_experts_array) + + +class TestMetaInfo: + @pytest.mark.parametrize( + "env", [{"process_fn_kwargs": {"cached_tokens": 3, "weight_version": "v1.0"}}], indirect=True + ) + def test_meta_info_fields_updated(self, variant, env): + result = run_generate(variant, env) + assert result.requests == [expected_request(variant)] + assert result.sample == expected_sample(cached_tokens=3, weight_versions=["v1.0"]) + + @pytest.mark.parametrize( + "env", + [ + { + "args_kwargs": {"sglang_speculative_algorithm": "EAGLE"}, + "process_fn_kwargs": {"spec_accept_token_num": 10, "spec_draft_token_num": 15, "spec_verify_ct": 3}, + } + ], + indirect=True, + ) + def test_spec_info_updated(self, variant, env): + result = run_generate(variant, env) + assert result.requests == [expected_request(variant)] + assert result.sample == expected_sample( + spec_info=Sample.SpecInfo( + spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3, completion_token_num=5 + ) + ) + + +class TestInputStatusValidation: + @pytest.mark.parametrize("status", [Sample.Status.PENDING, Sample.Status.ABORTED]) + def test_allowed_statuses(self, variant, env, status): + result = run_generate(variant, env, make_sample(status=status)) + assert result.requests == [expected_request(variant)] + assert result.sample.status == Sample.Status.COMPLETED + + @pytest.mark.parametrize("status", [Sample.Status.COMPLETED, Sample.Status.TRUNCATED]) + def test_rejected_statuses(self, variant, env, status): + with pytest.raises(AssertionError): + run_generate(variant, env, make_sample(status=status)) + + +class TestPayloadStructure: + def test_sampling_params_passed_through(self, variant, env): + result = run_generate(variant, env, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) + assert result.requests == [ + expected_request(variant, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) + ] + assert result.sample == expected_sample() + + +class TestBoundaryConditions: + def test_max_new_tokens_zero_returns_truncated(self, variant, env): + existing_tokens = [1, 2, 3, 4, 5, 6, 7] + list(range(100, 110)) + sample = make_sample(tokens=existing_tokens, response="x" * 10, response_length=10) + + result = run_generate(variant, env, sample, {"max_new_tokens": 10, "temperature": 0.7}) + assert result.requests == [] + assert result.sample.status == Sample.Status.TRUNCATED + + +class TestEmptyResponse: + @pytest.mark.parametrize("env", [{"process_fn_kwargs": {"response_text": ""}}], indirect=True) + def test_empty_response(self, variant, env): + result = run_generate(variant, env) + assert result.requests == [expected_request(variant)] + assert result.sample == expected_sample( + response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[] + ) + + +VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" + + +class TestMultimodal: + @pytest.mark.parametrize("env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True) + def test_multimodal_inputs_processed(self, variant, env): + test_image = Image.new("RGB", (64, 64), color="red") + multimodal_inputs = {"images": [test_image]} + processor = AutoProcessor.from_pretrained(VLM_MODEL_NAME, trust_remote_code=True) + expected_mti = { + k: v + for k, v in processor(text=PROMPT, **multimodal_inputs).items() + if k not in ["input_ids", "attention_mask"] + } + + result = run_generate(variant, env, make_sample(multimodal_inputs=multimodal_inputs)) + + assert result.requests == [ + expected_request( + variant, + input_ids=PROMPT_TOKENS, + image_data=[encode_image_for_rollout_engine(test_image)], + ) + ] + actual_mti = result.sample.multimodal_train_inputs + assert actual_mti is not None + assert set(actual_mti.keys()) == set(expected_mti.keys()) + assert torch.all(actual_mti["pixel_values"] == expected_mti["pixel_values"]) + assert torch.all(actual_mti["image_grid_thw"] == expected_mti["image_grid_thw"]) + assert result.sample == expected_sample( + tokens=PROMPT_TOKENS + RESPONSE_TOKENS, + multimodal_inputs=multimodal_inputs, + multimodal_train_inputs=actual_mti, + ) diff --git a/tests/rollout/modular_rollout/integration/utils.py b/tests/rollout/modular_rollout/integration/utils.py index 112409595a..260b3f1516 100644 --- a/tests/rollout/modular_rollout/integration/utils.py +++ b/tests/rollout/modular_rollout/integration/utils.py @@ -40,7 +40,7 @@ def expected_sample(*, group_index: int | None) -> Sample: "--eval-function-path", "miles.rollout.modular_rollout.orchestration_eval.SimpleEvalRolloutFn", "--custom-generate-function-path", - "miles.rollout.modular_rollout.inference_wrapper.generate", + "miles.rollout.generate_hub.single_turn.generate", ] MIXED_DATA_ROWS = [ From d1f29ed253f0b0f198e1d331db7ed73ff956cbe4 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:10:43 +0800 Subject: [PATCH 20/77] Refactor single turn generate function (#451) --- .../generate_hub/generate_endpoint_wrapper.py | 94 ++++++++++++++++++ miles/rollout/generate_hub/single_turn.py | 95 +++---------------- 2 files changed, 107 insertions(+), 82 deletions(-) create mode 100644 miles/rollout/generate_hub/generate_endpoint_wrapper.py diff --git a/miles/rollout/generate_hub/generate_endpoint_wrapper.py b/miles/rollout/generate_hub/generate_endpoint_wrapper.py new file mode 100644 index 0000000000..c927c05794 --- /dev/null +++ b/miles/rollout/generate_hub/generate_endpoint_wrapper.py @@ -0,0 +1,94 @@ +""" +Wrapper to integrate SGLang's `/generate` endpoint with RL things like Sample. +""" + +import numpy as np +import pybase64 + +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.types import Sample + + +# Make this an isolated function because users may want to compute their own +async def compute_prompt_ids_from_sample(state, sample): + if state.processor: + processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + # TODO shall we move it to other places? then can make this function immutable + sample.multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + return prompt_ids + else: + return state.tokenizer.encode(sample.prompt, add_special_tokens=False) + + +async def compute_request_payload(state, sample, prompt_ids: list[int], sampling_params: dict): + assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}" + + max_new_tokens = sampling_params.pop("max_new_tokens") + if len(sample.response) > 0: + max_new_tokens -= len(sample.tokens) - len(prompt_ids) + + # Prepare payload for sglang server + payload = { + # Use existing tokens for multi-turn or tokenize the new prompt + "input_ids": sample.tokens if len(sample.response) > 0 else prompt_ids, + "sampling_params": { + **sampling_params, + "max_new_tokens": max_new_tokens, + }, + "return_logprob": True, + "return_routed_experts": state.args.use_rollout_routing_replay, + } + if image_data := (sample.multimodal_inputs or {}).get("images"): + payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] + + assert payload["sampling_params"]["max_new_tokens"] >= 0 + + if payload["sampling_params"]["max_new_tokens"] == 0: + return None, Sample.Status.TRUNCATED + + return payload, None + + +async def update_sample_from_response(args, sample: Sample, payload: dict, output: dict): + # Initialize sample.tokens for the first turn + if (len(sample.response) == 0) and not sample.tokens: + sample.tokens = payload["input_ids"] + + if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: + from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + + # TODO may rename to match + await postprocess_sample_with_radix_tree(args, sample, output) + else: + if x := output["meta_info"].get("output_token_logprobs"): + new_response_tokens = [item[1] for item in x] + new_response_log_probs = [item[0] for item in x] + else: + new_response_tokens, new_response_log_probs = [], [] + + # Update sample with tokens directly - avoiding re-tokenization + sample.tokens = sample.tokens + new_response_tokens + sample.response_length += len(new_response_tokens) + sample.response += output["text"] + + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += new_response_log_probs + + sample.rollout_routed_experts = _get_rollout_routed_experts_from_response(args, sample, output) + + # TODO may unify (currently there are both methods inside Sample and separate functions) + sample.update_from_meta_info(args, output["meta_info"]) + + +def _get_rollout_routed_experts_from_response(args, sample, output): + info = output["meta_info"].get("routed_experts") + if info is None: + return None + + x = np.frombuffer(pybase64.b64decode(info.encode("ascii")), dtype=np.int32) + x = x.reshape(len(sample.tokens) - 1, args.num_layers, args.moe_router_topk) + return x diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py index 3a09d3dfdd..f8c52d490d 100644 --- a/miles/rollout/generate_hub/single_turn.py +++ b/miles/rollout/generate_hub/single_turn.py @@ -1,100 +1,31 @@ -import numpy as np -import pybase64 +""" +Simple single-turn generation. +""" from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_hub.generate_endpoint_wrapper import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) from miles.utils.http_utils import post -from miles.utils.processing_utils import encode_image_for_rollout_engine -from miles.utils.types import Sample async def generate(input: GenerateFnInput) -> GenerateFnOutput: - """Generate using traditional SGLang router with token-based workflow""" - state = input.state args = input.args sample = input.sample - sampling_params = input.sampling_params - - if args.ci_test: - assert isinstance(sample.prompt, str) url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - assert ( - sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED - ), f"Sample status is {sample.status}" - - if state.processor: - processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) - prompt_ids = processor_output["input_ids"][0] - sample.multimodal_train_inputs = { - k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] - } or None - else: - prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False) - - if len(sample.response) > 0: - sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) + prompt_ids = await compute_prompt_ids_from_sample(input.state, sample) + payload, halt_status = await compute_request_payload(input.state, sample, prompt_ids, input.sampling_params) - assert ( - sampling_params["max_new_tokens"] >= 0 - ), f"max_new_tokens: {sampling_params['max_new_tokens']} should not be less than 0" - if sampling_params["max_new_tokens"] == 0: - sample.status = Sample.Status.TRUNCATED + if payload is None: + sample.status = halt_status return GenerateFnOutput(samples=sample) - # Prepare payload for sglang server - payload = { - "sampling_params": sampling_params, - "return_logprob": True, - } - - if args.use_rollout_routing_replay: - payload["return_routed_experts"] = True - - if sample.multimodal_inputs and sample.multimodal_inputs["images"]: - image_data = sample.multimodal_inputs["images"] - payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] - - # Use existing tokens for multi-turn or tokenize the new prompt - if len(sample.response) > 0: - payload["input_ids"] = sample.tokens - else: - payload["input_ids"] = prompt_ids - if not sample.tokens: # Initialize sample.tokens for the first turn - sample.tokens = prompt_ids - output = await post(url, payload) - if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: - from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree - - sample = await postprocess_sample_with_radix_tree(args, sample, output) - else: - if "output_token_logprobs" in output["meta_info"]: - new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] - new_response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] - else: - new_response_tokens, new_response_log_probs = [], [] - - # Update sample with tokens directly - avoiding re-tokenization - sample.tokens = sample.tokens + new_response_tokens - sample.response_length += len(new_response_tokens) - sample.response += output["text"] - - if sample.rollout_log_probs is None: - sample.rollout_log_probs = [] - sample.rollout_log_probs += new_response_log_probs - - if "routed_experts" in output["meta_info"]: - sample.rollout_routed_experts = np.frombuffer( - pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")), - dtype=np.int32, - ).reshape( - len(sample.tokens) - 1, - args.num_layers, - args.moe_router_topk, - ) - - sample.update_from_meta_info(args, output["meta_info"]) + await update_sample_from_response(args, sample, payload=payload, output=output) return GenerateFnOutput(samples=sample) From 491e71be4787e705168b4e0cde645ef7a8ce3a64 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:11:05 +0800 Subject: [PATCH 21/77] Allow user-provided function to add extra arguments (#452) --- miles/utils/arguments.py | 16 ++++++++++ tests/utils/test_arguments.py | 58 +++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 tests/utils/test_arguments.py diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 79b2c419ca..41ebaf00fe 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -12,6 +12,7 @@ from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from miles.utils.logging_utils import configure_logger +from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -1344,6 +1345,20 @@ def add_ci_arguments(parser): ) return parser + def add_user_provided_function_arguments(parser): + args_partial, _ = parser.parse_known_args() + for path in [ + args_partial.rollout_function_path, + args_partial.custom_generate_function_path, + ]: + try: + fn = load_function(path) + except (ModuleNotFoundError, ValueError): + continue + if fn is not None and callable(getattr(fn, "add_arguments", None)): + fn.add_arguments(parser) + return parser + def add_sglang_tp_size(): temp_parser = argparse.ArgumentParser(add_help=False) temp_parser.add_argument("--rollout-num-gpus-per-engine", type=int, default=1) @@ -1374,6 +1389,7 @@ def add_sglang_tp_size(): parser = add_prefill_decode_disaggregation_arguments(parser) parser = add_ci_arguments(parser) parser = add_custom_megatron_plugins_arguments(parser) + parser = add_user_provided_function_arguments(parser) reset_arg( parser, "--custom-config-path", diff --git a/tests/utils/test_arguments.py b/tests/utils/test_arguments.py new file mode 100644 index 0000000000..9bd1a620d6 --- /dev/null +++ b/tests/utils/test_arguments.py @@ -0,0 +1,58 @@ +import argparse +import sys +from unittest.mock import patch + +import pytest + +from miles.utils.arguments import get_miles_extra_args_provider +from miles.utils.misc import function_registry + +PATH_ARGS = ["--rollout-function-path", "--custom-generate-function-path"] +REQUIRED_ARGS = ["--rollout-batch-size", "64"] + + +def make_class_with_add_arguments(): + class MyFn: + @classmethod + def add_arguments(cls, parser): + parser.add_argument("--my-custom-arg", type=int, default=42) + + return MyFn + + +def make_function_with_add_arguments(): + def my_fn(): + pass + + my_fn.add_arguments = lambda parser: parser.add_argument("--my-custom-arg", type=int, default=42) + return my_fn + + +def make_function_without_add_arguments(): + def my_fn(): + pass + + return my_fn + + +@pytest.mark.parametrize("path_arg", PATH_ARGS) +class TestAddArgumentsSupport: + + @pytest.mark.parametrize("fn_factory", [make_class_with_add_arguments, make_function_with_add_arguments]) + def test_add_arguments_is_called_and_arg_is_parsed(self, path_arg, fn_factory): + fn = fn_factory() + with function_registry.temporary("test:fn", fn), patch.object( + sys, "argv", ["test", path_arg, "test:fn", "--my-custom-arg", "100"] + REQUIRED_ARGS + ): + parser = argparse.ArgumentParser() + get_miles_extra_args_provider()(parser) + args, _ = parser.parse_known_args() + assert args.my_custom_arg == 100 + + def test_skips_function_without_add_arguments(self, path_arg): + fn = make_function_without_add_arguments() + with function_registry.temporary("test:fn", fn), patch.object( + sys, "argv", ["test", path_arg, "test:fn"] + REQUIRED_ARGS + ): + parser = argparse.ArgumentParser() + get_miles_extra_args_provider()(parser) From f75a9a7e34c92bd50b743261285a09e0b6611564 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:11:23 +0800 Subject: [PATCH 22/77] Copy core of retool example into multi_turn.py and adapt to new API (#453) --- .../generate_hub/multi_turn_single_sample.py | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 miles/rollout/generate_hub/multi_turn_single_sample.py diff --git a/miles/rollout/generate_hub/multi_turn_single_sample.py b/miles/rollout/generate_hub/multi_turn_single_sample.py new file mode 100644 index 0000000000..a6b049ead8 --- /dev/null +++ b/miles/rollout/generate_hub/multi_turn_single_sample.py @@ -0,0 +1,136 @@ +""" +Simple multi-turn generation with tool calling. +""" + +from typing import Any + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.utils.http_utils import post +from miles.utils.types import Sample + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + args = input.args + sample = input.sample + tokenizer = input.state.tokenizer + + assert not args.partial_rollout, "Partial rollout is not supported for " "this function at the moment." + + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + # Set up the initial prompt with system prompt and tools (outside the loop) + tool_specs = tool_registry.get_tool_specs() + prompt = format_conversation_with_tools(prompt=sample.prompt, tools=tool_specs) + + prompt_tokens_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"] + response = "" + response_token_ids = [] + loss_masks = [] + tool_call_count = 0 # Track actual tool call rounds + + for turn in range(TOOL_CONFIGS["max_turns"]): + # Check if total length exceeds max context length + total_length = len(prompt_tokens_ids) + len(response_token_ids) + if args.rollout_max_context_len is not None: + max_context_length = args.rollout_max_context_len + else: + max_context_length = args.context_parallel_size * args.max_tokens_per_gpu + if total_length >= max_context_length: + sample.status = Sample.Status.TRUNCATED + break + + # Use token IDs instead of text + current_token_ids = prompt_tokens_ids + response_token_ids + payload = { + "input_ids": current_token_ids, + "sampling_params": input.sampling_params, + "return_logprob": True, # Request log probabilities for training + } + + output = await post(url, payload) + + # Handle abort + if output["meta_info"]["finish_reason"]["type"] == "abort": + sample.status = Sample.Status.ABORTED + return GenerateFnOutput(samples=sample) + + if "output_token_logprobs" in output["meta_info"]: + cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + cur_response = tokenizer.decode(cur_response_token_ids) + cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += cur_log_probs + + else: + cur_response = output["text"] + cur_response = postprocess_responses(cur_response) + cur_response_token_ids = tokenizer(cur_response, add_special_tokens=False)["input_ids"] + + response += cur_response + response_token_ids += cur_response_token_ids + loss_masks += [1] * len(cur_response_token_ids) + + # Check length limit + if output["meta_info"]["finish_reason"]["type"] == "length": + break + + next_obs, done = await execute_predictions(cur_response) + if done: + break + + # Count tool calls (when we get interpreter output, it means a tool + # was called) + if "" in next_obs: + tool_call_count += 1 + + assert next_obs != "", "Next observation should not be empty." + obs_tokens_ids = tokenizer(next_obs, add_special_tokens=False)["input_ids"] + response += next_obs + response_token_ids += obs_tokens_ids + loss_masks += [0] * len(obs_tokens_ids) + + # Add dummy log probs for observation tokens (they won't be used due to loss_mask=0) + # Check if maximum tool call count reached + if sample.rollout_log_probs is not None: + sample.rollout_log_probs += [0.0] * len(obs_tokens_ids) + + assert len(response_token_ids) == len( + sample.rollout_log_probs + ), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps" + + if turn >= TOOL_CONFIGS["max_tool_calls"]: + break + + # Set sample attributes + sample.tokens = prompt_tokens_ids + response_token_ids + sample.response_length = len(response_token_ids) + sample.response = response + sample.loss_mask = loss_masks + + # Set status + sample.update_from_meta_info(args, output["meta_info"]) + + return GenerateFnOutput(samples=sample) + + +def format_conversation_with_tools( + prompt: str, tools: list[dict[str, Any]] = None, system_prompt: str = None, messages: list[dict[str, Any]] = None +) -> str: + return TODO + + +def postprocess_predictions(prediction: str): + """Extract action and content from prediction string""" + return TODO, TODO + + +def postprocess_responses(resp: str) -> str: + return TODO + + +async def execute_predictions(prediction: str) -> str: + """Execute predictions and return results""" + action, content = postprocess_predictions(prediction) + next_obs, done = TODO + return next_obs, done From 7157ebab8af3341db49a4c492f7dcc70ef56d7ec Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:11:50 +0800 Subject: [PATCH 23/77] Support tool response tokenization logic (#454) --- miles/rollout/generate_hub/tool_call_utils.py | 55 +++++++++++++ .../generate_hub/test_tool_call_utils.py | 80 +++++++++++++++++++ 2 files changed, 135 insertions(+) create mode 100644 miles/rollout/generate_hub/tool_call_utils.py create mode 100644 tests/rollout/generate_hub/test_tool_call_utils.py diff --git a/miles/rollout/generate_hub/tool_call_utils.py b/miles/rollout/generate_hub/tool_call_utils.py new file mode 100644 index 0000000000..d8a1ca574e --- /dev/null +++ b/miles/rollout/generate_hub/tool_call_utils.py @@ -0,0 +1,55 @@ +from typing import Any + + +_DUMMY_USER = {"role": "user", "content": "dummy"} + + +# TODO: very naive implementation, need the to-be-implemented e2e test to validate. +def tokenize_tool_responses( + tool_messages: list[dict[str, Any]], + tokenizer, +) -> list[int]: + return _tokenize_postfix_messages(tool_messages, tokenizer) + + +def _tokenize_postfix_messages( + postfix_messages: list[dict[str, Any]], + tokenizer, +) -> list[int]: + dummy_assistant = _build_dummy_assistant(postfix_messages) + base_messages = [_DUMMY_USER, dummy_assistant] + + messages_without = base_messages + messages_with = base_messages + postfix_messages + + tokens_with = tokenizer.apply_chat_template(messages_with, tokenize=True, add_generation_prompt=True) + tokens_without = tokenizer.apply_chat_template(messages_without, tokenize=True, add_generation_prompt=False) + + assert tokens_with[: len(tokens_without)] == tokens_without, ( + f"Fail to tokenize_tool_responses caused by token prefix mismatch. " + f"This can happen for thinking model or models with special chat template, " + f"and this simple example does not support it yet, " + f"since this means we cannot have a append-only token id list. " + f"{tokens_with=} {tokens_without=} " + f"{tokenizer.decode(tokens_with)=} {tokenizer.decode(tokens_without)=} " + ) + return tokens_with[len(tokens_without) :] + + +def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, Any]: + return { + "role": "assistant", + "content": "", + "reasoning_content": " ", + "tool_calls": [ + { + "id": resp.get("tool_call_id", f"call0000{i}"), + "type": "function", + "function": { + "name": resp.get("name", "dummy_func"), + "arguments": {}, + }, + } + for i, resp in enumerate(tool_responses) + ], + } diff --git a/tests/rollout/generate_hub/test_tool_call_utils.py b/tests/rollout/generate_hub/test_tool_call_utils.py new file mode 100644 index 0000000000..26d1330ae6 --- /dev/null +++ b/tests/rollout/generate_hub/test_tool_call_utils.py @@ -0,0 +1,80 @@ +import pytest + +from miles.rollout.generate_hub.tool_call_utils import _DUMMY_USER, _build_dummy_assistant, tokenize_tool_responses + +TOOL_CALL_TEST_MODELS = [ + "Qwen/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen3-0.6B", + "Qwen/Qwen3-4B-Instruct-2507", + "Qwen/Qwen3-Coder-30B-A3B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + "mistralai/Mistral-7B-Instruct-v0.3", + "deepseek-ai/DeepSeek-V3", + "stepfun-ai/step3", + "MiniMaxAI/MiniMax-M2", + "internlm/internlm3-8b-instruct", + "THUDM/glm-4-9b-chat", + "moonshotai/Kimi-K2-Instruct", + "XiaomiMiMo/MiMo-7B-RL", +] + +SINGLE_TOOL_CALL_ONLY_MODELS = [ + "meta-llama/Llama-3.2-1B-Instruct", +] + +# Models where tokenize->decode produces extra whitespace vs direct string diff +TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS = [ + "THUDM/glm-4-9b-chat", +] + +SAMPLE_TOOL_RESPONSES = [ + { + "role": "tool", + "tool_call_id": "call00000", + "content": '{"year": 2026}', + "name": "get_year", + }, + { + "role": "tool", + "tool_call_id": "call00001", + "content": '{"temperature": 25}', + "name": "get_temperature", + }, +] + + +class TestTokenizeToolResponses: + @pytest.mark.parametrize("num_tools", [1, 2]) + @pytest.mark.parametrize("model_name", TOOL_CALL_TEST_MODELS) + def test_tokenize_tool_responses(self, model_name, num_tools): + if num_tools > 1 and model_name in SINGLE_TOOL_CALL_ONLY_MODELS: + pytest.skip(f"{model_name} only supports single tool call") + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + tool_responses = SAMPLE_TOOL_RESPONSES[:num_tools] + assert len(tool_responses) == num_tools + + actual_token_ids = tokenize_tool_responses(tool_responses, tokenizer) + actual_str = tokenizer.decode(actual_token_ids) + + dummy_assistant = _build_dummy_assistant(tool_responses) + base_messages = [_DUMMY_USER, dummy_assistant] + expected_str = self._compute_chat_template_diff(base_messages, tool_responses, tokenizer) + + if model_name in TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS: + # Some models produce whitespace differences between tokenize->decode and direct string diff + actual_str = actual_str.replace(" ", "") + expected_str = expected_str.replace(" ", "") + + assert actual_str == expected_str, f"{model_name=}" + + @staticmethod + def _compute_chat_template_diff(base_messages, extra_messages, tokenizer) -> str: + text_with = tokenizer.apply_chat_template( + base_messages + extra_messages, tokenize=False, add_generation_prompt=True + ) + text_without = tokenizer.apply_chat_template(base_messages, tokenize=False, add_generation_prompt=False) + return text_with[len(text_without) :] From 1e74a8a0d41c3b5fd08aeadc876d3e6e39748094 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:12:13 +0800 Subject: [PATCH 24/77] Support mock tools and corresponding server replies (#456) --- miles/utils/test_utils/mock_tools.py | 131 ++++++++++++++++++++++ tests/utils/test_utils/test_mock_tools.py | 111 ++++++++++++++++++ 2 files changed, 242 insertions(+) create mode 100644 miles/utils/test_utils/mock_tools.py create mode 100644 tests/utils/test_utils/test_mock_tools.py diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py new file mode 100644 index 0000000000..83f1d94327 --- /dev/null +++ b/miles/utils/test_utils/mock_tools.py @@ -0,0 +1,131 @@ +import json + +from miles.utils.test_utils.mock_sglang_server import ProcessResult + +SAMPLE_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_year", + "description": "Get current year", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_temperature", + "description": "Get temperature for a location", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + }, +] + + +def _get_year(params: dict) -> str: + assert len(params) == 0 + return json.dumps({"year": 2026}) + + +def _get_temperature(params: dict) -> str: + assert params.get("location") == "Mars" + return json.dumps({"temperature": -60}) + + +TOOL_EXECUTORS = { + "get_year": _get_year, + "get_temperature": _get_temperature, +} + + +async def execute_tool_call(name: str, params: dict) -> str: + return TOOL_EXECUTORS[name](params) + + +MULTI_TURN_FIRST_PROMPT = ( + "<|im_start|>system\n" + "# Tools\n" + "\n" + "You may call one or more functions to assist with the user query.\n" + "\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n" + "\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" + "<|im_start|>user\n" + "What is 42 + year + temperature?<|im_end|>\n" + "<|im_start|>assistant\n" +) +MULTI_TURN_FIRST_RESPONSE = ( + "Let me get the year and temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "" +) + +MULTI_TURN_SECOND_PROMPT = ( + "<|im_start|>system\n" + "# Tools\n" + "\n" + "You may call one or more functions to assist with the user query.\n" + "\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n" + "\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" + "<|im_start|>user\n" + "What is 42 + year + temperature?<|im_end|>\n" + "<|im_start|>assistant\n" + "Let me get the year and temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "" + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" +) +MULTI_TURN_SECOND_RESPONSE = "The answer is: 42 + 2026 + -60 = 2008." + + +def multi_turn_tool_call_process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + MULTI_TURN_FIRST_PROMPT: MULTI_TURN_FIRST_RESPONSE, + MULTI_TURN_SECOND_PROMPT: MULTI_TURN_SECOND_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") diff --git a/tests/utils/test_utils/test_mock_tools.py b/tests/utils/test_utils/test_mock_tools.py new file mode 100644 index 0000000000..0a77a2a31f --- /dev/null +++ b/tests/utils/test_utils/test_mock_tools.py @@ -0,0 +1,111 @@ +import asyncio + +import pytest +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.function_call_parser import FunctionCallParser + +from miles.utils.test_utils.mock_tools import MULTI_TURN_FIRST_RESPONSE, SAMPLE_TOOLS, execute_tool_call + + +class TestExecuteToolCall: + def test_execute_get_year(self): + result = asyncio.run(execute_tool_call("get_year", {})) + assert result == '{"year": 2026}' + + def test_execute_get_temperature(self): + result = asyncio.run(execute_tool_call("get_temperature", {"location": "Mars"})) + assert result == '{"temperature": -60}' + + +class TestApplyChatTemplateWithTools: + EXPECTED_PROMPT_WITHOUT_TOOLS = ( + "<|im_start|>user\n" "What's the weather in Paris?<|im_end|>\n" "<|im_start|>assistant\n" + ) + + EXPECTED_PROMPT_WITH_TOOLS = ( + "<|im_start|>system\n" + "# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" + "<|im_start|>user\n" + "What's the weather in Paris?<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + @pytest.mark.parametrize( + "tools,expected", + [ + pytest.param(None, EXPECTED_PROMPT_WITHOUT_TOOLS, id="without_tools"), + pytest.param(SAMPLE_TOOLS, EXPECTED_PROMPT_WITH_TOOLS, id="with_tools"), + ], + ) + def test_apply_chat_template(self, tools, expected): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) + messages = [{"role": "user", "content": "What's the weather in Paris?"}] + + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools) + + assert prompt == expected + + +class TestSGLangFunctionCallParser: + """Test to demonstrate and ensure SGLang function call parser have features we need without breaking changes.""" + + @pytest.mark.parametrize( + "model_output,expected", + [ + pytest.param( + 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n', + ( + "Let me check for you.", + [ToolCallItem(tool_index=0, name="get_year", parameters="{}")], + ), + id="single_tool_call", + ), + pytest.param( + "I will get year and temperature.\n" + '\n{"name": "get_year", "arguments": {}}\n\n' + '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n', + ( + "I will get year and temperature.", + [ + ToolCallItem(tool_index=0, name="get_year", parameters="{}"), + ToolCallItem(tool_index=1, name="get_temperature", parameters='{"location": "Shanghai"}'), + ], + ), + id="multi_tool_calls", + ), + pytest.param( + "The weather is sunny today.", + ("The weather is sunny today.", []), + id="no_tool_call", + ), + pytest.param( + MULTI_TURN_FIRST_RESPONSE, + ( + "Let me get the year and temperature first.", + [ + ToolCallItem(tool_index=0, name="get_year", parameters="{}"), + ToolCallItem(tool_index=1, name="get_temperature", parameters='{"location": "Mars"}'), + ], + ), + id="multi_turn_first_response", + ), + ], + ) + def test_parse_non_stream(self, model_output, expected): + tools = TypeAdapter(list[Tool]).validate_python(SAMPLE_TOOLS) + parser = FunctionCallParser(tools=tools, tool_call_parser="qwen25") + assert parser.parse_non_stream(model_output) == expected From b609350eaf593796ed054bebd94f1036e1ef1519 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:12:35 +0800 Subject: [PATCH 25/77] Refactor to extract generation fixtures (#458) --- tests/conftest.py | 3 +- tests/fixtures/generation_fixtures.py | 181 ++++++++++++++ tests/fixtures/rollout_integration.py | 7 + .../rollout/generate_hub/test_single_turn.py | 222 +++++------------- 4 files changed, 244 insertions(+), 169 deletions(-) create mode 100644 tests/fixtures/generation_fixtures.py diff --git a/tests/conftest.py b/tests/conftest.py index 6697bd0b90..b04dc6bd0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from tests.fixtures.generation_fixtures import generation_env from tests.fixtures.rollout_integration import rollout_integration_env -_ = rollout_integration_env +_ = rollout_integration_env, generation_env diff --git a/tests/fixtures/generation_fixtures.py b/tests/fixtures/generation_fixtures.py new file mode 100644 index 0000000000..caae309f94 --- /dev/null +++ b/tests/fixtures/generation_fixtures.py @@ -0,0 +1,181 @@ +""" +Fixtures to test custom-generate-function +""" + +from argparse import Namespace +from dataclasses import dataclass +from typing import Any +from unittest.mock import patch + +import pytest + +from miles.rollout.base_types import GenerateFnInput +from miles.rollout.modular_rollout.compatibility import load_generate_function +from miles.rollout.modular_rollout.orchestration_common import GenerateState +from miles.utils.async_utils import run +from miles.utils.http_utils import init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server +from miles.utils.types import Sample + +MODEL_NAME = "Qwen/Qwen3-0.6B" +RESPONSE_TEXT = "\\boxed{8}" +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} + +VARIANT_TO_GENERATE_FN_PATH = { + "old_sglang_rollout": "miles.rollout.sglang_rollout.generate", + "single_turn": "miles.rollout.generate_hub.single_turn.generate", + "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn_single_sample.generate", +} + + +def make_sample( + *, + prompt: str | list[dict] = "What is 1+7?", + tokens: list[int] | None = None, + response: str = "", + response_length: int = 0, + status: Sample.Status = Sample.Status.PENDING, + multimodal_inputs: dict | None = None, +) -> Sample: + return Sample( + prompt=prompt, + tokens=tokens or [], + response=response, + response_length=response_length, + status=status, + multimodal_inputs=multimodal_inputs, + ) + + +@dataclass +class GenerateEnv: + args: Namespace + mock_server: Any + + +@dataclass +class GenerateResult: + sample: Sample + requests: list[dict] + + +def run_generate( + env: GenerateEnv, + sample: Sample, + sampling_params: dict[str, Any] | None = None, + *, + variant: str = "single_turn", +) -> GenerateResult: + env.mock_server.request_log.clear() + result_sample = run( + _call_generate( + env.args, + sample, + sampling_params or DEFAULT_SAMPLING_PARAMS, + variant=variant, + ) + ) + return GenerateResult(sample=result_sample, requests=list(env.mock_server.request_log)) + + +async def _call_generate( + args: Namespace, + sample: Sample, + sampling_params: dict[str, Any], + *, + variant: str = "single_turn", +) -> Sample: + generate_fn = load_generate_function(VARIANT_TO_GENERATE_FN_PATH[variant]) + state = GenerateState(args) + input = GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params.copy(), evaluation=False) + output = await generate_fn(input) + return output.samples + + +def make_args( + *, + router_port: int, + use_rollout_routing_replay: bool = False, + sglang_speculative_algorithm: str | None = None, + model_name: str = MODEL_NAME, + extra_argv: list[str] | None = None, + custom_generate_function_path: str | None = None, +) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + model_name, + "--prompt-data", + "/dev/null", + "--rm-type", + "math", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + if use_rollout_routing_replay: + argv.append("--use-rollout-routing-replay") + if sglang_speculative_algorithm: + argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) + if custom_generate_function_path: + argv.extend(["--custom-generate-function-path", custom_generate_function_path]) + if extra_argv: + argv.extend(extra_argv) + + from miles.utils.arguments import parse_args + + with patch("sys.argv", argv): + args = parse_args() + + init_http_client(args) + return args + + +@pytest.fixture +def generation_env(request, variant): + SingletonMeta.clear_all_instances() + params = getattr(request, "param", {}) + args_kwargs = params.get("args_kwargs", {}) + model_name = args_kwargs.get("model_name", MODEL_NAME) + custom_generate_function_path = VARIANT_TO_GENERATE_FN_PATH[variant] + + def process_fn(_): + x = params.get("process_fn_kwargs", {}) + return ProcessResult( + text=x.get("response_text", RESPONSE_TEXT), + finish_reason=x.get("finish_reason", "stop"), + cached_tokens=x.get("cached_tokens", 0), + meta_info=ProcessResultMetaInfo( + weight_version=x.get("weight_version"), + routed_experts=x.get("routed_experts"), + spec_accept_token_num=x.get("spec_accept_token_num"), + spec_draft_token_num=x.get("spec_draft_token_num"), + spec_verify_ct=x.get("spec_verify_ct"), + ), + ) + + with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server: + other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} + args = make_args( + router_port=mock_server.port, + model_name=model_name, + custom_generate_function_path=custom_generate_function_path, + **other_args_kwargs, + ) + yield GenerateEnv(args=args, mock_server=mock_server) + + SingletonMeta.clear_all_instances() diff --git a/tests/fixtures/rollout_integration.py b/tests/fixtures/rollout_integration.py index 74ce0b5134..60dd4b7d65 100644 --- a/tests/fixtures/rollout_integration.py +++ b/tests/fixtures/rollout_integration.py @@ -1,3 +1,8 @@ +""" +Fixtures to test rollout-function +""" + +# TODO may rename to rollout_fixutres.py to be aligned import json from argparse import Namespace from collections.abc import Iterator @@ -25,6 +30,7 @@ class IntegrationEnvConfig: latency: float = 0.0 +# TODO may rename to RolloutEnv @dataclass(frozen=True) class IntegrationEnv: args: Namespace @@ -93,6 +99,7 @@ def _write_jsonl(path: str, rows: list[dict]) -> None: DEFAULT_DATA_ROWS = [{"input": "What is 1+7?", "label": "8"}] +# TODO may rename to rollout_env @pytest.fixture def rollout_integration_env(tmp_path, request) -> IntegrationEnv: config = request.param diff --git a/tests/rollout/generate_hub/test_single_turn.py b/tests/rollout/generate_hub/test_single_turn.py index f9a63716bd..3c7d0954e6 100644 --- a/tests/rollout/generate_hub/test_single_turn.py +++ b/tests/rollout/generate_hub/test_single_turn.py @@ -1,24 +1,17 @@ -from argparse import Namespace -from dataclasses import dataclass -from typing import Any -from unittest.mock import patch - import numpy as np import pybase64 import pytest import torch from PIL import Image +from tests.fixtures.generation_fixtures import GenerateEnv, generation_env, make_sample, run_generate from transformers import AutoProcessor -from miles.rollout.base_types import GenerateFnInput -from miles.rollout.modular_rollout.orchestration_common import GenerateState -from miles.utils.async_utils import run -from miles.utils.http_utils import init_http_client -from miles.utils.misc import SingletonMeta from miles.utils.processing_utils import encode_image_for_rollout_engine -from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo from miles.utils.types import Sample +_ = generation_env + # ------------------------------------ fixtures and consts ---------------------------------------- @@ -28,10 +21,10 @@ RESPONSE_TOKENS = [59, 79075, 90, 23, 92] RESPONSE_TEXT = "\\boxed{8}" RESPONSE_LOG_PROBS = [-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125] -DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} +SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} -@pytest.fixture(params=["sglang_rollout", "modular_rollout"]) +@pytest.fixture(params=["old_sglang_rollout", "single_turn"]) def variant(request): return request.param @@ -46,10 +39,10 @@ def expected_request( ) -> dict: result = { "input_ids": input_ids or PROMPT_TOKENS, - "sampling_params": sampling_params or DEFAULT_SAMPLING_PARAMS, + "sampling_params": sampling_params or SAMPLING_PARAMS, "return_logprob": True, } - if variant == "modular_rollout" or return_routed_experts: + if variant == "single_turn" or return_routed_experts: result["return_routed_experts"] = return_routed_experts if image_data is not None: result["image_data"] = image_data @@ -97,115 +90,10 @@ def expected_sample( ) -def make_args( - *, - router_port: int, - use_rollout_routing_replay: bool = False, - sglang_speculative_algorithm: str | None = None, - model_name: str = MODEL_NAME, -) -> Namespace: - argv = [ - "pytest", - "--train-backend", - "fsdp", - "--rollout-batch-size", - "1", - "--num-rollout", - "1", - "--rollout-num-gpus", - "1", - "--rollout-num-gpus-per-engine", - "1", - "--hf-checkpoint", - model_name, - "--prompt-data", - "/dev/null", - "--rm-type", - "math", - "--sglang-router-ip", - "127.0.0.1", - "--sglang-router-port", - str(router_port), - "--rollout-max-response-len", - "16", - ] - if use_rollout_routing_replay: - argv.append("--use-rollout-routing-replay") - if sglang_speculative_algorithm: - argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) - - from miles.utils.arguments import parse_args - - with patch("sys.argv", argv): - args = parse_args() - - init_http_client(args) - return args - - -async def call_generate(variant: str, args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: - if variant == "sglang_rollout": - from miles.rollout.sglang_rollout import generate - - return await generate(args, sample, sampling_params.copy()) - elif variant == "modular_rollout": - from miles.rollout.generate_hub.single_turn import generate - - state = GenerateState(args) - output = await generate( - GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params.copy(), evaluation=False) - ) - return output.samples - else: - raise NotImplementedError - - -@dataclass -class GenerateEnv: - args: Namespace - mock_server: Any - - -@dataclass -class GenerateResult: - sample: Sample - requests: list[dict] - - -@pytest.fixture -def env(request): - SingletonMeta.clear_all_instances() - params = getattr(request, "param", {}) - args_kwargs = params.get("args_kwargs", {}) - model_name = args_kwargs.get("model_name", MODEL_NAME) - - def process_fn(_): - x = params.get("process_fn_kwargs", {}) - return ProcessResult( - text=x.get("response_text", RESPONSE_TEXT), - finish_reason=x.get("finish_reason", "stop"), - cached_tokens=x.get("cached_tokens", 0), - meta_info=ProcessResultMetaInfo( - weight_version=x.get("weight_version"), - routed_experts=x.get("routed_experts"), - spec_accept_token_num=x.get("spec_accept_token_num"), - spec_draft_token_num=x.get("spec_draft_token_num"), - spec_verify_ct=x.get("spec_verify_ct"), - ), - ) - - with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server: - other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} - args = make_args(router_port=mock_server.port, model_name=model_name, **other_args_kwargs) - yield GenerateEnv(args=args, mock_server=mock_server) - - SingletonMeta.clear_all_instances() - - -def make_sample(tokens=None, response="", response_length=0, status=Sample.Status.PENDING, multimodal_inputs=None): - return Sample( +def _make_sample(tokens=None, response="", response_length=0, status=Sample.Status.PENDING, multimodal_inputs=None): + return make_sample( prompt=PROMPT, - tokens=tokens or [], + tokens=tokens, response=response, response_length=response_length, status=status, @@ -213,26 +101,22 @@ def make_sample(tokens=None, response="", response_length=0, status=Sample.Statu ) -def run_generate(variant: str, env: GenerateEnv, sample: Sample | None = None, sampling_params: dict | None = None): - env.mock_server.request_log.clear() - result_sample = run( - call_generate(variant, env.args, sample or make_sample(), sampling_params or DEFAULT_SAMPLING_PARAMS) - ) - return GenerateResult(sample=result_sample, requests=list(env.mock_server.request_log)) +def _run_generate(variant: str, env: GenerateEnv, sample: Sample | None = None, sampling_params: dict | None = None): + return run_generate(env, sample or _make_sample(), sampling_params or SAMPLING_PARAMS, variant=variant) # ------------------------------------ tests ---------------------------------------- class TestBasicGeneration: - def test_basic_generation(self, variant, env): - result = run_generate(variant, env) + def test_basic_generation(self, variant, generation_env): + result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] assert result.sample == expected_sample() class TestResumedSingleTurn: - def test_two_consecutive_calls_on_same_sample(self, variant, env): + def test_two_consecutive_calls_on_same_sample(self, variant, generation_env): partial_text = "\\boxed" partial_tokens = [59, 79075] partial_log_probs = [-0.0, -0.0078125] @@ -241,9 +125,9 @@ def test_two_consecutive_calls_on_same_sample(self, variant, env): remaining_tokens = [90, 23, 92] remaining_log_probs = [-0.0, -0.0078125, -0.015625] - env.mock_server.process_fn = lambda _: ProcessResult(text=partial_text, finish_reason="abort") - sample = make_sample() - result1 = run_generate(variant, env, sample) + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=partial_text, finish_reason="abort") + sample = _make_sample() + result1 = _run_generate(variant, generation_env, sample) assert result1.requests == [expected_request(variant)] assert result1.sample == expected_sample( response=partial_text, @@ -253,8 +137,8 @@ def test_two_consecutive_calls_on_same_sample(self, variant, env): status=Sample.Status.ABORTED, ) - env.mock_server.process_fn = lambda _: ProcessResult(text=remaining_text, finish_reason="stop") - result2 = run_generate(variant, env, result1.sample) + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=remaining_text, finish_reason="stop") + result2 = _run_generate(variant, generation_env, result1.sample) tokens_after_turn1 = PROMPT_TOKENS + partial_tokens assert result2.requests == [ expected_request( @@ -275,23 +159,23 @@ def test_two_consecutive_calls_on_same_sample(self, variant, env): class TestFinishReason: @pytest.mark.parametrize( - "env,expected_status", + "generation_env,expected_status", [ ({"process_fn_kwargs": {"finish_reason": "stop"}}, Sample.Status.COMPLETED), ({"process_fn_kwargs": {"finish_reason": "length"}}, Sample.Status.TRUNCATED), ({"process_fn_kwargs": {"finish_reason": "abort"}}, Sample.Status.ABORTED), ], - indirect=["env"], + indirect=["generation_env"], ) - def test_finish_reason_sets_status(self, variant, env, expected_status): - result = run_generate(variant, env) + def test_finish_reason_sets_status(self, variant, generation_env, expected_status): + result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] assert result.sample == expected_sample(status=expected_status) class TestRoutedExperts: @pytest.mark.parametrize( - "env", + "generation_env", [ { "args_kwargs": {"use_rollout_routing_replay": True}, @@ -300,23 +184,23 @@ class TestRoutedExperts: ], indirect=True, ) - def test_routed_experts_enabled_and_parsed(self, variant, env): + def test_routed_experts_enabled_and_parsed(self, variant, generation_env): num_layers, moe_router_topk = 2, 4 num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS) routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape( num_tokens - 1, num_layers, moe_router_topk ) - env.args.num_layers = num_layers - env.args.moe_router_topk = moe_router_topk + generation_env.args.num_layers = num_layers + generation_env.args.moe_router_topk = moe_router_topk routed_experts_str = pybase64.b64encode(routed_experts_array.tobytes()).decode("ascii") - env.mock_server.process_fn = lambda _: ProcessResult( + generation_env.mock_server.process_fn = lambda _: ProcessResult( text=RESPONSE_TEXT, finish_reason="stop", meta_info=ProcessResultMetaInfo(routed_experts=routed_experts_str), ) - result = run_generate(variant, env) + result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant, return_routed_experts=True)] assert result.sample.rollout_routed_experts is not None assert result.sample.rollout_routed_experts.shape == (num_tokens - 1, num_layers, moe_router_topk) @@ -325,15 +209,15 @@ def test_routed_experts_enabled_and_parsed(self, variant, env): class TestMetaInfo: @pytest.mark.parametrize( - "env", [{"process_fn_kwargs": {"cached_tokens": 3, "weight_version": "v1.0"}}], indirect=True + "generation_env", [{"process_fn_kwargs": {"cached_tokens": 3, "weight_version": "v1.0"}}], indirect=True ) - def test_meta_info_fields_updated(self, variant, env): - result = run_generate(variant, env) + def test_meta_info_fields_updated(self, variant, generation_env): + result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] assert result.sample == expected_sample(cached_tokens=3, weight_versions=["v1.0"]) @pytest.mark.parametrize( - "env", + "generation_env", [ { "args_kwargs": {"sglang_speculative_algorithm": "EAGLE"}, @@ -342,8 +226,8 @@ def test_meta_info_fields_updated(self, variant, env): ], indirect=True, ) - def test_spec_info_updated(self, variant, env): - result = run_generate(variant, env) + def test_spec_info_updated(self, variant, generation_env): + result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] assert result.sample == expected_sample( spec_info=Sample.SpecInfo( @@ -354,20 +238,22 @@ def test_spec_info_updated(self, variant, env): class TestInputStatusValidation: @pytest.mark.parametrize("status", [Sample.Status.PENDING, Sample.Status.ABORTED]) - def test_allowed_statuses(self, variant, env, status): - result = run_generate(variant, env, make_sample(status=status)) + def test_allowed_statuses(self, variant, generation_env, status): + result = _run_generate(variant, generation_env, _make_sample(status=status)) assert result.requests == [expected_request(variant)] assert result.sample.status == Sample.Status.COMPLETED @pytest.mark.parametrize("status", [Sample.Status.COMPLETED, Sample.Status.TRUNCATED]) - def test_rejected_statuses(self, variant, env, status): + def test_rejected_statuses(self, variant, generation_env, status): with pytest.raises(AssertionError): - run_generate(variant, env, make_sample(status=status)) + _run_generate(variant, generation_env, _make_sample(status=status)) class TestPayloadStructure: - def test_sampling_params_passed_through(self, variant, env): - result = run_generate(variant, env, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) + def test_sampling_params_passed_through(self, variant, generation_env): + result = _run_generate( + variant, generation_env, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9} + ) assert result.requests == [ expected_request(variant, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) ] @@ -375,19 +261,19 @@ def test_sampling_params_passed_through(self, variant, env): class TestBoundaryConditions: - def test_max_new_tokens_zero_returns_truncated(self, variant, env): + def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): existing_tokens = [1, 2, 3, 4, 5, 6, 7] + list(range(100, 110)) - sample = make_sample(tokens=existing_tokens, response="x" * 10, response_length=10) + sample = _make_sample(tokens=existing_tokens, response="x" * 10, response_length=10) - result = run_generate(variant, env, sample, {"max_new_tokens": 10, "temperature": 0.7}) + result = _run_generate(variant, generation_env, sample, {"max_new_tokens": 10, "temperature": 0.7}) assert result.requests == [] assert result.sample.status == Sample.Status.TRUNCATED class TestEmptyResponse: - @pytest.mark.parametrize("env", [{"process_fn_kwargs": {"response_text": ""}}], indirect=True) - def test_empty_response(self, variant, env): - result = run_generate(variant, env) + @pytest.mark.parametrize("generation_env", [{"process_fn_kwargs": {"response_text": ""}}], indirect=True) + def test_empty_response(self, variant, generation_env): + result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] assert result.sample == expected_sample( response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[] @@ -398,8 +284,8 @@ def test_empty_response(self, variant, env): class TestMultimodal: - @pytest.mark.parametrize("env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True) - def test_multimodal_inputs_processed(self, variant, env): + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True) + def test_multimodal_inputs_processed(self, variant, generation_env): test_image = Image.new("RGB", (64, 64), color="red") multimodal_inputs = {"images": [test_image]} processor = AutoProcessor.from_pretrained(VLM_MODEL_NAME, trust_remote_code=True) @@ -409,7 +295,7 @@ def test_multimodal_inputs_processed(self, variant, env): if k not in ["input_ids", "attention_mask"] } - result = run_generate(variant, env, make_sample(multimodal_inputs=multimodal_inputs)) + result = _run_generate(variant, generation_env, _make_sample(multimodal_inputs=multimodal_inputs)) assert result.requests == [ expected_request( From ab03942815d730ab128b5326dc6cd8c59a8de97e Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:13:02 +0800 Subject: [PATCH 26/77] Support multi-turn testing with snapshot test utils (#459) --- tests/rollout/generate_hub/test_multi_turn.py | 207 ++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 tests/rollout/generate_hub/test_multi_turn.py diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py new file mode 100644 index 0000000000..d292815e0d --- /dev/null +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -0,0 +1,207 @@ +from copy import deepcopy +from dataclasses import dataclass, replace +from itertools import groupby + +import pytest +from tests.fixtures.generation_fixtures import GenerateEnv, generation_env, make_sample, run_generate +from transformers import AutoTokenizer + +from miles.utils.test_utils.mock_sglang_server import ProcessResult +from miles.utils.test_utils.mock_tools import ( + MULTI_TURN_FIRST_RESPONSE, + MULTI_TURN_SECOND_RESPONSE, + SAMPLE_TOOLS, + multi_turn_tool_call_process_fn, +) +from miles.utils.types import Sample + +_ = generation_env, SAMPLE_TOOLS, multi_turn_tool_call_process_fn + + +# ------------------------------------ fixtures and consts ---------------------------------------- + + +MODEL_NAME = "Qwen/Qwen3-0.6B" +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} +TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + +MULTI_TURN_EXTRA_ARGV = [ + "--generate-max-turns", + "4", + "--generate-max-tool-calls", + "4", + "--generate-tool-specs-path", + "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + "--generate-tool-call-parser", + "qwen25", + "--generate-execute-tool-function-path", + "miles.utils.test_utils.mock_tools.execute_tool_call", + "--rollout-max-context-len", + "4096", +] + + +@pytest.fixture(params=["multi_turn_single_sample"]) +def variant(request): + return request.param + + +@dataclass(frozen=True) +class SampleParsedChunk: + tokens_decoded_str: str + loss_mask_value: int + rollout_log_probs: list[float] + + +def parse_sample_into_chunks(sample: Sample, tokenizer) -> list[SampleParsedChunk]: + prompt_len = len(sample.tokens) - sample.response_length + response_tokens = sample.tokens[prompt_len:] + loss_mask = sample.loss_mask + log_probs = sample.rollout_log_probs + + chunks = [] + idx = 0 + for mask_val, group in groupby(loss_mask): + group_len = len(list(group)) + sli = slice(idx, idx + group_len) + chunks.append( + SampleParsedChunk( + tokens_decoded_str=tokenizer.decode(response_tokens[sli]), + loss_mask_value=mask_val, + rollout_log_probs=log_probs[sli], + ) + ) + idx += group_len + return chunks + + +def expected_partial_sample( + *, + prompt: list[dict], + response: str, + response_length: int, + status: Sample.Status = Sample.Status.COMPLETED, +) -> Sample: + return Sample( + prompt=prompt, + response=response, + response_length=response_length, + status=status, + tokens=[], + loss_mask=[], + rollout_log_probs=[], + weight_versions=[], + spec_info=Sample.SpecInfo(), + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + + +def verify_sample( + actual: Sample, + *, + expected_chunks: list[SampleParsedChunk], + expected_partial_sample: Sample, +): + actual_chunks = parse_sample_into_chunks(actual, TOKENIZER) + assert actual_chunks == expected_chunks + + actual_partial = replace( + deepcopy(actual), + tokens=[], + loss_mask=[], + rollout_log_probs=[], + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + assert actual_partial == expected_partial_sample + + +def _run_generate(variant: str, env: GenerateEnv, sample: Sample, sampling_params: dict | None = None): + return run_generate(env, sample, sampling_params, variant=variant) + + +SINGLE_TURN_PROMPT = [{"role": "user", "content": "What is 1+1?"}] +SINGLE_TURN_RESPONSE = "The answer is 2." + +TWO_TURN_USER_QUESTION = "What is 42 + year + temperature?" +TWO_TURN_PROMPT = [{"role": "user", "content": TWO_TURN_USER_QUESTION}] +TWO_TURN_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" +) + + +# ------------------------------------ tests ---------------------------------------- + + +class TestBasicMultiTurn: + @pytest.mark.parametrize( + "generation_env", + [{"args_kwargs": {"extra_argv": MULTI_TURN_EXTRA_ARGV}}], + indirect=True, + ) + def test_single_turn_no_tool_call(self, variant, generation_env): + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=SINGLE_TURN_RESPONSE, finish_reason="stop" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + assert len(result.requests) == 1 + verify_sample( + result.sample, + expected_chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + expected_partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, + response=SINGLE_TURN_RESPONSE, + response_length=6, + ), + ) + + @pytest.mark.parametrize( + "generation_env", + [{"args_kwargs": {"extra_argv": MULTI_TURN_EXTRA_ARGV}}], + indirect=True, + ) + def test_two_turns_with_tool_call(self, variant, generation_env): + generation_env.mock_server.process_fn = multi_turn_tool_call_process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) + + assert len(result.requests) == 2 + verify_sample( + result.sample, + expected_chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ), + SampleParsedChunk( + tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, + loss_mask_value=0, + rollout_log_probs=[0.0] * 31, + ), + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_SECOND_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(24)], + ), + ], + expected_partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE + MULTI_TURN_SECOND_RESPONSE, + response_length=45 + 31 + 24, + ), + ) From a1dca85d61017b5dbf6e393ae68fd87dfd424337 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:13:20 +0800 Subject: [PATCH 27/77] Update multi turn single sample implementation to use standard tooling (#460) --- .../generate_hub/multi_turn_single_sample.py | 123 ++++++++++-------- 1 file changed, 67 insertions(+), 56 deletions(-) diff --git a/miles/rollout/generate_hub/multi_turn_single_sample.py b/miles/rollout/generate_hub/multi_turn_single_sample.py index a6b049ead8..cd53ac5c48 100644 --- a/miles/rollout/generate_hub/multi_turn_single_sample.py +++ b/miles/rollout/generate_hub/multi_turn_single_sample.py @@ -2,10 +2,18 @@ Simple multi-turn generation with tool calling. """ -from typing import Any +import argparse +import json +import uuid + +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.function_call_parser import FunctionCallParser from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_hub.tool_call_utils import tokenize_tool_responses from miles.utils.http_utils import post +from miles.utils.misc import load_function from miles.utils.types import Sample @@ -18,17 +26,25 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + execute_tool_function = load_function(args.generate_execute_tool_function_path) + + tool_specs = load_function(args.generate_tool_specs_path) + assert isinstance(tool_specs, list) + + tool_call_parser = FunctionCallParser( + tools=(TypeAdapter(list[Tool]).validate_python(tool_specs)), + tool_call_parser=args.generate_tool_call_parser, + ) + # Set up the initial prompt with system prompt and tools (outside the loop) - tool_specs = tool_registry.get_tool_specs() - prompt = format_conversation_with_tools(prompt=sample.prompt, tools=tool_specs) + prompt = tokenizer.apply_chat_template(sample.prompt, tokenize=False, add_generation_prompt=True, tools=tool_specs) prompt_tokens_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"] response = "" response_token_ids = [] loss_masks = [] - tool_call_count = 0 # Track actual tool call rounds - for turn in range(TOOL_CONFIGS["max_turns"]): + for turn in range(args.generate_max_turns): # Check if total length exceeds max context length total_length = len(prompt_tokens_ids) + len(response_token_ids) if args.rollout_max_context_len is not None: @@ -54,18 +70,12 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: sample.status = Sample.Status.ABORTED return GenerateFnOutput(samples=sample) - if "output_token_logprobs" in output["meta_info"]: - cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]] - cur_response = tokenizer.decode(cur_response_token_ids) - cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] - if sample.rollout_log_probs is None: - sample.rollout_log_probs = [] - sample.rollout_log_probs += cur_log_probs - - else: - cur_response = output["text"] - cur_response = postprocess_responses(cur_response) - cur_response_token_ids = tokenizer(cur_response, add_special_tokens=False)["input_ids"] + cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + cur_response = tokenizer.decode(cur_response_token_ids) + cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += cur_log_probs response += cur_response response_token_ids += cur_response_token_ids @@ -75,31 +85,25 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: if output["meta_info"]["finish_reason"]["type"] == "length": break - next_obs, done = await execute_predictions(cur_response) - if done: + _, parsed_tool_calls = tool_call_parser.parse_non_stream(cur_response) + if len(parsed_tool_calls) == 0: break - # Count tool calls (when we get interpreter output, it means a tool - # was called) - if "" in next_obs: - tool_call_count += 1 + tool_messages = await _execute_tool_calls(parsed_tool_calls, execute_tool_function) - assert next_obs != "", "Next observation should not be empty." - obs_tokens_ids = tokenizer(next_obs, add_special_tokens=False)["input_ids"] - response += next_obs - response_token_ids += obs_tokens_ids - loss_masks += [0] * len(obs_tokens_ids) + next_obs_tokens_ids: list[int] = tokenize_tool_responses(tool_messages, tokenizer=tokenizer) + # TODO is this ok? + response += tokenizer.decode(next_obs_tokens_ids) + response_token_ids += next_obs_tokens_ids + loss_masks += [0] * len(next_obs_tokens_ids) - # Add dummy log probs for observation tokens (they won't be used due to loss_mask=0) - # Check if maximum tool call count reached - if sample.rollout_log_probs is not None: - sample.rollout_log_probs += [0.0] * len(obs_tokens_ids) + sample.rollout_log_probs += [0.0] * len(next_obs_tokens_ids) - assert len(response_token_ids) == len( - sample.rollout_log_probs - ), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps" + assert len(response_token_ids) == len( + sample.rollout_log_probs + ), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps" - if turn >= TOOL_CONFIGS["max_tool_calls"]: + if turn >= args.generate_max_tool_calls: break # Set sample attributes @@ -114,23 +118,30 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: return GenerateFnOutput(samples=sample) -def format_conversation_with_tools( - prompt: str, tools: list[dict[str, Any]] = None, system_prompt: str = None, messages: list[dict[str, Any]] = None -) -> str: - return TODO - - -def postprocess_predictions(prediction: str): - """Extract action and content from prediction string""" - return TODO, TODO - - -def postprocess_responses(resp: str) -> str: - return TODO - - -async def execute_predictions(prediction: str) -> str: - """Execute predictions and return results""" - action, content = postprocess_predictions(prediction) - next_obs, done = TODO - return next_obs, done +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-max-turns", type=int, default=16) + parser.add_argument("--generate-max-tool-calls", type=int, default=16) + parser.add_argument("--generate-tool-specs-path", type=str) + parser.add_argument("--generate-tool-call-parser", type=str) + parser.add_argument("--generate-execute-tool-function-path", type=str) + + +generate.add_arguments = _add_arguments + + +async def _execute_tool_calls(parsed_tool_calls, execute_one) -> list[dict]: + tool_messages = [] + for call in parsed_tool_calls: + params = json.loads(call.parameters) if call.parameters else {} + result = await execute_one(call.name, params) + assert isinstance(result, str) + tool_messages.append( + { + "role": "tool", + # src: serving_chat.py :: _process_tool_call_id + "tool_call_id": f"call_{uuid.uuid4().hex[:24]}", + "content": result, + "name": call.name, + } + ) + return tool_messages From 0c0cc8a866ec22fa86a8cfa83d07d6cdb842ceb3 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:13:39 +0800 Subject: [PATCH 28/77] Support comparison tests and edge case tests for multi-turn-single-sample (#463) --- .../generate_hub/multi_turn_single_sample.py | 7 +- tests/fixtures/generation_fixtures.py | 16 ++ tests/rollout/generate_hub/test_multi_turn.py | 142 ++++++++++++++---- .../rollout/generate_hub/test_single_turn.py | 35 ++++- 4 files changed, 160 insertions(+), 40 deletions(-) diff --git a/miles/rollout/generate_hub/multi_turn_single_sample.py b/miles/rollout/generate_hub/multi_turn_single_sample.py index cd53ac5c48..9a09630126 100644 --- a/miles/rollout/generate_hub/multi_turn_single_sample.py +++ b/miles/rollout/generate_hub/multi_turn_single_sample.py @@ -36,10 +36,11 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: tool_call_parser=args.generate_tool_call_parser, ) - # Set up the initial prompt with system prompt and tools (outside the loop) - prompt = tokenizer.apply_chat_template(sample.prompt, tokenize=False, add_generation_prompt=True, tools=tool_specs) - + prompt = sample.prompt + if not isinstance(prompt, str): + prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True, tools=tool_specs) prompt_tokens_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"] + response = "" response_token_ids = [] loss_masks = [] diff --git a/tests/fixtures/generation_fixtures.py b/tests/fixtures/generation_fixtures.py index caae309f94..d00424b829 100644 --- a/tests/fixtures/generation_fixtures.py +++ b/tests/fixtures/generation_fixtures.py @@ -95,12 +95,19 @@ async def _call_generate( def make_args( *, + variant: str, router_port: int, use_rollout_routing_replay: bool = False, sglang_speculative_algorithm: str | None = None, model_name: str = MODEL_NAME, extra_argv: list[str] | None = None, custom_generate_function_path: str | None = None, + generate_max_turns: int = 16, + generate_max_tool_calls: int = 16, + generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + generate_tool_call_parser: str = "qwen25", + generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", + rollout_max_context_len: int = 4096, ) -> Namespace: argv = [ "pytest", @@ -133,6 +140,14 @@ def make_args( argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) if custom_generate_function_path: argv.extend(["--custom-generate-function-path", custom_generate_function_path]) + + if variant == "multi_turn_single_sample": + argv.extend(["--generate-max-turns", str(generate_max_turns)]) + argv.extend(["--generate-max-tool-calls", str(generate_max_tool_calls)]) + argv.extend(["--generate-tool-specs-path", generate_tool_specs_path]) + argv.extend(["--generate-tool-call-parser", generate_tool_call_parser]) + argv.extend(["--generate-execute-tool-function-path", generate_execute_tool_function_path]) + argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) if extra_argv: argv.extend(extra_argv) @@ -171,6 +186,7 @@ def process_fn(_): with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server: other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} args = make_args( + variant=variant, router_port=mock_server.port, model_name=model_name, custom_generate_function_path=custom_generate_function_path, diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py index d292815e0d..6b105b4f7e 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -8,7 +8,9 @@ from miles.utils.test_utils.mock_sglang_server import ProcessResult from miles.utils.test_utils.mock_tools import ( + MULTI_TURN_FIRST_PROMPT, MULTI_TURN_FIRST_RESPONSE, + MULTI_TURN_SECOND_PROMPT, MULTI_TURN_SECOND_RESPONSE, SAMPLE_TOOLS, multi_turn_tool_call_process_fn, @@ -24,21 +26,8 @@ MODEL_NAME = "Qwen/Qwen3-0.6B" DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) - -MULTI_TURN_EXTRA_ARGV = [ - "--generate-max-turns", - "4", - "--generate-max-tool-calls", - "4", - "--generate-tool-specs-path", - "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", - "--generate-tool-call-parser", - "qwen25", - "--generate-execute-tool-function-path", - "miles.utils.test_utils.mock_tools.execute_tool_call", - "--rollout-max-context-len", - "4096", -] +FIRST_PROMPT_TOKEN_IDS = TOKENIZER(MULTI_TURN_FIRST_PROMPT, add_special_tokens=False)["input_ids"] +SECOND_PROMPT_TOKEN_IDS = TOKENIZER(MULTI_TURN_SECOND_PROMPT, add_special_tokens=False)["input_ids"] @pytest.fixture(params=["multi_turn_single_sample"]) @@ -56,8 +45,8 @@ class SampleParsedChunk: def parse_sample_into_chunks(sample: Sample, tokenizer) -> list[SampleParsedChunk]: prompt_len = len(sample.tokens) - sample.response_length response_tokens = sample.tokens[prompt_len:] - loss_mask = sample.loss_mask - log_probs = sample.rollout_log_probs + loss_mask = sample.loss_mask or [] + log_probs = sample.rollout_log_probs or [] chunks = [] idx = 0 @@ -119,8 +108,21 @@ def _run_generate(variant: str, env: GenerateEnv, sample: Sample, sampling_param return run_generate(env, sample, sampling_params, variant=variant) +def expected_request(input_ids: list[int], sampling_params: dict | None = None) -> dict: + return { + "input_ids": input_ids, + "sampling_params": sampling_params or DEFAULT_SAMPLING_PARAMS, + "return_logprob": True, + } + + SINGLE_TURN_PROMPT = [{"role": "user", "content": "What is 1+1?"}] SINGLE_TURN_RESPONSE = "The answer is 2." +_SINGLE_TURN_PROMPT_TEXT = TOKENIZER.apply_chat_template( + SINGLE_TURN_PROMPT, tokenize=False, add_generation_prompt=True, tools=SAMPLE_TOOLS +) +SINGLE_TURN_PROMPT_TOKEN_IDS = TOKENIZER(_SINGLE_TURN_PROMPT_TEXT, add_special_tokens=False)["input_ids"] +SINGLE_TURN_PROMPT_TOKEN_LEN = len(SINGLE_TURN_PROMPT_TOKEN_IDS) TWO_TURN_USER_QUESTION = "What is 42 + year + temperature?" TWO_TURN_PROMPT = [{"role": "user", "content": TWO_TURN_USER_QUESTION}] @@ -140,11 +142,6 @@ def _run_generate(variant: str, env: GenerateEnv, sample: Sample, sampling_param class TestBasicMultiTurn: - @pytest.mark.parametrize( - "generation_env", - [{"args_kwargs": {"extra_argv": MULTI_TURN_EXTRA_ARGV}}], - indirect=True, - ) def test_single_turn_no_tool_call(self, variant, generation_env): generation_env.mock_server.process_fn = lambda _: ProcessResult( text=SINGLE_TURN_RESPONSE, finish_reason="stop" @@ -152,7 +149,7 @@ def test_single_turn_no_tool_call(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) - assert len(result.requests) == 1 + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] verify_sample( result.sample, expected_chunks=[ @@ -169,17 +166,15 @@ def test_single_turn_no_tool_call(self, variant, generation_env): ), ) - @pytest.mark.parametrize( - "generation_env", - [{"args_kwargs": {"extra_argv": MULTI_TURN_EXTRA_ARGV}}], - indirect=True, - ) def test_two_turns_with_tool_call(self, variant, generation_env): generation_env.mock_server.process_fn = multi_turn_tool_call_process_fn result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) - assert len(result.requests) == 2 + assert result.requests == [ + expected_request(FIRST_PROMPT_TOKEN_IDS), + expected_request(SECOND_PROMPT_TOKEN_IDS), + ] verify_sample( result.sample, expected_chunks=[ @@ -205,3 +200,92 @@ def test_two_turns_with_tool_call(self, variant, generation_env): response_length=45 + 31 + 24, ), ) + + +class TestExitConditions: + def test_partial_rollout_not_supported(self, variant, generation_env): + generation_env.args.partial_rollout = True + + with pytest.raises(AssertionError, match="Partial rollout is not supported"): + _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + def test_abort_preserves_content(self, variant, generation_env): + pytest.skip("TODO: support") + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=SINGLE_TURN_RESPONSE, finish_reason="abort" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] + verify_sample( + result.sample, + expected_chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + expected_partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, + response=SINGLE_TURN_RESPONSE, + response_length=6, + status=Sample.Status.ABORTED, + ), + ) + + def test_finish_reason_length_exits_and_preserves_content(self, variant, generation_env): + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=MULTI_TURN_FIRST_RESPONSE, finish_reason="length" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) + + assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] + verify_sample( + result.sample, + expected_chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ), + ], + expected_partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE, + response_length=45, + status=Sample.Status.TRUNCATED, + ), + ) + + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"generate_max_turns": 1}}], indirect=True) + def test_max_turns_reached(self, variant, generation_env): + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=MULTI_TURN_FIRST_RESPONSE, finish_reason="stop" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) + + assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] + verify_sample( + result.sample, + expected_chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ), + SampleParsedChunk( + tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, + loss_mask_value=0, + rollout_log_probs=[0.0] * 31, + ), + ], + expected_partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE, + response_length=45 + 31, + ), + ) diff --git a/tests/rollout/generate_hub/test_single_turn.py b/tests/rollout/generate_hub/test_single_turn.py index 3c7d0954e6..b3de35341d 100644 --- a/tests/rollout/generate_hub/test_single_turn.py +++ b/tests/rollout/generate_hub/test_single_turn.py @@ -24,7 +24,7 @@ SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} -@pytest.fixture(params=["old_sglang_rollout", "single_turn"]) +@pytest.fixture(params=["old_sglang_rollout", "single_turn", "multi_turn_single_sample"]) def variant(request): return request.param @@ -50,6 +50,7 @@ def expected_request( def expected_sample( + variant: str, *, prompt: str = PROMPT, response: str = RESPONSE_TEXT, @@ -65,6 +66,8 @@ def expected_sample( multimodal_inputs: dict | None = None, multimodal_train_inputs: dict | None = None, ) -> Sample: + actual_response_length = response_length if response_length is not None else len(RESPONSE_TOKENS) + loss_mask = [1] * actual_response_length if variant == "multi_turn_single_sample" else None return Sample( group_index=None, index=None, @@ -76,7 +79,7 @@ def expected_sample( response_length=response_length, label=None, reward=None, - loss_mask=None, + loss_mask=loss_mask, weight_versions=weight_versions or [], rollout_log_probs=rollout_log_probs if rollout_log_probs is not None else RESPONSE_LOG_PROBS, rollout_routed_experts=rollout_routed_experts, @@ -112,11 +115,13 @@ class TestBasicGeneration: def test_basic_generation(self, variant, generation_env): result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] - assert result.sample == expected_sample() + assert result.sample == expected_sample(variant) class TestResumedSingleTurn: def test_two_consecutive_calls_on_same_sample(self, variant, generation_env): + if variant == "multi_turn_single_sample": + pytest.skip("not tested yet") partial_text = "\\boxed" partial_tokens = [59, 79075] partial_log_probs = [-0.0, -0.0078125] @@ -130,6 +135,7 @@ def test_two_consecutive_calls_on_same_sample(self, variant, generation_env): result1 = _run_generate(variant, generation_env, sample) assert result1.requests == [expected_request(variant)] assert result1.sample == expected_sample( + variant, response=partial_text, response_length=2, tokens=PROMPT_TOKENS + partial_tokens, @@ -148,6 +154,7 @@ def test_two_consecutive_calls_on_same_sample(self, variant, generation_env): ) ] assert result2.sample == expected_sample( + variant, response=partial_text + remaining_text, response_length=2 + 3, tokens=tokens_after_turn1 + remaining_tokens, @@ -168,9 +175,11 @@ class TestFinishReason: indirect=["generation_env"], ) def test_finish_reason_sets_status(self, variant, generation_env, expected_status): + if variant == "multi_turn_single_sample" and expected_status == Sample.Status.ABORTED: + pytest.skip("TODO: support") result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] - assert result.sample == expected_sample(status=expected_status) + assert result.sample == expected_sample(variant, status=expected_status) class TestRoutedExperts: @@ -185,6 +194,8 @@ class TestRoutedExperts: indirect=True, ) def test_routed_experts_enabled_and_parsed(self, variant, generation_env): + if variant == "multi_turn_single_sample": + pytest.skip("TODO: support") num_layers, moe_router_topk = 2, 4 num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS) routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape( @@ -214,7 +225,7 @@ class TestMetaInfo: def test_meta_info_fields_updated(self, variant, generation_env): result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] - assert result.sample == expected_sample(cached_tokens=3, weight_versions=["v1.0"]) + assert result.sample == expected_sample(variant, cached_tokens=3, weight_versions=["v1.0"]) @pytest.mark.parametrize( "generation_env", @@ -230,9 +241,10 @@ def test_spec_info_updated(self, variant, generation_env): result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] assert result.sample == expected_sample( + variant, spec_info=Sample.SpecInfo( spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3, completion_token_num=5 - ) + ), ) @@ -245,6 +257,8 @@ def test_allowed_statuses(self, variant, generation_env, status): @pytest.mark.parametrize("status", [Sample.Status.COMPLETED, Sample.Status.TRUNCATED]) def test_rejected_statuses(self, variant, generation_env, status): + if variant == "multi_turn_single_sample": + pytest.skip("not tested yet") with pytest.raises(AssertionError): _run_generate(variant, generation_env, _make_sample(status=status)) @@ -257,11 +271,13 @@ def test_sampling_params_passed_through(self, variant, generation_env): assert result.requests == [ expected_request(variant, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) ] - assert result.sample == expected_sample() + assert result.sample == expected_sample(variant) class TestBoundaryConditions: def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): + if variant == "multi_turn_single_sample": + pytest.skip("not tested yet") existing_tokens = [1, 2, 3, 4, 5, 6, 7] + list(range(100, 110)) sample = _make_sample(tokens=existing_tokens, response="x" * 10, response_length=10) @@ -276,7 +292,7 @@ def test_empty_response(self, variant, generation_env): result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] assert result.sample == expected_sample( - response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[] + variant, response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[] ) @@ -286,6 +302,8 @@ def test_empty_response(self, variant, generation_env): class TestMultimodal: @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True) def test_multimodal_inputs_processed(self, variant, generation_env): + if variant == "multi_turn_single_sample": + pytest.skip("not tested yet") test_image = Image.new("RGB", (64, 64), color="red") multimodal_inputs = {"images": [test_image]} processor = AutoProcessor.from_pretrained(VLM_MODEL_NAME, trust_remote_code=True) @@ -310,6 +328,7 @@ def test_multimodal_inputs_processed(self, variant, generation_env): assert torch.all(actual_mti["pixel_values"] == expected_mti["pixel_values"]) assert torch.all(actual_mti["image_grid_thw"] == expected_mti["image_grid_thw"]) assert result.sample == expected_sample( + variant, tokens=PROMPT_TOKENS + RESPONSE_TOKENS, multimodal_inputs=multimodal_inputs, multimodal_train_inputs=actual_mti, From cc46921d6b7aab47416d7af0cc35070bb6e98678 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:14:04 +0800 Subject: [PATCH 29/77] Change behavior of multi turn single sample to match single turn in degenerated cases (#464) --- .../generate_hub/multi_turn_single_sample.py | 23 ++++++------------- tests/fixtures/generation_fixtures.py | 2 -- tests/rollout/generate_hub/test_multi_turn.py | 1 - .../rollout/generate_hub/test_single_turn.py | 3 +-- 4 files changed, 8 insertions(+), 21 deletions(-) diff --git a/miles/rollout/generate_hub/multi_turn_single_sample.py b/miles/rollout/generate_hub/multi_turn_single_sample.py index 9a09630126..d012704f4c 100644 --- a/miles/rollout/generate_hub/multi_turn_single_sample.py +++ b/miles/rollout/generate_hub/multi_turn_single_sample.py @@ -39,7 +39,7 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: prompt = sample.prompt if not isinstance(prompt, str): prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True, tools=tool_specs) - prompt_tokens_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"] + prompt_tokens_ids = tokenizer.encode(prompt, add_special_tokens=False) response = "" response_token_ids = [] @@ -66,13 +66,8 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: output = await post(url, payload) - # Handle abort - if output["meta_info"]["finish_reason"]["type"] == "abort": - sample.status = Sample.Status.ABORTED - return GenerateFnOutput(samples=sample) - cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]] - cur_response = tokenizer.decode(cur_response_token_ids) + cur_response = output["text"] cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] if sample.rollout_log_probs is None: sample.rollout_log_probs = [] @@ -82,8 +77,11 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: response_token_ids += cur_response_token_ids loss_masks += [1] * len(cur_response_token_ids) - # Check length limit - if output["meta_info"]["finish_reason"]["type"] == "length": + # Set status + sample.update_from_meta_info(args, output["meta_info"]) + + finish_reason_type = output["meta_info"]["finish_reason"]["type"] + if finish_reason_type in ("abort", "length"): break _, parsed_tool_calls = tool_call_parser.parse_non_stream(cur_response) @@ -104,24 +102,17 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: sample.rollout_log_probs ), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps" - if turn >= args.generate_max_tool_calls: - break - # Set sample attributes sample.tokens = prompt_tokens_ids + response_token_ids sample.response_length = len(response_token_ids) sample.response = response sample.loss_mask = loss_masks - # Set status - sample.update_from_meta_info(args, output["meta_info"]) - return GenerateFnOutput(samples=sample) def _add_arguments(parser: argparse.ArgumentParser): parser.add_argument("--generate-max-turns", type=int, default=16) - parser.add_argument("--generate-max-tool-calls", type=int, default=16) parser.add_argument("--generate-tool-specs-path", type=str) parser.add_argument("--generate-tool-call-parser", type=str) parser.add_argument("--generate-execute-tool-function-path", type=str) diff --git a/tests/fixtures/generation_fixtures.py b/tests/fixtures/generation_fixtures.py index d00424b829..0b030da895 100644 --- a/tests/fixtures/generation_fixtures.py +++ b/tests/fixtures/generation_fixtures.py @@ -103,7 +103,6 @@ def make_args( extra_argv: list[str] | None = None, custom_generate_function_path: str | None = None, generate_max_turns: int = 16, - generate_max_tool_calls: int = 16, generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", generate_tool_call_parser: str = "qwen25", generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", @@ -143,7 +142,6 @@ def make_args( if variant == "multi_turn_single_sample": argv.extend(["--generate-max-turns", str(generate_max_turns)]) - argv.extend(["--generate-max-tool-calls", str(generate_max_tool_calls)]) argv.extend(["--generate-tool-specs-path", generate_tool_specs_path]) argv.extend(["--generate-tool-call-parser", generate_tool_call_parser]) argv.extend(["--generate-execute-tool-function-path", generate_execute_tool_function_path]) diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py index 6b105b4f7e..f13a23954c 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -210,7 +210,6 @@ def test_partial_rollout_not_supported(self, variant, generation_env): _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) def test_abort_preserves_content(self, variant, generation_env): - pytest.skip("TODO: support") generation_env.mock_server.process_fn = lambda _: ProcessResult( text=SINGLE_TURN_RESPONSE, finish_reason="abort" ) diff --git a/tests/rollout/generate_hub/test_single_turn.py b/tests/rollout/generate_hub/test_single_turn.py index b3de35341d..bb3f697b75 100644 --- a/tests/rollout/generate_hub/test_single_turn.py +++ b/tests/rollout/generate_hub/test_single_turn.py @@ -175,8 +175,6 @@ class TestFinishReason: indirect=["generation_env"], ) def test_finish_reason_sets_status(self, variant, generation_env, expected_status): - if variant == "multi_turn_single_sample" and expected_status == Sample.Status.ABORTED: - pytest.skip("TODO: support") result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] assert result.sample == expected_sample(variant, status=expected_status) @@ -196,6 +194,7 @@ class TestRoutedExperts: def test_routed_experts_enabled_and_parsed(self, variant, generation_env): if variant == "multi_turn_single_sample": pytest.skip("TODO: support") + num_layers, moe_router_topk = 2, 4 num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS) routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape( From 0626e73ca8529f1821ecfd85eb4ab0f837b4fe15 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:14:22 +0800 Subject: [PATCH 30/77] Refactor and unify multi-turn-single-sample and single-turn (#466) --- .../generate_hub/generate_endpoint_wrapper.py | 67 +++++----- .../generate_hub/multi_turn_single_sample.py | 114 +++++------------- miles/rollout/generate_hub/single_turn.py | 25 ++-- miles/rollout/generate_hub/tool_call_utils.py | 43 +++++++ miles/utils/misc.py | 1 + .../rollout/generate_hub/test_single_turn.py | 2 +- 6 files changed, 130 insertions(+), 122 deletions(-) diff --git a/miles/rollout/generate_hub/generate_endpoint_wrapper.py b/miles/rollout/generate_hub/generate_endpoint_wrapper.py index c927c05794..39fd419aac 100644 --- a/miles/rollout/generate_hub/generate_endpoint_wrapper.py +++ b/miles/rollout/generate_hub/generate_endpoint_wrapper.py @@ -2,6 +2,8 @@ Wrapper to integrate SGLang's `/generate` endpoint with RL things like Sample. """ +from typing import Any + import numpy as np import pybase64 @@ -10,49 +12,52 @@ # Make this an isolated function because users may want to compute their own -async def compute_prompt_ids_from_sample(state, sample): +def compute_prompt_ids_from_sample(state, sample, tools=None): + prompt = sample.prompt + if state.processor: - processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) + processor_output = state.processor(text=prompt, **sample.multimodal_inputs) prompt_ids = processor_output["input_ids"][0] + # TODO shall we move it to other places? then can make this function immutable sample.multimodal_train_inputs = { k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] } or None + return prompt_ids else: - return state.tokenizer.encode(sample.prompt, add_special_tokens=False) - - -async def compute_request_payload(state, sample, prompt_ids: list[int], sampling_params: dict): - assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}" - - max_new_tokens = sampling_params.pop("max_new_tokens") - if len(sample.response) > 0: - max_new_tokens -= len(sample.tokens) - len(prompt_ids) - - # Prepare payload for sglang server + if not isinstance(prompt, str): + prompt = state.tokenizer.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True, tools=tools + ) + + return state.tokenizer.encode(prompt, add_special_tokens=False) + + +# Thin wrapper to construct request payload. +# Make it a function to allow adding logics like `return_routed_experts` in the future +# without requiring users to change their code. +def compute_request_payload( + args, + input_ids: list[int], + sampling_params: dict, + multimodal_inputs: dict | None = None, +) -> dict[str, Any]: payload = { - # Use existing tokens for multi-turn or tokenize the new prompt - "input_ids": sample.tokens if len(sample.response) > 0 else prompt_ids, - "sampling_params": { - **sampling_params, - "max_new_tokens": max_new_tokens, - }, + "input_ids": input_ids, + "sampling_params": sampling_params, "return_logprob": True, - "return_routed_experts": state.args.use_rollout_routing_replay, + "return_routed_experts": args.use_rollout_routing_replay, } - if image_data := (sample.multimodal_inputs or {}).get("images"): + if image_data := (multimodal_inputs or {}).get("images"): payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] - assert payload["sampling_params"]["max_new_tokens"] >= 0 - - if payload["sampling_params"]["max_new_tokens"] == 0: - return None, Sample.Status.TRUNCATED + return payload - return payload, None - -async def update_sample_from_response(args, sample: Sample, payload: dict, output: dict): +async def update_sample_from_response( + args, sample: Sample, payload: dict, output: dict, update_loss_mask: bool = False +): # Initialize sample.tokens for the first turn if (len(sample.response) == 0) and not sample.tokens: sample.tokens = payload["input_ids"] @@ -62,6 +67,8 @@ async def update_sample_from_response(args, sample: Sample, payload: dict, outpu # TODO may rename to match await postprocess_sample_with_radix_tree(args, sample, output) + + assert not update_loss_mask, "This code branch has not implemented update_loss_mask" else: if x := output["meta_info"].get("output_token_logprobs"): new_response_tokens = [item[1] for item in x] @@ -78,6 +85,10 @@ async def update_sample_from_response(args, sample: Sample, payload: dict, outpu sample.rollout_log_probs = [] sample.rollout_log_probs += new_response_log_probs + if update_loss_mask: + sample.loss_mask += [1] * len(new_response_tokens) + + # TODO handle multi-turn cases (may need concat instead of assignment) sample.rollout_routed_experts = _get_rollout_routed_experts_from_response(args, sample, output) # TODO may unify (currently there are both methods inside Sample and separate functions) diff --git a/miles/rollout/generate_hub/multi_turn_single_sample.py b/miles/rollout/generate_hub/multi_turn_single_sample.py index d012704f4c..852ef9159e 100644 --- a/miles/rollout/generate_hub/multi_turn_single_sample.py +++ b/miles/rollout/generate_hub/multi_turn_single_sample.py @@ -3,15 +3,18 @@ """ import argparse -import json -import uuid - -from pydantic import TypeAdapter -from sglang.srt.entrypoints.openai.protocol import Tool -from sglang.srt.function_call.function_call_parser import FunctionCallParser from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput -from miles.rollout.generate_hub.tool_call_utils import tokenize_tool_responses +from miles.rollout.generate_hub.generate_endpoint_wrapper import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.rollout.generate_hub.tool_call_utils import ( + create_tool_call_parser, + execute_tool_calls, + update_sample_with_tool_responses, +) from miles.utils.http_utils import post from miles.utils.misc import load_function from miles.utils.types import Sample @@ -21,33 +24,24 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: args = input.args sample = input.sample tokenizer = input.state.tokenizer - - assert not args.partial_rollout, "Partial rollout is not supported for " "this function at the moment." + assert not args.partial_rollout url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" execute_tool_function = load_function(args.generate_execute_tool_function_path) tool_specs = load_function(args.generate_tool_specs_path) - assert isinstance(tool_specs, list) + tool_call_parser = create_tool_call_parser(tool_specs, args.generate_tool_call_parser) - tool_call_parser = FunctionCallParser( - tools=(TypeAdapter(list[Tool]).validate_python(tool_specs)), - tool_call_parser=args.generate_tool_call_parser, - ) + prompt_tokens_ids = compute_prompt_ids_from_sample(input.state, sample, tools=tool_specs) - prompt = sample.prompt - if not isinstance(prompt, str): - prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True, tools=tool_specs) - prompt_tokens_ids = tokenizer.encode(prompt, add_special_tokens=False) + sample.loss_mask = [] + sample.tokens = prompt_tokens_ids.copy() - response = "" - response_token_ids = [] - loss_masks = [] - - for turn in range(args.generate_max_turns): + for _turn in range(args.generate_max_turns): + # TODO handle separately # Check if total length exceeds max context length - total_length = len(prompt_tokens_ids) + len(response_token_ids) + total_length = len(sample.tokens) if args.rollout_max_context_len is not None: max_context_length = args.rollout_max_context_len else: @@ -56,57 +50,23 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: sample.status = Sample.Status.TRUNCATED break - # Use token IDs instead of text - current_token_ids = prompt_tokens_ids + response_token_ids - payload = { - "input_ids": current_token_ids, - "sampling_params": input.sampling_params, - "return_logprob": True, # Request log probabilities for training - } + # ----------------------- Call inference endpoint ------------------------- + payload = compute_request_payload(args, sample.tokens, input.sampling_params) output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output) - cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]] - cur_response = output["text"] - cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] - if sample.rollout_log_probs is None: - sample.rollout_log_probs = [] - sample.rollout_log_probs += cur_log_probs - - response += cur_response - response_token_ids += cur_response_token_ids - loss_masks += [1] * len(cur_response_token_ids) - - # Set status - sample.update_from_meta_info(args, output["meta_info"]) - - finish_reason_type = output["meta_info"]["finish_reason"]["type"] - if finish_reason_type in ("abort", "length"): + if output["meta_info"]["finish_reason"]["type"] in ("abort", "length"): break - _, parsed_tool_calls = tool_call_parser.parse_non_stream(cur_response) - if len(parsed_tool_calls) == 0: - break + # ----------------------- Execute tools ------------------------- - tool_messages = await _execute_tool_calls(parsed_tool_calls, execute_tool_function) - - next_obs_tokens_ids: list[int] = tokenize_tool_responses(tool_messages, tokenizer=tokenizer) - # TODO is this ok? - response += tokenizer.decode(next_obs_tokens_ids) - response_token_ids += next_obs_tokens_ids - loss_masks += [0] * len(next_obs_tokens_ids) - - sample.rollout_log_probs += [0.0] * len(next_obs_tokens_ids) - - assert len(response_token_ids) == len( - sample.rollout_log_probs - ), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps" + _, tool_calls = tool_call_parser.parse_non_stream(output["text"]) + if len(tool_calls) == 0: + break - # Set sample attributes - sample.tokens = prompt_tokens_ids + response_token_ids - sample.response_length = len(response_token_ids) - sample.response = response - sample.loss_mask = loss_masks + tool_messages = await execute_tool_calls(tool_calls, execute_tool_function) + update_sample_with_tool_responses(sample, tool_messages, tokenizer=tokenizer) return GenerateFnOutput(samples=sample) @@ -119,21 +79,3 @@ def _add_arguments(parser: argparse.ArgumentParser): generate.add_arguments = _add_arguments - - -async def _execute_tool_calls(parsed_tool_calls, execute_one) -> list[dict]: - tool_messages = [] - for call in parsed_tool_calls: - params = json.loads(call.parameters) if call.parameters else {} - result = await execute_one(call.name, params) - assert isinstance(result, str) - tool_messages.append( - { - "role": "tool", - # src: serving_chat.py :: _process_tool_call_id - "tool_call_id": f"call_{uuid.uuid4().hex[:24]}", - "content": result, - "name": call.name, - } - ) - return tool_messages diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py index f8c52d490d..8e1d7f2124 100644 --- a/miles/rollout/generate_hub/single_turn.py +++ b/miles/rollout/generate_hub/single_turn.py @@ -9,23 +9,34 @@ update_sample_from_response, ) from miles.utils.http_utils import post +from miles.utils.types import Sample async def generate(input: GenerateFnInput) -> GenerateFnOutput: args = input.args sample = input.sample - + sampling_params = input.sampling_params + assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}" url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - prompt_ids = await compute_prompt_ids_from_sample(input.state, sample) - payload, halt_status = await compute_request_payload(input.state, sample, prompt_ids, input.sampling_params) + prompt_ids = compute_prompt_ids_from_sample(input.state, sample) - if payload is None: - sample.status = halt_status - return GenerateFnOutput(samples=sample) + # Handle Partial Rollout resuming + if len(sample.response) > 0: + input_ids = sample.tokens + sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) - output = await post(url, payload) + assert sampling_params["max_new_tokens"] >= 0 + if sampling_params["max_new_tokens"] == 0: + sample.status = Sample.Status.TRUNCATED + return GenerateFnOutput(samples=sample) + else: + input_ids = prompt_ids + payload = compute_request_payload( + args, input_ids=input_ids, sampling_params=sampling_params, multimodal_inputs=sample.multimodal_inputs + ) + output = await post(url, payload) await update_sample_from_response(args, sample, payload=payload, output=output) return GenerateFnOutput(samples=sample) diff --git a/miles/rollout/generate_hub/tool_call_utils.py b/miles/rollout/generate_hub/tool_call_utils.py index d8a1ca574e..12ce362c06 100644 --- a/miles/rollout/generate_hub/tool_call_utils.py +++ b/miles/rollout/generate_hub/tool_call_utils.py @@ -1,9 +1,52 @@ +import json +import uuid +from collections.abc import Callable from typing import Any +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.function_call_parser import FunctionCallParser + +from miles.utils.types import Sample _DUMMY_USER = {"role": "user", "content": "dummy"} +def create_tool_call_parser(tool_specs, tool_call_parser): + return FunctionCallParser( + tools=TypeAdapter(list[Tool]).validate_python(tool_specs), + tool_call_parser=tool_call_parser, + ) + + +async def execute_tool_calls(tool_calls: list[ToolCallItem], execute_one: Callable) -> list[dict[str, Any]]: + tool_messages = [] + for call in tool_calls: + params = json.loads(call.parameters) if call.parameters else {} + result = await execute_one(call.name, params) + assert isinstance(result, str) + tool_messages.append( + { + "role": "tool", + # src: serving_chat.py :: _process_tool_call_id + "tool_call_id": f"call_{uuid.uuid4().hex[:24]}", + "content": result, + "name": call.name, + } + ) + return tool_messages + + +def update_sample_with_tool_responses(sample: Sample, tool_messages: list[dict[str, Any]], tokenizer): + next_obs_tokens_ids: list[int] = tokenize_tool_responses(tool_messages, tokenizer=tokenizer) + sample.response += tokenizer.decode(next_obs_tokens_ids) + sample.response_length += len(next_obs_tokens_ids) + sample.tokens += next_obs_tokens_ids + sample.loss_mask += [0] * len(next_obs_tokens_ids) + sample.rollout_log_probs += [0.0] * len(next_obs_tokens_ids) + + # TODO: very naive implementation, need the to-be-implemented e2e test to validate. def tokenize_tool_responses( tool_messages: list[dict[str, Any]], diff --git a/miles/utils/misc.py b/miles/utils/misc.py index 88e2213518..bae72ec0d7 100644 --- a/miles/utils/misc.py +++ b/miles/utils/misc.py @@ -36,6 +36,7 @@ def _unregister(self, name: str) -> None: function_registry = FunctionRegistry() +# TODO may rename to `load_object` since it can be used to load things like tool_specs def load_function(path): """ Load a function from registry or module. diff --git a/tests/rollout/generate_hub/test_single_turn.py b/tests/rollout/generate_hub/test_single_turn.py index bb3f697b75..eb85a854ae 100644 --- a/tests/rollout/generate_hub/test_single_turn.py +++ b/tests/rollout/generate_hub/test_single_turn.py @@ -42,7 +42,7 @@ def expected_request( "sampling_params": sampling_params or SAMPLING_PARAMS, "return_logprob": True, } - if variant == "single_turn" or return_routed_experts: + if variant in ("single_turn", "multi_turn_single_sample") or return_routed_experts: result["return_routed_experts"] = return_routed_experts if image_data is not None: result["image_data"] = image_data From c309dddb396de1b42f147e1db61597167c5bbd03 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:14:43 +0800 Subject: [PATCH 31/77] Support and refactor rollout-max-context-len (#468) --- .../generate_hub/generate_endpoint_wrapper.py | 9 +++- .../generate_hub/multi_turn_single_sample.py | 24 ++++----- miles/rollout/generate_hub/single_turn.py | 6 ++- tests/fixtures/generation_fixtures.py | 6 ++- tests/rollout/generate_hub/test_multi_turn.py | 53 +++++++++++++++++++ .../rollout/generate_hub/test_single_turn.py | 42 +++++++++++++-- 6 files changed, 116 insertions(+), 24 deletions(-) diff --git a/miles/rollout/generate_hub/generate_endpoint_wrapper.py b/miles/rollout/generate_hub/generate_endpoint_wrapper.py index 39fd419aac..858a2550a5 100644 --- a/miles/rollout/generate_hub/generate_endpoint_wrapper.py +++ b/miles/rollout/generate_hub/generate_endpoint_wrapper.py @@ -42,7 +42,12 @@ def compute_request_payload( input_ids: list[int], sampling_params: dict, multimodal_inputs: dict | None = None, -) -> dict[str, Any]: +) -> tuple[dict[str, Any] | None, Sample.Status | None]: + # TODO need to adjust sampling_params.max_new_tokens when input is moderately long + max_context_length = args.rollout_max_context_len or float("inf") + if len(input_ids) >= max_context_length: + return None, Sample.Status.TRUNCATED + payload = { "input_ids": input_ids, "sampling_params": sampling_params, @@ -52,7 +57,7 @@ def compute_request_payload( if image_data := (multimodal_inputs or {}).get("images"): payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] - return payload + return payload, None async def update_sample_from_response( diff --git a/miles/rollout/generate_hub/multi_turn_single_sample.py b/miles/rollout/generate_hub/multi_turn_single_sample.py index 852ef9159e..2f969cef69 100644 --- a/miles/rollout/generate_hub/multi_turn_single_sample.py +++ b/miles/rollout/generate_hub/multi_turn_single_sample.py @@ -17,10 +17,11 @@ ) from miles.utils.http_utils import post from miles.utils.misc import load_function -from miles.utils.types import Sample async def generate(input: GenerateFnInput) -> GenerateFnOutput: + # ----------------------- Setup ------------------------- + args = input.args sample = input.sample tokenizer = input.state.tokenizer @@ -33,28 +34,23 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: tool_specs = load_function(args.generate_tool_specs_path) tool_call_parser = create_tool_call_parser(tool_specs, args.generate_tool_call_parser) + # ----------------------- Initial prompts ------------------------- + prompt_tokens_ids = compute_prompt_ids_from_sample(input.state, sample, tools=tool_specs) sample.loss_mask = [] sample.tokens = prompt_tokens_ids.copy() for _turn in range(args.generate_max_turns): - # TODO handle separately - # Check if total length exceeds max context length - total_length = len(sample.tokens) - if args.rollout_max_context_len is not None: - max_context_length = args.rollout_max_context_len - else: - max_context_length = args.context_parallel_size * args.max_tokens_per_gpu - if total_length >= max_context_length: - sample.status = Sample.Status.TRUNCATED - break - # ----------------------- Call inference endpoint ------------------------- - payload = compute_request_payload(args, sample.tokens, input.sampling_params) + payload, halt_status = compute_request_payload(args, sample.tokens, input.sampling_params) + if payload is None: + sample.status = halt_status + break + output = await post(url, payload) - await update_sample_from_response(args, sample, payload=payload, output=output) + await update_sample_from_response(args, sample, payload=payload, output=output, update_loss_mask=True) if output["meta_info"]["finish_reason"]["type"] in ("abort", "length"): break diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py index 8e1d7f2124..ff976e29dd 100644 --- a/miles/rollout/generate_hub/single_turn.py +++ b/miles/rollout/generate_hub/single_turn.py @@ -33,9 +33,13 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: else: input_ids = prompt_ids - payload = compute_request_payload( + payload, halt_status = compute_request_payload( args, input_ids=input_ids, sampling_params=sampling_params, multimodal_inputs=sample.multimodal_inputs ) + if payload is None: + sample.status = halt_status + return GenerateFnOutput(samples=sample) + output = await post(url, payload) await update_sample_from_response(args, sample, payload=payload, output=output) diff --git a/tests/fixtures/generation_fixtures.py b/tests/fixtures/generation_fixtures.py index 0b030da895..f9131c8391 100644 --- a/tests/fixtures/generation_fixtures.py +++ b/tests/fixtures/generation_fixtures.py @@ -106,7 +106,7 @@ def make_args( generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", generate_tool_call_parser: str = "qwen25", generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", - rollout_max_context_len: int = 4096, + rollout_max_context_len: int | None = None, ) -> Namespace: argv = [ "pytest", @@ -139,13 +139,15 @@ def make_args( argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) if custom_generate_function_path: argv.extend(["--custom-generate-function-path", custom_generate_function_path]) + if rollout_max_context_len is not None: + argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) if variant == "multi_turn_single_sample": argv.extend(["--generate-max-turns", str(generate_max_turns)]) argv.extend(["--generate-tool-specs-path", generate_tool_specs_path]) argv.extend(["--generate-tool-call-parser", generate_tool_call_parser]) argv.extend(["--generate-execute-tool-function-path", generate_execute_tool_function_path]) - argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) + if extra_argv: argv.extend(extra_argv) diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py index f13a23954c..4a836cbce8 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -113,6 +113,7 @@ def expected_request(input_ids: list[int], sampling_params: dict | None = None) "input_ids": input_ids, "sampling_params": sampling_params or DEFAULT_SAMPLING_PARAMS, "return_logprob": True, + "return_routed_experts": False, } @@ -288,3 +289,55 @@ def test_max_turns_reached(self, variant, generation_env): response_length=45 + 31, ), ) + + +class TestRespectMaxContextLen: + @pytest.mark.parametrize( + "generation_env", [{"args_kwargs": {"rollout_max_context_len": SINGLE_TURN_PROMPT_TOKEN_LEN}}], indirect=True + ) + def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + assert result.requests == [] + verify_sample( + result.sample, + expected_chunks=[], + expected_partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, + response="", + response_length=0, + status=Sample.Status.TRUNCATED, + ), + ) + + @pytest.mark.parametrize( + "generation_env", + [{"args_kwargs": {"rollout_max_context_len": len(FIRST_PROMPT_TOKEN_IDS) + 45 + 31}}], + indirect=True, + ) + def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + generation_env.mock_server.process_fn = multi_turn_tool_call_process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) + + assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] + verify_sample( + result.sample, + expected_chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ), + SampleParsedChunk( + tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, + loss_mask_value=0, + rollout_log_probs=[0.0] * 31, + ), + ], + expected_partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE, + response_length=45 + 31, + status=Sample.Status.TRUNCATED, + ), + ) diff --git a/tests/rollout/generate_hub/test_single_turn.py b/tests/rollout/generate_hub/test_single_turn.py index eb85a854ae..077f1665bd 100644 --- a/tests/rollout/generate_hub/test_single_turn.py +++ b/tests/rollout/generate_hub/test_single_turn.py @@ -49,14 +49,21 @@ def expected_request( return result +class _Unset: + pass + + +_UNSET = _Unset() + + def expected_sample( variant: str, *, prompt: str = PROMPT, response: str = RESPONSE_TEXT, response_length: int = 5, - tokens: list[int] | None = None, - rollout_log_probs: list[float] | None = None, + tokens: list[int] | None | _Unset = _UNSET, + rollout_log_probs: list[float] | None | _Unset = _UNSET, status: Sample.Status = Sample.Status.COMPLETED, cached_tokens: int = 0, prompt_tokens: int = 7, @@ -72,7 +79,7 @@ def expected_sample( group_index=None, index=None, prompt=prompt, - tokens=tokens if tokens is not None else PROMPT_TOKENS + RESPONSE_TOKENS, + tokens=PROMPT_TOKENS + RESPONSE_TOKENS if isinstance(tokens, _Unset) else tokens, multimodal_inputs=multimodal_inputs, multimodal_train_inputs=multimodal_train_inputs, response=response, @@ -81,7 +88,7 @@ def expected_sample( reward=None, loss_mask=loss_mask, weight_versions=weight_versions or [], - rollout_log_probs=rollout_log_probs if rollout_log_probs is not None else RESPONSE_LOG_PROBS, + rollout_log_probs=RESPONSE_LOG_PROBS if isinstance(rollout_log_probs, _Unset) else rollout_log_probs, rollout_routed_experts=rollout_routed_experts, remove_sample=False, status=status, @@ -282,7 +289,32 @@ def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): result = _run_generate(variant, generation_env, sample, {"max_new_tokens": 10, "temperature": 0.7}) assert result.requests == [] - assert result.sample.status == Sample.Status.TRUNCATED + assert result.sample == expected_sample( + variant, + response="x" * 10, + response_length=10, + tokens=existing_tokens, + rollout_log_probs=[], + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + ) + + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"rollout_max_context_len": 5}}], indirect=True) + def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + result = _run_generate(variant, generation_env) + assert result.requests == [] + tokens = PROMPT_TOKENS if variant == "multi_turn_single_sample" else [] + assert result.sample == expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + ) class TestEmptyResponse: From 4c60a346a1c782aad2301a9d43e8fbcff922eec0 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:15:01 +0800 Subject: [PATCH 32/77] Support multiple output samples in addition to single sample in multi-turn (#469) --- .../generate_hub/generate_endpoint_wrapper.py | 2 + miles/rollout/generate_hub/multi_turn.py | 88 +++++ tests/fixtures/generation_fixtures.py | 13 +- tests/rollout/generate_hub/test_multi_turn.py | 332 +++++++++++------- .../rollout/generate_hub/test_single_turn.py | 82 +++-- 5 files changed, 357 insertions(+), 160 deletions(-) create mode 100644 miles/rollout/generate_hub/multi_turn.py diff --git a/miles/rollout/generate_hub/generate_endpoint_wrapper.py b/miles/rollout/generate_hub/generate_endpoint_wrapper.py index 858a2550a5..c6c7803f92 100644 --- a/miles/rollout/generate_hub/generate_endpoint_wrapper.py +++ b/miles/rollout/generate_hub/generate_endpoint_wrapper.py @@ -91,6 +91,8 @@ async def update_sample_from_response( sample.rollout_log_probs += new_response_log_probs if update_loss_mask: + if sample.loss_mask is None: + sample.loss_mask = [] sample.loss_mask += [1] * len(new_response_tokens) # TODO handle multi-turn cases (may need concat instead of assignment) diff --git a/miles/rollout/generate_hub/multi_turn.py b/miles/rollout/generate_hub/multi_turn.py new file mode 100644 index 0000000000..2c01a8ba2d --- /dev/null +++ b/miles/rollout/generate_hub/multi_turn.py @@ -0,0 +1,88 @@ +""" +Simple multi-turn generation with tool calling. +""" + +import argparse +from copy import deepcopy + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_hub.generate_endpoint_wrapper import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.rollout.generate_hub.tool_call_utils import ( + create_tool_call_parser, + execute_tool_calls, + update_sample_with_tool_responses, +) +from miles.utils.http_utils import post +from miles.utils.misc import load_function + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + # ----------------------- Setup ------------------------- + + args = input.args + sample = deepcopy(input.sample) + tokenizer = input.state.tokenizer + assert not args.partial_rollout, "Partial rollout is not supported" + + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + execute_tool_function = load_function(args.generate_execute_tool_function_path) + + tool_specs = load_function(args.generate_tool_specs_path) + tool_call_parser = create_tool_call_parser(tool_specs, args.generate_tool_call_parser) + + multi_samples = [] + + # ----------------------- Initial prompts ------------------------- + + prompt_tokens_ids = compute_prompt_ids_from_sample(input.state, sample, tools=tool_specs) + + sample.tokens = prompt_tokens_ids.copy() + + for _turn in range(args.generate_max_turns): + # ----------------------- Call inference endpoint ------------------------- + + payload, halt_status = compute_request_payload(args, sample.tokens, input.sampling_params) + if payload is None: + sample.status = halt_status + if args.generate_multi_samples and multi_samples: + multi_samples[-1].status = halt_status + break + + if args.generate_multi_samples: + sample = deepcopy(input.sample) + + output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output, update_loss_mask=True) + + if args.generate_multi_samples: + multi_samples.append(deepcopy(sample)) + + if output["meta_info"]["finish_reason"]["type"] in ("abort", "length"): + break + + # ----------------------- Execute tools ------------------------- + + _, tool_calls = tool_call_parser.parse_non_stream(output["text"]) + if len(tool_calls) == 0: + break + + tool_messages = await execute_tool_calls(tool_calls, execute_tool_function) + update_sample_with_tool_responses(sample, tool_messages, tokenizer=tokenizer) + + return GenerateFnOutput(samples=multi_samples if args.generate_multi_samples else sample) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-max-turns", type=int, default=16) + parser.add_argument("--generate-tool-specs-path", type=str) + parser.add_argument("--generate-tool-call-parser", type=str) + parser.add_argument("--generate-execute-tool-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true") + + +generate.add_arguments = _add_arguments diff --git a/tests/fixtures/generation_fixtures.py b/tests/fixtures/generation_fixtures.py index f9131c8391..9ce618bbdb 100644 --- a/tests/fixtures/generation_fixtures.py +++ b/tests/fixtures/generation_fixtures.py @@ -25,10 +25,15 @@ VARIANT_TO_GENERATE_FN_PATH = { "old_sglang_rollout": "miles.rollout.sglang_rollout.generate", "single_turn": "miles.rollout.generate_hub.single_turn.generate", - "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn_single_sample.generate", + "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn.generate", + "multi_turn_multi_samples": "miles.rollout.generate_hub.multi_turn.generate", } +def listify(x): + return x if isinstance(x, list) else [x] + + def make_sample( *, prompt: str | list[dict] = "What is 1+7?", @@ -56,7 +61,7 @@ class GenerateEnv: @dataclass class GenerateResult: - sample: Sample + sample: Sample | list[Sample] requests: list[dict] @@ -142,11 +147,13 @@ def make_args( if rollout_max_context_len is not None: argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) - if variant == "multi_turn_single_sample": + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): argv.extend(["--generate-max-turns", str(generate_max_turns)]) argv.extend(["--generate-tool-specs-path", generate_tool_specs_path]) argv.extend(["--generate-tool-call-parser", generate_tool_call_parser]) argv.extend(["--generate-execute-tool-function-path", generate_execute_tool_function_path]) + if variant == "multi_turn_multi_samples": + argv.append("--generate-multi-samples") if extra_argv: argv.extend(extra_argv) diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py index 4a836cbce8..dfdde99b34 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -3,7 +3,7 @@ from itertools import groupby import pytest -from tests.fixtures.generation_fixtures import GenerateEnv, generation_env, make_sample, run_generate +from tests.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate from transformers import AutoTokenizer from miles.utils.test_utils.mock_sglang_server import ProcessResult @@ -30,7 +30,7 @@ SECOND_PROMPT_TOKEN_IDS = TOKENIZER(MULTI_TURN_SECOND_PROMPT, add_special_tokens=False)["input_ids"] -@pytest.fixture(params=["multi_turn_single_sample"]) +@pytest.fixture(params=["multi_turn_single_sample", "multi_turn_multi_samples"]) def variant(request): return request.param @@ -42,6 +42,12 @@ class SampleParsedChunk: rollout_log_probs: list[float] +@dataclass +class ExpectedSampleInfo: + chunks: list[SampleParsedChunk] + partial_sample: Sample + + def parse_sample_into_chunks(sample: Sample, tokenizer) -> list[SampleParsedChunk]: prompt_len = len(sample.tokens) - sample.response_length response_tokens = sample.tokens[prompt_len:] @@ -85,23 +91,22 @@ def expected_partial_sample( ) -def verify_sample( - actual: Sample, - *, - expected_chunks: list[SampleParsedChunk], - expected_partial_sample: Sample, -): - actual_chunks = parse_sample_into_chunks(actual, TOKENIZER) - assert actual_chunks == expected_chunks - - actual_partial = replace( - deepcopy(actual), - tokens=[], - loss_mask=[], - rollout_log_probs=[], - prefix_cache_info=Sample.PrefixCacheInfo(), - ) - assert actual_partial == expected_partial_sample +def verify_samples(actual: Sample | list[Sample], expected: list[ExpectedSampleInfo]): + actual = listify(actual) + assert len(actual) == len(expected) + + for actual_item, expected_item in zip(actual, expected, strict=True): + actual_chunks = parse_sample_into_chunks(actual_item, TOKENIZER) + assert actual_chunks == expected_item.chunks + + actual_partial = replace( + deepcopy(actual_item), + tokens=[], + loss_mask=[], + rollout_log_probs=[], + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + assert actual_partial == expected_item.partial_sample def _run_generate(variant: str, env: GenerateEnv, sample: Sample, sampling_params: dict | None = None): @@ -151,20 +156,22 @@ def test_single_turn_no_tool_call(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] - verify_sample( + verify_samples( result.sample, - expected_chunks=[ - SampleParsedChunk( - tokens_decoded_str=SINGLE_TURN_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(6)], + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, response=SINGLE_TURN_RESPONSE, response_length=6 + ), ), ], - expected_partial_sample=expected_partial_sample( - prompt=SINGLE_TURN_PROMPT, - response=SINGLE_TURN_RESPONSE, - response_length=6, - ), ) def test_two_turns_with_tool_call(self, variant, generation_env): @@ -176,31 +183,63 @@ def test_two_turns_with_tool_call(self, variant, generation_env): expected_request(FIRST_PROMPT_TOKEN_IDS), expected_request(SECOND_PROMPT_TOKEN_IDS), ] - verify_sample( - result.sample, - expected_chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ), + SampleParsedChunk( + tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, loss_mask_value=0, rollout_log_probs=[0.0] * 31 + ), + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_SECOND_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(24)], + ), + ], + partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE + MULTI_TURN_SECOND_RESPONSE, + response_length=45 + 31 + 24, + ), ), - SampleParsedChunk( - tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, - loss_mask_value=0, - rollout_log_probs=[0.0] * 31, + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ) + ], + partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE, + response_length=45, + ), ), - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_SECOND_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(24)], + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_SECOND_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(24)], + ) + ], + partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_SECOND_RESPONSE, + response_length=24, + ), ), - ], - expected_partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE + MULTI_TURN_SECOND_RESPONSE, - response_length=45 + 31 + 24, - ), - ) + ] + verify_samples(result.sample, expected) class TestExitConditions: @@ -218,21 +257,25 @@ def test_abort_preserves_content(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] - verify_sample( + verify_samples( result.sample, - expected_chunks=[ - SampleParsedChunk( - tokens_decoded_str=SINGLE_TURN_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(6)], + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, + response=SINGLE_TURN_RESPONSE, + response_length=6, + status=Sample.Status.ABORTED, + ), ), ], - expected_partial_sample=expected_partial_sample( - prompt=SINGLE_TURN_PROMPT, - response=SINGLE_TURN_RESPONSE, - response_length=6, - status=Sample.Status.ABORTED, - ), ) def test_finish_reason_length_exits_and_preserves_content(self, variant, generation_env): @@ -243,21 +286,25 @@ def test_finish_reason_length_exits_and_preserves_content(self, variant, generat result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] - verify_sample( + verify_samples( result.sample, - expected_chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ) + ], + partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE, + response_length=45, + status=Sample.Status.TRUNCATED, + ), ), ], - expected_partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE, - response_length=45, - status=Sample.Status.TRUNCATED, - ), ) @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"generate_max_turns": 1}}], indirect=True) @@ -269,26 +316,44 @@ def test_max_turns_reached(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] - verify_sample( - result.sample, - expected_chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ), + SampleParsedChunk( + tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, loss_mask_value=0, rollout_log_probs=[0.0] * 31 + ), + ], + partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE, + response_length=45 + 31, + ), ), - SampleParsedChunk( - tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, - loss_mask_value=0, - rollout_log_probs=[0.0] * 31, + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ) + ], + partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE, + response_length=45, + ), ), - ], - expected_partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE, - response_length=45 + 31, - ), - ) + ] + verify_samples(result.sample, expected) class TestRespectMaxContextLen: @@ -298,16 +363,18 @@ class TestRespectMaxContextLen: def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) assert result.requests == [] - verify_sample( - result.sample, - expected_chunks=[], - expected_partial_sample=expected_partial_sample( - prompt=SINGLE_TURN_PROMPT, - response="", - response_length=0, - status=Sample.Status.TRUNCATED, - ), - ) + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, response="", response_length=0, status=Sample.Status.TRUNCATED + ), + ) + ] + else: + expected = [] + verify_samples(result.sample, expected) @pytest.mark.parametrize( "generation_env", @@ -320,24 +387,43 @@ def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, ge result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] - verify_sample( - result.sample, - expected_chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ), + SampleParsedChunk( + tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, loss_mask_value=0, rollout_log_probs=[0.0] * 31 + ), + ], + partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE, + response_length=45 + 31, + status=Sample.Status.TRUNCATED, + ), ), - SampleParsedChunk( - tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, - loss_mask_value=0, - rollout_log_probs=[0.0] * 31, + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(45)], + ) + ], + partial_sample=expected_partial_sample( + prompt=TWO_TURN_PROMPT, + response=MULTI_TURN_FIRST_RESPONSE, + response_length=45, + status=Sample.Status.TRUNCATED, + ), ), - ], - expected_partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE, - response_length=45 + 31, - status=Sample.Status.TRUNCATED, - ), - ) + ] + verify_samples(result.sample, expected) diff --git a/tests/rollout/generate_hub/test_single_turn.py b/tests/rollout/generate_hub/test_single_turn.py index 077f1665bd..824014276d 100644 --- a/tests/rollout/generate_hub/test_single_turn.py +++ b/tests/rollout/generate_hub/test_single_turn.py @@ -3,7 +3,7 @@ import pytest import torch from PIL import Image -from tests.fixtures.generation_fixtures import GenerateEnv, generation_env, make_sample, run_generate +from tests.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate from transformers import AutoProcessor from miles.utils.processing_utils import encode_image_for_rollout_engine @@ -24,7 +24,7 @@ SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} -@pytest.fixture(params=["old_sglang_rollout", "single_turn", "multi_turn_single_sample"]) +@pytest.fixture(params=["old_sglang_rollout", "single_turn", "multi_turn_single_sample", "multi_turn_multi_samples"]) def variant(request): return request.param @@ -42,7 +42,7 @@ def expected_request( "sampling_params": sampling_params or SAMPLING_PARAMS, "return_logprob": True, } - if variant in ("single_turn", "multi_turn_single_sample") or return_routed_experts: + if variant in ("single_turn", "multi_turn_single_sample", "multi_turn_multi_samples") or return_routed_experts: result["return_routed_experts"] = return_routed_experts if image_data is not None: result["image_data"] = image_data @@ -72,9 +72,16 @@ def expected_sample( spec_info: Sample.SpecInfo | None = None, multimodal_inputs: dict | None = None, multimodal_train_inputs: dict | None = None, + loss_mask: list[int] | None | _Unset = _UNSET, ) -> Sample: actual_response_length = response_length if response_length is not None else len(RESPONSE_TOKENS) - loss_mask = [1] * actual_response_length if variant == "multi_turn_single_sample" else None + if isinstance(loss_mask, _Unset): + loss_mask = ( + [1] * actual_response_length + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") + else None + ) + return Sample( group_index=None, index=None, @@ -122,12 +129,12 @@ class TestBasicGeneration: def test_basic_generation(self, variant, generation_env): result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] - assert result.sample == expected_sample(variant) + assert listify(result.sample) == [expected_sample(variant)] class TestResumedSingleTurn: def test_two_consecutive_calls_on_same_sample(self, variant, generation_env): - if variant == "multi_turn_single_sample": + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): pytest.skip("not tested yet") partial_text = "\\boxed" partial_tokens = [59, 79075] @@ -184,7 +191,7 @@ class TestFinishReason: def test_finish_reason_sets_status(self, variant, generation_env, expected_status): result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] - assert result.sample == expected_sample(variant, status=expected_status) + assert listify(result.sample) == [expected_sample(variant, status=expected_status)] class TestRoutedExperts: @@ -199,7 +206,7 @@ class TestRoutedExperts: indirect=True, ) def test_routed_experts_enabled_and_parsed(self, variant, generation_env): - if variant == "multi_turn_single_sample": + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): pytest.skip("TODO: support") num_layers, moe_router_topk = 2, 4 @@ -231,7 +238,7 @@ class TestMetaInfo: def test_meta_info_fields_updated(self, variant, generation_env): result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] - assert result.sample == expected_sample(variant, cached_tokens=3, weight_versions=["v1.0"]) + assert listify(result.sample) == [expected_sample(variant, cached_tokens=3, weight_versions=["v1.0"])] @pytest.mark.parametrize( "generation_env", @@ -246,12 +253,14 @@ def test_meta_info_fields_updated(self, variant, generation_env): def test_spec_info_updated(self, variant, generation_env): result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] - assert result.sample == expected_sample( - variant, - spec_info=Sample.SpecInfo( - spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3, completion_token_num=5 - ), - ) + assert listify(result.sample) == [ + expected_sample( + variant, + spec_info=Sample.SpecInfo( + spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3, completion_token_num=5 + ), + ) + ] class TestInputStatusValidation: @@ -259,11 +268,11 @@ class TestInputStatusValidation: def test_allowed_statuses(self, variant, generation_env, status): result = _run_generate(variant, generation_env, _make_sample(status=status)) assert result.requests == [expected_request(variant)] - assert result.sample.status == Sample.Status.COMPLETED + assert listify(result.sample) == [expected_sample(variant)] @pytest.mark.parametrize("status", [Sample.Status.COMPLETED, Sample.Status.TRUNCATED]) def test_rejected_statuses(self, variant, generation_env, status): - if variant == "multi_turn_single_sample": + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): pytest.skip("not tested yet") with pytest.raises(AssertionError): _run_generate(variant, generation_env, _make_sample(status=status)) @@ -277,12 +286,12 @@ def test_sampling_params_passed_through(self, variant, generation_env): assert result.requests == [ expected_request(variant, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) ] - assert result.sample == expected_sample(variant) + assert listify(result.sample) == [expected_sample(variant)] class TestBoundaryConditions: def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): - if variant == "multi_turn_single_sample": + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): pytest.skip("not tested yet") existing_tokens = [1, 2, 3, 4, 5, 6, 7] + list(range(100, 110)) sample = _make_sample(tokens=existing_tokens, response="x" * 10, response_length=10) @@ -294,7 +303,7 @@ def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): response="x" * 10, response_length=10, tokens=existing_tokens, - rollout_log_probs=[], + rollout_log_probs=None, status=Sample.Status.TRUNCATED, prompt_tokens=0, ) @@ -303,18 +312,23 @@ def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): if variant == "old_sglang_rollout": pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + if variant == "multi_turn_multi_samples": + pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") result = _run_generate(variant, generation_env) assert result.requests == [] - tokens = PROMPT_TOKENS if variant == "multi_turn_single_sample" else [] - assert result.sample == expected_sample( - variant, - response="", - response_length=0, - tokens=tokens, - rollout_log_probs=None, - status=Sample.Status.TRUNCATED, - prompt_tokens=0, - ) + tokens = PROMPT_TOKENS if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") else [] + assert listify(result.sample) == [ + expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, + ) + ] class TestEmptyResponse: @@ -322,9 +336,9 @@ class TestEmptyResponse: def test_empty_response(self, variant, generation_env): result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant)] - assert result.sample == expected_sample( - variant, response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[] - ) + assert listify(result.sample) == [ + expected_sample(variant, response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[]) + ] VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" @@ -333,7 +347,7 @@ def test_empty_response(self, variant, generation_env): class TestMultimodal: @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True) def test_multimodal_inputs_processed(self, variant, generation_env): - if variant == "multi_turn_single_sample": + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): pytest.skip("not tested yet") test_image = Image.new("RGB", (64, 64), color="red") multimodal_inputs = {"images": [test_image]} From c4667447371df74f6ffca74fb7f5052cabca42fa Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:15:21 +0800 Subject: [PATCH 33/77] Fix mock tool response with stop tokens (#471) --- miles/utils/test_utils/mock_tools.py | 4 +-- tests/rollout/generate_hub/test_multi_turn.py | 30 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py index 83f1d94327..faf8e0941e 100644 --- a/miles/utils/test_utils/mock_tools.py +++ b/miles/utils/test_utils/mock_tools.py @@ -77,7 +77,7 @@ async def execute_tool_call(name: str, params: dict) -> str: "\n" "\n" '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' - "" + "<|im_end|>\n" ) MULTI_TURN_SECOND_PROMPT = ( @@ -105,7 +105,7 @@ async def execute_tool_call(name: str, params: dict) -> str: "\n" "\n" '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' - "" + "<|im_end|>\n" "<|im_start|>user\n" "\n" '{"year": 2026}\n' diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py index dfdde99b34..8aff6bf148 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -190,7 +190,7 @@ def test_two_turns_with_tool_call(self, variant, generation_env): SampleParsedChunk( tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + rollout_log_probs=[-1 / 128 * i for i in range(47)], ), SampleParsedChunk( tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, loss_mask_value=0, rollout_log_probs=[0.0] * 31 @@ -204,7 +204,7 @@ def test_two_turns_with_tool_call(self, variant, generation_env): partial_sample=expected_partial_sample( prompt=TWO_TURN_PROMPT, response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE + MULTI_TURN_SECOND_RESPONSE, - response_length=45 + 31 + 24, + response_length=47 + 31 + 24, ), ), ] @@ -215,13 +215,13 @@ def test_two_turns_with_tool_call(self, variant, generation_env): SampleParsedChunk( tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + rollout_log_probs=[-1 / 128 * i for i in range(47)], ) ], partial_sample=expected_partial_sample( prompt=TWO_TURN_PROMPT, response=MULTI_TURN_FIRST_RESPONSE, - response_length=45, + response_length=47, ), ), ExpectedSampleInfo( @@ -294,13 +294,13 @@ def test_finish_reason_length_exits_and_preserves_content(self, variant, generat SampleParsedChunk( tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + rollout_log_probs=[-1 / 128 * i for i in range(47)], ) ], partial_sample=expected_partial_sample( prompt=TWO_TURN_PROMPT, response=MULTI_TURN_FIRST_RESPONSE, - response_length=45, + response_length=47, status=Sample.Status.TRUNCATED, ), ), @@ -323,7 +323,7 @@ def test_max_turns_reached(self, variant, generation_env): SampleParsedChunk( tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + rollout_log_probs=[-1 / 128 * i for i in range(47)], ), SampleParsedChunk( tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, loss_mask_value=0, rollout_log_probs=[0.0] * 31 @@ -332,7 +332,7 @@ def test_max_turns_reached(self, variant, generation_env): partial_sample=expected_partial_sample( prompt=TWO_TURN_PROMPT, response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE, - response_length=45 + 31, + response_length=47 + 31, ), ), ] @@ -343,13 +343,13 @@ def test_max_turns_reached(self, variant, generation_env): SampleParsedChunk( tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + rollout_log_probs=[-1 / 128 * i for i in range(47)], ) ], partial_sample=expected_partial_sample( prompt=TWO_TURN_PROMPT, response=MULTI_TURN_FIRST_RESPONSE, - response_length=45, + response_length=47, ), ), ] @@ -378,7 +378,7 @@ def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generat @pytest.mark.parametrize( "generation_env", - [{"args_kwargs": {"rollout_max_context_len": len(FIRST_PROMPT_TOKEN_IDS) + 45 + 31}}], + [{"args_kwargs": {"rollout_max_context_len": len(FIRST_PROMPT_TOKEN_IDS) + 47 + 31}}], indirect=True, ) def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env): @@ -394,7 +394,7 @@ def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, ge SampleParsedChunk( tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + rollout_log_probs=[-1 / 128 * i for i in range(47)], ), SampleParsedChunk( tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, loss_mask_value=0, rollout_log_probs=[0.0] * 31 @@ -403,7 +403,7 @@ def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, ge partial_sample=expected_partial_sample( prompt=TWO_TURN_PROMPT, response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE, - response_length=45 + 31, + response_length=47 + 31, status=Sample.Status.TRUNCATED, ), ), @@ -415,13 +415,13 @@ def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, ge SampleParsedChunk( tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(45)], + rollout_log_probs=[-1 / 128 * i for i in range(47)], ) ], partial_sample=expected_partial_sample( prompt=TWO_TURN_PROMPT, response=MULTI_TURN_FIRST_RESPONSE, - response_length=45, + response_length=47, status=Sample.Status.TRUNCATED, ), ), From 9e45a1fff9ff15f8a6da15d50857e57387602b5f Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:15:42 +0800 Subject: [PATCH 34/77] Add integration test for router (#472) --- tests/router/__init__.py | 0 tests/router/test_router.py | 204 ++++++++++++++++++++++++++++++++++++ 2 files changed, 204 insertions(+) create mode 100644 tests/router/__init__.py create mode 100644 tests/router/test_router.py diff --git a/tests/router/__init__.py b/tests/router/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/router/test_router.py b/tests/router/test_router.py new file mode 100644 index 0000000000..7c645fe304 --- /dev/null +++ b/tests/router/test_router.py @@ -0,0 +1,204 @@ +import asyncio +from argparse import Namespace + +import pytest +import requests + +from miles.router.router import MilesRouter +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, default_process_fn +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +def make_router_args(router_port: int, **overrides) -> Namespace: + defaults = dict( + sglang_router_ip="127.0.0.1", + sglang_router_port=router_port, + rollout_health_check_interval=1.0, + miles_router_health_check_failure_threshold=3, + miles_router_max_connections=100, + miles_router_timeout=None, + miles_router_middleware_paths=[], + ) + defaults.update(overrides) + return Namespace(**defaults) + + +def create_mock_worker(start_port: int = 30000) -> MockSGLangServer: + port = find_available_port(start_port) + return MockSGLangServer( + model_name="Qwen/Qwen3-0.6B", + process_fn=default_process_fn, + host="127.0.0.1", + port=port, + latency=0.0, + ) + + +class RouterEnv: + def __init__(self, router: MilesRouter, server: UvicornThreadServer): + self.router = router + self.server = server + + @property + def url(self) -> str: + return self.server.url + + +@pytest.fixture +def router_env(): + args = make_router_args(find_available_port(20000)) + router = MilesRouter(args, verbose=False) + server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) + server.start() + yield RouterEnv(router, server) + server.stop() + + +@pytest.fixture +def mock_worker(): + server = create_mock_worker() + server.start() + yield server + server.stop() + + +@pytest.fixture +def mock_worker_factory(): + servers = [] + + def _create(): + start_port = 30000 + len(servers) * 100 + server = create_mock_worker(start_port) + server.start() + servers.append(server) + return server + + yield _create + for s in servers: + s.stop() + + +@pytest.fixture +def router_factory(): + def _create(**overrides) -> MilesRouter: + args = make_router_args(find_available_port(20000), **overrides) + return MilesRouter(args, verbose=False) + + return _create + + +class TestWorkerManagement: + def test_add_worker_via_query_param(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30001" + r = requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0) + r.raise_for_status() + + assert r.json()["status"] == "success" + assert worker_url in router_env.router.worker_request_counts + assert router_env.router.worker_request_counts[worker_url] == 0 + + def test_add_worker_via_body(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30002" + r = requests.post(f"{router_env.url}/add_worker", json={"url": worker_url}, timeout=5.0) + r.raise_for_status() + + assert r.json()["status"] == "success" + assert worker_url in router_env.router.worker_request_counts + + def test_add_worker_duplicate(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30003" + requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() + requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() + + assert len(router_env.router.worker_request_counts) == 1 + assert worker_url in router_env.router.worker_request_counts + + def test_add_worker_missing_url(self, router_env: RouterEnv): + r = requests.post(f"{router_env.url}/add_worker", json={}, timeout=5.0) + assert r.status_code == 400 + assert "error" in r.json() + + def test_list_workers(self, router_env: RouterEnv): + worker_urls = ["http://127.0.0.1:30001", "http://127.0.0.1:30002"] + for url in worker_urls: + requests.post(f"{router_env.url}/add_worker", params={"url": url}, timeout=5.0) + + r = requests.get(f"{router_env.url}/list_workers", timeout=5.0) + r.raise_for_status() + assert set(r.json()["urls"]) == set(worker_urls) + + +class TestLoadBalancing: + def test_use_url_selects_min_load(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8} + + selected = router._use_url() + assert selected == "http://w2:8000" + assert router.worker_request_counts["http://w2:8000"] == 3 + + def test_use_url_excludes_dead_workers(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 1, "http://w3:8000": 3} + router.dead_workers = {"http://w2:8000"} + + selected = router._use_url() + assert selected == "http://w3:8000" + assert router.worker_request_counts["http://w3:8000"] == 4 + + def test_use_url_raises_when_all_dead(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 0} + router.dead_workers = {"http://w1:8000"} + + with pytest.raises(RuntimeError, match="No healthy workers"): + router._use_url() + + +# TODO: extract main body inside `_health_check_loop`, then can test that function +class TestHealthCheck: + def test_check_worker_health_success(self, router_factory, mock_worker: MockSGLangServer): + router = router_factory() + url, healthy = asyncio.run(router._check_worker_health(mock_worker.url)) + assert url == mock_worker.url + assert healthy is True + + def test_check_worker_health_failure(self, router_factory): + router = router_factory() + url, healthy = asyncio.run(router._check_worker_health("http://127.0.0.1:59999")) + assert url == "http://127.0.0.1:59999" + assert healthy is False + + +class TestProxyIntegration: + def test_proxy_forwards_request(self, router_env: RouterEnv, mock_worker: MockSGLangServer): + requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0).raise_for_status() + + payload = {"input_ids": [1, 2, 3], "return_logprob": True} + r = requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0) + r.raise_for_status() + + assert "text" in r.json() + assert len(mock_worker.request_log) == 1 + assert mock_worker.request_log[0] == payload + + def test_proxy_multi_worker(self, router_env: RouterEnv, mock_worker_factory): + worker1, worker2 = mock_worker_factory(), mock_worker_factory() + requests.post(f"{router_env.url}/add_worker", params={"url": worker1.url}, timeout=5.0) + requests.post(f"{router_env.url}/add_worker", params={"url": worker2.url}, timeout=5.0) + + payload = {"input_ids": [1, 2, 3], "return_logprob": True} + for _ in range(4): + requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0).raise_for_status() + + all_requests = worker1.request_log + worker2.request_log + assert len(all_requests) == 4 + assert all(req == payload for req in all_requests) + + def test_proxy_health_endpoint(self, router_env: RouterEnv, mock_worker: MockSGLangServer): + requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0) + + r = requests.get(f"{router_env.url}/health", timeout=5.0) + r.raise_for_status() + assert r.json()["status"] == "ok" From 61c8d8b535f6bad91767d04e5cd501511cdac80f Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:15:59 +0800 Subject: [PATCH 35/77] Support other http actions for http request utility (#473) --- miles/utils/http_utils.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 2b3e6e192f..9641cbe0ec 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -162,11 +162,15 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60): +async def _post(client, url, payload, max_retries=60, action="post"): retry_count = 0 while retry_count < max_retries: try: - response = await client.post(url, json=payload or {}) + if action in ("delete", "get"): + assert not payload + response = await getattr(client, action)(url) + else: + response = await getattr(client, action)(url, json=payload or {}) response.raise_for_status() try: output = response.json() @@ -240,8 +244,8 @@ def __init__(self, concurrency: int): timeout=httpx.Timeout(None), ) - async def do_post(self, url, payload, max_retries=60): - return await _post(self._client, url, payload, max_retries) + async def do_post(self, url, payload, max_retries=60, action="post"): + return await _post(self._client, url, payload, max_retries, action=action) # Create actors per node created = [] @@ -265,7 +269,7 @@ async def do_post(self, url, payload, max_retries=60): _post_actors = created -async def post(url, payload, max_retries=60): +async def post(url, payload, max_retries=60, action="post"): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: @@ -274,13 +278,13 @@ async def post(url, payload, max_retries=60): actor = _next_actor() if actor is not None: # Use a thread to avoid blocking the event loop on ray.get - obj_ref = actor.do_post.remote(url, payload, max_retries) + obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) return await asyncio.to_thread(ray.get, obj_ref) except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries) + return await _post(_http_client, url, payload, max_retries, action=action) async def get(url): From 759fc8b8b7b18726fc5cf2c55b8fd39cadc4897e Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:16:16 +0800 Subject: [PATCH 36/77] Support OpenAI format for tool execution (#474) --- miles/rollout/generate_hub/tool_call_utils.py | 39 ++++++++++++------- .../generate_hub/test_tool_call_utils.py | 19 +++++++++ 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/miles/rollout/generate_hub/tool_call_utils.py b/miles/rollout/generate_hub/tool_call_utils.py index 12ce362c06..fd755f6353 100644 --- a/miles/rollout/generate_hub/tool_call_utils.py +++ b/miles/rollout/generate_hub/tool_call_utils.py @@ -3,6 +3,7 @@ from collections.abc import Callable from typing import Any +from openai.types.chat import ChatCompletionMessageToolCall from pydantic import TypeAdapter from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.core_types import ToolCallItem @@ -20,24 +21,36 @@ def create_tool_call_parser(tool_specs, tool_call_parser): ) -async def execute_tool_calls(tool_calls: list[ToolCallItem], execute_one: Callable) -> list[dict[str, Any]]: +async def execute_tool_calls( + tool_calls: list[ToolCallItem | ChatCompletionMessageToolCall], + execute_one: Callable, +) -> list[dict[str, Any]]: tool_messages = [] for call in tool_calls: - params = json.loads(call.parameters) if call.parameters else {} - result = await execute_one(call.name, params) - assert isinstance(result, str) - tool_messages.append( - { - "role": "tool", - # src: serving_chat.py :: _process_tool_call_id - "tool_call_id": f"call_{uuid.uuid4().hex[:24]}", - "content": result, - "name": call.name, - } - ) + tool_messages.append(await _execute_tool_call(call, execute_one)) return tool_messages +async def _execute_tool_call( + call: ToolCallItem | ChatCompletionMessageToolCall, execute_one: Callable +) -> dict[str, Any]: + if isinstance(call, ChatCompletionMessageToolCall): + name = call.function.name + params = json.loads(call.function.arguments) if call.function.arguments else {} + tool_call_id = call.id + elif isinstance(call, ToolCallItem): + name = call.name + params = json.loads(call.parameters) if call.parameters else {} + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + else: + raise TypeError(f"Unsupported tool call type: {type(call)}") + + result = await execute_one(name, params) + assert isinstance(result, str) + + return {"role": "tool", "tool_call_id": tool_call_id, "content": result, "name": name} + + def update_sample_with_tool_responses(sample: Sample, tool_messages: list[dict[str, Any]], tokenizer): next_obs_tokens_ids: list[int] = tokenize_tool_responses(tool_messages, tokenizer=tokenizer) sample.response += tokenizer.decode(next_obs_tokens_ids) diff --git a/tests/rollout/generate_hub/test_tool_call_utils.py b/tests/rollout/generate_hub/test_tool_call_utils.py index 26d1330ae6..8f06756e64 100644 --- a/tests/rollout/generate_hub/test_tool_call_utils.py +++ b/tests/rollout/generate_hub/test_tool_call_utils.py @@ -44,6 +44,25 @@ class TestTokenizeToolResponses: + @pytest.mark.parametrize("model_name", ["Qwen/Qwen3-0.6B"]) + def test_snapshot(self, model_name): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + token_ids = tokenize_tool_responses(SAMPLE_TOOL_RESPONSES, tokenizer) + decoded = tokenizer.decode(token_ids) + + assert decoded == ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": 25}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + @pytest.mark.parametrize("num_tools", [1, 2]) @pytest.mark.parametrize("model_name", TOOL_CALL_TEST_MODELS) def test_tokenize_tool_responses(self, model_name, num_tools): From 2072ac8ff13f6568d2d6fae30d36444850599412 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:16:39 +0800 Subject: [PATCH 37/77] Support openai endpoint for mock sglang server (#475) --- miles/utils/test_utils/mock_sglang_server.py | 161 ++++-- miles/utils/test_utils/mock_tools.py | 49 ++ .../test_utils/test_mock_sglang_server.py | 525 +++++++++++++----- 3 files changed, 534 insertions(+), 201 deletions(-) diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py index d13b5bdf8a..f8f233d208 100644 --- a/miles/utils/test_utils/mock_sglang_server.py +++ b/miles/utils/test_utils/mock_sglang_server.py @@ -1,11 +1,16 @@ import asyncio import re +import time +import uuid from collections.abc import Callable from contextlib import contextmanager from dataclasses import asdict, dataclass from fastapi import FastAPI, Request from fastapi.responses import JSONResponse +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.function_call_parser import FunctionCallParser from transformers import AutoTokenizer from miles.utils.http_utils import find_available_port @@ -66,47 +71,26 @@ def reset_stats(self): self.request_log.clear() self._concurrency.reset() - def _setup_routes(self): - @self.app.post("/generate") - async def generate(request: Request): - payload = await request.json() - self.request_log.append(payload) - - with self._concurrency.track(): - if self.latency > 0: - await asyncio.sleep(self.latency) - - assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True" - input_ids = payload.get("input_ids", []) - - prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False) - process_result = self.process_fn(prompt_str) - output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) - - prompt_tokens = len(input_ids) - completion_tokens = len(output_ids) - - finish_reason_dict = {"type": process_result.finish_reason} - if process_result.finish_reason == "length": - finish_reason_dict["length"] = completion_tokens + def start(self): + self._server = UvicornThreadServer(self.app, host=self.host, port=self.port) + self._server.start() - output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] + def stop(self): + if self._server is not None: + self._server.stop() - meta_info = { - "finish_reason": finish_reason_dict, - "prompt_tokens": prompt_tokens, - "cached_tokens": process_result.cached_tokens, - "completion_tokens": completion_tokens, - "output_token_logprobs": output_token_logprobs, - **process_result.meta_info.to_dict(), - } + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" - response = { - "text": process_result.text, - "meta_info": meta_info, - } + def _setup_routes(self): + @self.app.post("/generate") + async def generate(request: Request): + return await self._handle_generate_like_request(request, self._compute_generate_response) - return JSONResponse(content=response) + @self.app.post("/v1/chat/completions") + async def chat_completions(request: Request): + return await self._handle_generate_like_request(request, self._compute_chat_completions_response) @self.app.get("/health") async def health(): @@ -116,17 +100,98 @@ async def health(): async def abort_request(_request: Request): return JSONResponse(content={"status": "ok"}) - def start(self): - self._server = UvicornThreadServer(self.app, host=self.host, port=self.port) - self._server.start() - - def stop(self): - if self._server is not None: - self._server.stop() - - @property - def url(self) -> str: - return f"http://{self.host}:{self.port}" + async def _handle_generate_like_request(self, request: Request, compute_fn: Callable[[dict], dict]): + payload = await request.json() + self.request_log.append(payload) + with self._concurrency.track(): + if self.latency > 0: + await asyncio.sleep(self.latency) + response = compute_fn(payload) + return JSONResponse(content=response) + + def _compute_generate_response(self, payload: dict) -> dict: + assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True" + input_ids = payload.get("input_ids", []) + + prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False) + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + prompt_tokens = len(input_ids) + completion_tokens = len(output_ids) + + finish_reason_dict = {"type": process_result.finish_reason} + if process_result.finish_reason == "length": + finish_reason_dict["length"] = completion_tokens + + output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] + + meta_info = { + "finish_reason": finish_reason_dict, + "prompt_tokens": prompt_tokens, + "cached_tokens": process_result.cached_tokens, + "completion_tokens": completion_tokens, + "output_token_logprobs": output_token_logprobs, + **process_result.meta_info.to_dict(), + } + + return {"text": process_result.text, "meta_info": meta_info} + + def _compute_chat_completions_response(self, payload: dict) -> dict: + messages = payload.get("messages", []) + tools = payload.get("tools") + + prompt_str = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, tools=tools + ) + + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + logprobs_content = [ + {"token": self.tokenizer.convert_ids_to_tokens(tid), "logprob": -1 / 128 * i} + for i, tid in enumerate(output_ids) + ] + + finish_reason = process_result.finish_reason + tool_calls = None + if tools and finish_reason == "stop": + parser = FunctionCallParser( + tools=TypeAdapter(list[Tool]).validate_python(tools), + tool_call_parser="qwen25", + ) + message_content, parsed_calls = parser.parse_non_stream(process_result.text) + if parsed_calls: + finish_reason = "tool_calls" + tool_calls = [ + { + "id": f"call{i:05d}", + "type": "function", + "function": {"name": call.name, "arguments": call.parameters or "{}"}, + } + for i, call in enumerate(parsed_calls) + ] + else: + message_content = process_result.text + + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message_content, + "tool_calls": tool_calls, + }, + "logprobs": {"content": logprobs_content}, + "finish_reason": finish_reason, + } + ], + } class Counter: diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py index faf8e0941e..220bd2bc01 100644 --- a/miles/utils/test_utils/mock_tools.py +++ b/miles/utils/test_utils/mock_tools.py @@ -117,6 +117,55 @@ async def execute_tool_call(name: str, params: dict) -> str: ) MULTI_TURN_SECOND_RESPONSE = "The answer is: 42 + 2026 + -60 = 2008." +MULTI_TURN_USER_QUESTION = "What is 42 + year + temperature?" +MULTI_TURN_FIRST_RESPONSE_CONTENT = "Let me get the year and temperature first." +MULTI_TURN_FIRST_TOOL_CALLS = [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}}, + { + "id": "call00001", + "type": "function", + "function": {"name": "get_temperature", "arguments": '{"location": "Mars"}'}, + }, +] +MULTI_TURN_FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, +] + +MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN = [ + {"role": "user", "content": MULTI_TURN_USER_QUESTION}, +] + +MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN = [ + {"role": "user", "content": MULTI_TURN_USER_QUESTION}, + { + "role": "assistant", + "content": MULTI_TURN_FIRST_RESPONSE_CONTENT, + "tool_calls": MULTI_TURN_FIRST_TOOL_CALLS, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}'}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}'}, +] + +MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = [ + {"role": "user", "content": MULTI_TURN_USER_QUESTION}, + { + "content": MULTI_TURN_FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": MULTI_TURN_FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, +] + def multi_turn_tool_call_process_fn(prompt: str) -> ProcessResult: prompt_response_pairs = { diff --git a/tests/utils/test_utils/test_mock_sglang_server.py b/tests/utils/test_utils/test_mock_sglang_server.py index 9326122b87..b7ed21f365 100644 --- a/tests/utils/test_utils/test_mock_sglang_server.py +++ b/tests/utils/test_utils/test_mock_sglang_server.py @@ -12,6 +12,23 @@ default_process_fn, with_mock_server, ) +from miles.utils.test_utils.mock_tools import ( + MULTI_TURN_FIRST_PROMPT, + MULTI_TURN_FIRST_RESPONSE, + MULTI_TURN_FIRST_RESPONSE_CONTENT, + MULTI_TURN_FIRST_TOOL_CALLS, + MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN, + MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN, + MULTI_TURN_SECOND_PROMPT, + MULTI_TURN_SECOND_RESPONSE, + SAMPLE_TOOLS, + multi_turn_tool_call_process_fn, +) + + +def expected_logprobs(tokenizer, text: str) -> list[dict]: + output_ids = tokenizer.encode(text, add_special_tokens=False) + return [{"token": tokenizer.convert_ids_to_tokens(tid), "logprob": -i / 128} for i, tid in enumerate(output_ids)] @pytest.fixture(scope="module") @@ -20,182 +37,384 @@ def mock_server(): yield server -def test_basic_server_start_stop(mock_server): - assert mock_server.port > 0 - assert f"http://{mock_server.host}:{mock_server.port}" == mock_server.url - - -def test_generate_endpoint_basic(mock_server): - prompt = "What is 1+7?" - input_ids = mock_server.tokenizer.encode(prompt, add_special_tokens=False) - assert input_ids == [3838, 374, 220, 16, 10, 22, 30] - - response = requests.post( - f"{mock_server.url}/generate", - json={ - "input_ids": input_ids, - "sampling_params": {"temperature": 0.7, "max_new_tokens": 10}, - "return_logprob": True, - }, - timeout=5.0, - ) - assert response.status_code == 200 - data = response.json() - - assert data == { - "text": "\\boxed{8}", - "meta_info": { - "finish_reason": {"type": "stop"}, - "prompt_tokens": len(input_ids), - "cached_tokens": 0, - "completion_tokens": 5, - "output_token_logprobs": [ - [-0.0, 59], - [-0.0078125, 79075], - [-0.015625, 90], - [-0.0234375, 23], - [-0.03125, 92], - ], - }, - } - - -def test_process_fn_receives_decoded_prompt(): - received_prompts = [] - - def process_fn(prompt: str) -> ProcessResult: - received_prompts.append(prompt) - return ProcessResult(text="response", finish_reason="stop") - - with with_mock_server(process_fn=process_fn) as server: - requests.post(f"{server.url}/generate", json={"input_ids": [1, 2, 3], "sampling_params": {}}, timeout=5.0) - - assert len(received_prompts) == 1 - assert isinstance(received_prompts[0], str) - - -def test_default_process_fn(): - assert default_process_fn("What is 1+5?") == ProcessResult(text="\\boxed{6}", finish_reason="stop") - assert default_process_fn("What is 1+10?") == ProcessResult(text="\\boxed{11}", finish_reason="stop") - assert default_process_fn("Hello") == ProcessResult(text="I don't understand.", finish_reason="stop") - - -def test_process_result_meta_info_to_dict(): - assert ProcessResultMetaInfo().to_dict() == {} - assert ProcessResultMetaInfo(weight_version="v1").to_dict() == {"weight_version": "v1"} - assert ProcessResultMetaInfo(weight_version="v1", spec_accept_token_num=10).to_dict() == { - "weight_version": "v1", - "spec_accept_token_num": 10, - } - assert ProcessResultMetaInfo( - weight_version="v1", routed_experts="abc", spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3 - ).to_dict() == { - "weight_version": "v1", - "routed_experts": "abc", - "spec_accept_token_num": 10, - "spec_draft_token_num": 15, - "spec_verify_ct": 3, - } - - -def test_generate_endpoint_with_meta_info(): - def process_fn(_: str) -> ProcessResult: - return ProcessResult( - text="ok", - finish_reason="stop", - cached_tokens=5, - meta_info=ProcessResultMetaInfo( - weight_version="v2.0", - routed_experts="encoded_data", - spec_accept_token_num=10, - spec_draft_token_num=15, - spec_verify_ct=3, - ), - ) +class TestProcessResultMetaInfo: + def test_to_dict_empty(self): + assert ProcessResultMetaInfo().to_dict() == {} - with with_mock_server(process_fn=process_fn) as server: - response = requests.post( - f"{server.url}/generate", - json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, - timeout=5.0, - ) - data = response.json() + def test_to_dict_single_field(self): + assert ProcessResultMetaInfo(weight_version="v1").to_dict() == {"weight_version": "v1"} - assert data == { - "text": "ok", - "meta_info": { - "finish_reason": {"type": "stop"}, - "prompt_tokens": 3, - "cached_tokens": 5, - "completion_tokens": 1, - "output_token_logprobs": [[-0.0, 562]], - "weight_version": "v2.0", - "routed_experts": "encoded_data", + def test_to_dict_partial_fields(self): + assert ProcessResultMetaInfo(weight_version="v1", spec_accept_token_num=10).to_dict() == { + "weight_version": "v1", + "spec_accept_token_num": 10, + } + + def test_to_dict_all_fields(self): + assert ProcessResultMetaInfo( + weight_version="v1", + routed_experts="abc", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ).to_dict() == { + "weight_version": "v1", + "routed_experts": "abc", "spec_accept_token_num": 10, "spec_draft_token_num": 15, "spec_verify_ct": 3, - }, - } + } -def test_request_log_and_reset_stats(mock_server): - mock_server.reset_stats() - assert len(mock_server.request_log) == 0 +class TestDefaultProcessFn: + def test_math_question(self): + assert default_process_fn("What is 1+5?") == ProcessResult(text="\\boxed{6}", finish_reason="stop") + assert default_process_fn("What is 1+10?") == ProcessResult(text="\\boxed{11}", finish_reason="stop") + + def test_unknown_question(self): + assert default_process_fn("Hello") == ProcessResult(text="I don't understand.", finish_reason="stop") + + +class TestCounter: + def test_tracks_max(self): + counter = Counter() + assert counter.max_value == 0 + + with counter.track(): + assert counter.max_value == 1 + with counter.track(): + assert counter.max_value == 2 - payload = {"input_ids": [1, 2, 3], "sampling_params": {"temperature": 0.5}, "return_logprob": True} - requests.post(f"{mock_server.url}/generate", json=payload, timeout=5.0) - assert len(mock_server.request_log) == 1 - assert mock_server.request_log[0] == payload + counter.reset() + assert counter.max_value == 0 - mock_server.reset_stats() - assert len(mock_server.request_log) == 0 - assert mock_server.max_concurrent == 0 + def test_concurrent_tasks(self): + counter = Counter() + async def task(): + with counter.track(): + await asyncio.sleep(0.1) -@pytest.mark.parametrize("latency,min_time,max_time", [(0.0, 0.0, 0.3), (0.5, 0.5, 1.0)]) -def test_latency(latency, min_time, max_time): - with with_mock_server(latency=latency) as server: - start = time.time() - requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) - elapsed = time.time() - start - assert min_time <= elapsed < max_time + async def run_all(): + await asyncio.gather(task(), task(), task()) + asyncio.run(run_all()) + assert counter.max_value == 3 -def test_max_concurrent_with_latency(): - with with_mock_server(latency=0.1) as server: - def send_request(): +class TestMockServerBasic: + def test_start_stop(self, mock_server): + assert mock_server.port > 0 + assert f"http://{mock_server.host}:{mock_server.port}" == mock_server.url + + def test_request_log_and_reset_stats(self, mock_server): + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + + payload = {"input_ids": [1, 2, 3], "sampling_params": {"temperature": 0.5}, "return_logprob": True} + requests.post(f"{mock_server.url}/generate", json=payload, timeout=5.0) + assert len(mock_server.request_log) == 1 + assert mock_server.request_log[0] == payload + + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + assert mock_server.max_concurrent == 0 + + @pytest.mark.parametrize("latency,min_time,max_time", [(0.0, 0.0, 0.3), (0.5, 0.5, 1.0)]) + def test_latency(self, latency, min_time, max_time): + with with_mock_server(latency=latency) as server: + start = time.time() requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + elapsed = time.time() - start + assert min_time <= elapsed < max_time - with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: - futures = [executor.submit(send_request) for _ in range(3)] - concurrent.futures.wait(futures) + def test_max_concurrent_with_latency(self): + with with_mock_server(latency=0.1) as server: - assert server.max_concurrent == 3 + def send_request(): + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(send_request) for _ in range(3)] + concurrent.futures.wait(futures) -def test_counter_tracks_max(): - counter = Counter() - assert counter.max_value == 0 + assert server.max_concurrent == 3 - with counter.track(): - assert counter.max_value == 1 - with counter.track(): - assert counter.max_value == 2 + def test_health_endpoint(self, mock_server): + response = requests.get(f"{mock_server.url}/health", timeout=5.0) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} - counter.reset() - assert counter.max_value == 0 + def test_abort_request_endpoint(self, mock_server): + response = requests.post(f"{mock_server.url}/abort_request", json={}, timeout=5.0) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} -def test_counter_concurrent_tasks(): - counter = Counter() +class TestGenerateEndpoint: + def test_basic(self, mock_server): + prompt = "What is 1+7?" + input_ids = mock_server.tokenizer.encode(prompt, add_special_tokens=False) + assert input_ids == [3838, 374, 220, 16, 10, 22, 30] - async def task(): - with counter.track(): - await asyncio.sleep(0.1) + response = requests.post( + f"{mock_server.url}/generate", + json={ + "input_ids": input_ids, + "sampling_params": {"temperature": 0.7, "max_new_tokens": 10}, + "return_logprob": True, + }, + timeout=5.0, + ) + assert response.status_code == 200 + assert response.json() == { + "text": "\\boxed{8}", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": len(input_ids), + "cached_tokens": 0, + "completion_tokens": 5, + "output_token_logprobs": [ + [-0.0, 59], + [-0.0078125, 79075], + [-0.015625, 90], + [-0.0234375, 23], + [-0.03125, 92], + ], + }, + } + + def test_with_meta_info(self): + def process_fn(_: str) -> ProcessResult: + return ProcessResult( + text="ok", + finish_reason="stop", + cached_tokens=5, + meta_info=ProcessResultMetaInfo( + weight_version="v2.0", + routed_experts="encoded_data", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ), + ) + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + + assert response.json() == { + "text": "ok", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": 3, + "cached_tokens": 5, + "completion_tokens": 1, + "output_token_logprobs": [[-0.0, 562]], + "weight_version": "v2.0", + "routed_experts": "encoded_data", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + }, + } + + def test_finish_reason_length(self): + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text="truncated output", finish_reason="length") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + data = response.json() + + finish_reason = data["meta_info"]["finish_reason"] + assert finish_reason["type"] == "length" + assert finish_reason["length"] == data["meta_info"]["completion_tokens"] + + +class TestChatCompletionsEndpoint: + def test_basic(self, mock_server): + response = requests.post( + f"{mock_server.url}/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "What is 1+5?"}], + }, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() - async def run_all(): - await asyncio.gather(task(), task(), task()) + assert data["id"].startswith("chatcmpl-") + assert isinstance(data["created"], int) + assert data == { + "id": data["id"], + "object": "chat.completion", + "created": data["created"], + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "\\boxed{6}", "tool_calls": None}, + "logprobs": {"content": expected_logprobs(mock_server.tokenizer, "\\boxed{6}")}, + "finish_reason": "stop", + } + ], + } + + def test_with_tool_calls(self): + tool_call_response = 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n' + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=tool_call_response, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What year is it?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": { + "role": "assistant", + "content": "Let me check for you.", + "tool_calls": [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}} + ], + }, + "logprobs": {"content": expected_logprobs(server.tokenizer, tool_call_response)}, + "finish_reason": "tool_calls", + } + + def test_with_tools_but_no_tool_call(self): + response_text = "The weather is sunny today." + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=response_text, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What's the weather?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": {"role": "assistant", "content": response_text, "tool_calls": None}, + "logprobs": {"content": expected_logprobs(server.tokenizer, response_text)}, + "finish_reason": "stop", + } + + def test_with_multiple_tool_calls(self): + multi_tool_response = ( + "I will get year and temperature.\n" + '\n{"name": "get_year", "arguments": {}}\n\n' + '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n' + ) - asyncio.run(run_all()) - assert counter.max_value == 3 + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=multi_tool_response, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What year and temperature?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": { + "role": "assistant", + "content": "I will get year and temperature.", + "tool_calls": [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}}, + { + "id": "call00001", + "type": "function", + "function": {"name": "get_temperature", "arguments": '{"location": "Shanghai"}'}, + }, + ], + }, + "logprobs": {"content": expected_logprobs(server.tokenizer, multi_tool_response)}, + "finish_reason": "tool_calls", + } + + +class TestMultiTurnToolCallProcessFn: + @pytest.mark.parametrize( + "prompt,expected_response", + [ + pytest.param(MULTI_TURN_FIRST_PROMPT, MULTI_TURN_FIRST_RESPONSE, id="first_turn"), + pytest.param(MULTI_TURN_SECOND_PROMPT, MULTI_TURN_SECOND_RESPONSE, id="second_turn"), + ], + ) + def test_generate_endpoint(self, prompt, expected_response): + with with_mock_server(process_fn=multi_turn_tool_call_process_fn) as server: + input_ids = server.tokenizer.encode(prompt, add_special_tokens=False) + response = requests.post( + f"{server.url}/generate", + json={"input_ids": input_ids, "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + assert data["text"] == expected_response + assert data["meta_info"]["finish_reason"] == {"type": "stop"} + + @pytest.mark.parametrize( + "messages,expected_content,expected_tool_calls,expected_finish_reason", + [ + pytest.param( + MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN, + MULTI_TURN_FIRST_RESPONSE_CONTENT, + MULTI_TURN_FIRST_TOOL_CALLS, + "tool_calls", + id="first_turn", + ), + pytest.param( + MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN, + MULTI_TURN_SECOND_RESPONSE, + None, + "stop", + id="second_turn", + ), + ], + ) + def test_chat_completions_endpoint(self, messages, expected_content, expected_tool_calls, expected_finish_reason): + with with_mock_server(process_fn=multi_turn_tool_call_process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={"model": "test", "messages": messages, "tools": SAMPLE_TOOLS}, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + assert data["choices"][0]["message"]["content"] == expected_content + assert data["choices"][0]["message"]["tool_calls"] == expected_tool_calls + assert data["choices"][0]["finish_reason"] == expected_finish_reason From b197e9995091560e1c4994a3d6308e9d71e25dbe Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:16:57 +0800 Subject: [PATCH 38/77] Support session based API in router with tracing (#476) --- miles/router/router.py | 47 +++++----- miles/router/sessions.py | 103 ++++++++++++++++++++++ tests/router/test_sessions.py | 159 ++++++++++++++++++++++++++++++++++ 3 files changed, 288 insertions(+), 21 deletions(-) create mode 100644 miles/router/sessions.py create mode 100644 tests/router/test_sessions.py diff --git a/miles/router/router.py b/miles/router/router.py index 2e8ecfc41f..7d3ecd9806 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse from starlette.responses import Response +from miles.router.sessions import setup_session_routes from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -69,6 +70,8 @@ def _setup_routes(self): self.app.post("/add_worker")(self.add_worker) self.app.get("/list_workers")(self.list_workers) self.app.post("/retrieve_from_text")(self.retrieve_from_text) + # Session routes - must be registered before catch-all + setup_session_routes(self.app, self) # Catch-all route for proxying to SGLang - must be registered LAST self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) @@ -130,39 +133,41 @@ async def _health_check_loop(self): async def proxy(self, request: Request, path: str): """Proxy all other requests to the SGLang router""" - # Forward all other paths to SGLang router + result = await self._do_proxy(request, path) + return self._build_proxy_response(result) + + async def _do_proxy(self, request: Request, path: str) -> dict: + """Core proxy logic. Returns dict with request_body, response_body, status_code, headers.""" worker_url = self._use_url() url = f"{worker_url}/{path}" - # Get request body and headers body = await request.body() headers = dict(request.headers) try: response = await self.client.request(request.method, url, content=body, headers=headers) - # Eagerly read content so we can return JSON (not streaming) content = await response.aread() - content_type = response.headers.get("content-type", "") - try: - # Prefer parsing JSON if possible - data = json.loads(content) - return JSONResponse( - content=data, - status_code=response.status_code, - headers=dict(response.headers), - ) - except Exception: - # Fall back to raw body with original content type - return Response( - content=content, - status_code=response.status_code, - headers=dict(response.headers), - media_type=content_type or None, - ) - + return { + "request_body": body, + "response_body": content, + "status_code": response.status_code, + "headers": dict(response.headers), + } finally: self._finish_url(worker_url) + def _build_proxy_response(self, result: dict) -> Response: + """Build HTTP response from proxy result.""" + content = result["response_body"] + status_code = result["status_code"] + headers = result["headers"] + content_type = headers.get("content-type", "") + try: + data = json.loads(content) + return JSONResponse(content=data, status_code=status_code, headers=headers) + except Exception: + return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) + async def add_worker(self, request: Request): """Add a new worker to the router. Supports providing the URL via query string or JSON body. diff --git a/miles/router/sessions.py b/miles/router/sessions.py new file mode 100644 index 0000000000..f52cc33ef0 --- /dev/null +++ b/miles/router/sessions.py @@ -0,0 +1,103 @@ +import json +import time +import uuid +from typing import TYPE_CHECKING + +from fastapi import Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from transformers import AutoTokenizer + +if TYPE_CHECKING: + from miles.router.router import MilesRouter + + +class SessionRecord(BaseModel): + timestamp: float + method: str + path: str + request: dict + response: dict + status_code: int + + +class DeleteSessionResponse(BaseModel): + session_id: str + records: list[SessionRecord] + + +class SessionManager: + def __init__(self): + self.sessions: dict[str, list[SessionRecord]] = {} + + def create_session(self) -> str: + session_id = uuid.uuid4().hex + self.sessions[session_id] = [] + return session_id + + def get_session(self, session_id: str) -> list[SessionRecord] | None: + return self.sessions.get(session_id) + + def delete_session(self, session_id: str) -> list[SessionRecord]: + assert session_id in self.sessions + return self.sessions.pop(session_id) + + def add_record(self, session_id: str, record: SessionRecord): + assert session_id in self.sessions + self.sessions[session_id].append(record) + + +def setup_session_routes(app, router: "MilesRouter"): + manager = SessionManager() + + # TODO temporary hack before @guapisolo implements TITO + # ============================= HACK START =============================== + tokenizer = AutoTokenizer.from_pretrained(router.args.hf_checkpoint, trust_remote_code=True) + # ============================= HACK END =============================== + + @app.post("/sessions") + async def create_session(): + session_id = manager.create_session() + return {"session_id": session_id} + + @app.delete("/sessions/{session_id}") + async def delete_session(session_id: str): + if session_id not in manager.sessions: + return JSONResponse(status_code=404, content={"error": "session not found"}) + records = manager.delete_session(session_id) + return DeleteSessionResponse(session_id=session_id, records=records) + + @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) + async def session_proxy(request: Request, session_id: str, path: str): + if session_id not in manager.sessions: + return JSONResponse(status_code=404, content={"error": "session not found"}) + + result = await router._do_proxy(request, path) + + request_body = json.loads(result["request_body"]) + response_body = json.loads(result["response_body"]) + + # TODO: remove this hack when @guapisolo implements the real TITO + # ============================= HACK START =============================== + request_body["input_ids"] = tokenizer.apply_chat_template( + request_body["messages"], + add_generation_prompt=True, + add_special_tokens=False, + tools=request_body.get("tools"), + ) + logprobs_content = response_body["choices"][0]["logprobs"]["content"] + for item in logprobs_content: + item["token_id"] = tokenizer.convert_tokens_to_ids(item["token"]) + # ============================= HACK END =============================== + + record = SessionRecord( + timestamp=time.time(), + method=request.method, + path=path, + request=request_body, + response=response_body, + status_code=result["status_code"], + ) + manager.add_record(session_id, record) + + return router._build_proxy_response(result) diff --git a/tests/router/test_sessions.py b/tests/router/test_sessions.py new file mode 100644 index 0000000000..0b37aa5c93 --- /dev/null +++ b/tests/router/test_sessions.py @@ -0,0 +1,159 @@ +from types import SimpleNamespace + +import pytest +import requests + +from miles.router.router import MilesRouter +from miles.router.sessions import SessionManager, SessionRecord +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import ProcessResult, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +class TestSessionManager: + def test_create_session(self): + manager = SessionManager() + session_id = manager.create_session() + assert session_id is not None + assert len(session_id) == 32 + assert session_id in manager.sessions + assert manager.sessions[session_id] == [] + + def test_get_session_exists(self): + manager = SessionManager() + session_id = manager.create_session() + records = manager.get_session(session_id) + assert records == [] + + def test_get_session_not_exists(self): + manager = SessionManager() + records = manager.get_session("nonexistent") + assert records is None + + def test_delete_session_exists(self): + manager = SessionManager() + session_id = manager.create_session() + records = manager.delete_session(session_id) + assert records == [] + assert session_id not in manager.sessions + + def test_delete_session_not_exists(self): + manager = SessionManager() + with pytest.raises(AssertionError): + manager.delete_session("nonexistent") + + def test_add_record(self): + manager = SessionManager() + session_id = manager.create_session() + record = SessionRecord( + timestamp=1234567890.0, + method="POST", + path="generate", + request={"prompt": "hello"}, + response={"text": "world"}, + status_code=200, + ) + manager.add_record(session_id, record) + assert len(manager.sessions[session_id]) == 1 + assert manager.sessions[session_id][0] == record + + def test_add_record_nonexistent_session(self): + manager = SessionManager() + record = SessionRecord( + timestamp=1234567890.0, + method="POST", + path="generate", + request={}, + response={}, + status_code=200, + ) + with pytest.raises(AssertionError): + manager.add_record("nonexistent", record) + + +@pytest.fixture(scope="class") +def router_url(): + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text=f"echo: {prompt}", finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as backend: + args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + ) + router = MilesRouter(args) + + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() + + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend.url}) + + try: + yield url + finally: + server.stop() + + +class TestSessionRoutes: + def test_create_session(self, router_url): + response = requests.post(f"{router_url}/sessions") + assert response.status_code == 200 + data = response.json() + assert "session_id" in data + assert len(data["session_id"]) == 32 + + def test_delete_session(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") + assert delete_resp.status_code == 200 + assert delete_resp.json()["session_id"] == session_id + assert delete_resp.json()["records"] == [] + + assert requests.delete(f"{router_url}/sessions/{session_id}").status_code == 404 + + def test_delete_session_not_found(self, router_url): + response = requests.delete(f"{router_url}/sessions/nonexistent") + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + +class TestSessionProxy: + def test_proxy_session_not_found(self, router_url): + response = requests.post(f"{router_url}/sessions/nonexistent/generate", json={}) + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + def test_proxy_records_request_response(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + resp = requests.post( + f"{router_url}/sessions/{session_id}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + ) + assert resp.status_code == 200 + assert "text" in resp.json() + + records = requests.delete(f"{router_url}/sessions/{session_id}").json()["records"] + assert len(records) == 1 + assert records[0]["method"] == "POST" + assert records[0]["path"] == "generate" + assert records[0]["request_json"]["input_ids"] == [1, 2, 3] + assert "text" in records[0]["response_json"] + + def test_proxy_accumulates_records(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + for _ in range(3): + requests.post( + f"{router_url}/sessions/{session_id}/generate", + json={"input_ids": [1], "sampling_params": {}, "return_logprob": True}, + ) + + records = requests.delete(f"{router_url}/sessions/{session_id}").json()["records"] + assert len(records) == 3 From 9be8ea7c3bcd27db1270420f616b03be87b09072 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:17:35 +0800 Subject: [PATCH 39/77] Support tracing OpenAI endpoint and converting to Sample (#477) --- .../generate_hub/openai_endpoint_utils.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 miles/rollout/generate_hub/openai_endpoint_utils.py diff --git a/miles/rollout/generate_hub/openai_endpoint_utils.py b/miles/rollout/generate_hub/openai_endpoint_utils.py new file mode 100644 index 0000000000..6293564f4f --- /dev/null +++ b/miles/rollout/generate_hub/openai_endpoint_utils.py @@ -0,0 +1,58 @@ +""" +Utilities for the OpenAI endpoint +""" + +from argparse import Namespace +from copy import deepcopy + +from miles.router.sessions import DeleteSessionResponse, SessionRecord +from miles.utils.http_utils import post +from miles.utils.types import Sample + + +class OpenAIEndpointTracer: + def __init__(self, router_url: str, session_id: str): + self.router_url = router_url + self.session_id = session_id + self.base_url = f"{router_url}/sessions/{session_id}/v1" + + @staticmethod + async def create(args: Namespace): + router_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" + session_id = (await post(f"{router_url}/sessions", {}))["session_id"] + return OpenAIEndpointTracer(router_url=router_url, session_id=session_id) + + async def collect_records(self) -> list[SessionRecord]: + # TODO: for fault tolerance, we may want to change to GET + DELETE + response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="delete") + response = DeleteSessionResponse.model_validate(response) + return response.records + + +def compute_samples_from_openai_records(input_sample: Sample, records: list[SessionRecord], tokenizer) -> list[Sample]: + return [_compute_sample_from_openai_record(input_sample, record, tokenizer) for record in records] + + +def _compute_sample_from_openai_record(input_sample: Sample, record: SessionRecord, tokenizer) -> Sample: + # TODO may refine after @guapisolo's implementation + choice = record.response["choices"][0] + output_token_ids = [item["token_id"] for item in choice["logprobs"]["content"]] + output_log_probs = [item["logprob"] for item in choice["logprobs"]["content"]] + + sample = deepcopy(input_sample) + sample.tokens = record.request["input_ids"] + output_token_ids + sample.rollout_log_probs = output_log_probs + sample.response = tokenizer.decode(output_token_ids) + sample.response_length = len(output_token_ids) + sample.loss_mask = [1] * len(output_token_ids) + + # TODO unify with Sample.update_from_meta_info + match choice["finish_reason"]: + case "stop" | "tool_calls": + sample.status = Sample.Status.COMPLETED + case "length": + sample.status = Sample.Status.TRUNCATED + case "abort": + sample.status = Sample.Status.ABORTED + + return sample From dec72fa9fe475eb5f2000d68dd6f976b4abe48ea Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:17:55 +0800 Subject: [PATCH 40/77] Fix sample filter flaky tests (#478) --- .../integration/test_sample_filter.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/rollout/modular_rollout/integration/test_sample_filter.py b/tests/rollout/modular_rollout/integration/test_sample_filter.py index c5c183ba3d..751d689cb4 100644 --- a/tests/rollout/modular_rollout/integration/test_sample_filter.py +++ b/tests/rollout/modular_rollout/integration/test_sample_filter.py @@ -1,15 +1,19 @@ from unittest.mock import Mock import pytest -from tests.rollout.modular_rollout.integration.utils import ( - MIXED_DATA_ROWS, - config, - filter_by_reward, - load_and_call_train, -) +from tests.rollout.modular_rollout.integration.utils import config, filter_by_reward, load_and_call_train from miles.utils.misc import function_registry +# Data with only 2 reward=1 samples out of 4. +# This ensures all 4 samples must be generated to collect 2 valid ones. +_FILTER_TEST_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, # reward=1 + {"input": "What is 1+8?", "label": "wrong"}, # reward=0 + {"input": "What is 1+9?", "label": "wrong"}, # reward=0 + {"input": "What is 1+6?", "label": "7"}, # reward=1 +] + @pytest.mark.parametrize( "rollout_integration_env", @@ -28,7 +32,7 @@ "--rollout-all-samples-process-path", "test:all_samples_process", ], - data_rows=MIXED_DATA_ROWS, + data_rows=_FILTER_TEST_DATA_ROWS, ), id="sample_filter_vs_all_samples", ), From fec580561cf6718f7d1d003ec97e8d17e9dbd94f Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:18:18 +0800 Subject: [PATCH 41/77] Support blackbox agents with tool calling (#479) --- .../rollout/generate_hub/agentic_tool_call.py | 79 +++++++++++++++++++ .../generate_hub/generate_endpoint_wrapper.py | 1 + tests/fixtures/generation_fixtures.py | 59 +++++++++++--- tests/rollout/generate_hub/test_multi_turn.py | 45 +++++++++-- 4 files changed, 163 insertions(+), 21 deletions(-) create mode 100644 miles/rollout/generate_hub/agentic_tool_call.py diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py new file mode 100644 index 0000000000..8022182470 --- /dev/null +++ b/miles/rollout/generate_hub/agentic_tool_call.py @@ -0,0 +1,79 @@ +""" +Simple agentic demo with tool calling. +""" + +import argparse +from copy import deepcopy +from typing import Any + +from openai import AsyncOpenAI + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_hub.openai_endpoint_utils import OpenAIEndpointTracer, compute_samples_from_openai_records +from miles.rollout.generate_hub.tool_call_utils import execute_tool_calls +from miles.utils.misc import load_function + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + tracer = await OpenAIEndpointTracer.create(input.args) + + await _run_blackbox_tool_call_agent( + base_url=tracer.base_url, + prompt=input.sample.prompt, + max_turns=input.args.generate_max_turns, + tool_specs_path=input.args.generate_tool_specs_path, + execute_tool_function_path=input.args.generate_execute_tool_function_path, + ) + + records = await tracer.collect_records() + samples = compute_samples_from_openai_records(input.sample, records, input.state.tokenizer) + return GenerateFnOutput(samples=samples) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-max-turns", type=int, default=16) + parser.add_argument("--generate-tool-specs-path", type=str) + parser.add_argument("--generate-execute-tool-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true") + + +generate.add_arguments = _add_arguments + + +async def _run_blackbox_tool_call_agent( + base_url: str, + prompt: list[dict[str, Any]], + max_turns: int, + tool_specs_path: str, + execute_tool_function_path: str, +): + """ + Imagine this is a black-box agent, e.g. SWE-agent, which does arbitrarily complex work, + only understands OpenAI compatible API, and never understands Miles or the Sample data structure. + """ + + # ----------------------- Setup ------------------------- + + client = AsyncOpenAI(base_url=base_url, api_key="empty") + execute_tool_function = load_function(execute_tool_function_path) + tool_specs = load_function(tool_specs_path) + + # ----------------------- Initial prompts ------------------------- + + messages = deepcopy(prompt) + + for _turn in range(max_turns): + # ----------------------- Call inference endpoint ------------------------- + + response = await client.chat.completions.create(model="default", messages=messages, tools=tool_specs) + + choice = response.choices[0] + messages.append(choice.message.model_dump()) + + if choice.finish_reason in ("stop", "length"): + break + + # ----------------------- Execute tools ------------------------- + + if x := choice.message.tool_calls: + messages += await execute_tool_calls(x, execute_tool_function) diff --git a/miles/rollout/generate_hub/generate_endpoint_wrapper.py b/miles/rollout/generate_hub/generate_endpoint_wrapper.py index c6c7803f92..8947201de9 100644 --- a/miles/rollout/generate_hub/generate_endpoint_wrapper.py +++ b/miles/rollout/generate_hub/generate_endpoint_wrapper.py @@ -1,3 +1,4 @@ +# TODO: may rename to generate_endpoint_utils.py """ Wrapper to integrate SGLang's `/generate` endpoint with RL things like Sample. """ diff --git a/tests/fixtures/generation_fixtures.py b/tests/fixtures/generation_fixtures.py index 9ce618bbdb..b3cb7fb091 100644 --- a/tests/fixtures/generation_fixtures.py +++ b/tests/fixtures/generation_fixtures.py @@ -3,19 +3,24 @@ """ from argparse import Namespace +from contextlib import contextmanager from dataclasses import dataclass +from types import SimpleNamespace from typing import Any from unittest.mock import patch import pytest +import requests from miles.rollout.base_types import GenerateFnInput from miles.rollout.modular_rollout.compatibility import load_generate_function from miles.rollout.modular_rollout.orchestration_common import GenerateState +from miles.router.router import MilesRouter from miles.utils.async_utils import run -from miles.utils.http_utils import init_http_client +from miles.utils.http_utils import find_available_port, init_http_client from miles.utils.misc import SingletonMeta from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer from miles.utils.types import Sample MODEL_NAME = "Qwen/Qwen3-0.6B" @@ -27,6 +32,7 @@ "single_turn": "miles.rollout.generate_hub.single_turn.generate", "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn.generate", "multi_turn_multi_samples": "miles.rollout.generate_hub.multi_turn.generate", + "agentic_tool_call_multi_samples": "miles.rollout.generate_hub.agentic_tool_call.generate", } @@ -147,12 +153,13 @@ def make_args( if rollout_max_context_len is not None: argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples", "agentic_tool_call_multi_samples"): argv.extend(["--generate-max-turns", str(generate_max_turns)]) argv.extend(["--generate-tool-specs-path", generate_tool_specs_path]) - argv.extend(["--generate-tool-call-parser", generate_tool_call_parser]) argv.extend(["--generate-execute-tool-function-path", generate_execute_tool_function_path]) - if variant == "multi_turn_multi_samples": + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + argv.extend(["--generate-tool-call-parser", generate_tool_call_parser]) + if variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples"): argv.append("--generate-multi-samples") if extra_argv: @@ -167,6 +174,31 @@ def make_args( return args +@contextmanager +def with_miles_router(backend_url: str, model_name: str): + router_args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + hf_checkpoint=model_name, + ) + router = MilesRouter(router_args) + + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() + + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend_url}) + + try: + yield port + finally: + server.stop() + + @pytest.fixture def generation_env(request, variant): SingletonMeta.clear_all_instances() @@ -191,14 +223,15 @@ def process_fn(_): ) with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server: - other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} - args = make_args( - variant=variant, - router_port=mock_server.port, - model_name=model_name, - custom_generate_function_path=custom_generate_function_path, - **other_args_kwargs, - ) - yield GenerateEnv(args=args, mock_server=mock_server) + with with_miles_router(mock_server.url, model_name) as router_port: + other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} + args = make_args( + variant=variant, + router_port=router_port, + model_name=model_name, + custom_generate_function_path=custom_generate_function_path, + **other_args_kwargs, + ) + yield GenerateEnv(args=args, mock_server=mock_server) SingletonMeta.clear_all_instances() diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py index 8aff6bf148..89f019342c 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -10,6 +10,8 @@ from miles.utils.test_utils.mock_tools import ( MULTI_TURN_FIRST_PROMPT, MULTI_TURN_FIRST_RESPONSE, + MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN, + MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT, MULTI_TURN_SECOND_PROMPT, MULTI_TURN_SECOND_RESPONSE, SAMPLE_TOOLS, @@ -30,7 +32,7 @@ SECOND_PROMPT_TOKEN_IDS = TOKENIZER(MULTI_TURN_SECOND_PROMPT, add_special_tokens=False)["input_ids"] -@pytest.fixture(params=["multi_turn_single_sample", "multi_turn_multi_samples"]) +@pytest.fixture(params=["multi_turn_single_sample", "multi_turn_multi_samples", "agentic_tool_call_multi_samples"]) def variant(request): return request.param @@ -122,6 +124,10 @@ def expected_request(input_ids: list[int], sampling_params: dict | None = None) } +def expected_openai_request(messages: list[dict]) -> dict: + return {"messages": messages, "model": "default", "tools": SAMPLE_TOOLS} + + SINGLE_TURN_PROMPT = [{"role": "user", "content": "What is 1+1?"}] SINGLE_TURN_RESPONSE = "The answer is 2." _SINGLE_TURN_PROMPT_TEXT = TOKENIZER.apply_chat_template( @@ -155,7 +161,10 @@ def test_single_turn_no_tool_call(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) - assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] + if variant == "agentic_tool_call_multi_samples": + assert result.requests == [expected_openai_request(SINGLE_TURN_PROMPT)] + else: + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] verify_samples( result.sample, [ @@ -179,10 +188,16 @@ def test_two_turns_with_tool_call(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) - assert result.requests == [ - expected_request(FIRST_PROMPT_TOKEN_IDS), - expected_request(SECOND_PROMPT_TOKEN_IDS), - ] + if variant == "agentic_tool_call_multi_samples": + assert result.requests == [ + expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(FIRST_PROMPT_TOKEN_IDS), + expected_request(SECOND_PROMPT_TOKEN_IDS), + ] if variant == "multi_turn_single_sample": expected = [ ExpectedSampleInfo( @@ -244,12 +259,16 @@ def test_two_turns_with_tool_call(self, variant, generation_env): class TestExitConditions: def test_partial_rollout_not_supported(self, variant, generation_env): + if variant == "agentic_tool_call_multi_samples": + pytest.skip("agentic_tool_call does not check partial_rollout flag") generation_env.args.partial_rollout = True with pytest.raises(AssertionError, match="Partial rollout is not supported"): _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) def test_abort_preserves_content(self, variant, generation_env): + if variant == "agentic_tool_call_multi_samples": + pytest.skip("agentic_tool_call does not handle abort finish_reason") generation_env.mock_server.process_fn = lambda _: ProcessResult( text=SINGLE_TURN_RESPONSE, finish_reason="abort" ) @@ -285,7 +304,10 @@ def test_finish_reason_length_exits_and_preserves_content(self, variant, generat result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) - assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] + if variant == "agentic_tool_call_multi_samples": + assert result.requests == [expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN)] + else: + assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] verify_samples( result.sample, [ @@ -315,7 +337,10 @@ def test_max_turns_reached(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) - assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] + if variant == "agentic_tool_call_multi_samples": + assert result.requests == [expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN)] + else: + assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] if variant == "multi_turn_single_sample": expected = [ ExpectedSampleInfo( @@ -361,6 +386,8 @@ class TestRespectMaxContextLen: "generation_env", [{"args_kwargs": {"rollout_max_context_len": SINGLE_TURN_PROMPT_TOKEN_LEN}}], indirect=True ) def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if variant == "agentic_tool_call_multi_samples": + pytest.skip("TODO: implement") result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) assert result.requests == [] if variant == "multi_turn_single_sample": @@ -382,6 +409,8 @@ def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generat indirect=True, ) def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if variant == "agentic_tool_call_multi_samples": + pytest.skip("TODO: implement") generation_env.mock_server.process_fn = multi_turn_tool_call_process_fn result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) From ad6052bb26b2c1269fa3e02a19cfb2f0ce1c83af Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:18:35 +0800 Subject: [PATCH 42/77] Support merging samples to construct trajectory (#480) --- miles/rollout/generate_hub/sample_utils.py | 114 +++++++++++++ miles/utils/types.py | 14 ++ .../rollout/generate_hub/test_sample_utils.py | 156 ++++++++++++++++++ 3 files changed, 284 insertions(+) create mode 100644 miles/rollout/generate_hub/sample_utils.py create mode 100644 tests/rollout/generate_hub/test_sample_utils.py diff --git a/miles/rollout/generate_hub/sample_utils.py b/miles/rollout/generate_hub/sample_utils.py new file mode 100644 index 0000000000..c71e1ec57b --- /dev/null +++ b/miles/rollout/generate_hub/sample_utils.py @@ -0,0 +1,114 @@ +from copy import deepcopy +from dataclasses import fields + +from miles.utils.types import Sample + + +def merge_samples(samples: list[Sample], tokenizer) -> Sample: + acc = samples[0] + for sample in samples[1:]: + acc = merge_sample_pair(acc, sample, tokenizer=tokenizer) + return acc + + +def merge_sample_pair(a: Sample, b: Sample, tokenizer) -> Sample: + """Merge two samples generated from sibling inference engine calls.""" + a, b = deepcopy(a), deepcopy(b) + + def _merge_equal_value(field): + x = getattr(a, field) + y = getattr(b, field) + assert x == y, f"{field} mismatch: a.{field}={x}, b.{field}={y}" + return x + + def _fill_defaults(sample: Sample): + if sample.loss_mask is None: + sample.loss_mask = [1] * sample.response_length + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [0.0] * sample.response_length + + _fill_defaults(a) + _fill_defaults(b) + + obs_len = len(b.tokens) - len(a.tokens) - b.response_length + obs_tokens = b.tokens[len(a.tokens) : len(a.tokens) + obs_len] + # TODO: is this acceptable? + obs_text = tokenizer.decode(obs_tokens) + + try: + a.validate() + b.validate() + assert _startswith(short=a.prompt, long=b.prompt), "b.prompt must start with a.prompt" + assert _startswith(short=a.tokens, long=b.tokens), "b.tokens must start with a.tokens" + assert obs_len > 0, f"obs_len must be > 0, got {obs_len}" + assert a.status == Sample.Status.COMPLETED, f"a.status must be COMPLETED, got {a.status}" + + return _create_with_all_fields( + Sample, + group_index=_merge_equal_value("group_index"), + index=_merge_equal_value("index"), + prompt=b.prompt, + tokens=b.tokens, + multimodal_inputs=_merge_equal_value("multimodal_inputs"), + multimodal_train_inputs=_merge_equal_value("multimodal_train_inputs"), + response=a.response + obs_text + b.response, + response_length=a.response_length + obs_len + b.response_length, + label=_merge_equal_value("label"), + reward=_merge_equal_value("reward"), + loss_mask=a.loss_mask + [0] * obs_len + b.loss_mask, + weight_versions=a.weight_versions + b.weight_versions, + rollout_log_probs=a.rollout_log_probs + [0.0] * obs_len + b.rollout_log_probs, + # TODO should support concat + rollout_routed_experts=_merge_equal_value("rollout_routed_experts"), + remove_sample=_merge_equal_value("remove_sample"), + status=b.status, + metadata=_merge_equal_value("metadata"), + train_metadata=_merge_equal_value("train_metadata"), + non_generation_time=_merge_equal_value("non_generation_time"), + spec_info=_merge_spec_info(a.spec_info, b.spec_info), + prefix_cache_info=_merge_prefix_cache_info(a.prefix_cache_info, b.prefix_cache_info), + ) + except AssertionError as e: + e.add_note(f"{a=} {b=}") + raise + + +def _merge_spec_info(a: Sample.SpecInfo, b: Sample.SpecInfo) -> Sample.SpecInfo: + def _merge_plus_value(field): + return getattr(a, field) + getattr(b, field) + + return _create_with_all_fields( + Sample.SpecInfo, + spec_accept_token_num=_merge_plus_value("spec_accept_token_num"), + spec_draft_token_num=_merge_plus_value("spec_draft_token_num"), + spec_verify_ct=_merge_plus_value("spec_verify_ct"), + completion_token_num=_merge_plus_value("completion_token_num"), + ) + + +def _merge_prefix_cache_info(a: Sample.PrefixCacheInfo, b: Sample.PrefixCacheInfo) -> Sample.PrefixCacheInfo: + def _merge_plus_value(field): + return getattr(a, field) + getattr(b, field) + + return _create_with_all_fields( + Sample.PrefixCacheInfo, + cached_tokens=_merge_plus_value("cached_tokens"), + total_prompt_tokens=_merge_plus_value("total_prompt_tokens"), + ) + + +def _create_with_all_fields(cls, **kwargs): + expected = {f.name for f in fields(cls)} + actual = set(kwargs.keys()) + assert ( + expected == actual + ), f"{cls.__name__} field mismatch. Missing: {expected - actual}, Extra: {actual - expected}" + return cls(**kwargs) + + +def _startswith(*, short, long) -> bool: + if isinstance(short, str) and isinstance(long, str): + return long.startswith(short) + if isinstance(short, list) and isinstance(long, list): + return (len(long) >= len(short)) and (long[: len(short)] == short) + raise NotImplementedError diff --git a/miles/utils/types.py b/miles/utils/types.py index 0a2531a7af..cb690ec600 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -145,6 +145,20 @@ def get_reward_value(self, args) -> float: def effective_response_length(self): return sum(self.loss_mask) if self.loss_mask is not None else self.response_length + def validate(self): + assert self.response_length >= 0, f"response_length must be >= 0, got {self.response_length}" + assert ( + len(self.tokens) >= self.response_length + ), f"tokens length ({len(self.tokens)}) must be >= response_length ({self.response_length})" + if self.loss_mask is not None: + assert ( + len(self.loss_mask) == self.response_length + ), f"loss_mask length ({len(self.loss_mask)}) != response_length ({self.response_length})" + if self.rollout_log_probs is not None: + assert ( + len(self.rollout_log_probs) == self.response_length + ), f"rollout_log_probs length ({len(self.rollout_log_probs)}) != response_length ({self.response_length})" + def update_from_meta_info(self, args, meta_info: dict): """ Update the sample with new information from meta_info returned by the rollout engine. diff --git a/tests/rollout/generate_hub/test_sample_utils.py b/tests/rollout/generate_hub/test_sample_utils.py new file mode 100644 index 0000000000..0c49dd433c --- /dev/null +++ b/tests/rollout/generate_hub/test_sample_utils.py @@ -0,0 +1,156 @@ +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.generate_hub.sample_utils import merge_sample_pair +from miles.utils.types import Sample + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + tokenizer.decode = lambda tokens: f"" + return tokenizer + + +def make_sample( + prompt="test_prompt", + tokens=None, + response="", + response_length=0, + loss_mask=None, + rollout_log_probs=None, + status=Sample.Status.COMPLETED, + label="test_label", + reward=1.0, + index=0, + group_index=0, +): + return Sample( + prompt=prompt, + tokens=tokens or [], + response=response, + response_length=response_length, + loss_mask=loss_mask, + rollout_log_probs=rollout_log_probs, + status=status, + label=label, + reward=reward, + index=index, + group_index=group_index, + ) + + +class TestMergeSamples: + def test_basic_merge(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 3, 10, 11, 12], + response="response1", + response_length=3, + loss_mask=[1, 1, 1], + rollout_log_probs=[-0.1, -0.2, -0.3], + ) + b = make_sample( + tokens=[1, 2, 3, 10, 11, 12, 20, 21, 30, 31, 32], + response="response2", + response_length=3, + loss_mask=[1, 1, 1], + rollout_log_probs=[-0.4, -0.5, -0.6], + status=Sample.Status.TRUNCATED, + ) + + merged = merge_sample_pair(a, b, mock_tokenizer) + + assert merged.tokens == b.tokens + assert merged.response_length == 3 + 2 + 3 + assert merged.loss_mask == [1, 1, 1, 0, 0, 1, 1, 1] + assert merged.rollout_log_probs == [-0.1, -0.2, -0.3, 0.0, 0.0, -0.4, -0.5, -0.6] + assert merged.prompt == a.prompt + assert merged.status == b.status + assert merged.label == a.label + assert merged.index == a.index + assert merged.group_index == a.group_index + assert "response1" in merged.response + assert "response2" in merged.response + assert "" in merged.response + + def test_loss_mask_none_defaults_to_all_ones(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=None, + rollout_log_probs=None, + ) + b = make_sample( + tokens=[1, 2, 10, 20, 30], + response_length=1, + loss_mask=None, + rollout_log_probs=None, + ) + + merged = merge_sample_pair(a, b, mock_tokenizer) + + assert merged.loss_mask == [1, 0, 1] + assert merged.rollout_log_probs == [0.0, 0.0, 0.0] + + def test_tokens_prefix_mismatch_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 3], + response_length=1, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 99, 20, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="b.tokens must start with a.tokens"): + merge_sample_pair(a, b, mock_tokenizer) + + def test_field_mismatch_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=[1], + index=0, + ) + b = make_sample( + tokens=[1, 2, 10, 20, 30], + response_length=1, + loss_mask=[1], + index=1, + ) + + with pytest.raises(AssertionError, match="index mismatch"): + merge_sample_pair(a, b, mock_tokenizer) + + def test_obs_len_invalid_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 10, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="obs_len must be > 0"): + merge_sample_pair(a, b, mock_tokenizer) + + def test_sample_validate_fails_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10, 11], + response_length=2, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 10, 11, 20, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="loss_mask length"): + merge_sample_pair(a, b, mock_tokenizer) From 8a67e905ba5e8d6e771d0a4f748745fbecf536a1 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:19:06 +0800 Subject: [PATCH 43/77] Support agentic rollout to generate one single sample for the whole tracjectory (#481) --- .../rollout/generate_hub/agentic_tool_call.py | 3 ++ tests/fixtures/generation_fixtures.py | 8 ++++- tests/rollout/generate_hub/test_multi_turn.py | 31 +++++++++++++------ 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py index 8022182470..82b59d9719 100644 --- a/miles/rollout/generate_hub/agentic_tool_call.py +++ b/miles/rollout/generate_hub/agentic_tool_call.py @@ -10,6 +10,7 @@ from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput from miles.rollout.generate_hub.openai_endpoint_utils import OpenAIEndpointTracer, compute_samples_from_openai_records +from miles.rollout.generate_hub.sample_utils import merge_samples from miles.rollout.generate_hub.tool_call_utils import execute_tool_calls from miles.utils.misc import load_function @@ -27,6 +28,8 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: records = await tracer.collect_records() samples = compute_samples_from_openai_records(input.sample, records, input.state.tokenizer) + if not input.args.generate_multi_samples: + samples = merge_samples(samples, input.state.tokenizer) return GenerateFnOutput(samples=samples) diff --git a/tests/fixtures/generation_fixtures.py b/tests/fixtures/generation_fixtures.py index b3cb7fb091..a0af8da9bc 100644 --- a/tests/fixtures/generation_fixtures.py +++ b/tests/fixtures/generation_fixtures.py @@ -32,6 +32,7 @@ "single_turn": "miles.rollout.generate_hub.single_turn.generate", "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn.generate", "multi_turn_multi_samples": "miles.rollout.generate_hub.multi_turn.generate", + "agentic_tool_call_single_sample": "miles.rollout.generate_hub.agentic_tool_call.generate", "agentic_tool_call_multi_samples": "miles.rollout.generate_hub.agentic_tool_call.generate", } @@ -153,7 +154,12 @@ def make_args( if rollout_max_context_len is not None: argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples", "agentic_tool_call_multi_samples"): + if variant in ( + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", + ): argv.extend(["--generate-max-turns", str(generate_max_turns)]) argv.extend(["--generate-tool-specs-path", generate_tool_specs_path]) argv.extend(["--generate-execute-tool-function-path", generate_execute_tool_function_path]) diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py index 89f019342c..cba3d195d6 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -22,6 +22,10 @@ _ = generation_env, SAMPLE_TOOLS, multi_turn_tool_call_process_fn +def is_agentic_variant(variant: str) -> bool: + return variant in ("agentic_tool_call_single_sample", "agentic_tool_call_multi_samples") + + # ------------------------------------ fixtures and consts ---------------------------------------- @@ -32,7 +36,14 @@ SECOND_PROMPT_TOKEN_IDS = TOKENIZER(MULTI_TURN_SECOND_PROMPT, add_special_tokens=False)["input_ids"] -@pytest.fixture(params=["multi_turn_single_sample", "multi_turn_multi_samples", "agentic_tool_call_multi_samples"]) +@pytest.fixture( + params=[ + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", + ] +) def variant(request): return request.param @@ -161,7 +172,7 @@ def test_single_turn_no_tool_call(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) - if variant == "agentic_tool_call_multi_samples": + if is_agentic_variant(variant): assert result.requests == [expected_openai_request(SINGLE_TURN_PROMPT)] else: assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] @@ -188,7 +199,7 @@ def test_two_turns_with_tool_call(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) - if variant == "agentic_tool_call_multi_samples": + if is_agentic_variant(variant): assert result.requests == [ expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN), expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), @@ -198,7 +209,7 @@ def test_two_turns_with_tool_call(self, variant, generation_env): expected_request(FIRST_PROMPT_TOKEN_IDS), expected_request(SECOND_PROMPT_TOKEN_IDS), ] - if variant == "multi_turn_single_sample": + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): expected = [ ExpectedSampleInfo( chunks=[ @@ -259,7 +270,7 @@ def test_two_turns_with_tool_call(self, variant, generation_env): class TestExitConditions: def test_partial_rollout_not_supported(self, variant, generation_env): - if variant == "agentic_tool_call_multi_samples": + if is_agentic_variant(variant): pytest.skip("agentic_tool_call does not check partial_rollout flag") generation_env.args.partial_rollout = True @@ -267,7 +278,7 @@ def test_partial_rollout_not_supported(self, variant, generation_env): _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) def test_abort_preserves_content(self, variant, generation_env): - if variant == "agentic_tool_call_multi_samples": + if is_agentic_variant(variant): pytest.skip("agentic_tool_call does not handle abort finish_reason") generation_env.mock_server.process_fn = lambda _: ProcessResult( text=SINGLE_TURN_RESPONSE, finish_reason="abort" @@ -304,7 +315,7 @@ def test_finish_reason_length_exits_and_preserves_content(self, variant, generat result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) - if variant == "agentic_tool_call_multi_samples": + if is_agentic_variant(variant): assert result.requests == [expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN)] else: assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] @@ -337,7 +348,7 @@ def test_max_turns_reached(self, variant, generation_env): result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) - if variant == "agentic_tool_call_multi_samples": + if is_agentic_variant(variant): assert result.requests == [expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN)] else: assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] @@ -386,7 +397,7 @@ class TestRespectMaxContextLen: "generation_env", [{"args_kwargs": {"rollout_max_context_len": SINGLE_TURN_PROMPT_TOKEN_LEN}}], indirect=True ) def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): - if variant == "agentic_tool_call_multi_samples": + if is_agentic_variant(variant): pytest.skip("TODO: implement") result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) assert result.requests == [] @@ -409,7 +420,7 @@ def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generat indirect=True, ) def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env): - if variant == "agentic_tool_call_multi_samples": + if is_agentic_variant(variant): pytest.skip("TODO: implement") generation_env.mock_server.process_fn = multi_turn_tool_call_process_fn From 5d296bf99adde134103a8227023016e944083921 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:19:26 +0800 Subject: [PATCH 44/77] Add three turn integration testing and refactor related stubs (#482) --- miles/utils/test_utils/mock_tools.py | 298 ++++++++++++------ tests/rollout/generate_hub/test_multi_turn.py | 289 +++++++++-------- 2 files changed, 347 insertions(+), 240 deletions(-) diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py index 220bd2bc01..6b99e36739 100644 --- a/miles/utils/test_utils/mock_tools.py +++ b/miles/utils/test_utils/mock_tools.py @@ -1,5 +1,7 @@ import json +from transformers import AutoTokenizer + from miles.utils.test_utils.mock_sglang_server import ProcessResult SAMPLE_TOOLS = [ @@ -36,8 +38,10 @@ def _get_year(params: dict) -> str: def _get_temperature(params: dict) -> str: - assert params.get("location") == "Mars" - return json.dumps({"temperature": -60}) + temps = {"Mars": -60, "Earth": 15} + location = params.get("location") + assert location in temps, f"Unknown location: {location}" + return json.dumps({"temperature": temps[location]}) TOOL_EXECUTORS = { @@ -50,7 +54,7 @@ async def execute_tool_call(name: str, params: dict) -> str: return TOOL_EXECUTORS[name](params) -MULTI_TURN_FIRST_PROMPT = ( +_SYSTEM_PROMPT = ( "<|im_start|>system\n" "# Tools\n" "\n" @@ -66,115 +70,199 @@ async def execute_tool_call(name: str, params: dict) -> str: "\n" '{"name": , "arguments": }\n' "<|im_end|>\n" - "<|im_start|>user\n" - "What is 42 + year + temperature?<|im_end|>\n" - "<|im_start|>assistant\n" -) -MULTI_TURN_FIRST_RESPONSE = ( - "Let me get the year and temperature first.\n" - "\n" - '{"name": "get_year", "arguments": {}}\n' - "\n" - "\n" - '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' - "<|im_end|>\n" ) -MULTI_TURN_SECOND_PROMPT = ( - "<|im_start|>system\n" - "# Tools\n" - "\n" - "You may call one or more functions to assist with the user query.\n" - "\n" - "You are provided with function signatures within XML tags:\n" - "\n" - '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' - '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' - "\n" - "\n" - "For each function call, return a json object with function name and arguments within XML tags:\n" - "\n" - '{"name": , "arguments": }\n' - "<|im_end|>\n" - "<|im_start|>user\n" - "What is 42 + year + temperature?<|im_end|>\n" - "<|im_start|>assistant\n" - "Let me get the year and temperature first.\n" - "\n" - '{"name": "get_year", "arguments": {}}\n' - "\n" - "\n" - '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' - "<|im_end|>\n" - "<|im_start|>user\n" - "\n" - '{"year": 2026}\n' - "\n" - "\n" - '{"temperature": -60}\n' - "<|im_end|>\n" - "<|im_start|>assistant\n" -) -MULTI_TURN_SECOND_RESPONSE = "The answer is: 42 + 2026 + -60 = 2008." -MULTI_TURN_USER_QUESTION = "What is 42 + year + temperature?" -MULTI_TURN_FIRST_RESPONSE_CONTENT = "Let me get the year and temperature first." -MULTI_TURN_FIRST_TOOL_CALLS = [ - {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}}, - { - "id": "call00001", - "type": "function", - "function": {"name": "get_temperature", "arguments": '{"location": "Mars"}'}, - }, -] -MULTI_TURN_FIRST_TOOL_CALLS_OPENAI_FORMAT = [ - {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, - { - "id": "call00001", - "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, - "type": "function", - }, -] +_TOKENIZER = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) -MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN = [ - {"role": "user", "content": MULTI_TURN_USER_QUESTION}, -] -MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN = [ - {"role": "user", "content": MULTI_TURN_USER_QUESTION}, - { - "role": "assistant", - "content": MULTI_TURN_FIRST_RESPONSE_CONTENT, - "tool_calls": MULTI_TURN_FIRST_TOOL_CALLS, - }, - {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}'}, - {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}'}, -] +class TwoTurnStub: + """Stub for 2-turn: get_year + get_temperature(Mars) -> final answer""" + + USER_QUESTION = "What is 42 + year + temperature?" + + FIRST_RESPONSE = ( + "Let me get the year and temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "<|im_end|>\n" + ) + + FIRST_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + SECOND_RESPONSE = "The answer is: 42 + 2026 + -60 = 2008." + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = "Let me get the year and temperature first." + FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + TwoTurnStub.FIRST_PROMPT: TwoTurnStub.FIRST_RESPONSE, + TwoTurnStub.SECOND_PROMPT: TwoTurnStub.SECOND_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") -MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = [ - {"role": "user", "content": MULTI_TURN_USER_QUESTION}, - { - "content": MULTI_TURN_FIRST_RESPONSE_CONTENT, - "refusal": None, - "role": "assistant", - "annotations": None, - "audio": None, - "function_call": None, - "tool_calls": MULTI_TURN_FIRST_TOOL_CALLS_OPENAI_FORMAT, - }, - {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, - {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, -] +class ThreeTurnStub: + """Stub for 3-turn: get_year + get_temperature(Mars) -> get_temperature(Earth) -> final answer""" + + USER_QUESTION = "What is 42 + year + Mars temperature + Earth temperature?" + + FIRST_RESPONSE = ( + "Let me get the year and Mars temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "<|im_end|>\n" + ) + + SECOND_RESPONSE = ( + "Now let me get Earth temperature.\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Earth"}}\n' + "<|im_end|>\n" + ) + + FIRST_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + SECOND_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"temperature": 15}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + THIRD_RESPONSE = "The answer is: 42 + 2026 + -60 + 15 = 2023." + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + THIRD_PROMPT = SECOND_PROMPT + SECOND_RESPONSE + SECOND_TOOL_RESPONSE + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + THIRD_PROMPT_TOKEN_IDS = _TOKENIZER(THIRD_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = "Let me get the year and Mars temperature first." + FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + SECOND_RESPONSE_CONTENT = "Now let me get Earth temperature." + SECOND_TOOL_CALLS_OPENAI_FORMAT = [ + { + "id": "call00000", + "function": {"arguments": '{"location": "Earth"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT = OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT + [ + { + "content": SECOND_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": SECOND_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"temperature": 15}', "name": "get_temperature"}, + ] -def multi_turn_tool_call_process_fn(prompt: str) -> ProcessResult: - prompt_response_pairs = { - MULTI_TURN_FIRST_PROMPT: MULTI_TURN_FIRST_RESPONSE, - MULTI_TURN_SECOND_PROMPT: MULTI_TURN_SECOND_RESPONSE, - } + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + ThreeTurnStub.FIRST_PROMPT: ThreeTurnStub.FIRST_RESPONSE, + ThreeTurnStub.SECOND_PROMPT: ThreeTurnStub.SECOND_RESPONSE, + ThreeTurnStub.THIRD_PROMPT: ThreeTurnStub.THIRD_RESPONSE, + } - for expect_prompt, response in prompt_response_pairs.items(): - if prompt == expect_prompt: - return ProcessResult(text=response, finish_reason="stop") + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") - raise ValueError(f"Unexpected {prompt=}") + raise ValueError(f"Unexpected {prompt=}") diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py index cba3d195d6..a59b1f2325 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -7,19 +7,10 @@ from transformers import AutoTokenizer from miles.utils.test_utils.mock_sglang_server import ProcessResult -from miles.utils.test_utils.mock_tools import ( - MULTI_TURN_FIRST_PROMPT, - MULTI_TURN_FIRST_RESPONSE, - MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN, - MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT, - MULTI_TURN_SECOND_PROMPT, - MULTI_TURN_SECOND_RESPONSE, - SAMPLE_TOOLS, - multi_turn_tool_call_process_fn, -) +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, ThreeTurnStub, TwoTurnStub from miles.utils.types import Sample -_ = generation_env, SAMPLE_TOOLS, multi_turn_tool_call_process_fn +_ = generation_env, SAMPLE_TOOLS, TwoTurnStub, ThreeTurnStub def is_agentic_variant(variant: str) -> bool: @@ -32,8 +23,6 @@ def is_agentic_variant(variant: str) -> bool: MODEL_NAME = "Qwen/Qwen3-0.6B" DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) -FIRST_PROMPT_TOKEN_IDS = TOKENIZER(MULTI_TURN_FIRST_PROMPT, add_special_tokens=False)["input_ids"] -SECOND_PROMPT_TOKEN_IDS = TOKENIZER(MULTI_TURN_SECOND_PROMPT, add_special_tokens=False)["input_ids"] @pytest.fixture( @@ -61,6 +50,16 @@ class ExpectedSampleInfo: partial_sample: Sample +def token_len(text: str) -> int: + return len(TOKENIZER(text, add_special_tokens=False)["input_ids"]) + + +def expected_chunk(text: str, loss_mask: int) -> SampleParsedChunk: + n = token_len(text) + log_probs = [-1 / 128 * i for i in range(n)] if loss_mask else [0.0] * n + return SampleParsedChunk(text, loss_mask, log_probs) + + def parse_sample_into_chunks(sample: Sample, tokenizer) -> list[SampleParsedChunk]: prompt_len = len(sample.tokens) - sample.response_length response_tokens = sample.tokens[prompt_len:] @@ -147,19 +146,6 @@ def expected_openai_request(messages: list[dict]) -> dict: SINGLE_TURN_PROMPT_TOKEN_IDS = TOKENIZER(_SINGLE_TURN_PROMPT_TEXT, add_special_tokens=False)["input_ids"] SINGLE_TURN_PROMPT_TOKEN_LEN = len(SINGLE_TURN_PROMPT_TOKEN_IDS) -TWO_TURN_USER_QUESTION = "What is 42 + year + temperature?" -TWO_TURN_PROMPT = [{"role": "user", "content": TWO_TURN_USER_QUESTION}] -TWO_TURN_TOOL_RESPONSE = ( - "<|im_start|>user\n" - "\n" - '{"year": 2026}\n' - "\n" - "\n" - '{"temperature": -60}\n' - "<|im_end|>\n" - "<|im_start|>assistant\n" -) - # ------------------------------------ tests ---------------------------------------- @@ -195,73 +181,53 @@ def test_single_turn_no_tool_call(self, variant, generation_env): ) def test_two_turns_with_tool_call(self, variant, generation_env): - generation_env.mock_server.process_fn = multi_turn_tool_call_process_fn + generation_env.mock_server.process_fn = TwoTurnStub.process_fn - result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) + S = TwoTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) if is_agentic_variant(variant): assert result.requests == [ - expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN), - expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), ] else: assert result.requests == [ - expected_request(FIRST_PROMPT_TOKEN_IDS), - expected_request(SECOND_PROMPT_TOKEN_IDS), + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), ] if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + S.SECOND_RESPONSE expected = [ ExpectedSampleInfo( chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(47)], - ), - SampleParsedChunk( - tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, loss_mask_value=0, rollout_log_probs=[0.0] * 31 - ), - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_SECOND_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(24)], - ), + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), ], partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE + MULTI_TURN_SECOND_RESPONSE, - response_length=47 + 31 + 24, + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), ), ), ] else: expected = [ ExpectedSampleInfo( - chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(47)], - ) - ], + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE, - response_length=47, + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), ), ), ExpectedSampleInfo( - chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_SECOND_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(24)], - ) - ], + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_SECOND_RESPONSE, - response_length=24, + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), ), ), ] @@ -309,31 +275,24 @@ def test_abort_preserves_content(self, variant, generation_env): ) def test_finish_reason_length_exits_and_preserves_content(self, variant, generation_env): - generation_env.mock_server.process_fn = lambda _: ProcessResult( - text=MULTI_TURN_FIRST_RESPONSE, finish_reason="length" - ) + S = TwoTurnStub + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="length") - result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) if is_agentic_variant(variant): - assert result.requests == [expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN)] + assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] else: - assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] verify_samples( result.sample, [ ExpectedSampleInfo( - chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(47)], - ) - ], + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE, - response_length=47, + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), status=Sample.Status.TRUNCATED, ), ), @@ -342,50 +301,37 @@ def test_finish_reason_length_exits_and_preserves_content(self, variant, generat @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"generate_max_turns": 1}}], indirect=True) def test_max_turns_reached(self, variant, generation_env): - generation_env.mock_server.process_fn = lambda _: ProcessResult( - text=MULTI_TURN_FIRST_RESPONSE, finish_reason="stop" - ) + S = TwoTurnStub + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="stop") - result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) if is_agentic_variant(variant): - assert result.requests == [expected_openai_request(MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN)] + assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] else: - assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] if variant == "multi_turn_single_sample": expected = [ ExpectedSampleInfo( chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(47)], - ), - SampleParsedChunk( - tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, loss_mask_value=0, rollout_log_probs=[0.0] * 31 - ), + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), ], partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE, - response_length=47 + 31, + prompt=S.PROMPT, + response=S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE), ), ), ] else: expected = [ ExpectedSampleInfo( - chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(47)], - ) - ], + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE, - response_length=47, + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), ), ), ] @@ -416,53 +362,126 @@ def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generat @pytest.mark.parametrize( "generation_env", - [{"args_kwargs": {"rollout_max_context_len": len(FIRST_PROMPT_TOKEN_IDS) + 47 + 31}}], + [ + { + "args_kwargs": { + "rollout_max_context_len": len(TwoTurnStub.FIRST_PROMPT_TOKEN_IDS) + + token_len(TwoTurnStub.FIRST_RESPONSE) + + token_len(TwoTurnStub.FIRST_TOOL_RESPONSE) + } + } + ], indirect=True, ) def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env): if is_agentic_variant(variant): pytest.skip("TODO: implement") - generation_env.mock_server.process_fn = multi_turn_tool_call_process_fn + S = TwoTurnStub + generation_env.mock_server.process_fn = S.process_fn - result = _run_generate(variant, generation_env, make_sample(prompt=TWO_TURN_PROMPT)) + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) - assert result.requests == [expected_request(FIRST_PROMPT_TOKEN_IDS)] + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] if variant == "multi_turn_single_sample": + partial_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE expected = [ ExpectedSampleInfo( chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(47)], - ), - SampleParsedChunk( - tokens_decoded_str=TWO_TURN_TOOL_RESPONSE, loss_mask_value=0, rollout_log_probs=[0.0] * 31 - ), + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), ], partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE + TWO_TURN_TOOL_RESPONSE, - response_length=47 + 31, + prompt=S.PROMPT, + response=partial_response, + response_length=token_len(partial_response), status=Sample.Status.TRUNCATED, ), ), ] else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + status=Sample.Status.TRUNCATED, + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestThreeTurn: + """Need to test 3-turn case besides 2-turn, because e.g. merge_samples may behave differently.""" + + def test_three_turns_with_sequential_tool_calls(self, variant, generation_env): + generation_env.mock_server.process_fn = ThreeTurnStub.process_fn + + S = ThreeTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [ + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + expected_openai_request(S.OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), + expected_request(S.THIRD_PROMPT_TOKEN_IDS), + ] + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = ( + S.FIRST_RESPONSE + + S.FIRST_TOOL_RESPONSE + + S.SECOND_RESPONSE + + S.SECOND_TOOL_RESPONSE + + S.THIRD_RESPONSE + ) expected = [ ExpectedSampleInfo( chunks=[ - SampleParsedChunk( - tokens_decoded_str=MULTI_TURN_FIRST_RESPONSE, - loss_mask_value=1, - rollout_log_probs=[-1 / 128 * i for i in range(47)], - ) + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), + expected_chunk(S.SECOND_TOOL_RESPONSE, 0), + expected_chunk(S.THIRD_RESPONSE, 1), ], partial_sample=expected_partial_sample( - prompt=TWO_TURN_PROMPT, - response=MULTI_TURN_FIRST_RESPONSE, - response_length=47, - status=Sample.Status.TRUNCATED, + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.THIRD_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.THIRD_RESPONSE, + response_length=token_len(S.THIRD_RESPONSE), ), ), ] From 7419ca7fe8768ac6cd9934625d7a35b59f1a7cc1 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:19:49 +0800 Subject: [PATCH 45/77] Add rollout level integration test for (multi-turn, agentic) x (single-sample, multi-sample) (#483) --- .../modular_rollout/orchestration_eval.py | 13 +- tests/fixtures/generation_fixtures.py | 61 +++++++--- .../modular_rollout/integration/test_basic.py | 3 +- .../integration/test_deterministic.py | 6 +- .../integration/test_dynamic_filter.py | 6 +- .../integration/test_group_rm.py | 4 +- .../integration/test_multi_sample.py | 2 +- .../integration/test_multi_turn.py | 114 ++++++++++++++++++ .../integration/test_over_sampling.py | 8 +- .../integration/test_sample_filter.py | 8 +- .../integration/test_semaphore.py | 10 +- .../modular_rollout/integration/utils.py | 33 +++-- 12 files changed, 223 insertions(+), 45 deletions(-) create mode 100644 tests/rollout/modular_rollout/integration/test_multi_turn.py diff --git a/miles/rollout/modular_rollout/orchestration_eval.py b/miles/rollout/modular_rollout/orchestration_eval.py index 5d95c54d49..0e215e9711 100644 --- a/miles/rollout/modular_rollout/orchestration_eval.py +++ b/miles/rollout/modular_rollout/orchestration_eval.py @@ -81,11 +81,14 @@ async def eval_rollout_single_dataset( pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) async for sample in as_completed_async(tasks): if do_print: - logger.info( - "eval_rollout_single_dataset example data: " - f"{[str(sample.prompt) + sample.response]} " - f"reward={sample.reward}" - ) + # TODO improve this after enhancing samples' type + s = (sample[0] if len(sample) > 0 else None) if isinstance(sample, list) else sample + if s is not None: + logger.info( + "eval_rollout_single_dataset example data: " + f"{[str(s.prompt) + s.response]} " + f"reward={s.reward}" + ) do_print = False if isinstance(sample, list): data.extend(sample) diff --git a/tests/fixtures/generation_fixtures.py b/tests/fixtures/generation_fixtures.py index a0af8da9bc..8c144cfe4c 100644 --- a/tests/fixtures/generation_fixtures.py +++ b/tests/fixtures/generation_fixtures.py @@ -37,6 +37,42 @@ } +def extra_argv_for_variant( + variant: str, + *, + custom_generate_function_path: str | None = None, + generate_max_turns: int = 16, + generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + generate_tool_call_parser: str = "qwen25", + generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", +) -> list[str]: + argv = [ + "--custom-generate-function-path", + custom_generate_function_path or VARIANT_TO_GENERATE_FN_PATH[variant], + ] + + if variant in ( + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", + ): + argv += [ + "--generate-max-turns", + str(generate_max_turns), + "--generate-tool-specs-path", + generate_tool_specs_path, + "--generate-execute-tool-function-path", + generate_execute_tool_function_path, + ] + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + argv += ["--generate-tool-call-parser", generate_tool_call_parser] + if variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples"): + argv.append("--generate-multi-samples") + + return argv + + def listify(x): return x if isinstance(x, list) else [x] @@ -149,24 +185,19 @@ def make_args( argv.append("--use-rollout-routing-replay") if sglang_speculative_algorithm: argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) - if custom_generate_function_path: - argv.extend(["--custom-generate-function-path", custom_generate_function_path]) if rollout_max_context_len is not None: argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) - if variant in ( - "multi_turn_single_sample", - "multi_turn_multi_samples", - "agentic_tool_call_single_sample", - "agentic_tool_call_multi_samples", - ): - argv.extend(["--generate-max-turns", str(generate_max_turns)]) - argv.extend(["--generate-tool-specs-path", generate_tool_specs_path]) - argv.extend(["--generate-execute-tool-function-path", generate_execute_tool_function_path]) - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): - argv.extend(["--generate-tool-call-parser", generate_tool_call_parser]) - if variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples"): - argv.append("--generate-multi-samples") + argv.extend( + extra_argv_for_variant( + variant, + custom_generate_function_path=custom_generate_function_path, + generate_max_turns=generate_max_turns, + generate_tool_specs_path=generate_tool_specs_path, + generate_tool_call_parser=generate_tool_call_parser, + generate_execute_tool_function_path=generate_execute_tool_function_path, + ) + ) if extra_argv: argv.extend(extra_argv) diff --git a/tests/rollout/modular_rollout/integration/test_basic.py b/tests/rollout/modular_rollout/integration/test_basic.py index bbb82ae50e..bf12cb3735 100644 --- a/tests/rollout/modular_rollout/integration/test_basic.py +++ b/tests/rollout/modular_rollout/integration/test_basic.py @@ -1,4 +1,5 @@ import pytest +from tests.fixtures.generation_fixtures import extra_argv_for_variant from tests.fixtures.rollout_integration import IntegrationEnvConfig from tests.rollout.modular_rollout.integration.utils import ( MODULAR_ROLLOUT_BASE_ARGV, @@ -37,7 +38,7 @@ id="new_rollout_old_generate", ), pytest.param( - IntegrationEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV), + IntegrationEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant("single_turn")), id="new_rollout_new_generate", ), ] diff --git a/tests/rollout/modular_rollout/integration/test_deterministic.py b/tests/rollout/modular_rollout/integration/test_deterministic.py index 63316ceb45..5a1dbb4f10 100644 --- a/tests/rollout/modular_rollout/integration/test_deterministic.py +++ b/tests/rollout/modular_rollout/integration/test_deterministic.py @@ -1,13 +1,13 @@ import pytest -from tests.rollout.modular_rollout.integration.utils import config, load_and_call_train +from tests.rollout.modular_rollout.integration.utils import integration_env_config, load_and_call_train @pytest.mark.parametrize( "rollout_integration_env,expected_seeds", [ pytest.param( - config( + integration_env_config( [ "--sglang-enable-deterministic-inference", "--rollout-seed", @@ -22,7 +22,7 @@ id="enabled", ), pytest.param( - config(["--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + integration_env_config(["--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), {None}, id="disabled", ), diff --git a/tests/rollout/modular_rollout/integration/test_dynamic_filter.py b/tests/rollout/modular_rollout/integration/test_dynamic_filter.py index c7e86657c5..eb25c9c1ad 100644 --- a/tests/rollout/modular_rollout/integration/test_dynamic_filter.py +++ b/tests/rollout/modular_rollout/integration/test_dynamic_filter.py @@ -3,8 +3,8 @@ import pytest from tests.rollout.modular_rollout.integration.utils import ( MIXED_DATA_ROWS, - config, filter_by_reward, + integration_env_config, load_and_call_train, ) @@ -15,13 +15,13 @@ "rollout_integration_env,use_filter,expect_all_correct", [ pytest.param( - config(["--rollout-batch-size", "4"], data_rows=MIXED_DATA_ROWS), + integration_env_config(["--rollout-batch-size", "4"], data_rows=MIXED_DATA_ROWS), False, False, id="no_filter", ), pytest.param( - config( + integration_env_config( ["--rollout-batch-size", "3", "--dynamic-sampling-filter-path", "test:filter_by_reward"], data_rows=MIXED_DATA_ROWS, ), diff --git a/tests/rollout/modular_rollout/integration/test_group_rm.py b/tests/rollout/modular_rollout/integration/test_group_rm.py index 8b8ab269d6..a1811467c2 100644 --- a/tests/rollout/modular_rollout/integration/test_group_rm.py +++ b/tests/rollout/modular_rollout/integration/test_group_rm.py @@ -1,13 +1,13 @@ import pytest -from tests.rollout.modular_rollout.integration.utils import config, load_and_call_train +from tests.rollout.modular_rollout.integration.utils import integration_env_config, load_and_call_train @pytest.mark.parametrize( "rollout_integration_env", [ pytest.param( - config(["--group-rm", "--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + integration_env_config(["--group-rm", "--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), id="group_rm_enabled", ), ], diff --git a/tests/rollout/modular_rollout/integration/test_multi_sample.py b/tests/rollout/modular_rollout/integration/test_multi_sample.py index 72cdee12b9..a2e854d9a8 100644 --- a/tests/rollout/modular_rollout/integration/test_multi_sample.py +++ b/tests/rollout/modular_rollout/integration/test_multi_sample.py @@ -35,7 +35,7 @@ async def _multi_sample_generate(input: GenerateFnInput) -> GenerateFnOutput: [ pytest.param( IntegrationEnvConfig( - extra_argv=MODULAR_ROLLOUT_BASE_ARGV[:4] + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + [ "--custom-generate-function-path", "test:multi_sample_generate", diff --git a/tests/rollout/modular_rollout/integration/test_multi_turn.py b/tests/rollout/modular_rollout/integration/test_multi_turn.py new file mode 100644 index 0000000000..97df120817 --- /dev/null +++ b/tests/rollout/modular_rollout/integration/test_multi_turn.py @@ -0,0 +1,114 @@ +from typing import Any + +import pytest +from tests.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fixtures.rollout_integration import IntegrationEnvConfig +from tests.rollout.modular_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_rollout + +from miles.utils.test_utils.mock_tools import TwoTurnStub +from miles.utils.types import Sample + + +TWO_TURN_DATA_ROWS = [{"input": [{"role": "user", "content": TwoTurnStub.USER_QUESTION}], "label": "2008"}] + +_VARIANT_NAMES = [ + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", +] + +_BASE_EXTRA_ARGV = [ + "--rollout-batch-size", + "2", + "--n-samples-per-prompt", + "2", + "--n-samples-per-eval-prompt", + "2", + "--custom-rm-path", + "tests.rollout.modular_rollout.integration.test_generate_hub._simple_reward_function", +] + + +def _config_for_variant(variant: str) -> IntegrationEnvConfig: + return IntegrationEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + _BASE_EXTRA_ARGV, + data_rows=TWO_TURN_DATA_ROWS, + ) + + +@pytest.mark.parametrize( + "variant,rollout_integration_env", + [pytest.param(variant, _config_for_variant(variant), id=variant) for variant in _VARIANT_NAMES], + indirect=["rollout_integration_env"], +) +@pytest.mark.parametrize("test_type", ["train", "eval"]) +def test_rollout(rollout_integration_env, variant, test_type): + env = rollout_integration_env + env.mock_server.process_fn = TwoTurnStub.process_fn + + out = load_and_call_rollout(env.args, env.data_source, mode=test_type) + + if test_type == "train": + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + _verify_samples(variant, group) + else: + assert "toy" in out.data + samples = out.data["toy"]["samples"] + _verify_samples(variant, samples) + + +def _verify_samples(variant: str, samples: list[Any]): + is_multi_samples = variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples") + + if is_multi_samples: + if len(samples) > 0 and isinstance(samples[0], list): + # Train mode: list[list[Sample]], grouped by prompt + assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" + for group_sample in samples: + assert isinstance(group_sample, list), "multi_samples variant should return list[Sample] per generate" + _verify_group_samples(group_sample) + else: + # Eval mode: list[Sample], flattened + # n_samples_per_eval_prompt=2, and each generate returns 2 turns, so 2*2=4 samples + assert ( + len(samples) == 4 + ), f"n_samples_per_eval_prompt=2, each generate returns 2 turns, so should have 4 samples, got {len(samples)}" + # Group samples by prompt (every 2 samples form a group) + group_samples_list = [samples[i : i + 2] for i in range(0, len(samples), 2)] + for group_samples in group_samples_list: + _verify_group_samples(group_samples) + else: + assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" + for sample in samples: + assert isinstance(sample, Sample), "single_sample variant should return Sample, not list" + _verify_sample(sample) + + +def _verify_group_samples(group_samples: list[Sample], expected_count: int = 2): + assert len(group_samples) == expected_count, f"Group should have {expected_count} samples (one per turn)" + for i, sample in enumerate(group_samples): + _verify_sample(sample, expect_answer=(i == len(group_samples) - 1)) + + +def _verify_sample(sample: Sample, expected_reward: float = 1.0, expect_answer: bool = True): + assert sample.status == Sample.Status.COMPLETED + assert sample.reward == expected_reward, f"Sample should have reward={expected_reward}" + if expect_answer: + assert "2008" in sample.response, "Response should contain final answer '2008'" + + +async def _simple_reward_function(args, samples: Sample | list[Sample]) -> float | list[float]: + if isinstance(samples, list): + # For multi_samples variants, use the last sample's reward + if getattr(args, "generate_multi_samples", False): + return [_check_reward(samples[-1])] * len(samples) + else: + return [_check_reward(sample) for sample in samples] + else: + return _check_reward(samples) + + +def _check_reward(sample: Sample) -> float: + return float(sample.response and (str(sample.label) in sample.response)) diff --git a/tests/rollout/modular_rollout/integration/test_over_sampling.py b/tests/rollout/modular_rollout/integration/test_over_sampling.py index 17ae7cb38f..e4318c88fa 100644 --- a/tests/rollout/modular_rollout/integration/test_over_sampling.py +++ b/tests/rollout/modular_rollout/integration/test_over_sampling.py @@ -1,5 +1,9 @@ import pytest -from tests.rollout.modular_rollout.integration.utils import config, filter_by_reward, load_and_call_train +from tests.rollout.modular_rollout.integration.utils import ( + filter_by_reward, + integration_env_config, + load_and_call_train, +) from miles.utils.misc import function_registry @@ -19,7 +23,7 @@ def _over_sampling_config(rollout_batch_size: int): - return config(["--rollout-batch-size", str(rollout_batch_size)] + _BASE_ARGV, data_rows=_DATA_ROWS) + return integration_env_config(["--rollout-batch-size", str(rollout_batch_size)] + _BASE_ARGV, data_rows=_DATA_ROWS) @pytest.mark.parametrize( diff --git a/tests/rollout/modular_rollout/integration/test_sample_filter.py b/tests/rollout/modular_rollout/integration/test_sample_filter.py index 751d689cb4..a69f05b352 100644 --- a/tests/rollout/modular_rollout/integration/test_sample_filter.py +++ b/tests/rollout/modular_rollout/integration/test_sample_filter.py @@ -1,7 +1,11 @@ from unittest.mock import Mock import pytest -from tests.rollout.modular_rollout.integration.utils import config, filter_by_reward, load_and_call_train +from tests.rollout.modular_rollout.integration.utils import ( + filter_by_reward, + integration_env_config, + load_and_call_train, +) from miles.utils.misc import function_registry @@ -19,7 +23,7 @@ "rollout_integration_env", [ pytest.param( - config( + integration_env_config( [ "--rollout-batch-size", "2", diff --git a/tests/rollout/modular_rollout/integration/test_semaphore.py b/tests/rollout/modular_rollout/integration/test_semaphore.py index bcd09e3559..ce42728635 100644 --- a/tests/rollout/modular_rollout/integration/test_semaphore.py +++ b/tests/rollout/modular_rollout/integration/test_semaphore.py @@ -1,6 +1,6 @@ import pytest -from tests.rollout.modular_rollout.integration.utils import config, load_and_call_train +from tests.rollout.modular_rollout.integration.utils import integration_env_config, load_and_call_train _DATA_ROWS = [{"input": f"What is 1+{i}?", "label": str(1 + i)} for i in range(10)] _BASE_ARGV = ["--rollout-batch-size", "4", "--n-samples-per-prompt", "2"] @@ -10,12 +10,16 @@ "rollout_integration_env,expected_range", [ pytest.param( - config(["--sglang-server-concurrency", "1"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05), + integration_env_config( + ["--sglang-server-concurrency", "1"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 + ), (1, 1), id="limit_1", ), pytest.param( - config(["--sglang-server-concurrency", "999"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05), + integration_env_config( + ["--sglang-server-concurrency", "999"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 + ), (2, 999), id="no_limit", ), diff --git a/tests/rollout/modular_rollout/integration/utils.py b/tests/rollout/modular_rollout/integration/utils.py index 260b3f1516..511a43bb70 100644 --- a/tests/rollout/modular_rollout/integration/utils.py +++ b/tests/rollout/modular_rollout/integration/utils.py @@ -1,6 +1,12 @@ +from tests.fixtures.generation_fixtures import extra_argv_for_variant from tests.fixtures.rollout_integration import IntegrationEnvConfig -from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnTrainInput +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnOutput, + RolloutFnTrainInput, +) from miles.rollout.filter_hub.base_types import DynamicFilterOutput from miles.rollout.modular_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils.types import Sample @@ -39,8 +45,6 @@ def expected_sample(*, group_index: int | None) -> Sample: "miles.rollout.modular_rollout.orchestration_train.SimpleTrainRolloutFn", "--eval-function-path", "miles.rollout.modular_rollout.orchestration_eval.SimpleEvalRolloutFn", - "--custom-generate-function-path", - "miles.rollout.generate_hub.single_turn.generate", ] MIXED_DATA_ROWS = [ @@ -51,20 +55,33 @@ def expected_sample(*, group_index: int | None) -> Sample: ] -def config(extra_argv: list[str], data_rows: list[dict] | None = None, latency: float = 0.0): +def integration_env_config( + extra_argv: list[str], + data_rows: list[dict] | None = None, + latency: float = 0.0, + variant: str = "single_turn", +): return IntegrationEnvConfig( - extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv, + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + extra_argv, data_rows=data_rows, latency=latency, ) -def load_and_call_train(args, data_source): +def load_and_call_rollout(args, data_source, mode: str = "train") -> RolloutFnOutput: + function_path = args.rollout_function_path if mode == "train" else args.eval_function_path fn = load_rollout_function( RolloutFnConstructorInput(args=args, data_source=data_source), - args.rollout_function_path, + function_path, ) - return call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0)) + if mode == "train": + return call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0)) + else: + return call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + +def load_and_call_train(args, data_source): + return load_and_call_rollout(args, data_source, mode="train") def filter_by_reward(args, samples, **kwargs): From adf9d722046319fe7b1bfa24c1dcea2badf68047 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:20:10 +0800 Subject: [PATCH 46/77] Add environment variable to guard enabling the new rollout (#484) --- miles/ray/rollout.py | 33 ++++++++++++++++++++++++------ miles/rollout/base_types.py | 40 +++++++++++++++++++++++++++++++++++++ miles/utils/arguments.py | 4 +++- miles/utils/environ.py | 5 +++++ miles/utils/http_utils.py | 2 ++ tests/conftest.py | 11 ++++++++++ 6 files changed, 88 insertions(+), 7 deletions(-) create mode 100644 miles/utils/environ.py diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 1cba8b7e00..1522c6b89e 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -13,9 +13,15 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from miles.backends.sglang_utils.sglang_engine import SGLangEngine -from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput, RolloutFnTrainInput +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnTrainInput, + call_rollout_fn, +) from miles.rollout.modular_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils import tracking_utils +from miles.utils.environ import get_experimental_rollout_refactor from miles.utils.health_monitor import RolloutHealthMonitor from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client from miles.utils.iter_utils import group_by @@ -54,9 +60,14 @@ def __init__(self, args, pg): data_source_cls = load_function(self.args.data_source_path) self.data_source = data_source_cls(args) - input = RolloutFnConstructorInput(args=args, data_source=self.data_source) - self.generate_rollout = load_rollout_function(input, self.args.rollout_function_path) - self.eval_generate_rollout = load_rollout_function(input, self.args.eval_function_path) + self.use_experimental_refactor = get_experimental_rollout_refactor() + if self.use_experimental_refactor: + input = RolloutFnConstructorInput(args=args, data_source=self.data_source) + self.generate_rollout = load_rollout_function(input, self.args.rollout_function_path) + self.eval_generate_rollout = load_rollout_function(input, self.args.eval_function_path) + else: + self.generate_rollout = load_function(self.args.rollout_function_path) + self.eval_generate_rollout = load_function(self.args.eval_function_path) self.custom_reward_post_process_func = None if self.args.custom_reward_post_process_path is not None: self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path) @@ -144,7 +155,12 @@ def eval(self, rollout_id): return self.health_monitoring_resume() - result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) + if self.use_experimental_refactor: + result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) + else: + result = call_rollout_fn( + self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True + ) data = result.data self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True) metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics) @@ -226,7 +242,12 @@ def _get_rollout_data(self, rollout_id): ) metrics = None else: - data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id)) + if self.use_experimental_refactor: + data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id)) + else: + data = call_rollout_fn( + self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False + ) metrics = data.metrics data = data.samples # flatten the data if it is a list of lists diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index e4aa454302..d4e37c8605 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -61,6 +61,46 @@ class RolloutFnEvalOutput: RolloutFnOutput = RolloutFnTrainOutput | RolloutFnEvalOutput +# TODO: may add add_arguments +# TODO: may add save/load if need it to be stateful +# Duck typing, users do not need to extend this class +@runtime_checkable +class RolloutFnProtocol(Protocol): + def __call__(self, input: RolloutFnInput) -> RolloutFnOutput | Awaitable[RolloutFnOutput]: ... + + +# TODO maybe put to modular_rollout folder depending on overall folder structure +@dataclass(frozen=True) +class GenerateFnInput: + state: GenerateState + sample: Sample + sampling_params: dict[str, Any] + evaluation: bool + + @property + def args(self) -> Namespace: + return self.state.args + + +@dataclass(frozen=True) +class GenerateFnOutput: + # One generate may lead to multiple samples, such as multi-agent, tree-like exploration, or + # multi-turn with removing thinking tokens. + samples: Sample | list[Sample] + + +# TODO: may add add_arguments +# TODO: may add save/load if need it to be stateful +@runtime_checkable +class GenerateFnProtocol(Protocol): + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: ... + + +def call_rollout_fn(fn, *args, evaluation: bool, **kwargs): + """Legacy rollout function call interface. Used when MILES_EXPERIMENTAL_ROLLOUT_REFACTOR is disabled.""" + output = fn(*args, **kwargs, evaluation=evaluation) + + # TODO: may add add_arguments # TODO: may add save/load if need it to be stateful # Duck typing, users do not need to extend this class diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 41ebaf00fe..c95f91ae90 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -10,6 +10,7 @@ from miles.backends.sglang_utils.arguments import add_sglang_arguments from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args +from miles.utils.environ import get_experimental_rollout_refactor from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from miles.utils.logging_utils import configure_logger from miles.utils.misc import load_function @@ -1389,7 +1390,8 @@ def add_sglang_tp_size(): parser = add_prefill_decode_disaggregation_arguments(parser) parser = add_ci_arguments(parser) parser = add_custom_megatron_plugins_arguments(parser) - parser = add_user_provided_function_arguments(parser) + if get_experimental_rollout_refactor(): + parser = add_user_provided_function_arguments(parser) reset_arg( parser, "--custom-config-path", diff --git a/miles/utils/environ.py b/miles/utils/environ.py new file mode 100644 index 0000000000..155e3fbf1b --- /dev/null +++ b/miles/utils/environ.py @@ -0,0 +1,5 @@ +import os + + +def get_experimental_rollout_refactor() -> bool: + return bool(int(os.environ.get("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", "0"))) diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 9641cbe0ec..0abdbbf59d 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -269,6 +269,7 @@ async def do_post(self, url, payload, max_retries=60, action="post"): _post_actors = created +# TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) async def post(url, payload, max_retries=60, action="post"): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: @@ -287,6 +288,7 @@ async def post(url, payload, max_retries=60, action="post"): return await _post(_http_client, url, payload, max_retries, action=action) +# TODO unify w/ `post` to add retries and remote-execution async def get(url): response = await _http_client.get(url) response.raise_for_status() diff --git a/tests/conftest.py b/tests/conftest.py index b04dc6bd0d..d72eda5f34 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,15 @@ +import os + +import pytest + from tests.fixtures.generation_fixtures import generation_env from tests.fixtures.rollout_integration import rollout_integration_env _ = rollout_integration_env, generation_env + + +@pytest.fixture(autouse=True) +def enable_experimental_rollout_refactor(): + os.environ["MILES_EXPERIMENTAL_ROLLOUT_REFACTOR"] = "1" + yield + os.environ.pop("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", None) From c410f4297208a5950aad5d048a8ef332e3e3b7d5 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:20:27 +0800 Subject: [PATCH 47/77] Improve fault tolerance for router session retrieval (#485) --- .../generate_hub/openai_endpoint_utils.py | 19 +++++-- miles/router/sessions.py | 49 ++++++++++++------ tests/router/test_sessions.py | 50 ++++++++++++++++--- 3 files changed, 92 insertions(+), 26 deletions(-) diff --git a/miles/rollout/generate_hub/openai_endpoint_utils.py b/miles/rollout/generate_hub/openai_endpoint_utils.py index 6293564f4f..73ba8198bf 100644 --- a/miles/rollout/generate_hub/openai_endpoint_utils.py +++ b/miles/rollout/generate_hub/openai_endpoint_utils.py @@ -2,13 +2,16 @@ Utilities for the OpenAI endpoint """ +import logging from argparse import Namespace from copy import deepcopy -from miles.router.sessions import DeleteSessionResponse, SessionRecord +from miles.router.sessions import GetSessionResponse, SessionRecord from miles.utils.http_utils import post from miles.utils.types import Sample +logger = logging.getLogger(__name__) + class OpenAIEndpointTracer: def __init__(self, router_url: str, session_id: str): @@ -23,10 +26,16 @@ async def create(args: Namespace): return OpenAIEndpointTracer(router_url=router_url, session_id=session_id) async def collect_records(self) -> list[SessionRecord]: - # TODO: for fault tolerance, we may want to change to GET + DELETE - response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="delete") - response = DeleteSessionResponse.model_validate(response) - return response.records + response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="get") + response = GetSessionResponse.model_validate(response) + records = response.records + + try: + await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="delete") + except Exception as e: + logger.warning(f"Failed to delete session {self.session_id} after collecting records: {e}") + + return records def compute_samples_from_openai_records(input_sample: Sample, records: list[SessionRecord], tokenizer) -> list[Sample]: diff --git a/miles/router/sessions.py b/miles/router/sessions.py index f52cc33ef0..9d753e5975 100644 --- a/miles/router/sessions.py +++ b/miles/router/sessions.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from fastapi import Request -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, Response from pydantic import BaseModel from transformers import AutoTokenizer @@ -21,7 +21,7 @@ class SessionRecord(BaseModel): status_code: int -class DeleteSessionResponse(BaseModel): +class GetSessionResponse(BaseModel): session_id: str records: list[SessionRecord] @@ -52,7 +52,15 @@ def setup_session_routes(app, router: "MilesRouter"): # TODO temporary hack before @guapisolo implements TITO # ============================= HACK START =============================== - tokenizer = AutoTokenizer.from_pretrained(router.args.hf_checkpoint, trust_remote_code=True) + # Lazy load tokenizer only when needed (for tests that don't have hf_checkpoint) + tokenizer = None + + def get_tokenizer(): + nonlocal tokenizer + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(router.args.hf_checkpoint, trust_remote_code=True) + return tokenizer + # ============================= HACK END =============================== @app.post("/sessions") @@ -60,12 +68,19 @@ async def create_session(): session_id = manager.create_session() return {"session_id": session_id} + @app.get("/sessions/{session_id}") + async def get_session(session_id: str): + records = manager.get_session(session_id) + if records is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + return GetSessionResponse(session_id=session_id, records=records) + @app.delete("/sessions/{session_id}") async def delete_session(session_id: str): if session_id not in manager.sessions: return JSONResponse(status_code=404, content={"error": "session not found"}) - records = manager.delete_session(session_id) - return DeleteSessionResponse(session_id=session_id, records=records) + manager.delete_session(session_id) + return Response(status_code=204) @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) async def session_proxy(request: Request, session_id: str, path: str): @@ -79,15 +94,21 @@ async def session_proxy(request: Request, session_id: str, path: str): # TODO: remove this hack when @guapisolo implements the real TITO # ============================= HACK START =============================== - request_body["input_ids"] = tokenizer.apply_chat_template( - request_body["messages"], - add_generation_prompt=True, - add_special_tokens=False, - tools=request_body.get("tools"), - ) - logprobs_content = response_body["choices"][0]["logprobs"]["content"] - for item in logprobs_content: - item["token_id"] = tokenizer.convert_tokens_to_ids(item["token"]) + if "messages" in request_body and "input_ids" not in request_body: + request_body["input_ids"] = get_tokenizer().apply_chat_template( + request_body["messages"], + add_generation_prompt=True, + add_special_tokens=False, + tools=request_body.get("tools"), + ) + if ( + "logprobs" in response_body.get("choices", [{}])[0] + and "content" in response_body["choices"][0]["logprobs"] + ): + logprobs_content = response_body["choices"][0]["logprobs"]["content"] + for item in logprobs_content: + if "token" in item and "token_id" not in item: + item["token_id"] = get_tokenizer().convert_tokens_to_ids(item["token"]) # ============================= HACK END =============================== record = SessionRecord( diff --git a/tests/router/test_sessions.py b/tests/router/test_sessions.py index 0b37aa5c93..5c6edafe20 100644 --- a/tests/router/test_sessions.py +++ b/tests/router/test_sessions.py @@ -83,6 +83,7 @@ def process_fn(prompt: str) -> ProcessResult: miles_router_middleware_paths=[], rollout_health_check_interval=60, miles_router_health_check_failure_threshold=3, + hf_checkpoint="Qwen/Qwen3-0.6B", ) router = MilesRouter(args) @@ -107,13 +108,40 @@ def test_create_session(self, router_url): assert "session_id" in data assert len(data["session_id"]) == 32 + def test_get_session(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + assert get_resp.status_code == 200 + data = get_resp.json() + assert data["session_id"] == session_id + assert data["records"] == [] + + def test_get_session_not_found(self, router_url): + response = requests.get(f"{router_url}/sessions/nonexistent") + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + def test_get_with_records(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + requests.post( + f"{router_url}/sessions/{session_id}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + ) + + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + assert get_resp.status_code == 200 + data = get_resp.json() + assert data["session_id"] == session_id + assert len(data["records"]) == 1 + def test_delete_session(self, router_url): session_id = requests.post(f"{router_url}/sessions").json()["session_id"] delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") - assert delete_resp.status_code == 200 - assert delete_resp.json()["session_id"] == session_id - assert delete_resp.json()["records"] == [] + assert delete_resp.status_code == 204 + assert delete_resp.text == "" assert requests.delete(f"{router_url}/sessions/{session_id}").status_code == 404 @@ -139,12 +167,16 @@ def test_proxy_records_request_response(self, router_url): assert resp.status_code == 200 assert "text" in resp.json() - records = requests.delete(f"{router_url}/sessions/{session_id}").json()["records"] + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + records = get_resp.json()["records"] assert len(records) == 1 assert records[0]["method"] == "POST" assert records[0]["path"] == "generate" - assert records[0]["request_json"]["input_ids"] == [1, 2, 3] - assert "text" in records[0]["response_json"] + assert records[0]["request"]["input_ids"] == [1, 2, 3] + assert "text" in records[0]["response"] + + delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") + assert delete_resp.status_code == 204 def test_proxy_accumulates_records(self, router_url): session_id = requests.post(f"{router_url}/sessions").json()["session_id"] @@ -155,5 +187,9 @@ def test_proxy_accumulates_records(self, router_url): json={"input_ids": [1], "sampling_params": {}, "return_logprob": True}, ) - records = requests.delete(f"{router_url}/sessions/{session_id}").json()["records"] + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + records = get_resp.json()["records"] assert len(records) == 3 + + delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") + assert delete_resp.status_code == 204 From 716c2dddaa0f1e070e0afbfc07c8395ff3a858b3 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:20:48 +0800 Subject: [PATCH 48/77] Support rollout routing replay for multi turn (#486) --- miles/rollout/generate_hub/sample_utils.py | 5 +- miles/utils/types.py | 4 ++ tests/rollout/generate_hub/test_multi_turn.py | 60 ++++++++++++++++++- .../rollout/generate_hub/test_single_turn.py | 12 ++-- 4 files changed, 72 insertions(+), 9 deletions(-) diff --git a/miles/rollout/generate_hub/sample_utils.py b/miles/rollout/generate_hub/sample_utils.py index c71e1ec57b..6d82a90a4f 100644 --- a/miles/rollout/generate_hub/sample_utils.py +++ b/miles/rollout/generate_hub/sample_utils.py @@ -41,6 +41,8 @@ def _fill_defaults(sample: Sample): assert _startswith(short=a.prompt, long=b.prompt), "b.prompt must start with a.prompt" assert _startswith(short=a.tokens, long=b.tokens), "b.tokens must start with a.tokens" assert obs_len > 0, f"obs_len must be > 0, got {obs_len}" + if a.rollout_routed_experts is not None: + assert a.rollout_routed_experts.shape[0] <= b.rollout_routed_experts.shape[0] assert a.status == Sample.Status.COMPLETED, f"a.status must be COMPLETED, got {a.status}" return _create_with_all_fields( @@ -58,8 +60,7 @@ def _fill_defaults(sample: Sample): loss_mask=a.loss_mask + [0] * obs_len + b.loss_mask, weight_versions=a.weight_versions + b.weight_versions, rollout_log_probs=a.rollout_log_probs + [0.0] * obs_len + b.rollout_log_probs, - # TODO should support concat - rollout_routed_experts=_merge_equal_value("rollout_routed_experts"), + rollout_routed_experts=b.rollout_routed_experts, remove_sample=_merge_equal_value("remove_sample"), status=b.status, metadata=_merge_equal_value("metadata"), diff --git a/miles/utils/types.py b/miles/utils/types.py index cb690ec600..5200d625e6 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -158,6 +158,10 @@ def validate(self): assert ( len(self.rollout_log_probs) == self.response_length ), f"rollout_log_probs length ({len(self.rollout_log_probs)}) != response_length ({self.response_length})" + if self.rollout_routed_experts is not None: + actual = len(self.rollout_routed_experts) + expect = len(self.tokens) - 1 + assert actual == expect, f"rollout_routed_experts length ({actual}) != len(tokens) - 1 ({expect})" def update_from_meta_info(self, args, meta_info: dict): """ diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py index a59b1f2325..a20e7eb41a 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -2,11 +2,13 @@ from dataclasses import dataclass, replace from itertools import groupby +import numpy as np +import pybase64 import pytest from tests.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate from transformers import AutoTokenizer -from miles.utils.test_utils.mock_sglang_server import ProcessResult +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, ThreeTurnStub, TwoTurnStub from miles.utils.types import Sample @@ -486,3 +488,59 @@ def test_three_turns_with_sequential_tool_calls(self, variant, generation_env): ), ] verify_samples(result.sample, expected) + + +class TestRoutedExpertsMultiTurn: + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": { + "use_rollout_routing_replay": True, + } + } + ], + indirect=True, + ) + def test_two_turns_routed_experts(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + + S = TwoTurnStub + num_layers, moe_router_topk = 2, 4 + generation_env.args.num_layers = num_layers + generation_env.args.moe_router_topk = moe_router_topk + + def make_routed_experts(prompt_token_ids, response_text): + total_tokens = len(prompt_token_ids) + token_len(response_text) + routed_experts_len = total_tokens - 1 + return np.arange(routed_experts_len * num_layers * moe_router_topk, dtype=np.int32).reshape( + routed_experts_len, num_layers, moe_router_topk + ) + + first_routed_experts = make_routed_experts(S.FIRST_PROMPT_TOKEN_IDS, S.FIRST_RESPONSE) + second_routed_experts = make_routed_experts(S.SECOND_PROMPT_TOKEN_IDS, S.SECOND_RESPONSE) + + def process_fn(prompt: str) -> ProcessResult: + if prompt == S.FIRST_PROMPT: + text, routed_experts = S.FIRST_RESPONSE, first_routed_experts + elif prompt == S.SECOND_PROMPT: + text, routed_experts = S.SECOND_RESPONSE, second_routed_experts + else: + raise ValueError(f"Unexpected prompt: {prompt}") + return ProcessResult( + text=text, + finish_reason="stop", + meta_info=ProcessResultMetaInfo( + routed_experts=pybase64.b64encode(routed_experts.tobytes()).decode("ascii") + ), + ) + + generation_env.mock_server.process_fn = process_fn + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT), DEFAULT_SAMPLING_PARAMS) + + sample = result.sample[-1] if isinstance(result.sample, list) else result.sample + assert sample.rollout_routed_experts is not None + assert sample.rollout_routed_experts.shape == second_routed_experts.shape + np.testing.assert_array_equal(sample.rollout_routed_experts, second_routed_experts) + assert len(sample.tokens) - 1 == second_routed_experts.shape[0] diff --git a/tests/rollout/generate_hub/test_single_turn.py b/tests/rollout/generate_hub/test_single_turn.py index 824014276d..bcbced5de0 100644 --- a/tests/rollout/generate_hub/test_single_turn.py +++ b/tests/rollout/generate_hub/test_single_turn.py @@ -18,10 +18,12 @@ MODEL_NAME = "Qwen/Qwen3-0.6B" PROMPT = "What is 1+7?" PROMPT_TOKENS = [3838, 374, 220, 16, 10, 22, 30] +PROMPT_TOKEN_LEN = len(PROMPT_TOKENS) RESPONSE_TOKENS = [59, 79075, 90, 23, 92] RESPONSE_TEXT = "\\boxed{8}" RESPONSE_LOG_PROBS = [-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125] SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} +DEFAULT_MAX_NEW_TOKENS = SAMPLING_PARAMS["max_new_tokens"] @pytest.fixture(params=["old_sglang_rollout", "single_turn", "multi_turn_single_sample", "multi_turn_multi_samples"]) @@ -206,9 +208,6 @@ class TestRoutedExperts: indirect=True, ) def test_routed_experts_enabled_and_parsed(self, variant, generation_env): - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): - pytest.skip("TODO: support") - num_layers, moe_router_topk = 2, 4 num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS) routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape( @@ -226,9 +225,10 @@ def test_routed_experts_enabled_and_parsed(self, variant, generation_env): result = _run_generate(variant, generation_env) assert result.requests == [expected_request(variant, return_routed_experts=True)] - assert result.sample.rollout_routed_experts is not None - assert result.sample.rollout_routed_experts.shape == (num_tokens - 1, num_layers, moe_router_topk) - np.testing.assert_array_equal(result.sample.rollout_routed_experts, routed_experts_array) + sample = result.sample[0] if isinstance(result.sample, list) else result.sample + assert sample.rollout_routed_experts is not None + assert sample.rollout_routed_experts.shape == (num_tokens - 1, num_layers, moe_router_topk) + np.testing.assert_array_equal(sample.rollout_routed_experts, routed_experts_array) class TestMetaInfo: From 1c26271bd4cd0e58b6f805902d14b982eb1c5270 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:21:11 +0800 Subject: [PATCH 49/77] Change max_num_tokens according to rollout_max_context_len (#487) --- .../generate_hub/generate_endpoint_wrapper.py | 9 ++-- tests/rollout/generate_hub/test_multi_turn.py | 26 +++++++++++ .../rollout/generate_hub/test_single_turn.py | 44 +++++++++++++++++++ 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/miles/rollout/generate_hub/generate_endpoint_wrapper.py b/miles/rollout/generate_hub/generate_endpoint_wrapper.py index 8947201de9..52796e9ec6 100644 --- a/miles/rollout/generate_hub/generate_endpoint_wrapper.py +++ b/miles/rollout/generate_hub/generate_endpoint_wrapper.py @@ -44,14 +44,15 @@ def compute_request_payload( sampling_params: dict, multimodal_inputs: dict | None = None, ) -> tuple[dict[str, Any] | None, Sample.Status | None]: - # TODO need to adjust sampling_params.max_new_tokens when input is moderately long - max_context_length = args.rollout_max_context_len or float("inf") - if len(input_ids) >= max_context_length: + max_new_tokens = sampling_params.pop("max_new_tokens", args.rollout_max_response_len) + if x := args.rollout_max_context_len: + max_new_tokens = min(max_new_tokens, x - len(input_ids)) + if max_new_tokens <= 0: return None, Sample.Status.TRUNCATED payload = { "input_ids": input_ids, - "sampling_params": sampling_params, + "sampling_params": {**sampling_params, "max_new_tokens": max_new_tokens}, "return_logprob": True, "return_routed_experts": args.use_rollout_routing_replay, } diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/rollout/generate_hub/test_multi_turn.py index a20e7eb41a..18652be7b6 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/rollout/generate_hub/test_multi_turn.py @@ -414,6 +414,32 @@ def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, ge ] verify_samples(result.sample, expected) + @pytest.mark.parametrize( + "generation_env,expected_max_new_tokens", + [ + ( + {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 10}}, + 10, + ), + ( + {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 100}}, + 64, + ), + ], + indirect=["generation_env"], + ) + def test_second_turn_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + S = TwoTurnStub + generation_env.mock_server.process_fn = S.process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + assert len(result.requests) >= 2 + assert result.requests[1]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens + assert result.requests[1]["sampling_params"]["temperature"] == DEFAULT_SAMPLING_PARAMS["temperature"] + class TestThreeTurn: """Need to test 3-turn case besides 2-turn, because e.g. merge_samples may behave differently.""" diff --git a/tests/rollout/generate_hub/test_single_turn.py b/tests/rollout/generate_hub/test_single_turn.py index bcbced5de0..2d399fe9e0 100644 --- a/tests/rollout/generate_hub/test_single_turn.py +++ b/tests/rollout/generate_hub/test_single_turn.py @@ -330,6 +330,50 @@ def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generat ) ] + @pytest.mark.parametrize( + "generation_env,expected_max_new_tokens", + [ + ({"args_kwargs": {"rollout_max_context_len": 10}}, 10 - PROMPT_TOKEN_LEN), + ({"args_kwargs": {"rollout_max_context_len": 8}}, 8 - PROMPT_TOKEN_LEN), + ({"args_kwargs": {"rollout_max_context_len": 100}}, DEFAULT_MAX_NEW_TOKENS), + ], + indirect=["generation_env"], + ) + def test_moderate_length_input_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + result = _run_generate(variant, generation_env) + assert len(result.requests) == 1 + assert result.requests[0]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens + assert result.requests[0]["sampling_params"]["temperature"] == SAMPLING_PARAMS["temperature"] + assert listify(result.sample) == [expected_sample(variant)] + + @pytest.mark.parametrize( + "generation_env", + [{"args_kwargs": {"rollout_max_context_len": PROMPT_TOKEN_LEN}}], + indirect=True, + ) + def test_adjusted_max_new_tokens_zero_returns_truncated(self, variant, generation_env): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + if variant == "multi_turn_multi_samples": + pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") + result = _run_generate(variant, generation_env) + assert result.requests == [] + tokens = PROMPT_TOKENS if variant == "multi_turn_single_sample" else [] + assert listify(result.sample) == [ + expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, + ) + ] + class TestEmptyResponse: @pytest.mark.parametrize("generation_env", [{"process_fn_kwargs": {"response_text": ""}}], indirect=True) From 4faeb7abfe50729b7dd16cdb897b2984389c40ec Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:21:30 +0800 Subject: [PATCH 50/77] Minor code and test cleanup (#488) --- .../generate_hub/generate_endpoint_wrapper.py | 3 +- .../modular_rollout/orchestration_common.py | 3 +- miles/utils/test_utils/mock_sglang_server.py | 2 +- .../test_utils/test_mock_sglang_server.py | 31 ++++++------------- tests/utils/test_utils/test_mock_tools.py | 4 +-- 5 files changed, 17 insertions(+), 26 deletions(-) diff --git a/miles/rollout/generate_hub/generate_endpoint_wrapper.py b/miles/rollout/generate_hub/generate_endpoint_wrapper.py index 52796e9ec6..5abce60693 100644 --- a/miles/rollout/generate_hub/generate_endpoint_wrapper.py +++ b/miles/rollout/generate_hub/generate_endpoint_wrapper.py @@ -2,7 +2,7 @@ """ Wrapper to integrate SGLang's `/generate` endpoint with RL things like Sample. """ - +from copy import deepcopy from typing import Any import numpy as np @@ -44,6 +44,7 @@ def compute_request_payload( sampling_params: dict, multimodal_inputs: dict | None = None, ) -> tuple[dict[str, Any] | None, Sample.Status | None]: + sampling_params = deepcopy(sampling_params) max_new_tokens = sampling_params.pop("max_new_tokens", args.rollout_max_response_len) if x := args.rollout_max_context_len: max_new_tokens = min(max_new_tokens, x - len(input_ids)) diff --git a/miles/rollout/modular_rollout/orchestration_common.py b/miles/rollout/modular_rollout/orchestration_common.py index ab0f55f2b2..195e39cff8 100644 --- a/miles/rollout/modular_rollout/orchestration_common.py +++ b/miles/rollout/modular_rollout/orchestration_common.py @@ -1,6 +1,7 @@ import asyncio import logging from argparse import Namespace +from copy import deepcopy from typing import Any from miles.rollout.base_types import GenerateFnInput @@ -68,7 +69,7 @@ async def generate_and_rm( GenerateFnInput( state=state, sample=sample, - sampling_params=sampling_params, + sampling_params=deepcopy(sampling_params), evaluation=evaluation, ) ) diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py index f8f233d208..2c0dddfe54 100644 --- a/miles/utils/test_utils/mock_sglang_server.py +++ b/miles/utils/test_utils/mock_sglang_server.py @@ -32,7 +32,7 @@ def to_dict(self) -> dict: @dataclass(frozen=True) class ProcessResult: text: str - finish_reason: str + finish_reason: str = "stop" cached_tokens: int = 0 meta_info: ProcessResultMetaInfo = ProcessResultMetaInfo() diff --git a/tests/utils/test_utils/test_mock_sglang_server.py b/tests/utils/test_utils/test_mock_sglang_server.py index b7ed21f365..6633678da1 100644 --- a/tests/utils/test_utils/test_mock_sglang_server.py +++ b/tests/utils/test_utils/test_mock_sglang_server.py @@ -12,18 +12,7 @@ default_process_fn, with_mock_server, ) -from miles.utils.test_utils.mock_tools import ( - MULTI_TURN_FIRST_PROMPT, - MULTI_TURN_FIRST_RESPONSE, - MULTI_TURN_FIRST_RESPONSE_CONTENT, - MULTI_TURN_FIRST_TOOL_CALLS, - MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN, - MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN, - MULTI_TURN_SECOND_PROMPT, - MULTI_TURN_SECOND_RESPONSE, - SAMPLE_TOOLS, - multi_turn_tool_call_process_fn, -) +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub def expected_logprobs(tokenizer, text: str) -> list[dict]: @@ -370,12 +359,12 @@ class TestMultiTurnToolCallProcessFn: @pytest.mark.parametrize( "prompt,expected_response", [ - pytest.param(MULTI_TURN_FIRST_PROMPT, MULTI_TURN_FIRST_RESPONSE, id="first_turn"), - pytest.param(MULTI_TURN_SECOND_PROMPT, MULTI_TURN_SECOND_RESPONSE, id="second_turn"), + pytest.param(TwoTurnStub.FIRST_PROMPT, TwoTurnStub.FIRST_RESPONSE, id="first_turn"), + pytest.param(TwoTurnStub.SECOND_PROMPT, TwoTurnStub.SECOND_RESPONSE, id="second_turn"), ], ) def test_generate_endpoint(self, prompt, expected_response): - with with_mock_server(process_fn=multi_turn_tool_call_process_fn) as server: + with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: input_ids = server.tokenizer.encode(prompt, add_special_tokens=False) response = requests.post( f"{server.url}/generate", @@ -391,15 +380,15 @@ def test_generate_endpoint(self, prompt, expected_response): "messages,expected_content,expected_tool_calls,expected_finish_reason", [ pytest.param( - MULTI_TURN_OPENAI_MESSAGES_FIRST_TURN, - MULTI_TURN_FIRST_RESPONSE_CONTENT, - MULTI_TURN_FIRST_TOOL_CALLS, + TwoTurnStub.OPENAI_MESSAGES_FIRST_TURN, + TwoTurnStub.FIRST_RESPONSE_CONTENT, + TwoTurnStub.FIRST_TOOL_CALLS_OPENAI_FORMAT, "tool_calls", id="first_turn", ), pytest.param( - MULTI_TURN_OPENAI_MESSAGES_SECOND_TURN, - MULTI_TURN_SECOND_RESPONSE, + TwoTurnStub.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT, + TwoTurnStub.SECOND_RESPONSE, None, "stop", id="second_turn", @@ -407,7 +396,7 @@ def test_generate_endpoint(self, prompt, expected_response): ], ) def test_chat_completions_endpoint(self, messages, expected_content, expected_tool_calls, expected_finish_reason): - with with_mock_server(process_fn=multi_turn_tool_call_process_fn) as server: + with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: response = requests.post( f"{server.url}/v1/chat/completions", json={"model": "test", "messages": messages, "tools": SAMPLE_TOOLS}, diff --git a/tests/utils/test_utils/test_mock_tools.py b/tests/utils/test_utils/test_mock_tools.py index 0a77a2a31f..b905fa8525 100644 --- a/tests/utils/test_utils/test_mock_tools.py +++ b/tests/utils/test_utils/test_mock_tools.py @@ -6,7 +6,7 @@ from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.function_call_parser import FunctionCallParser -from miles.utils.test_utils.mock_tools import MULTI_TURN_FIRST_RESPONSE, SAMPLE_TOOLS, execute_tool_call +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub, execute_tool_call class TestExecuteToolCall: @@ -93,7 +93,7 @@ class TestSGLangFunctionCallParser: id="no_tool_call", ), pytest.param( - MULTI_TURN_FIRST_RESPONSE, + TwoTurnStub.FIRST_RESPONSE, ( "Let me get the year and temperature first.", [ From 204ecb1614eae6227100d2758b034474e66a49fe Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:30:30 +0800 Subject: [PATCH 51/77] Cleanup file and folder structure for rollout (#489) --- miles/ray/rollout.py | 2 +- miles/rollout/base_types.py | 21 +--- .../rollout/generate_hub/agentic_tool_call.py | 9 +- miles/rollout/generate_hub/multi_turn.py | 4 +- miles/rollout/generate_hub/single_turn.py | 2 +- .../__init__.py | 0 .../generate_endpoint_utils.py} | 7 +- .../openai_endpoint_utils.py | 0 .../sample_utils.py | 4 +- .../tool_call_utils.py | 4 + miles/rollout/inference_rollout/__init__.py | 2 + .../compatibility.py | 6 +- .../inference_rollout_common.py} | 2 +- .../inference_rollout_eval.py} | 6 +- .../inference_rollout_train.py} | 2 +- .../modular_rollout/inference_wrapper.py | 100 ------------------ tests/e2e/.gitkeep | 1 + tests/{rollout => fast}/__init__.py | 0 tests/{ => fast}/conftest.py | 6 +- tests/{ => fast}/fixtures/__init__.py | 0 .../fixtures/generation_fixtures.py | 4 +- .../fixtures/rollout_fixtures.py} | 13 +-- .../generate_hub => fast/rollout}/__init__.py | 0 .../rollout/generate_hub}/__init__.py | 0 .../rollout/generate_hub/test_multi_turn.py | 2 +- .../rollout/generate_hub/test_single_turn.py | 2 +- .../generate_hub/test_tool_call_utils.py | 2 +- .../rollout/generate_utils}/__init__.py | 0 .../generate_utils}/test_sample_utils.py | 14 +-- .../rollout/inference_rollout}/__init__.py | 0 .../rollout/inference_rollout}/conftest.py | 0 .../integration}/__init__.py | 0 .../integration/test_basic.py | 30 +++--- .../integration/test_deterministic.py | 10 +- .../integration/test_dynamic_filter.py | 10 +- .../integration/test_group_rm.py | 8 +- .../integration/test_multi_sample.py | 12 +-- .../integration/test_multi_turn.py | 20 ++-- .../integration/test_over_sampling.py | 10 +- .../integration/test_sample_filter.py | 8 +- .../integration/test_semaphore.py | 10 +- .../inference_rollout}/integration/utils.py | 12 +-- .../inference_rollout}/test_compatibility.py | 2 +- tests/fast/rollout/rm_hub/__init__.py | 0 .../rollout/rm_hub/test_deepscaler.py | 0 tests/{ => fast}/rollout/rm_hub/test_f1.py | 0 tests/{ => fast}/rollout/rm_hub/test_gpqa.py | 0 .../rollout/rm_hub/test_math_dapo_utils.py | 0 .../rollout/rm_hub/test_math_utils.py | 0 .../{ => fast}/rollout/rm_hub/test_rm_hub.py | 0 tests/fast/router/__init__.py | 0 tests/{ => fast}/router/test_router.py | 0 tests/{ => fast}/router/test_sessions.py | 0 tests/fast/utils/__init__.py | 0 tests/{ => fast}/utils/test_arguments.py | 0 tests/{ => fast}/utils/test_mask_utils.py | 0 tests/{ => fast}/utils/test_misc.py | 0 tests/fast/utils/test_utils/__init__.py | 0 .../test_utils/test_mock_sglang_server.py | 0 .../utils/test_utils/test_mock_tools.py | 0 .../modular_rollout/test_integration.py | 98 ----------------- 61 files changed, 117 insertions(+), 328 deletions(-) rename miles/rollout/{modular_rollout => generate_utils}/__init__.py (100%) rename miles/rollout/{generate_hub/generate_endpoint_wrapper.py => generate_utils/generate_endpoint_utils.py} (93%) rename miles/rollout/{generate_hub => generate_utils}/openai_endpoint_utils.py (100%) rename miles/rollout/{generate_hub => generate_utils}/sample_utils.py (97%) rename miles/rollout/{generate_hub => generate_utils}/tool_call_utils.py (99%) create mode 100644 miles/rollout/inference_rollout/__init__.py rename miles/rollout/{modular_rollout => inference_rollout}/compatibility.py (93%) rename miles/rollout/{modular_rollout/orchestration_common.py => inference_rollout/inference_rollout_common.py} (98%) rename miles/rollout/{modular_rollout/orchestration_eval.py => inference_rollout/inference_rollout_eval.py} (97%) rename miles/rollout/{modular_rollout/orchestration_train.py => inference_rollout/inference_rollout_train.py} (98%) delete mode 100644 miles/rollout/modular_rollout/inference_wrapper.py create mode 100644 tests/e2e/.gitkeep rename tests/{rollout => fast}/__init__.py (100%) rename tests/{ => fast}/conftest.py (57%) rename tests/{ => fast}/fixtures/__init__.py (100%) rename tests/{ => fast}/fixtures/generation_fixtures.py (98%) rename tests/{fixtures/rollout_integration.py => fast/fixtures/rollout_fixtures.py} (89%) rename tests/{rollout/generate_hub => fast/rollout}/__init__.py (100%) rename tests/{rollout/modular_rollout => fast/rollout/generate_hub}/__init__.py (100%) rename tests/{ => fast}/rollout/generate_hub/test_multi_turn.py (99%) rename tests/{ => fast}/rollout/generate_hub/test_single_turn.py (99%) rename tests/{ => fast}/rollout/generate_hub/test_tool_call_utils.py (96%) rename tests/{rollout/modular_rollout/integration => fast/rollout/generate_utils}/__init__.py (100%) rename tests/{rollout/generate_hub => fast/rollout/generate_utils}/test_sample_utils.py (91%) rename tests/{rollout/rm_hub => fast/rollout/inference_rollout}/__init__.py (100%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/conftest.py (100%) rename tests/{router => fast/rollout/inference_rollout/integration}/__init__.py (100%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/integration/test_basic.py (64%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/integration/test_deterministic.py (74%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/integration/test_dynamic_filter.py (80%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/integration/test_group_rm.py (68%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/integration/test_multi_sample.py (82%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/integration/test_multi_turn.py (87%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/integration/test_over_sampling.py (83%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/integration/test_sample_filter.py (91%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/integration/test_semaphore.py (74%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/integration/utils.py (86%) rename tests/{rollout/modular_rollout => fast/rollout/inference_rollout}/test_compatibility.py (99%) create mode 100644 tests/fast/rollout/rm_hub/__init__.py rename tests/{ => fast}/rollout/rm_hub/test_deepscaler.py (100%) rename tests/{ => fast}/rollout/rm_hub/test_f1.py (100%) rename tests/{ => fast}/rollout/rm_hub/test_gpqa.py (100%) rename tests/{ => fast}/rollout/rm_hub/test_math_dapo_utils.py (100%) rename tests/{ => fast}/rollout/rm_hub/test_math_utils.py (100%) rename tests/{ => fast}/rollout/rm_hub/test_rm_hub.py (100%) create mode 100644 tests/fast/router/__init__.py rename tests/{ => fast}/router/test_router.py (100%) rename tests/{ => fast}/router/test_sessions.py (100%) create mode 100644 tests/fast/utils/__init__.py rename tests/{ => fast}/utils/test_arguments.py (100%) rename tests/{ => fast}/utils/test_mask_utils.py (100%) rename tests/{ => fast}/utils/test_misc.py (100%) create mode 100644 tests/fast/utils/test_utils/__init__.py rename tests/{ => fast}/utils/test_utils/test_mock_sglang_server.py (100%) rename tests/{ => fast}/utils/test_utils/test_mock_tools.py (100%) delete mode 100644 tests/rollout/modular_rollout/test_integration.py diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 1522c6b89e..6198d62360 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -19,7 +19,7 @@ RolloutFnTrainInput, call_rollout_fn, ) -from miles.rollout.modular_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils import tracking_utils from miles.utils.environ import get_experimental_rollout_refactor from miles.utils.health_monitor import RolloutHealthMonitor diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index d4e37c8605..daa53634c9 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -1,15 +1,14 @@ from __future__ import annotations from argparse import Namespace -from collections.abc import Awaitable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any from miles.rollout.data_source import DataSource from miles.utils.types import Sample if TYPE_CHECKING: - from miles.rollout.modular_rollout.orchestration_common import GenerateState + from miles.rollout.inference_rollout.inference_rollout_common import GenerateState @dataclass(frozen=True) @@ -61,15 +60,6 @@ class RolloutFnEvalOutput: RolloutFnOutput = RolloutFnTrainOutput | RolloutFnEvalOutput -# TODO: may add add_arguments -# TODO: may add save/load if need it to be stateful -# Duck typing, users do not need to extend this class -@runtime_checkable -class RolloutFnProtocol(Protocol): - def __call__(self, input: RolloutFnInput) -> RolloutFnOutput | Awaitable[RolloutFnOutput]: ... - - -# TODO maybe put to modular_rollout folder depending on overall folder structure @dataclass(frozen=True) class GenerateFnInput: state: GenerateState @@ -89,13 +79,6 @@ class GenerateFnOutput: samples: Sample | list[Sample] -# TODO: may add add_arguments -# TODO: may add save/load if need it to be stateful -@runtime_checkable -class GenerateFnProtocol(Protocol): - async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: ... - - def call_rollout_fn(fn, *args, evaluation: bool, **kwargs): """Legacy rollout function call interface. Used when MILES_EXPERIMENTAL_ROLLOUT_REFACTOR is disabled.""" output = fn(*args, **kwargs, evaluation=evaluation) diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py index 82b59d9719..05223a6544 100644 --- a/miles/rollout/generate_hub/agentic_tool_call.py +++ b/miles/rollout/generate_hub/agentic_tool_call.py @@ -9,9 +9,12 @@ from openai import AsyncOpenAI from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput -from miles.rollout.generate_hub.openai_endpoint_utils import OpenAIEndpointTracer, compute_samples_from_openai_records -from miles.rollout.generate_hub.sample_utils import merge_samples -from miles.rollout.generate_hub.tool_call_utils import execute_tool_calls +from miles.rollout.generate_utils.openai_endpoint_utils import ( + OpenAIEndpointTracer, + compute_samples_from_openai_records, +) +from miles.rollout.generate_utils.sample_utils import merge_samples +from miles.rollout.generate_utils.tool_call_utils import execute_tool_calls from miles.utils.misc import load_function diff --git a/miles/rollout/generate_hub/multi_turn.py b/miles/rollout/generate_hub/multi_turn.py index 2c01a8ba2d..97814ecb3d 100644 --- a/miles/rollout/generate_hub/multi_turn.py +++ b/miles/rollout/generate_hub/multi_turn.py @@ -6,12 +6,12 @@ from copy import deepcopy from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput -from miles.rollout.generate_hub.generate_endpoint_wrapper import ( +from miles.rollout.generate_utils.generate_endpoint_utils import ( compute_prompt_ids_from_sample, compute_request_payload, update_sample_from_response, ) -from miles.rollout.generate_hub.tool_call_utils import ( +from miles.rollout.generate_utils.tool_call_utils import ( create_tool_call_parser, execute_tool_calls, update_sample_with_tool_responses, diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py index ff976e29dd..5c0a15b5b4 100644 --- a/miles/rollout/generate_hub/single_turn.py +++ b/miles/rollout/generate_hub/single_turn.py @@ -3,7 +3,7 @@ """ from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput -from miles.rollout.generate_hub.generate_endpoint_wrapper import ( +from miles.rollout.generate_utils.generate_endpoint_utils import ( compute_prompt_ids_from_sample, compute_request_payload, update_sample_from_response, diff --git a/miles/rollout/modular_rollout/__init__.py b/miles/rollout/generate_utils/__init__.py similarity index 100% rename from miles/rollout/modular_rollout/__init__.py rename to miles/rollout/generate_utils/__init__.py diff --git a/miles/rollout/generate_hub/generate_endpoint_wrapper.py b/miles/rollout/generate_utils/generate_endpoint_utils.py similarity index 93% rename from miles/rollout/generate_hub/generate_endpoint_wrapper.py rename to miles/rollout/generate_utils/generate_endpoint_utils.py index 5abce60693..a91d71f1de 100644 --- a/miles/rollout/generate_hub/generate_endpoint_wrapper.py +++ b/miles/rollout/generate_utils/generate_endpoint_utils.py @@ -1,7 +1,7 @@ -# TODO: may rename to generate_endpoint_utils.py """ -Wrapper to integrate SGLang's `/generate` endpoint with RL things like Sample. +Utils to integrate SGLang's `/generate` endpoint with RL things like Sample. """ + from copy import deepcopy from typing import Any @@ -35,9 +35,6 @@ def compute_prompt_ids_from_sample(state, sample, tools=None): return state.tokenizer.encode(prompt, add_special_tokens=False) -# Thin wrapper to construct request payload. -# Make it a function to allow adding logics like `return_routed_experts` in the future -# without requiring users to change their code. def compute_request_payload( args, input_ids: list[int], diff --git a/miles/rollout/generate_hub/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py similarity index 100% rename from miles/rollout/generate_hub/openai_endpoint_utils.py rename to miles/rollout/generate_utils/openai_endpoint_utils.py diff --git a/miles/rollout/generate_hub/sample_utils.py b/miles/rollout/generate_utils/sample_utils.py similarity index 97% rename from miles/rollout/generate_hub/sample_utils.py rename to miles/rollout/generate_utils/sample_utils.py index 6d82a90a4f..6a4e645be5 100644 --- a/miles/rollout/generate_hub/sample_utils.py +++ b/miles/rollout/generate_utils/sample_utils.py @@ -7,11 +7,11 @@ def merge_samples(samples: list[Sample], tokenizer) -> Sample: acc = samples[0] for sample in samples[1:]: - acc = merge_sample_pair(acc, sample, tokenizer=tokenizer) + acc = _merge_sample_pair(acc, sample, tokenizer=tokenizer) return acc -def merge_sample_pair(a: Sample, b: Sample, tokenizer) -> Sample: +def _merge_sample_pair(a: Sample, b: Sample, tokenizer) -> Sample: """Merge two samples generated from sibling inference engine calls.""" a, b = deepcopy(a), deepcopy(b) diff --git a/miles/rollout/generate_hub/tool_call_utils.py b/miles/rollout/generate_utils/tool_call_utils.py similarity index 99% rename from miles/rollout/generate_hub/tool_call_utils.py rename to miles/rollout/generate_utils/tool_call_utils.py index fd755f6353..85ea87aeab 100644 --- a/miles/rollout/generate_hub/tool_call_utils.py +++ b/miles/rollout/generate_utils/tool_call_utils.py @@ -1,3 +1,7 @@ +""" +Utils to handle tool calls. +""" + import json import uuid from collections.abc import Callable diff --git a/miles/rollout/inference_rollout/__init__.py b/miles/rollout/inference_rollout/__init__.py new file mode 100644 index 0000000000..33ccf17bfb --- /dev/null +++ b/miles/rollout/inference_rollout/__init__.py @@ -0,0 +1,2 @@ +# This is a refactor of the portions above generate-function in sglang_rollout.py, +# and is give a different name to ensure both code exist at the same time. diff --git a/miles/rollout/modular_rollout/compatibility.py b/miles/rollout/inference_rollout/compatibility.py similarity index 93% rename from miles/rollout/modular_rollout/compatibility.py rename to miles/rollout/inference_rollout/compatibility.py index 41427d0ed0..7711e0dd31 100644 --- a/miles/rollout/modular_rollout/compatibility.py +++ b/miles/rollout/inference_rollout/compatibility.py @@ -8,7 +8,6 @@ RolloutFnEvalOutput, RolloutFnInput, RolloutFnOutput, - RolloutFnProtocol, RolloutFnTrainOutput, ) from miles.utils.async_utils import run @@ -31,9 +30,6 @@ def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: return output -assert issubclass(LegacyRolloutFnAdapter, RolloutFnProtocol) - - def load_rollout_function(input: RolloutFnConstructorInput, path: str): fn = load_function(path) @@ -43,7 +39,7 @@ def load_rollout_function(input: RolloutFnConstructorInput, path: str): return LegacyRolloutFnAdapter(input, fn) -def call_rollout_function(fn: RolloutFnProtocol, input: RolloutFnInput) -> RolloutFnOutput: +def call_rollout_function(fn, input: RolloutFnInput) -> RolloutFnOutput: output = fn(input) if inspect.iscoroutine(output): diff --git a/miles/rollout/modular_rollout/orchestration_common.py b/miles/rollout/inference_rollout/inference_rollout_common.py similarity index 98% rename from miles/rollout/modular_rollout/orchestration_common.py rename to miles/rollout/inference_rollout/inference_rollout_common.py index 195e39cff8..5d6f67de24 100644 --- a/miles/rollout/modular_rollout/orchestration_common.py +++ b/miles/rollout/inference_rollout/inference_rollout_common.py @@ -6,7 +6,7 @@ from miles.rollout.base_types import GenerateFnInput from miles.rollout.generate_hub.single_turn import generate -from miles.rollout.modular_rollout.compatibility import load_generate_function +from miles.rollout.inference_rollout.compatibility import load_generate_function from miles.rollout.rm_hub import async_rm, batched_async_rm from miles.utils.processing_utils import load_processor, load_tokenizer from miles.utils.types import Sample diff --git a/miles/rollout/modular_rollout/orchestration_eval.py b/miles/rollout/inference_rollout/inference_rollout_eval.py similarity index 97% rename from miles/rollout/modular_rollout/orchestration_eval.py rename to miles/rollout/inference_rollout/inference_rollout_eval.py index 0e215e9711..3117598f5c 100644 --- a/miles/rollout/modular_rollout/orchestration_eval.py +++ b/miles/rollout/inference_rollout/inference_rollout_eval.py @@ -6,7 +6,11 @@ from tqdm import tqdm from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput, RolloutFnEvalOutput -from miles.rollout.modular_rollout.orchestration_common import GenerateState, compute_sampling_params, generate_and_rm +from miles.rollout.inference_rollout.inference_rollout_common import ( + GenerateState, + compute_sampling_params, + generate_and_rm, +) from miles.utils.data import Dataset from miles.utils.eval_config import EvalDatasetConfig from miles.utils.misc import as_completed_async diff --git a/miles/rollout/modular_rollout/orchestration_train.py b/miles/rollout/inference_rollout/inference_rollout_train.py similarity index 98% rename from miles/rollout/modular_rollout/orchestration_train.py rename to miles/rollout/inference_rollout/inference_rollout_train.py index 2adfa2dce1..b0b7741755 100644 --- a/miles/rollout/modular_rollout/orchestration_train.py +++ b/miles/rollout/inference_rollout/inference_rollout_train.py @@ -9,7 +9,7 @@ from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnTrainInput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter -from miles.rollout.modular_rollout.orchestration_common import GenerateState, generate_and_rm_group +from miles.rollout.inference_rollout.inference_rollout_common import GenerateState, generate_and_rm_group from miles.utils.http_utils import get, post from miles.utils.misc import as_completed_async, load_function from miles.utils.types import Sample diff --git a/miles/rollout/modular_rollout/inference_wrapper.py b/miles/rollout/modular_rollout/inference_wrapper.py deleted file mode 100644 index 3a09d3dfdd..0000000000 --- a/miles/rollout/modular_rollout/inference_wrapper.py +++ /dev/null @@ -1,100 +0,0 @@ -import numpy as np -import pybase64 - -from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput -from miles.utils.http_utils import post -from miles.utils.processing_utils import encode_image_for_rollout_engine -from miles.utils.types import Sample - - -async def generate(input: GenerateFnInput) -> GenerateFnOutput: - """Generate using traditional SGLang router with token-based workflow""" - state = input.state - args = input.args - sample = input.sample - sampling_params = input.sampling_params - - if args.ci_test: - assert isinstance(sample.prompt, str) - - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - - assert ( - sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED - ), f"Sample status is {sample.status}" - - if state.processor: - processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) - prompt_ids = processor_output["input_ids"][0] - sample.multimodal_train_inputs = { - k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] - } or None - else: - prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False) - - if len(sample.response) > 0: - sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) - - assert ( - sampling_params["max_new_tokens"] >= 0 - ), f"max_new_tokens: {sampling_params['max_new_tokens']} should not be less than 0" - if sampling_params["max_new_tokens"] == 0: - sample.status = Sample.Status.TRUNCATED - return GenerateFnOutput(samples=sample) - - # Prepare payload for sglang server - payload = { - "sampling_params": sampling_params, - "return_logprob": True, - } - - if args.use_rollout_routing_replay: - payload["return_routed_experts"] = True - - if sample.multimodal_inputs and sample.multimodal_inputs["images"]: - image_data = sample.multimodal_inputs["images"] - payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] - - # Use existing tokens for multi-turn or tokenize the new prompt - if len(sample.response) > 0: - payload["input_ids"] = sample.tokens - else: - payload["input_ids"] = prompt_ids - if not sample.tokens: # Initialize sample.tokens for the first turn - sample.tokens = prompt_ids - - output = await post(url, payload) - - if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: - from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree - - sample = await postprocess_sample_with_radix_tree(args, sample, output) - else: - if "output_token_logprobs" in output["meta_info"]: - new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] - new_response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] - else: - new_response_tokens, new_response_log_probs = [], [] - - # Update sample with tokens directly - avoiding re-tokenization - sample.tokens = sample.tokens + new_response_tokens - sample.response_length += len(new_response_tokens) - sample.response += output["text"] - - if sample.rollout_log_probs is None: - sample.rollout_log_probs = [] - sample.rollout_log_probs += new_response_log_probs - - if "routed_experts" in output["meta_info"]: - sample.rollout_routed_experts = np.frombuffer( - pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")), - dtype=np.int32, - ).reshape( - len(sample.tokens) - 1, - args.num_layers, - args.moe_router_topk, - ) - - sample.update_from_meta_info(args, output["meta_info"]) - - return GenerateFnOutput(samples=sample) diff --git a/tests/e2e/.gitkeep b/tests/e2e/.gitkeep new file mode 100644 index 0000000000..615f2b076c --- /dev/null +++ b/tests/e2e/.gitkeep @@ -0,0 +1 @@ +# TODO: may move e2e tests to this folder \ No newline at end of file diff --git a/tests/rollout/__init__.py b/tests/fast/__init__.py similarity index 100% rename from tests/rollout/__init__.py rename to tests/fast/__init__.py diff --git a/tests/conftest.py b/tests/fast/conftest.py similarity index 57% rename from tests/conftest.py rename to tests/fast/conftest.py index d72eda5f34..4cb30e91fa 100644 --- a/tests/conftest.py +++ b/tests/fast/conftest.py @@ -2,10 +2,10 @@ import pytest -from tests.fixtures.generation_fixtures import generation_env -from tests.fixtures.rollout_integration import rollout_integration_env +from tests.fast.fixtures.generation_fixtures import generation_env +from tests.fast.fixtures.rollout_fixtures import rollout_env -_ = rollout_integration_env, generation_env +_ = rollout_env, generation_env @pytest.fixture(autouse=True) diff --git a/tests/fixtures/__init__.py b/tests/fast/fixtures/__init__.py similarity index 100% rename from tests/fixtures/__init__.py rename to tests/fast/fixtures/__init__.py diff --git a/tests/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py similarity index 98% rename from tests/fixtures/generation_fixtures.py rename to tests/fast/fixtures/generation_fixtures.py index 8c144cfe4c..816371ee3a 100644 --- a/tests/fixtures/generation_fixtures.py +++ b/tests/fast/fixtures/generation_fixtures.py @@ -13,8 +13,8 @@ import requests from miles.rollout.base_types import GenerateFnInput -from miles.rollout.modular_rollout.compatibility import load_generate_function -from miles.rollout.modular_rollout.orchestration_common import GenerateState +from miles.rollout.inference_rollout.compatibility import load_generate_function +from miles.rollout.inference_rollout.inference_rollout_common import GenerateState from miles.router.router import MilesRouter from miles.utils.async_utils import run from miles.utils.http_utils import find_available_port, init_http_client diff --git a/tests/fixtures/rollout_integration.py b/tests/fast/fixtures/rollout_fixtures.py similarity index 89% rename from tests/fixtures/rollout_integration.py rename to tests/fast/fixtures/rollout_fixtures.py index 60dd4b7d65..44d8a50d79 100644 --- a/tests/fixtures/rollout_integration.py +++ b/tests/fast/fixtures/rollout_fixtures.py @@ -2,7 +2,6 @@ Fixtures to test rollout-function """ -# TODO may rename to rollout_fixutres.py to be aligned import json from argparse import Namespace from collections.abc import Iterator @@ -24,15 +23,14 @@ @dataclass(frozen=True) -class IntegrationEnvConfig: +class RolloutEnvConfig: extra_argv: list[str] | None = None data_rows: list[dict] | None = None latency: float = 0.0 -# TODO may rename to RolloutEnv @dataclass(frozen=True) -class IntegrationEnv: +class RolloutEnv: args: Namespace data_source: DataSource mock_server: MockSGLangServer @@ -99,11 +97,10 @@ def _write_jsonl(path: str, rows: list[dict]) -> None: DEFAULT_DATA_ROWS = [{"input": "What is 1+7?", "label": "8"}] -# TODO may rename to rollout_env @pytest.fixture -def rollout_integration_env(tmp_path, request) -> IntegrationEnv: +def rollout_env(tmp_path, request) -> RolloutEnv: config = request.param - assert isinstance(config, IntegrationEnvConfig) + assert isinstance(config, RolloutEnvConfig) data_rows = config.data_rows or DEFAULT_DATA_ROWS @@ -125,6 +122,6 @@ def rollout_integration_env(tmp_path, request) -> IntegrationEnv: r.raise_for_status() data_source = RolloutDataSourceWithBuffer(args) - yield IntegrationEnv(args=args, data_source=data_source, mock_server=mock_server) + yield RolloutEnv(args=args, data_source=data_source, mock_server=mock_server) SingletonMeta.clear_all_instances() diff --git a/tests/rollout/generate_hub/__init__.py b/tests/fast/rollout/__init__.py similarity index 100% rename from tests/rollout/generate_hub/__init__.py rename to tests/fast/rollout/__init__.py diff --git a/tests/rollout/modular_rollout/__init__.py b/tests/fast/rollout/generate_hub/__init__.py similarity index 100% rename from tests/rollout/modular_rollout/__init__.py rename to tests/fast/rollout/generate_hub/__init__.py diff --git a/tests/rollout/generate_hub/test_multi_turn.py b/tests/fast/rollout/generate_hub/test_multi_turn.py similarity index 99% rename from tests/rollout/generate_hub/test_multi_turn.py rename to tests/fast/rollout/generate_hub/test_multi_turn.py index 18652be7b6..5d974aaadd 100644 --- a/tests/rollout/generate_hub/test_multi_turn.py +++ b/tests/fast/rollout/generate_hub/test_multi_turn.py @@ -5,7 +5,7 @@ import numpy as np import pybase64 import pytest -from tests.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate +from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate from transformers import AutoTokenizer from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo diff --git a/tests/rollout/generate_hub/test_single_turn.py b/tests/fast/rollout/generate_hub/test_single_turn.py similarity index 99% rename from tests/rollout/generate_hub/test_single_turn.py rename to tests/fast/rollout/generate_hub/test_single_turn.py index 2d399fe9e0..a58e6fb3c6 100644 --- a/tests/rollout/generate_hub/test_single_turn.py +++ b/tests/fast/rollout/generate_hub/test_single_turn.py @@ -3,7 +3,7 @@ import pytest import torch from PIL import Image -from tests.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate +from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate from transformers import AutoProcessor from miles.utils.processing_utils import encode_image_for_rollout_engine diff --git a/tests/rollout/generate_hub/test_tool_call_utils.py b/tests/fast/rollout/generate_hub/test_tool_call_utils.py similarity index 96% rename from tests/rollout/generate_hub/test_tool_call_utils.py rename to tests/fast/rollout/generate_hub/test_tool_call_utils.py index 8f06756e64..a89ebfb408 100644 --- a/tests/rollout/generate_hub/test_tool_call_utils.py +++ b/tests/fast/rollout/generate_hub/test_tool_call_utils.py @@ -1,6 +1,6 @@ import pytest -from miles.rollout.generate_hub.tool_call_utils import _DUMMY_USER, _build_dummy_assistant, tokenize_tool_responses +from miles.rollout.generate_utils.tool_call_utils import _DUMMY_USER, _build_dummy_assistant, tokenize_tool_responses TOOL_CALL_TEST_MODELS = [ "Qwen/Qwen2.5-0.5B-Instruct", diff --git a/tests/rollout/modular_rollout/integration/__init__.py b/tests/fast/rollout/generate_utils/__init__.py similarity index 100% rename from tests/rollout/modular_rollout/integration/__init__.py rename to tests/fast/rollout/generate_utils/__init__.py diff --git a/tests/rollout/generate_hub/test_sample_utils.py b/tests/fast/rollout/generate_utils/test_sample_utils.py similarity index 91% rename from tests/rollout/generate_hub/test_sample_utils.py rename to tests/fast/rollout/generate_utils/test_sample_utils.py index 0c49dd433c..c53fbbb56a 100644 --- a/tests/rollout/generate_hub/test_sample_utils.py +++ b/tests/fast/rollout/generate_utils/test_sample_utils.py @@ -2,7 +2,7 @@ import pytest -from miles.rollout.generate_hub.sample_utils import merge_sample_pair +from miles.rollout.generate_utils.sample_utils import _merge_sample_pair from miles.utils.types import Sample @@ -59,7 +59,7 @@ def test_basic_merge(self, mock_tokenizer): status=Sample.Status.TRUNCATED, ) - merged = merge_sample_pair(a, b, mock_tokenizer) + merged = _merge_sample_pair(a, b, mock_tokenizer) assert merged.tokens == b.tokens assert merged.response_length == 3 + 2 + 3 @@ -88,7 +88,7 @@ def test_loss_mask_none_defaults_to_all_ones(self, mock_tokenizer): rollout_log_probs=None, ) - merged = merge_sample_pair(a, b, mock_tokenizer) + merged = _merge_sample_pair(a, b, mock_tokenizer) assert merged.loss_mask == [1, 0, 1] assert merged.rollout_log_probs == [0.0, 0.0, 0.0] @@ -106,7 +106,7 @@ def test_tokens_prefix_mismatch_raises(self, mock_tokenizer): ) with pytest.raises(AssertionError, match="b.tokens must start with a.tokens"): - merge_sample_pair(a, b, mock_tokenizer) + _merge_sample_pair(a, b, mock_tokenizer) def test_field_mismatch_raises(self, mock_tokenizer): a = make_sample( @@ -123,7 +123,7 @@ def test_field_mismatch_raises(self, mock_tokenizer): ) with pytest.raises(AssertionError, match="index mismatch"): - merge_sample_pair(a, b, mock_tokenizer) + _merge_sample_pair(a, b, mock_tokenizer) def test_obs_len_invalid_raises(self, mock_tokenizer): a = make_sample( @@ -138,7 +138,7 @@ def test_obs_len_invalid_raises(self, mock_tokenizer): ) with pytest.raises(AssertionError, match="obs_len must be > 0"): - merge_sample_pair(a, b, mock_tokenizer) + _merge_sample_pair(a, b, mock_tokenizer) def test_sample_validate_fails_raises(self, mock_tokenizer): a = make_sample( @@ -153,4 +153,4 @@ def test_sample_validate_fails_raises(self, mock_tokenizer): ) with pytest.raises(AssertionError, match="loss_mask length"): - merge_sample_pair(a, b, mock_tokenizer) + _merge_sample_pair(a, b, mock_tokenizer) diff --git a/tests/rollout/rm_hub/__init__.py b/tests/fast/rollout/inference_rollout/__init__.py similarity index 100% rename from tests/rollout/rm_hub/__init__.py rename to tests/fast/rollout/inference_rollout/__init__.py diff --git a/tests/rollout/modular_rollout/conftest.py b/tests/fast/rollout/inference_rollout/conftest.py similarity index 100% rename from tests/rollout/modular_rollout/conftest.py rename to tests/fast/rollout/inference_rollout/conftest.py diff --git a/tests/router/__init__.py b/tests/fast/rollout/inference_rollout/integration/__init__.py similarity index 100% rename from tests/router/__init__.py rename to tests/fast/rollout/inference_rollout/integration/__init__.py diff --git a/tests/rollout/modular_rollout/integration/test_basic.py b/tests/fast/rollout/inference_rollout/integration/test_basic.py similarity index 64% rename from tests/rollout/modular_rollout/integration/test_basic.py rename to tests/fast/rollout/inference_rollout/integration/test_basic.py index bf12cb3735..a148cdf14c 100644 --- a/tests/rollout/modular_rollout/integration/test_basic.py +++ b/tests/fast/rollout/inference_rollout/integration/test_basic.py @@ -1,18 +1,18 @@ import pytest -from tests.fixtures.generation_fixtures import extra_argv_for_variant -from tests.fixtures.rollout_integration import IntegrationEnvConfig -from tests.rollout.modular_rollout.integration.utils import ( +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import ( MODULAR_ROLLOUT_BASE_ARGV, expected_sample, load_and_call_train, ) from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput -from miles.rollout.modular_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function _VARIANTS = [ pytest.param( - IntegrationEnvConfig( + RolloutEnvConfig( extra_argv=[ "--rollout-function-path", "miles.rollout.sglang_rollout.generate_rollout", @@ -25,12 +25,12 @@ id="old_rollout_old_generate", ), pytest.param( - IntegrationEnvConfig( + RolloutEnvConfig( extra_argv=[ "--rollout-function-path", - "miles.rollout.modular_rollout.orchestration_train.SimpleTrainRolloutFn", + "miles.rollout.inference_rollout.inference_rollout_train.SimpleTrainRolloutFn", "--eval-function-path", - "miles.rollout.modular_rollout.orchestration_eval.SimpleEvalRolloutFn", + "miles.rollout.inference_rollout.inference_rollout_eval.SimpleEvalRolloutFn", "--custom-generate-function-path", "miles.rollout.sglang_rollout.generate", ] @@ -38,15 +38,15 @@ id="new_rollout_old_generate", ), pytest.param( - IntegrationEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant("single_turn")), + RolloutEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant("single_turn")), id="new_rollout_new_generate", ), ] -@pytest.mark.parametrize("rollout_integration_env", _VARIANTS, indirect=True) -def test_train(rollout_integration_env): - env = rollout_integration_env +@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) +def test_train(rollout_env): + env = rollout_env out = load_and_call_train(env.args, env.data_source) assert len(out.samples) == env.args.rollout_batch_size @@ -55,9 +55,9 @@ def test_train(rollout_integration_env): assert group[0] == expected_sample(group_index=0) -@pytest.mark.parametrize("rollout_integration_env", _VARIANTS, indirect=True) -def test_eval(rollout_integration_env): - env = rollout_integration_env +@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) +def test_eval(rollout_env): + env = rollout_env fn = load_rollout_function( RolloutFnConstructorInput(args=env.args, data_source=env.data_source), env.args.eval_function_path ) diff --git a/tests/rollout/modular_rollout/integration/test_deterministic.py b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py similarity index 74% rename from tests/rollout/modular_rollout/integration/test_deterministic.py rename to tests/fast/rollout/inference_rollout/integration/test_deterministic.py index 5a1dbb4f10..69a2359117 100644 --- a/tests/rollout/modular_rollout/integration/test_deterministic.py +++ b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py @@ -1,10 +1,10 @@ import pytest -from tests.rollout.modular_rollout.integration.utils import integration_env_config, load_and_call_train +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train @pytest.mark.parametrize( - "rollout_integration_env,expected_seeds", + "rollout_env,expected_seeds", [ pytest.param( integration_env_config( @@ -27,10 +27,10 @@ id="disabled", ), ], - indirect=["rollout_integration_env"], + indirect=["rollout_env"], ) -def test_sampling_seeds(rollout_integration_env, expected_seeds): - env = rollout_integration_env +def test_sampling_seeds(rollout_env, expected_seeds): + env = rollout_env load_and_call_train(env.args, env.data_source) seeds = {req.get("sampling_params", {}).get("sampling_seed") for req in env.mock_server.request_log} diff --git a/tests/rollout/modular_rollout/integration/test_dynamic_filter.py b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py similarity index 80% rename from tests/rollout/modular_rollout/integration/test_dynamic_filter.py rename to tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py index eb25c9c1ad..0ca5743ac5 100644 --- a/tests/rollout/modular_rollout/integration/test_dynamic_filter.py +++ b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py @@ -1,7 +1,7 @@ from contextlib import nullcontext import pytest -from tests.rollout.modular_rollout.integration.utils import ( +from tests.fast.rollout.inference_rollout.integration.utils import ( MIXED_DATA_ROWS, filter_by_reward, integration_env_config, @@ -12,7 +12,7 @@ @pytest.mark.parametrize( - "rollout_integration_env,use_filter,expect_all_correct", + "rollout_env,use_filter,expect_all_correct", [ pytest.param( integration_env_config(["--rollout-batch-size", "4"], data_rows=MIXED_DATA_ROWS), @@ -30,10 +30,10 @@ id="with_filter", ), ], - indirect=["rollout_integration_env"], + indirect=["rollout_env"], ) -def test_filter_effect(rollout_integration_env, use_filter, expect_all_correct): - env = rollout_integration_env +def test_filter_effect(rollout_env, use_filter, expect_all_correct): + env = rollout_env ctx = function_registry.temporary("test:filter_by_reward", filter_by_reward) if use_filter else nullcontext() with ctx: diff --git a/tests/rollout/modular_rollout/integration/test_group_rm.py b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py similarity index 68% rename from tests/rollout/modular_rollout/integration/test_group_rm.py rename to tests/fast/rollout/inference_rollout/integration/test_group_rm.py index a1811467c2..afd870c302 100644 --- a/tests/rollout/modular_rollout/integration/test_group_rm.py +++ b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py @@ -1,10 +1,10 @@ import pytest -from tests.rollout.modular_rollout.integration.utils import integration_env_config, load_and_call_train +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train @pytest.mark.parametrize( - "rollout_integration_env", + "rollout_env", [ pytest.param( integration_env_config(["--group-rm", "--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), @@ -13,8 +13,8 @@ ], indirect=True, ) -def test_group_rm_rewards_set(rollout_integration_env): - env = rollout_integration_env +def test_group_rm_rewards_set(rollout_env): + env = rollout_env out = load_and_call_train(env.args, env.data_source) assert len(out.samples) == env.args.rollout_batch_size diff --git a/tests/rollout/modular_rollout/integration/test_multi_sample.py b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py similarity index 82% rename from tests/rollout/modular_rollout/integration/test_multi_sample.py rename to tests/fast/rollout/inference_rollout/integration/test_multi_sample.py index a2e854d9a8..2b12d3d88f 100644 --- a/tests/rollout/modular_rollout/integration/test_multi_sample.py +++ b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py @@ -1,6 +1,6 @@ import pytest -from tests.fixtures.rollout_integration import DEFAULT_DATA_ROWS, IntegrationEnvConfig -from tests.rollout.modular_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_train +from tests.fast.fixtures.rollout_fixtures import DEFAULT_DATA_ROWS, RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_train from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput from miles.utils.misc import function_registry @@ -31,10 +31,10 @@ async def _multi_sample_generate(input: GenerateFnInput) -> GenerateFnOutput: @pytest.mark.parametrize( - "rollout_integration_env", + "rollout_env", [ pytest.param( - IntegrationEnvConfig( + RolloutEnvConfig( extra_argv=MODULAR_ROLLOUT_BASE_ARGV + [ "--custom-generate-function-path", @@ -51,8 +51,8 @@ async def _multi_sample_generate(input: GenerateFnInput) -> GenerateFnOutput: ], indirect=True, ) -def test_multi_sample_output_preserves_existing_reward(rollout_integration_env): - env = rollout_integration_env +def test_multi_sample_output_preserves_existing_reward(rollout_env): + env = rollout_env with function_registry.temporary("test:multi_sample_generate", _multi_sample_generate): out = load_and_call_train(env.args, env.data_source) diff --git a/tests/rollout/modular_rollout/integration/test_multi_turn.py b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py similarity index 87% rename from tests/rollout/modular_rollout/integration/test_multi_turn.py rename to tests/fast/rollout/inference_rollout/integration/test_multi_turn.py index 97df120817..c41d713991 100644 --- a/tests/rollout/modular_rollout/integration/test_multi_turn.py +++ b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py @@ -1,9 +1,9 @@ from typing import Any import pytest -from tests.fixtures.generation_fixtures import extra_argv_for_variant -from tests.fixtures.rollout_integration import IntegrationEnvConfig -from tests.rollout.modular_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_rollout +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_rollout from miles.utils.test_utils.mock_tools import TwoTurnStub from miles.utils.types import Sample @@ -26,25 +26,25 @@ "--n-samples-per-eval-prompt", "2", "--custom-rm-path", - "tests.rollout.modular_rollout.integration.test_generate_hub._simple_reward_function", + "tests.fast.rollout.inference_rollout.integration.test_multi_turn._simple_reward_function", ] -def _config_for_variant(variant: str) -> IntegrationEnvConfig: - return IntegrationEnvConfig( +def _config_for_variant(variant: str) -> RolloutEnvConfig: + return RolloutEnvConfig( extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + _BASE_EXTRA_ARGV, data_rows=TWO_TURN_DATA_ROWS, ) @pytest.mark.parametrize( - "variant,rollout_integration_env", + "variant,rollout_env", [pytest.param(variant, _config_for_variant(variant), id=variant) for variant in _VARIANT_NAMES], - indirect=["rollout_integration_env"], + indirect=["rollout_env"], ) @pytest.mark.parametrize("test_type", ["train", "eval"]) -def test_rollout(rollout_integration_env, variant, test_type): - env = rollout_integration_env +def test_rollout(rollout_env, variant, test_type): + env = rollout_env env.mock_server.process_fn = TwoTurnStub.process_fn out = load_and_call_rollout(env.args, env.data_source, mode=test_type) diff --git a/tests/rollout/modular_rollout/integration/test_over_sampling.py b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py similarity index 83% rename from tests/rollout/modular_rollout/integration/test_over_sampling.py rename to tests/fast/rollout/inference_rollout/integration/test_over_sampling.py index e4318c88fa..0812962cc7 100644 --- a/tests/rollout/modular_rollout/integration/test_over_sampling.py +++ b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py @@ -1,5 +1,5 @@ import pytest -from tests.rollout.modular_rollout.integration.utils import ( +from tests.fast.rollout.inference_rollout.integration.utils import ( filter_by_reward, integration_env_config, load_and_call_train, @@ -27,15 +27,15 @@ def _over_sampling_config(rollout_batch_size: int): @pytest.mark.parametrize( - "rollout_integration_env,expected_rounds", + "rollout_env,expected_rounds", [ pytest.param(_over_sampling_config(1), 1, id="one_round"), pytest.param(_over_sampling_config(2), 2, id="two_rounds"), ], - indirect=["rollout_integration_env"], + indirect=["rollout_env"], ) -def test_over_sampling_rounds(rollout_integration_env, expected_rounds): - env = rollout_integration_env +def test_over_sampling_rounds(rollout_env, expected_rounds): + env = rollout_env with function_registry.temporary("test:filter_by_reward", filter_by_reward): out = load_and_call_train(env.args, env.data_source) diff --git a/tests/rollout/modular_rollout/integration/test_sample_filter.py b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py similarity index 91% rename from tests/rollout/modular_rollout/integration/test_sample_filter.py rename to tests/fast/rollout/inference_rollout/integration/test_sample_filter.py index a69f05b352..36e78c16c1 100644 --- a/tests/rollout/modular_rollout/integration/test_sample_filter.py +++ b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py @@ -1,7 +1,7 @@ from unittest.mock import Mock import pytest -from tests.rollout.modular_rollout.integration.utils import ( +from tests.fast.rollout.inference_rollout.integration.utils import ( filter_by_reward, integration_env_config, load_and_call_train, @@ -20,7 +20,7 @@ @pytest.mark.parametrize( - "rollout_integration_env", + "rollout_env", [ pytest.param( integration_env_config( @@ -43,8 +43,8 @@ ], indirect=True, ) -def test_sample_filter_and_all_samples_process(rollout_integration_env): - env = rollout_integration_env +def test_sample_filter_and_all_samples_process(rollout_env): + env = rollout_env sample_filter_mock = Mock() all_samples_process_mock = Mock() diff --git a/tests/rollout/modular_rollout/integration/test_semaphore.py b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py similarity index 74% rename from tests/rollout/modular_rollout/integration/test_semaphore.py rename to tests/fast/rollout/inference_rollout/integration/test_semaphore.py index ce42728635..889a9ff8ac 100644 --- a/tests/rollout/modular_rollout/integration/test_semaphore.py +++ b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py @@ -1,13 +1,13 @@ import pytest -from tests.rollout.modular_rollout.integration.utils import integration_env_config, load_and_call_train +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train _DATA_ROWS = [{"input": f"What is 1+{i}?", "label": str(1 + i)} for i in range(10)] _BASE_ARGV = ["--rollout-batch-size", "4", "--n-samples-per-prompt", "2"] @pytest.mark.parametrize( - "rollout_integration_env,expected_range", + "rollout_env,expected_range", [ pytest.param( integration_env_config( @@ -24,10 +24,10 @@ id="no_limit", ), ], - indirect=["rollout_integration_env"], + indirect=["rollout_env"], ) -def test_max_concurrent(rollout_integration_env, expected_range): - env = rollout_integration_env +def test_max_concurrent(rollout_env, expected_range): + env = rollout_env load_and_call_train(env.args, env.data_source) min_expected, max_expected = expected_range assert min_expected <= env.mock_server.max_concurrent <= max_expected diff --git a/tests/rollout/modular_rollout/integration/utils.py b/tests/fast/rollout/inference_rollout/integration/utils.py similarity index 86% rename from tests/rollout/modular_rollout/integration/utils.py rename to tests/fast/rollout/inference_rollout/integration/utils.py index 511a43bb70..6f3fb1916d 100644 --- a/tests/rollout/modular_rollout/integration/utils.py +++ b/tests/fast/rollout/inference_rollout/integration/utils.py @@ -1,5 +1,5 @@ -from tests.fixtures.generation_fixtures import extra_argv_for_variant -from tests.fixtures.rollout_integration import IntegrationEnvConfig +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig from miles.rollout.base_types import ( RolloutFnConstructorInput, @@ -8,7 +8,7 @@ RolloutFnTrainInput, ) from miles.rollout.filter_hub.base_types import DynamicFilterOutput -from miles.rollout.modular_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils.types import Sample @@ -42,9 +42,9 @@ def expected_sample(*, group_index: int | None) -> Sample: MODULAR_ROLLOUT_BASE_ARGV = [ "--rollout-function-path", - "miles.rollout.modular_rollout.orchestration_train.SimpleTrainRolloutFn", + "miles.rollout.inference_rollout.inference_rollout_train.SimpleTrainRolloutFn", "--eval-function-path", - "miles.rollout.modular_rollout.orchestration_eval.SimpleEvalRolloutFn", + "miles.rollout.inference_rollout.inference_rollout_eval.SimpleEvalRolloutFn", ] MIXED_DATA_ROWS = [ @@ -61,7 +61,7 @@ def integration_env_config( latency: float = 0.0, variant: str = "single_turn", ): - return IntegrationEnvConfig( + return RolloutEnvConfig( extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + extra_argv, data_rows=data_rows, latency=latency, diff --git a/tests/rollout/modular_rollout/test_compatibility.py b/tests/fast/rollout/inference_rollout/test_compatibility.py similarity index 99% rename from tests/rollout/modular_rollout/test_compatibility.py rename to tests/fast/rollout/inference_rollout/test_compatibility.py index f012cbd490..ddfecd067b 100644 --- a/tests/rollout/modular_rollout/test_compatibility.py +++ b/tests/fast/rollout/inference_rollout/test_compatibility.py @@ -12,7 +12,7 @@ RolloutFnTrainInput, RolloutFnTrainOutput, ) -from miles.rollout.modular_rollout.compatibility import ( +from miles.rollout.inference_rollout.compatibility import ( LegacyGenerateFnAdapter, LegacyRolloutFnAdapter, call_rollout_function, diff --git a/tests/fast/rollout/rm_hub/__init__.py b/tests/fast/rollout/rm_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/rollout/rm_hub/test_deepscaler.py b/tests/fast/rollout/rm_hub/test_deepscaler.py similarity index 100% rename from tests/rollout/rm_hub/test_deepscaler.py rename to tests/fast/rollout/rm_hub/test_deepscaler.py diff --git a/tests/rollout/rm_hub/test_f1.py b/tests/fast/rollout/rm_hub/test_f1.py similarity index 100% rename from tests/rollout/rm_hub/test_f1.py rename to tests/fast/rollout/rm_hub/test_f1.py diff --git a/tests/rollout/rm_hub/test_gpqa.py b/tests/fast/rollout/rm_hub/test_gpqa.py similarity index 100% rename from tests/rollout/rm_hub/test_gpqa.py rename to tests/fast/rollout/rm_hub/test_gpqa.py diff --git a/tests/rollout/rm_hub/test_math_dapo_utils.py b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py similarity index 100% rename from tests/rollout/rm_hub/test_math_dapo_utils.py rename to tests/fast/rollout/rm_hub/test_math_dapo_utils.py diff --git a/tests/rollout/rm_hub/test_math_utils.py b/tests/fast/rollout/rm_hub/test_math_utils.py similarity index 100% rename from tests/rollout/rm_hub/test_math_utils.py rename to tests/fast/rollout/rm_hub/test_math_utils.py diff --git a/tests/rollout/rm_hub/test_rm_hub.py b/tests/fast/rollout/rm_hub/test_rm_hub.py similarity index 100% rename from tests/rollout/rm_hub/test_rm_hub.py rename to tests/fast/rollout/rm_hub/test_rm_hub.py diff --git a/tests/fast/router/__init__.py b/tests/fast/router/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/router/test_router.py b/tests/fast/router/test_router.py similarity index 100% rename from tests/router/test_router.py rename to tests/fast/router/test_router.py diff --git a/tests/router/test_sessions.py b/tests/fast/router/test_sessions.py similarity index 100% rename from tests/router/test_sessions.py rename to tests/fast/router/test_sessions.py diff --git a/tests/fast/utils/__init__.py b/tests/fast/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/utils/test_arguments.py b/tests/fast/utils/test_arguments.py similarity index 100% rename from tests/utils/test_arguments.py rename to tests/fast/utils/test_arguments.py diff --git a/tests/utils/test_mask_utils.py b/tests/fast/utils/test_mask_utils.py similarity index 100% rename from tests/utils/test_mask_utils.py rename to tests/fast/utils/test_mask_utils.py diff --git a/tests/utils/test_misc.py b/tests/fast/utils/test_misc.py similarity index 100% rename from tests/utils/test_misc.py rename to tests/fast/utils/test_misc.py diff --git a/tests/fast/utils/test_utils/__init__.py b/tests/fast/utils/test_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/utils/test_utils/test_mock_sglang_server.py b/tests/fast/utils/test_utils/test_mock_sglang_server.py similarity index 100% rename from tests/utils/test_utils/test_mock_sglang_server.py rename to tests/fast/utils/test_utils/test_mock_sglang_server.py diff --git a/tests/utils/test_utils/test_mock_tools.py b/tests/fast/utils/test_utils/test_mock_tools.py similarity index 100% rename from tests/utils/test_utils/test_mock_tools.py rename to tests/fast/utils/test_utils/test_mock_tools.py diff --git a/tests/rollout/modular_rollout/test_integration.py b/tests/rollout/modular_rollout/test_integration.py deleted file mode 100644 index ed21ceee51..0000000000 --- a/tests/rollout/modular_rollout/test_integration.py +++ /dev/null @@ -1,98 +0,0 @@ -import pytest - -from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput, RolloutFnTrainInput -from miles.rollout.modular_rollout.compatibility import call_rollout_function, load_rollout_function -from miles.utils.types import Sample - - -def _expected_sample(*, group_index: int | None) -> Sample: - return Sample( - group_index=group_index, - index=0, - prompt="What is 1+7?", - tokens=[3838, 374, 220, 16, 10, 22, 30, 59, 79075, 90, 23, 92], - multimodal_inputs=None, - multimodal_train_inputs=None, - response="\\boxed{8}", - response_length=5, - label="8", - reward=1, - loss_mask=None, - weight_versions=[], - rollout_log_probs=[-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125], - rollout_routed_experts=None, - remove_sample=False, - status=Sample.Status.COMPLETED, - metadata={}, - train_metadata=None, - non_generation_time=0.0, - spec_info=Sample.SpecInfo( - spec_accept_token_num=0, spec_draft_token_num=0, spec_verify_ct=0, completion_token_num=0 - ), - prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=0, total_prompt_tokens=7), - ) - - -_ROLLOUT_ARGV_VARIANTS = [ - pytest.param( - [ - "--rollout-function-path", - "miles.rollout.sglang_rollout.generate_rollout", - "--eval-function-path", - "miles.rollout.sglang_rollout.generate_rollout", - "--custom-generate-function-path", - "miles.rollout.sglang_rollout.generate", - ], - id="old_rollout_old_generate", - ), - pytest.param( - [ - "--rollout-function-path", - "miles.rollout.modular_rollout.orchestration_train.SimpleTrainRolloutFn", - "--eval-function-path", - "miles.rollout.modular_rollout.orchestration_eval.SimpleEvalRolloutFn", - "--custom-generate-function-path", - "miles.rollout.sglang_rollout.generate", - ], - id="new_rollout_old_generate", - ), - pytest.param( - [ - "--rollout-function-path", - "miles.rollout.modular_rollout.orchestration_train.SimpleTrainRolloutFn", - "--eval-function-path", - "miles.rollout.modular_rollout.orchestration_eval.SimpleEvalRolloutFn", - "--custom-generate-function-path", - "miles.rollout.modular_rollout.inference_wrapper.generate", - ], - id="new_rollout_new_generate", - ), -] - - -@pytest.mark.parametrize("rollout_integration_env", _ROLLOUT_ARGV_VARIANTS, indirect=True) -def test_simple_train_rollout_fn_integration(rollout_integration_env): - args, data_source = rollout_integration_env - fn = load_rollout_function( - RolloutFnConstructorInput(args=args, data_source=data_source), args.rollout_function_path - ) - out = call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0)) - - assert len(out.samples) == args.rollout_batch_size - group = out.samples[0] - assert len(group) == args.n_samples_per_prompt - assert group[0] == _expected_sample(group_index=0) - - -@pytest.mark.parametrize("rollout_integration_env", _ROLLOUT_ARGV_VARIANTS, indirect=True) -def test_simple_eval_rollout_fn_integration(rollout_integration_env): - args, data_source = rollout_integration_env - fn = load_rollout_function(RolloutFnConstructorInput(args=args, data_source=data_source), args.eval_function_path) - out = call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) - - assert "toy" in out.data - rewards = out.data["toy"]["rewards"] - samples = out.data["toy"]["samples"] - assert len(rewards) == len(samples) == args.n_samples_per_eval_prompt - assert rewards[0] == 1 - assert samples[0] == _expected_sample(group_index=None) From 6ecdec9a870227478fa68ae3c6b7c039a6bff83c Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:30:53 +0800 Subject: [PATCH 52/77] Add CPU-only tests to CI (#490) --- .github/workflows/pr-test.yml | 40 +++++++++++++++++++ .github/workflows/pr-test.yml.j2 | 19 ++++----- requirements.txt | 1 + tests/ci/gpu_lock_exec.py | 11 +++-- .../generate_hub/test_tool_call_utils.py | 4 +- .../fast/utils/test_utils/test_mock_tools.py | 10 ++--- 6 files changed, 63 insertions(+), 22 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index d34c823aa3..4b8b5dc82c 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -25,6 +25,46 @@ concurrency: jobs: + fast: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=16g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 0, "test_file": "fast"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} + e2e-test-short: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) runs-on: self-hosted diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 37b6fa4463..c052b8494f 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -1,4 +1,10 @@ <% set jobs = { + 'fast': { + 'test_executor': 'pytest', + 'tests': [ + {'test_file': 'fast', 'num_gpus': 0}, + ], + }, 'e2e-test-short': { 'label': 'run-ci-short', 'tests': [ @@ -98,7 +104,7 @@ concurrency: jobs: <% for job_name, config in jobs.items() %> << job_name >>: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, '<< config.label >>')) + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request<% if config.label %> && contains(github.event.pull_request.labels.*.name, '<< config.label >>')<% endif %>) runs-on: self-hosted container: image: << config.image if config.image else 'radixark/miles:latest' >> @@ -153,14 +159,5 @@ jobs: - name: Execute shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- << config.test_executor | default('python') >> tests/${{ matrix.info.test_file }} <% endfor %> \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2c20195fc4..dacd51132c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf pillow +pybase64 pylatexenc pyyaml qwen_vl_utils # for VLM diff --git a/tests/ci/gpu_lock_exec.py b/tests/ci/gpu_lock_exec.py index 9507e2e858..20379f76a2 100644 --- a/tests/ci/gpu_lock_exec.py +++ b/tests/ci/gpu_lock_exec.py @@ -19,11 +19,14 @@ def main(): _execute_print_only(args) return - fd_locks = _try_acquire(args) + if args.count == 0 and not args.devices: + print("[gpu_lock_exec] Do not acquire GPU since count=0", flush=True) + else: + fd_locks = _try_acquire(args) - dev_list = ",".join(str(x.gpu_id) for x in fd_locks) - os.environ[args.target_env_name] = dev_list - print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) + dev_list = ",".join(str(x.gpu_id) for x in fd_locks) + os.environ[args.target_env_name] = dev_list + print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) _os_execvp(args) diff --git a/tests/fast/rollout/generate_hub/test_tool_call_utils.py b/tests/fast/rollout/generate_hub/test_tool_call_utils.py index a89ebfb408..0f2305e753 100644 --- a/tests/fast/rollout/generate_hub/test_tool_call_utils.py +++ b/tests/fast/rollout/generate_hub/test_tool_call_utils.py @@ -7,7 +7,7 @@ "Qwen/Qwen3-0.6B", "Qwen/Qwen3-4B-Instruct-2507", "Qwen/Qwen3-Coder-30B-A3B-Instruct", - "meta-llama/Llama-3.2-1B-Instruct", + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI "mistralai/Mistral-7B-Instruct-v0.3", "deepseek-ai/DeepSeek-V3", "stepfun-ai/step3", @@ -19,7 +19,7 @@ ] SINGLE_TOOL_CALL_ONLY_MODELS = [ - "meta-llama/Llama-3.2-1B-Instruct", + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo ] # Models where tokenize->decode produces extra whitespace vs direct string diff diff --git a/tests/fast/utils/test_utils/test_mock_tools.py b/tests/fast/utils/test_utils/test_mock_tools.py index b905fa8525..3f2116ec01 100644 --- a/tests/fast/utils/test_utils/test_mock_tools.py +++ b/tests/fast/utils/test_utils/test_mock_tools.py @@ -70,7 +70,7 @@ class TestSGLangFunctionCallParser: 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n', ( "Let me check for you.", - [ToolCallItem(tool_index=0, name="get_year", parameters="{}")], + [ToolCallItem(tool_index=-1, name="get_year", parameters="{}")], ), id="single_tool_call", ), @@ -81,8 +81,8 @@ class TestSGLangFunctionCallParser: ( "I will get year and temperature.", [ - ToolCallItem(tool_index=0, name="get_year", parameters="{}"), - ToolCallItem(tool_index=1, name="get_temperature", parameters='{"location": "Shanghai"}'), + ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), + ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Shanghai"}'), ], ), id="multi_tool_calls", @@ -97,8 +97,8 @@ class TestSGLangFunctionCallParser: ( "Let me get the year and temperature first.", [ - ToolCallItem(tool_index=0, name="get_year", parameters="{}"), - ToolCallItem(tool_index=1, name="get_temperature", parameters='{"location": "Mars"}'), + ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), + ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Mars"}'), ], ), id="multi_turn_first_response", From 59fa9f1466ee9c7941f40ed32631a91b671ca04d Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:31:35 +0800 Subject: [PATCH 53/77] Use new rollout function by default when corresponding flag is on (#491) --- .github/workflows/pr-test.yml | 40 +++++++++++++++++ miles/ray/rollout.py | 4 +- .../inference_rollout_common.py | 44 ++++++++++++++++++- .../inference_rollout_eval.py | 17 ------- .../inference_rollout_train.py | 15 +------ miles/utils/arguments.py | 10 +++-- miles/utils/environ.py | 13 +++++- .../integration/test_basic.py | 4 +- .../inference_rollout/integration/utils.py | 4 +- 9 files changed, 106 insertions(+), 45 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 4b8b5dc82c..e2167b93d2 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -65,6 +65,46 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} + e2e-test-short: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=16g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 0, "test_file": "fast"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} + e2e-test-short: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) runs-on: self-hosted diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 6198d62360..27211845d8 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -21,7 +21,7 @@ ) from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils import tracking_utils -from miles.utils.environ import get_experimental_rollout_refactor +from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.health_monitor import RolloutHealthMonitor from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client from miles.utils.iter_utils import group_by @@ -60,7 +60,7 @@ def __init__(self, args, pg): data_source_cls = load_function(self.args.data_source_path) self.data_source = data_source_cls(args) - self.use_experimental_refactor = get_experimental_rollout_refactor() + self.use_experimental_refactor = enable_experimental_rollout_refactor() if self.use_experimental_refactor: input = RolloutFnConstructorInput(args=args, data_source=self.data_source) self.generate_rollout = load_rollout_function(input, self.args.rollout_function_path) diff --git a/miles/rollout/inference_rollout/inference_rollout_common.py b/miles/rollout/inference_rollout/inference_rollout_common.py index 5d6f67de24..8518c6e020 100644 --- a/miles/rollout/inference_rollout/inference_rollout_common.py +++ b/miles/rollout/inference_rollout/inference_rollout_common.py @@ -4,7 +4,16 @@ from copy import deepcopy from typing import Any -from miles.rollout.base_types import GenerateFnInput +from miles.rollout.base_types import ( + GenerateFnInput, + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnEvalOutput, + RolloutFnInput, + RolloutFnOutput, + RolloutFnTrainInput, + RolloutFnTrainOutput, +) from miles.rollout.generate_hub.single_turn import generate from miles.rollout.inference_rollout.compatibility import load_generate_function from miles.rollout.rm_hub import async_rm, batched_async_rm @@ -148,3 +157,36 @@ def compute_sampling_params( no_stop_trim=True, spaces_between_special_tokens=False, ) + + +class InferenceRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + self.data_source = input.data_source + self.state = GenerateState(input.args) + self.eval_prompt_dataset_cache = {} + + async def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: + if input.evaluation: + return await self._call_eval(input) + return await self._call_train(input) + + async def _call_train(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput: + from miles.rollout.inference_rollout.inference_rollout_train import generate_rollout_async + + output, aborted_samples = await generate_rollout_async( + self.state, input.rollout_id, self.data_source.get_samples + ) + self.data_source.add_samples(aborted_samples) + return output + + async def _call_eval(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput: + from miles.rollout.inference_rollout.inference_rollout_eval import eval_rollout_single_dataset + + assert not self.state.args.group_rm, "Group RM is not supported for eval rollout" + + coros = [] + for dataset_cfg in getattr(self.state.args, "eval_datasets", []) or []: + coros.append(eval_rollout_single_dataset(self.state, dataset_cfg, self.eval_prompt_dataset_cache)) + results_list = await asyncio.gather(*coros) + results = {k: v for r in results_list for k, v in r.items()} + return RolloutFnEvalOutput(data=results) diff --git a/miles/rollout/inference_rollout/inference_rollout_eval.py b/miles/rollout/inference_rollout/inference_rollout_eval.py index 3117598f5c..2d052be0ae 100644 --- a/miles/rollout/inference_rollout/inference_rollout_eval.py +++ b/miles/rollout/inference_rollout/inference_rollout_eval.py @@ -5,7 +5,6 @@ from tqdm import tqdm -from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput, RolloutFnEvalOutput from miles.rollout.inference_rollout.inference_rollout_common import ( GenerateState, compute_sampling_params, @@ -111,19 +110,3 @@ async def eval_rollout_single_dataset( "samples": data, } } - - -class SimpleEvalRolloutFn: - def __init__(self, input: RolloutFnConstructorInput): - self.prompt_dataset_cache = {} - self.state = GenerateState(input.args) - - async def __call__(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput: - assert not self.state.args.group_rm, "Group RM is not supported for eval rollout" - - coros = [] - for dataset_cfg in getattr(self.state.args, "eval_datasets", []) or []: - coros.append(eval_rollout_single_dataset(self.state, dataset_cfg, self.prompt_dataset_cache)) - results_list = await asyncio.gather(*coros) - results = {k: v for r in results_list for k, v in r.items()} - return RolloutFnEvalOutput(data=results) diff --git a/miles/rollout/inference_rollout/inference_rollout_train.py b/miles/rollout/inference_rollout/inference_rollout_train.py index b0b7741755..bae94ec67b 100644 --- a/miles/rollout/inference_rollout/inference_rollout_train.py +++ b/miles/rollout/inference_rollout/inference_rollout_train.py @@ -7,7 +7,7 @@ from packaging.version import parse from tqdm import tqdm -from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnTrainInput, RolloutFnTrainOutput +from miles.rollout.base_types import RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter from miles.rollout.inference_rollout.inference_rollout_common import GenerateState, generate_and_rm_group from miles.utils.http_utils import get, post @@ -144,16 +144,3 @@ async def generate_rollout_async( f(args, all_samples, data_source) return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples - - -class SimpleTrainRolloutFn: - def __init__(self, input: RolloutFnConstructorInput): - self.data_source = input.data_source - self.state = GenerateState(input.args) - - async def __call__(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput: - output, aborted_samples = await generate_rollout_async( - self.state, input.rollout_id, self.data_source.get_samples - ) - self.data_source.add_samples(aborted_samples) - return output diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index c95f91ae90..0710202924 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -10,7 +10,7 @@ from miles.backends.sglang_utils.arguments import add_sglang_arguments from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args -from miles.utils.environ import get_experimental_rollout_refactor +from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from miles.utils.logging_utils import configure_logger from miles.utils.misc import load_function @@ -206,7 +206,11 @@ def add_rollout_arguments(parser): parser.add_argument( "--rollout-function-path", type=str, - default="miles.rollout.sglang_rollout.generate_rollout", + default=( + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn" + if enable_experimental_rollout_refactor() + else "miles.rollout.sglang_rollout.generate_rollout" + ), help=( "Path to the rollout generation function." "You should use this model to create your own custom rollout function, " @@ -1390,7 +1394,7 @@ def add_sglang_tp_size(): parser = add_prefill_decode_disaggregation_arguments(parser) parser = add_ci_arguments(parser) parser = add_custom_megatron_plugins_arguments(parser) - if get_experimental_rollout_refactor(): + if enable_experimental_rollout_refactor(): parser = add_user_provided_function_arguments(parser) reset_arg( parser, diff --git a/miles/utils/environ.py b/miles/utils/environ.py index 155e3fbf1b..35d1f350ee 100644 --- a/miles/utils/environ.py +++ b/miles/utils/environ.py @@ -1,5 +1,14 @@ import os +_printed_experimental_rollout_refactor = False -def get_experimental_rollout_refactor() -> bool: - return bool(int(os.environ.get("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", "0"))) + +def enable_experimental_rollout_refactor() -> bool: + result = bool(int(os.environ.get("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", "0"))) + + global _printed_experimental_rollout_refactor + if result and not _printed_experimental_rollout_refactor: + print("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1 is enabled (experimental feature)") + _printed_experimental_rollout_refactor = True + + return result diff --git a/tests/fast/rollout/inference_rollout/integration/test_basic.py b/tests/fast/rollout/inference_rollout/integration/test_basic.py index a148cdf14c..5b791829d5 100644 --- a/tests/fast/rollout/inference_rollout/integration/test_basic.py +++ b/tests/fast/rollout/inference_rollout/integration/test_basic.py @@ -28,9 +28,7 @@ RolloutEnvConfig( extra_argv=[ "--rollout-function-path", - "miles.rollout.inference_rollout.inference_rollout_train.SimpleTrainRolloutFn", - "--eval-function-path", - "miles.rollout.inference_rollout.inference_rollout_eval.SimpleEvalRolloutFn", + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", "--custom-generate-function-path", "miles.rollout.sglang_rollout.generate", ] diff --git a/tests/fast/rollout/inference_rollout/integration/utils.py b/tests/fast/rollout/inference_rollout/integration/utils.py index 6f3fb1916d..ad413cf949 100644 --- a/tests/fast/rollout/inference_rollout/integration/utils.py +++ b/tests/fast/rollout/inference_rollout/integration/utils.py @@ -42,9 +42,7 @@ def expected_sample(*, group_index: int | None) -> Sample: MODULAR_ROLLOUT_BASE_ARGV = [ "--rollout-function-path", - "miles.rollout.inference_rollout.inference_rollout_train.SimpleTrainRolloutFn", - "--eval-function-path", - "miles.rollout.inference_rollout.inference_rollout_eval.SimpleEvalRolloutFn", + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", ] MIXED_DATA_ROWS = [ From f20b28cea00ada70accc13134d99c31282f943ee Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:33:03 +0800 Subject: [PATCH 54/77] Enable experimental rollout flag for CI tests (#508) --- .github/workflows/pr-test.yml | 40 ---------- miles/rollout/base_types.py | 37 +-------- .../generate_hub/multi_turn_single_sample.py | 77 ------------------- tests/test_external_rollout.py | 1 + tests/test_mimo_7B_mtp_only_grad.py | 1 + tests/test_moonlight_16B_A3B.py | 1 + tests/test_quick_start_glm4_9B.py | 1 + tests/test_qwen2.5_0.5B_gsm8k.py | 1 + tests/test_qwen2.5_0.5B_gsm8k_async.py | 1 + tests/test_qwen2.5_0.5B_gsm8k_async_short.py | 1 + tests/test_qwen2.5_0.5B_gsm8k_short.py | 1 + tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 1 + tests/test_qwen3_0.6B_fsdp_distributed.py | 1 + tests/test_qwen3_0.6B_megatron_fsdp_align.py | 3 + tests/test_qwen3_0.6B_parallel_check.py | 2 + tests/test_qwen3_30B_A3B.py | 1 + tests/test_qwen3_4B_ckpt.py | 1 + tests/test_qwen3_4B_fsdp_true_on_policy.py | 1 + tests/test_qwen3_4B_ppo.py | 1 + tests/test_qwen3_vl_4B_fsdp.py | 1 + 20 files changed, 24 insertions(+), 150 deletions(-) delete mode 100644 miles/rollout/generate_hub/multi_turn_single_sample.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index e2167b93d2..4b8b5dc82c 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -65,46 +65,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} - e2e-test-short: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) - runs-on: self-hosted - container: - image: radixark/miles:latest - options: > - --gpus all - --ipc=host - --shm-size=16g - --ulimit memlock=-1 - --ulimit stack=67108864 - --memory=0 - --memory-swap=0 - -v /mnt/nvme0n1/miles_ci:/data/miles_ci - -v /mnt/nvme0n1/miles_ci/models:/root/models - -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets - strategy: - fail-fast: false - matrix: - info: [{"num_gpus": 0, "test_file": "fast"}] - defaults: - run: - working-directory: ${{ github.workspace }} - env: - GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} - WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} - MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Install - shell: bash - run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - - - name: Execute - shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} - e2e-test-short: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) runs-on: self-hosted diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index daa53634c9..c2644e87f9 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -83,37 +83,8 @@ def call_rollout_fn(fn, *args, evaluation: bool, **kwargs): """Legacy rollout function call interface. Used when MILES_EXPERIMENTAL_ROLLOUT_REFACTOR is disabled.""" output = fn(*args, **kwargs, evaluation=evaluation) + # compatibility for legacy version + if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)): + output = RolloutFnEvalOutput(data=output) if evaluation else RolloutFnTrainOutput(samples=output) -# TODO: may add add_arguments -# TODO: may add save/load if need it to be stateful -# Duck typing, users do not need to extend this class -@runtime_checkable -class RolloutFnProtocol(Protocol): - def __call__(self, input: RolloutFnInput) -> RolloutFnOutput | Awaitable[RolloutFnOutput]: ... - - -# TODO maybe put to modular_rollout folder depending on overall folder structure -@dataclass(frozen=True) -class GenerateFnInput: - state: GenerateState - sample: Sample - sampling_params: dict[str, Any] - evaluation: bool - - @property - def args(self) -> Namespace: - return self.state.args - - -@dataclass(frozen=True) -class GenerateFnOutput: - # One generate may lead to multiple samples, such as multi-agent, tree-like exploration, or - # multi-turn with removing thinking tokens. - samples: Sample | list[Sample] - - -# TODO: may add add_arguments -# TODO: may add save/load if need it to be stateful -@runtime_checkable -class GenerateFnProtocol(Protocol): - async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: ... + return output diff --git a/miles/rollout/generate_hub/multi_turn_single_sample.py b/miles/rollout/generate_hub/multi_turn_single_sample.py deleted file mode 100644 index 2f969cef69..0000000000 --- a/miles/rollout/generate_hub/multi_turn_single_sample.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Simple multi-turn generation with tool calling. -""" - -import argparse - -from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput -from miles.rollout.generate_hub.generate_endpoint_wrapper import ( - compute_prompt_ids_from_sample, - compute_request_payload, - update_sample_from_response, -) -from miles.rollout.generate_hub.tool_call_utils import ( - create_tool_call_parser, - execute_tool_calls, - update_sample_with_tool_responses, -) -from miles.utils.http_utils import post -from miles.utils.misc import load_function - - -async def generate(input: GenerateFnInput) -> GenerateFnOutput: - # ----------------------- Setup ------------------------- - - args = input.args - sample = input.sample - tokenizer = input.state.tokenizer - assert not args.partial_rollout - - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - - execute_tool_function = load_function(args.generate_execute_tool_function_path) - - tool_specs = load_function(args.generate_tool_specs_path) - tool_call_parser = create_tool_call_parser(tool_specs, args.generate_tool_call_parser) - - # ----------------------- Initial prompts ------------------------- - - prompt_tokens_ids = compute_prompt_ids_from_sample(input.state, sample, tools=tool_specs) - - sample.loss_mask = [] - sample.tokens = prompt_tokens_ids.copy() - - for _turn in range(args.generate_max_turns): - # ----------------------- Call inference endpoint ------------------------- - - payload, halt_status = compute_request_payload(args, sample.tokens, input.sampling_params) - if payload is None: - sample.status = halt_status - break - - output = await post(url, payload) - await update_sample_from_response(args, sample, payload=payload, output=output, update_loss_mask=True) - - if output["meta_info"]["finish_reason"]["type"] in ("abort", "length"): - break - - # ----------------------- Execute tools ------------------------- - - _, tool_calls = tool_call_parser.parse_non_stream(output["text"]) - if len(tool_calls) == 0: - break - - tool_messages = await execute_tool_calls(tool_calls, execute_tool_function) - update_sample_with_tool_responses(sample, tool_messages, tokenizer=tokenizer) - - return GenerateFnOutput(samples=sample) - - -def _add_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--generate-max-turns", type=int, default=16) - parser.add_argument("--generate-tool-specs-path", type=str) - parser.add_argument("--generate-tool-call-parser", type=str) - parser.add_argument("--generate-execute-tool-function-path", type=str) - - -generate.add_arguments = _add_arguments diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index c5c0838c53..9b6e69c295 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -126,6 +126,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, before_ray_job_submit=_launch_background, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_mimo_7B_mtp_only_grad.py b/tests/test_mimo_7B_mtp_only_grad.py index 97c76ace5a..d90a2d7a71 100644 --- a/tests/test_mimo_7B_mtp_only_grad.py +++ b/tests/test_mimo_7B_mtp_only_grad.py @@ -135,6 +135,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_moonlight_16B_A3B.py b/tests/test_moonlight_16B_A3B.py index b1255982ed..c35943ec15 100644 --- a/tests/test_moonlight_16B_A3B.py +++ b/tests/test_moonlight_16B_A3B.py @@ -113,6 +113,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index 15ca8ce5fe..ae3c383ae8 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -115,6 +115,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index dcdbd58347..4d7f034f6c 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -120,6 +120,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index dcaaf5e1f7..32b60f5937 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -120,6 +120,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py index 90cd15cb68..b1954a4e83 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py @@ -118,6 +118,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_short.py b/tests/test_qwen2.5_0.5B_gsm8k_short.py index 867fdcad60..86e21eac8d 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_short.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_short.py @@ -117,6 +117,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 3d19b48ced..3d4768e420 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -93,6 +93,7 @@ def execute(): train_args=train_args, num_gpus_per_node=2, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index 3d70f3e4ce..fcd7772882 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -95,6 +95,7 @@ def execute(): num_gpus_per_node=2 if FEW_GPU else 4, megatron_model_type=None, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/test_qwen3_0.6B_megatron_fsdp_align.py index 1431d8c3d4..b89a2f283b 100644 --- a/tests/test_qwen3_0.6B_megatron_fsdp_align.py +++ b/tests/test_qwen3_0.6B_megatron_fsdp_align.py @@ -97,6 +97,7 @@ def execute(): train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "), num_gpus_per_node=NUM_GPUS, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) U.execute_train( @@ -109,6 +110,7 @@ def execute(): ), num_gpus_per_node=NUM_GPUS, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) U.execute_train( @@ -135,6 +137,7 @@ def execute(): "--debug-train-only " ), num_gpus_per_node=NUM_GPUS, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, megatron_model_type=MODEL_TYPE, ) diff --git a/tests/test_qwen3_0.6B_parallel_check.py b/tests/test_qwen3_0.6B_parallel_check.py index 44f5c42fa5..d0ad283d15 100644 --- a/tests/test_qwen3_0.6B_parallel_check.py +++ b/tests/test_qwen3_0.6B_parallel_check.py @@ -95,6 +95,7 @@ def execute(): ), num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) # 8 GPU CPU 1 for num_gpus in [8, 4, 2]: @@ -124,6 +125,7 @@ def execute(): train_args=args, num_gpus_per_node=num_gpus, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) train_args += "--calculate-per-token-loss " diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index adff108043..b30eeed8e5 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -139,6 +139,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_4B_ckpt.py b/tests/test_qwen3_4B_ckpt.py index 22fb2b5fc3..0df4492e10 100644 --- a/tests/test_qwen3_4B_ckpt.py +++ b/tests/test_qwen3_4B_ckpt.py @@ -124,6 +124,7 @@ def execute(mode: str = ""): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_4B_fsdp_true_on_policy.py b/tests/test_qwen3_4B_fsdp_true_on_policy.py index 7c975c7cc2..03ba4094e9 100644 --- a/tests/test_qwen3_4B_fsdp_true_on_policy.py +++ b/tests/test_qwen3_4B_fsdp_true_on_policy.py @@ -95,6 +95,7 @@ def execute(): "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", "CUBLAS_WORKSPACE_CONFIG": ":4096:8", "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", } U.execute_train( diff --git a/tests/test_qwen3_4B_ppo.py b/tests/test_qwen3_4B_ppo.py index 962f610fac..d4c1ac273a 100644 --- a/tests/test_qwen3_4B_ppo.py +++ b/tests/test_qwen3_4B_ppo.py @@ -122,6 +122,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_vl_4B_fsdp.py b/tests/test_qwen3_vl_4B_fsdp.py index fbdffd237e..bc4ef3293c 100644 --- a/tests/test_qwen3_vl_4B_fsdp.py +++ b/tests/test_qwen3_vl_4B_fsdp.py @@ -92,6 +92,7 @@ def execute(): extra_env_vars = { "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", } U.execute_train( From f652c2c42a499affb95487044eb91b21ebef19c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Sun, 25 Jan 2026 11:44:30 -0800 Subject: [PATCH 55/77] rather professional readme document (#511) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- README.md | 253 +++++++++++++++++++++--------------------------------- 1 file changed, 100 insertions(+), 153 deletions(-) diff --git a/README.md b/README.md index 809f78471b..aa9ab01560 100644 --- a/README.md +++ b/README.md @@ -1,212 +1,159 @@ -
-logo +
-[![GitHub Repo](https://img.shields.io/badge/github-radixark%2Fmiles-black?logo=github)](https://github.com/radixark/miles) - - -
- - -> A journey of a thousand miles is made one small step at a time. +Miles Logo -**Miles** is an enterprise-facing reinforcement learning framework for **large-scale MoE post-training and production workloads**, forked from and co-evolving with **[slime](https://github.com/THUDM/slime)**. +### **Enterprise-Grade Reinforcement Learning for Large-Scale Model Training** +### **High-Performance Rollout • Low Precision Training • Production Stability** -Miles keeps slime’s lightweight, modular design, but focuses on: - -- New hardware support (e.g., GB300 and beyond) -- Stable, controllable RL for large MoE models -- Production-grade features - - -## News +[![GitHub Repo](https://img.shields.io/badge/github-radixark%2Fmiles-black?logo=github)](https://github.com/radixark/miles) +[![License](https://img.shields.io/github/license/radixark/miles)](LICENSE) +[![Slack](https://img.shields.io/badge/slack-join-brightgreen.svg)](https://slack.sglang.ai) -- [2025/12] Support FSDP2 as A Training Backend for Miles ([blog](https://lmsys.org/blog/2025-12-03-miles-fsdp/)). -- [2025/11] Unified FP8: Moving Beyond Mixed Precision for Stable and Accelerated MoE RL ([blog](https://lmsys.org/blog/2025-11-25-fp8-rl/)). -- [2025/11] Power Up Speculative Decoding In Reinforcement Learning ([blog](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/spec/readme-en.md)). -- [2025/11] Introduce Miles - born after slime towards enterprise RL training ([blog](https://lmsys.org/blog/2025-11-19-miles/)). +[**Latest Updates**](#latest-updates) | [**Quick Start**](#quick-start) | [**Key Features**](#key-features) | [**Documentation**](docs/en/get_started/quick_start.md) +
--- -## Table of Contents -- [Quick Start](#quick-start) -- [Arguments Walkthrough](#arguments-walkthrough) -- [Developer Guide](#developer-guide) -- [Recent Updates](#recent-updates) -- [Roadmap](#roadmap) -- [Architecture Overview](#architecture-overview) -- [FAQ & Acknowledgements](#faq--acknowledgements) ---- +## Latest Updates -## Quick Start +* **[2026/01]** 💎 **INT4 Quantization-Aware Training (QAT)**: Inspired by the Kimi K2-Thinking report, Miles now features a full-stack INT4 W4A16 QAT pipeline. This allows 1TB-scale models to fit into single-machine VRAM (e.g., NVIDIA H200), doubling rollout efficiency by eliminating cross-node bottlenecks while maintaining BF16-equivalent accuracy. [Blog Coming Soon] +* **[2026/01]** 🤖 **Multi-Agent Co-Evolution**: Miles now supports **MrlX**, a novel asynchronous co-evolutionary framework for Multi-Agent RL. Achieve superior performance in complex tasks like Doctor-Patient simulations and DeepResearch pipelines by enabling specialized agents to evolve together symbiotically. [[Link]](https://github.com/AQ-MedAI/MrlX) +* **[2025/12]** 🔄 **Rollout Routing Replay (R3)**: In collaboration with SGLang, we've launched R3 to solve MoE RL instability. R3 records inference routing decisions and replays them during training, effectively eliminating the "training-inference mismatch" and preventing training collapse in large MoE models like Qwen3 and DeepSeek-V3. [[Paper]](https://arxiv.org/pdf/2510.11370) +* **[2025/11]** 🔥 **Unified FP8 Release**: Solves the stability issues in MoE RL by ensuring training and inference use the exact same FP8 quantization logic. [[Blog]](https://lmsys.org/blog/2025-11-25-fp8-rl/) +* **[2025/11]** ⚡ **Speculative Decoding in RL**: Integrated speculative rollout with online SFT for draft models, achieving massive throughput gains. [[Blog]](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/spec/readme-en.md) +* **[2025/11]** 🎉 **Miles Project Launch**: A joint effort by InfiXAI, Ant Group, SGLang RL Team, and the Miles community. [[Announcement]](https://lmsys.org/blog/2025-11-19-miles/) -> **Note:** Miles is under active development. Commands and examples may evolve; please check the repo for the latest instructions. +## What is Miles? -For a comprehensive quick start guide covering environment setup, data preparation, training startup, and key code analysis, please refer to: -- [Quick Start Guide](./docs/en/get_started/quick_start.md) +**Miles** is a high-performance, enterprise-ready reinforcement learning (RL) framework specifically optimized for **Large-Scale model Post-Training**. Built as a powerful fork of **[slime](https://github.com/THUDM/slime)**, Miles bridges the gap between research-grade RL and production-grade reliability by integrating **SGLang** for high-throughput rollout and **Megatron-LM** for scalable training. -We also provide examples for some use cases not covered in the quick start guide; please check [examples](examples/). +> *"A journey of a thousand miles begins with a single step."* — Miles focuses on the low-level system optimizations that make large-scale RL stable, efficient, and reproducible. --- -## Arguments Walkthrough - -Arguments in Miles follow the same three-layer pattern as slime: - -1. **Megatron arguments**: Megatron arguments are exposed unchanged, e.g. `--tensor-model-parallel-size 2`. -2. **SGLang arguments**: All SGLang arguments are exposed with a prefix `--sglang-`, e.g. `--mem-fraction-static` → `--sglang-mem-fraction-static`. +## Key Features -3. **Miles-specific arguments*: Please refer to [`miles/utils/arguments.py`](miles/utils/arguments.py) for a full list +### 🌪️ Advanced MoE & Low-Precision Training -For more detailed usage, please refer to the documentation and example configs in the repo as they become available. - +* **Unified FP8 Pipeline**: The first framework to implement end-to-end FP8 sampling and training. By unifying precision across rollout and training, Miles eliminates the quantization-induced discrepancy that causes RL collapse in large MoE models. +* **Rollout Routing Replay (R3)**: Records expert routing decisions during SGLang inference and replays them during training to ensure bit-wise expert alignment. +* **INT4 QAT Support**: Recommendation for 1TB+ models to enable single-machine (e.g., H200) deployment by significantly reducing memory footprint. +### 🛡️ Eliminating Train-Inference Mismatch -## Recent Updates +* **Bit-wise Identical Training and Inference Log Probs**: System-level solution achieving deterministic forward/backward passes through kernel-level optimization (FlashAttention-3, DeepGEMM). +* **Algorithmic Correction (TIS/MIS)**: When mismatch is unavoidable, Miles provides **Truncated Importance Sampling (TIS)** and **Masked Importance Sampling (MIS)** to mitigate off-policy bias and prevent training divergence. -Miles starts from slime’s proven backbone and adds a series of upgrades for production environments. The recent PRs and changes have also been synced to slime side. +### ⚡ Extreme Performance & Efficiency -### ✅ True On-Policy +* **Speculative RL Training**: Achieve **25%+ rollout speedup** by using an **Online SFT Draft Model**. Unlike frozen draft models, Miles updates the draft policy during RL to prevent policy drift. +* **Zero-Copy Weight Sync**: Optimized weight refit via **CUDA IPC zero-copy mapping**, async tensor gathering, and bucketed flattening. Sync time reduced by 50% compared to standard HTTP/RPC transfers. +* **Partial Rollout & Over-Sampling**: Handles the "Long-Tail Effect" in multi-turn RL by over-sampling requests and recycling half-finished trajectories to maximize GPU utilization. -Miles extends slime’s deterministic training and supports **infrastructure-level true on-policy support** for SGLang + FSDP: +## Model Support & Training Diversity -- Keeps the mismatch between **training** and **inference** effectively at **zero** -- Aligns numerical behavior end-to-end between training and deployment -- Uses: - - FlashAttention-3 - - DeepGEMM - - Batch-invariant kernels from Thinking Machines Lab - - `torch.compile` and careful alignment of numeric operations +### 🏗️ Supported Models +Miles supports a wide range of state-of-the-art architectures, with a special emphasis on **DeepSeek, Qwen, Llama** and mainstream models. -This makes Miles suitable for **high-stakes experiments** where repeatability, auditability, and production debugging matter. +| Family | Supported Models | +| :--- | :--- | +| **DeepSeek** | **R1, V3, V3.2** | +| **Qwen** | **Qwen 2, 2.5, 3** | +| **Llama** | **Llama 3, 3.1, 3.3, 4** | +| **Gemma** | **Gemma 2, 3, 3N** | +| **GLM** | **GLM-4.5, GLM-4.6, GLM-4.7** | +| **MiniMax** | **M2, M2.1** | +| **Others** | **Mistral, Mixtral, Phi, gpt-oss and any model supported by SGLang and Megatron** | -### 🧮 Memory Robustness & Efficiency +### 🧩 Diverse Training Scenarios +Miles is designed to handle the complexity of modern RL workloads across various dimensions: +* **Multi-Turn Interaction**: Optimized for complex, multi-round conversations and tool-use scenarios. +* **VLM & LLM Support**: Unified framework for both Vision-Language and pure Text models. +* **Reasoning & Coding**: Specific recipes and optimizations for **Reasoning (Math/Logic)** and **Coding Agent** tasks. +* **Multi-Agent Training**: Support for advanced co-training and collaborative multi-agent reinforcement learning. -To fully utilize precious GPU memory **without** constant OOM failures, Miles includes: - -- Graceful handling of benign OOMs via error propagation -- Memory margins to avoid NCCL-related OOM issues -- Fixes for FSDP excessive memory usage -- Support for move-based and partial offloading -- Host peak memory savings for smoother multi-node training - -The goal is to let large MoE jobs run **closer to the hardware limit** while staying stable. +--- -### ⚡ Speculative Training +## Quick Start -Miles adds **speculative training** support tailored for RL: +### Installation -- Performs **online SFT on the draft model during RL**, instead of freezing it -- Avoids draft policy drift away from the target model -- Achieves **25%+ rollout speedup** vs. frozen MTP, especially in later training stages -- Includes: - - MTP with sequence packing + CP - - Proper loss masking and edge-case handling - - LM head / embedding gradient isolation - - Weight sync flows between Megatron and SGLang +We recommend using our official Docker image for the best performance and compatibility: -### 🧱 Hardware & Examples +```bash +# Pull the latest image +docker pull radixark/miles:latest -Miles actively tracks new hardware and provides usable examples: +# Or install from source +pip install -r requirements.txt +pip install -e . +``` -- GB300 training support, with more recipes coming -- A **formal mathematics (Lean)** example with SFT / RL scripts, showcasing Miles in a verifiable environment setting +### Launch Training -### 🛠 Miscellaneous Improvements +Miles provides a unified entry point for complex RL tasks. Here is an example of FP8 GRPO training for Qwen3: -Additional engineering improvements include: +```bash +python train.py \ + --advantage-estimator grpo \ + --model-name qwen3-30b-a3b \ + --hf-checkpoint /path/to/qwen3-30b-a3b-hf \ + --rollout-batch-size 512 \ + --n-samples-per-prompt 8 +``` -- Enhanced FSDP training backend -- Option to deploy the **rollout subsystem independently** outside the main framework -- Better debugging & profiling: more metrics, post-hoc analyzers, and profiler integration -- Gradual refactoring for clarity and maintainability +For comprehensive guides on environment setup and custom reward functions, see the [Quick Start Guide](docs/en/get_started/quick_start.md). --- ## Roadmap -We are actively evolving Miles toward a **production-ready RL engine** for large-scale MoE and multimodal workloads. Current roadmap items include: +### ✅ Completed -- **Large-scale MoE RL recipes** on new hardware (e.g., GB300 and successors) -- **Multimodal training** support -- **Rollout accelerations** - - Compatibility with SGLang spec v2 for improved performance - - More advanced speculative training schemes (e.g., EAGLE3-style, multi-spec layers) -- **Elasticity & fault tolerance** - - More robust handling of GPU / node failures in long-running jobs -- **Resource scheduling for async training** - - Balancing training and serving in large-scale asynchronous RL systems +- [x] **Unified FP8** E2E Training & Rollout +- [x] **INT4 Quantization-Aware Training (QAT)**: Single-machine 1TB models +- [x] **Speculative RL** with Online SFT +- [x] **Multi-Agent RL** (Co-evolutionary frameworks like [MrlX](https://github.com/AQ-MedAI/MrlX)) +- [x] **Support DeepSeek V3.2 Models** +- [x] **VLM Multi-Turn Training** +- [x] **Aligning SGLang with Megatron in Dense Models** +- [x] **Rollout Routing Replay (R3)** -We’ll continue to iterate based on feedback from users across research labs, startups, and enterprise teams. - ---- - -## Architecture Overview - -Miles inherits slime’s core architecture as below. - - -![arch](./imgs/arch.png) +### 🏗️ In Progress & Planned +- [ ] **Zero mismatch for MoE RL** +- [ ] **Aligning SGLang with Megatron in MoE Models** +- [ ] **Diffusion RL** Support +- [ ] **Omni RL** Support +- [ ] **Diffusion LLM RL** Support +- [ ] **Elastic Resource Scheduling**: Dynamic scaling of rollout vs. training workers -**Module overview:** -- **training (Megatron)** - Main training loop. Reads data from the Data Buffer and synchronizes parameters to the rollout subsystem after updates. - -- **rollout (SGLang + router)** - Generates new samples, including rewards / verifier outputs, and writes them back to the Data Buffer. - -- **data buffer** - Manages prompt initialization, custom data sources, and rollout generation strategies. Serves as the bridge between training and rollout. - -This decoupled design lets you: - -- Swap in different algorithms / reward functions without touching rollout code -- Customize rollout engines independently from training -- Scale rollouts and training differently depending on hardware and deployment constraints --- +## Acknowledgements -## Developer Guide - -* **Contributions welcome!** - We’re especially interested in: - - * New hardware backends & tuning - * MoE RL recipes - * Stability / determinism improvements - * Multimodal & speculative training use cases +Miles is built upon the shoulders of giants in the LLM infrastructure ecosystem: +* **[slime](https://github.com/THUDM/slime)**: The core modular architecture and inspiration. +* **[SGLang](https://github.com/sgl-project/sglang)**: The high-performance inference engine. +* **[Megatron-LM](https://github.com/NVIDIA/Megatron-LM)**: Robust large-scale training components. -* We recommend using [pre-commit](https://pre-commit.com/) to keep style consistent: - -```bash -apt install pre-commit -y -pre-commit install - -# run pre-commit to ensure code style consistency -pre-commit run --all-files --show-diff-on-failure --color=always -``` - -* For debugging tips, performance tuning, and internal architecture notes, see the `docs/` and `developer_guide/` folders (coming soon). - ---- - -## FAQ & Acknowledgements - -* For FAQs, please see `docs/en/get_started/qa.md` (to be added as the project matures). -* **Huge thanks** to the **slime** authors and community — Miles would not exist without slime’s design and ecosystem. -* We also acknowledge and rely on the broader LLM infra ecosystem, including SGLang, Megatron-LM, and related tools. +Special thanks to **InfiXAI Team**, **Ant Group AQ Team**, **SGLang RL Team**, and the **Miles Team**. We also thank **DataCrunch** for compute sponsorship and **NVIDIA** for technical support on Transformer Engine (TE). --- ## Links -* **Miles GitHub**: [https://github.com/radixark/miles](https://github.com/radixark/miles) -* **slime GitHub**: [https://github.com/THUDM/slime](https://github.com/THUDM/slime) +* **GitHub**: [https://github.com/radixark/miles](https://github.com/radixark/miles) +* **Slime Project**: [https://github.com/THUDM/slime](https://github.com/THUDM/slime) +* **Developer Guide**: Check the `docs/` and `examples/` directories for in-depth technical notes. -We’re excited to see what you build — whether you choose **slime**, **Miles**, or both in different parts of your stack. 🚀 +
+**Give Miles a ⭐️ Star if it helps your RL journey!** + +
From 55858437b778eda000c8d41b339b8ddfee4d2759 Mon Sep 17 00:00:00 2001 From: zijiexia <37504505+zijiexia@users.noreply.github.com> Date: Sun, 25 Jan 2026 20:14:35 -0800 Subject: [PATCH 56/77] [Docs] Remove linked blog (#518) --- examples/true_on_policy/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/true_on_policy/README.md b/examples/true_on_policy/README.md index 620564d410..553d2de64d 100644 --- a/examples/true_on_policy/README.md +++ b/examples/true_on_policy/README.md @@ -1,6 +1,6 @@ # True On-Policy between Training and Inference -True on-policy ensures that the log probs generated by inference engine (SGLang) is strictly equal to the one generated by the training Engine. Here's our [blog](https://lmsys.org/blog/2025-12-03-miles-fsdp/) for more details. +True on-policy ensures that the log probs generated by inference engine (SGLang) is strictly equal to the one generated by the training Engine. ## Examples From a8c8687ca7f1e3cf98aa6a63c3c3c0e95a79b6b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Mon, 26 Jan 2026 22:55:29 -0800 Subject: [PATCH 57/77] Adds new blogs to latest update (#520) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index aa9ab01560..24bcb773ba 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,8 @@ ## Latest Updates -* **[2026/01]** 💎 **INT4 Quantization-Aware Training (QAT)**: Inspired by the Kimi K2-Thinking report, Miles now features a full-stack INT4 W4A16 QAT pipeline. This allows 1TB-scale models to fit into single-machine VRAM (e.g., NVIDIA H200), doubling rollout efficiency by eliminating cross-node bottlenecks while maintaining BF16-equivalent accuracy. [Blog Coming Soon] +* **[2026/01]** 💎 **INT4 Quantization-Aware Training (QAT)**: Inspired by the Kimi K2-Thinking report, Miles now features a full-stack INT4 W4A16 QAT pipeline. This allows 1TB-scale models to fit into single-machine VRAM (e.g., NVIDIA H200), doubling rollout efficiency by eliminating cross-node bottlenecks while maintaining BF16-equivalent accuracy. [Blog](https://lmsys.org/blog/2026-01-28-int4-qat/) +* **[2026/01]** 💎 **Unified VLM/LLM Multi-Turn Training**: We provided an implementation for the VLM multi-turn sampling paradigm. Developers only need to write a customized `rollout` function to easily start multi-turn RL for VLM, just like training LLM. [blog](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/vlm-multi-turn/readme-en.md) * **[2026/01]** 🤖 **Multi-Agent Co-Evolution**: Miles now supports **MrlX**, a novel asynchronous co-evolutionary framework for Multi-Agent RL. Achieve superior performance in complex tasks like Doctor-Patient simulations and DeepResearch pipelines by enabling specialized agents to evolve together symbiotically. [[Link]](https://github.com/AQ-MedAI/MrlX) * **[2025/12]** 🔄 **Rollout Routing Replay (R3)**: In collaboration with SGLang, we've launched R3 to solve MoE RL instability. R3 records inference routing decisions and replays them during training, effectively eliminating the "training-inference mismatch" and preventing training collapse in large MoE models like Qwen3 and DeepSeek-V3. [[Paper]](https://arxiv.org/pdf/2510.11370) * **[2025/11]** 🔥 **Unified FP8 Release**: Solves the stability issues in MoE RL by ensuring training and inference use the exact same FP8 quantization logic. [[Blog]](https://lmsys.org/blog/2025-11-25-fp8-rl/) From 81ea4a82e46493fa4e27af5bbf77e36f8c2d62e4 Mon Sep 17 00:00:00 2001 From: "Ethan (Yusheng) Su" Date: Tue, 27 Jan 2026 14:12:08 -0800 Subject: [PATCH 58/77] [CI] Re-organize and enable necessary end2end CI cases (#499) Co-authored-by: Yusheng Su --- .github/workflows/pr-test.yml | 121 +++++++++++++- .github/workflows/pr-test.yml.j2 | 96 ++++++----- miles/backends/training_utils/log_utils.py | 10 +- tests/{ => e2e/ckpt}/test_qwen3_4B_ckpt.py | 0 .../fsdp}/test_qwen3_0.6B_fsdp_distributed.py | 0 .../test_qwen3_0.6B_megatron_fsdp_align.py | 0 .../test_qwen3_4B_fsdp_true_on_policy.py | 0 tests/{ => e2e/fsdp}/test_qwen3_vl_4B_fsdp.py | 0 .../image}/test_mimo_7B_mtp_only_grad.py | 0 .../{ => e2e/image}/test_moonlight_16B_A3B.py | 0 .../image}/test_quick_start_glm4_9B.py | 0 .../image}/test_qwen2.5_0.5B_gsm8k.py | 0 .../image}/test_qwen2.5_0.5B_gsm8k_async.py | 0 .../test_qwen2.5_0.5B_gsm8k_async_short.py | 0 .../image}/test_qwen2.5_0.5B_gsm8k_short.py | 0 .../test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 0 .../image/test_qwen3_0.6B_fsdp_distributed.py | 106 ++++++++++++ .../test_qwen3_0.6B_megatron_fsdp_align.py | 155 ++++++++++++++++++ .../image}/test_qwen3_0.6B_parallel_check.py | 0 tests/{ => e2e/image}/test_qwen3_30B_A3B.py | 0 tests/e2e/image/test_qwen3_4B_ckpt.py | 138 ++++++++++++++++ .../test_qwen3_4B_fsdp_true_on_policy.py | 113 +++++++++++++ tests/{ => e2e/image}/test_qwen3_4B_ppo.py | 0 tests/e2e/image/test_qwen3_vl_4B_fsdp.py | 112 +++++++++++++ tests/e2e/long/test_qwen2.5_0.5B_gsm8k.py | 131 +++++++++++++++ .../e2e/long/test_qwen2.5_0.5B_gsm8k_async.py | 131 +++++++++++++++ .../megatron/test_mimo_7B_mtp_only_grad.py | 147 +++++++++++++++++ tests/e2e/megatron/test_moonlight_16B_A3B.py | 124 ++++++++++++++ .../megatron}/test_moonlight_16B_A3B_r3.py | 0 .../e2e/megatron/test_quick_start_glm4_9B.py | 127 ++++++++++++++ tests/e2e/megatron/test_qwen3_30B_A3B.py | 151 +++++++++++++++++ .../megatron}/test_qwen3_30B_A3B_r3.py | 0 tests/e2e/megatron/test_qwen3_4B_ppo.py | 134 +++++++++++++++ .../test_qwen3_0.6B_megatron_fsdp_align.py | 155 ++++++++++++++++++ .../test_qwen3_0.6B_parallel_check.py | 138 ++++++++++++++++ .../test_qwen2.5_0.5B_gsm8k_async_short.py | 129 +++++++++++++++ .../short/test_qwen2.5_0.5B_gsm8k_short.py | 128 +++++++++++++++ .../test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 104 ++++++++++++ 38 files changed, 2400 insertions(+), 50 deletions(-) rename tests/{ => e2e/ckpt}/test_qwen3_4B_ckpt.py (100%) rename tests/{ => e2e/fsdp}/test_qwen3_0.6B_fsdp_distributed.py (100%) rename tests/{ => e2e/fsdp}/test_qwen3_0.6B_megatron_fsdp_align.py (100%) rename tests/{ => e2e/fsdp}/test_qwen3_4B_fsdp_true_on_policy.py (100%) rename tests/{ => e2e/fsdp}/test_qwen3_vl_4B_fsdp.py (100%) rename tests/{ => e2e/image}/test_mimo_7B_mtp_only_grad.py (100%) rename tests/{ => e2e/image}/test_moonlight_16B_A3B.py (100%) rename tests/{ => e2e/image}/test_quick_start_glm4_9B.py (100%) rename tests/{ => e2e/image}/test_qwen2.5_0.5B_gsm8k.py (100%) rename tests/{ => e2e/image}/test_qwen2.5_0.5B_gsm8k_async.py (100%) rename tests/{ => e2e/image}/test_qwen2.5_0.5B_gsm8k_async_short.py (100%) rename tests/{ => e2e/image}/test_qwen2.5_0.5B_gsm8k_short.py (100%) rename tests/{ => e2e/image}/test_qwen3_0.6B_fsdp_colocated_2xGPU.py (100%) create mode 100644 tests/e2e/image/test_qwen3_0.6B_fsdp_distributed.py create mode 100644 tests/e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py rename tests/{ => e2e/image}/test_qwen3_0.6B_parallel_check.py (100%) rename tests/{ => e2e/image}/test_qwen3_30B_A3B.py (100%) create mode 100644 tests/e2e/image/test_qwen3_4B_ckpt.py create mode 100644 tests/e2e/image/test_qwen3_4B_fsdp_true_on_policy.py rename tests/{ => e2e/image}/test_qwen3_4B_ppo.py (100%) create mode 100644 tests/e2e/image/test_qwen3_vl_4B_fsdp.py create mode 100644 tests/e2e/long/test_qwen2.5_0.5B_gsm8k.py create mode 100644 tests/e2e/long/test_qwen2.5_0.5B_gsm8k_async.py create mode 100644 tests/e2e/megatron/test_mimo_7B_mtp_only_grad.py create mode 100644 tests/e2e/megatron/test_moonlight_16B_A3B.py rename tests/{ => e2e/megatron}/test_moonlight_16B_A3B_r3.py (100%) create mode 100644 tests/e2e/megatron/test_quick_start_glm4_9B.py create mode 100644 tests/e2e/megatron/test_qwen3_30B_A3B.py rename tests/{ => e2e/megatron}/test_qwen3_30B_A3B_r3.py (100%) create mode 100644 tests/e2e/megatron/test_qwen3_4B_ppo.py create mode 100644 tests/e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py create mode 100644 tests/e2e/precision/test_qwen3_0.6B_parallel_check.py create mode 100644 tests/e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py create mode 100644 tests/e2e/short/test_qwen2.5_0.5B_gsm8k_short.py create mode 100644 tests/e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 4b8b5dc82c..e7a909be2d 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -33,7 +33,7 @@ jobs: options: > --gpus all --ipc=host - --shm-size=16g + --shm-size=32g --ulimit memlock=-1 --ulimit stack=67108864 --memory=0 @@ -41,6 +41,9 @@ jobs: -v /mnt/nvme0n1/miles_ci:/data/miles_ci -v /mnt/nvme0n1/miles_ci/models:/root/models -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp strategy: fail-fast: false matrix: @@ -52,11 +55,26 @@ jobs: GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} steps: - name: Checkout repository uses: actions/checkout@v4 + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + - name: Install shell: bash run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages @@ -65,6 +83,84 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + + + unit-test: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-unit-test')) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 2, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + + e2e-test-short: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) runs-on: self-hosted @@ -87,7 +183,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}] + info: [{"num_gpus": 4, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -132,6 +228,7 @@ jobs: ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + e2e-test-fsdp: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-fsdp')) runs-on: self-hosted @@ -154,7 +251,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}] + info: [{"num_gpus": 2, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 4, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -199,6 +296,7 @@ jobs: ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + e2e-test-megatron: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-megatron')) runs-on: self-hosted @@ -221,7 +319,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}] + info: [{"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -266,6 +364,7 @@ jobs: ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + e2e-test-precision: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-precision')) runs-on: self-hosted @@ -288,7 +387,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}] + info: [{"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -333,6 +432,7 @@ jobs: ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + e2e-test-ckpt: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-ckpt')) runs-on: self-hosted @@ -355,7 +455,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}] + info: [{"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}] defaults: run: working-directory: ${{ github.workspace }} @@ -400,6 +500,7 @@ jobs: ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + e2e-test-long: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-long')) runs-on: self-hosted @@ -422,7 +523,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}] + info: [{"num_gpus": 2, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -467,11 +568,12 @@ jobs: ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + e2e-test-image: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-image')) runs-on: self-hosted container: - image: radixark/miles-test:latest + image: radixark/miles:latest options: > --gpus all --ipc=host @@ -489,7 +591,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}] + info: [{"num_gpus": 4, "test_file": "e2e/image/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "e2e/image/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "e2e/image/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "e2e/image/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "e2e/image/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/image/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "e2e/image/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "e2e/image/test_qwen2.5_0.5B_gsm8k_async.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -533,3 +635,4 @@ jobs: pkill -9 -f raylet 2>/dev/null || true ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index c052b8494f..2dae38ff4f 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -5,78 +5,84 @@ {'test_file': 'fast', 'num_gpus': 0}, ], }, + 'unit-test': { + 'label': 'run-unit-test', + 'tests': [ + {'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2} + ], + }, 'e2e-test-short': { 'label': 'run-ci-short', 'tests': [ - {'test_file': 'test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, + {'test_file': 'e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4}, + {'test_file': 'e2e/short/test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4}, + {'test_file': 'e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, ], }, 'e2e-test-fsdp': { 'label': 'run-ci-fsdp', 'tests': [ - {'test_file': 'test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2}, - {'test_file': 'test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, - {'test_file': 'test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, + {'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2}, + {'test_file': 'e2e/fsdp/test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, + {'test_file': 'e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, + {'test_file': 'e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, ], }, 'e2e-test-megatron': { 'label': 'run-ci-megatron', 'tests': [ - {'test_file': 'test_quick_start_glm4_9B.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1'}, - {'test_file': 'test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1', 'enable_eval': '0'}, - {'test_file': 'test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8}, - {'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8}, - {'test_file': 'test_moonlight_16B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, + {'test_file': 'e2e/megatron/test_quick_start_glm4_9B.py', 'num_gpus': 8}, + {'test_file': 'e2e/megatron/test_qwen3_30B_A3B.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1'}, + {'test_file': 'e2e/megatron/test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1', 'enable_eval': '0'}, + {'test_file': 'e2e/megatron/test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'e2e/megatron/test_qwen3_4B_ppo.py', 'num_gpus': 8}, + {'test_file': 'e2e/megatron/test_moonlight_16B_A3B.py', 'num_gpus': 8}, + {'test_file': 'e2e/megatron/test_moonlight_16B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'e2e/megatron/test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, ], }, 'e2e-test-precision': { 'label': 'run-ci-precision', 'tests': [ - {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, + {'test_file': 'e2e/precision/test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, + {'test_file': 'e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, ], }, 'e2e-test-ckpt': { 'label': 'run-ci-ckpt', 'tests': [ - {'test_file': 'test_qwen3_4B_ckpt.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, + {'test_file': 'e2e/ckpt/test_qwen3_4B_ckpt.py', 'num_gpus': 8}, + {'test_file': 'e2e/ckpt/test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, ], }, 'e2e-test-long': { 'label': 'run-ci-long', 'tests': [ - {'test_file': 'test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, - {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, + {'test_file': 'e2e/long/test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, + {'test_file': 'e2e/long/test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, ], }, 'e2e-test-image': { 'label': 'run-ci-image', - 'image': 'radixark/miles-test:latest', + 'image': 'radixark/miles:latest', 'tests': [ - {'test_file': 'test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, - {'test_file': 'test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2}, - {'test_file': 'test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, - {'test_file': 'test_quick_start_glm4_9B.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8}, - {'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8}, - {'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, - {'test_file': 'test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, - {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, + {'test_file': 'e2e/image/test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4}, + {'test_file': 'e2e/image/test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4}, + {'test_file': 'e2e/image/test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, + {'test_file': 'e2e/image/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2}, + {'test_file': 'e2e/image/test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, + {'test_file': 'e2e/image/test_quick_start_glm4_9B.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_30B_A3B.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_4B_ppo.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_moonlight_16B_A3B.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, + {'test_file': 'e2e/image/test_qwen3_4B_ckpt.py', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, + {'test_file': 'e2e/image/test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, + {'test_file': 'e2e/image/test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, ], }, } %> @@ -160,4 +166,14 @@ jobs: - name: Execute shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- << config.test_executor | default('python') >> tests/${{ matrix.info.test_file }} -<% endfor %> \ No newline at end of file + + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + +<% endfor %> diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 1a2f176028..2b9891f64e 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -136,7 +136,15 @@ def log_rollout_data( # NOTE: Here we have to do the clone().detach(), otherwise the tensor will be # modified in place and will cause problem for the next rollout. val = torch.cat(val).clone().detach() - if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]: + if key in [ + "log_probs", + "ref_log_probs", + "rollout_log_probs", + "returns", + "advantages", + "values", + "entropy", + ]: sum_of_sample_mean = get_sum_of_sample_mean( total_lengths, response_lengths, diff --git a/tests/test_qwen3_4B_ckpt.py b/tests/e2e/ckpt/test_qwen3_4B_ckpt.py similarity index 100% rename from tests/test_qwen3_4B_ckpt.py rename to tests/e2e/ckpt/test_qwen3_4B_ckpt.py diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py similarity index 100% rename from tests/test_qwen3_0.6B_fsdp_distributed.py rename to tests/e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py diff --git a/tests/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py similarity index 100% rename from tests/test_qwen3_0.6B_megatron_fsdp_align.py rename to tests/e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py diff --git a/tests/test_qwen3_4B_fsdp_true_on_policy.py b/tests/e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py similarity index 100% rename from tests/test_qwen3_4B_fsdp_true_on_policy.py rename to tests/e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py diff --git a/tests/test_qwen3_vl_4B_fsdp.py b/tests/e2e/fsdp/test_qwen3_vl_4B_fsdp.py similarity index 100% rename from tests/test_qwen3_vl_4B_fsdp.py rename to tests/e2e/fsdp/test_qwen3_vl_4B_fsdp.py diff --git a/tests/test_mimo_7B_mtp_only_grad.py b/tests/e2e/image/test_mimo_7B_mtp_only_grad.py similarity index 100% rename from tests/test_mimo_7B_mtp_only_grad.py rename to tests/e2e/image/test_mimo_7B_mtp_only_grad.py diff --git a/tests/test_moonlight_16B_A3B.py b/tests/e2e/image/test_moonlight_16B_A3B.py similarity index 100% rename from tests/test_moonlight_16B_A3B.py rename to tests/e2e/image/test_moonlight_16B_A3B.py diff --git a/tests/test_quick_start_glm4_9B.py b/tests/e2e/image/test_quick_start_glm4_9B.py similarity index 100% rename from tests/test_quick_start_glm4_9B.py rename to tests/e2e/image/test_quick_start_glm4_9B.py diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k.py similarity index 100% rename from tests/test_qwen2.5_0.5B_gsm8k.py rename to tests/e2e/image/test_qwen2.5_0.5B_gsm8k.py diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k_async.py similarity index 100% rename from tests/test_qwen2.5_0.5B_gsm8k_async.py rename to tests/e2e/image/test_qwen2.5_0.5B_gsm8k_async.py diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k_async_short.py similarity index 100% rename from tests/test_qwen2.5_0.5B_gsm8k_async_short.py rename to tests/e2e/image/test_qwen2.5_0.5B_gsm8k_async_short.py diff --git a/tests/test_qwen2.5_0.5B_gsm8k_short.py b/tests/e2e/image/test_qwen2.5_0.5B_gsm8k_short.py similarity index 100% rename from tests/test_qwen2.5_0.5B_gsm8k_short.py rename to tests/e2e/image/test_qwen2.5_0.5B_gsm8k_short.py diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/e2e/image/test_qwen3_0.6B_fsdp_colocated_2xGPU.py similarity index 100% rename from tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py rename to tests/e2e/image/test_qwen3_0.6B_fsdp_colocated_2xGPU.py diff --git a/tests/e2e/image/test_qwen3_0.6B_fsdp_distributed.py b/tests/e2e/image/test_qwen3_0.6B_fsdp_distributed.py new file mode 100644 index 0000000000..fcd7772882 --- /dev/null +++ b/tests/e2e/image/test_qwen3_0.6B_fsdp_distributed.py @@ -0,0 +1,106 @@ +import os +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" + + +FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + # NOTE cannot be exactly multiple of eval-interval, since async causes some offsets + f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 65} " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 1 " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 256 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + # "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 1 " "--sglang-enable-metrics " + + misc_args = ( + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {1 if FEW_GPU else 2} " + f"--rollout-num-gpus {1 if FEW_GPU else 2} " + "--train-backend fsdp " + ) + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + "--ci-metric-checker-key eval/gsm8k " + "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=2 if FEW_GPU else 4, + megatron_model_type=None, + train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py new file mode 100644 index 0000000000..b89a2f283b --- /dev/null +++ b/tests/e2e/image/test_qwen3_0.6B_megatron_fsdp_align.py @@ -0,0 +1,155 @@ +import os + +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" +MODEL_TYPE = "qwen3-0.6B" +NUM_GPUS = 4 +CP_SIZE = 1 +MEGATRON_TP_SIZE = 1 +MEGATRON_PP_SIZE = 1 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + dir_dst="/root/models", + ) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/" + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 64 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " "--sglang-chunked-prefill-size 4096 " "--sglang-mem-fraction-static 0.75 " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " "--colocate " f"--actor-num-gpus-per-node {NUM_GPUS} " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + debug_data_path = "test_rollout_data_megatron_fsdp_align.pt" + grad_norm_path = "grad_norm_fsdp.pt" + + fsdp_args = ( + "--train-backend fsdp " + "--attn-implementation flash_attention_2 " + "--gradient-checkpointing " + f"--context-parallel-size {CP_SIZE} " + f"--update-weight-buffer-size {512 * 1024 * 1024} " + """--train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' """ + ) + + try: + U.execute_train( + train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + U.execute_train( + train_args=train_args + + ( + f"{fsdp_args}" + f"--load-debug-rollout-data {debug_data_path} " + f"--ci-save-grad-norm {grad_norm_path} " + "--debug-train-only " + ), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + U.execute_train( + train_args=train_args + + ( + f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + f"--tensor-model-parallel-size {MEGATRON_TP_SIZE} " + "--sequence-parallel " + f"--pipeline-model-parallel-size {MEGATRON_PP_SIZE} " + f"--context-parallel-size {CP_SIZE} " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--train-memory-margin-bytes 3221225472 " + f"--load-debug-rollout-data {debug_data_path} " + f"--ci-load-grad-norm {grad_norm_path} " + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--debug-train-only " + ), + num_gpus_per_node=NUM_GPUS, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + megatron_model_type=MODEL_TYPE, + ) + + finally: + if os.path.exists(grad_norm_path): + os.remove(grad_norm_path) + if os.path.exists(debug_data_path): + os.remove(debug_data_path) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_0.6B_parallel_check.py b/tests/e2e/image/test_qwen3_0.6B_parallel_check.py similarity index 100% rename from tests/test_qwen3_0.6B_parallel_check.py rename to tests/e2e/image/test_qwen3_0.6B_parallel_check.py diff --git a/tests/test_qwen3_30B_A3B.py b/tests/e2e/image/test_qwen3_30B_A3B.py similarity index 100% rename from tests/test_qwen3_30B_A3B.py rename to tests/e2e/image/test_qwen3_30B_A3B.py diff --git a/tests/e2e/image/test_qwen3_4B_ckpt.py b/tests/e2e/image/test_qwen3_4B_ckpt.py new file mode 100644 index 0000000000..0df4492e10 --- /dev/null +++ b/tests/e2e/image/test_qwen3_4B_ckpt.py @@ -0,0 +1,138 @@ +import os +from argparse import ArgumentParser + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Qwen3-4B" +MODEL_TYPE = "qwen3-4B" +NUM_GPUS = 8 + + +parser = ArgumentParser() +parser.add_argument("--async-save", action="store_true", help="Whether to test async save/load.") + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"rm -rf /root/models/{MODEL_NAME}_miles") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint( + model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS, dir_dst="/root/models" + ) + + +def execute(mode: str = ""): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + if mode == "save": + ckpt_args += f"--save /root/models/{MODEL_NAME}_miles " + ckpt_args += "--save-interval 2 " + elif mode == "async_save": + ckpt_args += f"--save /root/models/{MODEL_NAME}_miles " + ckpt_args += "--save-interval 2 " + ckpt_args += "--async-save " + elif mode == "load": + ckpt_args += f"--load /root/models/{MODEL_NAME}_miles " + ckpt_args += "--ckpt-step 1 " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--global-batch-size 32 " + "--balance-data " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 16384} " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 --sglang-mem-fraction-static 0.8 --sglang-cuda-graph-bs 1 2 4 8 16 " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + args = parser.parse_args() + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute("save" if not args.async_save else "async_save") + execute("load") diff --git a/tests/e2e/image/test_qwen3_4B_fsdp_true_on_policy.py b/tests/e2e/image/test_qwen3_4B_fsdp_true_on_policy.py new file mode 100644 index 0000000000..03ba4094e9 --- /dev/null +++ b/tests/e2e/image/test_qwen3_4B_fsdp_true_on_policy.py @@ -0,0 +1,113 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +NUM_GPUS = 2 + +MODEL_NAME = "Qwen3-4B" + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 4096 " + "--eval-top-p 0.7 " + ) + + fsdp_args = "--train-backend fsdp " "--update-weight-buffer-size 536870912 " + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + "--sglang-decode-log-interval 1000 " + "--sglang-enable-metrics " + "--sglang-enable-deterministic-inference " + "--sglang-rl-on-policy-target fsdp " + "--sglang-attention-backend fa3 " + "--attn-implementation flash_attention_3 " + "--deterministic-mode " + "--true-on-policy-mode " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{fsdp_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + extra_env_vars = { + "NCCL_ALGO": "allreduce:tree", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + } + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars=extra_env_vars, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_4B_ppo.py b/tests/e2e/image/test_qwen3_4B_ppo.py similarity index 100% rename from tests/test_qwen3_4B_ppo.py rename to tests/e2e/image/test_qwen3_4B_ppo.py diff --git a/tests/e2e/image/test_qwen3_vl_4B_fsdp.py b/tests/e2e/image/test_qwen3_vl_4B_fsdp.py new file mode 100644 index 0000000000..bc4ef3293c --- /dev/null +++ b/tests/e2e/image/test_qwen3_vl_4B_fsdp.py @@ -0,0 +1,112 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +NUM_GPUS = 8 + +MODEL_NAME = "Qwen3-VL-4B-Instruct" +DATASET_NAME = "chenhegu/geo3k_imgurl" + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset(DATASET_NAME) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/geo3k_imgurl/train.parquet " + "--input-key problem " + "--label-key answer " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + ) + + # multimodal keys required for vlm datasets + multimodal_args = '--multimodal-keys \'{"image": "images"}\' ' + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data geo3k /root/datasets/geo3k_imgurl/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 4096 " + ) + + fsdp_args = "--train-backend fsdp " "--gradient-checkpointing " "--update-weight-buffer-size 536870912 " + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + "--sglang-mem-fraction-static 0.6 " + "--sglang-decode-log-interval 1000 " + "--sglang-enable-metrics " + "--sglang-attention-backend fa3 " + "--attn-implementation flash_attention_3 " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{multimodal_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{fsdp_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + extra_env_vars = { + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + } + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars=extra_env_vars, + ) + + +if __name__ == "__main__": + prepare() + os.environ.pop("http_proxy", None) + os.environ.pop("https_proxy", None) + os.environ.pop("HTTP_PROXY", None) + os.environ.pop("HTTPS_PROXY", None) + execute() diff --git a/tests/e2e/long/test_qwen2.5_0.5B_gsm8k.py b/tests/e2e/long/test_qwen2.5_0.5B_gsm8k.py new file mode 100644 index 0000000000..4d7f034f6c --- /dev/null +++ b/tests/e2e/long/test_qwen2.5_0.5B_gsm8k.py @@ -0,0 +1,131 @@ +import os +import miles.utils.external_utils.command_utils as U + + +FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 2 if FEW_GPU else 4 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 250} " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 1 " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 256 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + # "--micro-batch-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + f"--sglang-mem-fraction-static {0.6 if TIGHT_DEVICE_MEMORY else 0.7} " + "--sglang-enable-metrics " + ) + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + "--ci-metric-checker-key eval/gsm8k " + "--ci-metric-checker-threshold 0.55 " # loose threshold at 250 step + ) + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {2 if FEW_GPU else 4} " + "--colocate " + "--megatron-to-hf-mode bridge " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/long/test_qwen2.5_0.5B_gsm8k_async.py b/tests/e2e/long/test_qwen2.5_0.5B_gsm8k_async.py new file mode 100644 index 0000000000..32b60f5937 --- /dev/null +++ b/tests/e2e/long/test_qwen2.5_0.5B_gsm8k_async.py @@ -0,0 +1,131 @@ +import os +import miles.utils.external_utils.command_utils as U + +FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 2 if FEW_GPU else 4 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 250} " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 1 " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 256 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + # "--micro-batch-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + f"--sglang-mem-fraction-static {0.6 if TIGHT_DEVICE_MEMORY else 0.7} " + "--sglang-enable-metrics " + ) + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + "--ci-metric-checker-key eval/gsm8k " + "--ci-metric-checker-threshold 0.55 " # loose threshold at 250 step + ) + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {1 if FEW_GPU else 2} " + f"--rollout-num-gpus {1 if FEW_GPU else 2} " + "--megatron-to-hf-mode bridge " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/megatron/test_mimo_7B_mtp_only_grad.py b/tests/e2e/megatron/test_mimo_7B_mtp_only_grad.py new file mode 100644 index 0000000000..d90a2d7a71 --- /dev/null +++ b/tests/e2e/megatron/test_mimo_7B_mtp_only_grad.py @@ -0,0 +1,147 @@ +"""End-to-end test for MTP-only gradient verification. + +This test verifies that when MTP training is enabled and all outputs are truncated +(due to very short max response length), only MTP parameters receive non-zero +gradients while all other model parameters have zero gradients. + +This validates that the MTP loss computation correctly isolates gradient flow +to only the MTP layers when the main model loss is zero (due to truncation). +""" + +import os + +import miles.utils.external_utils.command_utils as U + + +MODEL_NAME = "MiMo-7B-RL" +MODEL_TYPE = "mimo-7B-rl" +NUM_GPUS = 8 + + +def prepare(): + """Download model and convert checkpoint with MTP layers.""" + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download XiaomiMiMo/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + # Convert checkpoint with MTP layers enabled + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + extra_args=" --mtp-num-layers 1", + dir_dst="/root/models", + ) + + +def execute(): + """Run training with MTP enabled and very short output length to cause truncation.""" + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + + # Use very short rollout-max-response-len to ensure all outputs are truncated + # This should result in zero loss for the main model, leaving only MTP loss + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 2 " + # Very short max response length to cause all outputs to be truncated + "--rollout-max-response-len 128 " + "--rollout-temperature 0.8 " + "--global-batch-size 8 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 4096 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " + "--rollout-num-gpus 8 " + "--sglang-mem-fraction-static 0.8 " + "--sglang-enable-metrics " + "--sglang-speculative-algorithm EAGLE " + "--sglang-speculative-num-steps 2 " + "--sglang-speculative-eagle-topk 1 " + "--sglang-speculative-num-draft-tokens 3 " + ) + + # Enable MTP training with loss scaling + mtp_args = "--mtp-num-layers 1 " "--enable-mtp-training " "--mtp-loss-scaling-factor 0.2 " + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + # MTP grad check is automatically triggered when ci_test and enable_mtp_training are both set + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{sglang_args} " + f"{mtp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + # Remove proxy settings that might interfere with local operations + for key in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"]: + os.environ.pop(key, None) + execute() diff --git a/tests/e2e/megatron/test_moonlight_16B_A3B.py b/tests/e2e/megatron/test_moonlight_16B_A3B.py new file mode 100644 index 0000000000..c35943ec15 --- /dev/null +++ b/tests/e2e/megatron/test_moonlight_16B_A3B.py @@ -0,0 +1,124 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Moonlight-16B-A3B-Instruct" +MODEL_TYPE = "moonlight" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command( + "hf download moonshotai/Moonlight-16B-A3B-Instruct --local-dir /root/models/Moonlight-16B-A3B-Instruct" + ) + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 4096 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 2048} " + ) + + grpo_args = ( + "--advantage-estimator gspo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " "--sglang-mem-fraction-static 0.8 " "--sglang-max-running-requests 512 " + ) + + ci_args = "--ci-test " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_moonlight_16B_A3B_r3.py b/tests/e2e/megatron/test_moonlight_16B_A3B_r3.py similarity index 100% rename from tests/test_moonlight_16B_A3B_r3.py rename to tests/e2e/megatron/test_moonlight_16B_A3B_r3.py diff --git a/tests/e2e/megatron/test_quick_start_glm4_9B.py b/tests/e2e/megatron/test_quick_start_glm4_9B.py new file mode 100644 index 0000000000..ae3c383ae8 --- /dev/null +++ b/tests/e2e/megatron/test_quick_start_glm4_9B.py @@ -0,0 +1,127 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = U.get_bool_env_var("MILES_TEST_ENABLE_EVAL", "1") +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "GLM-Z1-9B-0414" +MODEL_TYPE = "glm4-9B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command("hf download zai-org/GLM-Z1-9B-0414 --local-dir /root/models/GLM-Z1-9B-0414") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 16384 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_DEVICE_MEMORY else 4608} " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + "--use-tis " + "--calculate-per-token-loss " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 " "--use-miles-router " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 4 " + "--rollout-num-gpus 4 " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/megatron/test_qwen3_30B_A3B.py b/tests/e2e/megatron/test_qwen3_30B_A3B.py new file mode 100644 index 0000000000..b30eeed8e5 --- /dev/null +++ b/tests/e2e/megatron/test_qwen3_30B_A3B.py @@ -0,0 +1,151 @@ +import os + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) +USE_DEEPEP = bool(int(os.environ.get("MILES_TEST_USE_DEEPEP", "1"))) +USE_FP8_ROLLOUT = bool(int(os.environ.get("MILES_TEST_USE_FP8_ROLLOUT", "1"))) + +MODEL_NAME = "Qwen3-30B-A3B" +MODEL_TYPE = "qwen3-30B-A3B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command("hf download Qwen/Qwen3-30B-A3B --local-dir /root/models/Qwen3-30B-A3B") + U.exec_command("hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/models/Qwen3-30B-A3B-FP8") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + if USE_FP8_ROLLOUT: + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}-FP8 " f"--ref-load /root/{MODEL_NAME}_torch_dist " + else: + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 16384 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 4 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 16384} " + ) + + grpo_args = ( + "--advantage-estimator gspo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + "--use-tis " + "--use-routing-replay " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 8 " + "--sglang-mem-fraction-static 0.8 " + "--sglang-max-running-requests 512 " + "--sglang-enable-metrics " + ) + + if USE_DEEPEP: + sglang_args += "--sglang-moe-a2a-backend deepep --sglang-deepep-mode auto " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + if USE_DEEPEP: + misc_args += "--moe-token-dispatcher-type flex --moe-enable-deepep " + else: + misc_args += "--moe-token-dispatcher-type alltoall " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_30B_A3B_r3.py b/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py similarity index 100% rename from tests/test_qwen3_30B_A3B_r3.py rename to tests/e2e/megatron/test_qwen3_30B_A3B_r3.py diff --git a/tests/e2e/megatron/test_qwen3_4B_ppo.py b/tests/e2e/megatron/test_qwen3_4B_ppo.py new file mode 100644 index 0000000000..d4c1ac273a --- /dev/null +++ b/tests/e2e/megatron/test_qwen3_4B_ppo.py @@ -0,0 +1,134 @@ +import os + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Qwen3-4B" +MODEL_TYPE = "qwen3-4B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command("hf download Qwen/Qwen3-4B --local-dir /root/models/Qwen3-4B") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 0.8 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 16384 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 16384} " + ) + + ppo_args = ( + "--advantage-estimator ppo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + "--num-critic-only-steps 1 " + "--normalize-advantages " + "--critic-lr 1e-5 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " + "--rollout-num-gpus 8 " + "--sglang-mem-fraction-static 0.8 " + "--sglang-max-running-requests 512 " + "--sglang-enable-metrics " + ) + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 4 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py new file mode 100644 index 0000000000..b89a2f283b --- /dev/null +++ b/tests/e2e/precision/test_qwen3_0.6B_megatron_fsdp_align.py @@ -0,0 +1,155 @@ +import os + +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" +MODEL_TYPE = "qwen3-0.6B" +NUM_GPUS = 4 +CP_SIZE = 1 +MEGATRON_TP_SIZE = 1 +MEGATRON_PP_SIZE = 1 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + dir_dst="/root/models", + ) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/" + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 64 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " "--sglang-chunked-prefill-size 4096 " "--sglang-mem-fraction-static 0.75 " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " "--colocate " f"--actor-num-gpus-per-node {NUM_GPUS} " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + debug_data_path = "test_rollout_data_megatron_fsdp_align.pt" + grad_norm_path = "grad_norm_fsdp.pt" + + fsdp_args = ( + "--train-backend fsdp " + "--attn-implementation flash_attention_2 " + "--gradient-checkpointing " + f"--context-parallel-size {CP_SIZE} " + f"--update-weight-buffer-size {512 * 1024 * 1024} " + """--train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' """ + ) + + try: + U.execute_train( + train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + U.execute_train( + train_args=train_args + + ( + f"{fsdp_args}" + f"--load-debug-rollout-data {debug_data_path} " + f"--ci-save-grad-norm {grad_norm_path} " + "--debug-train-only " + ), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + U.execute_train( + train_args=train_args + + ( + f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + f"--tensor-model-parallel-size {MEGATRON_TP_SIZE} " + "--sequence-parallel " + f"--pipeline-model-parallel-size {MEGATRON_PP_SIZE} " + f"--context-parallel-size {CP_SIZE} " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--train-memory-margin-bytes 3221225472 " + f"--load-debug-rollout-data {debug_data_path} " + f"--ci-load-grad-norm {grad_norm_path} " + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--debug-train-only " + ), + num_gpus_per_node=NUM_GPUS, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + megatron_model_type=MODEL_TYPE, + ) + + finally: + if os.path.exists(grad_norm_path): + os.remove(grad_norm_path) + if os.path.exists(debug_data_path): + os.remove(debug_data_path) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/precision/test_qwen3_0.6B_parallel_check.py b/tests/e2e/precision/test_qwen3_0.6B_parallel_check.py new file mode 100644 index 0000000000..d0ad283d15 --- /dev/null +++ b/tests/e2e/precision/test_qwen3_0.6B_parallel_check.py @@ -0,0 +1,138 @@ +import os + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Qwen3-0.6B" +MODEL_TYPE = "qwen3-0.6B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + U.convert_checkpoint( + model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS, dir_dst="/root/models" + ) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 0.8 " + "--global-batch-size 32 " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 " "--rollout-num-gpus 8 " "--sglang-mem-fraction-static 0.8 " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + for i in range(2): + U.execute_train( + train_args=train_args + + ( + f"--save-debug-rollout-data data-{i}.pt " + f"--ci-save-grad-norm grad_norms-{i}.pt " + f"--actor-num-gpus-per-node {NUM_GPUS} " + ), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + # 8 GPU CPU 1 + for num_gpus in [8, 4, 2]: + remaining_gpus = num_gpus + for tp_size in [1, 2, 4, 8]: + remaining_gpus /= tp_size + for pp_size in [1, 2, 4]: + if remaining_gpus < pp_size: + continue + remaining_gpus /= pp_size + for cp_size in [1, 2, 4, 8]: + if remaining_gpus < cp_size: + continue + args = train_args + ( + f"--load-debug-rollout-data data-{i}.pt " + f"--ci-load-grad-norm grad_norms-{i}.pt " + f"--context-parallel-size {cp_size} " + f"--tensor-model-parallel-size {tp_size} " + f"--pipeline-model-parallel-size {pp_size} " + "--sequence-parallel " + f"--actor-num-gpus-per-node {num_gpus} " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + U.execute_train( + train_args=args, + num_gpus_per_node=num_gpus, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + train_args += "--calculate-per-token-loss " + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py b/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py new file mode 100644 index 0000000000..b1954a4e83 --- /dev/null +++ b/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py @@ -0,0 +1,129 @@ +import os +import miles.utils.external_utils.command_utils as U + +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 4 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 4 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--over-sampling-batch-size 16 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 32 " + ) + + eval_args = ( + "--eval-interval 8 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + f"--sglang-mem-fraction-static {0.55 if TIGHT_DEVICE_MEMORY else 0.65} " + "--sglang-enable-metrics " + ) + + ci_args = "--ci-test " + + fault_tolerance_args = ( + "--use-fault-tolerance " + "--rollout-health-check-interval 5 " + "--rollout-health-check-timeout 10 " + "--rollout-health-check-first-wait 0 " + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 1 " + "--rollout-num-gpus 3 " + "--megatron-to-hf-mode bridge " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{fault_tolerance_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_short.py b/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_short.py new file mode 100644 index 0000000000..86e21eac8d --- /dev/null +++ b/tests/e2e/short/test_qwen2.5_0.5B_gsm8k_short.py @@ -0,0 +1,128 @@ +import os +import miles.utils.external_utils.command_utils as U + +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 4 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 4 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--over-sampling-batch-size 16 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 32 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + f"--sglang-mem-fraction-static {0.6 if TIGHT_DEVICE_MEMORY else 0.7} " + "--sglang-enable-metrics " + ) + + ci_args = "--ci-test " + + fault_tolerance_args = ( + "--use-fault-tolerance " + "--rollout-health-check-interval 5 " + "--rollout-health-check-timeout 10 " + "--rollout-health-check-first-wait 0 " + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 4 " + "--colocate " + "--megatron-to-hf-mode bridge " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{fault_tolerance_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py new file mode 100644 index 0000000000..3d4768e420 --- /dev/null +++ b/tests/e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -0,0 +1,104 @@ +import os +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 60} " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 1 " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 256 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + # "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 " "--sglang-decode-log-interval 1000 " "--sglang-enable-metrics " + + fsdp_args = ( + # Set to true for FULL_STATE_DICT mode, false for SHARDED_STATE_DICT mode (default) + # "--fsdp-full-params " # Uncomment this line to enable full params mode + # Set the bucket size for weight update + "--update-weight-buffer-size 536870912 " # 512MB + ) + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + "--ci-metric-checker-key eval/gsm8k " + "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step + ) + + misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{sglang_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{eval_args} " + f"{fsdp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=2, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() From 9440fd1b26b6514ff8843b8bb3006a2c2a692dc3 Mon Sep 17 00:00:00 2001 From: Yihao Wang <42559837+AgainstEntropy@users.noreply.github.com> Date: Tue, 27 Jan 2026 15:11:08 -0800 Subject: [PATCH 59/77] fix: fix arg parsing by making help a string instead of a tuple (#522) --- miles/utils/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 0710202924..922d3a4dc4 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1022,7 +1022,7 @@ def add_wandb_arguments(parser): default=None, help=( "Log statistics of the category of reward, such as why the reward function considers it as failed. " - "Specify the key in the reward dict using this argument.", + "Specify the key in the reward dict using this argument." ), ) parser.add_argument( From a681b875fa591edb38cf400cb25d7e9b0bddd90c Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Wed, 28 Jan 2026 15:34:42 -0800 Subject: [PATCH 60/77] [fix] mbridge incorrectly handle all weight precision to bf16 (#524) --- tools/convert_hf_to_torch_dist.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tools/convert_hf_to_torch_dist.py b/tools/convert_hf_to_torch_dist.py index d6fddf386c..994700293b 100644 --- a/tools/convert_hf_to_torch_dist.py +++ b/tools/convert_hf_to_torch_dist.py @@ -1,6 +1,7 @@ import gc import os import shutil +from functools import wraps import torch import torch.distributed as dist @@ -11,6 +12,7 @@ import miles_plugins.mbridge # noqa: F401 from mbridge import AutoBridge +from mbridge.core.bridge import Bridge from miles.backends.megatron_utils.arguments import set_default_megatron_args from miles.backends.megatron_utils.initialize import init from miles.backends.megatron_utils.model_provider import get_model_provider_func @@ -18,6 +20,24 @@ from miles.utils.memory_utils import print_memory +def patch_weight_to_mcore_format_preserve_fp32(): + + original_method = Bridge._weight_to_mcore_format + + @wraps(original_method) + def patched_method(self, mcore_weights_name, hf_weights): + original_dtype = getattr(self, "dtype", None) + self.dtype = None + try: + result = original_method(self, mcore_weights_name, hf_weights) + finally: + self.dtype = original_dtype + return result + + Bridge._weight_to_mcore_format = patched_method + print("[Patch] Applied patch to preserve FP32 precision in _weight_to_mcore_format") + + def add_convertion_args(parser): """Add conversion arguments to the parser""" parser.add_argument("--hf-checkpoint", type=str, required=True, help="HuggingFace model path") @@ -111,6 +131,10 @@ def main(): # Load model hf_model_path = args.hf_checkpoint bridge = AutoBridge.from_pretrained(hf_model_path, trust_remote_code=True) + + # Patch to preserve FP32 precision for _keep_fp32 params + patch_weight_to_mcore_format_preserve_fp32() + bridge.load_weights(model, hf_model_path, memory_efficient=True) print(f"Model loaded: {hf_model_path}") From d22bc8c17fe7a74a3f86265b94b4ac41e9c42b8c Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Wed, 28 Jan 2026 17:34:59 -0800 Subject: [PATCH 61/77] fix int4 kernel setup.py (#527) --- .../megatron_utils/kernels/int4_qat/setup.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/miles/backends/megatron_utils/kernels/int4_qat/setup.py b/miles/backends/megatron_utils/kernels/int4_qat/setup.py index b27967bc98..8715dd7b8a 100644 --- a/miles/backends/megatron_utils/kernels/int4_qat/setup.py +++ b/miles/backends/megatron_utils/kernels/int4_qat/setup.py @@ -1,3 +1,4 @@ +import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension import torch @@ -10,6 +11,16 @@ arch_list.append(f"{major}.{minor}") arch_list = sorted(set(arch_list)) +# Fallback to TORCH_CUDA_ARCH_LIST env var or default architectures when GPU is not available +if not arch_list: + env_arch = os.environ.get("TORCH_CUDA_ARCH_LIST", "") + if env_arch: + # Parse TORCH_CUDA_ARCH_LIST format: "7.0 7.5 8.0 8.6 9.0+PTX" + arch_list = [a.strip().replace("+PTX", "") for a in env_arch.replace(";", " ").split() if a.strip()] + else: + # Default to common architectures (Volta, Turing, Ampere, Ada, Hopper) + arch_list = ["8.0", "8.6", "8.9", "9.0"] + setup( name="fake_int4_quant_cuda", ext_modules=[ @@ -31,7 +42,8 @@ + [ f'-gencode=arch=compute_{arch.replace(".", "")},code=sm_{arch.replace(".", "")}' for arch in arch_list - ], + ] + + ["-gencode=arch=compute_90a,code=sm_90a"], }, ) ], From ebdea20f4fb1a874a2042cb25beacdc22a196faa Mon Sep 17 00:00:00 2001 From: Hudson Xing <1277646412@qq.com> Date: Fri, 30 Jan 2026 18:45:12 +0800 Subject: [PATCH 62/77] fix: missing comma in runtime env JSON for qwen3-235B-A22B (#531) --- scripts/run-qwen3-235B-A22B.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/run-qwen3-235B-A22B.sh b/scripts/run-qwen3-235B-A22B.sh index e42e17ab29..ffd5972ac0 100644 --- a/scripts/run-qwen3-235B-A22B.sh +++ b/scripts/run-qwen3-235B-A22B.sh @@ -161,7 +161,7 @@ RUNTIME_ENV_JSON="{ \"env_vars\": { \"PYTHONPATH\": \"/root/Megatron-LM/\", \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", \"no_proxy\": \"${no_proxy}\", \"MASTER_ADDR\": \"${MASTER_ADDR}\" } From be0c84097928d27be759a8d0de473b7017d8539e Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Sun, 1 Feb 2026 13:15:39 -0800 Subject: [PATCH 63/77] Fix accuracy bug in data packing for thd (#542) --- miles/backends/training_utils/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/miles/backends/training_utils/data.py b/miles/backends/training_utils/data.py index 67bb30108d..c7b0707425 100644 --- a/miles/backends/training_utils/data.py +++ b/miles/backends/training_utils/data.py @@ -121,7 +121,7 @@ def get_batch( tokens = batch["tokens"] # use 0 as the pad token id should be fine? pad_token_id = 0 - pad_size = parallel_state.dp_size * pad_multiplier + pad_size = parallel_state.tp_size * pad_multiplier # for cp, we need all tokens to calculate logprob batch["unconcat_tokens"] = tokens From 511e56016a8f3b086fda106d16cf151c6e4d125f Mon Sep 17 00:00:00 2001 From: Yuzhen Zhou <82826991+zyzshishui@users.noreply.github.com> Date: Mon, 2 Feb 2026 01:52:37 -0800 Subject: [PATCH 64/77] Fix Memory Leak on Rocm Offload (#545) --- docker/amd_patch/latest/megatron.patch | 232 +------------------------ 1 file changed, 9 insertions(+), 223 deletions(-) diff --git a/docker/amd_patch/latest/megatron.patch b/docker/amd_patch/latest/megatron.patch index c840133cef..b9e6a61d7c 100644 --- a/docker/amd_patch/latest/megatron.patch +++ b/docker/amd_patch/latest/megatron.patch @@ -1,5 +1,5 @@ diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py -index fe26e8b4..4451f277 100644 +index fe26e8b43..4451f2776 100644 --- a/megatron/core/distributed/__init__.py +++ b/megatron/core/distributed/__init__.py @@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads @@ -19,7 +19,7 @@ index fe26e8b4..4451f277 100644 + if hasattr(custom_fsdp, 'MegatronFSDP'): + custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py -index 99c3edc0..26ea5cb4 100644 +index 99c3edc05..26ea5cb4b 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -404,6 +404,7 @@ class TELinear(te.pytorch.Linear): @@ -31,7 +31,7 @@ index 99c3edc0..26ea5cb4 100644 # Reduce the gradient on the expert_data_parallel group for expert linear layers setattr(param, "allreduce", not self.expert_parallel) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py -index 002edb92..f7273488 100755 +index 002edb925..f72734885 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -80,6 +80,8 @@ def get_gpt_layer_with_transformer_engine_spec( @@ -56,7 +56,7 @@ index 002edb92..f7273488 100755 "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py -index df9adc3e..2f4f544a 100644 +index df9adc3ef..2f4f544a7 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -443,7 +443,7 @@ class GPTModel(LanguageModule): @@ -69,7 +69,7 @@ index df9adc3e..2f4f544a 100644 input_ids=input_ids, position_ids=position_ids, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py -index 57332ac3..f3abd642 100644 +index 57332ac39..f2d0fa9c8 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -9,6 +9,7 @@ from typing import Callable, List, Optional @@ -80,222 +80,8 @@ index 57332ac3..f3abd642 100644 from .utils import GlobalMemoryBuffer, is_torch_min_version -@@ -163,6 +164,213 @@ def get_nccl_options(pg_name, nccl_comm_cfgs): - return None - - -+old_new_group = None -+ -+ -+def monkey_patch_torch_dist(): -+ print("Applying monkey patch to torch.distributed", flush=True) -+ global old_new_group -+ if old_new_group is not None: -+ return -+ -+ old_new_group = dist.new_group -+ -+ def new_group(*args, **kwargs): -+ group = old_new_group(*args, **kwargs) -+ # skip none nccl group. -+ if ( -+ len(args) >= 3 and args[2] == "gloo" or -+ "backend" in kwargs and kwargs["backend"] == "gloo" -+ ): -+ return group -+ -+ # Get ranks from arguments -+ if len(args) >= 1 and args[0] is not None: -+ ranks = args[0] -+ elif "ranks" in kwargs and kwargs["ranks"] is not None: -+ ranks = kwargs["ranks"] -+ else: -+ # If no ranks specified, use all ranks in world -+ ranks = list(range(dist.get_world_size())) -+ -+ if len(ranks) == 1: -+ return group -+ -+ group = ReloadableProcessGroup(group, ranks) -+ return group -+ -+ dist.new_group = new_group -+ -+ def get_new_function(func): -+ def new_function(*args, **kwargs): -+ args = ( -+ arg.group if isinstance(arg, ReloadableProcessGroup) else arg -+ for arg in args -+ ) -+ kwargs = { -+ k: (v.group if isinstance(v, ReloadableProcessGroup) else v) -+ for k, v in kwargs.items() -+ } -+ return func(*args, **kwargs) -+ return new_function -+ -+ dist.get_rank = get_new_function(dist.get_rank) -+ dist.get_world_size = get_new_function(dist.get_world_size) -+ dist.get_backend = get_new_function(dist.get_backend) -+ dist.get_global_rank = get_new_function(dist.get_global_rank) -+ dist.get_group_rank = get_new_function(dist.get_group_rank) -+ dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks) -+ -+ dist.all_reduce = get_new_function(dist.all_reduce) -+ dist.all_gather = get_new_function(dist.all_gather) -+ dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor) -+ dist.all_gather_object = get_new_function(dist.all_gather_object) -+ dist.all_to_all = get_new_function(dist.all_to_all) -+ dist.all_to_all_single = get_new_function(dist.all_to_all_single) -+ dist.broadcast = get_new_function(dist.broadcast) -+ dist.reduce = get_new_function(dist.reduce) -+ dist.reduce_scatter = get_new_function(dist.reduce_scatter) -+ dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor) -+ dist.scatter = get_new_function(dist.scatter) -+ dist.gather = get_new_function(dist.gather) -+ dist.barrier = get_new_function(dist.barrier) -+ dist.send = get_new_function(dist.send) -+ dist.recv = get_new_function(dist.recv) -+ dist._coalescing_manager = get_new_function(dist._coalescing_manager) -+ -+ # p2p -+ old_isend = dist.isend -+ old_irecv = dist.irecv -+ -+ dist.isend = get_new_function(dist.isend) -+ dist.irecv = get_new_function(dist.irecv) -+ -+ def get_new_p2pop_function(func): -+ def new_function(*args, **kwargs): -+ def convert(arg): -+ if isinstance(arg, ReloadableProcessGroup): -+ return arg.group -+ elif arg == dist.isend: -+ arg = old_isend -+ elif arg == dist.irecv: -+ arg = old_irecv -+ return arg -+ -+ args = (convert(arg) for arg in args) -+ kwargs = { -+ k: convert(v) -+ for k, v in kwargs.items() -+ } -+ return func(*args, **kwargs) -+ return new_function -+ -+ dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__) -+ dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__) -+ -+ -+ -+class ReloadableProcessGroup(torch.distributed.ProcessGroup): -+ GROUPS = [] -+ -+ def __init__(self, group, ranks): -+ super().__init__( -+ rank=dist.get_rank(group), -+ size=dist.get_world_size(group), -+ ) -+ #print(f"Creating ReloadableProcessGroup with ranks: {ranks}", flush=True) -+ self.group = group -+ self.group_info = { -+ "ranks": ranks, -+ } -+ ReloadableProcessGroup.GROUPS.append(self) -+ -+ def __getattr__(self, name): -+ return getattr(self.group, name) -+ -+ @staticmethod -+ def destroy_process_groups(): -+ for reloadable_group in ReloadableProcessGroup.GROUPS: -+ if reloadable_group.group is None: -+ continue -+ #print(f"Destroying process group: {reloadable_group.group_info['ranks']}") -+ dist.destroy_process_group(reloadable_group.group) -+ del reloadable_group.group -+ reloadable_group.group = None -+ -+ @staticmethod -+ def reload_process_groups(): -+ for reloadable_group in ReloadableProcessGroup.GROUPS: -+ if reloadable_group.group is not None: -+ continue -+ #print(f"Reloading process group: {reloadable_group.group_info['ranks']}") -+ group = old_new_group( -+ ranks=reloadable_group.group_info["ranks"], -+ backend="nccl" -+ ) -+ reloadable_group.group = group -+ -+ def rank(self) -> int: return self.group.rank() -+ def size(self) -> int: return self.group.size() -+ def name(self) -> str: return self.group.name() -+ -+ def shutdown(self) -> None: -+ if self.group is not None: -+ self.group.shutdown() -+ -+ def abort(self) -> None: -+ if self.group is not None: -+ self.group.abort() -+ -+ def _fwd(self, method, *args, **kwargs): -+ inner = self.group -+ if inner is None: -+ raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.") -+ return getattr(inner, method)(*args, **kwargs) -+ -+ def barrier(self, *a, **kw): return self._fwd("barrier", *a, **kw) -+ def broadcast(self, *a, **kw): return self._fwd("broadcast", *a, **kw) -+ def allreduce(self, *a, **kw): return self._fwd("allreduce", *a, **kw) -+ def allreduce_coalesced(self, *a, **kw): return self._fwd("allreduce_coalesced", *a, **kw) -+ def reduce(self, *a, **kw): return self._fwd("reduce", *a, **kw) -+ def allgather(self, *a, **kw): return self._fwd("allgather", *a, **kw) -+ def _allgather_base(self, *a, **kw): return self._fwd("_allgather_base", *a, **kw) -+ def allgather_coalesced(self, *a, **kw): return self._fwd("allgather_coalesced", *a, **kw) -+ def allgather_into_tensor_coalesced(self, *a, **kw): return self._fwd("allgather_into_tensor_coalesced", *a, **kw) -+ def gather(self, *a, **kw): return self._fwd("gather", *a, **kw) -+ def scatter(self, *a, **kw): return self._fwd("scatter", *a, **kw) -+ def reduce_scatter(self, *a, **kw): return self._fwd("reduce_scatter", *a, **kw) -+ def _reduce_scatter_base(self, *a, **kw): return self._fwd("_reduce_scatter_base", *a, **kw) -+ def reduce_scatter_tensor_coalesced(self, *a, **kw): return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw) -+ def alltoall_base(self, *a, **kw): return self._fwd("alltoall_base", *a, **kw) -+ def alltoall(self, *a, **kw): return self._fwd("alltoall", *a, **kw) -+ def send(self, *a, **kw): return self._fwd("send", *a, **kw) -+ def recv(self, *a, **kw): return self._fwd("recv", *a, **kw) -+ def recv_anysource(self, *a, **kw): return self._fwd("recv_anysource", *a, **kw) -+ -+ def _start_coalescing(self, *a, **kw): return self._fwd("_start_coalescing", *a, **kw) -+ def _end_coalescing(self, *a, **kw): return self._fwd("_end_coalescing", *a, **kw) -+ def _get_backend_name(self): return self._fwd("_get_backend_name") -+ def _get_backend(self, *a, **kw): return self._fwd("_get_backend", *a, **kw) -+ def _set_default_backend(self, *a, **kw): return self._fwd("_set_default_backend", *a, **kw) -+ @property -+ def bound_device_id(self): return self.group.bound_device_id -+ @bound_device_id.setter -+ def bound_device_id(self, dev): self.group.bound_device_id = dev -+ -+ -+def destroy_process_groups(): -+ """Destroy all reloadable process groups.""" -+ ReloadableProcessGroup.destroy_process_groups() -+ -+ -+def reload_process_groups(): -+ """Reload all reloadable process groups.""" -+ ReloadableProcessGroup.reload_process_groups() -+ -+ -+monkey_patch_torch_dist() -+ -+ - def create_group( - ranks=None, - timeout=None, diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py -index 63ee9d1f..b90b744c 100644 +index 63ee9d1f5..b90b744c1 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( @@ -326,7 +112,7 @@ index 63ee9d1f..b90b744c 100644 ops.append(recv_next_op) if len(ops) > 0: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index 6f557e1f..b295fd35 100644 +index 6f557e1f5..b295fd351 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig): @@ -340,7 +126,7 @@ index 6f557e1f..b295fd35 100644 """Whether to run real-time tests.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py -index 84f22bde..b4807d26 100644 +index 84f22bdea..b4807d261 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -224,6 +224,7 @@ class TransformerLayerSubmodules: @@ -412,7 +198,7 @@ index 84f22bde..b4807d26 100644 # discard the output of the pre-mlp layernorm and register the recompute # as a gradient hook of mlp_output_with_bias[0] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index 24ba8926..4f039fd4 100644 +index 24ba89263..4f039fd43 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1191,6 +1191,9 @@ def core_transformer_config_from_args(args, config_class=None): From 6bc0dd5390ad83be1a1ceb3a154e2ca376bfdaca Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Mon, 2 Feb 2026 23:38:17 -0800 Subject: [PATCH 65/77] [bugfix] Fix R3 padding (#551) --- miles/backends/megatron_utils/actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index a92198a674..f196164876 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -228,7 +228,7 @@ def pad_func(experts, pad): # TODO: maybe extract a common process function for here and get_batch? rollout_routed_experts = [slice_with_cp(r, pad_func, self.parallel_state) for r in rollout_routed_experts] rollout_routed_experts = torch.cat(rollout_routed_experts, dim=0) - pad_size = self.parallel_state.dp_size * self.args.data_pad_size_multiplier + pad_size = self.parallel_state.tp_size * self.args.data_pad_size_multiplier pad = (pad_size - rollout_routed_experts.size(0) % pad_size) % pad_size if pad != 0: rollout_routed_experts = pad_func(rollout_routed_experts, pad) From e8c5a8cfa0898131cffe9cd9c3f035db3e420fc9 Mon Sep 17 00:00:00 2001 From: Jinn <47354855+jhinpan@users.noreply.github.com> Date: Tue, 3 Feb 2026 15:55:12 -0600 Subject: [PATCH 66/77] Super tiny fix for blog link (#553) Co-authored-by: root --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 24bcb773ba..2bda4c95b3 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,8 @@ ## Latest Updates -* **[2026/01]** 💎 **INT4 Quantization-Aware Training (QAT)**: Inspired by the Kimi K2-Thinking report, Miles now features a full-stack INT4 W4A16 QAT pipeline. This allows 1TB-scale models to fit into single-machine VRAM (e.g., NVIDIA H200), doubling rollout efficiency by eliminating cross-node bottlenecks while maintaining BF16-equivalent accuracy. [Blog](https://lmsys.org/blog/2026-01-28-int4-qat/) -* **[2026/01]** 💎 **Unified VLM/LLM Multi-Turn Training**: We provided an implementation for the VLM multi-turn sampling paradigm. Developers only need to write a customized `rollout` function to easily start multi-turn RL for VLM, just like training LLM. [blog](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/vlm-multi-turn/readme-en.md) +* **[2026/01]** 💎 **INT4 Quantization-Aware Training (QAT)**: Inspired by the Kimi K2-Thinking report, Miles now features a full-stack INT4 W4A16 QAT pipeline. This allows 1TB-scale models to fit into single-machine VRAM (e.g., NVIDIA H200), doubling rollout efficiency by eliminating cross-node bottlenecks while maintaining BF16-equivalent accuracy. [Blog](https://lmsys.org/blog/2026-01-26-int4-qat/) +* **[2026/01]** 💎 **Unified VLM/LLM Multi-Turn Training**: We provided an implementation for the VLM multi-turn sampling paradigm. Developers only need to write a customized `rollout` function to easily start multi-turn RL for VLM, just like training LLM. [Blog](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/vlm-multi-turn/readme-en.md) * **[2026/01]** 🤖 **Multi-Agent Co-Evolution**: Miles now supports **MrlX**, a novel asynchronous co-evolutionary framework for Multi-Agent RL. Achieve superior performance in complex tasks like Doctor-Patient simulations and DeepResearch pipelines by enabling specialized agents to evolve together symbiotically. [[Link]](https://github.com/AQ-MedAI/MrlX) * **[2025/12]** 🔄 **Rollout Routing Replay (R3)**: In collaboration with SGLang, we've launched R3 to solve MoE RL instability. R3 records inference routing decisions and replays them during training, effectively eliminating the "training-inference mismatch" and preventing training collapse in large MoE models like Qwen3 and DeepSeek-V3. [[Paper]](https://arxiv.org/pdf/2510.11370) * **[2025/11]** 🔥 **Unified FP8 Release**: Solves the stability issues in MoE RL by ensuring training and inference use the exact same FP8 quantization logic. [[Blog]](https://lmsys.org/blog/2025-11-25-fp8-rl/) From e67fc49220d9a7956900f5d523a57b943daba8aa Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 5 Feb 2026 14:08:18 -0800 Subject: [PATCH 67/77] Update CODEOWNERS with new paths and owners (#564) --- .github/CODEOWNERS | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index dc0cc7cbc9..416cb53856 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,3 +1,8 @@ .github/CODEOWNERS @fzyzcjy @Ying1123 .github/workflows/ @yushengsu-thu /miles/ @fzyzcjy @yueming-yuan +/miles/backends/ @maocheng23 +/miles/ray/ @maocheng23 +/miles/rollout/ @guapisolo +/miles/router/ @guapisolo +/miles/utils/ @guapisolo @maocheng23 From 32d4c8552dab63ec793e40318983d9a2c8013703 Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:35:48 -0800 Subject: [PATCH 68/77] feat: Support OAI TITO v1 (#502) --- .github/workflows/pr-test.yml | 132 ++++----- .github/workflows/pr-test.yml.j2 | 19 +- docker/patch/v0.5.7/sglang.patch | 207 ++++++++++++++ docs/en/get_started/gen_endpoint.md | 104 +++++++ docs/en/get_started/oai_endpoint.md | 136 +++++++++ examples/openai_format/__init__.py | 1 + examples/openai_format/dapo_math.py | 19 ++ examples/openai_format/run-qwen3-4B.sh | 158 +++++++++++ .../rollout/generate_hub/agentic_tool_call.py | 80 +++--- .../generate_utils/openai_endpoint_utils.py | 25 +- miles/router/middleware_hub/radix_tree.py | 4 +- .../middleware_hub/radix_tree_middleware.py | 27 +- miles/router/router.py | 43 +-- miles/router/session/naive_trajectory.py | 70 +++++ miles/router/session/session_types.py | 15 + miles/router/session/sessions.py | 94 +++++++ miles/router/sessions.py | 124 --------- miles/utils/arguments.py | 11 + miles/utils/test_utils/mock_sglang_server.py | 8 +- miles/utils/test_utils/mock_tools.py | 55 ++++ tests/e2e/sglang_patch/__init__.py | 0 tests/e2e/sglang_patch/sglang_server.py | 123 +++++++++ .../test_chat_input_ids_equivalence.py | 122 ++++++++ tests/fast/fixtures/generation_fixtures.py | 22 +- tests/fast/router/test_sessions.py | 261 +++++++++--------- .../test_utils/test_mock_sglang_server.py | 29 +- 26 files changed, 1449 insertions(+), 440 deletions(-) create mode 100644 docs/en/get_started/gen_endpoint.md create mode 100644 docs/en/get_started/oai_endpoint.md create mode 100644 examples/openai_format/__init__.py create mode 100644 examples/openai_format/dapo_math.py create mode 100644 examples/openai_format/run-qwen3-4B.sh create mode 100644 miles/router/session/naive_trajectory.py create mode 100644 miles/router/session/session_types.py create mode 100644 miles/router/session/sessions.py delete mode 100644 miles/router/sessions.py create mode 100644 tests/e2e/sglang_patch/__init__.py create mode 100644 tests/e2e/sglang_patch/sglang_server.py create mode 100644 tests/e2e/sglang_patch/test_chat_input_ids_equivalence.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index e7a909be2d..eb2e20b9c8 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -83,16 +83,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - - unit-test: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-unit-test')) runs-on: self-hosted @@ -151,15 +141,63 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() + e2e-test-sglang: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-sglang')) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 1, "test_file": "e2e/sglang_patch/test_chat_input_ids_equivalence.py"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes shell: bash run: | pkill -9 -f 'ray::' 2>/dev/null || true pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true ray stop --force 2>/dev/null || true rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} e2e-test-short: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) @@ -219,16 +257,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - - e2e-test-fsdp: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-fsdp')) runs-on: self-hosted @@ -287,16 +315,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - - e2e-test-megatron: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-megatron')) runs-on: self-hosted @@ -355,16 +373,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - - e2e-test-precision: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-precision')) runs-on: self-hosted @@ -423,16 +431,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - - e2e-test-ckpt: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-ckpt')) runs-on: self-hosted @@ -491,16 +489,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - - e2e-test-long: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-long')) runs-on: self-hosted @@ -559,16 +547,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - - e2e-test-image: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-image')) runs-on: self-hosted @@ -626,13 +604,3 @@ jobs: - name: Execute shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 2dae38ff4f..5fdcc201f2 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -8,7 +8,14 @@ 'unit-test': { 'label': 'run-unit-test', 'tests': [ - {'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2} + {'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2} + ], + }, + 'e2e-test-sglang': { + 'label': 'run-ci-sglang', + 'test_executor': 'pytest', + 'tests': [ + {'test_file': 'e2e/sglang_patch/test_chat_input_ids_equivalence.py', 'num_gpus': 1}, ], }, 'e2e-test-short': { @@ -166,14 +173,4 @@ jobs: - name: Execute shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- << config.test_executor | default('python') >> tests/${{ matrix.info.test_file }} - - - name: Post-test cleanup - if: always() - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - <% endfor %> diff --git a/docker/patch/v0.5.7/sglang.patch b/docker/patch/v0.5.7/sglang.patch index 42d23ed659..ea44e24be1 100644 --- a/docker/patch/v0.5.7/sglang.patch +++ b/docker/patch/v0.5.7/sglang.patch @@ -74,6 +74,213 @@ index 0478526ef..cfb1aa669 100644 def get_pipeline_model_parallel_world_size(): +diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py +index 34aa364cf..da5d0d6b6 100644 +--- a/python/sglang/srt/entrypoints/openai/protocol.py ++++ b/python/sglang/srt/entrypoints/openai/protocol.py +@@ -81,6 +81,7 @@ class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) ++ token_ids: List[int] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + +@@ -92,6 +93,7 @@ class TopLogprob(BaseModel): + + class ChatCompletionTokenLogprob(BaseModel): + token: str ++ token_id: int + bytes: List[int] + logprob: float + top_logprobs: List[TopLogprob] +@@ -501,6 +503,7 @@ class ChatCompletionRequest(BaseModel): + top_k: Optional[int] = None + min_p: Optional[float] = None + min_tokens: int = 0 ++ logprob_start_len: Optional[int] = None + regex: Optional[str] = None + ebnf: Optional[str] = None + repetition_penalty: Optional[float] = None +@@ -536,6 +539,9 @@ class ChatCompletionRequest(BaseModel): + + # For data parallel rank routing + data_parallel_rank: Optional[int] = None ++ ++ # Input ids, if provided, it will override the message input. ++ input_ids: Optional[Union[List[List[int]], List[int]]] = None + + # OpenAI/SGLang default sampling parameters + _DEFAULT_SAMPLING_PARAMS = { +@@ -618,8 +624,8 @@ class ChatCompletionRequest(BaseModel): + + def to_sampling_params( + self, +- stop: List[str], + model_generation_config: Dict[str, Any], ++ stop: Optional[List[str]] = None, + tool_call_constraint: Optional[ToolCallConstraint] = None, + ) -> Dict[str, Any]: + """ +@@ -706,6 +712,7 @@ class ChatMessage(BaseModel): + class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage ++ input_token_ids: Optional[List[int]] = None + logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None + finish_reason: Optional[ + Literal[ +diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py +index cb0c084a3..daa0db7bb 100644 +--- a/python/sglang/srt/entrypoints/openai/serving_chat.py ++++ b/python/sglang/srt/entrypoints/openai/serving_chat.py +@@ -146,6 +146,62 @@ class OpenAIServingChat(OpenAIServingBase): + + return None + ++ def _convert_chat_completion_with_input_ids_to_internal_request( ++ self, ++ request: ChatCompletionRequest, ++ raw_request: Request = None, ++ ) -> tuple[GenerateReqInput, ChatCompletionRequest]: ++ ++ # Notice: currently, if input_ids is provided, the stop token is not used. ++ sampling_params = request.to_sampling_params( ++ model_generation_config=self.default_sampling_params ++ ) ++ ++ prompt_kwargs = {"input_ids": request.input_ids} ++ ++ # Extract custom labels from raw request headers ++ custom_labels = self.extract_custom_labels(raw_request) ++ ++ # Resolve LoRA adapter from model parameter or explicit lora_path ++ lora_path = self._resolve_lora_path(request.model, request.lora_path) ++ if lora_path: ++ first_adapter = ( ++ lora_path ++ if isinstance(lora_path, str) ++ else next((a for a in lora_path if a), None) ++ ) ++ if first_adapter: ++ self._validate_lora_enabled(first_adapter) ++ ++ logprob_start_len = ( ++ request.logprob_start_len if request.logprob_start_len is not None else -1 ++ ) ++ ++ adapted_request = GenerateReqInput( ++ **prompt_kwargs, ++ sampling_params=sampling_params, ++ return_logprob=request.logprobs, ++ logprob_start_len=logprob_start_len, ++ top_logprobs_num=request.top_logprobs or 0, ++ stream=request.stream, ++ return_text_in_logprobs=True, ++ lora_path=lora_path, ++ bootstrap_host=request.bootstrap_host, ++ bootstrap_port=request.bootstrap_port, ++ bootstrap_room=request.bootstrap_room, ++ data_parallel_rank=request.data_parallel_rank, ++ return_hidden_states=request.return_hidden_states, ++ rid=request.rid, ++ extra_key=self._compute_extra_key(request), ++ require_reasoning=self._get_reasoning_from_request(request), ++ priority=request.priority, ++ custom_labels=custom_labels, ++ custom_logit_processor=request.custom_logit_processor, ++ ) ++ ++ return adapted_request, request ++ ++ + def _convert_to_internal_request( + self, + request: ChatCompletionRequest, +@@ -162,6 +218,9 @@ class OpenAIServingChat(OpenAIServingBase): + """Convert OpenAI chat completion request to internal format""" + is_multimodal = self.tokenizer_manager.model_config.is_multimodal + ++ if request.input_ids: ++ return self._convert_chat_completion_with_input_ids_to_internal_request(request, raw_request) ++ + # Process messages and apply chat template + processed_messages = self._process_messages(request, is_multimodal) + +@@ -195,6 +254,10 @@ class OpenAIServingChat(OpenAIServingBase): + if first_adapter: + self._validate_lora_enabled(first_adapter) + ++ logprob_start_len = ( ++ request.logprob_start_len if request.logprob_start_len is not None else -1 ++ ) ++ + adapted_request = GenerateReqInput( + **prompt_kwargs, + image_data=processed_messages.image_data, +@@ -202,7 +265,7 @@ class OpenAIServingChat(OpenAIServingBase): + audio_data=processed_messages.audio_data, + sampling_params=sampling_params, + return_logprob=request.logprobs, +- logprob_start_len=-1, ++ logprob_start_len=logprob_start_len, + top_logprobs_num=request.top_logprobs or 0, + stream=request.stream, + return_text_in_logprobs=True, +@@ -768,8 +831,13 @@ class OpenAIServingChat(OpenAIServingBase): + for idx, ret_item in enumerate(ret): + # Process logprobs + choice_logprobs = None ++ input_token_ids = None + if request.logprobs: + choice_logprobs = self._process_response_logprobs(ret_item) ++ input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] ++ input_token_ids = [ ++ token_id for _, token_id, _ in input_token_logprobs ++ ] + + # Handle hidden states + hidden_states = process_hidden_states_from_ret(ret_item, request) +@@ -824,6 +892,7 @@ class OpenAIServingChat(OpenAIServingBase): + tool_calls=tool_calls, + reasoning_content=reasoning_text if reasoning_text else None, + ), ++ input_token_ids=input_token_ids, + logprobs=choice_logprobs, + finish_reason=finish_reason["type"] if finish_reason else None, + matched_stop=( +@@ -865,6 +934,7 @@ class OpenAIServingChat(OpenAIServingBase): + for token_idx, (token, logprob) in enumerate( + zip(logprobs.tokens, logprobs.token_logprobs) + ): ++ token_id = logprobs.token_ids[token_idx] + token_bytes = list(token.encode("utf-8")) + top_logprobs = [] + if logprobs.top_logprobs: +@@ -885,6 +955,7 @@ class OpenAIServingChat(OpenAIServingBase): + token_logprobs.append( + ChatCompletionTokenLogprob( + token=token, ++ token_id=token_id, + bytes=token_bytes, + logprob=logprob, + top_logprobs=top_logprobs, +diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py +index 94ac5458d..e718ddb2a 100644 +--- a/python/sglang/srt/entrypoints/openai/utils.py ++++ b/python/sglang/srt/entrypoints/openai/utils.py +@@ -19,9 +19,10 @@ def to_openai_style_logprobs( + ret_logprobs = LogProbs() + + def append_token_logprobs(token_logprobs): +- for logprob, _, token_text in token_logprobs: ++ for logprob, token_id, token_text in token_logprobs: + ret_logprobs.tokens.append(token_text) + ret_logprobs.token_logprobs.append(logprob) ++ ret_logprobs.token_ids.append(token_id) + + # Not supported yet + ret_logprobs.text_offset.append(-1) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index b07164c53..8e6722ce0 100644 --- a/python/sglang/srt/layers/layernorm.py diff --git a/docs/en/get_started/gen_endpoint.md b/docs/en/get_started/gen_endpoint.md new file mode 100644 index 0000000000..a8e9d2ae1e --- /dev/null +++ b/docs/en/get_started/gen_endpoint.md @@ -0,0 +1,104 @@ +# Gen Endpoint Usage + +This document covers generate_hub usage for the `/generate` endpoint. For OpenAI +format usage, see `docs/en/get_started/oai_endpoint.md`. + +## 1. What generate_hub is + +`miles/rollout/generate_hub/` contains reusable generate functions that plug into +rollout through `--custom-generate-function-path`. They use the refactor +interface (`GenerateFnInput` / `GenerateFnOutput`) and are meant to be composed +with custom agents, tool use, or multi-turn logic. + +Key types and entry points: + +- `miles/rollout/base_types.py` defines `GenerateFnInput` and `GenerateFnOutput`. +- `miles/rollout/inference_rollout/inference_rollout_common.py` builds a + `GenerateState` and calls the generate function. +- `MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1` enables the new path (see + `examples/openai_format/*.sh`). + +## 2. Generate function basics + +The intended abstraction is: + +1. The rollout engine provides a `GenerateFnInput` with: + - `state` (tokenizer, processor, args, sampling defaults) + - `sample` (prompt, current tokens, response, status) + - `sampling_params` (max_new_tokens, temperature, top_p, etc.) +2. The generate function focuses only on: + - turning the sample into a model request + - executing the request (SGLang `/generate` or OpenAI format) + - updating the `Sample` with tokens, logprobs, loss mask, and status + +Minimal skeleton: + +```python +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.utils.types import Sample + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + args = input.args + sample = input.sample + sampling_params = input.sampling_params + + # 1) build request from prompt and sampling params + # 2) call backend + # 3) update sample.tokens, sample.response, sample.rollout_log_probs, sample.loss_mask, sample.status + + return GenerateFnOutput(samples=sample) + +def _add_arguments(parser): + parser.add_argument("--your-arg", type=str) + +generate.add_arguments = _add_arguments +``` + +Notes: + +- `generate.add_arguments = _add_arguments` is the hook for custom CLI flags. + Add any arguments you want; they are parsed into `input.args` and can be used + freely by your generator without touching rollout core code. +- Use `compute_prompt_ids_from_sample` and `compute_request_payload` from + `miles/rollout/generate_utils/generate_endpoint_utils.py` to build requests + for the `/generate` endpoint. +- If you want to return multiple samples, set `--generate-multi-samples` and + return a list. + +## 3. /generate endpoint examples + +Examples (library side): + +- `miles/rollout/generate_hub/single_turn.py` + - Single-turn generation using `/generate`. + - Works with text or multimodal prompts. +- `miles/rollout/generate_hub/multi_turn.py` + - Multi-turn tool calling using `/generate`. + - CLI flags: `--generate-max-turns`, `--generate-tool-specs-path`, + `--generate-tool-call-parser`, `--generate-execute-tool-function-path`, + `--generate-multi-samples`. +- `miles/rollout/generate_hub/benchmarkers.py` + - Benchmark helper that forces random output sequence length (OSL). + +## 4. Radix tree middleware helper (full TITO for `/generate`) + +Full TITO caching for the `/generate` endpoint is provided by the radix tree +middleware. This is unrelated to session middleware and works only on the +`/generate` and `/retrieve_from_text` routes. + +What it does: + +- Caches token ids and logprobs by prompt text in a radix tree. +- Lets `/generate` requests include `input_tokens` and avoids re-tokenization. +- Enables `update_sample_from_response` to fetch tokens via + `/retrieve_from_text` for training. + +How to enable: + +``` +--use-miles-router \ +--miles-router-middleware-paths miles.router.middleware_hub.radix_tree_middleware.RadixTreeMiddleware +``` + +Make sure `--sglang-router-ip` and `--sglang-router-port` point to the Miles +Router so `/retrieve_from_text` can be reached during rollout. diff --git a/docs/en/get_started/oai_endpoint.md b/docs/en/get_started/oai_endpoint.md new file mode 100644 index 0000000000..9b882ec846 --- /dev/null +++ b/docs/en/get_started/oai_endpoint.md @@ -0,0 +1,136 @@ +# OAI Endpoint Usage + +This document explains how to use the OpenAI-format chat endpoint through Miles +Router sessions. For the `/generate` endpoint, see +`docs/en/get_started/gen_endpoint.md`. + +## 1. Minimal `run_agent` loop + +Your `run_agent` receives a session-scoped `base_url`. Send OpenAI-format chat +requests to `base_url/v1/chat/completions` and pass the `messages` list as the +prompt. + +Minimal custom agent example: + +```python +from miles.utils.http_utils import post + +async def run_agent(base_url: str, prompt, request_kwargs: dict | None = None) -> None: + payload = {"model": "default", "messages": prompt, **(request_kwargs or {})} + await post(f"{base_url}/v1/chat/completions", payload) +``` + +Notes for `run_agent`: + +- `base_url` already includes the session path (e.g. `/sessions/`), so you + should not manually add the session id. Just append the OpenAI route. +- `request_kwargs` already contains the default sampling settings from + `agentic_tool_call.build_chat_request_kwargs`, so you can directly expand it + into the chat request payload. +- If you pass rollout sampling params, `max_new_tokens` will be mapped to the + OpenAI `max_tokens` field before the request is sent. +- If you need structured parsing payloads, use SGLang's + `ChatCompletionRequest`-compatible format. It is compatible with native OpenAI + fields, plus extra SGLang parameters. + +## 2. OpenAI chat messages and the basic request + +The OpenAI-format chat API uses a list of `messages`, each with a `role` and +`content`. + +Minimal request shape: + +```json +{ + "model": "default", + "messages": [ + {"role": "system", "content": "You are a concise assistant."}, + {"role": "user", "content": "Answer with one word: 2+2?"} + ], + "logprobs": true, + "logprob_start_len": 0 +} +``` + +You can pass any OpenAI-compatible parameters in the payload, or any +SGLang-compatible `ChatCompletionRequest` parameters. Note: +`logprobs=True` and `logprob_start_len=0` are required to extract token ids and +logprobs for TITO (see below), and are already set in `request_kwargs`. + +## 3. Quickstart index + +If you just want something runnable, start here: + +Generator entry point: + +- `miles/rollout/generate_hub/agentic_tool_call.py` + - OpenAI-format agent loop via router sessions. + +OpenAI-format examples that use `agentic_tool_call.generate`: + +- `examples/openai_format/dapo_math.py` + - Single-turn OpenAI format agent (DAPO math). +- Launcher scripts: + - `examples/openai_format/run-qwen3-4B-dapo-math.sh` + + +You can customize generate function like: +``` +CUSTOM_ARGS=( + --custom-generate-function-path miles.rollout.generate_hub.agentic_tool_call.generate + --custom-agent-function-path examples.openai_format.dapo_math.run_agent +) +``` + +For OpenAI format, do not add `--apply-chat-template`; the +prompt must remain a `messages` list. + +More agentic multi-turn examples will come in the future. + +## 4. Further customization (OpenAI wrapper generate function) + +For OpenAI-format rollout, the key generate function is +`miles/rollout/generate_hub/agentic_tool_call.generate`. It is a thin wrapper +around your custom agent: + +1. Create a session on Miles Router and build a session-scoped `base_url`. +2. Call the custom agent (from `--custom-agent-function-path`) to send one or + more chat requests to `base_url/v1/chat/completions`, typically using + `prompt` and `request_kwargs`. +3. Collect session records via `OpenAIEndpointTracer`. +4. Convert records into `Sample` objects with + `compute_samples_from_openai_records`. + +If you want general generate-function customization beyond the OpenAI wrapper, +see `docs/en/get_started/gen_endpoint.md`. + +## 5. TITO (token-in token-out) + +TITO needs two things: + +1. Prompt token ids returned by the backend (e.g. `input_logprobs` or + `input_token_ids`). These can come from tokenizing `messages`, or from a + provided `input_ids` payload. +2. Output token ids returned by the backend (`logprobs.content[*].token_id`). + +By default, the session middleware forwards raw `messages` to SGLang. With +`logprobs=True` and `logprob_start_len=0`, SGLang tokenizes the prompt and +returns prompt token ids along with output token ids, which is sufficient for +TITO. You do not need to provide `input_ids`. + +If you prefer to send `input_ids` to SGLang, you can enable token input for chat +completions in the router via +`--miles-router-enable-token-input-for-chat-completions`. The session route +will tokenize `messages` and inject `input_ids` before proxying to SGLang. The +backend still returns prompt token ids, and they should match any `input_ids` +you supplied. + +We can save multi-turn samples within a single session, but we still do not +inherit or reuse prompt tokens across turns. Each request is tokenized +independently, regardless of which option you choose. + +### Common pitfalls + +- Ensure `logprobs=True` in OpenAI chat requests, and ensure + `logprob_start_len=0` if you rely on SGLang to return prompt token ids. +- Ensure the tokenizer matches `--hf-checkpoint`. diff --git a/examples/openai_format/__init__.py b/examples/openai_format/__init__.py new file mode 100644 index 0000000000..30436bcc42 --- /dev/null +++ b/examples/openai_format/__init__.py @@ -0,0 +1 @@ +"""OpenAI format examples.""" diff --git a/examples/openai_format/dapo_math.py b/examples/openai_format/dapo_math.py new file mode 100644 index 0000000000..dae0c4ed9a --- /dev/null +++ b/examples/openai_format/dapo_math.py @@ -0,0 +1,19 @@ +""" +Custom agent example: single-turn DAPO math via OpenAI endpoints. +""" + +from __future__ import annotations + +from typing import Any + + +# Notice: only function based agent can use post API in miles +from miles.utils.http_utils import post + + +async def run_agent( + base_url: str, prompt: list[dict[str, Any]] | str, request_kwargs: dict[str, Any] | None = None +) -> None: + request_kwargs = request_kwargs or {} + payload = {"model": "default", "messages": prompt, "logprobs": True, **request_kwargs} + await post(base_url + "/v1/chat/completions", payload) diff --git a/examples/openai_format/run-qwen3-4B.sh b/examples/openai_format/run-qwen3-4B.sh new file mode 100644 index 0000000000..d6bbfddec6 --- /dev/null +++ b/examples/openai_format/run-qwen3-4B.sh @@ -0,0 +1,158 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 +export MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../scripts/models/qwen3-4B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/shared/Qwen3-4B + #--hf-checkpoint /root/shared/Qwen3-4B-FP8 + --ref-load /root/shared/Qwen3-4B_torch_dist +# --load /root/shared/Qwen3-4B_miles/ + --save /root/shared/Qwen3-4B_miles/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --rollout-shuffle + --rm-type deepscaler + --num-rollout 200 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 8192 + --rollout-temperature 1 + + --global-batch-size 256 + --balance-data +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 16 + --eval-max-response-len 16384 + --eval-top-p 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 1 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project miles-oai + --wandb-group qwen3-4B-test + --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.8 +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash + --use-miles-router +) + +CUSTOM_ARGS=( + --custom-generate-function-path miles.rollout.generate_hub.agentic_tool_call.generate + --custom-agent-function-path examples.openai_format.dapo_math.run_agent +) + +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export CUDA_VISIBLE_DEVICES=4,5,6,7 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py index 05223a6544..3e1ae9ef54 100644 --- a/miles/rollout/generate_hub/agentic_tool_call.py +++ b/miles/rollout/generate_hub/agentic_tool_call.py @@ -3,10 +3,10 @@ """ import argparse -from copy import deepcopy +from collections.abc import Callable from typing import Any -from openai import AsyncOpenAI +from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput from miles.rollout.generate_utils.openai_endpoint_utils import ( @@ -14,19 +14,20 @@ compute_samples_from_openai_records, ) from miles.rollout.generate_utils.sample_utils import merge_samples -from miles.rollout.generate_utils.tool_call_utils import execute_tool_calls from miles.utils.misc import load_function async def generate(input: GenerateFnInput) -> GenerateFnOutput: tracer = await OpenAIEndpointTracer.create(input.args) - await _run_blackbox_tool_call_agent( + custom_agent_function: Callable = load_function(input.args.custom_agent_function_path) + assert ( + custom_agent_function is not None + ), f"Custom agent function {input.args.custom_agent_function_path} not found" + await custom_agent_function( base_url=tracer.base_url, prompt=input.sample.prompt, - max_turns=input.args.generate_max_turns, - tool_specs_path=input.args.generate_tool_specs_path, - execute_tool_function_path=input.args.generate_execute_tool_function_path, + request_kwargs=build_chat_request_kwargs(input.sampling_params), ) records = await tracer.collect_records() @@ -37,49 +38,32 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: def _add_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--generate-max-turns", type=int, default=16) - parser.add_argument("--generate-tool-specs-path", type=str) - parser.add_argument("--generate-execute-tool-function-path", type=str) - parser.add_argument("--generate-multi-samples", action="store_true") + parser.add_argument("--custom-agent-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true", default=False) generate.add_arguments = _add_arguments -async def _run_blackbox_tool_call_agent( - base_url: str, - prompt: list[dict[str, Any]], - max_turns: int, - tool_specs_path: str, - execute_tool_function_path: str, -): - """ - Imagine this is a black-box agent, e.g. SWE-agent, which does arbitrarily complex work, - only understands OpenAI compatible API, and never understands Miles or the Sample data structure. - """ - - # ----------------------- Setup ------------------------- - - client = AsyncOpenAI(base_url=base_url, api_key="empty") - execute_tool_function = load_function(execute_tool_function_path) - tool_specs = load_function(tool_specs_path) - - # ----------------------- Initial prompts ------------------------- - - messages = deepcopy(prompt) - - for _turn in range(max_turns): - # ----------------------- Call inference endpoint ------------------------- - - response = await client.chat.completions.create(model="default", messages=messages, tools=tool_specs) - - choice = response.choices[0] - messages.append(choice.message.model_dump()) - - if choice.finish_reason in ("stop", "length"): - break - - # ----------------------- Execute tools ------------------------- - - if x := choice.message.tool_calls: - messages += await execute_tool_calls(x, execute_tool_function) +# Process keys to match ChatCompletionRequest input +def build_chat_request_kwargs(sampling_params: dict[str, Any]) -> dict[str, Any]: + request_kwargs = dict(sampling_params) + key_map = { + "max_new_tokens": "max_tokens", + "min_new_tokens": "min_tokens", + "sampling_seed": "seed", + } + for src, dst in key_map.items(): + if src in request_kwargs: + if dst not in request_kwargs: + request_kwargs[dst] = request_kwargs[src] + request_kwargs.pop(src, None) + + # Notice: Here we force the inference backend to return token information and start from 0 + # The start len should be 0 to make sure prompt token ids and be correctly returned from SGLang. + request_kwargs["logprobs"] = True + request_kwargs["logprob_start_len"] = 0 + + reserved_keys = {"model", "messages"} + allowed_keys = set(ChatCompletionRequest.model_fields) - reserved_keys + return {key: value for key, value in request_kwargs.items() if key in allowed_keys and value is not None} diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index 73ba8198bf..f5bf52d6ca 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -6,7 +6,7 @@ from argparse import Namespace from copy import deepcopy -from miles.router.sessions import GetSessionResponse, SessionRecord +from miles.router.session.sessions import GetSessionResponse, SessionRecord from miles.utils.http_utils import post from miles.utils.types import Sample @@ -17,16 +17,21 @@ class OpenAIEndpointTracer: def __init__(self, router_url: str, session_id: str): self.router_url = router_url self.session_id = session_id - self.base_url = f"{router_url}/sessions/{session_id}/v1" + self.base_url = f"{router_url}/sessions/{session_id}" @staticmethod async def create(args: Namespace): router_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" - session_id = (await post(f"{router_url}/sessions", {}))["session_id"] + response = await post(f"{router_url}/sessions", {}, action="post") + session_id = response["session_id"] return OpenAIEndpointTracer(router_url=router_url, session_id=session_id) async def collect_records(self) -> list[SessionRecord]: - response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="get") + try: + response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="get") + except Exception as e: + logger.warning(f"Failed to get session {self.session_id} records: {e}") + raise response = GetSessionResponse.model_validate(response) records = response.records @@ -35,7 +40,7 @@ async def collect_records(self) -> list[SessionRecord]: except Exception as e: logger.warning(f"Failed to delete session {self.session_id} after collecting records: {e}") - return records + return records or [] def compute_samples_from_openai_records(input_sample: Sample, records: list[SessionRecord], tokenizer) -> list[Sample]: @@ -45,11 +50,19 @@ def compute_samples_from_openai_records(input_sample: Sample, records: list[Sess def _compute_sample_from_openai_record(input_sample: Sample, record: SessionRecord, tokenizer) -> Sample: # TODO may refine after @guapisolo's implementation choice = record.response["choices"][0] + + input_token_ids = choice["input_token_ids"] output_token_ids = [item["token_id"] for item in choice["logprobs"]["content"]] output_log_probs = [item["logprob"] for item in choice["logprobs"]["content"]] sample = deepcopy(input_sample) - sample.tokens = record.request["input_ids"] + output_token_ids + # sample.tokens = record.request["input_ids"] + output_token_ids + request_input_ids = record.request.get("input_ids") + if request_input_ids is not None: + assert ( + request_input_ids == input_token_ids + ), "for prompt part, input_ids return by sglang should match with the request input_ids" + sample.tokens = input_token_ids + output_token_ids sample.rollout_log_probs = output_log_probs sample.response = tokenizer.decode(output_token_ids) sample.response_length = len(output_token_ids) diff --git a/miles/router/middleware_hub/radix_tree.py b/miles/router/middleware_hub/radix_tree.py index 6e722f1e25..67b9d6fe4e 100644 --- a/miles/router/middleware_hub/radix_tree.py +++ b/miles/router/middleware_hub/radix_tree.py @@ -584,8 +584,8 @@ def retrieve_from_text(self, text: str, return_logprob: bool = True): text: Input text to get tokens for return_logprob: If True, also return log probabilities Returns: - List of token IDs corresponding to the input text if return_logprob is False. - Tuple of (token_ids, logp) if return_logprob is True. + List of token (IDs, logp, loss_mask) corresponding to the input text + if return_logprob is False, all logp will be 0.0 """ # Call find_longest_prefix to get the match result result = self.find_longest_prefix(text) diff --git a/miles/router/middleware_hub/radix_tree_middleware.py b/miles/router/middleware_hub/radix_tree_middleware.py index db57f64564..b9d62d8415 100644 --- a/miles/router/middleware_hub/radix_tree_middleware.py +++ b/miles/router/middleware_hub/radix_tree_middleware.py @@ -66,12 +66,14 @@ def __init__(self, app, *, router): self.router.radix_tree = self.radix_tree async def dispatch(self, request: Request, call_next): - path = request.url.path + if path == "/generate": + return await self._generate(request, call_next) + if path == "/retrieve_from_text": + return await self._retrieve_from_text(request) + return await call_next(request) - if path != "/generate": - return await call_next(request) - + async def _generate(self, request: Request, call_next): request_json = await request.json() if "text" in request_json: input_text = request_json.pop("text", "") @@ -154,6 +156,23 @@ async def dispatch(self, request: Request, call_next): print(f"[miles-router] Warning: Failed to cache trajectory: {e}") return response + async def _retrieve_from_text(self, request: Request): + payload = await request.json() + text = payload.get("text", "") + token_ids, logp, loss_mask = self.radix_tree.retrieve_from_text(text, return_logprob=True) + result = { + "response": text, + "tokens": token_ids, + "loss_mask": loss_mask, + "rollout_logp": logp, + "token_length": len(token_ids), + "loss_mask_length": len(loss_mask), + } + assert ( + len(token_ids) == len(logp) == len(loss_mask) + ), "Token IDs, logp, and loss mask must have the same length" + return JSONResponse(result) + async def postprocess_sample_with_radix_tree(args, sample: Sample, output: dict): assert not args.partial_rollout, "Currently partial rollout is not supported when using miles router" diff --git a/miles/router/router.py b/miles/router/router.py index 7d3ecd9806..f092f359a7 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from starlette.responses import Response -from miles.router.sessions import setup_session_routes +from miles.router.session.sessions import setup_session_routes from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -65,11 +65,10 @@ def __init__(self, args, verbose=False): self.app.add_middleware(middleware, router=self) def _setup_routes(self): - """Setup all the HTTP routes""" + """Setup all the HTTP routes except catch-all proxy""" # sglang-router api self.app.post("/add_worker")(self.add_worker) self.app.get("/list_workers")(self.list_workers) - self.app.post("/retrieve_from_text")(self.retrieve_from_text) # Session routes - must be registered before catch-all setup_session_routes(self.app, self) # Catch-all route for proxying to SGLang - must be registered LAST @@ -136,13 +135,23 @@ async def proxy(self, request: Request, path: str): result = await self._do_proxy(request, path) return self._build_proxy_response(result) - async def _do_proxy(self, request: Request, path: str) -> dict: + async def _do_proxy( + self, + request: Request, + path: str, + body: bytes | None = None, + headers: dict | None = None, + ) -> dict: """Core proxy logic. Returns dict with request_body, response_body, status_code, headers.""" worker_url = self._use_url() url = f"{worker_url}/{path}" - body = await request.body() - headers = dict(request.headers) + if body is None: + body = await request.body() + if headers is None: + headers = dict(request.headers) + if body is not None: + headers = {k: v for k, v in headers.items() if k.lower() not in ("content-length", "transfer-encoding")} try: response = await self.client.request(request.method, url, content=body, headers=headers) @@ -202,28 +211,6 @@ async def list_workers(self, request: Request): """List all registered workers""" return {"urls": list(self.worker_request_counts.keys())} - async def retrieve_from_text(self, request: Request): - """Get token information from text input""" - body = await request.body() - payload = json.loads(body) if body else {} - - text = payload.get("text", "") - - # Use radix tree's retrieve_from_text method (no need to fetch weight version here) - token_ids, logp, loss_mask = self.radix_tree.retrieve_from_text(text, return_logprob=True) - - # Handle the result based on whether logp was requested - result = { - "tokens": token_ids, # token IDs - "response": text, # The input text - "loss_mask": loss_mask, # Loss mask for the tokens - "token_length": len(token_ids), - "loss_mask_length": len(loss_mask), - "rollout_logp": logp, - } - - return result - def _use_url(self): """Select worker URL with minimal active requests.""" diff --git a/miles/router/session/naive_trajectory.py b/miles/router/session/naive_trajectory.py new file mode 100644 index 0000000000..3cd4ff1b75 --- /dev/null +++ b/miles/router/session/naive_trajectory.py @@ -0,0 +1,70 @@ +import threading +import uuid +from typing import Any + +from pydantic import BaseModel, Field + +from miles.router.session.session_types import SessionRecord + + +class NaiveTrajectory(BaseModel): + messages: list[dict[str, Any]] = Field(default_factory=list) + records: list[SessionRecord] = Field(default_factory=list) + + def append_session_record(self, record: SessionRecord): + self.records.append(record) + + +# This is only a naive trajectory manager to store history session record. +# Cross turn token input not implemented. +class NaiveTrajectoryManager: + def __init__(self, args, tokenizer: Any): + self.sessions: dict[str, NaiveTrajectory] = {} + self.args = args + self.tokenizer = tokenizer + self._lock = threading.RLock() + + def create_session(self) -> str: + with self._lock: + session_id = uuid.uuid4().hex + self.sessions[session_id] = NaiveTrajectory() + return session_id + + def get_session_records_by_id(self, session_id: str) -> list[SessionRecord] | None: + with self._lock: + session = self.sessions.get(session_id) + if session is None: + return None + return session.records + + def calc_prompt_tokens( + self, + session_id: str, + messages: list[dict[str, Any]], + ) -> list[int] | None: + with self._lock: + session = self.sessions.get(session_id) + if session is None: + return None + token_ids = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_special_tokens=False, + add_generation_prompt=True, + ) + return token_ids + + def delete_session_by_id(self, session_id: str) -> bool | None: + with self._lock: + session = self.sessions.pop(session_id, None) + if session is None: + return None + return True + + def append_session_record(self, session_id: str, record: SessionRecord) -> bool | None: + with self._lock: + session = self.sessions.get(session_id) + if session is None: + return None + session.append_session_record(record) + return True diff --git a/miles/router/session/session_types.py b/miles/router/session/session_types.py new file mode 100644 index 0000000000..c895b5e38d --- /dev/null +++ b/miles/router/session/session_types.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + + +class SessionRecord(BaseModel): + timestamp: float + method: str + path: str + request: dict + response: dict + status_code: int + + +class GetSessionResponse(BaseModel): + session_id: str + records: list[SessionRecord] diff --git a/miles/router/session/sessions.py b/miles/router/session/sessions.py new file mode 100644 index 0000000000..349272f818 --- /dev/null +++ b/miles/router/session/sessions.py @@ -0,0 +1,94 @@ +import json +import logging +import time +from typing import TYPE_CHECKING + +from fastapi import Request +from fastapi.responses import JSONResponse, Response +from transformers import AutoTokenizer + +from miles.router.session.naive_trajectory import NaiveTrajectoryManager +from miles.router.session.session_types import GetSessionResponse, SessionRecord + +if TYPE_CHECKING: + from miles.router.router import MilesRouter + +logger = logging.getLogger(__name__) + + +def setup_session_routes(app, router: "MilesRouter"): + hf_checkpoint = getattr(router.args, "hf_checkpoint", None) + if not hf_checkpoint: + if getattr(router, "verbose", False): + logger.info("[miles-router] Skipping session routes (hf_checkpoint not set).") + return + + tokenizer = AutoTokenizer.from_pretrained(hf_checkpoint, trust_remote_code=True) + manager = NaiveTrajectoryManager(router.args, tokenizer) + + @app.post("/sessions") + async def create_session(): + session_id = manager.create_session() + return {"session_id": session_id} + + @app.get("/sessions/{session_id}") + async def get_session(session_id: str): + records = manager.get_session_records_by_id(session_id) + if records is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + return GetSessionResponse( + session_id=session_id, + records=records, + ) + + @app.delete("/sessions/{session_id}") + async def delete_session(session_id: str): + deleted = manager.delete_session_by_id(session_id) + if deleted is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + return Response(status_code=204) + + @app.post("/sessions/{session_id}/v1/chat/completions") + async def chat_completions(request: Request, session_id: str): + body = await request.body() + request_body = json.loads(body) if body else {} + + if router.args.miles_router_enable_token_input_for_chat_completions: + if "messages" in request_body and "input_ids" not in request_body: + prompt_token_ids = manager.calc_prompt_tokens(session_id, request_body["messages"]) + if prompt_token_ids is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + request_body["input_ids"] = prompt_token_ids + body = json.dumps(request_body).encode("utf-8") + + result = await router._do_proxy(request, "v1/chat/completions", body=body) + + response = json.loads(result["response_body"]) + + choice = response.get("choices", [{}])[0] + # messages = request_body["messages"] + [choice["message"]] + + if "logprobs" not in choice or "content" not in choice["logprobs"]: + raise RuntimeError("logprobs must be in choice") + logprobs_content = choice["logprobs"]["content"] + + for item in logprobs_content: + if "token_id" not in item: + raise RuntimeError("token_id must be in item") + record = SessionRecord( + timestamp=time.time(), + method=request.method, + path="/v1/chat/completions", + status_code=result["status_code"], + request=request_body, + response=response, + ) + appended = manager.append_session_record(session_id, record) + if appended is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + return router._build_proxy_response(result) + + @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) + async def session_proxy(request: Request, session_id: str, path: str): + result = await router._do_proxy(request, path) + return router._build_proxy_response(result) diff --git a/miles/router/sessions.py b/miles/router/sessions.py deleted file mode 100644 index 9d753e5975..0000000000 --- a/miles/router/sessions.py +++ /dev/null @@ -1,124 +0,0 @@ -import json -import time -import uuid -from typing import TYPE_CHECKING - -from fastapi import Request -from fastapi.responses import JSONResponse, Response -from pydantic import BaseModel -from transformers import AutoTokenizer - -if TYPE_CHECKING: - from miles.router.router import MilesRouter - - -class SessionRecord(BaseModel): - timestamp: float - method: str - path: str - request: dict - response: dict - status_code: int - - -class GetSessionResponse(BaseModel): - session_id: str - records: list[SessionRecord] - - -class SessionManager: - def __init__(self): - self.sessions: dict[str, list[SessionRecord]] = {} - - def create_session(self) -> str: - session_id = uuid.uuid4().hex - self.sessions[session_id] = [] - return session_id - - def get_session(self, session_id: str) -> list[SessionRecord] | None: - return self.sessions.get(session_id) - - def delete_session(self, session_id: str) -> list[SessionRecord]: - assert session_id in self.sessions - return self.sessions.pop(session_id) - - def add_record(self, session_id: str, record: SessionRecord): - assert session_id in self.sessions - self.sessions[session_id].append(record) - - -def setup_session_routes(app, router: "MilesRouter"): - manager = SessionManager() - - # TODO temporary hack before @guapisolo implements TITO - # ============================= HACK START =============================== - # Lazy load tokenizer only when needed (for tests that don't have hf_checkpoint) - tokenizer = None - - def get_tokenizer(): - nonlocal tokenizer - if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained(router.args.hf_checkpoint, trust_remote_code=True) - return tokenizer - - # ============================= HACK END =============================== - - @app.post("/sessions") - async def create_session(): - session_id = manager.create_session() - return {"session_id": session_id} - - @app.get("/sessions/{session_id}") - async def get_session(session_id: str): - records = manager.get_session(session_id) - if records is None: - return JSONResponse(status_code=404, content={"error": "session not found"}) - return GetSessionResponse(session_id=session_id, records=records) - - @app.delete("/sessions/{session_id}") - async def delete_session(session_id: str): - if session_id not in manager.sessions: - return JSONResponse(status_code=404, content={"error": "session not found"}) - manager.delete_session(session_id) - return Response(status_code=204) - - @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) - async def session_proxy(request: Request, session_id: str, path: str): - if session_id not in manager.sessions: - return JSONResponse(status_code=404, content={"error": "session not found"}) - - result = await router._do_proxy(request, path) - - request_body = json.loads(result["request_body"]) - response_body = json.loads(result["response_body"]) - - # TODO: remove this hack when @guapisolo implements the real TITO - # ============================= HACK START =============================== - if "messages" in request_body and "input_ids" not in request_body: - request_body["input_ids"] = get_tokenizer().apply_chat_template( - request_body["messages"], - add_generation_prompt=True, - add_special_tokens=False, - tools=request_body.get("tools"), - ) - if ( - "logprobs" in response_body.get("choices", [{}])[0] - and "content" in response_body["choices"][0]["logprobs"] - ): - logprobs_content = response_body["choices"][0]["logprobs"]["content"] - for item in logprobs_content: - if "token" in item and "token_id" not in item: - item["token_id"] = get_tokenizer().convert_tokens_to_ids(item["token"]) - # ============================= HACK END =============================== - - record = SessionRecord( - timestamp=time.time(), - method=request.method, - path=path, - request=request_body, - response=response_body, - status_code=result["status_code"], - ) - manager.add_record(session_id, record) - - return router._build_proxy_response(result) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 922d3a4dc4..09c817c5e8 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -959,6 +959,17 @@ def add_router_arguments(parser): default=3, help="Number of consecutive failures before marking a worker as unhealthy.", ) + parser.add_argument( + "--miles-router-enable-token-input-for-chat-completions", + action="store_true", + default=False, + help=( + "This is an experimental feature, and only supports for text model." + "Whether to enable token input for chat completions. If set, we will calculate " + "the input_ids for the prompt part inside miles and add it to the request body." + "This is reserved for cross turn token in under OAI format." + ), + ) RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) return parser diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py index 2c0dddfe54..3647c86265 100644 --- a/miles/utils/test_utils/mock_sglang_server.py +++ b/miles/utils/test_utils/mock_sglang_server.py @@ -144,12 +144,17 @@ def _compute_chat_completions_response(self, payload: dict) -> dict: prompt_str = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, tools=tools ) + prompt_ids = self.tokenizer.encode(prompt_str, add_special_tokens=False) process_result = self.process_fn(prompt_str) output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) logprobs_content = [ - {"token": self.tokenizer.convert_ids_to_tokens(tid), "logprob": -1 / 128 * i} + { + "token": self.tokenizer.convert_ids_to_tokens(tid), + "token_id": tid, + "logprob": -1 / 128 * i, + } for i, tid in enumerate(output_ids) ] @@ -188,6 +193,7 @@ def _compute_chat_completions_response(self, payload: dict) -> dict: "tool_calls": tool_calls, }, "logprobs": {"content": logprobs_content}, + "input_token_ids": prompt_ids, "finish_reason": finish_reason, } ], diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py index 6b99e36739..f38344c8c6 100644 --- a/miles/utils/test_utils/mock_tools.py +++ b/miles/utils/test_utils/mock_tools.py @@ -1,9 +1,14 @@ import json +from copy import deepcopy +from typing import Any from transformers import AutoTokenizer from miles.utils.test_utils.mock_sglang_server import ProcessResult +AGENTIC_MAX_TURNS: int | None = None +from miles.utils.http_utils import post + SAMPLE_TOOLS = [ { "type": "function", @@ -54,6 +59,56 @@ async def execute_tool_call(name: str, params: dict) -> str: return TOOL_EXECUTORS[name](params) +async def run_agentic_tool_call( + base_url: str, + prompt: list[dict[str, Any]] | str, + request_kwargs: dict[str, Any] | None = None, + max_turns: int = 8, +) -> None: + if AGENTIC_MAX_TURNS is not None: + max_turns = AGENTIC_MAX_TURNS + messages = deepcopy(prompt) if isinstance(prompt, list) else [{"role": "user", "content": prompt}] + request_kwargs = request_kwargs or {} + model = request_kwargs.get("model", "default") + tools = request_kwargs.get("tools", SAMPLE_TOOLS) + + for _ in range(max_turns): + payload = {"model": model, "messages": messages, "tools": tools} + response = await post(base_url + "/v1/chat/completions", payload) + choice = response["choices"][0]["message"] + tool_calls = choice.get("tool_calls") or [] + if not tool_calls: + break + + assistant_msg = { + "content": choice.get("content"), + "refusal": choice.get("refusal"), + "role": choice.get("role", "assistant"), + "annotations": choice.get("annotations"), + "audio": choice.get("audio"), + "function_call": choice.get("function_call"), + "tool_calls": tool_calls, + } + messages.append(assistant_msg) + + for tool_call in tool_calls: + name = tool_call["function"]["name"] + raw_args = tool_call["function"].get("arguments") or "{}" + try: + params = json.loads(raw_args) if isinstance(raw_args, str) else raw_args + except json.JSONDecodeError: + params = {} + result = await execute_tool_call(name, params) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.get("id"), + "content": result, + "name": name, + } + ) + + _SYSTEM_PROMPT = ( "<|im_start|>system\n" "# Tools\n" diff --git a/tests/e2e/sglang_patch/__init__.py b/tests/e2e/sglang_patch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/e2e/sglang_patch/sglang_server.py b/tests/e2e/sglang_patch/sglang_server.py new file mode 100644 index 0000000000..44214de056 --- /dev/null +++ b/tests/e2e/sglang_patch/sglang_server.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import os +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import IO + +import requests + +from miles.utils.http_utils import find_available_port + +DEFAULT_HOST = "127.0.0.1" +DEFAULT_BASE_PORT = 34000 +DEFAULT_STARTUP_TIMEOUT_SECS = 900.0 +DEFAULT_SHUTDOWN_TIMEOUT_SECS = 30.0 + + +@dataclass +class SGLangServer: + process: subprocess.Popen + host: str + port: int + log_path: Path + _log_file: IO[str] + + @property + def base_url(self) -> str: + return f"http://{self.host}:{self.port}" + + def stop(self, timeout_secs: float = DEFAULT_SHUTDOWN_TIMEOUT_SECS) -> None: + if self.process.poll() is None: + self.process.terminate() + try: + self.process.wait(timeout=timeout_secs) + except subprocess.TimeoutExpired: + self.process.kill() + self.process.wait(timeout=timeout_secs) + self._log_file.close() + + +def start_sglang_server( + *, + model_path: str, + host: str = DEFAULT_HOST, + port: int | None = None, + startup_timeout_secs: float = DEFAULT_STARTUP_TIMEOUT_SECS, + enable_deterministic_inference: bool = True, + extra_args: list[str] | None = None, +) -> SGLangServer: + if port is None: + port = find_available_port(DEFAULT_BASE_PORT) + + log_path = Path(f"/tmp/sglang_e2e_{port}.log") + log_file = log_path.open("w", encoding="utf-8") + + cmd = [ + sys.executable, + "-m", + "sglang.launch_server", + "--model-path", + model_path, + "--host", + host, + "--port", + str(port), + "--trust-remote-code", + ] + if enable_deterministic_inference: + cmd.append("--enable-deterministic-inference") + if extra_args: + cmd.extend(extra_args) + + env = os.environ.copy() + env.setdefault("PYTHONUNBUFFERED", "1") + + process = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT, env=env) + server = SGLangServer(process=process, host=host, port=port, log_path=log_path, _log_file=log_file) + + _wait_for_ready(server, timeout_secs=startup_timeout_secs) + return server + + +def _wait_for_ready(server: SGLangServer, *, timeout_secs: float) -> None: + deadline = time.monotonic() + timeout_secs + last_error = "" + + while time.monotonic() < deadline: + if server.process.poll() is not None: + log_tail = _read_log_tail(server.log_path) + raise RuntimeError( + "SGLang server exited early. " f"Exit code: {server.process.returncode}. " f"Log tail:\n{log_tail}" + ) + + try: + response = requests.get(f"{server.base_url}/health", timeout=5) + if response.status_code == 200: + return + last_error = f"status_code={response.status_code}" + except requests.RequestException as exc: + last_error = str(exc) + + time.sleep(1.0) + + log_tail = _read_log_tail(server.log_path) + raise TimeoutError( + "Timed out waiting for SGLang server to become healthy. " + f"Last error: {last_error}. " + f"Log tail:\n{log_tail}" + ) + + +def _read_log_tail(path: Path, max_lines: int = 80) -> str: + if not path.exists(): + return "" + + content = path.read_text(encoding="utf-8", errors="ignore") + lines = content.splitlines() + if len(lines) <= max_lines: + return content + return "\n".join(lines[-max_lines:]) diff --git a/tests/e2e/sglang_patch/test_chat_input_ids_equivalence.py b/tests/e2e/sglang_patch/test_chat_input_ids_equivalence.py new file mode 100644 index 0000000000..adcf08fb2b --- /dev/null +++ b/tests/e2e/sglang_patch/test_chat_input_ids_equivalence.py @@ -0,0 +1,122 @@ +import math +import os + +import pytest +import requests +from tests.e2e.sglang_patch.sglang_server import start_sglang_server +from transformers import AutoTokenizer + +MODEL_PATH = os.environ.get("SGLANG_E2E_MODEL_PATH", "Qwen/Qwen3-0.6B") +SEED = 1234 +TEMPERATURE = 1.0 +TOP_P = 1.0 +MAX_COMPLETION_TOKENS = 64 +LOGPROB_TOL = 1e-6 + + +@pytest.fixture(scope="module") +def sglang_server(): + server = start_sglang_server(model_path=MODEL_PATH) + try: + yield server + finally: + server.stop() + + +@pytest.mark.system +def test_chat_completions_input_ids_equivalence(sglang_server): + """Validate that providing input_ids yields the same completion as raw messages.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + messages = _build_messages() + # Build the same prompt two ways: message list vs. explicit input_ids. + input_ids = _build_input_ids(tokenizer, messages) + + # Request completions for both payload variants. + response_a = _post_chat(sglang_server.base_url, _build_payload(messages)) + response_b = _post_chat(sglang_server.base_url, _build_payload(messages, input_ids)) + + choice_a = response_a["choices"][0] + choice_b = response_b["choices"][0] + + # The generated content and finish reason should match across variants. + assert choice_a["message"]["content"] == choice_b["message"]["content"] + assert choice_a["finish_reason"] == choice_b["finish_reason"] + + # Compare token ids and per-token logprobs for exact equivalence. + token_ids_a, logprobs_a = _extract_tokens_and_logprobs(choice_a) + token_ids_b, logprobs_b = _extract_tokens_and_logprobs(choice_b) + + assert token_ids_a == token_ids_b + assert len(logprobs_a) == len(logprobs_b) + + for index, (a_val, b_val) in enumerate(zip(logprobs_a, logprobs_b, strict=True)): + assert math.isclose(a_val, b_val, abs_tol=LOGPROB_TOL), f"logprob mismatch at {index}: {a_val} vs {b_val}" + + +@pytest.mark.system +def test_chat_completions_input_logprobs_prompt_ids_match(sglang_server): + """Ensure input_ids are echoed exactly in input_token_ids and logprobs are present.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + messages = _build_messages() + input_ids = _build_input_ids(tokenizer, messages) + + response = _post_chat(sglang_server.base_url, _build_payload(messages, input_ids)) + choice = response["choices"][0] + + input_token_ids = _extract_input_token_ids(choice) + + assert input_token_ids == input_ids + assert choice.get("logprobs", {}).get("content"), "logprobs content is missing" + + +def _post_chat(base_url: str, payload: dict) -> dict: + response = requests.post(f"{base_url}/v1/chat/completions", json=payload, timeout=120) + print(f"response: {response.json()}", flush=True) + assert response.status_code == 200, response.text + return response.json() + + +def _build_messages() -> list[dict]: + return [ + {"role": "system", "content": "You are a concise assistant."}, + {"role": "user", "content": "Answer with one word: 2+2?"}, + ] + + +def _build_input_ids(tokenizer, messages: list[dict]) -> list[int]: + return tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + + +def _build_payload(messages: list[dict], input_ids: list[int] | None = None) -> dict: + payload = { + "model": MODEL_PATH, + "temperature": TEMPERATURE, + "top_p": TOP_P, + "max_completion_tokens": MAX_COMPLETION_TOKENS, + "seed": SEED, + "logprobs": True, + "messages": messages, + "logprob_start_len": 0, + } + if input_ids is not None: + payload["input_ids"] = input_ids + return payload + + +def _extract_tokens_and_logprobs(choice: dict) -> tuple[list[int], list[float]]: + logprobs = choice.get("logprobs", {}).get("content") + assert logprobs, "logprobs content is missing" + + token_ids = [] + for item in logprobs: + token_ids.append(item["token_id"]) + values = [item["logprob"] for item in logprobs] + return token_ids, values + + +def _extract_input_token_ids(choice: dict) -> list[int]: + token_ids = choice.get("input_token_ids") + assert token_ids is not None, "input_token_ids is missing in response" + + print(f"input_token_ids: {token_ids}", flush=True) + return token_ids diff --git a/tests/fast/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py index 816371ee3a..2dfabfa3ee 100644 --- a/tests/fast/fixtures/generation_fixtures.py +++ b/tests/fast/fixtures/generation_fixtures.py @@ -19,6 +19,7 @@ from miles.utils.async_utils import run from miles.utils.http_utils import find_available_port, init_http_client from miles.utils.misc import SingletonMeta +from miles.utils.test_utils import mock_tools from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer from miles.utils.types import Sample @@ -45,18 +46,14 @@ def extra_argv_for_variant( generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", generate_tool_call_parser: str = "qwen25", generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", + custom_agent_function_path: str = "miles.utils.test_utils.mock_tools.run_agentic_tool_call", ) -> list[str]: argv = [ "--custom-generate-function-path", custom_generate_function_path or VARIANT_TO_GENERATE_FN_PATH[variant], ] - if variant in ( - "multi_turn_single_sample", - "multi_turn_multi_samples", - "agentic_tool_call_single_sample", - "agentic_tool_call_multi_samples", - ): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): argv += [ "--generate-max-turns", str(generate_max_turns), @@ -65,9 +62,12 @@ def extra_argv_for_variant( "--generate-execute-tool-function-path", generate_execute_tool_function_path, ] - if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): - argv += ["--generate-tool-call-parser", generate_tool_call_parser] - if variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples"): + argv += ["--generate-tool-call-parser", generate_tool_call_parser] + if variant == "multi_turn_multi_samples": + argv.append("--generate-multi-samples") + elif variant in ("agentic_tool_call_single_sample", "agentic_tool_call_multi_samples"): + argv += ["--custom-agent-function-path", custom_agent_function_path] + if variant == "agentic_tool_call_multi_samples": argv.append("--generate-multi-samples") return argv @@ -219,6 +219,7 @@ def with_miles_router(backend_url: str, model_name: str): miles_router_middleware_paths=[], rollout_health_check_interval=60, miles_router_health_check_failure_threshold=3, + miles_router_enable_token_input_for_chat_completions=False, hf_checkpoint=model_name, ) router = MilesRouter(router_args) @@ -269,6 +270,9 @@ def process_fn(_): custom_generate_function_path=custom_generate_function_path, **other_args_kwargs, ) + if variant.startswith("agentic_tool_call"): + mock_tools.AGENTIC_MAX_TURNS = args_kwargs.get("generate_max_turns") yield GenerateEnv(args=args, mock_server=mock_server) + mock_tools.AGENTIC_MAX_TURNS = None SingletonMeta.clear_all_instances() diff --git a/tests/fast/router/test_sessions.py b/tests/fast/router/test_sessions.py index 5c6edafe20..566bb938f7 100644 --- a/tests/fast/router/test_sessions.py +++ b/tests/fast/router/test_sessions.py @@ -1,195 +1,210 @@ from types import SimpleNamespace +from unittest.mock import patch import pytest import requests from miles.router.router import MilesRouter -from miles.router.sessions import SessionManager, SessionRecord +from miles.router.session.naive_trajectory import NaiveTrajectoryManager +from miles.router.session.session_types import SessionRecord from miles.utils.http_utils import find_available_port -from miles.utils.test_utils.mock_sglang_server import ProcessResult, with_mock_server +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, ProcessResult, with_mock_server from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer -class TestSessionManager: - def test_create_session(self): - manager = SessionManager() - session_id = manager.create_session() +class DummyTokenizer: + """Minimal tokenizer stub for testing NaiveTrajectoryManager.""" + + def apply_chat_template( + self, + messages, + tokenize: bool = True, + add_special_tokens: bool = False, + add_generation_prompt: bool = True, + ): + """Return deterministic token ids based on message count.""" + base = len(messages) or 1 + return [base, base + 1, base + 2] + + +@pytest.fixture +def naive_manager(): + """Create a NaiveTrajectoryManager with a dummy tokenizer.""" + args = SimpleNamespace() + tokenizer = DummyTokenizer() + return NaiveTrajectoryManager(args, tokenizer) + + +class TestNaiveTrajectoryManager: + def test_create_session(self, naive_manager: NaiveTrajectoryManager): + session_id = naive_manager.create_session() assert session_id is not None assert len(session_id) == 32 - assert session_id in manager.sessions - assert manager.sessions[session_id] == [] + assert session_id in naive_manager.sessions - def test_get_session_exists(self): - manager = SessionManager() - session_id = manager.create_session() - records = manager.get_session(session_id) + def test_get_session_records_by_id(self, naive_manager: NaiveTrajectoryManager): + session_id = naive_manager.create_session() + records = naive_manager.get_session_records_by_id(session_id) assert records == [] - def test_get_session_not_exists(self): - manager = SessionManager() - records = manager.get_session("nonexistent") + def test_get_session_records_by_id_not_found(self, naive_manager: NaiveTrajectoryManager): + records = naive_manager.get_session_records_by_id("nonexistent") assert records is None - def test_delete_session_exists(self): - manager = SessionManager() - session_id = manager.create_session() - records = manager.delete_session(session_id) - assert records == [] - assert session_id not in manager.sessions + def test_calc_prompt_tokens_for_existing_session(self, naive_manager: NaiveTrajectoryManager): + session_id = naive_manager.create_session() + messages = [{"role": "user", "content": "hello"}] + + token_ids = naive_manager.calc_prompt_tokens(session_id, messages) + + assert token_ids == [1, 2, 3] - def test_delete_session_not_exists(self): - manager = SessionManager() - with pytest.raises(AssertionError): - manager.delete_session("nonexistent") + def test_calc_prompt_tokens_for_missing_session(self, naive_manager: NaiveTrajectoryManager): + messages = [{"role": "user", "content": "hello"}] + token_ids = naive_manager.calc_prompt_tokens("missing", messages) + assert token_ids is None - def test_add_record(self): - manager = SessionManager() - session_id = manager.create_session() + def test_delete_session_by_id(self, naive_manager: NaiveTrajectoryManager): + session_id = naive_manager.create_session() + assert naive_manager.delete_session_by_id(session_id) is True + assert session_id not in naive_manager.sessions + assert naive_manager.delete_session_by_id(session_id) is None + + def test_append_session_record(self, naive_manager: NaiveTrajectoryManager): + session_id = naive_manager.create_session() record = SessionRecord( - timestamp=1234567890.0, + timestamp=0.0, method="POST", - path="generate", - request={"prompt": "hello"}, - response={"text": "world"}, + path="/v1/chat/completions", status_code=200, + request={"messages": [{"role": "user", "content": "hello"}]}, + response={"choices": []}, ) - manager.add_record(session_id, record) - assert len(manager.sessions[session_id]) == 1 - assert manager.sessions[session_id][0] == record - def test_add_record_nonexistent_session(self): - manager = SessionManager() + appended = naive_manager.append_session_record(session_id, record) + + assert appended is True + records = naive_manager.get_session_records_by_id(session_id) + assert records is not None + assert len(records) == 1 + assert records[0].path == record.path + + def test_append_session_record_missing_session(self, naive_manager: NaiveTrajectoryManager): record = SessionRecord( - timestamp=1234567890.0, + timestamp=0.0, method="POST", - path="generate", + path="/v1/chat/completions", + status_code=200, request={}, response={}, - status_code=200, ) - with pytest.raises(AssertionError): - manager.add_record("nonexistent", record) + appended = naive_manager.append_session_record("missing", record) + assert appended is None @pytest.fixture(scope="class") -def router_url(): +def router_env(): + """Create a MilesRouter with session routes and a mock backend.""" + def process_fn(prompt: str) -> ProcessResult: return ProcessResult(text=f"echo: {prompt}", finish_reason="stop") - with with_mock_server(process_fn=process_fn) as backend: - args = SimpleNamespace( - miles_router_max_connections=10, - miles_router_timeout=30, - miles_router_middleware_paths=[], - rollout_health_check_interval=60, - miles_router_health_check_failure_threshold=3, - hf_checkpoint="Qwen/Qwen3-0.6B", - ) - router = MilesRouter(args) + original_chat_response = MockSGLangServer._compute_chat_completions_response + + def patched_chat_response(self, payload: dict) -> dict: + response = original_chat_response(self, payload) + logprobs_content = response["choices"][0]["logprobs"]["content"] + for item in logprobs_content: + item["token_id"] = self.tokenizer.convert_tokens_to_ids(item["token"]) + return response + + with patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response): + with with_mock_server(process_fn=process_fn) as backend: + args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + miles_router_enable_token_input_for_chat_completions=False, + hf_checkpoint="Qwen/Qwen3-0.6B", + trajectory_manager="naive_trajectory", + ) + router = MilesRouter(args) - port = find_available_port(31000) - server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) - server.start() + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() - url = f"http://127.0.0.1:{port}" - requests.post(f"{url}/add_worker", json={"url": backend.url}) + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend.url}, timeout=5.0) - try: - yield url - finally: - server.stop() + try: + yield SimpleNamespace(url=url) + finally: + server.stop() class TestSessionRoutes: - def test_create_session(self, router_url): - response = requests.post(f"{router_url}/sessions") + def test_create_session(self, router_env): + response = requests.post(f"{router_env.url}/sessions", timeout=5.0) assert response.status_code == 200 data = response.json() assert "session_id" in data assert len(data["session_id"]) == 32 - def test_get_session(self, router_url): - session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + def test_get_session_initial_state(self, router_env): + session_id = requests.post(f"{router_env.url}/sessions", timeout=5.0).json()["session_id"] - get_resp = requests.get(f"{router_url}/sessions/{session_id}") + get_resp = requests.get(f"{router_env.url}/sessions/{session_id}", timeout=5.0) assert get_resp.status_code == 200 data = get_resp.json() assert data["session_id"] == session_id assert data["records"] == [] - def test_get_session_not_found(self, router_url): - response = requests.get(f"{router_url}/sessions/nonexistent") + def test_get_session_not_found(self, router_env): + response = requests.get(f"{router_env.url}/sessions/nonexistent", timeout=5.0) assert response.status_code == 404 assert response.json()["error"] == "session not found" - def test_get_with_records(self, router_url): - session_id = requests.post(f"{router_url}/sessions").json()["session_id"] - - requests.post( - f"{router_url}/sessions/{session_id}/generate", - json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, - ) + def test_delete_session(self, router_env): + session_id = requests.post(f"{router_env.url}/sessions", timeout=5.0).json()["session_id"] - get_resp = requests.get(f"{router_url}/sessions/{session_id}") - assert get_resp.status_code == 200 - data = get_resp.json() - assert data["session_id"] == session_id - assert len(data["records"]) == 1 - - def test_delete_session(self, router_url): - session_id = requests.post(f"{router_url}/sessions").json()["session_id"] - - delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") + delete_resp = requests.delete(f"{router_env.url}/sessions/{session_id}", timeout=5.0) assert delete_resp.status_code == 204 assert delete_resp.text == "" - assert requests.delete(f"{router_url}/sessions/{session_id}").status_code == 404 + assert requests.delete(f"{router_env.url}/sessions/{session_id}", timeout=5.0).status_code == 404 - def test_delete_session_not_found(self, router_url): - response = requests.delete(f"{router_url}/sessions/nonexistent") + def test_delete_session_not_found(self, router_env): + response = requests.delete(f"{router_env.url}/sessions/nonexistent", timeout=5.0) assert response.status_code == 404 assert response.json()["error"] == "session not found" class TestSessionProxy: - def test_proxy_session_not_found(self, router_url): - response = requests.post(f"{router_url}/sessions/nonexistent/generate", json={}) - assert response.status_code == 404 - assert response.json()["error"] == "session not found" - - def test_proxy_records_request_response(self, router_url): - session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + def test_proxy_chat_appends_record(self, router_env): + session_id = requests.post(f"{router_env.url}/sessions", timeout=5.0).json()["session_id"] + payload = { + "messages": [{"role": "user", "content": "What is 1+2?"}], + "return_logprob": True, + } resp = requests.post( - f"{router_url}/sessions/{session_id}/generate", - json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + f"{router_env.url}/sessions/{session_id}/v1/chat/completions", + json=payload, + timeout=10.0, ) assert resp.status_code == 200 - assert "text" in resp.json() - - get_resp = requests.get(f"{router_url}/sessions/{session_id}") - records = get_resp.json()["records"] - assert len(records) == 1 - assert records[0]["method"] == "POST" - assert records[0]["path"] == "generate" - assert records[0]["request"]["input_ids"] == [1, 2, 3] - assert "text" in records[0]["response"] - - delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") - assert delete_resp.status_code == 204 - - def test_proxy_accumulates_records(self, router_url): - session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + body = resp.json() + assert "choices" in body + assert body["choices"] - for _ in range(3): - requests.post( - f"{router_url}/sessions/{session_id}/generate", - json={"input_ids": [1], "sampling_params": {}, "return_logprob": True}, - ) - - get_resp = requests.get(f"{router_url}/sessions/{session_id}") + get_resp = requests.get(f"{router_env.url}/sessions/{session_id}", timeout=5.0) records = get_resp.json()["records"] - assert len(records) == 3 - delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") - assert delete_resp.status_code == 204 + assert isinstance(records, list) + assert len(records) == 1 + record = records[0] + assert record["path"] == "/v1/chat/completions" + assert record["status_code"] == 200 diff --git a/tests/fast/utils/test_utils/test_mock_sglang_server.py b/tests/fast/utils/test_utils/test_mock_sglang_server.py index 6633678da1..e387fd78bd 100644 --- a/tests/fast/utils/test_utils/test_mock_sglang_server.py +++ b/tests/fast/utils/test_utils/test_mock_sglang_server.py @@ -17,7 +17,15 @@ def expected_logprobs(tokenizer, text: str) -> list[dict]: output_ids = tokenizer.encode(text, add_special_tokens=False) - return [{"token": tokenizer.convert_ids_to_tokens(tid), "logprob": -i / 128} for i, tid in enumerate(output_ids)] + return [ + {"token": tokenizer.convert_ids_to_tokens(tid), "token_id": tid, "logprob": -i / 128} + for i, tid in enumerate(output_ids) + ] + + +def expected_input_token_ids(tokenizer, messages: list[dict], tools: list[dict] | None) -> list[int]: + prompt_str = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools) + return tokenizer.encode(prompt_str, add_special_tokens=False) @pytest.fixture(scope="module") @@ -230,11 +238,12 @@ def process_fn(_: str) -> ProcessResult: class TestChatCompletionsEndpoint: def test_basic(self, mock_server): + messages = [{"role": "user", "content": "What is 1+5?"}] response = requests.post( f"{mock_server.url}/v1/chat/completions", json={ "model": "test-model", - "messages": [{"role": "user", "content": "What is 1+5?"}], + "messages": messages, }, timeout=5.0, ) @@ -253,6 +262,7 @@ def test_basic(self, mock_server): "index": 0, "message": {"role": "assistant", "content": "\\boxed{6}", "tool_calls": None}, "logprobs": {"content": expected_logprobs(mock_server.tokenizer, "\\boxed{6}")}, + "input_token_ids": expected_input_token_ids(mock_server.tokenizer, messages, None), "finish_reason": "stop", } ], @@ -286,6 +296,11 @@ def process_fn(_: str) -> ProcessResult: ], }, "logprobs": {"content": expected_logprobs(server.tokenizer, tool_call_response)}, + "input_token_ids": expected_input_token_ids( + server.tokenizer, + [{"role": "user", "content": "What year is it?"}], + SAMPLE_TOOLS, + ), "finish_reason": "tool_calls", } @@ -311,6 +326,11 @@ def process_fn(_: str) -> ProcessResult: "index": 0, "message": {"role": "assistant", "content": response_text, "tool_calls": None}, "logprobs": {"content": expected_logprobs(server.tokenizer, response_text)}, + "input_token_ids": expected_input_token_ids( + server.tokenizer, + [{"role": "user", "content": "What's the weather?"}], + SAMPLE_TOOLS, + ), "finish_reason": "stop", } @@ -351,6 +371,11 @@ def process_fn(_: str) -> ProcessResult: ], }, "logprobs": {"content": expected_logprobs(server.tokenizer, multi_tool_response)}, + "input_token_ids": expected_input_token_ids( + server.tokenizer, + [{"role": "user", "content": "What year and temperature?"}], + SAMPLE_TOOLS, + ), "finish_reason": "tool_calls", } From 824b8490aaaa637f4a8840e241d6e237053322ae Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Fri, 6 Feb 2026 11:30:27 -0800 Subject: [PATCH 69/77] Update CODEOWNERS for miles directory ownership (#569) --- .github/CODEOWNERS | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 416cb53856..c81945d40e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ .github/CODEOWNERS @fzyzcjy @Ying1123 .github/workflows/ @yushengsu-thu /miles/ @fzyzcjy @yueming-yuan -/miles/backends/ @maocheng23 -/miles/ray/ @maocheng23 -/miles/rollout/ @guapisolo -/miles/router/ @guapisolo -/miles/utils/ @guapisolo @maocheng23 +/miles/backends/ @fzyzcjy @yueming-yuan @maocheng23 +/miles/ray/ @fzyzcjy @yueming-yuan @maocheng23 +/miles/rollout/ @fzyzcjy @yueming-yuan @guapisolo +/miles/router/ @fzyzcjy @yueming-yuan @guapisolo +/miles/utils/ @fzyzcjy @yueming-yuan @guapisolo @maocheng23 From a93d4842e4597807e282043e647ce295ec94bfe3 Mon Sep 17 00:00:00 2001 From: Chengxi Li <114854555+Hecate0821@users.noreply.github.com> Date: Sat, 7 Feb 2026 11:16:18 -0800 Subject: [PATCH 70/77] [Doc] Add doc for miles router (#538) --- README.md | 2 +- docs/en/advanced/miles-router.md | 93 ++++++++++++++++++++++++++++ docs/en/get_started/customization.md | 1 + docs/en/index.rst | 1 + 4 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 docs/en/advanced/miles-router.md diff --git a/README.md b/README.md index 2bda4c95b3..1d5bd5d4a8 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ * **[2026/01]** 💎 **INT4 Quantization-Aware Training (QAT)**: Inspired by the Kimi K2-Thinking report, Miles now features a full-stack INT4 W4A16 QAT pipeline. This allows 1TB-scale models to fit into single-machine VRAM (e.g., NVIDIA H200), doubling rollout efficiency by eliminating cross-node bottlenecks while maintaining BF16-equivalent accuracy. [Blog](https://lmsys.org/blog/2026-01-26-int4-qat/) * **[2026/01]** 💎 **Unified VLM/LLM Multi-Turn Training**: We provided an implementation for the VLM multi-turn sampling paradigm. Developers only need to write a customized `rollout` function to easily start multi-turn RL for VLM, just like training LLM. [Blog](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/vlm-multi-turn/readme-en.md) * **[2026/01]** 🤖 **Multi-Agent Co-Evolution**: Miles now supports **MrlX**, a novel asynchronous co-evolutionary framework for Multi-Agent RL. Achieve superior performance in complex tasks like Doctor-Patient simulations and DeepResearch pipelines by enabling specialized agents to evolve together symbiotically. [[Link]](https://github.com/AQ-MedAI/MrlX) -* **[2025/12]** 🔄 **Rollout Routing Replay (R3)**: In collaboration with SGLang, we've launched R3 to solve MoE RL instability. R3 records inference routing decisions and replays them during training, effectively eliminating the "training-inference mismatch" and preventing training collapse in large MoE models like Qwen3 and DeepSeek-V3. [[Paper]](https://arxiv.org/pdf/2510.11370) +* **[2025/12]** 🔄 **Rollout Routing Replay (R3)**: In collaboration with SGLang, we've launched R3 to solve MoE RL instability. R3 records inference routing decisions and replays them during training, effectively eliminating the "training-inference mismatch" and preventing training collapse in large MoE models like Qwen3 and DeepSeek-V3. [[Paper]](https://arxiv.org/pdf/2510.11370) [[Docs]](docs/en/advanced/miles-router.md#22-rollout-routing-replay-r3-for-moe) * **[2025/11]** 🔥 **Unified FP8 Release**: Solves the stability issues in MoE RL by ensuring training and inference use the exact same FP8 quantization logic. [[Blog]](https://lmsys.org/blog/2025-11-25-fp8-rl/) * **[2025/11]** ⚡ **Speculative Decoding in RL**: Integrated speculative rollout with online SFT for draft models, achieving massive throughput gains. [[Blog]](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/spec/readme-en.md) * **[2025/11]** 🎉 **Miles Project Launch**: A joint effort by InfiXAI, Ant Group, SGLang RL Team, and the Miles community. [[Announcement]](https://lmsys.org/blog/2025-11-19-miles/) diff --git a/docs/en/advanced/miles-router.md b/docs/en/advanced/miles-router.md new file mode 100644 index 0000000000..1eeb42b659 --- /dev/null +++ b/docs/en/advanced/miles-router.md @@ -0,0 +1,93 @@ +# Miles Router + +miles includes an optional Miles Router used during rollout / data generation. It is a lightweight HTTP router/proxy that sits in front of one or more SGLang worker servers and adds training-oriented capabilities that are not the main goal of serving-focused routers. + +--- + +## 1. What is Miles Router? + +Miles Router is a small FastAPI service that: + +- Registers workers (SGLang HTTP servers) into a local pool +- Routes requests to a selected worker (simple least-inflight load balancing) +- Proxies arbitrary paths to the selected worker (e.g. `/generate`) +- Runs periodic health checks and quarantines unhealthy workers +- Supports middleware plugins (via `--miles-router-middleware-paths`) to implement rollout-specific processing (e.g. caching, request/response transforms) + +In miles's architecture, the router is part of the rollout system ("SGLang + router") that generates samples and pushes them into the data buffer. + +### How it is launched + +In distributed training, miles will start a router automatically when `--sglang-router-ip` is not provided: + +- If `--use-miles-router` is set, miles starts Miles Router +- Otherwise, miles starts SGLang Model Gateway + +--- + +## 2. Why we need Miles Router + +Unlike production inference, RL rollout needs to capture additional metadata for training: token-level logprobs, loss masks, and (for MoE models) expert routing decisions. Miles Router provides these capabilities through its middleware system and passthrough proxy design. + +### 2.1 Radix-tree cache (transparent token management) + +> Use this when your rollout pipeline is text-in/text-out and you cannot reliably persist token IDs; if you already control token-in/token-out (e.g. search r1, multiturn VLM examples), you likely don't need the radix-tree cache. + +Text-in text-out interfaces can cause token retokenization mismatches - re-tokenizing text at training time may produce different token sequences than rollout, breaking per-token alignment needed for PPO/GRPO losses. + +The radix-tree cache solves this transparently: it intercepts text-based requests, tokenizes them, and stores trajectories (text, token IDs, logprobs, loss masks) keyed by the text prefix. After rollout finishes, calling `/retrieve_from_text` returns the exact token sequence with aligned metadata, without requiring any changes to your rollout code. + +Implementation-wise, the radix-tree cache: + +- Accepts text plus tokens/metadata and stores them in a radix tree +- Uses longest-prefix matching to reuse cached token sequences (enabling token-in/token-out downstream) +- Allows insertion of new text continuations as rollout proceeds (multiple trajectories per prompt, e.g. GRPO) +- Periodically cleans up stale nodes to control memory usage + +Use the radix cache when you have text-based rollout code and want token-level precision without rewriting, or when running GRPO with multiple trajectories sharing the same prompt prefix. + +### 2.2 Rollout routing replay (R3) for MoE + +For MoE models, miles supports rollout routing replay (R3): record expert routing decisions during rollout and replay them during training to improve stability. + +#### SGLang side + +SGLang provides expert routing capture via: + +- `--enable-return-routed-experts`: server argument to enable routing capture +- `RoutedExpertsCapturer`: captures `topk_ids` (selected expert IDs) at each MoE layer during forward pass +- `return_routed_experts`: request parameter to retrieve routing data +- Returns `routed_experts` in response `meta_info` - a `[seq_len - 1, num_layers, top_k]` tensor of expert IDs + +#### miles side + +miles consumes the routing data and replays it during training: + +- `--use-miles-router --use-rollout-routing-replay`: both flags required to enable R3 +- Rollout sends `return_routed_experts=True` and stores results in `sample.rollout_routed_experts` +- Training calls `fill_routing_replay()` to load routing data into `RoutingReplay` objects +- During forward pass, recorded routing decisions are replayed instead of recomputed + +#### Why Miles Router is needed + +We need Miles Router because the SGLang worker returns routed experts in the response (`meta_info.routed_experts`) when the request sets `return_routed_experts=true`, and Miles Router preserves this field end-to-end. SGLang Model Gateway may drop this extra metadata when it reconstructs responses with a fixed schema (see section 3). + +--- + +## 3. Differences vs SGLang Model Gateway + +Miles Router and SGLang Model Gateway can both route requests to workers, but they are optimized for different goals. + +### Key differences + +Miles Router is a lightweight Python/FastAPI proxy that acts as a passthrough to SGLang workers. This passthrough design enables RL-specific features like radix-tree trajectory caching and R3 (which require preserving raw response metadata like `routed_experts`). + +SGLang Model Gateway is a high-performance Rust-based router optimized for large-scale inference: async non-blocking routing, advanced fault tolerance (retries, circuit breakers), multiple load balancing policies (including cache-aware routing), and PD disaggregation support. However, it reconstructs responses with a fixed schema, so it does not preserve the metadata needed for miles's R3 flow. + +For more details on SGLang Model Gateway, see the [official documentation](https://docs.sglang.io/advanced_features/sgl_model_gateway.html). + +### When to use which + +- Use Miles Router when you need R3 or radix-tree caching +- Use SGLang Model Gateway for everything else (recommended default) + diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index b1088ce643..bfd5024228 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -417,3 +417,4 @@ Stabilize MoE RL training by recording and replaying expert routing decisions to | `--use-routing-replay` | Forward-backward routing consistency in training. ([arXiv:2507.18071](https://arxiv.org/abs/2507.18071)) | | `--use-rollout-routing-replay` | R3: Replay routing from rollout during training. **Requires `--use-miles-router`**. ([arXiv:2510.11370](https://arxiv.org/abs/2510.11370)) | +For detailed explanation of R3 and MilesRouter, see [Miles Router](../advanced/miles-router.md). diff --git a/docs/en/index.rst b/docs/en/index.rst index afafc67966..3f08d98d02 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -41,6 +41,7 @@ miles is the RL-framework behind GLM-4.7, GLM-4.6 and GLM-4.5. Apart from models :caption: Advanced Features _examples_synced/reproducibility/README.md + advanced/miles-router.md advanced/speculative-decoding.md advanced/fault-tolerance.md advanced/arch-support-beyond-megatron.md From 05e73c376883377cc5d00cbf909cc2098323c7e1 Mon Sep 17 00:00:00 2001 From: Banghua Zhu Date: Mon, 9 Feb 2026 20:39:03 -0800 Subject: [PATCH 71/77] fix: allow --seq-length CLI override in megatron backend (#568) Co-authored-by: Banghua Zhu Co-authored-by: Claude Opus 4.5 --- miles/backends/megatron_utils/arguments.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/miles/backends/megatron_utils/arguments.py b/miles/backends/megatron_utils/arguments.py index 0eb2bcd444..24496011b1 100644 --- a/miles/backends/megatron_utils/arguments.py +++ b/miles/backends/megatron_utils/arguments.py @@ -14,7 +14,8 @@ def set_default_megatron_args(args): # TODO: maybe change this after megatron has good fp8 support args.bf16 = not args.fp16 # placeholders - args.seq_length = 4096 + if args.seq_length is None: + args.seq_length = 4096 args.max_position_embeddings = args.seq_length # TODO: revisit this when megatron(dev) have solved the optimizer-cpu-offload ckpt saving bug args.dist_ckpt_save_pre_mcore_014 = True From 6fd53f61425e5c0936c8da927a797d04e1c585c3 Mon Sep 17 00:00:00 2001 From: Yuzhen Zhou <82826991+zyzshishui@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:45:35 -0800 Subject: [PATCH 72/77] [AMD] Bump Megatron to 3714d81d and SGLang v0.5.7 (#563) --- docker/Dockerfile.rocm | 2 +- docker/Dockerfile.rocm_MI350-5 | 29 +- docker/amd_patch/sglv0.5.0rc0/megatron.patch | 872 ++++++++++++------ .../amd_megatron_fused_kernels_init.patch | 51 + docker/amd_patch/sglv0.5.7/megatron.patch | 792 ++++++++++++++++ docker/amd_patch/sglv0.5.7/sglang.patch | 36 + 6 files changed, 1515 insertions(+), 267 deletions(-) create mode 100644 docker/amd_patch/sglv0.5.7/amd_megatron_fused_kernels_init.patch create mode 100644 docker/amd_patch/sglv0.5.7/megatron.patch create mode 100644 docker/amd_patch/sglv0.5.7/sglang.patch diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 8dafd4cd55..41c5e93563 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -263,7 +263,7 @@ COPY amd_patch/sglv0.5.0rc0 /workspace/patch RUN pip uninstall -y megatron-core && \ git clone https://github.com/NVIDIA/Megatron-LM && \ cd Megatron-LM && \ - git checkout 48406695c4efcf1026a7ed70bb390793918dd97b && \ + git checkout 3714d81d418c9f1bca4594fc35f9e8289f652862 && \ git apply /workspace/patch/amd_megatron_fused_kernels_init.patch && \ pip install -vvv -e . && \ cd /workspace/ diff --git a/docker/Dockerfile.rocm_MI350-5 b/docker/Dockerfile.rocm_MI350-5 index dd32f32c5e..7db29a7517 100644 --- a/docker/Dockerfile.rocm_MI350-5 +++ b/docker/Dockerfile.rocm_MI350-5 @@ -1,8 +1,23 @@ #### Use the base image for ROCm 7 / gfx950 (MI355) -# The Docker image built with this Dockerfile: +# ===================================================================== +# Docker Image Version Information (Updated: Feb 5, 2026) +# ===================================================================== # Base image: ROCm 7 with vllm pre-built for gfx950 # Target GPU: MI355 (gfx950) +# +# Key Dependencies: +# - sglang: v0.5.7 +# - sgl_kernel: 0.3.20 (built from sglang v0.5.7) +# - Megatron-LM: commit 3714d81d418c9f1bca4594fc35f9e8289f652862 +# - TransformerEngine: commit 90c04bcdc3c109505b318f40a39680263af55edf +# - aiter: v0.1.7.post2 +# - Ray: 2.47.1 +# +# Patches: amd_patch/sglv0.5.7/ +# - sglang.patch +# - megatron.patch, amd_megatron_fused_kernels_init.patch +# ===================================================================== FROM rocm/sgl-dev:rocm7-vllm-20250904 @@ -70,7 +85,7 @@ RUN pip uninstall -y megatron-core || true RUN rm -rf Megatron-LM RUN git clone https://github.com/NVIDIA/Megatron-LM \ && cd Megatron-LM \ - && git checkout 48406695c4efcf1026a7ed70bb390793918dd97b \ + && git checkout 3714d81d418c9f1bca4594fc35f9e8289f652862 \ && pip install -e . ######################################### ######################################### @@ -99,7 +114,7 @@ RUN pip install "ray[data,train,tune,serve]==2.47.1" ######################################### ###6. Install torch_memory_saver######### ######################################### -RUN pip install torch_memory_saver +RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@64a92e1d7fb822ea4af5579c8cebb162692c531c --no-cache-dir --force-reinstall ######################################### ######################################### @@ -167,7 +182,7 @@ RUN pip uninstall -y sgl_kernel sglang || true RUN rm -rf sglang RUN git clone https://github.com/sgl-project/sglang.git \ && cd sglang \ - && git checkout v0.5.6 + && git checkout v0.5.7 # Build sgl-kernel for gfx950 RUN cd sglang/sgl-kernel \ @@ -194,8 +209,8 @@ RUN python -m pip cache purge #### APPLY PATCHES (gfx950/MI355) ######### ########################################### -# Copy patches from miles repo -COPY amd_patch/latest /app/patch +# Copy patches from miles repo (sglang v0.5.7 specific) +COPY amd_patch/sglv0.5.7 /app/patch # Apply Megatron patches RUN cd /app/Megatron-LM \ @@ -209,7 +224,7 @@ RUN cd /app/Megatron-LM \ # Apply SGLang patch RUN cd /app/sglang \ - && git apply /app/patch/sglang.patch || echo "Check patch compatibility with v0.5.6" \ + && git apply /app/patch/sglang.patch \ && if grep -R -n '^<<<<<<< ' .; then \ echo "Patch failed to apply cleanly. Please resolve conflicts." && \ exit 1; \ diff --git a/docker/amd_patch/sglv0.5.0rc0/megatron.patch b/docker/amd_patch/sglv0.5.0rc0/megatron.patch index c840133cef..b129959aff 100644 --- a/docker/amd_patch/sglv0.5.0rc0/megatron.patch +++ b/docker/amd_patch/sglv0.5.0rc0/megatron.patch @@ -1,5 +1,56 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index 5a1ea308d..aa701237f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py -index fe26e8b4..4451f277 100644 +index fe26e8b43..4451f2776 100644 --- a/megatron/core/distributed/__init__.py +++ b/megatron/core/distributed/__init__.py @@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads @@ -19,10 +70,10 @@ index fe26e8b4..4451f277 100644 + if hasattr(custom_fsdp, 'MegatronFSDP'): + custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py -index 99c3edc0..26ea5cb4 100644 +index acb93ef78..d239db4ab 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py -@@ -404,6 +404,7 @@ class TELinear(te.pytorch.Linear): +@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): ) for param in self.parameters(): @@ -30,49 +81,418 @@ index 99c3edc0..26ea5cb4 100644 if is_expert: # Reduce the gradient on the expert_data_parallel group for expert linear layers setattr(param, "allreduce", not self.expert_parallel) +@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): + + + if HAVE_TE and is_te_min_version("1.9.0.dev0"): ++ def ceil_div(x: int, y: int) -> int: ++ return (x + y - 1) // y ++ ++ class _FakeInt4QuantizationSTE(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, x, group_size): ++ m, n = x.shape ++ block_size_m, block_size_n = 1, group_size ++ ++ ++ m_padded = ceil_div(m, block_size_m) * block_size_m ++ n_padded = ceil_div(n, block_size_n) * block_size_n ++ ++ x_padded = torch.zeros( ++ (m_padded, n_padded), ++ dtype=x.dtype, device=x.device ++ ) ++ x_padded[:m, :n] = x ++ ++ x_view = x_padded.view( ++ m_padded // block_size_m, ++ block_size_m, ++ n_padded // block_size_n, ++ block_size_n ++ ) ++ ++ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) ++ q_max = 7 ++ x_scale = x_max / q_max ++ ++ x_scale = x_scale.clamp(min=1e-5) ++ ++ x_div = x_view / x_scale ++ x_round = torch.round(x_div) ++ ++ x_q_clamped = x_round.clamp(-q_max, q_max) ++ ++ x_dequant_view = x_q_clamped * x_scale ++ ++ x_dequant_full = x_dequant_view.view_as(x_padded) ++ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) ++ ++ return x_out ++ ++ @staticmethod ++ def backward(ctx, grad_output): ++ return grad_output, None ++ ++ def fake_int4_quantization_ste(x, group_size): ++ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) ++ ++ if hasattr(x, 'main_grad'): ++ x_out.main_grad = x.main_grad ++ ++ return x_out + + class TEGroupedLinear(te.pytorch.GroupedLinear): + """ +@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) ++ + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + +@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + return out + return out, None + ++ def _get_weight_tensors(self): ++ """Get the weight tensors of the module.""" ++ weight_tensors = super()._get_weight_tensors() ++ ++ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": ++ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) ++ ++ weight_tensors = [ ++ fake_int4_quantization_ste(w, group_size) ++ for w in weight_tensors ++ ] ++ ++ return weight_tensors ++ + def _encode_extra_state(self, state): + # TE 2.0 changed the format of extra_state to be a byte tensor + if is_te_min_version("2.0.0"): +diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +index 1fd5dcfae..c9aeef1f0 100644 +--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py ++++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( + cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + +- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads +- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads +- mask = kv_off < head_num * stride_kv_nheads +- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] +- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] +- k = tl.load(KV_ptr + k_in_off, mask=mask) +- v = tl.load(KV_ptr + v_in_off, mask=mask) ++ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ k_off = ki_range * stride_kv_nheads + kj_range ++ if v_dim > 0: ++ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ v = tl.load(KV_ptr + v_off, mask=mask_v) ++ else: ++ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) ++ k = tl.load(KV_ptr + k_off, mask=mask_k) + +- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads +- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads ++ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads ++ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads + +- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] +- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] +- tl.store(K_ptr + k_out_off, k, mask=mask) +- tl.store(V_ptr + v_out_off, v, mask=mask) ++ k_out_off = ki_range * stride_k_nheads + kj_range ++ tl.store(K_ptr + k_out_off, k, mask=mask_k) ++ if v_dim > 0: ++ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] ++ tl.store(V_ptr + v_out_off, v, mask=mask_v) + + EMB = K_POS_EMB + pid_m * stride_emb_seq + # x1 = t[..., 0::2], x2 = t[..., 1::2] +@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( + x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + ++ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ mask_x = x_range < head_num + x_left_off = ( +- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads ++ x_range * stride_k_nheads + + k_dim + + tl.arange(0, emb_dim // 2)[None, :] + ) + x_right_off = x_left_off + emb_dim // 2 +- tl.store(K_ptr + x_left_off, x_left, mask=mask) +- tl.store(K_ptr + x_right_off, x_right, mask=mask) ++ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) ++ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) + + + @triton.autotune( +@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( + else: + token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) + +- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads +- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads +- mask = dkv_off < head_num * stride_dkv_nheads +- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] +- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] +- +- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads +- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads +- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] +- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] +- dk = tl.load(dK_ptr + dk_in_off, mask=mask) +- dv = tl.load(dV_ptr + dv_in_off, mask=mask) +- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) +- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) ++ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ dk_out_off = ki_range * stride_dkv_nheads + kj_range ++ ++ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads ++ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads ++ dk_in_off = ki_range * stride_dk_nheads + kj_range ++ ++ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) ++ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) ++ ++ if v_dim > 0: ++ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] ++ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) ++ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) + + if pid_head == 0: + x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): +- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads +- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim ++ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads ++ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads + mask = x_off < head_num * stride_dk_nheads + x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] + x_right_off = x_left_off + emb_dim // 2 +@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) + o_value = kv.new_empty(total_seqlen, nheads, v_dim) ++ k_dim_ceil = triton.next_power_of_2(k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_fwd_kv_kernel[grid]( +@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + emb_dim, + k_dim, ++ k_dim_ceil, + v_dim, + nheads, + batch_size, +@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) + d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) ++ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_bwd_kv_kernel[grid]( +@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + ctx.emb_dim, + ctx.k_dim, ++ k_dim_ceil, + ctx.v_dim, + nheads, + batch_size, +diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py +index 13d74aa52..060898a7a 100644 +--- a/megatron/core/models/common/language_module/language_module.py ++++ b/megatron/core/models/common/language_module/language_module.py +@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): + assert ( + column_parallel_linear is not None + ), "column_parallel_linear cannot be None when not using fused linear cross entropy." +- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) ++ # output ++ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} ++ output_layer_buffers = dict(column_parallel_linear.named_buffers()) ++ logits, _ = torch.func.functional_call( ++ column_parallel_linear, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden,), ++ col_linear_kwargs, ++ ) + + return self.compute_language_model_loss(labels, logits) + diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py -index 002edb92..f7273488 100755 +index e21127b87..712793853 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py -@@ -80,6 +80,8 @@ def get_gpt_layer_with_transformer_engine_spec( - use_te_op_fuser: Optional[bool] = False, +@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( use_kitchen: bool = False, use_te_activation_func: bool = False, + fallback_to_eager_attn: bool = False, + post_self_attn_layernorm: bool = False, + post_mlp_layernorm: bool = False, ) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). -@@ -182,9 +184,11 @@ def get_gpt_layer_with_transformer_engine_spec( - ), - ), - self_attn_bda=get_bias_dropout_add, -+ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, - pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, - mlp=mlp, - mlp_bda=get_bias_dropout_add, -+ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, - sharded_state_dict_keys_map={ - "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", - "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", +@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( + mlp=mlp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + normalization=normalization, ++ post_self_attn_layernorm=post_self_attn_layernorm, ++ post_mlp_layernorm=post_mlp_layernorm, + ) + + +@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( + mlp: ModuleSpec, + sharded_state_dict_keys_map: Optional[dict] = None, + normalization: Optional[str] = None, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Helper function to get module spec for TransformerLayer""" + +@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py -index df9adc3e..2f4f544a 100644 +index a1230568c..1fd52f65a 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py -@@ -443,7 +443,7 @@ class GPTModel(LanguageModule): +@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post +@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): + output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - - if mtp_in_postprocess: -+ if mtp_in_postprocess and labels is not None: ++ ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, +@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): + return hidden_states + + # Skip when mtp_num_layers is None or 0 +- if self.config.mtp_num_layers: +- mtp_labels = labels.clone() ++ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( +@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ +- 'weight': output_weight, ++ 'weight': output_weight.detach() if output_weight else None, + 'runtime_gather_output': runtime_gather_output, + }, + ) +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index 6e093f96f..eac21a3ea 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + if key == 'padding': + tensors[key] = LocalNonpersistentObject(tensors[key]) + continue ++ if key == 'step': ++ continue + assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( + tensors[key].shape, + gbuf_local_start, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py -index 57332ac3..f3abd642 100644 +index a273002b9..4f821cfd5 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py -@@ -9,6 +9,7 @@ from typing import Callable, List, Optional +@@ -11,6 +11,7 @@ from typing import Callable, List, Optional import numpy as np import torch @@ -80,222 +500,8 @@ index 57332ac3..f3abd642 100644 from .utils import GlobalMemoryBuffer, is_torch_min_version -@@ -163,6 +164,213 @@ def get_nccl_options(pg_name, nccl_comm_cfgs): - return None - - -+old_new_group = None -+ -+ -+def monkey_patch_torch_dist(): -+ print("Applying monkey patch to torch.distributed", flush=True) -+ global old_new_group -+ if old_new_group is not None: -+ return -+ -+ old_new_group = dist.new_group -+ -+ def new_group(*args, **kwargs): -+ group = old_new_group(*args, **kwargs) -+ # skip none nccl group. -+ if ( -+ len(args) >= 3 and args[2] == "gloo" or -+ "backend" in kwargs and kwargs["backend"] == "gloo" -+ ): -+ return group -+ -+ # Get ranks from arguments -+ if len(args) >= 1 and args[0] is not None: -+ ranks = args[0] -+ elif "ranks" in kwargs and kwargs["ranks"] is not None: -+ ranks = kwargs["ranks"] -+ else: -+ # If no ranks specified, use all ranks in world -+ ranks = list(range(dist.get_world_size())) -+ -+ if len(ranks) == 1: -+ return group -+ -+ group = ReloadableProcessGroup(group, ranks) -+ return group -+ -+ dist.new_group = new_group -+ -+ def get_new_function(func): -+ def new_function(*args, **kwargs): -+ args = ( -+ arg.group if isinstance(arg, ReloadableProcessGroup) else arg -+ for arg in args -+ ) -+ kwargs = { -+ k: (v.group if isinstance(v, ReloadableProcessGroup) else v) -+ for k, v in kwargs.items() -+ } -+ return func(*args, **kwargs) -+ return new_function -+ -+ dist.get_rank = get_new_function(dist.get_rank) -+ dist.get_world_size = get_new_function(dist.get_world_size) -+ dist.get_backend = get_new_function(dist.get_backend) -+ dist.get_global_rank = get_new_function(dist.get_global_rank) -+ dist.get_group_rank = get_new_function(dist.get_group_rank) -+ dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks) -+ -+ dist.all_reduce = get_new_function(dist.all_reduce) -+ dist.all_gather = get_new_function(dist.all_gather) -+ dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor) -+ dist.all_gather_object = get_new_function(dist.all_gather_object) -+ dist.all_to_all = get_new_function(dist.all_to_all) -+ dist.all_to_all_single = get_new_function(dist.all_to_all_single) -+ dist.broadcast = get_new_function(dist.broadcast) -+ dist.reduce = get_new_function(dist.reduce) -+ dist.reduce_scatter = get_new_function(dist.reduce_scatter) -+ dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor) -+ dist.scatter = get_new_function(dist.scatter) -+ dist.gather = get_new_function(dist.gather) -+ dist.barrier = get_new_function(dist.barrier) -+ dist.send = get_new_function(dist.send) -+ dist.recv = get_new_function(dist.recv) -+ dist._coalescing_manager = get_new_function(dist._coalescing_manager) -+ -+ # p2p -+ old_isend = dist.isend -+ old_irecv = dist.irecv -+ -+ dist.isend = get_new_function(dist.isend) -+ dist.irecv = get_new_function(dist.irecv) -+ -+ def get_new_p2pop_function(func): -+ def new_function(*args, **kwargs): -+ def convert(arg): -+ if isinstance(arg, ReloadableProcessGroup): -+ return arg.group -+ elif arg == dist.isend: -+ arg = old_isend -+ elif arg == dist.irecv: -+ arg = old_irecv -+ return arg -+ -+ args = (convert(arg) for arg in args) -+ kwargs = { -+ k: convert(v) -+ for k, v in kwargs.items() -+ } -+ return func(*args, **kwargs) -+ return new_function -+ -+ dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__) -+ dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__) -+ -+ -+ -+class ReloadableProcessGroup(torch.distributed.ProcessGroup): -+ GROUPS = [] -+ -+ def __init__(self, group, ranks): -+ super().__init__( -+ rank=dist.get_rank(group), -+ size=dist.get_world_size(group), -+ ) -+ #print(f"Creating ReloadableProcessGroup with ranks: {ranks}", flush=True) -+ self.group = group -+ self.group_info = { -+ "ranks": ranks, -+ } -+ ReloadableProcessGroup.GROUPS.append(self) -+ -+ def __getattr__(self, name): -+ return getattr(self.group, name) -+ -+ @staticmethod -+ def destroy_process_groups(): -+ for reloadable_group in ReloadableProcessGroup.GROUPS: -+ if reloadable_group.group is None: -+ continue -+ #print(f"Destroying process group: {reloadable_group.group_info['ranks']}") -+ dist.destroy_process_group(reloadable_group.group) -+ del reloadable_group.group -+ reloadable_group.group = None -+ -+ @staticmethod -+ def reload_process_groups(): -+ for reloadable_group in ReloadableProcessGroup.GROUPS: -+ if reloadable_group.group is not None: -+ continue -+ #print(f"Reloading process group: {reloadable_group.group_info['ranks']}") -+ group = old_new_group( -+ ranks=reloadable_group.group_info["ranks"], -+ backend="nccl" -+ ) -+ reloadable_group.group = group -+ -+ def rank(self) -> int: return self.group.rank() -+ def size(self) -> int: return self.group.size() -+ def name(self) -> str: return self.group.name() -+ -+ def shutdown(self) -> None: -+ if self.group is not None: -+ self.group.shutdown() -+ -+ def abort(self) -> None: -+ if self.group is not None: -+ self.group.abort() -+ -+ def _fwd(self, method, *args, **kwargs): -+ inner = self.group -+ if inner is None: -+ raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.") -+ return getattr(inner, method)(*args, **kwargs) -+ -+ def barrier(self, *a, **kw): return self._fwd("barrier", *a, **kw) -+ def broadcast(self, *a, **kw): return self._fwd("broadcast", *a, **kw) -+ def allreduce(self, *a, **kw): return self._fwd("allreduce", *a, **kw) -+ def allreduce_coalesced(self, *a, **kw): return self._fwd("allreduce_coalesced", *a, **kw) -+ def reduce(self, *a, **kw): return self._fwd("reduce", *a, **kw) -+ def allgather(self, *a, **kw): return self._fwd("allgather", *a, **kw) -+ def _allgather_base(self, *a, **kw): return self._fwd("_allgather_base", *a, **kw) -+ def allgather_coalesced(self, *a, **kw): return self._fwd("allgather_coalesced", *a, **kw) -+ def allgather_into_tensor_coalesced(self, *a, **kw): return self._fwd("allgather_into_tensor_coalesced", *a, **kw) -+ def gather(self, *a, **kw): return self._fwd("gather", *a, **kw) -+ def scatter(self, *a, **kw): return self._fwd("scatter", *a, **kw) -+ def reduce_scatter(self, *a, **kw): return self._fwd("reduce_scatter", *a, **kw) -+ def _reduce_scatter_base(self, *a, **kw): return self._fwd("_reduce_scatter_base", *a, **kw) -+ def reduce_scatter_tensor_coalesced(self, *a, **kw): return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw) -+ def alltoall_base(self, *a, **kw): return self._fwd("alltoall_base", *a, **kw) -+ def alltoall(self, *a, **kw): return self._fwd("alltoall", *a, **kw) -+ def send(self, *a, **kw): return self._fwd("send", *a, **kw) -+ def recv(self, *a, **kw): return self._fwd("recv", *a, **kw) -+ def recv_anysource(self, *a, **kw): return self._fwd("recv_anysource", *a, **kw) -+ -+ def _start_coalescing(self, *a, **kw): return self._fwd("_start_coalescing", *a, **kw) -+ def _end_coalescing(self, *a, **kw): return self._fwd("_end_coalescing", *a, **kw) -+ def _get_backend_name(self): return self._fwd("_get_backend_name") -+ def _get_backend(self, *a, **kw): return self._fwd("_get_backend", *a, **kw) -+ def _set_default_backend(self, *a, **kw): return self._fwd("_set_default_backend", *a, **kw) -+ @property -+ def bound_device_id(self): return self.group.bound_device_id -+ @bound_device_id.setter -+ def bound_device_id(self, dev): self.group.bound_device_id = dev -+ -+ -+def destroy_process_groups(): -+ """Destroy all reloadable process groups.""" -+ ReloadableProcessGroup.destroy_process_groups() -+ -+ -+def reload_process_groups(): -+ """Reload all reloadable process groups.""" -+ ReloadableProcessGroup.reload_process_groups() -+ -+ -+monkey_patch_torch_dist() -+ -+ - def create_group( - ranks=None, - timeout=None, diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py -index 63ee9d1f..b90b744c 100644 +index ac839c21f..f18309217 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( @@ -325,13 +531,148 @@ index 63ee9d1f..b90b744c 100644 ) ops.append(recv_next_op) if len(ops) > 0: +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 28cff06f5..48c9c1a25 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -587,6 +587,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from miles.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 16fc9d9af..3e95858a6 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -201,6 +201,9 @@ class TopKRouter(Router): + self.global_tokens_per_expert = None + self.ga_steps = None + ++ from miles.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index a8f4abfcd..f33f6f05e 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) +- position_ids, _ = roll_tensor( +- position_ids, +- shifts=-1, +- dims=-1, +- cp_group=self.cp_group, +- packed_seq_params=packed_seq_params, +- ) ++ if position_ids is not None: ++ position_ids, _ = roll_tensor( ++ position_ids, ++ shifts=-1, ++ dims=-1, ++ cp_group=self.cp_group, ++ packed_seq_params=packed_seq_params, ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: +@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index 6f557e1f..b295fd35 100644 +index e2705bd9f..a0aa109b5 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py -@@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig): - qk_layernorm: bool = False - """Whether to apply `normalization` type of normalization to the query and key embeddings.""" +@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): + attention_output_gate: bool = False + """Whether to apply output gate to the attention layers.""" + post_self_attn_layernorm: bool = False + post_mlp_layernorm: bool = False @@ -340,10 +681,10 @@ index 6f557e1f..b295fd35 100644 """Whether to run real-time tests.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py -index 84f22bde..b4807d26 100644 +index 3ea405770..5a42001b9 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py -@@ -224,6 +224,7 @@ class TransformerLayerSubmodules: +@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: input_layernorm: Union[ModuleSpec, type] = IdentityOp self_attention: Union[ModuleSpec, type] = IdentityOp self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp @@ -351,7 +692,7 @@ index 84f22bde..b4807d26 100644 pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp -@@ -232,6 +233,7 @@ class TransformerLayerSubmodules: +@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp @@ -359,7 +700,7 @@ index 84f22bde..b4807d26 100644 # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) -@@ -336,6 +338,14 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): +@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): # [Module 3: BiasDropoutFusion] self.self_attn_bda = build_module(submodules.self_attn_bda) @@ -369,14 +710,13 @@ index 84f22bde..b4807d26 100644 + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) -+ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn self.pre_cross_attn_layernorm = build_module( submodules.pre_cross_attn_layernorm, -@@ -399,6 +409,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): - # [Module 9: BiasDropoutFusion] - self.mlp_bda = build_module(submodules.mlp_bda) +@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + self.post_mlp_layernorm = build_module( + submodules.post_mlp_layernorm, @@ -388,19 +728,18 @@ index 84f22bde..b4807d26 100644 self.recompute_input_layernorm = False self.recompute_pre_mlp_layernorm = False self.recompute_mlp = False -@@ -535,6 +552,11 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): +@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): attention_output_with_bias[0] ) + attention_output, attention_output_bias = attention_output_with_bias + attention_output = self.post_self_attn_layernorm(attention_output) + attention_output_with_bias = (attention_output, attention_output_bias) -+ + # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? nvtx_range_push(suffix="self_attn_bda") -@@ -635,6 +657,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): +@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): else: mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) @@ -412,12 +751,12 @@ index 84f22bde..b4807d26 100644 # discard the output of the pre-mlp layernorm and register the recompute # as a gradient hook of mlp_output_with_bias[0] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index 24ba8926..4f039fd4 100644 +index b267c8a81..83736acdc 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py -@@ -1191,6 +1191,9 @@ def core_transformer_config_from_args(args, config_class=None): - if args.is_hybrid_model: - kw_args['is_hybrid_model'] = args.is_hybrid_model +@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm @@ -425,7 +764,7 @@ index 24ba8926..4f039fd4 100644 # handle quantization config # NOTE: Kitchen arguments are only added to the namespace when # Kitchen library is available. -@@ -1481,6 +1484,10 @@ def _add_network_size_args(parser): +@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): action='store_true', help='If set, use original BERT residula connection ' 'ordering.') @@ -433,6 +772,21 @@ index 24ba8926..4f039fd4 100644 + help='If set, use post self attention layernorm.') + group.add_argument('--post-mlp-layernorm', action='store_true', + help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') group.add_argument('--openai-gelu', action='store_true', help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 13b7526ca..6c590f653 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, +- trust_remote_code=trust_remote_code, ++ trust_remote_code=True, + **kwargs, + ) + self._vocab = self._tokenizer.get_vocab() diff --git a/docker/amd_patch/sglv0.5.7/amd_megatron_fused_kernels_init.patch b/docker/amd_patch/sglv0.5.7/amd_megatron_fused_kernels_init.patch new file mode 100644 index 0000000000..f6efca346d --- /dev/null +++ b/docker/amd_patch/sglv0.5.7/amd_megatron_fused_kernels_init.patch @@ -0,0 +1,51 @@ +diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py +index 87cceac3..ac686d74 100644 +--- a/megatron/legacy/fused_kernels/__init__.py ++++ b/megatron/legacy/fused_kernels/__init__.py +@@ -3,6 +3,7 @@ + import os + import pathlib + import subprocess ++import torch + + from torch.utils import cpp_extension + +@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + + def load(args): +- +- # Check if cuda 11 is installed for compute capability 8.0 +- cc_flag = [] +- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( +- cpp_extension.CUDA_HOME +- ) +- if int(bare_metal_major) >= 11: +- cc_flag.append('-gencode') +- cc_flag.append('arch=compute_80,code=sm_80') +- if int(bare_metal_minor) >= 8: ++ if torch.cuda.is_available() and torch.version.cuda: ++ # Check if cuda 11 is installed for compute capability 8.0 ++ cc_flag = [] ++ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( ++ cpp_extension.CUDA_HOME ++ ) ++ if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') +- cc_flag.append('arch=compute_90,code=sm_90') ++ cc_flag.append('arch=compute_80,code=sm_80') ++ if int(bare_metal_minor) >= 8: ++ cc_flag.append('-gencode') ++ cc_flag.append('arch=compute_90,code=sm_90') + +- # Build path +- srcpath = pathlib.Path(__file__).parent.absolute() +- buildpath = srcpath / "build" +- _create_build_dir(buildpath) ++ # Build path ++ srcpath = pathlib.Path(__file__).parent.absolute() ++ buildpath = srcpath / "build" ++ _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): diff --git a/docker/amd_patch/sglv0.5.7/megatron.patch b/docker/amd_patch/sglv0.5.7/megatron.patch new file mode 100644 index 0000000000..b129959aff --- /dev/null +++ b/docker/amd_patch/sglv0.5.7/megatron.patch @@ -0,0 +1,792 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index 5a1ea308d..aa701237f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + +diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py +index fe26e8b43..4451f2776 100644 +--- a/megatron/core/distributed/__init__.py ++++ b/megatron/core/distributed/__init__.py +@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads + from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel + from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel + from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig ++ ++# Backward compatibility patch for FSDP module reorganization ++import sys ++import importlib.util ++ ++spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') ++if spec: ++ custom_fsdp = importlib.util.module_from_spec(spec) ++ spec.loader.exec_module(custom_fsdp) ++ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp ++ if hasattr(custom_fsdp, 'MegatronFSDP'): ++ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index acb93ef78..d239db4ab 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): + ) + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): + + + if HAVE_TE and is_te_min_version("1.9.0.dev0"): ++ def ceil_div(x: int, y: int) -> int: ++ return (x + y - 1) // y ++ ++ class _FakeInt4QuantizationSTE(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, x, group_size): ++ m, n = x.shape ++ block_size_m, block_size_n = 1, group_size ++ ++ ++ m_padded = ceil_div(m, block_size_m) * block_size_m ++ n_padded = ceil_div(n, block_size_n) * block_size_n ++ ++ x_padded = torch.zeros( ++ (m_padded, n_padded), ++ dtype=x.dtype, device=x.device ++ ) ++ x_padded[:m, :n] = x ++ ++ x_view = x_padded.view( ++ m_padded // block_size_m, ++ block_size_m, ++ n_padded // block_size_n, ++ block_size_n ++ ) ++ ++ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) ++ q_max = 7 ++ x_scale = x_max / q_max ++ ++ x_scale = x_scale.clamp(min=1e-5) ++ ++ x_div = x_view / x_scale ++ x_round = torch.round(x_div) ++ ++ x_q_clamped = x_round.clamp(-q_max, q_max) ++ ++ x_dequant_view = x_q_clamped * x_scale ++ ++ x_dequant_full = x_dequant_view.view_as(x_padded) ++ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) ++ ++ return x_out ++ ++ @staticmethod ++ def backward(ctx, grad_output): ++ return grad_output, None ++ ++ def fake_int4_quantization_ste(x, group_size): ++ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) ++ ++ if hasattr(x, 'main_grad'): ++ x_out.main_grad = x.main_grad ++ ++ return x_out + + class TEGroupedLinear(te.pytorch.GroupedLinear): + """ +@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) ++ + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + +@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + return out + return out, None + ++ def _get_weight_tensors(self): ++ """Get the weight tensors of the module.""" ++ weight_tensors = super()._get_weight_tensors() ++ ++ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": ++ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) ++ ++ weight_tensors = [ ++ fake_int4_quantization_ste(w, group_size) ++ for w in weight_tensors ++ ] ++ ++ return weight_tensors ++ + def _encode_extra_state(self, state): + # TE 2.0 changed the format of extra_state to be a byte tensor + if is_te_min_version("2.0.0"): +diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +index 1fd5dcfae..c9aeef1f0 100644 +--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py ++++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( + cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + +- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads +- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads +- mask = kv_off < head_num * stride_kv_nheads +- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] +- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] +- k = tl.load(KV_ptr + k_in_off, mask=mask) +- v = tl.load(KV_ptr + v_in_off, mask=mask) ++ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ k_off = ki_range * stride_kv_nheads + kj_range ++ if v_dim > 0: ++ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ v = tl.load(KV_ptr + v_off, mask=mask_v) ++ else: ++ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) ++ k = tl.load(KV_ptr + k_off, mask=mask_k) + +- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads +- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads ++ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads ++ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads + +- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] +- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] +- tl.store(K_ptr + k_out_off, k, mask=mask) +- tl.store(V_ptr + v_out_off, v, mask=mask) ++ k_out_off = ki_range * stride_k_nheads + kj_range ++ tl.store(K_ptr + k_out_off, k, mask=mask_k) ++ if v_dim > 0: ++ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] ++ tl.store(V_ptr + v_out_off, v, mask=mask_v) + + EMB = K_POS_EMB + pid_m * stride_emb_seq + # x1 = t[..., 0::2], x2 = t[..., 1::2] +@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( + x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + ++ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ mask_x = x_range < head_num + x_left_off = ( +- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads ++ x_range * stride_k_nheads + + k_dim + + tl.arange(0, emb_dim // 2)[None, :] + ) + x_right_off = x_left_off + emb_dim // 2 +- tl.store(K_ptr + x_left_off, x_left, mask=mask) +- tl.store(K_ptr + x_right_off, x_right, mask=mask) ++ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) ++ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) + + + @triton.autotune( +@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( + else: + token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) + +- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads +- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads +- mask = dkv_off < head_num * stride_dkv_nheads +- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] +- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] +- +- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads +- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads +- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] +- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] +- dk = tl.load(dK_ptr + dk_in_off, mask=mask) +- dv = tl.load(dV_ptr + dv_in_off, mask=mask) +- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) +- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) ++ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ dk_out_off = ki_range * stride_dkv_nheads + kj_range ++ ++ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads ++ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads ++ dk_in_off = ki_range * stride_dk_nheads + kj_range ++ ++ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) ++ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) ++ ++ if v_dim > 0: ++ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] ++ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) ++ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) + + if pid_head == 0: + x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): +- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads +- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim ++ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads ++ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads + mask = x_off < head_num * stride_dk_nheads + x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] + x_right_off = x_left_off + emb_dim // 2 +@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) + o_value = kv.new_empty(total_seqlen, nheads, v_dim) ++ k_dim_ceil = triton.next_power_of_2(k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_fwd_kv_kernel[grid]( +@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + emb_dim, + k_dim, ++ k_dim_ceil, + v_dim, + nheads, + batch_size, +@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) + d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) ++ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_bwd_kv_kernel[grid]( +@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + ctx.emb_dim, + ctx.k_dim, ++ k_dim_ceil, + ctx.v_dim, + nheads, + batch_size, +diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py +index 13d74aa52..060898a7a 100644 +--- a/megatron/core/models/common/language_module/language_module.py ++++ b/megatron/core/models/common/language_module/language_module.py +@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): + assert ( + column_parallel_linear is not None + ), "column_parallel_linear cannot be None when not using fused linear cross entropy." +- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) ++ # output ++ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} ++ output_layer_buffers = dict(column_parallel_linear.named_buffers()) ++ logits, _ = torch.func.functional_call( ++ column_parallel_linear, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden,), ++ col_linear_kwargs, ++ ) + + return self.compute_language_model_loss(labels, logits) + +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index e21127b87..712793853 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( + use_kitchen: bool = False, + use_te_activation_func: bool = False, + fallback_to_eager_attn: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + +@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( + mlp=mlp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + normalization=normalization, ++ post_self_attn_layernorm=post_self_attn_layernorm, ++ post_mlp_layernorm=post_mlp_layernorm, + ) + + +@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( + mlp: ModuleSpec, + sharded_state_dict_keys_map: Optional[dict] = None, + normalization: Optional[str] = None, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Helper function to get module spec for TransformerLayer""" + +@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index a1230568c..1fd52f65a 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post +@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() +- if mtp_in_postprocess: ++ ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, +@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): + return hidden_states + + # Skip when mtp_num_layers is None or 0 +- if self.config.mtp_num_layers: +- mtp_labels = labels.clone() ++ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( +@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ +- 'weight': output_weight, ++ 'weight': output_weight.detach() if output_weight else None, + 'runtime_gather_output': runtime_gather_output, + }, + ) +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index 6e093f96f..eac21a3ea 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + if key == 'padding': + tensors[key] = LocalNonpersistentObject(tensors[key]) + continue ++ if key == 'step': ++ continue + assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( + tensors[key].shape, + gbuf_local_start, +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index a273002b9..4f821cfd5 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -11,6 +11,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from .utils import GlobalMemoryBuffer, is_torch_min_version + +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index ac839c21f..f18309217 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -26,22 +26,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 28cff06f5..48c9c1a25 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -587,6 +587,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from miles.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 16fc9d9af..3e95858a6 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -201,6 +201,9 @@ class TopKRouter(Router): + self.global_tokens_per_expert = None + self.ga_steps = None + ++ from miles.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index a8f4abfcd..f33f6f05e 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) +- position_ids, _ = roll_tensor( +- position_ids, +- shifts=-1, +- dims=-1, +- cp_group=self.cp_group, +- packed_seq_params=packed_seq_params, +- ) ++ if position_ids is not None: ++ position_ids, _ = roll_tensor( ++ position_ids, ++ shifts=-1, ++ dims=-1, ++ cp_group=self.cp_group, ++ packed_seq_params=packed_seq_params, ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: +@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index e2705bd9f..a0aa109b5 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): + attention_output_gate: bool = False + """Whether to apply output gate to the attention layers.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index 3ea405770..5a42001b9 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + cross_attention: Union[ModuleSpec, type] = IdentityOp +@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = build_module( + submodules.pre_cross_attn_layernorm, +@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index b267c8a81..83736acdc 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') ++ group.add_argument('--post-self-attn-layernorm', action='store_true', ++ help='If set, use post self attention layernorm.') ++ group.add_argument('--post-mlp-layernorm', action='store_true', ++ help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 13b7526ca..6c590f653 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, +- trust_remote_code=trust_remote_code, ++ trust_remote_code=True, + **kwargs, + ) + self._vocab = self._tokenizer.get_vocab() diff --git a/docker/amd_patch/sglv0.5.7/sglang.patch b/docker/amd_patch/sglv0.5.7/sglang.patch new file mode 100644 index 0000000000..e1d6562e16 --- /dev/null +++ b/docker/amd_patch/sglv0.5.7/sglang.patch @@ -0,0 +1,36 @@ +diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py +index 8e3429dec..494a754b3 100644 +--- a/python/sglang/srt/distributed/parallel_state.py ++++ b/python/sglang/srt/distributed/parallel_state.py +@@ -1849,7 +1849,10 @@ def get_tensor_model_parallel_world_size(): + + def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" +- return get_tp_group().rank_in_group ++ try: ++ return get_tp_group().rank_in_group ++ except Exception: ++ return 0 + + + def get_pipeline_model_parallel_world_size(): +diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py +index b38a83d57..a492e3ef8 100644 +--- a/python/sglang/srt/models/qwen3_next.py ++++ b/python/sglang/srt/models/qwen3_next.py +@@ -45,13 +45,14 @@ from sglang.srt.utils import ( + LazyValue, + add_prefix, + is_cuda, ++ is_cuda_alike, + is_npu, + make_layers, + set_weight_attrs, + ) + + logger = logging.getLogger(__name__) +-_is_cuda = is_cuda() ++_is_cuda = is_cuda_alike() + _is_npu = is_npu() + + From 500d9e99ccb24cc6480ceea72dc3ebdf84ac15c0 Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Tue, 10 Feb 2026 16:55:12 -0800 Subject: [PATCH 73/77] [CI] Fix CI oom in Qwen-30B-A3B (#580) --- tests/e2e/image/test_qwen3_30B_A3B.py | 2 +- tests/e2e/megatron/test_qwen3_30B_A3B.py | 2 +- tests/e2e/megatron/test_qwen3_30B_A3B_r3.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/e2e/image/test_qwen3_30B_A3B.py b/tests/e2e/image/test_qwen3_30B_A3B.py index b30eeed8e5..95649e2a33 100644 --- a/tests/e2e/image/test_qwen3_30B_A3B.py +++ b/tests/e2e/image/test_qwen3_30B_A3B.py @@ -93,7 +93,7 @@ def execute(): sglang_args = ( "--rollout-num-gpus-per-engine 8 " - "--sglang-mem-fraction-static 0.8 " + f"--sglang-mem-fraction-static {0.7 if TIGHT_HOST_MEMORY else 0.8} " "--sglang-max-running-requests 512 " "--sglang-enable-metrics " ) diff --git a/tests/e2e/megatron/test_qwen3_30B_A3B.py b/tests/e2e/megatron/test_qwen3_30B_A3B.py index b30eeed8e5..95649e2a33 100644 --- a/tests/e2e/megatron/test_qwen3_30B_A3B.py +++ b/tests/e2e/megatron/test_qwen3_30B_A3B.py @@ -93,7 +93,7 @@ def execute(): sglang_args = ( "--rollout-num-gpus-per-engine 8 " - "--sglang-mem-fraction-static 0.8 " + f"--sglang-mem-fraction-static {0.7 if TIGHT_HOST_MEMORY else 0.8} " "--sglang-max-running-requests 512 " "--sglang-enable-metrics " ) diff --git a/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py b/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py index 5a5b968aa6..8b54176d12 100644 --- a/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py +++ b/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py @@ -94,7 +94,7 @@ def execute(): sglang_args = ( "--rollout-num-gpus-per-engine 8 " - "--sglang-mem-fraction-static 0.8 " + f"--sglang-mem-fraction-static {0.7 if TIGHT_HOST_MEMORY else 0.8} " "--sglang-max-running-requests 512 " "--sglang-enable-metrics " ) From 9286b1a4548fa907bde1952bc484f260d389c04e Mon Sep 17 00:00:00 2001 From: Ratish P <114130421+Ratish1@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:32:12 +0530 Subject: [PATCH 74/77] docs: add Miles server arguments (#517) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Zijie Xia Co-authored-by: zijiexia <37504505+zijiexia@users.noreply.github.com> Co-authored-by: 赵晨阳 --- docs/en/advanced/miles_server_args.md | 492 ++++++++++++++++++++++++++ docs/en/get_started/customization.md | 24 +- miles/utils/arguments.py | 27 +- 3 files changed, 530 insertions(+), 13 deletions(-) create mode 100644 docs/en/advanced/miles_server_args.md diff --git a/docs/en/advanced/miles_server_args.md b/docs/en/advanced/miles_server_args.md new file mode 100644 index 0000000000..c02de1545c --- /dev/null +++ b/docs/en/advanced/miles_server_args.md @@ -0,0 +1,492 @@ +# Miles Server Arguments + +This document provides a detailed list of command-line arguments used to configure Miles for RL training and inference. These arguments enable precise control over cluster resources, training backends (Megatron/FSDP), inference optimization via SGLang, and RL algorithmic hyperparameters. + +You can find all arguments by running: +```bash +python3 train.py --help +``` + +Note that this document is based on commit `a93d484` and was last updated on 02/09/2026. We try our best to ensure the quality and accuracy of these documents. Even so, it's hard to accurately describe all the hundreds of parameters' effect on such complex RL scenarios. This doc is for reference and may contain some tiny errors. + +## Argument Sources + +Miles acts as an orchestrator that integrates multiple frameworks. To help identify where an argument is directed, we follow these prefix conventions: + +* **`--sglang-*`**: Arguments passed directly to the **SGLang** rollout. +* **`--router-*`**: Arguments directed to the **SGLang Model Gateway/Router**. +* **No Prefix**: Default arguments corresponding to **Megatron-LM** (when using the Megatron backend) or **Miles native** configuration. +* **`--fsdp-*`**: Specific arguments for the experimental **FSDP** backend. + +**Note** that Arguments labeled as **Megatron-LM (Reset by Miles)** are native Megatron-LM parameters where Miles has modified the default value or behavior to better suit RL training workflows. + +## Table of Contents + +1. [Cluster and Resource Management](#cluster-and-resource-management) +2. [Training Backend](#training-backend) +3. [Rollout Management](#rollout-management) +4. [Sampling and Filtering](#sampling-and-filtering) +5. [Data Arguments](#data-arguments) +6. [Evaluation Arguments](#evaluation-arguments) +7. [Checkpointing and Resuming](#checkpointing-and-resuming) +8. [Algorithm and RL Arguments](#algorithm-and-rl-arguments) +9. [Logging and Monitoring](#logging-and-monitoring) +10. [Fault Tolerance](#fault-tolerance) +11. [Miles Router](#miles-router) +12. [Reward Model Arguments](#reward-model-arguments) +13. [Rollout Buffer Management](#rollout-buffer-management) +14. [Multi-Token Prediction (MTP) Arguments](#multi-token-prediction-mtp-arguments) +15. [SGLang Backend Arguments](#sglang-backend-arguments) +16. [Megatron Specific Arguments](#megatron-specific-arguments) +17. [FSDP Specific Arguments](#fsdp-specific-arguments) +18. [Debug and Profiling](#debug-and-profiling) +19. [Environment Variables](#environment-variables) +20. [Multi-Turn and Agentic Arguments](#multi-turn-and-agentic-arguments) +21. [Advanced Developer Hooks and CI](#advanced-developer-hooks-and-ci) +22. [Miscellaneous and System](#miscellaneous-and-system) + +## Cluster and Resource Management + +Arguments for configuring Ray cluster resources and GPU allocation. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--actor-num-nodes` | Number of nodes for training the Actor. | `1` | Type: int | Miles Native | +| `--actor-num-gpus-per-node` | Number of GPUs per node for training the Actor. | `8` | Type: int | Miles Native | +| `--critic-num-nodes` | Number of nodes for the Critic. Defaults to `--actor-num-nodes`. | `None` | Type: int | Miles Native | +| `--critic-num-gpus-per-node` | Number of GPUs per node for the Critic. Defaults to `--actor-num-gpus-per-node`. | `None` | Type: int | Miles Native | +| `--rollout-num-gpus` | Total number of GPUs required for rollout (inference). In `--colocate` mode, this is ignored and set to `actor-num-gpus-per-node * actor-num-nodes` (plus critic GPUs if enabled). | `None` | Type: int | Miles Native | +| `--rollout-num-gpus-per-engine` | Number of GPUs per inference engine, same as `tp_size` in SGLang. For multi-node serving, this should be the total GPU count / `tp_size` for each SGLang instance. | `1` | Type: int | Miles Native | +| `--num-gpus-per-node` | Total GPUs per node on the physical machine. This informs the Ray scheduler of the hardware capacity. In **Colocate mode**, it is required that the machine has fewer than 8 GPUs to calculate correct VRAM offsets. In **Disaggregated mode**, it ensures SGLang engines are distributed correctly across nodes without exceeding per-node GPU limits. | `8` | Type: int | Miles Native | +| `--colocate` | Deploy training and rollout on the same GPUs. This mode automatically enables `--offload-train` and `--offload-rollout` to facilitate weight-swapping between the training actor and inference engine. **Note:** The offload parameters are currently only used for AMD GPUs and will be removed soon. **Memory Tip:** When colocating, it is highly recommended to set `--sglang-mem-fraction-static` to **0.8** (especially on **NVIDIA Blackwell B200/B300** GPUs). This leaves sufficient VRAM (~20%) for Megatron to initialize its structures before the first weight offload to CPU occurs. On GB200/GB300, values up to 0.75 are safer for long-running jobs to prevent potential OOMs. #TODO: Verify optimal fraction for Blackwell in production | `False` | bool flag (set to enable) | Miles Native | +| `--prefill-num-servers` | Number of dedicated prefill servers for PD disaggregation. | `None` | Type: int | Miles Native | +| `--distributed-backend` | Backend for distributed communication. | `nccl` | `nccl`, `gloo` | Megatron-LM (Reset by Miles) | +| `--distributed-timeout-minutes` | Timeout for distributed operations in minutes. | `10` | Type: int | Megatron-LM (Reset by Miles) | + +Note that most use cases do not need to consider offload parameters, including `--offload-rollout, --no-offload-rollout, --offload-train, --no-offload-train`. They are used only on AMD GPUs and will eventually be removed. + +## Training Backend + +Arguments for configuring the training engine (Megatron or FSDP). + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--train-backend` | The backend for training. Highly suggest Megatron for numerical stability and efficiency. | `"megatron"` | `megatron`, `fsdp` | Miles Native | +| `--qkv-format` | Whether to pack variable-length sequences into the token dimension for training. `thd` (T-H-D, a.k.a. varlen / packed sequence) concatenates sequences and uses `cu_seqlens` to avoid padding; it is the default and is usually faster by reducing padding overhead. `bshd` (B-S-H-D) uses fixed-shape padded batches; use it for newer models with novel attention architectures (e.g., sparse attention, attention sink) where the training backend does not support `thd`. | `"thd"` | `thd`, `bshd` | Miles Native | +| `--optimizer` | Optimizer type. | `adam` | `adam`, `sgd` | Megatron-LM & FSDP | +| `--lr` | Learning rate for the Actor. | `1e-6` | Type: float | Megatron-LM (Reset by Miles) & FSDP | +| `--lr-warmup-init` | Initial learning rate for warmup. | `0.0` | Type: float | Megatron-LM & FSDP | +| `--min-lr` | Minimum learning rate after decay. | `0.0` | Type: float | Megatron-LM & FSDP | +| `--lr-decay-style` | Learning rate decay style. | `constant`(FSDP), `linear`(Megatron) | Type: str | Megatron-LM & FSDP | +| `--lr-warmup-iters` | Number of iterations for warmup. | `0` | Type: int | Megatron-LM & FSDP | +| `--lr-decay-iters` | Number of iterations for learning rate decay. | `None` | Type: int | Megatron-LM & FSDP | +| `--lr-warmup-fraction` | Fraction of total steps to warmup. | `None` | Type: float | Megatron-LM & FSDP | +| `--adam-beta1` | Beta1 for Adam optimizer. | `0.9` | Type: float | Megatron-LM & FSDP | +| `--adam-beta2` | Beta2 for Adam optimizer. | `0.95` | Type: float | Megatron-LM & FSDP | +| `--adam-eps` | Epsilon for Adam optimizer. | `1e-8` | Type: float | Megatron-LM & FSDP | +| `--true-on-policy-mode` | Strictly align SGLang's log probs and training engine's log probs to bit-wise equal. This parameter is only used for FSDP right now. [Ref](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/mismatch/blog-en.md#truly-on-policy-training) | `False` | bool flag (set to enable) | Miles Native | +| `--train-env-vars` | Extra environment variables for training process, e.g., PyTorch memory management ones. | `{}` | Type: JSON / Dict | Miles Native | +| `--train-memory-margin-bytes` | Reserved memory margin for training in bytes. Defaults to 1GB. | `1073741824` | Type: int | Miles Native | +| `--disable-weights-backuper` | Applies to `megatron` training backend only. Disables the system that backs up model weights (Actor, Ref, Old Actor) to CPU RAM. Disabling saves significant host memory but prevents features that rely on weight-swapping, such as computing the KL-divergence against a reference model. **Note**: do not set `--ref-load` and `--keep-old-actor` if disable weights backuper. | `False` | bool flag (set to disable) | Miles Native | +| `--custom-model-provider-path` | Path to a custom function that replaces the default model provider. [Ref](../get_started/customization.md#20-model-provider---custom-model-provider-path) | `None` | Type: str | Miles Native | +| `--recompute-loss-function` | Enable recomputing the loss function to save memory during training. | `False` | bool flag (set to enable) | Miles Native | +| `--log-probs-chunk-size` | Specifies the chunk size for logprobs computation to reduce peak memory usage. Processing logits in smaller batches, it prevents CUDA OOM errors during long-context prefilling or re-computation. Set to `-1` to disable chunking. [Ref](https://github.com/sgl-project/sglang/pull/6318) | `-1` | Type: int | Miles Native | +| `--keep-old-actor` | Maintains a "Model Queue" (Actor, Rollout Actor, Old Actor) to ensure importance sampling ratios are calculated against the exact policy version that generated the data. Essential for asynchronous RL where training and inference are decoupled, preventing mathematical incorrectness due to model staleness. It consumes additional Host Memory (extra ~1x model size for `update_weights_interval > 1` or 2x for `update_weights_interval == 1`) depending on update interval. | `False` | bool flag (set to enable) | Miles Native | +| `--update-weight-buffer-size` | Buffer size for updating weights, in bytes. [Ref](https://hebiao064.github.io/rl-weight-sync#42-optimizing-sglang-server-calls-with-tensor-bucketing-from-50s-to-30s) | `536870912` | Type: int | Miles Native | +| `--update-weights-interval` | Interval (in rollout rounds) for syncing weights to inference engines. Set to `>1` for async RL. | `1` | Type: int | Miles Native | +| `--fp16` | Enable FP16 mixed precision. | `False` | bool flag (set to enable) | Megatron-LM & FSDP | +| `--context-parallel-size` | Size of context parallelism. | `1` | Type: int | Megatron-LM & FSDP | +| `--deterministic-mode` | Enable deterministic mode for reproducibility. [Ref](https://lmsys.org/blog/2025-09-22-sglang-deterministic/) | `False` | bool flag (set to enable) | Megatron-LM & FSDP | + +## Rollout Management + +Arguments for configuring the rollout (inference) process and custom rollout logic. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--hf-checkpoint` | Path to the HuggingFace checkpoint used to initialize SGLang and provide the tokenizer. | `None` | Type: str | Miles Native | +| `--model-name` | The name of the model that is used to convert the Megatron weights into HuggingFace format. If not set, we will use `type(AutoConfig.from_pretrained(args.hf_checkpoint)).__name__.lower()` as `model_name`. Providing this argument can also help in cases where transformers cannot find certain models. | `None` | Type: str | Miles Native | +| `--rollout-function-path` | Path to the rollout generation function. Use this to inject custom logic (e.g., for multi-turn or tool use). [Ref](../get_started/customization.md#1-rollout-function---rollout-function-path) | `miles.rollout.sglang_rollout.generate_rollout` (or `miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn` when `MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1`) | Type: str | Miles Native | +| `--rollout-temperature` | Sampling temperature for the inference engine during rollout. | `1.0` | Type: float | Miles Native | +| `--rollout-top-p` | Top-p (nucleus) sampling threshold during rollout. | `1.0` | Type: float | Miles Native | +| `--rollout-top-k` | Top-k sampling threshold during rollout. `-1` means disabled. | `-1` | Type: int | Miles Native | +| `--rollout-max-context-len` | The maximum context size for the inference engine during rollout. It should not exceed the `max_position_embeddings` in the HuggingFace model's `config.json`. **Note:** This acts as a hard cap for the total tokens (Prompt + Response). | `None` | Type: int | Miles Native | +| `--rollout-max-prompt-len` | Maximum length of the prompt. Longer prompts are filtered during dataset initialization. This is not recommended if the dataset is large. **Note:** Defaults to `rollout-max-context-len - 1` if not set, ensuring at least one token can be generated. | `None` | Type: int | Miles Native | +| `--rollout-max-response-len` | Maximum length of the response (`max_tokens` in SGLang). **Note:** Generation will stop when either this limit is reached or the total session length hits `rollout-max-context-len`. | `None` | Type: int | Miles Native | +| `--rollout-skip-special-tokens` | Skip special tokens (e.g., `<\|im_end\|>`, `<\|endoftext\|>`) in the decoded response string. **Critical for Multi-Turn RL:** Ensures that when a response is appended to the conversation history for the next turn, it doesn't include terminal special tokens that would interfere with chat template formatting or cause early termination in subsequent turns. | `False` | bool flag (set to enable) | Miles Native | +| `--rollout-stop` | A list of strings that trigger termination of generation if they appear in the output (e.g., `"\nUser:"`). | `None` | Type: List[str] | Miles Native | +| `--rollout-stop-token-ids` | A list of numerical token IDs that trigger termination. This is the token-level equivalent of `--rollout-stop` and is preferred for special control tokens that are difficult to input as strings. | `None` | Type: List[int] | Miles Native | +| `--rollout-shuffle` | Shuffle the prompts during rollout. | `False` | bool flag (set to enable) | Miles Native | +| `--rollout-seed` | Seed for the random number generator during rollout (used for shuffling and sampling). | `42` | Type: int | Miles Native | +| `--rollout-external` | Use external SGLang instances instead of launching them inside the framework. | `False` | bool flag (set to enable) | Miles Native | +| `--rollout-external-engine-addrs` | Addresses and ports of the external engines. | `None` | Type: List[str] | Miles Native | +| `--custom-generate-function-path` | Path to override only the `generate` step within the default rollout function. If your custom `generate` returns `list[Sample]` (multi-sample), make sure your rollout pipeline can handle it; the default rollout expects a flat `list[Sample]` of length `--n-samples-per-prompt` for each prompt group. [Ref](../get_started/customization.md#2-custom-generate-function---custom-generate-function-path) | `None` | Type: str | Miles Native | +| `--custom-rollout-log-function-path` | Path to a custom function for logging training rollout data. [Ref](../get_started/customization.md#14-logging-functions) | `None` | Type: str | Miles Native | +| `--custom-eval-rollout-log-function-path` | Path to a custom function for logging evaluation rollout data. [Ref](../get_started/customization.md#14-logging-functions) | `None` | Type: str | Miles Native | +| `--rollout-data-postprocess-path` | Path to a function called after all rollout data (including log probs) is ready. [Ref](../get_started/customization.md#8-rollout-data-postprocess---rollout-data-postprocess-path) | `None` | Type: str | Miles Native | + +## Sampling and Filtering + +Arguments for sampling strategies and data filtering during rollout and buffer management. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--over-sampling-batch-size` | Number of prompts requested in each **oversampling** round when **dynamic sampling** is enabled. Miles samples `over_sampling_batch_size` prompts, generates `--n-samples-per-prompt` responses per prompt asynchronously, and then keeps/discards each prompt group via `--dynamic-sampling-filter-path`. If filtering is strict and the remaining accepted batch size drops below the target `--rollout-batch-size`, Miles automatically triggers another oversampling round of the same size. If unset, defaults to `--rollout-batch-size`. See [Dynamic Sampling](../get_started/quick_start.md#dynamic-sampling). | `None` | Type: int | Miles Native | +| `--dynamic-sampling-filter-path` | Path to the filter function for dynamic sampling. [Ref](../get_started/customization.md#4-dynamic-sampling-filter---dynamic-sampling-filter-path) | `None` | Type: str | Miles Native | +| `--partial-rollout` | Enable partial rollout for **dynamic sampling**: cache partially generated (aborted/unfinished) samples and resume generation in later rollout steps, reducing wasted compute for long responses. Cached samples are stored in the rollout buffer and can be prioritized/selected via `--buffer-filter-path` (default FIFO behavior). See [Partial Rollout](../get_started/quick_start.md#partial-rollout). | `False` | bool flag (set to enable) | Miles Native | +| `--mask-offpolicy-in-partial-rollout` | When using partial rollout, mask the previously generated (cached) response tokens so they do not contribute to the loss; only tokens generated after resuming are used for training. This helps avoid training on a cached prefix produced by an older policy version. See [Partial Rollout](../get_started/quick_start.md#partial-rollout). | `False` | bool flag (set to enable) | Miles Native | +| `--buffer-filter-path` | Path to the function to filter or sort samples in the rollout buffer before training. [Ref](../get_started/customization.md#5-buffer-filter---buffer-filter-path) | `None` | Type: str | Miles Native | +| `--rollout-sample-filter-path` | Path to the function that marks individual samples to be excluded from loss calculation. [Ref](../get_started/customization.md#6-rollout-sample-filter---rollout-sample-filter-path) | `None` | Type: str | Miles Native | +| `--rollout-all-samples-process-path` | Path to the function to process all samples (including filtered ones) after rollout. [Ref](../get_started/customization.md#7-rollout-all-samples-process---rollout-all-samples-process-path) | `None` | Type: str | Miles Native | + +## Data Arguments + +Arguments for dataset configuration, prompt mapping, and training batch sizes. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--prompt-data` | Path to the prompt dataset (JSONL format), and each line should contain `--input-key` and `--label-key`, which will be used as the prompt and the label, respectively. | `None` | Type: str | Miles Native | +| `--disable-rollout-global-dataset` | Disable the global dataset for rollout. By default, Miles loads `--prompt-data` into a global dataset and samples from it for rollout. Setting this flag turns off this behavior. Use this flag only when providing a custom `--rollout-function-path` (and usually a custom `--data-source-path`) that handles data loading independently. | `False` | bool flag (set to disable) | Miles Native | +| `--data-source-path` | Path to a custom Python class for the rollout data source. [Ref](../get_started/customization.md#15-data-source---data-source-path) | `miles.rollout.data_source.RolloutDataSourceWithBuffer` | Type: str | Miles Native | +| `--input-key` | Key in the JSONL data representing the user input/prompt. | `"input"` | Type: str | Miles Native | +| `--label-key` | Key in the JSONL data representing the label/ground truth. | `None` | Type: str | Miles Native | +| `--metadata-key` | When adding tools during `apply_chat_template`, provide the key for the tools to the prompt dataset. | `"metadata"` | Type: str | Miles Native | +| `--multimodal-keys` | JSON string for multimodal data mapping media types to data keys. Example: `'{"image": "image_file"}'` | `None` | Type: str | Miles Native | +| `--tool-key` | JSON key for tool definitions in the prompt dataset (used when applying chat templates). | `"tools"` | Type: str | Miles Native | +| `--apply-chat-template` | Whether to apply the chat template to the input prompt. The input should be the same structure as an OpenAI message, e.g., `[{'role': 'user', 'content': 'blabla'}]`. | `False` | bool flag (set to enable) | Miles Native | +| `--apply-chat-template-kwargs` | Extra arguments for the chat template processing (JSON string). | `"{}"` | Type: str | Miles Native | +| `--num-rollout` | Number of rollout steps. If not set, Miles will calculate the number of rollout steps from the dataset size. **Note:** This value will be overwritten if `--num-epoch` is also set. | `None` | Type: int | Miles Native | +| `--num-epoch` | Number of epochs for the training. If set, `num_rollout` is calculated as `(num_epoch * dataset_size) // rollout_batch_size`. **Note:** This argument takes precedence and will overwrite `--num-rollout` if both are specified. | `None` | Type: int | Miles Native | +| `--rollout-batch-size` | Number of prompts per rollout batch. The total data returned should be `rollout_batch_size` * `n_samples_per_prompt`. | Required | Type: int | Miles Native | +| `--n-samples-per-prompt` | Number of responses to generate for each prompt, e.g., the group size of GRPO. The default rollout pipeline expects each prompt group to contain exactly `n_samples_per_prompt` samples. | `1` | Type: int | Miles Native | +| `--global-batch-size` | Total samples per optimizer step. Automatically calculated or **overridden** if `num_steps_per_rollout` is set. | `None` | Type: int | Megatron-LM (Reset by Miles) | +| `--num-steps-per-rollout` | The number of training steps to perform using the data collected in a single rollout round. Setting this to `n` means the policy model will be updated `n` times using the same batch of rollout data. Miles ensures that `(rollout-batch-size * n-samples-per-prompt) = (global-batch-size * num-steps-per-rollout)`. If this value is not provided, you have to set `--global-batch-size` explicitly. If both are provided, `--num-steps-per-rollout` will **override** the global batch size with `num_steps_per_rollout = (rollout_batch_size * n_samples_per_prompt) // num_steps_per_rollout`. | `None` | Type: int | Miles Native | +| `--use-dynamic-batch-size` | Dynamically packs variable-length samples into micro-batches to maximize GPU utilization, ensuring the total token count per batch does not exceed `--max-tokens-per-gpu`. For example, with a 300-token limit, samples of lengths 100, 200, and 300 would be packed into two batches: `[100, 200]` and `[300]`. **Note:** Miles ensures that enabling this optimization does not affect the mathematical correctness of per-sample or per-token loss calculation. It is **strongly recommended** to enable this for maximum efficiency. **Compatibility:** only supported when `--qkv-format` is `thd` (does not work for `bshd`). | `False` | bool flag (set to enable) | Miles Native | +| `--max-tokens-per-gpu` | The maximum number of tokens (Prompt + Response combined) per GPU for dynamic batch size. This parameter defines the total sequence length budget for packing samples into micro-batches during training. Note that when enabling context parallel (CP), the effective capacity is shared, so the value should be approximately `(Total_Sequence_Length) // cp_size`. | `None` | Type: int | Miles Native | +| `--log-probs-max-tokens-per-gpu` | The maximum number of tokens per GPU for calculating log probs. This is used to calculate the log probs of the responses during rollout, and should be set to a larger value than `max_tokens_per_gpu` if you want better performance. | `None` | Type: int | Miles Native | +| `--balance-data` | Repartition each rollout batch so each data-parallel rank gets a similar total token count via the Karmarkar-Karp method. It may be beneficial for training speed, but changes per-rank sample grouping and adds a small CPU scheduling overhead. | `False` | bool flag (set to enable) | Miles Native | +| `--data-pad-size-multiplier` | Multiplier used to calculate the sequence padding boundary. Miles rounds sequence lengths up to a multiple of `tensor_parallel_size * data_pad_size_multiplier`. This optimization ensures that matrix dimensions are aligned with NVIDIA Tensor Core requirements, maximizing throughput and reducing VRAM fragmentation. **Note:** better not change this; values `<128` may trigger accuracy loss under `--qkv-format thd` when `TP>=4`. | `128` | Type: int | Miles Native | +| `--micro-batch-size` | Micro batch size per GPU. Ignored when `--use-dynamic-batch-size` is enabled. Works for both `--qkv-format thd` and `--qkv-format bshd` (and is required for `bshd` because dynamic batch size is unsupported). | `1` | Type: int | Megatron-LM (Reset by Miles) | + +## Evaluation Arguments + +Arguments for configuring the evaluation process during training. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--eval-interval` | Interval (in rollout steps) between evaluations. | `None` | Type: int | Megatron-LM (Reset by Miles) | +| `--eval-prompt-data` | List of name and path pairs for evaluation datasets (e.g., `aime /path/to/aime.jsonl`). | `None` | Type: List[str] | Miles Native | +| `--eval-config` | Path to an OmegaConf YAML/JSON file describing evaluation datasets (overrides `--eval-prompt-data`). | `None` | Type: str | Miles Native | +| `--skip-eval-before-train` | Skip the evaluation step before training starts. | `False` | bool flag (set to enable) | Miles Native | +| `--n-samples-per-eval-prompt` | Number of responses for each prompt in generation. | `1` | Type: int | Miles Native | +| `--eval-temperature` | Temperature for evaluation (defaults to rollout temperature if not set). | `None` | Type: float | Miles Native | +| `--eval-top-p` | Top-p sampling threshold for evaluation (defaults to rollout top-p if not set). | `None` | Type: float | Miles Native | +| `--eval-top-k` | Top-k sampling threshold for evaluation (defaults to rollout top-k if not set). | `None` | Type: int | Miles Native | +| `--eval-max-response-len` | Maximum response length for evaluation (defaults to rollout max response length if not set). | `None` | Type: int | Miles Native | +| `--eval-max-prompt-len` | Maximum prompt length for evaluation. | `None` | Type: int | Miles Native | +| `--eval-min-new-tokens` | Minimum tokens to generate for evaluation responses (Not used). | `None` | Type: int | Miles Native | +| `--eval-max-context-len` | Maximum context length for evaluation (defaults to rollout max context length if not set). | `None` | Type: int | Miles Native | +| `--eval-function-path` | Path to a custom evaluation function. [Ref](../get_started/customization.md#16-evaluation-function---eval-function-path) | `None` | Type: str | Miles Native | +| `--eval-input-key` | JSON key for input text in evaluation datasets. | `None` | Type: str | Miles Native | +| `--eval-label-key` | JSON key for ground truth labels in evaluation datasets. | `None` | Type: str | Miles Native | +| `--eval-tool-key` | JSON key for tool definitions in evaluation datasets. | `None` | Type: str | Miles Native | + +## Checkpointing and Resuming + +Arguments for saving and loading model states. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--load` | Path to the training model checkpoint to load. | `None` | Type: str | Megatron-LM (Reset by Miles) | +| `--save` | Path to save checkpoints. | `None` | Type: str | Megatron-LM (Reset by Miles) | +| `--save-interval` | Interval (in rollout steps) to save checkpoints. Requires `--save` to be set. | `None` | Type: int | Megatron-LM (Reset by Miles) | +| `--async-save` | Enable asynchronous checkpoint saving (Megatron backend only). | `False` | bool flag (set to enable) | Megatron-LM (Reset by Miles) | +| `--save-hf` | Path to save the model in HuggingFace format when using Megatron backend. The model will be saved to `save_hf.format(rollout_id)`. | `None` | Type: str | Miles Native | +| `--no-save-optim` | If set, optimizer state is not saved with checkpoints to reduce size, but prevents resumption of training. | `False` | bool flag (set to enable) | Megatron-LM (Reset by Miles) | +| `--ref-load` | Path to the reference model checkpoint. Used as an initial checkpoint if `--load` is not set. | `None` | Type: str | Miles Native | +| `--ref-ckpt-step` | The checkpoint step for the reference model. | `None` | Type: int | Miles Native | +| `--critic-load` | Checkpoint to load for the critic model. | value of `--load` | Type: str | Miles Native | +| `--critic-save` | Path to save the critic model. | `None` | Type: str | Miles Native | +| `--start-rollout-id` | The starting rollout step. If not set, it is inferred from the --load checkpoint when resuming training. Otherwise, if training is not continuous, Miles will start training from scratch | `None` | Type: int | Miles Native | + +--- + +## Algorithm and RL Arguments + +Arguments for reinforcement learning algorithms and loss calculation. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--advantage-estimator` | Advantage estimator to use. | `"grpo"` | `grpo`, `gspo`, `ppo`, `reinforce_plus_plus`, `reinforce_plus_plus_baseline`, `on_policy_distillation` | Miles Native | +| `--loss-type` | Type of loss function to use. | `"policy_loss"` | `policy_loss`, `sft_loss`, `custom_loss` | Miles Native | +| `--custom-loss-function-path` | Path to a custom loss calculation function (requires `--loss-type custom_loss`). [Ref](../get_started/customization.md#9-custom-loss-function---custom-loss-function-path) | `None` | Type: str | Miles Native | +| `--critic-lr` | Learning rate for the Critic. Defaults to `--lr`. | `None` | Type: float | Miles Native | +| `--critic-lr-warmup-iters` | Number of iterations for Critic learning rate linear warmup. | `0` | Type: int | Miles Native | +| `--num-critic-only-steps` | Number of initial steps dedicated to training only the Critic. | `0` | Type: int | Miles Native | +| `--eps-clip` | PPO clip range. | `0.2` | Type: float | Miles Native | +| `--eps-clip-high` | PPO clip upper range (defaults to `--eps-clip` if not set). | `None` | Type: float | Miles Native | +| `--eps-clip-c` | Lower bound for [Dual-clip PPO](https://arxiv.org/pdf/1912.09729). | `None` | Type: float | Miles Native | +| `--value-clip` | Clip range for value loss. | `0.2` | Type: float | Miles Native | +| `--kl-coef` | KL penalty coefficient for reward shaping. This is applied to the reward signal before advantage calculation for PPO and REINFORCE-style estimator. | `0.00` | Type: float | Miles Native | +| `--use-kl-loss` | Enable KL loss term in the final objective (as in GRPO). | `False` | bool flag (set to enable) | Miles Native | +| `--kl-loss-coef` | Weight of the KL loss term in the final objective. | `0.0` | Type: float | Miles Native | +| `--kl-loss-type` | Selection of the KL loss implementation. See [Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for more details. | `k1` | `k1`, `k2`, `k3`, `low_var_kl` | Miles Native | +| `--use-unbiased-kl` | Apply Importance Sampling (IS) correction to the KL estimator. Reduces bias from distribution shift. | `False` | bool flag (set to enable) | Miles Native | +| `--entropy-coef` | Coefficient for entropy regularization term. Penalizes low entropy to encourage exploration and prevent premature convergence. | `0.0` | Type: float | Miles Native | +| `--gamma` | Discount factor for future rewards. Used in PPO (GAE) and REINFORCE++. | `1.0` | Type: float | Miles Native | +| `--lambd` | PPO GAE lambda. | `1.0` | Type: float | Miles Native | +| `--normalize-advantages` | Performs distributed masked whitening of advantages. Normalization statistics are computed globally across the Data-Parallel group, ignoring padding tokens. | `False` | bool flag (set to enable) | Miles Native | +| `--disable-compute-advantages-and-returns` | Disables the calculation of advantages and returns. This is typically used for SFT or custom loss functions where value estimation is not required. | `False` | bool flag (set to enable) | Miles Native | +| `--use-tis` | Enable Token-level Importance Sampling (TIS) from this [blog](https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33). | `False` | bool (set to enable) | Miles Native | +| `--tis-clip` | Clipping threshold C for importance sampling ratios to control variance. | `2.0` | Type: float | Miles Native | +| `--tis-clip-low` | Lower bound clipping threshold C for importance sampling ratios to control variance. | `0.0` | Type: float | Miles Native | +| `--custom-tis-function-path` | Path to a custom TIS or MIS function. [Ref](../get_started/customization.md#10-custom-tisrs-function---custom-tis-function-path) | `None` | Type: str | Miles Native | +| `--custom-pg-loss-reducer-function-path` | Custom reducer function for policy gradient loss. [Ref](../get_started/customization.md#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | `None` | Type: str | Miles Native | +| `--use-routing-replay` | Enable R2 (Routing Replay) for MoE: record expert routing decisions during forward and replay them during backward. [Paper](https://arxiv.org/abs/2507.18071) **Note:** automatically set to `True` when `--use-rollout-routing-replay` is enabled. | `False` | bool flag (set to enable) | Miles Native | +| `--use-rollout-routing-replay` | Enable R3 (Rollout Routing Replay) for MoE: record expert routing decisions during rollout and replay them during training. **Requires `--use-miles-router`**. [Paper](https://arxiv.org/abs/2510.11370) [Ref](miles-router.md#22-rollout-routing-replay-r3-for-moe) | `False` | bool flag (set to enable) | Miles Native | +| `--use-opsm` | Enable Off-Policy Sequence Masking (OPSM). Filters sequences that have **BOTH** negative advantages (bad results) AND high KL divergence (stale data). This stabilizes training by preventing updates from unreliable, highly off-policy samples. | `False` | bool flag (set to enable) | Miles Native | +| `--opsm-delta` | The threshold for Off-Policy Sequence Masking (OPSM). | `1e-4` | Type: float | Miles Native | +| `--get-mismatch-metrics` | Calculate mismatch metrics. If it is set, you need to provide a custom TIS function via `--custom-tis-function-path`. | `False` | bool flag (set to enable) | Miles Native | +| `--ref-update-interval` | Interval (in rollout steps) to update ref model from actor. If `None`, ref model is not updated. | `None` | Type: int | Miles Native | +| `--reset-optimizer-states` | Resets the optimizer state after each rollout round. This clears the optimization history, which can improve stability or satisfy specific experimental requirements. | `False` | bool flag (set to enable) | Miles Native | +| `--disable-grpo-std-normalization` | Disable standard deviation normalization for GRPO. From [Dr.GRPO](https://arxiv.org/pdf/2503.20783) | `False` | bool flag (set to enable) | Miles Native | +| `--disable-rewards-normalization` | Disable the default group-wise reward normalization for GRPO, GSPO, and REINFORCE++. This effectively skips the baseline subtraction step. | `False` | bool flag (set to enable) | Miles Native | +| `--use-rollout-entropy` | Enable entropy calculation when calculating the logprobs from actor and reference model. This is useful for implementing custom entropy-based loss masking. | `False` | bool flag (set to enable) | Miles Native | +| `--use-rollout-logprobs` | Use rollout logprobs as the old-policy logprobs when computing importance sampling ratios / PPO-style KL in GRPO/GSPO/PPO. If not set, Miles recomputes old-policy logprobs with the training actor (e.g., `old_actor` or `actor`, depending on configuration). If `--get-mismatch-metrics` is set, the log probs will still be recomputed by the training engine (one more forward pass will be applied). | `False` | bool flag (set to enable) | Miles Native | +| `--calculate-per-token-loss` | Calculate loss on a per-token basis. | `False` | bool flag (set to enable) | Megatron-LM (Reset by Miles) | +| `--seed` | Random seed for the training process. **Also passed to SGLang servers as `random_seed`** (Miles uses `seed + engine_rank` so each engine has a distinct but reproducible seed). | `1234` | Type: int | Megatron-LM (Reset by Miles) | +| `--clip-grad` | Maximum gradient norm for gradient clipping. | `1.0` | Type: float | Megatron-LM (Reset by Miles) | + +--- + +## Logging and Monitoring + +Arguments for WandB, Tensorboard, and general logging. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--use-wandb` | Enable WandB logging. | `False` | bool flag (set to enable) | Miles Native | +| `--wandb-mode` | WandB operating mode. Overrides `WANDB_MODE`. | `None` | `online`, `offline`, `disabled` | Miles Native | +| `--wandb-project` | WandB project name. | `None` | Type: str | Megatron-LM (Reset by Miles) | +| `--wandb-group` | WandB group name. | `None` | Type: str | Miles Native | +| `--wandb-team` | WandB team name. | `None` | Type: str | Miles Native | +| `--wandb-host` | WandB host address. | `None` | Type: str | Miles Native | +| `--wandb-key` | WandB API key. | `None` | Type: str | Miles Native | +| `--wandb-run-id` | Specific WandB run ID to resume. | `None` | Type: str | Miles Native | +| `--wandb-dir` | Directory to store WandB logs. Default is `./wandb` in current directory. | `None` | Type: str | Miles Native | +| `--disable-wandb-random-suffix` | Disable adding a random suffix to the WandB run name. By default, we will add a random 6 length string with characters to the run name. | `False` | bool flag (set to enable) | Miles Native | +| `--wandb-always-use-train-step` | Use training steps instead of rollout steps for the x-axis. | `False` | bool flag (set to enable) | Miles Native | +| `--use-tensorboard` | Enable Tensorboard logging. | `False` | bool flag (set to enable) | Miles Native | +| `--tb-project-name` | Tensorboard project directory. | `None` | Type: str | Miles Native | +| `--tb-experiment-name` | Tensorboard experiment name. | `None` | Type: str | Miles Native | +| `--tensorboard-dir` | Directory to store Tensorboard logs. | `None` | Type: str | Miles Native | +| `--log-multi-turn` | Log detailed information for multi-turn conversations. | `False` | bool flag (set to enable) | Miles Native | +| `--log-passrate` | Enable logging of `pass@n` metrics. | `False` | bool flag (set to enable) | Miles Native | +| `--log-correct-samples` | Explicitly log metrics for correct samples. | `False` | bool flag (set to enable) | Miles Native | +| `--log-reward-category` | Log reward-category statistics (e.g., why the reward function marked a failure). Use this argument to specify the key in the reward dict. | `None` | Type: str | Miles Native | + +--- + +## Fault Tolerance + +Arguments for handling server failures during rollout. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--use-fault-tolerance` | Enable fault tolerance for rollout engines. Periodically sends `/health_generate` heartbeats. | `False` | bool flag (set to enable) | Miles Native | +| `--rollout-health-check-interval` | Interval in seconds between rollout engine `/health_generate` checks during generate/eval. | `30.0` | Type: float | Miles Native | +| `--rollout-health-check-timeout` | Timeout in seconds to wait for a rollout engine `/health_generate` response before killing it. | `30.0` | Type: float | Miles Native | +| `--rollout-health-check-first-wait` | Initial grace period (in seconds) before starting health checks. This allows time for model compilation and initialization. Increase this value significantly when using deepgemm. | `0.0` | Type: float | Miles Native | + +--- + +## Miles Router + +Arguments for the specialized Miles text-based router. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--use-miles-router` | Use Miles Router (FastAPI passthrough proxy) instead of SGLang Model Gateway for rollout routing. Required for features that depend on preserving extra rollout metadata (e.g., R3). [Ref](miles-router.md) | `False` | bool flag (set to enable) | Miles Native | +| `--miles-router-middleware-paths` | Paths to custom MilesRouter middleware functions. [Ref](../get_started/customization.md#18-miles-router-middleware---miles-router-middleware-paths) | `""` | Type: List[str] | Miles Native | +| `--miles-router-timeout` | Timeout for router HTTP requests in seconds. | `None` | Type: float | Miles Native | +| `--miles-router-max-connections` | Max connections for MilesRouter HTTP client. | `None` | Type: int | Miles Native | +| `--miles-router-health-check-failure-threshold` | Number of consecutive failures before marking a worker as unhealthy. | `3` | Type: int | Miles Native | + +--- + +## Reward Model Arguments + +Arguments for configuring reward signals and post-processing. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--rm-type` | Built-in reward model selection. | `None` | `remote_rm`, `deepscaler`, `dapo`, `math`, `f1`, `gpqa`, `ifbench`, `random` | Miles Native | +| `--rm-url` | URL for the reward model service (used with `--rm-type remote_rm`). | `None` | Type: str | Miles Native | +| `--reward-key` | JSON key to extract the numerical reward from a returned dictionary if reward model returns a dict instead of a value. | `None` | Type: str | Miles Native | +| `--eval-reward-key` | Evaluation variant for `--reward-key`. | `None` | Type: str | Miles Native | +| `--custom-rm-path` | Path to a custom Python reward function. [Ref](../get_started/customization.md#3-reward-model---custom-rm-path) | `None` | Type: str | Miles Native | +| `--group-rm` | Defer reward computation to process the entire group of samples (per-prompt) at once. Essential for comparative/ranking reward models and improves throughput. **Not supported in eval**. | `False` | bool flag (set to enable) | Miles Native | +| `--custom-reward-post-process-path` | Path to a custom reward post-processor. [Ref](../get_started/customization.md#12-reward-post-processing---custom-reward-post-process-path) | `None` | Type: str | Miles Native | +| `--custom-convert-samples-to-train-data-path` | Path to a custom data format converter. [Ref](../get_started/customization.md#13-samples-to-train-data-conversion---custom-convert-samples-to-train-data-path) | `None` | Type: str | Miles Native | + +--- + +## Rollout Buffer Management + +Arguments for managing the rollout data buffer. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--rollout-buffer-url` | URL for the rollout buffer service. | `None` | Type: str | Miles Native | +| `--fetch-trajectory-retry-times` | Number of times to retry fetching trajectory, -1 means unlimited retry. | `-1` | Type: int | Miles Native | +| `--min-batch-collection-ratio` | Minimum batch collection ratio before proceeding. | `1.0` | Type: float | Miles Native | +| `--disable-rollout-trim-samples` | Disable trim samples in rollout buffer when converting samples to train data. | `False` | bool flag (set to enable) | Miles Native | +| `--use-dynamic-global-batch-size` | Enable dynamic global batch size, disable trim samples in rollout buffer when converting samples to train data. | `False` | bool flag (set to enable) | Miles Native | +| `--rollout-task-type` | Type of task being performed. | `math` | Type: str | Miles Native | +| `--loss-mask-type` | Selection of the token masking logic. | `qwen` | `qwen`, `qwen3`, `distill_qwen` | Miles Native | + +--- + +## Multi-Token Prediction (MTP) Arguments + +Arguments for MTP-based training. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--enable-mtp-training` | Enable MTP layer parameter updates during training. | `False` | bool flag (set to enable) | Miles Native | +| `--mtp-num-layers` | Number of MTP layers to include. | `None` | Type: int | Megatron-LM (Reset by Miles) | +| `--mtp-loss-scaling-factor` | Scaling factor applied to the MTP loss. | `0.2` | Type: float | Megatron-LM (Reset by Miles) | + +--- + +## SGLang Backend Arguments + +Most SGLang server arguments can be passed through by adding the `--sglang-` prefix (some are intentionally skipped, e.g. `model_path`, `tp_size`, `port`, `nnodes`, `node_rank`). For a full list, refer to the [SGLang Server Arguments documentation](https://docs.sglang.io/advanced_features/server_arguments.html). + +Commonly used arguments: + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--sglang-mem-fraction-static` | Fraction of GPU memory to reserve for SGLang KV cache. | `0.9` | Type: float | SGLang | +| `--sglang-server-concurrency` | Maximum number of concurrent requests. | `512` | Type: int | SGLang | +| `--sglang-router-ip` | IP address of the SGLang router and Miles Router. | `None` | Type: str | SGLang Gateway & Miles Router | +| `--sglang-router-port` | Port of the SGLang router and Miles Router. | `None` | Type: int | SGLang Gateway & Miles Router | +| `--sglang-router-request-timeout-secs` | Timeout for requests to the SGLang router. | `14400` | Type: int | SGLang Gateway | + +--- + +## Megatron Specific Arguments + +Arguments applicable when using `--train-backend megatron`. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--megatron-to-hf-mode` | Method to convert Megatron weights to HuggingFace format for SGLang integration. | `raw` | `raw`, `bridge` | Miles Native | +| `--seq-length` | Megatron’s “maximum sequence length” parameter. **In miles training, this parameter has no effect in most setups**: miles uses varlen/packed samples (no truncation based on `seq_length`), forces variable sequence lengths for PP communication buffers, and uses all-to-all token dispatch for MoE. This parameter mainly matters in Megatron’s dataset pipeline. | `None` | Type: int | Megatron-LM | + +--- + +## FSDP Specific Arguments + +Arguments applicable when using `--train-backend fsdp`. **Note: The FSDP backend is still under development and experimental.** + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--warmup-ratio` | Ratio of total steps for warmup. | `0.03` | Type: float | Miles Native | +| `--weight-decay` | Weight decay for the optimizer. | `0.0` | Type: float | Miles Native | +| `--gradient-checkpointing` | Enable gradient checkpointing. | `False` | bool flag (set to enable) | Miles Native | +| `--fsdp-cpu-offload` | Offload parameters and gradients to CPU. | `False` | bool flag (set to enable) | Miles Native | +| `--fsdp-state-dict-cpu-offload` | Offload full state dict to CPU during collection. | `False` | bool flag (set to enable) | Miles Native | +| `--fsdp-cpu-backend` | CPU backend for FSDP CPU offload. | `gloo` | `gloo`, `None` | Miles Native | +| `--attn-implementation` | Selection of the attention implementation. | `flash_attention_2` | `flash_attention_2`, `sdpa`, `eager` | Miles Native | +| `--use-pytorch-profiler` | Enable PyTorch-native profiling. | `False` | bool flag (set to enable) | Miles Native | +| `--profile-step-start` | Starting step for profiling. | `10` | Type: int | Miles Native | +| `--profile-step-end` | Ending step for profiling. | `12` | Type: int | Miles Native | +| `--lr-wsd-decay-iters` | Number of iterations for WSD decay. | `None` | Type: int | Miles Native | +| `--lr-wsd-decay-style` | Decay style for WSD. | `None` | Type: str | Miles Native | +| `--use-checkpoint-lr-scheduler` | Use the checkpoint's LR scheduler state. | `False` | bool flag (set to enable) | Miles Native | +| `--override-lr-scheduler` | Override the loaded LR scheduler state. | `False` | bool flag (set to enable) | Miles Native | +| `--wandb-run-name` | Specific run name for WandB (FSDP backend). | `None` | Type: str | Miles Native | + +--- + +## Debug and Profiling + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--check-weight-update-equal` | Use SGLang's weight checker to check and ensure that the loaded weight from HF checkpoint and received from Megatron are bit-wise equal. | `False` | bool flag (set to enable) | Miles Native | +| `--save-debug-rollout-data` | Path to save rollout data for offline analysis. [Ref](../developer_guide/debug.md) | `None` | Type: str | Miles Native | +| `--load-debug-rollout-data` | Path to load debug rollout data (bypasses SGLang). [Ref](../developer_guide/debug.md) | `None` | Type: str | Miles Native | +| `--load-debug-rollout-data-subsample` | Percentage of debug data to load (0.0 to 1.0). [Ref](../developer_guide/debug.md) | `None` | Type: float | Miles Native | +| `--debug-rollout-only` | Run the rollout phase only without training. [Ref](../developer_guide/debug.md) | `False` | bool flag (set to enable) | Miles Native | +| `--debug-train-only` | Run the training phase only without launching SGLang servers. [Ref](../developer_guide/debug.md) | `False` | bool flag (set to enable) | Miles Native | +| `--save-debug-train-data` | Path to save training batches for offline math debugging. | `None` | Type: str | Miles Native | +| `--dump-details` | Dump exhaustive training details for post-hoc visualization. | `None` | Type: str | Miles Native | +| `--memory-snapshot-path` | Path to save memory snapshots. | `snapshot.pickle` | Type: str | Miles Native | +| `--record-memory-history` | Record memory history for snapshots. | `False` | bool flag (set to enable) | Miles Native | +| `--memory-snapshot-dir` | Directory for PyTorch memory snapshots. | `.` | Type: str | Miles Native | +| `--memory-snapshot-num-steps` | Number of steps to record before saving snapshot. | `None` | Type: int | Miles Native | +| `--memory-recorder` | Selection of the memory recording backend. | `torch` | `torch`, `memray` | Miles Native | +| `--profile-target` | Training components to profile (accepts multiple). | `train_overall` | `train_overall`, `train_actor`, `train_log_probs` | Miles Native | + +--- + +## Environment Variables + +Miles recognizes several environment variables for advanced configuration. + +| Variable | Description | Source | +| :--- | :--- | :--- | +| `MILES_EXPERIMENTAL_ROLLOUT_REFACTOR` | Set to `1` to enable the experimental rollout implementation refactor. | Miles Native | +| `ENABLE_ROUTING_REPLAY` | Internal variable used to enable MoE routing consistency checks during training. | Miles Native | +| `TENSORBOARD_DIR` | Base directory for Tensorboard logs. | Miles Native | +| `MILES_HOST_IP` | Overrides the host IP used for distributed communication. | Miles Native | +| `PYTHONPATH` | Must include the path to your `Megatron-LM` installation when using the Megatron backend. | System | +| `NCCL_SOCKET_IFNAME` | Specifies the network interface for NCCL communication (e.g., `eth0`, `bond0`). | System | +| `GLOO_SOCKET_IFNAME` | Specifies the network interface for GLOO communication. | System | +| `NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME` | Network interface for NVSHMEM bootstrap. | System | + +--- + +## Multi-Turn and Agentic Arguments + +Arguments for managing interactions and tools. Only available when `MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1` and the rollout/generate function exposes `add_arguments`. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--generate-max-turns` | Maximum number of turns in a conversation. | `16` | Type: int | Miles Native | +| `--generate-tool-specs-path` | Path to the tool specifications (JSON). | `None` | Type: str | Miles Native | +| `--generate-tool-call-parser` | The parser used to extract tool calls from text. | `None` | Type: str | Miles Native | +| `--generate-execute-tool-function-path` | Path to the function that executes the tool. | `None` | Type: str | Miles Native | +| `--generate-multi-samples` | Whether to generate multiple samples within one turn. | `False` | bool flag (set to enable) | Miles Native | + +--- + +## Advanced Developer Hooks and CI + +Hooks for custom logic and Continuous Integration testing flags. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--custom-megatron-init-path` | Path to custom Megatron initialization logic. [Ref](../get_started/customization.md#17-megatron-hooks) | `None` | Type: str | Miles Native | +| `--custom-megatron-before-log-prob-hook-path` | Hook called before calculating log probabilities. [Ref](../get_started/customization.md#17-megatron-hooks) | `None` | Type: str | Miles Native | +| `--custom-megatron-before-train-step-hook-path` | Hook called before each training step. [Ref](../get_started/customization.md#17-megatron-hooks) | `None` | Type: str | Miles Native | +| `--ci-test` | Enable Continuous Integration testing mode. | `False` | bool flag (set to enable) | Miles Native | +| `--ci-disable-kl-checker` | Disable KL divergence sanity checks in CI. | `False` | bool flag (set to enable) | Miles Native | +| `--ci-metric-checker-key` | Metric key to monitor for pass/fail in CI. | `None` | Type: str | Miles Native | +| `--ci-metric-checker-threshold` | Pass/fail threshold (minimum value) for the monitored metric. | `None` | Type: float | Miles Native | +| `--ci-save-grad-norm` | Path to save gradient norms for CI comparison. | `None` | Type: str | Miles Native | +| `--ci-load-grad-norm` | Path to load gradient norms for CI verification. | `None` | Type: str | Miles Native | + +--- + +## Miscellaneous and System + +General arguments for infrastructure and configuration overrides. + +| Argument | Description | Default | Options | Source | +| :--- | :--- | :--- | :--- | :--- | +| `--http-proxy` | HTTP proxy server for remote reward model calls. | `None` | Type: str | Miles Native | +| `--use-distributed-post` | Use distributed POST requests for remote reward models. | `False` | bool flag (set to enable) | Miles Native | +| `--custom-config-path` | Path to the YAML config for custom function arguments. | `None` | Type: str | Miles Native | +| `--padded-vocab-size` | Manually specify the vocab size for padding. | `None` | Type: int | Miles Native | diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index bfd5024228..8aa63c23fb 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -29,12 +29,19 @@ Below is a summary of all available customization interfaces and their purposes. | [`--custom-megatron-before-log-prob-hook-path`](#17-megatron-hooks) | Custom logic before log probability computation. | | [`--custom-megatron-before-train-step-hook-path`](#17-megatron-hooks) | Custom logic before each training step. | | [`--miles-router-middleware-paths`](#18-miles-router-middleware---miles-router-middleware-paths) | Add custom middleware to miles router. | +| [`--custom-model-provider-path`](#20-model-provider---custom-model-provider-path) | Path to a custom function that replaces the default model provider. | ## Detailed Interface Reference ### 1. Rollout Function (`--rollout-function-path`) -**Default**: `miles.rollout.sglang_rollout.generate_rollout` +**Default**: +```python +if enable_experimental_rollout_refactor(): + miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn +else: + miles.rollout.sglang_rollout.generate_rollout +``` **Purpose**: Override the entire rollout generation logic. @@ -418,3 +425,18 @@ Stabilize MoE RL training by recording and replaying expert routing decisions to | `--use-rollout-routing-replay` | R3: Replay routing from rollout during training. **Requires `--use-miles-router`**. ([arXiv:2510.11370](https://arxiv.org/abs/2510.11370)) | For detailed explanation of R3 and MilesRouter, see [Miles Router](../advanced/miles-router.md). + +--- + +### 20. Model Provider (`--custom-model-provider-path`) + +**Default**: `None` + +**Purpose**: Path to a custom function that replaces the default model provider (e.g., `'my_module.my_provider'`). The function must return a GPTModel. + +**Signature**: +```python +def custom_model_provider(pre_process: bool, post_process: bool, vp_stage: int | None = None) -> GPTModel +``` + + diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 09c817c5e8..ac43859521 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -124,7 +124,7 @@ def add_train_arguments(parser): type=str, choices=["thd", "bshd"], default="thd", - help="The qkv layout for Megatron backend.", + help="The qkv layout.", ) parser.add_argument( "--true-on-policy-mode", @@ -148,7 +148,12 @@ def add_train_arguments(parser): "--disable-weights-backuper", action="store_false", dest="enable_weights_backuper", - help="Whether to disable weights backuper to save host memory.", + help=( + "Applies to `megatron` training backend only. " + "Disables the system that backups model weights (Actor, Ref, Old Actor) to CPU RAM. " + "Disabling saves significant host memory but prevents features that rely on weight-swapping, such as computing KL-divergence against a reference model. " + "Note: do not set `--ref-load` and `--keep-old-actor` if disable weights backuper." + ), ) parser.add_argument( "--megatron-to-hf-mode", @@ -171,7 +176,7 @@ def add_train_arguments(parser): parser.add_argument( "--recompute-loss-function", action="store_true", - help="Whether to disable recompute loss function to save memory during training.", + help="Whether to enable recompute loss function to save memory during training.", ) parser.add_argument( "--log-probs-chunk-size", type=int, default=-1, help="Chunk size to compute log probs to save memory" @@ -493,10 +498,8 @@ def add_data_arguments(parser): action="store_false", dest="rollout_global_dataset", help=( - "Whether to use a global dataset for rollout. " - "If set, the rollout will use the `--prompt-data` as the prompt dataset, " - "and the prompts for rollout will be sampled from the dataset. " - "If not set, you need to manage the data by your self." + "Disable the global dataset for rollout. By default, Miles loads `--prompt-data` into a global dataset and samples from it for rollout. " + "Setting this flag turns off this behavior, Use this flag only when providing a custom `--rollout-function-path` (and usually a custom `--data-source-path`) that handles data loading independently." ), ) @@ -513,7 +516,7 @@ def add_data_arguments(parser): help=( "The path to the prompt data. " "Currently we only support jsonl format, and each line should contains --input-key and --label-key, " - "which will be used as the prompt and the label respectively. " + "which will be used as the prompt and the label respectively." "If you want to use a custom template, you can set --apply-chat-template to true, in that case, " "the input should be the same structure as an openai message, e.g. [{'role': 'user', 'content': 'blabla'}]. " ), @@ -585,8 +588,8 @@ def add_data_arguments(parser): action="store_true", default=False, help=( - "Balance the number of tokens between data parallel ranks with `karmarkar_karp` for verl. " - "Note that this may allocate the different response of the same prompt into different training steps." + "Repartition each rollout batch so each data-parallel rank gets a similar total token count via Karmarkar-Karp method. " + "It may be beneficial for training speed but changes per-rank sample grouping and adds a small CPU scheduling overhead." ), ) @@ -875,7 +878,7 @@ def add_algo_arguments(parser): "--use-tis", action="store_true", default=False, - help="Enable TIS from https://fengyao.notion.site/off-policy-rl for off-policy importance sampling.", + help="Enable TIS from https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33.", ) parser.add_argument( "--tis-clip", @@ -1040,7 +1043,7 @@ def add_wandb_arguments(parser): "--log-correct-samples", action="store_true", default=False, - help="Whether to turn on passrate logging, which will log the pass@n of the responses in the rollout.", + help="Explicitly log metrics for correct samples.", ) parser.add_argument("--wandb-run-id", type=str, default=None) return parser From 725dcef16554927ba2be69bc181e1697a9f89efd Mon Sep 17 00:00:00 2001 From: lizamd Date: Fri, 13 Feb 2026 13:53:10 -0500 Subject: [PATCH 75/77] [AMD] Unify run-qwen3-4B.sh to support both AMD and NVIDIA GPUs Auto-detect GPU vendor (/dev/kfd or torch.version.hip for AMD, nvidia-smi for NVIDIA) and conditionally apply platform-specific settings: - AMD: HIP_VISIBLE_DEVICES, RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES, --no-gradient-accumulation-fusion, --no-offload-train/rollout - NVIDIA: NVLink detection, NCCL_NVLS_ENABLE - Both: dynamic Megatron-LM path detection, configurable MODEL_DIR/DATA_DIR This eliminates the need for a separate run-qwen3-4B-amd.sh script. Co-Authored-By: Claude Opus 4.6 --- scripts/run-qwen3-4B.sh | 76 +++++++++++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/scripts/run-qwen3-4B.sh b/scripts/run-qwen3-4B.sh index c7f01abd93..2d0fd0eb35 100644 --- a/scripts/run-qwen3-4B.sh +++ b/scripts/run-qwen3-4B.sh @@ -10,33 +10,54 @@ sleep 3 pkill -9 ray pkill -9 python -set -ex +set -euxo pipefail -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 +# ==================== Platform Detection ==================== +if [ -e /dev/kfd ] || python3 -c "import torch; assert torch.version.hip" 2>/dev/null; then + GPU_VENDOR="amd" +elif command -v nvidia-smi &>/dev/null; then + GPU_VENDOR="nvidia" else + echo "ERROR: No supported GPU detected (need NVIDIA or AMD)" + exit 1 +fi +echo "Detected GPU vendor: ${GPU_VENDOR}" + +# ==================== Configurable Paths ==================== +MODEL_DIR="${MODEL_DIR:-/root}" +DATA_DIR="${DATA_DIR:-/root}" +export MODEL_DIR DATA_DIR + +# ==================== Platform-Specific Setup ==================== +if [ "$GPU_VENDOR" = "amd" ]; then + export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} + export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} + NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) HAS_NVLINK=0 +else + NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) + if [ "$NVLINK_COUNT" -gt 0 ]; then HAS_NVLINK=1; else HAS_NVLINK=0; fi + echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + NUM_GPUS=8 fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/qwen3-4B.sh" CKPT_ARGS=( - --hf-checkpoint /root/Qwen3-4B - #--hf-checkpoint /root/Qwen3-4B-FP8 - --ref-load /root/Qwen3-4B_torch_dist - --load /root/Qwen3-4B_miles/ - --save /root/Qwen3-4B_miles/ + --hf-checkpoint ${MODEL_DIR}/Qwen3-4B + #--hf-checkpoint ${MODEL_DIR}/Qwen3-4B-FP8 + --ref-load ${MODEL_DIR}/Qwen3-4B_torch_dist + --load ${MODEL_DIR}/Qwen3-4B_miles/ + --save ${MODEL_DIR}/Qwen3-4B_miles/ --save-interval 20 ) ROLLOUT_ARGS=( - --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl --input-key prompt --label-key label --apply-chat-template @@ -54,7 +75,7 @@ ROLLOUT_ARGS=( EVAL_ARGS=( --eval-interval 20 - --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 --eval-top-p 1 @@ -119,14 +140,26 @@ MISC_ARGS=( --attention-backend flash ) -# launch the master node of ray in container +# ==================== Platform-Specific Args ==================== +PLATFORM_TRAIN_ARGS=() +if [ "$GPU_VENDOR" = "amd" ]; then + # Apex not available on ROCm + MISC_ARGS+=(--no-gradient-accumulation-fusion) + # Disable offloading (torch_memory_saver may not support ROCm; MI300X has 192GB HBM) + PLATFORM_TRAIN_ARGS+=(--no-offload-train --no-offload-rollout) +fi + +# ==================== Launch Ray ==================== export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Dynamically detect Megatron-LM installation path +MEGATRON_LM_PATH=$(python3 -c "import megatron; import os; print(os.path.dirname(os.path.dirname(megatron.__file__)))" 2>/dev/null || echo "/app/Megatron-LM") -# Build the runtime environment JSON with proper variable substitution +# Build the runtime environment JSON RUNTIME_ENV_JSON="{ \"env_vars\": { - \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"PYTHONPATH\": \"${MEGATRON_LM_PATH}/\", \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" } @@ -136,8 +169,9 @@ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ + --actor-num-gpus-per-node ${NUM_GPUS} \ --colocate \ + ${PLATFORM_TRAIN_ARGS[@]+"${PLATFORM_TRAIN_ARGS[@]}"} \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ ${ROLLOUT_ARGS[@]} \ @@ -147,4 +181,4 @@ ray job submit --address="http://127.0.0.1:8265" \ ${PERF_ARGS[@]} \ ${EVAL_ARGS[@]} \ ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ No newline at end of file + ${MISC_ARGS[@]} From daab467c463f84d22cf6d73b17145d5285441425 Mon Sep 17 00:00:00 2001 From: lizamd Date: Mon, 16 Feb 2026 19:03:34 -0500 Subject: [PATCH 76/77] Address PR review comments - Use dynamic NVIDIA GPU count via nvidia-smi -L instead of hardcoded 8 - Remove --no-gradient-accumulation-fusion (AMD Docker now supports it) - Remove --no-offload-train/rollout (torch_memory_saver resolved for ROCm) - Expand compact if/else to multi-line for readability Co-Authored-By: Claude Opus 4.6 --- scripts/run-qwen3-4B.sh | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/scripts/run-qwen3-4B.sh b/scripts/run-qwen3-4B.sh index 2d0fd0eb35..a4fc2d9dfc 100644 --- a/scripts/run-qwen3-4B.sh +++ b/scripts/run-qwen3-4B.sh @@ -36,9 +36,13 @@ if [ "$GPU_VENDOR" = "amd" ]; then HAS_NVLINK=0 else NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) - if [ "$NVLINK_COUNT" -gt 0 ]; then HAS_NVLINK=1; else HAS_NVLINK=0; fi + if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 + else + HAS_NVLINK=0 + fi echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - NUM_GPUS=8 + NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) fi # will prevent ray from buffering stdout/stderr @@ -140,15 +144,6 @@ MISC_ARGS=( --attention-backend flash ) -# ==================== Platform-Specific Args ==================== -PLATFORM_TRAIN_ARGS=() -if [ "$GPU_VENDOR" = "amd" ]; then - # Apex not available on ROCm - MISC_ARGS+=(--no-gradient-accumulation-fusion) - # Disable offloading (torch_memory_saver may not support ROCm; MI300X has 192GB HBM) - PLATFORM_TRAIN_ARGS+=(--no-offload-train --no-offload-rollout) -fi - # ==================== Launch Ray ==================== export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 @@ -171,7 +166,6 @@ ray job submit --address="http://127.0.0.1:8265" \ --actor-num-nodes 1 \ --actor-num-gpus-per-node ${NUM_GPUS} \ --colocate \ - ${PLATFORM_TRAIN_ARGS[@]+"${PLATFORM_TRAIN_ARGS[@]}"} \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ ${ROLLOUT_ARGS[@]} \ From 0ea7c3b930321dc105998666009222065b31f711 Mon Sep 17 00:00:00 2001 From: lizamd Date: Mon, 16 Feb 2026 19:06:18 -0500 Subject: [PATCH 77/77] Add --sglang-disable-custom-all-reduce for AMD Prevent driver-level deadlocks when offload is enabled on AMD GPUs, consistent with PR #588 changes to run-qwen3-4B-amd.sh. Co-Authored-By: Claude Opus 4.6 --- scripts/run-qwen3-4B.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scripts/run-qwen3-4B.sh b/scripts/run-qwen3-4B.sh index a4fc2d9dfc..cecb41704e 100644 --- a/scripts/run-qwen3-4B.sh +++ b/scripts/run-qwen3-4B.sh @@ -133,6 +133,11 @@ SGLANG_ARGS=( --sglang-mem-fraction-static 0.7 ) +# AMD: disable custom all-reduce to prevent driver-level deadlocks with offload enabled +if [ "$GPU_VENDOR" = "amd" ]; then + SGLANG_ARGS+=(--sglang-disable-custom-all-reduce) +fi + MISC_ARGS=( # default dropout in megatron is 0.1 --attention-dropout 0.0