diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py index b37b2ebf997..6cded06593b 100644 --- a/verl/workers/rollout/chat_scheduler.py +++ b/verl/workers/rollout/chat_scheduler.py @@ -405,6 +405,9 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: # validation dataset has already been repeated in `PPOTrainer._validate`. n = 1 if batch.meta_info.get("validate", False) else self.config.n tasks, batch_conversations = [], [None] * len(batch) * n + + all_sequences_turn_data = [{} for _ in range(len(batch) * n)] + for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0)): # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] batch_conversations[batch_index] = conversation.tolist() @@ -415,6 +418,7 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: messages=batch_conversations[batch_index], request_id=None, sampling_params=kwargs, + turn_data=all_sequences_turn_data[batch_index], ) ) ) @@ -422,11 +426,12 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: await asyncio.gather(*tasks) output_batch = self.completion_callback.postprocess(batch, batch_conversations, n=n) output_batch.meta_info["timing"] = {"generate_sequences": time.time() - t_start} + output_batch.non_tensor_batch["turn_info"] = np.array(all_sequences_turn_data) print("[ChatCompletionScheduler] generate_sequences done") return output_batch async def _submit_chat_completions_semaphore( - self, messages: List[Dict[str, str]], request_id: str, sampling_params: Dict[str, Any] + self, messages: List[Dict[str, str]], request_id: str, sampling_params: Dict[str, Any], turn_data: Dict[str, Any] ): done = asyncio.Event() @@ -434,6 +439,8 @@ async def _submit_chat_completions_semaphore( "__done__": done, "__depth__": 0, # indicate how many ongoing completion requests "__sampling_params__": sampling_params, + "current_turn": 0, + "turn_data": turn_data, # used to collect turn data for each sequence } self.submit_chat_completions(messages=messages, request_id=request_id, info=info)