From 6c96b61aed17d485deccae35815934a3e05e4c86 Mon Sep 17 00:00:00 2001 From: Thibaut Barroyer Date: Wed, 2 Jul 2025 10:48:46 +0200 Subject: [PATCH 1/2] fix chat scheduler --- verl/workers/rollout/chat_scheduler.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py index b37b2ebf997..b534b64c2b0 100644 --- a/verl/workers/rollout/chat_scheduler.py +++ b/verl/workers/rollout/chat_scheduler.py @@ -405,6 +405,9 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: # validation dataset has already been repeated in `PPOTrainer._validate`. n = 1 if batch.meta_info.get("validate", False) else self.config.n tasks, batch_conversations = [], [None] * len(batch) * n + + all_sequences_turn_data = [[] for _ in range(len(batch) * n)] + for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0)): # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] batch_conversations[batch_index] = conversation.tolist() @@ -415,6 +418,7 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: messages=batch_conversations[batch_index], request_id=None, sampling_params=kwargs, + turn_data=all_sequences_turn_data[batch_index], ) ) ) @@ -422,11 +426,12 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: await asyncio.gather(*tasks) output_batch = self.completion_callback.postprocess(batch, batch_conversations, n=n) output_batch.meta_info["timing"] = {"generate_sequences": time.time() - t_start} + output_batch.non_tensor_batch["turn_data"] = np.array(all_sequences_turn_data) print("[ChatCompletionScheduler] generate_sequences done") return output_batch async def _submit_chat_completions_semaphore( - self, messages: List[Dict[str, str]], request_id: str, sampling_params: Dict[str, Any] + self, messages: List[Dict[str, str]], request_id: str, sampling_params: Dict[str, Any], turn_data: List[Dict[str, Any]] ): done = asyncio.Event() @@ -434,6 +439,8 @@ async def _submit_chat_completions_semaphore( "__done__": done, "__depth__": 0, # indicate how many ongoing completion requests "__sampling_params__": sampling_params, + "current_turn": 0, + "turn_data": turn_data, # used to collect turn data for each sequence } self.submit_chat_completions(messages=messages, request_id=request_id, info=info) From 36b79f661ddb2a59d9c977fb311da77acc3dca5a Mon Sep 17 00:00:00 2001 From: Thibaut Barroyer Date: Wed, 2 Jul 2025 11:44:49 +0200 Subject: [PATCH 2/2] fix chat scheduler --- verl/workers/rollout/chat_scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py index b534b64c2b0..6cded06593b 100644 --- a/verl/workers/rollout/chat_scheduler.py +++ b/verl/workers/rollout/chat_scheduler.py @@ -406,7 +406,7 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: n = 1 if batch.meta_info.get("validate", False) else self.config.n tasks, batch_conversations = [], [None] * len(batch) * n - all_sequences_turn_data = [[] for _ in range(len(batch) * n)] + all_sequences_turn_data = [{} for _ in range(len(batch) * n)] for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0)): # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] @@ -426,12 +426,12 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: await asyncio.gather(*tasks) output_batch = self.completion_callback.postprocess(batch, batch_conversations, n=n) output_batch.meta_info["timing"] = {"generate_sequences": time.time() - t_start} - output_batch.non_tensor_batch["turn_data"] = np.array(all_sequences_turn_data) + output_batch.non_tensor_batch["turn_info"] = np.array(all_sequences_turn_data) print("[ChatCompletionScheduler] generate_sequences done") return output_batch async def _submit_chat_completions_semaphore( - self, messages: List[Dict[str, str]], request_id: str, sampling_params: Dict[str, Any], turn_data: List[Dict[str, Any]] + self, messages: List[Dict[str, str]], request_id: str, sampling_params: Dict[str, Any], turn_data: Dict[str, Any] ): done = asyncio.Event()