Skip to content

Fix TRL GRPO implementation to comply with current TRL API (≥0.29)#2

Draft
Copilot wants to merge 2 commits into
reward_designfrom
copilot/fix-trl-implementation-issues
Draft

Fix TRL GRPO implementation to comply with current TRL API (≥0.29)#2
Copilot wants to merge 2 commits into
reward_designfrom
copilot/fix-trl-implementation-issues

Conversation

Copilot AI commented Mar 8, 2026

Copy link
Copy Markdown

Several breaking incompatibilities between the GRPO training code and the current TRL library API caused silent failures (wrong dtype, wrong reward signature) or hard crashes (unknown config params).

train_grpo.py

  • Reward function signature: completions changed from list[list[str]] to list[str] — TRL decodes completions to plain strings before passing them; removed the now-dead normalization loop
  • target_buffer kwarg: TRL forwards extra dataset columns as **kwargs, not positional args — updated both env_reward_function and reward_wrapper accordingly
  • Removed vllm_device: parameter no longer exists in GRPOConfig; replaced with vllm_mode="colocate" to preserve in-process single-GPU vLLM behaviour (default changed to "server" in current TRL)
  • Removed max_prompt_length: parameter no longer exists in GRPOConfig
# Before — crashed at runtime: unknown GRPOConfig fields + wrong reward signature
def env_reward_function(prompts, completions: list[list[str]], target_buffers: list[str | None], **kwargs): ...
GRPOConfig(..., max_prompt_length=512, vllm_device="cuda:0")

# After — matches current TRL API
def env_reward_function(prompts: list[str], completions: list[str], **kwargs):
    target_buffers = list(kwargs["target_buffer"]) if isinstance(kwargs.get("target_buffer"), (list, tuple)) else [None] * len(prompts)
    ...
GRPOConfig(..., vllm_mode="colocate", vllm_gpu_memory_utilization=0.5)

train.py / agent.py

  • dtype=torch_dtype=: AutoModelForCausalLM.from_pretrained uses torch_dtype; the incorrect kwarg silently loaded the model in float32 instead of bfloat16

Warning

Firewall rules blocked me from connecting to one or more addresses (expand for details)

I tried to connect to the following addresses, but was blocked by firewall rules:

  • huggingface.co
    • Triggering command: /home/REDACTED/work/_temp/ghcca-node/node/bin/node /home/REDACTED/work/_temp/ghcca-node/node/bin/node --enable-source-maps /home/REDACTED/work/_temp/copilot-developer-action-main/dist/index.js (dns block)

If you need me to access, download, or install something from one of these locations, you can either:


🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. Learn more about Advanced Security.

Co-authored-by: Antix5 <96021131+Antix5@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix TRL implementation compliance with documentation Fix TRL GRPO implementation to comply with current TRL API (≥0.29) Mar 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants