From 84519a0fad0283040941e5c4ae57a4e1781dce06 Mon Sep 17 00:00:00 2001 From: dafu-wu Date: Thu, 2 Apr 2026 04:46:45 +0000 Subject: [PATCH] [fix] Improve async training pipeline: ServerAdapter weight sync, robustness, and packaging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Core improvements: - async_main: Fix Rollouter/Trainer lifecycle so Trainer is not cancelled when Rollouter finishes first; shutdown DataPool after Trainer completes; propagate critical env vars (PYTHONPATH, CUDA_HOME, etc.) to Ray workers - async_rollouter: Use sys.executable instead of hardcoded 'python'; log gateway stdout/stderr to files; use FQDN hostname for multi-node gateway access; defer DataPool shutdown to async_main; gracefully handle validation failures - async_trainer: Fix need_reference_policy(config) call; remove deprecated num_examine param; add debug logging for reward computation diagnostics - blackbox_agent_flow: Add HTTP POST retry with exponential backoff; handle transient connection errors during agent runs (reward=0 fallback); support _run_agent returning (num_turns, reward) tuple - gsm8k_agent/gsm8k_agent_flow: Return (turns_used, reward) tuple from solve() for direct reward passing to Gateway - data_pool: Add get_last_step() for auto-reward computation; fix max_queue_size display when None - training_backend: Add debug logging for reward status in VerlBackend.convert() - detach_workers: Add ServerAdapter support — skip NCCL broadcast for vLLM HTTP server mode; add extract_actor_weights() and receive_and_update_weights() for Ray object store weight transfer; handle DTensor extraction - gateway: Add auto-reward computation for black-box agents via RewardLoopWorker; improve vLLM proxy error handling - param_sync: Add ServerAdapter weight sync path via Ray object store as alternative to NCCL; auto-detect ServerAdapter rollouts - reward_loop: Fix import compatibility (reward_manager vs reward_loop module); add safe defaults for reward_loop_source and reward_manager config - main_agent_ppo: Fix need_reference_policy call; remove num_examine; set reward_fn as attributes to match RayPPOTrainer API - pyproject.toml: Add PEP 621 project metadata, build-system config, and setuptools package discovery for pip-installable packaging --- claw_r1/__init__.py | 0 claw_r1/async_main.py | 59 +++++-- claw_r1/async_rollouter.py | 45 +++-- claw_r1/async_trainer.py | 23 ++- claw_r1/blackbox_agent/blackbox_agent_flow.py | 82 ++++++++- claw_r1/blackbox_agent/gsm8k_agent.py | 12 +- claw_r1/blackbox_agent/gsm8k_agent_flow.py | 2 +- claw_r1/data_pool/data_pool.py | 14 +- claw_r1/data_pool/training_backend.py | 9 + claw_r1/detach_workers.py | 114 ++++++++++++- claw_r1/gateway/gateway.py | 159 +++++++++++++++++- claw_r1/main_agent_ppo.py | 16 +- claw_r1/param_sync.py | 57 ++++++- claw_r1/reward_loop.py | 10 +- pyproject.toml | 51 ++++++ 15 files changed, 591 insertions(+), 62 deletions(-) create mode 100644 claw_r1/__init__.py diff --git a/claw_r1/__init__.py b/claw_r1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/claw_r1/async_main.py b/claw_r1/async_main.py index d5092ea..6f37a95 100644 --- a/claw_r1/async_main.py +++ b/claw_r1/async_main.py @@ -116,7 +116,7 @@ def _initialize(self, config): validate_config( config=config, - use_reference_policy=need_reference_policy(role_worker_mapping), + use_reference_policy=need_reference_policy(config), use_critic=need_critic(config), ) @@ -240,27 +240,50 @@ def _run(self): rollouter_future = self.components["rollouter"].fit.remote() trainer_future = self.components["trainer"].fit.remote() - futures = [rollouter_future, trainer_future] - + # The Rollouter generates data much faster than the Trainer consumes it. + # We must wait for the Trainer to finish (it controls the training steps). + # The Rollouter finishing first is normal — it just means all data has + # been generated. We should NOT cancel the Trainer when that happens. try: - while futures: - done, remaining = ray.wait(futures, num_returns=1, timeout=None) + # Wait for Trainer to finish (the primary task) + # Also monitor Rollouter for errors + futures = {rollouter_future: "Rollouter", trainer_future: "Trainer"} + trainer_done = False + rollouter_done = False + + while not trainer_done: + done, _ = ray.wait(list(futures.keys()), num_returns=1, timeout=None) for f in done: + name = futures.pop(f) try: ray.get(f) - print("[ASYNC] Component completed successfully") + print(f"[ASYNC] {name} completed successfully") + if name == "Trainer": + trainer_done = True + elif name == "Rollouter": + rollouter_done = True except Exception as e: - print(f"[ASYNC] Component failed: {e}") - for r in remaining: - ray.cancel(r) + print(f"[ASYNC] {name} failed: {e}") + # Cancel remaining futures on error + for remaining_f in futures: + ray.cancel(remaining_f) raise - futures = remaining + + # Trainer is done. If Rollouter is still running, cancel it. + if not rollouter_done and rollouter_future in futures: + ray.cancel(rollouter_future) + print("[ASYNC] Cancelled Rollouter (Trainer finished first)") + except Exception as e: print(f"[ASYNC] Training failed: {e}") - for f in futures: - ray.cancel(f) raise finally: + # Shutdown DataPool after Trainer is done + try: + data_pool = ray.get_actor(self.components["data_pool_name"]) + ray.get(data_pool.shutdown.remote()) + except Exception: + pass print("[ASYNC] Training finished") @@ -273,6 +296,18 @@ def main(config): if not ray.is_initialized(): default_runtime_env = get_ppo_ray_runtime_env() + # Propagate critical env vars to Ray workers so they can find + # clawr1_env packages, CUDA libs, and vLLM settings + _propagate_vars = [ + "PYTHONPATH", "PATH", "LD_LIBRARY_PATH", + "VLLM_USE_V1", "CUDA_HOME", "CONDA_PREFIX", + "SWANLAB_MODE", + ] + for var in _propagate_vars: + val = os.environ.get(var, "") + if val: + default_runtime_env.setdefault("env_vars", {})[var] = val + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) diff --git a/claw_r1/async_rollouter.py b/claw_r1/async_rollouter.py index f57ff7c..08612ba 100644 --- a/claw_r1/async_rollouter.py +++ b/claw_r1/async_rollouter.py @@ -54,7 +54,6 @@ def __init__( self.val_reward_fn = load_reward_manager( config, tokenizer, - num_examine=1, **config.reward_model.get("reward_kwargs", {}), ) @@ -240,8 +239,9 @@ def _init_gateway(self): ray_address = ray_ctx.gcs_address ray_namespace = ray_ctx.namespace + import sys cmd = [ - "python", + sys.executable, "-m", "claw_r1.gateway.gateway", "--data-pool-name", @@ -264,13 +264,27 @@ def _init_gateway(self): str(gateway_port), ] + # Log gateway output to files for debugging + import tempfile + gateway_log_dir = os.path.join( + self.config.trainer.get("default_local_dir", "/tmp"), + "gateway_logs", + ) + os.makedirs(gateway_log_dir, exist_ok=True) + self._gateway_stdout = open(os.path.join(gateway_log_dir, "gateway_stdout.log"), "w") + self._gateway_stderr = open(os.path.join(gateway_log_dir, "gateway_stderr.log"), "w") + self._gateway_process = subprocess.Popen( cmd, - stdout=subprocess.DEVNULL, - stderr=subprocess.PIPE, + stdout=self._gateway_stdout, + stderr=self._gateway_stderr, text=True, ) - self._gateway_url = f"http://localhost:{gateway_port}" + # Use the actual hostname/IP instead of localhost so that + # AgentFlowWorker actors on other nodes can reach the gateway. + import socket + gateway_host = socket.getfqdn() + self._gateway_url = f"http://{gateway_host}:{gateway_port}" atexit.register(self._stop_gateway) for _ in range(120): @@ -344,10 +358,12 @@ async def fit(self): t.cancel() await asyncio.gather(gen_task, monitor_task, return_exceptions=True) - # Signal DataPool shutdown - data_pool = ray.get_actor(self._data_pool_name) - ray.get(data_pool.shutdown.remote()) - logger.info("Rollouter fit completed") + # Do NOT shutdown DataPool here — the Trainer may still be consuming + # batches. The Trainer will exit on its own when it reaches + # total_train_steps or when the DataPool is empty. + # data_pool = ray.get_actor(self._data_pool_name) + # ray.get(data_pool.shutdown.remote()) + logger.info("Rollouter fit completed — generation finished, DataPool remains open for Trainer") async def _generation_main(self): """Iterate over epochs/batches, generate sequences, submit to DataPool.""" @@ -449,7 +465,16 @@ async def update_param_version( val_metrics = None if validate and self.val_reward_fn is not None: - val_metrics = self._validate() + try: + val_metrics = self._validate() + except Exception as exc: + logger.error( + "[AsyncRollouter] Validation failed (version=%d): %s. " + "Skipping validation and continuing training.", + version, exc, + exc_info=True, + ) + val_metrics = None from ray import cloudpickle as ray_cloudpickle diff --git a/claw_r1/async_trainer.py b/claw_r1/async_trainer.py index 886dd38..5905122 100644 --- a/claw_r1/async_trainer.py +++ b/claw_r1/async_trainer.py @@ -58,20 +58,18 @@ def __init__( self.device_name = device_name or self.config.trainer.device self.hybrid_engine = False - self.use_reference_policy = need_reference_policy(self.role_worker_mapping) + self.use_reference_policy = need_reference_policy(self.config) self.use_critic = need_critic(self.config) self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 self.reward_fn = load_reward_manager( config, tokenizer, - num_examine=0, **config.reward_model.get("reward_kwargs", {}), ) self.val_reward_fn = load_reward_manager( config, tokenizer, - num_examine=1, **config.reward_model.get("reward_kwargs", {}), ) @@ -361,6 +359,21 @@ def _process_batch(self, batch: DataProto, metrics: dict, timing_raw: dict) -> D def _compute_reward(self, batch: DataProto): """Compute or extract reward for training.""" + has_rm = "rm_scores" in batch.batch + has_ds = "data_source" in batch.non_tensor_batch if hasattr(batch, "non_tensor_batch") else False + has_gt = "reward_model" in batch.non_tensor_batch if hasattr(batch, "non_tensor_batch") else False + print( + f"[DEBUG _compute_reward] has_rm_scores={has_rm}, has_data_source={has_ds}, " + f"has_reward_model={has_gt}, reward_fn={self.reward_fn is not None}" + ) + if has_rm: + rm = batch.batch["rm_scores"] + print(f"[DEBUG _compute_reward] rm_scores: sum={rm.sum().item():.4f}, max={rm.max().item():.4f}, nonzero={rm.count_nonzero().item()}") + if has_ds: + print(f"[DEBUG _compute_reward] data_source sample: {batch.non_tensor_batch['data_source'][:3]}") + if has_gt: + print(f"[DEBUG _compute_reward] reward_model sample: {batch.non_tensor_batch['reward_model'][:3]}") + if "rm_scores" in batch.batch: reward_extra_keys = batch.meta_info.get("reward_extra_keys", []) reward_extra_infos_dict = ( @@ -371,7 +384,9 @@ def _compute_reward(self, batch: DataProto): if self.reward_fn is not None: from verl.trainer.ppo.reward import compute_reward - return compute_reward(batch, self.reward_fn) + result = compute_reward(batch, self.reward_fn) + print(f"[DEBUG _compute_reward] reward_fn result: sum={result[0].sum().item():.4f}, max={result[0].max().item():.4f}, nonzero={result[0].count_nonzero().item()}") + return result raise ValueError("No reward_fn and no pre-computed rm_scores in batch") diff --git a/claw_r1/blackbox_agent/blackbox_agent_flow.py b/claw_r1/blackbox_agent/blackbox_agent_flow.py index 6fd370a..9bbb044 100644 --- a/claw_r1/blackbox_agent/blackbox_agent_flow.py +++ b/claw_r1/blackbox_agent/blackbox_agent_flow.py @@ -7,6 +7,7 @@ separate modules (e.g. gsm8k_agent_flow.py). """ +import asyncio import json import logging import os @@ -53,6 +54,35 @@ def _prepare_params(self, kwargs: dict[str, Any]) -> tuple[str | None, str, dict metadata = {k: v for k, v in kwargs.items() if k not in _DEFAULT_SKIP_KEYS} return channel, prompt_uid, metadata + async def _http_post_with_retry( + self, + url: str, + max_retries: int = 3, + retry_delay: float = 3.0, + timeout: float = 600.0, + **kwargs, + ) -> httpx.Response: + """POST with retry on transient connection errors.""" + last_exc = None + for attempt in range(max_retries): + try: + async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as http: + resp = await http.post(url, **kwargs) + resp.raise_for_status() + return resp + except (httpx.ReadError, httpx.ConnectError, httpx.RemoteProtocolError) as exc: + last_exc = exc + if attempt < max_retries - 1: + wait = retry_delay * (2 ** attempt) + logger.warning( + "HTTP POST %s failed (attempt %d/%d): %s. Retrying in %.1fs...", + url, attempt + 1, max_retries, exc, wait, + ) + await asyncio.sleep(wait) + else: + logger.error("HTTP POST %s failed after %d attempts: %s", url, max_retries, exc) + raise last_exc + async def run(self, sampling_params: dict[str, Any], **kwargs) -> int: channel, prompt_uid, metadata = self._prepare_params(kwargs) @@ -80,17 +110,55 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> int: headers={"content-type": "application/json"}, ) - # 3. Run the concrete agent. + # 3. Run the concrete agent (with retry on transient errors). + reward = None + num_turns = 0 try: - num_turns = await self._run_agent(base_url, kwargs) + result = await self._run_agent(base_url, kwargs) + # _run_agent may return int (num_turns) or tuple (num_turns, reward) + if isinstance(result, tuple): + num_turns, reward = result + else: + num_turns = result + except (httpx.ReadError, httpx.ConnectError, httpx.RemoteProtocolError) as exc: + # Agent failed due to transient connection error (e.g. vLLM weight update). + # Log warning but still complete the trajectory with reward=0. + logger.warning( + "Agent run failed with transient error: %s. Completing trajectory with reward=0.", + exc, + ) + reward = 0.0 finally: - # 4. Mark trajectory complete. - async with httpx.AsyncClient(timeout=httpx.Timeout(600.0)) as http: - await http.post(f"{base_url}/complete_trajectory") + # 4. Mark trajectory complete, passing reward if available. + # Use retry to handle transient Gateway connection issues. + try: + body: dict[str, Any] = {} + if reward is not None: + body["reward"] = float(reward) + if channel: + body["channel"] = channel + if body: + await self._http_post_with_retry( + f"{base_url}/complete_trajectory", + json=body, + ) + else: + await self._http_post_with_retry( + f"{base_url}/complete_trajectory", + ) + except Exception as exc: + logger.error( + "Failed to complete trajectory after retries: %s", exc, + ) return num_turns @abstractmethod - async def _run_agent(self, base_url: str, kwargs: dict[str, Any]) -> int: - """Create and run the concrete Agent. Subclasses implement this.""" + async def _run_agent(self, base_url: str, kwargs: dict[str, Any]) -> int | tuple[int, float]: + """Create and run the concrete Agent. Subclasses implement this. + + Returns either: + - ``int``: number of turns used (reward computed by Gateway) + - ``tuple[int, float]``: (turns_used, reward) for direct reward passing + """ raise NotImplementedError diff --git a/claw_r1/blackbox_agent/gsm8k_agent.py b/claw_r1/blackbox_agent/gsm8k_agent.py index 969038a..408d777 100644 --- a/claw_r1/blackbox_agent/gsm8k_agent.py +++ b/claw_r1/blackbox_agent/gsm8k_agent.py @@ -101,15 +101,17 @@ def __init__(self, base_url: str): timeout=600.0, ) - async def solve(self, question: str, ground_truth: str, max_turns: int = 3) -> int: + async def solve(self, question: str, ground_truth: str, max_turns: int = 3) -> tuple[int, float]: """Attempt to solve *question* in up to *max_turns* LLM interactions. - Returns the number of turns actually used. Trajectory completion is - signaled by the caller (BlackBoxAgentFlowBase or online service entrypoint). + Returns a tuple of (turns_used, reward). ``reward`` is 1.0 if the + agent called ``check_answer`` with the correct answer, 0.0 otherwise. + Trajectory completion is signaled by the caller. """ messages: list[dict] = [{"role": "user", "content": question}] turns_used = 0 + reward = 0.0 for turn in range(max_turns): turns_used = turn + 1 @@ -127,9 +129,11 @@ async def solve(self, question: str, ground_truth: str, max_turns: int = 3) -> i if tc["name"] == "check_answer": answer = tc["arguments"].get("answer", "") result = check_answer(answer, ground_truth) + if "Correct" in result: + reward = 1.0 messages.append({"role": "tool", "content": result}) else: messages.append({"role": "assistant", "content": content}) break - return turns_used + return turns_used, reward diff --git a/claw_r1/blackbox_agent/gsm8k_agent_flow.py b/claw_r1/blackbox_agent/gsm8k_agent_flow.py index c6ae262..eefbac0 100644 --- a/claw_r1/blackbox_agent/gsm8k_agent_flow.py +++ b/claw_r1/blackbox_agent/gsm8k_agent_flow.py @@ -13,7 +13,7 @@ class BlackBoxGSM8KAgentFlow(BlackBoxAgentFlowBase): """Black-box flow that delegates to :class:`GSM8KAgent`.""" - async def _run_agent(self, base_url: str, kwargs: dict[str, Any]) -> int: + async def _run_agent(self, base_url: str, kwargs: dict[str, Any]) -> tuple[int, float]: raw_prompt = kwargs.get("raw_prompt", []) if isinstance(raw_prompt, list) and raw_prompt: question = next( diff --git a/claw_r1/data_pool/data_pool.py b/claw_r1/data_pool/data_pool.py index 83d6754..1598863 100644 --- a/claw_r1/data_pool/data_pool.py +++ b/claw_r1/data_pool/data_pool.py @@ -239,6 +239,18 @@ def complete_trajectory( return True + def get_last_step( + self, + trajectory_uid: str, + channel: str = DEFAULT_CHANNEL, + ) -> Step | None: + """Return the last Step of a trajectory, or None if not found.""" + ch = self._ch(channel) + idx_list = ch.trajectory_index.get(trajectory_uid) + if not idx_list: + return None + return ch.steps[idx_list[-1]] + # ── Lifecycle ────────────────────────────────────────────────────────── def shutdown(self, channel: str | None = None) -> None: @@ -279,7 +291,7 @@ def get_statistics(self, channel: str = DEFAULT_CHANNEL) -> dict: "total_dropped": ch.total_dropped, "queue_size": unconsumed, "ready_prompt_groups": ready, - "max_queue_size": self._max_queue_size, + "max_queue_size": self._max_queue_size if self._max_queue_size is not None else -1, "shutdown": ch.shutdown, } diff --git a/claw_r1/data_pool/training_backend.py b/claw_r1/data_pool/training_backend.py index 4bf4375..428fec1 100644 --- a/claw_r1/data_pool/training_backend.py +++ b/claw_r1/data_pool/training_backend.py @@ -95,6 +95,15 @@ def convert(self, steps: list[Step]) -> DataProto: has_reward = any(s.reward is not None for s in steps) reward_tensors: list[torch.Tensor] = [] + # DEBUG: log reward status + reward_values = [s.reward for s in steps[:5]] # first 5 steps + import logging as _logging + _tb_logger = _logging.getLogger("claw_r1.training_backend") + _tb_logger.warning( + "[DEBUG] VerlBackend.convert: %d steps, has_reward=%s, sample rewards=%s", + len(steps), has_reward, reward_values, + ) + for step in steps: padded = self._pad_single_step(step) prompt_ids_list.append(padded["prompt_ids"]) diff --git a/claw_r1/detach_workers.py b/claw_r1/detach_workers.py index 93ab4f7..0a7d9fc 100644 --- a/claw_r1/detach_workers.py +++ b/claw_r1/detach_workers.py @@ -5,6 +5,11 @@ weight synchronization so the ParameterSynchronizer can broadcast updated actor weights to the rollout replicas. +When the rollout uses a ServerAdapter (vLLM HTTP server mode), the rollout +worker skips the NCCL broadcast (vLLM sleep mode consumes all GPU memory) +and instead receives weights via Ray object store, then pushes them to the +vLLM server via the ServerAdapter's CUDA IPC update_weights mechanism. + Based on ``verl/recipe/fully_async_policy/fsdp_workers.py``. """ @@ -35,7 +40,14 @@ def _get_inference_model(rollout): - """Extract the underlying model from a vLLM/SGLang inference engine.""" + """Extract the underlying model from a vLLM/SGLang inference engine. + + Returns None for ServerAdapter (async HTTP server mode) which uses + CUDA IPC/ZMQ for weight updates instead of direct model access. + """ + if not hasattr(rollout, "inference_engine"): + # ServerAdapter in async mode — no direct model access + return None engine = rollout.inference_engine if hasattr(engine, "llm_engine"): return engine.llm_engine.model_executor.driver_worker.worker.model_runner.model @@ -44,18 +56,37 @@ def _get_inference_model(rollout): raise AttributeError(f"Unsupported inference_engine type: {type(engine)}") +def _is_server_adapter(rollout): + """Check if the rollout is a ServerAdapter (async HTTP server mode).""" + return not hasattr(rollout, "inference_engine") + + class _DetachNcclSync(AsyncActorRolloutRefWorker): - """Mixin adding NCCL-based weight synchronization between actor and rollout.""" + """Mixin adding NCCL-based weight synchronization between actor and rollout. + + For ServerAdapter rollouts, the NCCL broadcast is skipped on the rollout side + because vLLM sleep mode consumes all GPU memory. Instead, the actor extracts + weights and sends them via Ray object store to the rollout worker. + """ def _get_actor_params(self): raise NotImplementedError @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) def sync_rollout_weights(self, sync_group_name="actor_rollout"): - """Broadcast actor weights to rollout via NCCL collective.""" + """Broadcast actor weights to rollout via NCCL collective. + + For ServerAdapter rollouts, this is a no-op on the rollout side. + Weight transfer is handled separately via extract_and_send_weights + on the actor side and receive_and_update_weights on the rollout side. + """ assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine assert hasattr(self, "_weights_info") and self._weights_info is not None + # For ServerAdapter rollouts, skip NCCL broadcast entirely + if self._is_rollout and _is_server_adapter(self.rollout): + return + if self._is_actor and self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) @@ -63,9 +94,10 @@ def sync_rollout_weights(self, sync_group_name="actor_rollout"): if self._is_rollout: inference_model = _get_inference_model(self.rollout) - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader + if inference_model is not None: + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - patch_vllm_moe_model_weight_loader(inference_model) + patch_vllm_moe_model_weight_loader(inference_model) for key, shape, dtype in self._weights_info: tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) @@ -88,6 +120,73 @@ def sync_rollout_weights(self, sync_group_name="actor_rollout"): offload_fsdp_model_to_cpu(self.actor_module_fsdp) get_torch_device().empty_cache() + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def extract_actor_weights(self): + """Extract actor weights as a CPU state dict for ServerAdapter sync.""" + assert self._is_actor + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + + params = self._get_actor_params() + # Move to CPU to avoid GPU memory issues during transfer + # Also convert DTensors (from FSDP2) to regular tensors to avoid + # "mixed torch.Tensor and DTensor" errors in vLLM weight transfer + cpu_params = {} + for k, v in params.items(): + if hasattr(v, 'full_tensor'): + v = v.full_tensor() + elif hasattr(v, '_local_tensor'): + v = v._local_tensor + cpu_params[k] = v.cpu() + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + get_torch_device().empty_cache() + + return cpu_params + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def receive_and_update_weights(self, cpu_params): + """Receive weights from actor and push to vLLM server via ServerAdapter. + + Ray async actors use uvloop which doesn't support nested event loops. + We run the async update_weights in a separate thread with its own + asyncio event loop to avoid the 'event loop already running' error. + """ + assert self._is_rollout + + import asyncio + import concurrent.futures + + def _weight_generator(): + for key, tensor in cpu_params.items(): + # Ensure we have a plain torch.Tensor (not DTensor from FSDP) + if hasattr(tensor, 'full_tensor'): + tensor = tensor.full_tensor() + elif hasattr(tensor, '_local_tensor'): + tensor = tensor._local_tensor + t = tensor.to(get_torch_device().current_device()) + # Final safety: if still a DTensor, extract the local data + if not isinstance(t, torch.Tensor) or type(t).__name__ == 'DTensor': + t = t.data if hasattr(t, 'data') else t + yield key, t + + def _run_in_new_loop(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + new_loop.run_until_complete( + self.rollout.update_weights(_weight_generator()) + ) + finally: + new_loop.close() + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(_run_in_new_loop) + future.result() + + get_torch_device().empty_cache() + class DetachActorWorker(_DetachNcclSync): """Actor worker for async mode — training only, no rollout.""" @@ -136,3 +235,8 @@ def set_actor_weights_info(self, weights_info): """Receive weights info from the actor side so sync can proceed.""" assert self._is_rollout self._weights_info = weights_info + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def check_server_adapter(self): + """Return True if this rollout worker uses a ServerAdapter.""" + return _is_server_adapter(self.rollout) diff --git a/claw_r1/gateway/gateway.py b/claw_r1/gateway/gateway.py index d72d280..6570756 100644 --- a/claw_r1/gateway/gateway.py +++ b/claw_r1/gateway/gateway.py @@ -25,6 +25,8 @@ from typing import Any, Optional from uuid import uuid4 +import asyncio + import httpx import ray import uvicorn @@ -101,6 +103,42 @@ def _next_vllm_address() -> str: return _normalize_address(next(_vllm_cycle)) +async def _vllm_post_with_retry( + url: str, + json: dict[str, Any], + timeout: float = 600.0, + max_retries: int = 3, + retry_delay: float = 5.0, +) -> httpx.Response: + """POST to vLLM with retry on connection errors (ReadError, ConnectError). + + After weight updates, vLLM connections may be stale. This helper retries + with exponential backoff to handle transient connection failures. + """ + last_exc = None + for attempt in range(max_retries): + try: + resp = await _http_client.post(url, json=json, timeout=timeout) + resp.raise_for_status() + return resp + except (httpx.ReadError, httpx.ConnectError, httpx.RemoteProtocolError) as exc: + last_exc = exc + if attempt < max_retries - 1: + wait = retry_delay * (2 ** attempt) + logger.warning( + "vLLM request failed (attempt %d/%d): %s. Retrying in %.1fs...", + attempt + 1, max_retries, exc, wait, + ) + await asyncio.sleep(wait) + else: + logger.error( + "vLLM request failed after %d attempts: %s", max_retries, exc, + ) + except httpx.HTTPStatusError: + raise + raise last_exc + + # ── White-box endpoints (implemented) ──────────────────────────────────── @@ -147,8 +185,7 @@ async def generate(req: GenerateRequest): vllm_payload["model"] = model try: - resp = await _http_client.post(url, json=vllm_payload, timeout=600.0) - resp.raise_for_status() + resp = await _vllm_post_with_retry(url, json=vllm_payload, timeout=600.0) except httpx.HTTPStatusError as exc: raise HTTPException(exc.response.status_code, f"vLLM error: {exc.response.text}") from None except httpx.RequestError as exc: @@ -311,8 +348,9 @@ async def chat_completions_proxy(trajectory_uid: str, prompt_uid: str, request: vllm_payload["model"] = model try: - resp = await _http_client.post(f"{base_url}/v1/completions", json=vllm_payload, timeout=600.0) - resp.raise_for_status() + resp = await _vllm_post_with_retry( + f"{base_url}/v1/completions", json=vllm_payload, timeout=600.0, + ) except httpx.HTTPStatusError as exc: raise HTTPException(exc.response.status_code, f"vLLM error: {exc.response.text}") from None except httpx.RequestError as exc: @@ -384,12 +422,40 @@ async def complete_trajectory(trajectory_uid: str, prompt_uid: str, req: Complet Called by agents via ``POST {base_url}/v1/complete_trajectory``. ``trajectory_uid`` is extracted from the URL path. An optional request body can override ``channel`` and supply a ``reward``. + + If no explicit reward is provided and a RewardLoopWorker is available, + the Gateway will automatically compute the reward from the last Step's + prompt/response tokens and the trajectory metadata (which carries + dataset fields like ``data_source`` and ``reward_model``). """ if _data_pool is None: raise HTTPException(503, "DataPool not connected") channel = req.channel if req else _trajectory_channel.get(trajectory_uid, "train") reward = req.reward if req else None + + # ── Auto-compute reward for black-box agents ────────────────────── + # Black-box agents don't call /compute_reward themselves; they only + # call /v1/chat/completions + /v1/complete_trajectory. When no + # explicit reward is supplied, we compute it here from the last Step. + if reward is None and _reward_worker is not None: + try: + reward = await _auto_compute_reward(trajectory_uid) + logger.warning( + "[DEBUG] Auto reward for %s: %s", trajectory_uid, reward, + ) + except Exception as exc: + import traceback + logger.warning( + "Auto reward computation failed for trajectory %s: %s\n%s", + trajectory_uid, exc, traceback.format_exc(), + ) + elif reward is None: + logger.warning( + "[DEBUG] complete_trajectory %s: reward=None, _reward_worker=%s", + trajectory_uid, _reward_worker, + ) + await _data_pool.complete_trajectory.remote(trajectory_uid, reward=reward, channel=channel) _trajectory_step_counter.pop(trajectory_uid, None) _trajectory_channel.pop(trajectory_uid, None) @@ -397,6 +463,91 @@ async def complete_trajectory(trajectory_uid: str, prompt_uid: str, req: Complet return {"status": "ok"} +async def _auto_compute_reward(trajectory_uid: str) -> float | None: + """Compute reward for the last Step of a trajectory via RewardLoopWorker. + + Retrieves the last Step from the DataPool, builds a DataProto with the + Step's prompt/response tokens and the trajectory metadata, then calls + the RewardLoopWorker to compute the score. + + Returns the reward score, or None if computation is not possible. + """ + import numpy as np + import torch + from tensordict import TensorDict + + from verl.protocol import DataProto + from verl.utils.model import compute_position_id_with_mask + + # Get the last step from DataPool + last_step = await _data_pool.get_last_step.remote(trajectory_uid) + if last_step is None: + return None + + prompt_ids = last_step.prompt_ids + response_ids = last_step.response_ids + metadata = _trajectory_metadata.get(trajectory_uid) or last_step.metadata or {} + + if not prompt_ids or not response_ids: + return None + + # Pad prompt (left) and response (right) to fixed lengths + _tokenizer.padding_side = "left" + prompt_out = _tokenizer.pad( + {"input_ids": prompt_ids}, + padding="max_length", + max_length=_prompt_length, + return_tensors="pt", + return_attention_mask=True, + ) + if prompt_out["input_ids"].dim() == 1: + prompt_out["input_ids"] = prompt_out["input_ids"].unsqueeze(0) + prompt_out["attention_mask"] = prompt_out["attention_mask"].unsqueeze(0) + + _tokenizer.padding_side = "right" + response_out = _tokenizer.pad( + {"input_ids": response_ids}, + padding="max_length", + max_length=_response_length, + return_tensors="pt", + return_attention_mask=True, + ) + if response_out["input_ids"].dim() == 1: + response_out["input_ids"] = response_out["input_ids"].unsqueeze(0) + response_out["attention_mask"] = response_out["attention_mask"].unsqueeze(0) + + attention_mask = torch.cat( + [prompt_out["attention_mask"], response_out["attention_mask"]], + dim=1, + ) + input_ids = torch.cat( + [prompt_out["input_ids"], response_out["input_ids"]], + dim=1, + ) + position_ids = compute_position_id_with_mask(attention_mask) + + batch = TensorDict( + { + "prompts": prompt_out["input_ids"], + "responses": response_out["input_ids"], + "attention_mask": attention_mask, + "input_ids": input_ids, + "position_ids": position_ids, + }, + batch_size=1, + ) + + # Build non_tensor_batch from trajectory metadata + non_tensor_batch: dict = {} + for k, v in metadata.items(): + non_tensor_batch[k] = np.array([v], dtype=object) + + data = DataProto(batch=batch, non_tensor_batch=non_tensor_batch) + result = await _reward_worker.compute_score.remote(data) + + return result.get("reward_score") + + @app.post("/{trajectory_uid}/{prompt_uid}/v1/register_trajectory") async def register_trajectory(trajectory_uid: str, prompt_uid: str, request: Request): """Register channel and metadata for a trajectory before the agent starts. diff --git a/claw_r1/main_agent_ppo.py b/claw_r1/main_agent_ppo.py index c9ee0eb..e0ac2e6 100644 --- a/claw_r1/main_agent_ppo.py +++ b/claw_r1/main_agent_ppo.py @@ -285,7 +285,7 @@ def run(self, config): # validate config validate_config( config=config, - use_reference_policy=need_reference_policy(self.role_worker_mapping), + use_reference_policy=need_reference_policy(config), use_critic=need_critic(config), ) @@ -305,10 +305,10 @@ def run(self, config): # Load the reward manager for training and validation. reward_fn = load_reward_manager( - config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + config, tokenizer, **config.reward_model.get("reward_kwargs", {}) ) val_reward_fn = load_reward_manager( - config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + config, tokenizer, **config.reward_model.get("reward_kwargs", {}) ) resource_pool_manager = self.init_resource_pool_mgr(config) @@ -316,7 +316,6 @@ def run(self, config): from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler from verl.utils.dataset.rl_dataset import collate_fn - # Create training and validation datasets. # Create training and validation datasets. train_dataset = create_rl_dataset( config.data.train_files, @@ -337,6 +336,9 @@ def run(self, config): train_sampler = create_rl_sampler(config.data, train_dataset) # Initialize the Agent trainer. + # RayAgentTrainer extends RayPPOTrainer which does NOT accept + # reward_fn/val_reward_fn as __init__ args. They are set as + # attributes after construction instead. trainer = RayAgentTrainer( config=config, tokenizer=tokenizer, @@ -344,13 +346,15 @@ def run(self, config): role_worker_mapping=self.role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn, train_dataset=train_dataset, val_dataset=val_dataset, collate_fn=collate_fn, train_sampler=train_sampler, ) + # Set reward functions as attributes (RayPPOTrainer creates them + # internally in fit(), but we want to use our pre-loaded ones). + trainer.reward_fn = reward_fn + trainer.val_reward_fn = val_reward_fn # Initialize the workers of the trainer. trainer.init_workers() diff --git a/claw_r1/param_sync.py b/claw_r1/param_sync.py index cb5b36f..25ecee3 100644 --- a/claw_r1/param_sync.py +++ b/claw_r1/param_sync.py @@ -2,13 +2,22 @@ Based on ``verl/recipe/fully_async_policy/param_sync.py``. -Flow:: +Flow (NCCL mode — colocated or direct model access):: sync_weights(version) 1. ray.get(rollouter.pause()) # stop generation, clear KV cache 2. NCCL broadcast: Actor → vLLM # via DetachActorWorker / DetachAsyncRolloutWorker 3. rollouter.update_param_version(...) # async — updates version, optionally validates 4. rollouter.resume(...) # async — resumes generation + +Flow (ServerAdapter mode — vLLM HTTP server with CUDA IPC):: + + sync_weights(version) + 1. ray.get(rollouter.pause()) + 2. Actor extracts weights to CPU → Ray object store → Rollout worker + 3. Rollout worker pushes weights to vLLM server via ServerAdapter.update_weights() + 4. rollouter.update_param_version(...) + 5. rollouter.resume(...) """ import logging @@ -29,6 +38,10 @@ class ParameterSynchronizer: Creates an NCCL collective group spanning actor and rollout workers so that ``sync_weights`` can broadcast the latest actor parameters to the vLLM inference replicas. + + For ServerAdapter rollouts (vLLM HTTP server mode), NCCL is not used for + the rollout side. Instead, weights are extracted on the actor, transferred + via Ray object store, and pushed to the vLLM server via CUDA IPC. """ def __init__(self, config, trainer, rollouter): @@ -46,8 +59,25 @@ def __init__(self, config, trainer, rollouter): self._wait_last_update = None self._wait_last_resume = None + # Detect if rollout uses ServerAdapter (no inference_engine) + self._use_server_adapter = self._detect_server_adapter() + self._init_weights_info() - self._init_sync_group() + if not self._use_server_adapter: + self._init_sync_group() + else: + logger.info("ServerAdapter detected — using Ray object store for weight sync (no NCCL group)") + + def _detect_server_adapter(self): + """Check if the rollout workers use ServerAdapter.""" + try: + # Ask the rollout worker if it has a ServerAdapter + result = ray.get(self.rollout_wg.check_server_adapter()) + return any(result) if isinstance(result, list) else bool(result) + except Exception: + # If the method doesn't exist, assume ServerAdapter based on config + # ServerAdapter is used when rollout is on a separate GPU pool (async mode) + return True def _init_weights_info(self): self.weights_info = self.actor_wg.get_actor_weights_info()[0] @@ -68,7 +98,7 @@ def get_current_param_version(self) -> int: return self.current_version def sync_weights(self, version: int, validate: bool = False, global_steps: int = 0): - """Pause rollout, broadcast weights, then resume.""" + """Pause rollout, sync weights, then resume.""" start = time.time() self.current_version = version @@ -76,8 +106,10 @@ def sync_weights(self, version: int, validate: bool = False, global_steps: int = pause_time = time.time() logger.info("Rollouter paused in %.2fs", pause_time - start) - self.actor_wg.sync_rollout_weights(self.sync_group_name) - ray.get(self.rollout_wg.sync_rollout_weights(self.sync_group_name)) + if self._use_server_adapter: + self._sync_weights_via_ray_store() + else: + self._sync_weights_via_nccl() sync_time = time.time() logger.info( @@ -95,6 +127,21 @@ def sync_weights(self, version: int, validate: bool = False, global_steps: int = ) self._wait_last_resume = self.rollouter.resume.remote(self._wait_last_update) + def _sync_weights_via_nccl(self): + """Original NCCL broadcast path for colocated/direct model access.""" + self.actor_wg.sync_rollout_weights(self.sync_group_name) + ray.get(self.rollout_wg.sync_rollout_weights(self.sync_group_name)) + + def _sync_weights_via_ray_store(self): + """ServerAdapter path: extract weights on actor, send via Ray, push to vLLM.""" + # Step 1: Extract weights from actor to CPU (returns dict via Ray object store) + cpu_params_refs = self.actor_wg.extract_actor_weights() + # actor_wg returns a list (one per worker), we need the first one + cpu_params = ray.get(cpu_params_refs[0]) + + # Step 2: Send to rollout worker which pushes to vLLM server + ray.get(self.rollout_wg.receive_and_update_weights(cpu_params)) + def wait_last_valid(self): """Block until the last sync + optional validation completes.""" if self._wait_last_update: diff --git a/claw_r1/reward_loop.py b/claw_r1/reward_loop.py index 74a9f52..7e3eac0 100644 --- a/claw_r1/reward_loop.py +++ b/claw_r1/reward_loop.py @@ -20,7 +20,10 @@ import ray from omegaconf import DictConfig -from verl.experimental.reward_loop.reward_loop import get_reward_manager_cls +try: + from verl.experimental.reward_loop.reward_manager import get_reward_manager_cls +except ImportError: + from verl.experimental.reward_loop.reward_loop import get_reward_manager_cls from verl.protocol import DataProto from verl.trainer.ppo.reward import get_custom_reward_fn from verl.utils import hf_tokenizer @@ -66,11 +69,12 @@ def _init_reward_fn(self): # Load reward loop manager class # Support both registry and importlib loading methods - reward_loop_source = self.config.reward_model.get("reward_loop_source", "register") + reward_loop_source = self.config.reward_model.get("reward_loop_source", None) or "register" if reward_loop_source == "register": # Load from registry (default behavior) - reward_manager_cls = get_reward_manager_cls(self.config.reward_model.reward_manager) + reward_manager_name = self.config.reward_model.get("reward_manager", None) or "naive" + reward_manager_cls = get_reward_manager_cls(reward_manager_name) elif reward_loop_source == "importlib": # Load from external module using importlib from verl.utils.import_utils import load_extern_object diff --git a/pyproject.toml b/pyproject.toml index 458345b..48bf3ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,54 @@ +# ------------------------------- +# build-system +# ------------------------------- +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +# ------------------------------- +# project (PEP 621 metadata) +# ------------------------------- +[project] +name = "claw_r1" +version = "0.1.0" +description = "Claw-R1: Empowering OpenClaw with Advanced Agentic RL" +license = {text = "Apache-2.0"} +readme = {file = "README.md", content-type = "text/markdown"} +requires-python = ">=3.10" +dependencies = [ + "httpx", + "ray[default]", + "uvicorn", + "fastapi", + "hydra-core", + "omegaconf", + "numpy<2.0.0", + "torch", + "torchdata", + "tensordict>=0.8.0,<=0.10.0,!=0.9.0", + "transformers", + "pydantic", + "Pillow", + "msgpack", + "idna", + "openai", + "flash-attn>=2.5.0", + "swanlab[dashboard]", +] + +# ------------------------------- +# tool.setuptools - Package discovery +# ------------------------------- +[tool.setuptools.packages.find] +include = ["claw_r1*"] +exclude = ["assets*", "docs*", "example*"] + +[tool.setuptools.package-data] +claw_r1 = [ + "config/*.yaml", + "config/**/*.yaml", +] + # ------------------------------- # tool.ruff - Linting configuration # -------------------------------