diff --git a/openadapt_evals/training/standalone/config.py b/openadapt_evals/training/standalone/config.py index 45defc0..a6edd81 100644 --- a/openadapt_evals/training/standalone/config.py +++ b/openadapt_evals/training/standalone/config.py @@ -47,7 +47,12 @@ class TrainingConfig: task_dir: str | None = None screen_size: tuple[int, int] = (1920, 1080) stuck_window: int = 3 - learning_rate: float = 5e-6 + learning_rate: float = 1e-6 + # Maximum gradient norm for clipping. Critical for stable training: + # grad_norm > 100 means gradients are dominated by clipping direction + # (effectively random) rather than the actual gradient signal. Lower + # values (0.5-1.0) stabilize training at the cost of slower learning. + max_grad_norm: float = 1.0 num_training_steps: int = 1000 save_every_steps: int = 50 output_dir: str = "checkpoints/grpo" diff --git a/openadapt_evals/training/standalone/trainer.py b/openadapt_evals/training/standalone/trainer.py index 312eebf..e4388af 100644 --- a/openadapt_evals/training/standalone/trainer.py +++ b/openadapt_evals/training/standalone/trainer.py @@ -457,7 +457,18 @@ def _training_step(self, rollouts: list[Rollout]) -> dict[str, float]: l = self._compute_rollout_loss(r, a, 1.0 / n) losses.append(l) grad_norm = torch.nn.utils.clip_grad_norm_( - [p for p in self._model.parameters() if p.requires_grad], max_norm=1.0) + [p for p in self._model.parameters() if p.requires_grad], + max_norm=self._config.max_grad_norm, + ) + gn = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) + if gn > 10 * self._config.max_grad_norm: + logger.warning( + "grad_norm=%.1f is %.0fx the clip threshold (%.1f). " + "Gradients are dominated by clipping, not learning signal. " + "Consider lowering learning_rate (current: %.1e).", + gn, gn / self._config.max_grad_norm, + self._config.max_grad_norm, self._config.learning_rate, + ) self._optimizer.step() avg_loss = sum(losses) / max(n, 1) @@ -485,9 +496,10 @@ def train(self) -> str: """Run GRPO training loop. Returns path to final checkpoint.""" import torch - logger.warning( - "The standalone GRPO trainer is deprecated. Use scripts/train_trl_grpo.py " - "with TRL's GRPOTrainer instead. See docs/eval_results/ for migration guide." + logger.info( + "Using standalone GRPO trainer. This is the production training " + "path for VLM agents with dynamic screenshots. TRL migration " + "pending multimodal environment_factory support (TRL PR #5323)." ) self._load_task_configs()