Add trajectory-level deduplication for GRPO advantage normalization#462
Add trajectory-level deduplication for GRPO advantage normalization#462zzjweb wants to merge 1 commit intomicrosoft:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This pull request adds trajectory-level deduplication to GRPO advantage normalization to address turn-level bias in multi-turn reinforcement learning scenarios. The implementation introduces a new compute_grpo_outcome_advantage function that tracks unique (data_id, rollout_id) pairs to ensure each trajectory is counted only once when computing baseline statistics for advantage estimation.
Changes:
- Added
compute_grpo_outcome_advantagefunction with trajectory-level deduplication logic - Integrated new advantage computation into the training pipeline with configurable behavior via
compute_mean_std_cross_all_dataparameter - Added assertion to restrict trajectory-level normalization to GRPO algorithm only
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def compute_grpo_outcome_advantage( | ||
| token_level_rewards: torch.Tensor, | ||
| response_mask: torch.Tensor, | ||
| index: np.ndarray, | ||
| traj_index: np.ndarray | None = None, | ||
| epsilon: float = 1e-6, | ||
| norm_adv_by_std_in_grpo: bool = True, | ||
| compute_mean_std_cross_all_data: bool = True, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Compute advantage for GRPO with trajectory-level deduplication support. | ||
|
|
||
| This is a minimal extension of VeRL's GRPO implementation, adding support for | ||
| trajectory-level deduplication via `traj_index` and `compute_mean_std_cross_all_data`. | ||
|
|
||
| Args: | ||
| token_level_rewards: Shape (bs, response_length). | ||
| response_mask: Shape (bs, response_length). | ||
| index: Group index array (e.g., data_id). | ||
| traj_index: Trajectory index array (e.g., rollout_id). If None, no deduplication. | ||
| epsilon: Small value for numerical stability. | ||
| norm_adv_by_std_in_grpo: If True, normalize by std (original GRPO). If False, Dr.GRPO style. | ||
| compute_mean_std_cross_all_data: If True (default), compute mean/std across all data. | ||
| If False, compute mean/std per unique (index, traj_index) trajectory. | ||
|
|
||
| Returns: | ||
| Tuple of (advantages, returns), both shape (bs, response_length). | ||
| """ | ||
| scores = token_level_rewards.sum(dim=-1) | ||
|
|
||
| id2score: dict = defaultdict(list) | ||
| id2mean: dict = {} | ||
| id2std: dict = {} | ||
| seen_pairs: set = set() | ||
|
|
||
| with torch.no_grad(): | ||
| bsz = scores.shape[0] | ||
| for i in range(bsz): | ||
| # Trajectory deduplication: skip if (index, traj_index) already seen | ||
| if traj_index is not None and (index[i], traj_index[i]) in seen_pairs: | ||
| continue | ||
| id2score[index[i]].append(scores[i]) | ||
| # Mark as seen only when compute_mean_std_cross_all_data is False | ||
| if traj_index is not None and not compute_mean_std_cross_all_data: | ||
| seen_pairs.add((index[i], traj_index[i])) | ||
|
|
||
| for idx in id2score: | ||
| if len(id2score[idx]) == 1: | ||
| id2mean[idx] = torch.tensor(0.0) | ||
| id2std[idx] = torch.tensor(1.0) | ||
| elif len(id2score[idx]) > 1: | ||
| scores_tensor = torch.stack(id2score[idx]) | ||
| id2mean[idx] = torch.mean(scores_tensor) | ||
| id2std[idx] = torch.std(scores_tensor) | ||
| else: | ||
| raise ValueError(f"no score in prompt index: {idx}") | ||
|
|
||
| for i in range(bsz): | ||
| if norm_adv_by_std_in_grpo: | ||
| scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) | ||
| else: | ||
| scores[i] = scores[i] - id2mean[index[i]] | ||
| scores = scores.unsqueeze(-1) * response_mask | ||
|
|
||
| return scores, scores | ||
|
|
There was a problem hiding this comment.
The new compute_grpo_outcome_advantage function lacks test coverage. Given that this is a critical mathematical computation affecting training outcomes, unit tests should be added to verify:
- Correct behavior when
compute_mean_std_cross_all_data=TruevsFalse - Proper handling of trajectory deduplication with different
(index, traj_index)combinations - Device consistency (tensors on GPU)
- Edge cases: single-sample groups, all identical scores, etc.
- Correct advantage normalization with and without std division
Consider adding tests in tests/trainer/ directory or a new test file specifically for GRPO advantage computation.
| id2mean[idx] = torch.mean(scores_tensor) | ||
| id2std[idx] = torch.std(scores_tensor) | ||
| else: | ||
| raise ValueError(f"no score in prompt index: {idx}") |
There was a problem hiding this comment.
The error message uses f-string formatting but doesn't include the idx variable value. The message should be updated to include the actual index value that's causing the issue for better debugging:
raise ValueError(f"no score in prompt index: {idx}")should ensure the value is actually included in the error output.
| for i in range(bsz): | ||
| # Trajectory deduplication: skip if (index, traj_index) already seen | ||
| if traj_index is not None and (index[i], traj_index[i]) in seen_pairs: | ||
| continue | ||
| id2score[index[i]].append(scores[i]) | ||
| # Mark as seen only when compute_mean_std_cross_all_data is False | ||
| if traj_index is not None and not compute_mean_std_cross_all_data: | ||
| seen_pairs.add((index[i], traj_index[i])) | ||
|
|
||
| for idx in id2score: | ||
| if len(id2score[idx]) == 1: | ||
| id2mean[idx] = torch.tensor(0.0) | ||
| id2std[idx] = torch.tensor(1.0) | ||
| elif len(id2score[idx]) > 1: | ||
| scores_tensor = torch.stack(id2score[idx]) | ||
| id2mean[idx] = torch.mean(scores_tensor) | ||
| id2std[idx] = torch.std(scores_tensor) | ||
| else: | ||
| raise ValueError(f"no score in prompt index: {idx}") | ||
|
|
||
| for i in range(bsz): | ||
| if norm_adv_by_std_in_grpo: | ||
| scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) | ||
| else: | ||
| scores[i] = scores[i] - id2mean[index[i]] |
There was a problem hiding this comment.
The function accepts index as np.ndarray but uses it directly to index into dictionaries (lines 90, 108, 110). In Python dictionaries, NumPy array elements may not hash correctly depending on their dtype. If index contains NumPy scalars (e.g., np.int64), this could cause issues.
Consider converting array elements to Python native types when using them as dictionary keys:
idx_key = int(index[i])
id2score[idx_key].append(scores[i])Or document that index must contain hashable types that work as dictionary keys.
| if not compute_mean_std_cross_all_data: | ||
| assert self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO, ( | ||
| f"compute_mean_std_cross_all_data=False is only supported for GRPO, " | ||
| f"got {self.config.algorithm.adv_estimator}" | ||
| ) |
There was a problem hiding this comment.
The assertion on lines 432-435 only checks when compute_mean_std_cross_all_data=False, but the new GRPO implementation is used for ALL GRPO cases (line 438 condition). This means when compute_mean_std_cross_all_data=True with a non-GRPO estimator, the assertion is never checked, but the code would still go through the else branch at line 452.
While this is not necessarily incorrect (the else branch handles non-GRPO cases properly), the control flow could be clearer. Consider restructuring to make the relationship between the flag and the GRPO check more explicit, or add a comment explaining why the assertion only needs to check the False case.
@microsoft-github-policy-service agree |
Problem
Agent-lightning inherits VeRL's default advantage estimation, which assumes each batch sample is independent. In multi-turn scenarios, this causes turn-level bias: trajectories with more turns contribute more to baseline statistics (mean/std), leading to biased advantage estimation and inefficient optimization.
Solution
Implements trajectory-level deduplication using
(data_id, rollout_id)pairs. Setalgorithm.compute_mean_std_cross_all_data=Falseto ensure each trajectory is counted only once when computing baselines.In
agentlightning.verl.trainer, we re-implementcomputer_grpo_outcome_advantageto integrate the new trajectory-level deduplication logic while keeping dependency on VeRL minimal.Example Configuration
Control the normalization behavior via the
compute_mean_std_cross_all_dataparameter:compute_mean_std_cross_all_data=True(default): Cross-all-data normalization, more stable but still counts each turncompute_mean_std_cross_all_data=False: Trajectory-level normalization - each trajectory counted only once, eliminates biasImplementation
Affected algorithms (currently only GRPO is supported):
Files modified:
agentlightning/verl/trainer.py: Addcomputer_grpo_outcome_advantage