Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added claw_r1/__init__.py
Empty file.
59 changes: 47 additions & 12 deletions claw_r1/async_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _initialize(self, config):

validate_config(
config=config,
use_reference_policy=need_reference_policy(role_worker_mapping),
use_reference_policy=need_reference_policy(config),
use_critic=need_critic(config),
)

Expand Down Expand Up @@ -240,27 +240,50 @@ def _run(self):
rollouter_future = self.components["rollouter"].fit.remote()
trainer_future = self.components["trainer"].fit.remote()

futures = [rollouter_future, trainer_future]

# The Rollouter generates data much faster than the Trainer consumes it.
# We must wait for the Trainer to finish (it controls the training steps).
# The Rollouter finishing first is normal — it just means all data has
# been generated. We should NOT cancel the Trainer when that happens.
try:
while futures:
done, remaining = ray.wait(futures, num_returns=1, timeout=None)
# Wait for Trainer to finish (the primary task)
# Also monitor Rollouter for errors
futures = {rollouter_future: "Rollouter", trainer_future: "Trainer"}
trainer_done = False
rollouter_done = False

while not trainer_done:
done, _ = ray.wait(list(futures.keys()), num_returns=1, timeout=None)
for f in done:
name = futures.pop(f)
try:
ray.get(f)
print("[ASYNC] Component completed successfully")
print(f"[ASYNC] {name} completed successfully")
if name == "Trainer":
trainer_done = True
elif name == "Rollouter":
rollouter_done = True
except Exception as e:
print(f"[ASYNC] Component failed: {e}")
for r in remaining:
ray.cancel(r)
print(f"[ASYNC] {name} failed: {e}")
# Cancel remaining futures on error
for remaining_f in futures:
ray.cancel(remaining_f)
raise
futures = remaining

# Trainer is done. If Rollouter is still running, cancel it.
if not rollouter_done and rollouter_future in futures:
ray.cancel(rollouter_future)
print("[ASYNC] Cancelled Rollouter (Trainer finished first)")

except Exception as e:
print(f"[ASYNC] Training failed: {e}")
for f in futures:
ray.cancel(f)
raise
finally:
# Shutdown DataPool after Trainer is done
try:
data_pool = ray.get_actor(self.components["data_pool_name"])
ray.get(data_pool.shutdown.remote())
except Exception:
pass
print("[ASYNC] Training finished")


Expand All @@ -273,6 +296,18 @@ def main(config):

if not ray.is_initialized():
default_runtime_env = get_ppo_ray_runtime_env()
# Propagate critical env vars to Ray workers so they can find
# clawr1_env packages, CUDA libs, and vLLM settings
_propagate_vars = [
"PYTHONPATH", "PATH", "LD_LIBRARY_PATH",
"VLLM_USE_V1", "CUDA_HOME", "CONDA_PREFIX",
"SWANLAB_MODE",
]
for var in _propagate_vars:
val = os.environ.get(var, "")
if val:
default_runtime_env.setdefault("env_vars", {})[var] = val

ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
Expand Down
45 changes: 35 additions & 10 deletions claw_r1/async_rollouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
self.val_reward_fn = load_reward_manager(
config,
tokenizer,
num_examine=1,
**config.reward_model.get("reward_kwargs", {}),
)

Expand Down Expand Up @@ -240,8 +239,9 @@ def _init_gateway(self):
ray_address = ray_ctx.gcs_address
ray_namespace = ray_ctx.namespace

import sys
cmd = [
"python",
sys.executable,
"-m",
"claw_r1.gateway.gateway",
"--data-pool-name",
Expand All @@ -264,13 +264,27 @@ def _init_gateway(self):
str(gateway_port),
]

# Log gateway output to files for debugging
import tempfile
gateway_log_dir = os.path.join(
self.config.trainer.get("default_local_dir", "/tmp"),
"gateway_logs",
)
os.makedirs(gateway_log_dir, exist_ok=True)
self._gateway_stdout = open(os.path.join(gateway_log_dir, "gateway_stdout.log"), "w")
self._gateway_stderr = open(os.path.join(gateway_log_dir, "gateway_stderr.log"), "w")

self._gateway_process = subprocess.Popen(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
stdout=self._gateway_stdout,
stderr=self._gateway_stderr,
text=True,
)
self._gateway_url = f"http://localhost:{gateway_port}"
# Use the actual hostname/IP instead of localhost so that
# AgentFlowWorker actors on other nodes can reach the gateway.
import socket
gateway_host = socket.getfqdn()
self._gateway_url = f"http://{gateway_host}:{gateway_port}"
atexit.register(self._stop_gateway)

for _ in range(120):
Expand Down Expand Up @@ -344,10 +358,12 @@ async def fit(self):
t.cancel()
await asyncio.gather(gen_task, monitor_task, return_exceptions=True)

# Signal DataPool shutdown
data_pool = ray.get_actor(self._data_pool_name)
ray.get(data_pool.shutdown.remote())
logger.info("Rollouter fit completed")
# Do NOT shutdown DataPool here — the Trainer may still be consuming
# batches. The Trainer will exit on its own when it reaches
# total_train_steps or when the DataPool is empty.
# data_pool = ray.get_actor(self._data_pool_name)
# ray.get(data_pool.shutdown.remote())
logger.info("Rollouter fit completed — generation finished, DataPool remains open for Trainer")

async def _generation_main(self):
"""Iterate over epochs/batches, generate sequences, submit to DataPool."""
Expand Down Expand Up @@ -449,7 +465,16 @@ async def update_param_version(

val_metrics = None
if validate and self.val_reward_fn is not None:
val_metrics = self._validate()
try:
val_metrics = self._validate()
except Exception as exc:
logger.error(
"[AsyncRollouter] Validation failed (version=%d): %s. "
"Skipping validation and continuing training.",
version, exc,
exc_info=True,
)
val_metrics = None

from ray import cloudpickle as ray_cloudpickle

Expand Down
23 changes: 19 additions & 4 deletions claw_r1/async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,18 @@ def __init__(
self.device_name = device_name or self.config.trainer.device

self.hybrid_engine = False
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
self.use_reference_policy = need_reference_policy(self.config)
self.use_critic = need_critic(self.config)
self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0

self.reward_fn = load_reward_manager(
config,
tokenizer,
num_examine=0,
**config.reward_model.get("reward_kwargs", {}),
)
self.val_reward_fn = load_reward_manager(
config,
tokenizer,
num_examine=1,
**config.reward_model.get("reward_kwargs", {}),
)

Expand Down Expand Up @@ -361,6 +359,21 @@ def _process_batch(self, batch: DataProto, metrics: dict, timing_raw: dict) -> D

def _compute_reward(self, batch: DataProto):
"""Compute or extract reward for training."""
has_rm = "rm_scores" in batch.batch
has_ds = "data_source" in batch.non_tensor_batch if hasattr(batch, "non_tensor_batch") else False
has_gt = "reward_model" in batch.non_tensor_batch if hasattr(batch, "non_tensor_batch") else False
print(
f"[DEBUG _compute_reward] has_rm_scores={has_rm}, has_data_source={has_ds}, "
f"has_reward_model={has_gt}, reward_fn={self.reward_fn is not None}"
)
if has_rm:
rm = batch.batch["rm_scores"]
print(f"[DEBUG _compute_reward] rm_scores: sum={rm.sum().item():.4f}, max={rm.max().item():.4f}, nonzero={rm.count_nonzero().item()}")
if has_ds:
print(f"[DEBUG _compute_reward] data_source sample: {batch.non_tensor_batch['data_source'][:3]}")
if has_gt:
print(f"[DEBUG _compute_reward] reward_model sample: {batch.non_tensor_batch['reward_model'][:3]}")

if "rm_scores" in batch.batch:
reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
reward_extra_infos_dict = (
Expand All @@ -371,7 +384,9 @@ def _compute_reward(self, batch: DataProto):
if self.reward_fn is not None:
from verl.trainer.ppo.reward import compute_reward

return compute_reward(batch, self.reward_fn)
result = compute_reward(batch, self.reward_fn)
print(f"[DEBUG _compute_reward] reward_fn result: sum={result[0].sum().item():.4f}, max={result[0].max().item():.4f}, nonzero={result[0].count_nonzero().item()}")
return result

raise ValueError("No reward_fn and no pre-computed rm_scores in batch")

Expand Down
82 changes: 75 additions & 7 deletions claw_r1/blackbox_agent/blackbox_agent_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
separate modules (e.g. gsm8k_agent_flow.py).
"""

import asyncio
import json
import logging
import os
Expand Down Expand Up @@ -53,6 +54,35 @@ def _prepare_params(self, kwargs: dict[str, Any]) -> tuple[str | None, str, dict
metadata = {k: v for k, v in kwargs.items() if k not in _DEFAULT_SKIP_KEYS}
return channel, prompt_uid, metadata

async def _http_post_with_retry(
self,
url: str,
max_retries: int = 3,
retry_delay: float = 3.0,
timeout: float = 600.0,
**kwargs,
) -> httpx.Response:
"""POST with retry on transient connection errors."""
last_exc = None
for attempt in range(max_retries):
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as http:
resp = await http.post(url, **kwargs)
resp.raise_for_status()
return resp
except (httpx.ReadError, httpx.ConnectError, httpx.RemoteProtocolError) as exc:
last_exc = exc
if attempt < max_retries - 1:
wait = retry_delay * (2 ** attempt)
logger.warning(
"HTTP POST %s failed (attempt %d/%d): %s. Retrying in %.1fs...",
url, attempt + 1, max_retries, exc, wait,
)
await asyncio.sleep(wait)
else:
logger.error("HTTP POST %s failed after %d attempts: %s", url, max_retries, exc)
raise last_exc

async def run(self, sampling_params: dict[str, Any], **kwargs) -> int:
channel, prompt_uid, metadata = self._prepare_params(kwargs)

Expand Down Expand Up @@ -80,17 +110,55 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> int:
headers={"content-type": "application/json"},
)

# 3. Run the concrete agent.
# 3. Run the concrete agent (with retry on transient errors).
reward = None
num_turns = 0
try:
num_turns = await self._run_agent(base_url, kwargs)
result = await self._run_agent(base_url, kwargs)
# _run_agent may return int (num_turns) or tuple (num_turns, reward)
if isinstance(result, tuple):
num_turns, reward = result
else:
num_turns = result
except (httpx.ReadError, httpx.ConnectError, httpx.RemoteProtocolError) as exc:
# Agent failed due to transient connection error (e.g. vLLM weight update).
# Log warning but still complete the trajectory with reward=0.
logger.warning(
"Agent run failed with transient error: %s. Completing trajectory with reward=0.",
exc,
)
reward = 0.0
finally:
# 4. Mark trajectory complete.
async with httpx.AsyncClient(timeout=httpx.Timeout(600.0)) as http:
await http.post(f"{base_url}/complete_trajectory")
# 4. Mark trajectory complete, passing reward if available.
# Use retry to handle transient Gateway connection issues.
try:
body: dict[str, Any] = {}
if reward is not None:
body["reward"] = float(reward)
if channel:
body["channel"] = channel
if body:
await self._http_post_with_retry(
f"{base_url}/complete_trajectory",
json=body,
)
else:
await self._http_post_with_retry(
f"{base_url}/complete_trajectory",
)
except Exception as exc:
logger.error(
"Failed to complete trajectory after retries: %s", exc,
)

return num_turns

@abstractmethod
async def _run_agent(self, base_url: str, kwargs: dict[str, Any]) -> int:
"""Create and run the concrete Agent. Subclasses implement this."""
async def _run_agent(self, base_url: str, kwargs: dict[str, Any]) -> int | tuple[int, float]:
"""Create and run the concrete Agent. Subclasses implement this.

Returns either:
- ``int``: number of turns used (reward computed by Gateway)
- ``tuple[int, float]``: (turns_used, reward) for direct reward passing
"""
raise NotImplementedError
12 changes: 8 additions & 4 deletions claw_r1/blackbox_agent/gsm8k_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,17 @@ def __init__(self, base_url: str):
timeout=600.0,
)

async def solve(self, question: str, ground_truth: str, max_turns: int = 3) -> int:
async def solve(self, question: str, ground_truth: str, max_turns: int = 3) -> tuple[int, float]:
"""Attempt to solve *question* in up to *max_turns* LLM interactions.

Returns the number of turns actually used. Trajectory completion is
signaled by the caller (BlackBoxAgentFlowBase or online service entrypoint).
Returns a tuple of (turns_used, reward). ``reward`` is 1.0 if the
agent called ``check_answer`` with the correct answer, 0.0 otherwise.
Trajectory completion is signaled by the caller.
"""
messages: list[dict] = [{"role": "user", "content": question}]

turns_used = 0
reward = 0.0
for turn in range(max_turns):
turns_used = turn + 1

Expand All @@ -127,9 +129,11 @@ async def solve(self, question: str, ground_truth: str, max_turns: int = 3) -> i
if tc["name"] == "check_answer":
answer = tc["arguments"].get("answer", "")
result = check_answer(answer, ground_truth)
if "Correct" in result:
reward = 1.0
messages.append({"role": "tool", "content": result})
else:
messages.append({"role": "assistant", "content": content})
break

return turns_used
return turns_used, reward
2 changes: 1 addition & 1 deletion claw_r1/blackbox_agent/gsm8k_agent_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class BlackBoxGSM8KAgentFlow(BlackBoxAgentFlowBase):
"""Black-box flow that delegates to :class:`GSM8KAgent`."""

async def _run_agent(self, base_url: str, kwargs: dict[str, Any]) -> int:
async def _run_agent(self, base_url: str, kwargs: dict[str, Any]) -> tuple[int, float]:
raw_prompt = kwargs.get("raw_prompt", [])
if isinstance(raw_prompt, list) and raw_prompt:
question = next(
Expand Down
Loading