From 810ed3a3197de806ff4e83097ed41fa014b58d39 Mon Sep 17 00:00:00 2001 From: yuetian Date: Tue, 3 Feb 2026 15:52:47 +0800 Subject: [PATCH] [feat] Supprot async_reset_consumpation enable reuse same batch data Signed-off-by: yuetian --- tests/test_client.py | 53 ++++++++++ tests/test_controller.py | 157 ++++++++++++++++++++++++++++++ transfer_queue/client.py | 70 +++++++++++++ transfer_queue/controller.py | 70 +++++++++++++ transfer_queue/utils/zmq_utils.py | 2 + 5 files changed, 352 insertions(+) diff --git a/tests/test_client.py b/tests/test_client.py index 604ced82..38f140b0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -137,6 +137,13 @@ def _handle_requests(self): elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META: response_body = {"message": "success"} response_type = ZMQRequestType.SET_CUSTOM_META_RESPONSE + elif request_msg.request_type == ZMQRequestType.RESET_CONSUMPTION: + # Mock reset consumption - always succeed + response_body = { + "success": True, + "message": "Consumption reset successfully", + } + response_type = ZMQRequestType.RESET_CONSUMPTION_RESPONSE else: response_body = {"error": f"Unknown request type: {request_msg.request_type}"} response_type = ZMQRequestType.CLEAR_META_RESPONSE @@ -531,6 +538,52 @@ def test_get_partition_list(client_setup): assert "test_partition" in partition_list +def test_reset_consumption(client_setup): + """Test synchronous reset_consumption - resets consumption status for a partition""" + client, _, _ = client_setup + + # Test synchronous reset_consumption with task_name + success = client.reset_consumption(partition_id="train_0", task_name="generate_sequences") + assert success is True + + print("✓ reset_consumption with task_name returns True") + + +def test_reset_consumption_all_tasks(client_setup): + """Test synchronous reset_consumption without task_name (resets all tasks)""" + client, _, _ = client_setup + + # Test synchronous reset_consumption without task_name (reset all tasks) + success = client.reset_consumption(partition_id="train_0") + assert success is True + + print("✓ reset_consumption without task_name (all tasks) returns True") + + +@pytest.mark.asyncio +async def test_async_reset_consumption(client_setup): + """Test async reset_consumption - resets consumption status for a partition""" + client, _, _ = client_setup + + # Test async_reset_consumption with task_name + success = await client.async_reset_consumption(partition_id="train_0", task_name="generate_sequences") + assert success is True + + print("✓ async_reset_consumption with task_name returns True") + + +@pytest.mark.asyncio +async def test_async_reset_consumption_all_tasks(client_setup): + """Test async reset_consumption without task_name (resets all tasks)""" + client, _, _ = client_setup + + # Test async_reset_consumption without task_name (reset all tasks) + success = await client.async_reset_consumption(partition_id="train_0") + assert success is True + + print("✓ async_reset_consumption without task_name (all tasks) returns True") + + @pytest.mark.asyncio async def test_async_check_consumption_status(client_setup): """Test async consumption status checking""" diff --git a/tests/test_controller.py b/tests/test_controller.py index 14fd3aad..09f5778c 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -198,6 +198,163 @@ def test_controller_with_single_partition(self, ray_setup): assert partition is None print("✓ Clear partition correct") + def test_controller_reset_consumption(self, ray_setup): + """Test reset_consumption functionality - allows data to be re-consumed""" + gbs = 4 + num_n_samples = 2 + partition_id = "test_reset_consumption" + + tq_controller = TransferQueueController.remote() + + # Step 1: Create metadata in insert mode + data_fields = ["prompt_ids", "attention_mask"] + metadata = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=gbs * num_n_samples, + partition_id=partition_id, + mode="insert", + ) + ) + assert metadata.global_indexes == list(range(gbs * num_n_samples)) + + # Step 2: Update production status + dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in metadata.global_indexes} + shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in metadata.global_indexes} + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata.global_indexes, + field_names=metadata.field_names, + dtypes=dtypes, + shapes=shapes, + ) + ) + assert success + + # Step 3: Verify consumption status BEFORE consumption (should be all zeros) + global_index, consumption_status = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id, + task_name="generate_sequences", + ) + ) + expected_consumption_before = torch.zeros(gbs * num_n_samples, dtype=torch.int8) + assert torch.equal(consumption_status, expected_consumption_before) + print("✓ Consumption status before fetch is all zeros") + + # Step 4: Fetch data (mark as consumed) + gen_meta = ray.get( + tq_controller.get_metadata.remote( + data_fields=["prompt_ids"], + batch_size=gbs * num_n_samples, + partition_id=partition_id, + mode="fetch", + task_name="generate_sequences", + ) + ) + assert gen_meta.global_indexes == list(range(gbs * num_n_samples)) + + # Step 5: Verify consumption status AFTER consumption (should be all ones) + global_index, consumption_status = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id, + task_name="generate_sequences", + ) + ) + expected_consumption_after = torch.ones(gbs * num_n_samples, dtype=torch.int8) + assert torch.equal(consumption_status, expected_consumption_after) + print("✓ Consumption status after fetch is all ones") + + # Step 6: Reset consumption for specific task + ray.get( + tq_controller.reset_consumption.remote( + partition_id=partition_id, + task_name="generate_sequences", + ) + ) + + # Step 7: Verify consumption status is reset (should be all zeros again) + global_index, consumption_status = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id, + task_name="generate_sequences", + ) + ) + expected_consumption_reset = torch.zeros(gbs * num_n_samples, dtype=torch.int8) + assert torch.equal(consumption_status, expected_consumption_reset) + print("✓ Consumption status after reset is all zeros") + + # Step 8: Consume again and test reset all tasks + gen_meta_2 = ray.get( + tq_controller.get_metadata.remote( + data_fields=["prompt_ids"], + batch_size=gbs * num_n_samples, + partition_id=partition_id, + mode="fetch", + task_name="generate_sequences", + ) + ) + assert gen_meta_2.global_indexes == list(range(gbs * num_n_samples)) + + # Also consume with another task + gen_meta_3 = ray.get( + tq_controller.get_metadata.remote( + data_fields=["attention_mask"], + batch_size=gbs * num_n_samples, + partition_id=partition_id, + mode="fetch", + task_name="another_task", + ) + ) + assert gen_meta_3.global_indexes == list(range(gbs * num_n_samples)) + + # Verify both tasks have consumed + _, consumption_status_task1 = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id, + task_name="generate_sequences", + ) + ) + _, consumption_status_task2 = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id, + task_name="another_task", + ) + ) + assert torch.equal(consumption_status_task1, torch.ones(gbs * num_n_samples, dtype=torch.int8)) + assert torch.equal(consumption_status_task2, torch.ones(gbs * num_n_samples, dtype=torch.int8)) + print("✓ Both tasks consumed successfully") + + # Step 9: Reset all tasks (task_name=None) + ray.get( + tq_controller.reset_consumption.remote( + partition_id=partition_id, + task_name=None, # Reset all tasks + ) + ) + + # Step 10: Verify all tasks are reset + _, consumption_status_task1_reset = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id, + task_name="generate_sequences", + ) + ) + _, consumption_status_task2_reset = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id, + task_name="another_task", + ) + ) + assert torch.equal(consumption_status_task1_reset, torch.zeros(gbs * num_n_samples, dtype=torch.int8)) + assert torch.equal(consumption_status_task2_reset, torch.zeros(gbs * num_n_samples, dtype=torch.int8)) + print("✓ Reset all tasks successful - both tasks have zero consumption status") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + print("✓ Reset consumption test completed successfully") + def test_controller_with_multi_partitions(self, ray_setup): gbs_1 = 8 num_n_samples_1 = 4 diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 1bfc9ffa..24b4e559 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -768,6 +768,63 @@ async def async_check_consumption_status( return False return torch.all(consumption_status == 1).item() + @dynamic_socket(socket_name="request_handle_socket") + async def async_reset_consumption( + self, + partition_id: str, + task_name: Optional[str] = None, + socket: Optional[zmq.asyncio.Socket] = None, + ) -> bool: + """Reset consumption status for a partition, allowing data to be re-consumed. + This is useful for debugging scenarios where the same rollout data needs to be + trained multiple times without regenerating the data. + Args: + partition_id: Partition id to reset consumption status for + task_name: Name of the task to reset. If None, resets all tasks. + socket: ZMQ async socket for message transmission (injected by decorator) + Returns: + bool: True if reset was successful, False otherwise + Raises: + RuntimeError: If communication fails or controller returns error response + Example: + >>> # Reset consumption for train task to re-train on same data + >>> success = asyncio.run(client.async_reset_consumption( + ... partition_id="train_0", + ... task_name="train" + ... )) + >>> print(f"Reset successful: {success}") + """ + assert socket is not None + body = {"partition_id": partition_id} + if task_name is not None: + body["task_name"] = task_name + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.RESET_CONSUMPTION, + sender_id=self.client_id, + receiver_id=self._controller.id, + body=body, + ) + try: + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart() + response_msg = ZMQMessage.deserialize(response_serialized) + logger.debug( + f"[{self.client_id}]: Client reset consumption response: {response_msg} " + f"from controller {self._controller.id}" + ) + if response_msg.request_type == ZMQRequestType.RESET_CONSUMPTION_RESPONSE: + success = response_msg.body.get("success", False) + if not success: + logger.warning(f"[{self.client_id}]: Reset consumption failed: {response_msg.body.get('message')}") + return success + else: + raise RuntimeError( + f"[{self.client_id}]: Failed to reset consumption from controller {self._controller.id}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e + async def async_check_production_status( self, data_fields: list[str], @@ -917,6 +974,7 @@ def wrapper(*args, **kwargs): self._check_production_status = _make_sync(self.async_check_production_status) self._get_partition_list = _make_sync(self.async_get_partition_list) self._set_custom_meta = _make_sync(self.async_set_custom_meta) + self._reset_consumption = _make_sync(self.async_reset_consumption) def put( self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None @@ -1138,6 +1196,18 @@ def get_consumption_status( """ return self._get_consumption_status(task_name, partition_id) + def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool: + """Synchronously reset consumption status for a partition. + This allows the same data to be re-consumed, useful for debugging scenarios + where the same rollout data needs to be trained multiple times. + Args: + partition_id: Partition id to reset consumption status for + task_name: Name of the task to reset. If None, resets all tasks. + Returns: + bool: True if reset was successful, False otherwise + """ + return self._reset_consumption(partition_id, task_name) + def check_production_status(self, data_fields: list[str], partition_id: str) -> bool: """Synchronously check if all samples for a partition are ready (produced) for consumption. diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 0049c4e9..c91eaf7c 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -546,6 +546,28 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te return partition_global_index, consumption_status + def reset_consumption(self, task_name: Optional[str] = None): + """ + Reset consumption status for a specific task or all tasks. + + This allows the same data to be re-consumed without clearing the actual data. + Useful for debugging scenarios where the same rollout data needs to be + trained multiple times. + Args: + task_name: Name of the task to reset consumption for. + If None, resets consumption status for all tasks. + """ + if task_name is not None: + # Reset specific task + if task_name in self.consumption_status: + self.consumption_status[task_name].zero_() + logger.debug(f"Reset consumption status for task '{task_name}' in partition {self.partition_id}") + else: + # Reset all tasks + for name, status_tensor in self.consumption_status.items(): + status_tensor.zero_() + logger.debug(f"Reset consumption status for all tasks in partition {self.partition_id}") + # ==================== Production Status Interface ==================== def get_production_status_for_fields( self, field_names: list[str], mask: bool = False @@ -1341,6 +1363,24 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True): self.partitions.pop(partition_id) self.sampler.clear_cache(partition_id) + def reset_consumption(self, partition_id: str, task_name: Optional[str] = None): + """ + Reset consumption status for a partition without clearing the actual data. + + This allows the same data to be re-consumed, useful for debugging scenarios + where the same rollout data needs to be trained multiple times. + Args: + partition_id: ID of the partition to reset consumption for + task_name: Name of the task to reset. If None, resets all tasks. + Raises: + ValueError: If partition not found + """ + logger.debug(f"[{self.controller_id}]: Resetting consumption for partition {partition_id}, task={task_name}") + partition = self._get_partition(partition_id) + if not partition: + raise ValueError(f"Partition {partition_id} not found") + partition.reset_consumption(task_name) + def clear_meta( self, global_indexes: list[int], @@ -1624,6 +1664,36 @@ def _process_request(self): }, ) + elif request_msg.request_type == ZMQRequestType.RESET_CONSUMPTION: + with perf_monitor.measure(op_type="RESET_CONSUMPTION"): + # Handle reset consumption status request + params = request_msg.body + partition_id = params["partition_id"] + task_name = params.get("task_name") # Optional + try: + self.reset_consumption(partition_id, task_name) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.RESET_CONSUMPTION_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={ + "partition_id": partition_id, + "success": True, + "message": f"Consumption reset for partition {partition_id}", + }, + ) + except Exception as e: + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.RESET_CONSUMPTION_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={ + "partition_id": partition_id, + "success": False, + "message": str(e), + }, + ) + elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION: with perf_monitor.measure(op_type="GET_PRODUCTION"): # Handle production status checks diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 0a39113d..1f6ed922 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -84,6 +84,8 @@ class ZMQRequestType(ExplicitEnum): # GET_CONSUMPTION GET_CONSUMPTION = "GET_CONSUMPTION" CONSUMPTION_RESPONSE = "CONSUMPTION_RESPONSE" + RESET_CONSUMPTION = "RESET_CONSUMPTION" + RESET_CONSUMPTION_RESPONSE = "RESET_CONSUMPTION_RESPONSE" # GET_PRODUCTION GET_PRODUCTION = "GET_PRODUCTION"