Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion openadapt_evals/training/standalone/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ class TrainingConfig:
task_dir: str | None = None
screen_size: tuple[int, int] = (1920, 1080)
stuck_window: int = 3
learning_rate: float = 5e-6
learning_rate: float = 1e-6
# Maximum gradient norm for clipping. Critical for stable training:
# grad_norm > 100 means gradients are dominated by clipping direction
# (effectively random) rather than the actual gradient signal. Lower
# values (0.5-1.0) stabilize training at the cost of slower learning.
max_grad_norm: float = 1.0
num_training_steps: int = 1000
save_every_steps: int = 50
output_dir: str = "checkpoints/grpo"
Expand Down
20 changes: 16 additions & 4 deletions openadapt_evals/training/standalone/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,18 @@ def _training_step(self, rollouts: list[Rollout]) -> dict[str, float]:
l = self._compute_rollout_loss(r, a, 1.0 / n)
losses.append(l)
grad_norm = torch.nn.utils.clip_grad_norm_(
[p for p in self._model.parameters() if p.requires_grad], max_norm=1.0)
[p for p in self._model.parameters() if p.requires_grad],
max_norm=self._config.max_grad_norm,
)
gn = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm)
if gn > 10 * self._config.max_grad_norm:
logger.warning(
"grad_norm=%.1f is %.0fx the clip threshold (%.1f). "
"Gradients are dominated by clipping, not learning signal. "
"Consider lowering learning_rate (current: %.1e).",
gn, gn / self._config.max_grad_norm,
self._config.max_grad_norm, self._config.learning_rate,
)
self._optimizer.step()

avg_loss = sum(losses) / max(n, 1)
Expand Down Expand Up @@ -485,9 +496,10 @@ def train(self) -> str:
"""Run GRPO training loop. Returns path to final checkpoint."""
import torch

logger.warning(
"The standalone GRPO trainer is deprecated. Use scripts/train_trl_grpo.py "
"with TRL's GRPOTrainer instead. See docs/eval_results/ for migration guide."
logger.info(
"Using standalone GRPO trainer. This is the production training "
"path for VLM agents with dynamic screenshots. TRL migration "
"pending multimodal environment_factory support (TRL PR #5323)."
)

self._load_task_configs()
Expand Down
Loading