From affcbf185d564da25fc96dea2f3ad4f66ec993b0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Mar 2026 14:01:24 +0000 Subject: [PATCH 1/2] Initial plan From 731fdb9b8f6ec8051c14782787a59cbc87a6c030 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Mar 2026 14:08:00 +0000 Subject: [PATCH 2/2] Fix TRL implementation to comply with current TRL documentation Co-authored-by: Antix5 <96021131+Antix5@users.noreply.github.com> --- train_utils/agent.py | 2 +- train_utils/train.py | 2 +- train_utils/train_grpo.py | 42 +++++++++++++++------------------------ 3 files changed, 18 insertions(+), 28 deletions(-) diff --git a/train_utils/agent.py b/train_utils/agent.py index 18fad55..7c53434 100644 --- a/train_utils/agent.py +++ b/train_utils/agent.py @@ -25,7 +25,7 @@ def initialize_model(): model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", - dtype=torch.bfloat16, + torch_dtype=torch.bfloat16, trust_remote_code=True, ) return tokenizer, model diff --git a/train_utils/train.py b/train_utils/train.py index 511bd9b..e6b92c7 100644 --- a/train_utils/train.py +++ b/train_utils/train.py @@ -78,7 +78,7 @@ def main(): tokenizer.padding_side = "right" model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, device_map="auto", dtype=torch.bfloat16, trust_remote_code=True, attn_implementation="sdpa" + MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, attn_implementation="sdpa" ) # Enable gradient checkpointing to save VRAM model.gradient_checkpointing_enable() diff --git a/train_utils/train_grpo.py b/train_utils/train_grpo.py index de63ef3..ee362be 100644 --- a/train_utils/train_grpo.py +++ b/train_utils/train_grpo.py @@ -22,38 +22,29 @@ def extract_commands(completion: str) -> list[str]: return commands if commands else [completion.strip()] -def env_reward_function(prompts: list[str], completions: list[list[str]], target_buffers: list[str | None], **kwargs: dict[str, object]) -> list[float]: +def env_reward_function(prompts: list[str], completions: list[str], **kwargs: object) -> list[float]: """ The vectorized pure mapped reward function for TRL GRPO. Instantiates ViClient up front, configures the environment objective via `init_session`, and parses the LLM output into keystroke acts. + + TRL passes `completions` as a plain list[str] (one decoded string per completion). + Extra dataset columns (e.g. `target_buffer`) are forwarded via **kwargs. """ client = ViClient() rewards: list[float] = [] - # TRL passes a list of N completions (each completion is a list containing the generated text) - # But usually GRPO `completions` is a list of N strings (one per prompt if batch=N, or N completions for 1 prompt) - # The signature in modern TRL passes lists of strings for completions. - - # Flatten/normalize depending on how TRL supplies it - normalized_completions: list[str] = [] - for c in completions: - if isinstance(c, list): - normalized_completions.append(c[0]) # Depending on format, sometimes wrapped - else: - normalized_completions.append(str(c)) # type: ignore - + # Extra dataset columns (e.g. target_buffer) are passed by TRL as keyword arguments. + raw_targets = kwargs.get("target_buffer") + target_buffers: list[str | None] = list(raw_targets) if isinstance(raw_targets, (list, tuple)) else [None] * len(prompts) # type: ignore - # In GRPO, prompts is repeated G times for each group. + # In GRPO, prompts is repeated G times for each group. # We iterate over the batch and get exactly one reward float per completion. for i in range(len(prompts)): prompt_text = prompts[i] - completion_text = normalized_completions[i] - - # We can extract the target buffer if it was supplied in the dataset - # In a real setup, kwargs might contain dataset columns. For simplicity, we assume - # target_buffers are broadcasted or we extract them. - target = target_buffers[i] if target_buffers and i < len(target_buffers) else None + completion_text = completions[i] + + target = target_buffers[i] if i < len(target_buffers) else None try: session_id = client.init_session( @@ -109,22 +100,21 @@ def format_trl(example: dict[str, str]) -> dict[str, str]: ds = ds.map(format_trl) # type: ignore - # We need a closure to capture dataset target_buffers if they aren't natively passed. - # We can fetch them directly from kwargs if we configure the trainer correctly. - def reward_wrapper(prompts: list[str], completions: list[list[str]], target_buffer: list[str | None], **kwargs: dict[str, object]) -> list[float]: - return env_reward_function(prompts, completions, target_buffer, **kwargs) + # TRL forwards extra dataset columns (e.g. target_buffer) to the reward function via **kwargs. + # The wrapper simply delegates to env_reward_function which reads target_buffer from kwargs. + def reward_wrapper(prompts: list[str], completions: list[str], **kwargs: object) -> list[float]: + return env_reward_function(prompts, completions, **kwargs) training_args = GRPOConfig( output_dir="grpo_outputs", learning_rate=1e-5, per_device_train_batch_size=4, gradient_accumulation_steps=4, - max_prompt_length=512, max_completion_length=256, num_train_epochs=args.epochs, beta=args.beta, use_vllm=True, # Important: Scale GRPO with vLLM - vllm_device="cuda:0", + vllm_mode="colocate", # Run vLLM in the same process, sharing training GPUs vllm_gpu_memory_utilization=0.5, )