Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ def _handle_requests(self):
elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META:
response_body = {"message": "success"}
response_type = ZMQRequestType.SET_CUSTOM_META_RESPONSE
elif request_msg.request_type == ZMQRequestType.RESET_CONSUMPTION:
# Mock reset consumption - always succeed
response_body = {
"success": True,
"message": "Consumption reset successfully",
}
response_type = ZMQRequestType.RESET_CONSUMPTION_RESPONSE
else:
response_body = {"error": f"Unknown request type: {request_msg.request_type}"}
response_type = ZMQRequestType.CLEAR_META_RESPONSE
Expand Down Expand Up @@ -531,6 +538,52 @@ def test_get_partition_list(client_setup):
assert "test_partition" in partition_list


def test_reset_consumption(client_setup):
"""Test synchronous reset_consumption - resets consumption status for a partition"""
client, _, _ = client_setup

# Test synchronous reset_consumption with task_name
success = client.reset_consumption(partition_id="train_0", task_name="generate_sequences")
assert success is True

print("✓ reset_consumption with task_name returns True")


def test_reset_consumption_all_tasks(client_setup):
"""Test synchronous reset_consumption without task_name (resets all tasks)"""
client, _, _ = client_setup

# Test synchronous reset_consumption without task_name (reset all tasks)
success = client.reset_consumption(partition_id="train_0")
assert success is True

print("✓ reset_consumption without task_name (all tasks) returns True")


@pytest.mark.asyncio
async def test_async_reset_consumption(client_setup):
"""Test async reset_consumption - resets consumption status for a partition"""
client, _, _ = client_setup

# Test async_reset_consumption with task_name
success = await client.async_reset_consumption(partition_id="train_0", task_name="generate_sequences")
assert success is True

print("✓ async_reset_consumption with task_name returns True")


@pytest.mark.asyncio
async def test_async_reset_consumption_all_tasks(client_setup):
"""Test async reset_consumption without task_name (resets all tasks)"""
client, _, _ = client_setup

# Test async_reset_consumption without task_name (reset all tasks)
success = await client.async_reset_consumption(partition_id="train_0")
assert success is True

print("✓ async_reset_consumption without task_name (all tasks) returns True")


@pytest.mark.asyncio
async def test_async_check_consumption_status(client_setup):
"""Test async consumption status checking"""
Expand Down
157 changes: 157 additions & 0 deletions tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,163 @@ def test_controller_with_single_partition(self, ray_setup):
assert partition is None
print("✓ Clear partition correct")

def test_controller_reset_consumption(self, ray_setup):
"""Test reset_consumption functionality - allows data to be re-consumed"""
gbs = 4
num_n_samples = 2
partition_id = "test_reset_consumption"

tq_controller = TransferQueueController.remote()

# Step 1: Create metadata in insert mode
data_fields = ["prompt_ids", "attention_mask"]
metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="insert",
)
)
assert metadata.global_indexes == list(range(gbs * num_n_samples))

# Step 2: Update production status
dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in metadata.global_indexes}
shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in metadata.global_indexes}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=partition_id,
global_indexes=metadata.global_indexes,
field_names=metadata.field_names,
dtypes=dtypes,
shapes=shapes,
)
)
assert success

# Step 3: Verify consumption status BEFORE consumption (should be all zeros)
global_index, consumption_status = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
expected_consumption_before = torch.zeros(gbs * num_n_samples, dtype=torch.int8)
assert torch.equal(consumption_status, expected_consumption_before)
print("✓ Consumption status before fetch is all zeros")

# Step 4: Fetch data (mark as consumed)
gen_meta = ray.get(
tq_controller.get_metadata.remote(
data_fields=["prompt_ids"],
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="fetch",
task_name="generate_sequences",
)
)
assert gen_meta.global_indexes == list(range(gbs * num_n_samples))

# Step 5: Verify consumption status AFTER consumption (should be all ones)
global_index, consumption_status = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
expected_consumption_after = torch.ones(gbs * num_n_samples, dtype=torch.int8)
assert torch.equal(consumption_status, expected_consumption_after)
print("✓ Consumption status after fetch is all ones")

# Step 6: Reset consumption for specific task
ray.get(
tq_controller.reset_consumption.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)

# Step 7: Verify consumption status is reset (should be all zeros again)
global_index, consumption_status = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
expected_consumption_reset = torch.zeros(gbs * num_n_samples, dtype=torch.int8)
assert torch.equal(consumption_status, expected_consumption_reset)
print("✓ Consumption status after reset is all zeros")

# Step 8: Consume again and test reset all tasks
gen_meta_2 = ray.get(
tq_controller.get_metadata.remote(
data_fields=["prompt_ids"],
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="fetch",
task_name="generate_sequences",
)
)
assert gen_meta_2.global_indexes == list(range(gbs * num_n_samples))

# Also consume with another task
gen_meta_3 = ray.get(
tq_controller.get_metadata.remote(
data_fields=["attention_mask"],
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="fetch",
task_name="another_task",
)
)
assert gen_meta_3.global_indexes == list(range(gbs * num_n_samples))

# Verify both tasks have consumed
_, consumption_status_task1 = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
_, consumption_status_task2 = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="another_task",
)
)
assert torch.equal(consumption_status_task1, torch.ones(gbs * num_n_samples, dtype=torch.int8))
assert torch.equal(consumption_status_task2, torch.ones(gbs * num_n_samples, dtype=torch.int8))
print("✓ Both tasks consumed successfully")

# Step 9: Reset all tasks (task_name=None)
ray.get(
tq_controller.reset_consumption.remote(
partition_id=partition_id,
task_name=None, # Reset all tasks
)
)

# Step 10: Verify all tasks are reset
_, consumption_status_task1_reset = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
_, consumption_status_task2_reset = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="another_task",
)
)
assert torch.equal(consumption_status_task1_reset, torch.zeros(gbs * num_n_samples, dtype=torch.int8))
assert torch.equal(consumption_status_task2_reset, torch.zeros(gbs * num_n_samples, dtype=torch.int8))
print("✓ Reset all tasks successful - both tasks have zero consumption status")

# Clean up
ray.get(tq_controller.clear_partition.remote(partition_id))
print("✓ Reset consumption test completed successfully")

def test_controller_with_multi_partitions(self, ray_setup):
gbs_1 = 8
num_n_samples_1 = 4
Expand Down
70 changes: 70 additions & 0 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,63 @@ async def async_check_consumption_status(
return False
return torch.all(consumption_status == 1).item()

@dynamic_socket(socket_name="request_handle_socket")
async def async_reset_consumption(
self,
partition_id: str,
task_name: Optional[str] = None,
socket: Optional[zmq.asyncio.Socket] = None,
) -> bool:
"""Reset consumption status for a partition, allowing data to be re-consumed.
This is useful for debugging scenarios where the same rollout data needs to be
trained multiple times without regenerating the data.
Args:
partition_id: Partition id to reset consumption status for
task_name: Name of the task to reset. If None, resets all tasks.
socket: ZMQ async socket for message transmission (injected by decorator)
Returns:
bool: True if reset was successful, False otherwise
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
>>> # Reset consumption for train task to re-train on same data
>>> success = asyncio.run(client.async_reset_consumption(
... partition_id="train_0",
... task_name="train"
... ))
>>> print(f"Reset successful: {success}")
"""
assert socket is not None
body = {"partition_id": partition_id}
if task_name is not None:
body["task_name"] = task_name
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.RESET_CONSUMPTION,
sender_id=self.client_id,
receiver_id=self._controller.id,
body=body,
)
try:
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart()
response_msg = ZMQMessage.deserialize(response_serialized)
logger.debug(
f"[{self.client_id}]: Client reset consumption response: {response_msg} "
f"from controller {self._controller.id}"
)
if response_msg.request_type == ZMQRequestType.RESET_CONSUMPTION_RESPONSE:
success = response_msg.body.get("success", False)
if not success:
logger.warning(f"[{self.client_id}]: Reset consumption failed: {response_msg.body.get('message')}")
return success
else:
raise RuntimeError(
f"[{self.client_id}]: Failed to reset consumption from controller {self._controller.id}: "
f"{response_msg.body.get('message', 'Unknown error')}"
)
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e

async def async_check_production_status(
self,
data_fields: list[str],
Expand Down Expand Up @@ -917,6 +974,7 @@ def wrapper(*args, **kwargs):
self._check_production_status = _make_sync(self.async_check_production_status)
self._get_partition_list = _make_sync(self.async_get_partition_list)
self._set_custom_meta = _make_sync(self.async_set_custom_meta)
self._reset_consumption = _make_sync(self.async_reset_consumption)

def put(
self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None
Expand Down Expand Up @@ -1138,6 +1196,18 @@ def get_consumption_status(
"""
return self._get_consumption_status(task_name, partition_id)

def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool:
"""Synchronously reset consumption status for a partition.
This allows the same data to be re-consumed, useful for debugging scenarios
where the same rollout data needs to be trained multiple times.
Args:
partition_id: Partition id to reset consumption status for
task_name: Name of the task to reset. If None, resets all tasks.
Returns:
bool: True if reset was successful, False otherwise
"""
return self._reset_consumption(partition_id, task_name)

def check_production_status(self, data_fields: list[str], partition_id: str) -> bool:
"""Synchronously check if all samples for a partition are ready (produced) for consumption.

Expand Down
Loading