Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions tests/test_async_simple_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def mock_async_storage_manager():
role=Role.CONTROLLER,
id="controller_0",
ip="127.0.0.1",
ports={"handshake_socket": 12347, "data_status_update_socket": 12348},
ports={"handshake_socket": 12347},
)

config = {
Expand All @@ -68,7 +68,6 @@ async def mock_async_storage_manager():
manager.config = config
manager.controller_info = controller_info
manager.storage_unit_infos = storage_unit_infos
manager.data_status_update_socket = None
manager.controller_handshake_socket = None
manager.zmq_context = None

Expand Down Expand Up @@ -158,7 +157,7 @@ async def test_async_storage_manager_error_handling():
role=Role.CONTROLLER,
id="controller_0",
ip="127.0.0.1",
ports={"handshake_socket": 12346, "data_status_update_socket": 12347},
ports={"handshake_socket": 12346},
)

config = {
Expand Down Expand Up @@ -257,7 +256,6 @@ async def test_get_data_routes_from_hash():
manager.storage_manager_id = "test_get"
manager.storage_unit_infos = storage_unit_infos
manager.controller_info = None
manager.data_status_update_socket = None
manager.controller_handshake_socket = None
manager.zmq_context = None

Expand Down Expand Up @@ -310,7 +308,6 @@ async def test_clear_data_routes_from_hash():
manager.storage_manager_id = "test_clear"
manager.storage_unit_infos = storage_unit_infos
manager.controller_info = None
manager.data_status_update_socket = None
manager.controller_handshake_socket = None
manager.zmq_context = None

Expand Down Expand Up @@ -361,7 +358,6 @@ async def test_hash_routing_stable_across_batch_sizes():
manager.storage_manager_id = "test_hash_batch"
manager.storage_unit_infos = storage_unit_infos
manager.controller_info = None
manager.data_status_update_socket = None
manager.controller_handshake_socket = None
manager.zmq_context = None

Expand Down Expand Up @@ -422,7 +418,6 @@ async def test_hash_routing_stable_reversed_order():
manager.storage_manager_id = "test_hash_order"
manager.storage_unit_infos = storage_unit_infos
manager.controller_info = None
manager.data_status_update_socket = None
manager.controller_handshake_socket = None
manager.zmq_context = None

Expand Down
1 change: 0 additions & 1 deletion tests/test_ray_p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def create_mock_controller():
ip="127.0.0.1",
ports={
"request_handle_socket": 9981,
"data_status_update_socket": 9982,
"handshake_socket": 9983,
},
)
Expand Down
167 changes: 56 additions & 111 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dataclasses import dataclass, field
from itertools import groupby
from operator import itemgetter
from threading import Lock, Thread
from threading import Thread
from typing import TYPE_CHECKING, Any
from uuid import uuid4

Expand Down Expand Up @@ -361,10 +361,6 @@ class DataPartitionStatus:
keys_mapping: dict[str, int] = field(default_factory=dict) # key -> global_idx
revert_keys_mapping: dict[int, str] = field(default_factory=dict) # global_idx -> key

# Threading lock for concurrency control; only for preventing mask operation error when expanding production_status.
# No need to strictly lock for every read/write operation since freshness is not critical.
data_status_lock: Lock = field(default_factory=Lock)

# Dynamic configuration - these are computed from the current state
@property
def total_samples_num(self) -> int:
Expand Down Expand Up @@ -409,8 +405,7 @@ def register_pre_allocated_indexes(self, allocated_indexes: list[int]):
max_sample_idx = max(allocated_indexes)
required_samples = max_sample_idx + 1

with self.data_status_lock:
self.ensure_samples_capacity(required_samples)
self.ensure_samples_capacity(required_samples)

logger.debug(f"Pre-allocated indexes in {self.partition_id}: {allocated_indexes}")

Expand Down Expand Up @@ -526,9 +521,8 @@ def update_production_status(
max_sample_idx = max(global_indices) if global_indices else -1
required_samples = max_sample_idx + 1

with self.data_status_lock:
# Ensure we have enough rows
self.ensure_samples_capacity(required_samples)
# Ensure we have enough rows
self.ensure_samples_capacity(required_samples)

# Register new fields if needed
new_fields = [f for f in field_names if f not in self.field_name_mapping]
Expand All @@ -538,14 +532,12 @@ def update_production_status(
self.field_name_mapping[f] = len(self.field_name_mapping)

required_fields = len(self.field_name_mapping)
with self.data_status_lock:
self.ensure_fields_capacity(required_fields)
self.ensure_fields_capacity(required_fields)

with self.data_status_lock:
# Update production status
if self.production_status is not None and global_indices and field_names:
field_indices = [self.field_name_mapping.get(f) for f in field_names]
self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1
# Update production status
if self.production_status is not None and global_indices and field_names:
field_indices = [self.field_name_mapping.get(f) for f in field_names]
self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1

# Update field metadata
self._update_field_metadata(global_indices, field_schema, custom_backend_meta)
Expand Down Expand Up @@ -641,8 +633,7 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
if partition_global_index.numel() == 0:
empty_status = self.consumption_status[task_name].new_zeros(0)
return partition_global_index, empty_status
with self.data_status_lock:
self.ensure_samples_capacity(max(partition_global_index) + 1)
self.ensure_samples_capacity(max(partition_global_index) + 1)
consumption_status = self.consumption_status[task_name][partition_global_index]
else:
consumption_status = self.consumption_status[task_name]
Expand Down Expand Up @@ -730,23 +721,22 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]:
if field_name not in self.field_name_mapping:
return []

with self.data_status_lock:
row_mask = torch.ones(self.allocated_samples_num, dtype=torch.bool)
row_mask = torch.ones(self.allocated_samples_num, dtype=torch.bool)

# Apply consumption filter (exclude already consumed samples)
_, consumption_status = self.get_consumption_status(task_name, mask=False)
if consumption_status is not None:
unconsumed_mask = consumption_status == 0
row_mask &= unconsumed_mask
# Apply consumption filter (exclude already consumed samples)
_, consumption_status = self.get_consumption_status(task_name, mask=False)
if consumption_status is not None:
Comment on lines +724 to +728
unconsumed_mask = consumption_status == 0
row_mask &= unconsumed_mask

# Create column mask for requested fields
col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool)
field_indices = [self.field_name_mapping[field] for field in field_names]
if field_indices:
col_mask[field_indices] = True
# Create column mask for requested fields
col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool)
field_indices = [self.field_name_mapping[field] for field in field_names]
if field_indices:
col_mask[field_indices] = True

# Filter production status by masks
relevant_status = self.production_status[row_mask][:, col_mask]
# Filter production status by masks
relevant_status = self.production_status[row_mask][:, col_mask]

# Check if all required fields are ready for each sample
all_fields_ready = torch.all(relevant_status, dim=1)
Expand Down Expand Up @@ -886,9 +876,6 @@ def _perform_copy():
snapshot = cls.__new__(cls)

for name, value in self.__dict__.items():
if name == "data_status_lock":
continue

if isinstance(value, Tensor):
new_val = value.clone().detach()
else:
Expand All @@ -897,13 +884,7 @@ def _perform_copy():
setattr(snapshot, name, new_val)
return snapshot

lock_obj = getattr(self, "data_status_lock", None)

if lock_obj:
with lock_obj:
return _perform_copy()
else:
return _perform_copy()
return _perform_copy()

def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = True):
"""Clear all production and optionally consumption data for given global_indexes."""
Expand Down Expand Up @@ -1021,7 +1002,6 @@ def __init__(

# Start background processing threads
self._start_process_handshake()
self._start_process_update_data_status()
self._start_process_request()
Comment on lines 1003 to 1005

logger.info(f"TransferQueue Controller {self.controller_id} initialized")
Expand Down Expand Up @@ -1070,7 +1050,7 @@ def _get_partition(self, partition_id: str) -> DataPartitionStatus | None:

def get_partition_snapshot(self, partition_id: str) -> DataPartitionStatus | None:
"""
Get a copy of partition status information, without threading.Lock().
Get a copy of partition status information.

Args:
partition_id: ID of the partition to retrieve
Expand Down Expand Up @@ -1623,8 +1603,7 @@ def kv_retrieve_meta(
partition.keys_mapping[keys[none_indexes[i]]] = batch_global_indexes[i]
partition.revert_keys_mapping[batch_global_indexes[i]] = keys[none_indexes[i]]

with partition.data_status_lock:
partition.ensure_samples_capacity(max(batch_global_indexes) + 1)
partition.ensure_samples_capacity(max(batch_global_indexes) + 1)

verified_global_indexes = [idx for idx in global_indexes if idx is not None]
assert len(verified_global_indexes) == len(keys)
Expand Down Expand Up @@ -1685,7 +1664,6 @@ def _init_zmq_socket(self):
try:
self._handshake_socket_port = get_free_port(ip=self._node_ip)
self._request_handle_socket_port = get_free_port(ip=self._node_ip)
self._data_status_update_socket_port = get_free_port(ip=self._node_ip)

self.handshake_socket = create_zmq_socket(
ctx=self.zmq_context,
Expand All @@ -1701,15 +1679,6 @@ def _init_zmq_socket(self):
)
self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port))

self.data_status_update_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.data_status_update_socket.bind(
format_zmq_address(self._node_ip, self._data_status_update_socket_port)
)

break
except zmq.ZMQError:
logger.warning(f"[{self.controller_id}]: Try to bind ZMQ sockets failed, retrying...")
Expand All @@ -1722,7 +1691,6 @@ def _init_zmq_socket(self):
ports={
"handshake_socket": self._handshake_socket_port,
"request_handle_socket": self._request_handle_socket_port,
"data_status_update_socket": self._data_status_update_socket_port,
},
)

Expand Down Expand Up @@ -1781,15 +1749,6 @@ def _start_process_handshake(self):
)
self.wait_connection_thread.start()

def _start_process_update_data_status(self):
"""Start the data status update processing thread."""
self.process_update_data_status_thread = Thread(
target=self._update_data_status,
name="TransferQueueControllerProcessUpdateDataStatusThread",
daemon=True,
)
self.process_update_data_status_thread.start()

def _start_process_request(self):
"""Start the request processing thread."""
self.process_request_thread = Thread(
Expand Down Expand Up @@ -1834,6 +1793,36 @@ def _process_request(self):
body={"metadata": metadata},
)

elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE:
with monitor.measure(op_type="NOTIFY_DATA_UPDATE"):
message_data = request_msg.body
partition_id = message_data.get("partition_id")
global_indexes = message_data.get("global_indexes", [])

# Update production status
success = self.update_production_status(
partition_id=partition_id,
global_indexes=global_indexes,
field_schema=message_data.get("field_schema", {}),
custom_backend_meta=message_data.get("custom_backend_meta", {}),
)
if success:
if self._metrics is not None:
self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes))
logger.debug(f"[{self.controller_id}]: Updated production status for partition {partition_id}")

# Send acknowledgment
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={
"controller_id": self.controller_id,
"partition_id": partition_id,
"success": success,
},
)

elif request_msg.request_type == ZMQRequestType.GET_PARTITION_META:
with monitor.measure(op_type="GET_PARTITION_META"):
params = request_msg.body
Expand Down Expand Up @@ -2047,50 +2036,6 @@ def _process_request(self):

self.request_handle_socket.send_multipart([identity, *response_msg.serialize()])

def _update_data_status(self):
"""Process data status update messages from storage units - adapted for partitions."""
logger.debug(f"[{self.controller_id}]: start receiving update_data_status requests...")

perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id)

while True:
monitor = self._metrics if self._metrics is not None else perf_monitor

messages = self.data_status_update_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)

if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE:
with monitor.measure(op_type="NOTIFY_DATA_UPDATE"):
message_data = request_msg.body
partition_id = message_data.get("partition_id")
global_indexes = message_data.get("global_indexes", [])

# Update production status
success = self.update_production_status(
partition_id=partition_id,
global_indexes=global_indexes,
field_schema=message_data.get("field_schema", {}),
custom_backend_meta=message_data.get("custom_backend_meta", {}),
)
if success:
if self._metrics is not None:
self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes))
logger.debug(f"[{self.controller_id}]: Updated production status for partition {partition_id}")

# Send acknowledgment
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
sender_id=self.controller_id,
body={
"controller_id": self.controller_id,
"partition_id": partition_id,
"success": success,
},
)
self.data_status_update_socket.send_multipart([identity, *response_msg.serialize()])

def get_zmq_server_info(self) -> ZMQServerInfo:
"""Get ZMQ server connection information."""
return self.zmq_server_info
Expand Down
2 changes: 1 addition & 1 deletion transfer_queue/storage/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ async def notify_data_update(
sock = create_zmq_socket(self.zmq_context, zmq.DEALER, self.controller_info.ip, identity)

try:
sock.connect(self.controller_info.to_addr("data_status_update_socket"))
sock.connect(self.controller_info.to_addr("request_handle_socket"))

normalized_field_schema = {}
for field_name, field in field_schema.items():
Expand Down
Loading