Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion train_utils/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def initialize_model():
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
dtype=torch.bfloat16,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
return tokenizer, model
Expand Down
2 changes: 1 addition & 1 deletion train_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main():
tokenizer.padding_side = "right"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", dtype=torch.bfloat16, trust_remote_code=True, attn_implementation="sdpa"
MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, attn_implementation="sdpa"
)
# Enable gradient checkpointing to save VRAM
model.gradient_checkpointing_enable()
Expand Down
42 changes: 16 additions & 26 deletions train_utils/train_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,29 @@ def extract_commands(completion: str) -> list[str]:
return commands if commands else [completion.strip()]


def env_reward_function(prompts: list[str], completions: list[list[str]], target_buffers: list[str | None], **kwargs: dict[str, object]) -> list[float]:
def env_reward_function(prompts: list[str], completions: list[str], **kwargs: object) -> list[float]:
"""
The vectorized pure mapped reward function for TRL GRPO.
Instantiates ViClient up front, configures the environment objective via `init_session`,
and parses the LLM output into keystroke acts.

TRL passes `completions` as a plain list[str] (one decoded string per completion).
Extra dataset columns (e.g. `target_buffer`) are forwarded via **kwargs.
"""
client = ViClient()
rewards: list[float] = []

# TRL passes a list of N completions (each completion is a list containing the generated text)
# But usually GRPO `completions` is a list of N strings (one per prompt if batch=N, or N completions for 1 prompt)
# The signature in modern TRL passes lists of strings for completions.

# Flatten/normalize depending on how TRL supplies it
normalized_completions: list[str] = []
for c in completions:
if isinstance(c, list):
normalized_completions.append(c[0]) # Depending on format, sometimes wrapped
else:
normalized_completions.append(str(c)) # type: ignore

# Extra dataset columns (e.g. target_buffer) are passed by TRL as keyword arguments.
raw_targets = kwargs.get("target_buffer")
target_buffers: list[str | None] = list(raw_targets) if isinstance(raw_targets, (list, tuple)) else [None] * len(prompts) # type: ignore

# In GRPO, prompts is repeated G times for each group.
# In GRPO, prompts is repeated G times for each group.
# We iterate over the batch and get exactly one reward float per completion.
for i in range(len(prompts)):
prompt_text = prompts[i]
completion_text = normalized_completions[i]

# We can extract the target buffer if it was supplied in the dataset
# In a real setup, kwargs might contain dataset columns. For simplicity, we assume
# target_buffers are broadcasted or we extract them.
target = target_buffers[i] if target_buffers and i < len(target_buffers) else None
completion_text = completions[i]

target = target_buffers[i] if i < len(target_buffers) else None

try:
session_id = client.init_session(
Expand Down Expand Up @@ -109,22 +100,21 @@ def format_trl(example: dict[str, str]) -> dict[str, str]:

ds = ds.map(format_trl) # type: ignore

# We need a closure to capture dataset target_buffers if they aren't natively passed.
# We can fetch them directly from kwargs if we configure the trainer correctly.
def reward_wrapper(prompts: list[str], completions: list[list[str]], target_buffer: list[str | None], **kwargs: dict[str, object]) -> list[float]:
return env_reward_function(prompts, completions, target_buffer, **kwargs)
# TRL forwards extra dataset columns (e.g. target_buffer) to the reward function via **kwargs.
# The wrapper simply delegates to env_reward_function which reads target_buffer from kwargs.
def reward_wrapper(prompts: list[str], completions: list[str], **kwargs: object) -> list[float]:
return env_reward_function(prompts, completions, **kwargs)

training_args = GRPOConfig(
output_dir="grpo_outputs",
learning_rate=1e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
max_prompt_length=512,
max_completion_length=256,
num_train_epochs=args.epochs,
beta=args.beta,
use_vllm=True, # Important: Scale GRPO with vLLM
vllm_device="cuda:0",
vllm_mode="colocate", # Run vLLM in the same process, sharing training GPUs
vllm_gpu_memory_utilization=0.5,
)

Expand Down