From e666b3caa9049de9f5c172eccff27e72986174fd Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 4 Jun 2026 15:46:44 +0800 Subject: [PATCH 1/2] honor data_status_update thread Signed-off-by: 0oshowero0 --- tests/test_async_simple_storage_manager.py | 9 +- tests/test_ray_p2p.py | 1 - transfer_queue/controller.py | 164 +++++++-------------- transfer_queue/storage/managers/base.py | 2 +- 4 files changed, 57 insertions(+), 119 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 6c4da3d2..a6875600 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -53,7 +53,7 @@ async def mock_async_storage_manager(): role=Role.CONTROLLER, id="controller_0", ip="127.0.0.1", - ports={"handshake_socket": 12347, "data_status_update_socket": 12348}, + ports={"handshake_socket": 12347}, ) config = { @@ -68,7 +68,6 @@ async def mock_async_storage_manager(): manager.config = config manager.controller_info = controller_info manager.storage_unit_infos = storage_unit_infos - manager.data_status_update_socket = None manager.controller_handshake_socket = None manager.zmq_context = None @@ -158,7 +157,7 @@ async def test_async_storage_manager_error_handling(): role=Role.CONTROLLER, id="controller_0", ip="127.0.0.1", - ports={"handshake_socket": 12346, "data_status_update_socket": 12347}, + ports={"handshake_socket": 12346}, ) config = { @@ -257,7 +256,6 @@ async def test_get_data_routes_from_hash(): manager.storage_manager_id = "test_get" manager.storage_unit_infos = storage_unit_infos manager.controller_info = None - manager.data_status_update_socket = None manager.controller_handshake_socket = None manager.zmq_context = None @@ -310,7 +308,6 @@ async def test_clear_data_routes_from_hash(): manager.storage_manager_id = "test_clear" manager.storage_unit_infos = storage_unit_infos manager.controller_info = None - manager.data_status_update_socket = None manager.controller_handshake_socket = None manager.zmq_context = None @@ -361,7 +358,6 @@ async def test_hash_routing_stable_across_batch_sizes(): manager.storage_manager_id = "test_hash_batch" manager.storage_unit_infos = storage_unit_infos manager.controller_info = None - manager.data_status_update_socket = None manager.controller_handshake_socket = None manager.zmq_context = None @@ -422,7 +418,6 @@ async def test_hash_routing_stable_reversed_order(): manager.storage_manager_id = "test_hash_order" manager.storage_unit_infos = storage_unit_infos manager.controller_info = None - manager.data_status_update_socket = None manager.controller_handshake_socket = None manager.zmq_context = None diff --git a/tests/test_ray_p2p.py b/tests/test_ray_p2p.py index 353bb926..92f3f9c9 100644 --- a/tests/test_ray_p2p.py +++ b/tests/test_ray_p2p.py @@ -60,7 +60,6 @@ def create_mock_controller(): ip="127.0.0.1", ports={ "request_handle_socket": 9981, - "data_status_update_socket": 9982, "handshake_socket": 9983, }, ) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 9661ba0e..5c28a88d 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -20,7 +20,7 @@ from dataclasses import dataclass, field from itertools import groupby from operator import itemgetter -from threading import Lock, Thread +from threading import Thread from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -361,10 +361,6 @@ class DataPartitionStatus: keys_mapping: dict[str, int] = field(default_factory=dict) # key -> global_idx revert_keys_mapping: dict[int, str] = field(default_factory=dict) # global_idx -> key - # Threading lock for concurrency control; only for preventing mask operation error when expanding production_status. - # No need to strictly lock for every read/write operation since freshness is not critical. - data_status_lock: Lock = field(default_factory=Lock) - # Dynamic configuration - these are computed from the current state @property def total_samples_num(self) -> int: @@ -409,8 +405,7 @@ def register_pre_allocated_indexes(self, allocated_indexes: list[int]): max_sample_idx = max(allocated_indexes) required_samples = max_sample_idx + 1 - with self.data_status_lock: - self.ensure_samples_capacity(required_samples) + self.ensure_samples_capacity(required_samples) logger.debug(f"Pre-allocated indexes in {self.partition_id}: {allocated_indexes}") @@ -526,9 +521,8 @@ def update_production_status( max_sample_idx = max(global_indices) if global_indices else -1 required_samples = max_sample_idx + 1 - with self.data_status_lock: - # Ensure we have enough rows - self.ensure_samples_capacity(required_samples) + # Ensure we have enough rows + self.ensure_samples_capacity(required_samples) # Register new fields if needed new_fields = [f for f in field_names if f not in self.field_name_mapping] @@ -538,14 +532,12 @@ def update_production_status( self.field_name_mapping[f] = len(self.field_name_mapping) required_fields = len(self.field_name_mapping) - with self.data_status_lock: - self.ensure_fields_capacity(required_fields) + self.ensure_fields_capacity(required_fields) - with self.data_status_lock: - # Update production status - if self.production_status is not None and global_indices and field_names: - field_indices = [self.field_name_mapping.get(f) for f in field_names] - self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 + # Update production status + if self.production_status is not None and global_indices and field_names: + field_indices = [self.field_name_mapping.get(f) for f in field_names] + self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 # Update field metadata self._update_field_metadata(global_indices, field_schema, custom_backend_meta) @@ -641,8 +633,7 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te if partition_global_index.numel() == 0: empty_status = self.consumption_status[task_name].new_zeros(0) return partition_global_index, empty_status - with self.data_status_lock: - self.ensure_samples_capacity(max(partition_global_index) + 1) + self.ensure_samples_capacity(max(partition_global_index) + 1) consumption_status = self.consumption_status[task_name][partition_global_index] else: consumption_status = self.consumption_status[task_name] @@ -730,23 +721,22 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: if field_name not in self.field_name_mapping: return [] - with self.data_status_lock: - row_mask = torch.ones(self.allocated_samples_num, dtype=torch.bool) + row_mask = torch.ones(self.allocated_samples_num, dtype=torch.bool) - # Apply consumption filter (exclude already consumed samples) - _, consumption_status = self.get_consumption_status(task_name, mask=False) - if consumption_status is not None: - unconsumed_mask = consumption_status == 0 - row_mask &= unconsumed_mask + # Apply consumption filter (exclude already consumed samples) + _, consumption_status = self.get_consumption_status(task_name, mask=False) + if consumption_status is not None: + unconsumed_mask = consumption_status == 0 + row_mask &= unconsumed_mask - # Create column mask for requested fields - col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool) - field_indices = [self.field_name_mapping[field] for field in field_names] - if field_indices: - col_mask[field_indices] = True + # Create column mask for requested fields + col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool) + field_indices = [self.field_name_mapping[field] for field in field_names] + if field_indices: + col_mask[field_indices] = True - # Filter production status by masks - relevant_status = self.production_status[row_mask][:, col_mask] + # Filter production status by masks + relevant_status = self.production_status[row_mask][:, col_mask] # Check if all required fields are ready for each sample all_fields_ready = torch.all(relevant_status, dim=1) @@ -886,9 +876,6 @@ def _perform_copy(): snapshot = cls.__new__(cls) for name, value in self.__dict__.items(): - if name == "data_status_lock": - continue - if isinstance(value, Tensor): new_val = value.clone().detach() else: @@ -897,13 +884,7 @@ def _perform_copy(): setattr(snapshot, name, new_val) return snapshot - lock_obj = getattr(self, "data_status_lock", None) - - if lock_obj: - with lock_obj: - return _perform_copy() - else: - return _perform_copy() + return _perform_copy() def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = True): """Clear all production and optionally consumption data for given global_indexes.""" @@ -1021,7 +1002,6 @@ def __init__( # Start background processing threads self._start_process_handshake() - self._start_process_update_data_status() self._start_process_request() logger.info(f"TransferQueue Controller {self.controller_id} initialized") @@ -1623,8 +1603,7 @@ def kv_retrieve_meta( partition.keys_mapping[keys[none_indexes[i]]] = batch_global_indexes[i] partition.revert_keys_mapping[batch_global_indexes[i]] = keys[none_indexes[i]] - with partition.data_status_lock: - partition.ensure_samples_capacity(max(batch_global_indexes) + 1) + partition.ensure_samples_capacity(max(batch_global_indexes) + 1) verified_global_indexes = [idx for idx in global_indexes if idx is not None] assert len(verified_global_indexes) == len(keys) @@ -1685,7 +1664,6 @@ def _init_zmq_socket(self): try: self._handshake_socket_port = get_free_port(ip=self._node_ip) self._request_handle_socket_port = get_free_port(ip=self._node_ip) - self._data_status_update_socket_port = get_free_port(ip=self._node_ip) self.handshake_socket = create_zmq_socket( ctx=self.zmq_context, @@ -1701,15 +1679,6 @@ def _init_zmq_socket(self): ) self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port)) - self.data_status_update_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ip=self._node_ip, - ) - self.data_status_update_socket.bind( - format_zmq_address(self._node_ip, self._data_status_update_socket_port) - ) - break except zmq.ZMQError: logger.warning(f"[{self.controller_id}]: Try to bind ZMQ sockets failed, retrying...") @@ -1722,7 +1691,6 @@ def _init_zmq_socket(self): ports={ "handshake_socket": self._handshake_socket_port, "request_handle_socket": self._request_handle_socket_port, - "data_status_update_socket": self._data_status_update_socket_port, }, ) @@ -1781,15 +1749,6 @@ def _start_process_handshake(self): ) self.wait_connection_thread.start() - def _start_process_update_data_status(self): - """Start the data status update processing thread.""" - self.process_update_data_status_thread = Thread( - target=self._update_data_status, - name="TransferQueueControllerProcessUpdateDataStatusThread", - daemon=True, - ) - self.process_update_data_status_thread.start() - def _start_process_request(self): """Start the request processing thread.""" self.process_request_thread = Thread( @@ -1834,6 +1793,35 @@ def _process_request(self): body={"metadata": metadata}, ) + elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE: + with monitor.measure(op_type="NOTIFY_DATA_UPDATE"): + message_data = request_msg.body + partition_id = message_data.get("partition_id") + global_indexes = message_data.get("global_indexes", []) + + # Update production status + success = self.update_production_status( + partition_id=partition_id, + global_indexes=global_indexes, + field_schema=message_data.get("field_schema", {}), + custom_backend_meta=message_data.get("custom_backend_meta", {}), + ) + if success: + if self._metrics is not None: + self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes)) + logger.debug(f"[{self.controller_id}]: Updated production status for partition {partition_id}") + + # Send acknowledgment + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, + sender_id=self.controller_id, + body={ + "controller_id": self.controller_id, + "partition_id": partition_id, + "success": success, + }, + ) + elif request_msg.request_type == ZMQRequestType.GET_PARTITION_META: with monitor.measure(op_type="GET_PARTITION_META"): params = request_msg.body @@ -2047,50 +2035,6 @@ def _process_request(self): self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) - def _update_data_status(self): - """Process data status update messages from storage units - adapted for partitions.""" - logger.debug(f"[{self.controller_id}]: start receiving update_data_status requests...") - - perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id) - - while True: - monitor = self._metrics if self._metrics is not None else perf_monitor - - messages = self.data_status_update_socket.recv_multipart(copy=False) - identity = messages.pop(0) - serialized_msg = messages - request_msg = ZMQMessage.deserialize(serialized_msg) - - if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE: - with monitor.measure(op_type="NOTIFY_DATA_UPDATE"): - message_data = request_msg.body - partition_id = message_data.get("partition_id") - global_indexes = message_data.get("global_indexes", []) - - # Update production status - success = self.update_production_status( - partition_id=partition_id, - global_indexes=global_indexes, - field_schema=message_data.get("field_schema", {}), - custom_backend_meta=message_data.get("custom_backend_meta", {}), - ) - if success: - if self._metrics is not None: - self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes)) - logger.debug(f"[{self.controller_id}]: Updated production status for partition {partition_id}") - - # Send acknowledgment - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, - sender_id=self.controller_id, - body={ - "controller_id": self.controller_id, - "partition_id": partition_id, - "success": success, - }, - ) - self.data_status_update_socket.send_multipart([identity, *response_msg.serialize()]) - def get_zmq_server_info(self) -> ZMQServerInfo: """Get ZMQ server connection information.""" return self.zmq_server_info diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 180d466e..7057cc1e 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -210,7 +210,7 @@ async def notify_data_update( sock = create_zmq_socket(self.zmq_context, zmq.DEALER, self.controller_info.ip, identity) try: - sock.connect(self.controller_info.to_addr("data_status_update_socket")) + sock.connect(self.controller_info.to_addr("request_handle_socket")) normalized_field_schema = {} for field_name, field in field_schema.items(): From 8816640c7278fc5373d703f4831fde5fd22ae1a3 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 4 Jun 2026 16:05:55 +0800 Subject: [PATCH 2/2] fix Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 5c28a88d..8d8c9e1f 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1050,7 +1050,7 @@ def _get_partition(self, partition_id: str) -> DataPartitionStatus | None: def get_partition_snapshot(self, partition_id: str) -> DataPartitionStatus | None: """ - Get a copy of partition status information, without threading.Lock(). + Get a copy of partition status information. Args: partition_id: ID of the partition to retrieve @@ -1815,6 +1815,7 @@ def _process_request(self): response_msg = ZMQMessage.create( request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, sender_id=self.controller_id, + receiver_id=request_msg.sender_id, body={ "controller_id": self.controller_id, "partition_id": partition_id,