From 4d894e2f4d65243b4cc3cbe88ca64b3d010e1e5c Mon Sep 17 00:00:00 2001 From: ppraneth Date: Sun, 24 May 2026 12:14:12 +0530 Subject: [PATCH 1/2] test Signed-off-by: ppraneth --- tests/test_partition_results.py | 167 ++++++++++++++++++++ torchspec/controller/training_controller.py | 32 +++- 2 files changed, 196 insertions(+), 3 deletions(-) create mode 100644 tests/test_partition_results.py diff --git a/tests/test_partition_results.py b/tests/test_partition_results.py new file mode 100644 index 00000000..6db593c4 --- /dev/null +++ b/tests/test_partition_results.py @@ -0,0 +1,167 @@ +import importlib +import sys +from dataclasses import dataclass +from unittest.mock import patch + +import torch + +from torchspec.utils.types import InferenceOutput + + +@dataclass +class MockControllerArgs: + per_dp_rank_batch_size: int = 2 + max_sample_pool_size: int = 0 + + +def _create_controller_class(): + module_name = "torchspec.controller.training_controller" + if module_name in sys.modules: + del sys.modules[module_name] + with patch("ray.remote", lambda cls: cls): + module = importlib.import_module(module_name) + return module.AsyncTrainingController + + +def _make_output(data_id: str, seq_len: int) -> InferenceOutput: + return InferenceOutput( + data_id=data_id, + mooncake_key=f"key-{data_id}", + tensor_shapes={"input_ids": (1, seq_len), "hidden_states": (1, seq_len, 4096)}, + tensor_dtypes={"input_ids": torch.int64, "hidden_states": torch.bfloat16}, + ) + + +def _make_controller(dp_size: int, per_dp_rank_batch_size: int): + AsyncTrainingController = _create_controller_class() + args = MockControllerArgs(per_dp_rank_batch_size=per_dp_rank_batch_size) + return AsyncTrainingController(args, dp_size=dp_size) + + +class TestPartitionFallback: + """When at most one sample per rank, partition is round-robin.""" + + def test_single_dp_rank_keeps_all_samples_together(self): + controller = _make_controller(dp_size=1, per_dp_rank_batch_size=4) + results = [_make_output(f"s{i}", seq_len=100 + i) for i in range(4)] + + partitions = controller._partition_results(results) + + assert len(partitions) == 1 + assert [r.data_id for r in partitions[0]] == ["s0", "s1", "s2", "s3"] + + def test_one_sample_per_rank_uses_round_robin(self): + controller = _make_controller(dp_size=4, per_dp_rank_batch_size=1) + results = [_make_output(f"s{i}", seq_len=1000 - 100 * i) for i in range(4)] + + partitions = controller._partition_results(results) + + assert [p[0].data_id for p in partitions] == ["s0", "s1", "s2", "s3"] + + def test_empty_results_returns_empty_partitions(self): + controller = _make_controller(dp_size=4, per_dp_rank_batch_size=2) + + partitions = controller._partition_results([]) + + assert partitions == [[], [], [], []] + + +class TestPartitionBinPacking: + """When per-rank capacity > 1, partition balances total sequence load.""" + + def test_capacity_is_exactly_results_per_rank(self): + controller = _make_controller(dp_size=2, per_dp_rank_batch_size=2) + results = [_make_output(f"s{i}", seq_len=100) for i in range(4)] + + partitions = controller._partition_results(results) + + assert len(partitions) == 2 + assert len(partitions[0]) == 2 + assert len(partitions[1]) == 2 + + def test_longest_first_balances_load_across_ranks(self): + # Lengths 1000, 800, 200, 100 across dp=2 mbs=2: + # Greedy LPT pairs 1000+100 and 800+200 (loads 1100 and 1000), + # which is more balanced than round-robin's (1000+200, 800+100) = (1200, 900). + controller = _make_controller(dp_size=2, per_dp_rank_batch_size=2) + results = [ + _make_output("a", 1000), + _make_output("b", 800), + _make_output("c", 200), + _make_output("d", 100), + ] + + partitions = controller._partition_results(results) + + loads = [sum(r.tensor_shapes["input_ids"][-1] for r in p) for p in partitions] + assert sorted(loads) == [1000, 1100] + + def test_each_rank_receives_exactly_capacity_samples(self): + # Stress test: skewed lengths must not violate per-rank capacity. + controller = _make_controller(dp_size=4, per_dp_rank_batch_size=3) + lengths = [2000, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100] + results = [_make_output(f"s{i}", L) for i, L in enumerate(lengths)] + + partitions = controller._partition_results(results) + + assert len(partitions) == 4 + for p in partitions: + assert len(p) == 3 + # Every sample is assigned exactly once. + assigned = sorted(r.data_id for p in partitions for r in p) + assert assigned == sorted(r.data_id for r in results) + + def test_outlier_does_not_starve_other_ranks(self): + # One huge sample plus many small ones — the rank holding the + # outlier should still receive `capacity` samples, not all of them. + controller = _make_controller(dp_size=2, per_dp_rank_batch_size=4) + results = [_make_output("big", 5000)] + [_make_output(f"s{i}", 100) for i in range(7)] + + partitions = controller._partition_results(results) + + assert len(partitions[0]) == 4 + assert len(partitions[1]) == 4 + + def test_partition_is_deterministic_for_fixed_input(self): + controller = _make_controller(dp_size=2, per_dp_rank_batch_size=2) + results = [_make_output(f"s{i}", L) for i, L in enumerate([300, 200, 400, 100])] + + p1 = controller._partition_results(results) + p2 = controller._partition_results(results) + + ids1 = [[r.data_id for r in part] for part in p1] + ids2 = [[r.data_id for r in part] for part in p2] + assert ids1 == ids2 + + +class TestPartitionDefensiveFallback: + """`_partition_results` should not crash when `input_ids` shape is missing.""" + + def test_missing_input_ids_shape_treated_as_zero_length(self): + controller = _make_controller(dp_size=2, per_dp_rank_batch_size=2) + results = [ + InferenceOutput( + data_id=f"s{i}", + mooncake_key=f"k{i}", + tensor_shapes={"hidden_states": (1, 100, 4096)}, # no "input_ids" + tensor_dtypes={"hidden_states": torch.bfloat16}, + ) + for i in range(4) + ] + + partitions = controller._partition_results(results) + + assert len(partitions) == 2 + assert len(partitions[0]) == 2 + assert len(partitions[1]) == 2 + + def test_none_tensor_shapes_treated_as_zero_length(self): + controller = _make_controller(dp_size=2, per_dp_rank_batch_size=2) + results = [ + InferenceOutput(data_id=f"s{i}", mooncake_key=f"k{i}", tensor_shapes=None) + for i in range(4) + ] + + partitions = controller._partition_results(results) + + assert sum(len(p) for p in partitions) == 4 diff --git a/torchspec/controller/training_controller.py b/torchspec/controller/training_controller.py index 1c944619..736e191c 100644 --- a/torchspec/controller/training_controller.py +++ b/torchspec/controller/training_controller.py @@ -472,11 +472,37 @@ def try_dispatch_batch(self) -> bool: self.batch_id += 1 return True + @staticmethod + def _seq_len(result: InferenceOutput) -> int: + shapes = result.tensor_shapes or {} + ids_shape = shapes.get("input_ids") + return ids_shape[-1] if ids_shape else 0 + def _partition_results(self, results: list[InferenceOutput]) -> list[list[InferenceOutput]]: - """Partition InferenceOutputs across DP ranks.""" + """Partition InferenceOutputs across DP ranks. + + When each rank receives more than one sample per dispatch, uses + longest-first greedy bin-packing with a per-rank capacity cap so + that ranks see similar total sequence load. Falls back to + round-robin when there is at most one sample per rank (e.g. eval + dispatch, or training with per_dp_rank_batch_size=1) because no + balancing is possible in that case. + """ partitions: list[list[InferenceOutput]] = [[] for _ in range(self.dp_size)] - for i, result in enumerate(results): - partitions[i % self.dp_size].append(result) + if self.dp_size <= 1 or len(results) <= self.dp_size: + for i, result in enumerate(results): + partitions[i % self.dp_size].append(result) + return partitions + + capacity = len(results) // self.dp_size + loads = [0] * self.dp_size + for result in sorted(results, key=self._seq_len, reverse=True): + min_rank = min( + (r for r in range(self.dp_size) if len(partitions[r]) < capacity), + key=lambda r: loads[r], + ) + partitions[min_rank].append(result) + loads[min_rank] += self._seq_len(result) return partitions def _dispatch_to_queues( From 7dfda81fbbb365e3ba236c3473133a393902613d Mon Sep 17 00:00:00 2001 From: ppraneth Date: Sun, 24 May 2026 13:39:32 +0530 Subject: [PATCH 2/2] fix non-divisible batches Signed-off-by: ppraneth --- tests/test_partition_results.py | 12 ++++++++++++ torchspec/controller/training_controller.py | 7 ++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/test_partition_results.py b/tests/test_partition_results.py index 6db593c4..21aafb2a 100644 --- a/tests/test_partition_results.py +++ b/tests/test_partition_results.py @@ -65,6 +65,18 @@ def test_empty_results_returns_empty_partitions(self): assert partitions == [[], [], [], []] + def test_non_divisible_batch_falls_back_to_round_robin(self): + # 5 results over 2 ranks: capacity would floor to 2 and the + # greedy generator would empty out on the 5th item. Fall back + # to round-robin instead of crashing. + controller = _make_controller(dp_size=2, per_dp_rank_batch_size=2) + results = [_make_output(f"s{i}", seq_len=100 + i) for i in range(5)] + + partitions = controller._partition_results(results) + + assert [r.data_id for r in partitions[0]] == ["s0", "s2", "s4"] + assert [r.data_id for r in partitions[1]] == ["s1", "s3"] + class TestPartitionBinPacking: """When per-rank capacity > 1, partition balances total sequence load.""" diff --git a/torchspec/controller/training_controller.py b/torchspec/controller/training_controller.py index 736e191c..7298db91 100644 --- a/torchspec/controller/training_controller.py +++ b/torchspec/controller/training_controller.py @@ -485,11 +485,12 @@ def _partition_results(self, results: list[InferenceOutput]) -> list[list[Infere longest-first greedy bin-packing with a per-rank capacity cap so that ranks see similar total sequence load. Falls back to round-robin when there is at most one sample per rank (e.g. eval - dispatch, or training with per_dp_rank_batch_size=1) because no - balancing is possible in that case. + dispatch, or training with per_dp_rank_batch_size=1) or when + len(results) is not divisible by dp_size — preserving the old + round-robin behavior for irregular batch sizes. """ partitions: list[list[InferenceOutput]] = [[] for _ in range(self.dp_size)] - if self.dp_size <= 1 or len(results) <= self.dp_size: + if self.dp_size <= 1 or len(results) <= self.dp_size or len(results) % self.dp_size != 0: for i, result in enumerate(results): partitions[i % self.dp_size].append(result) return partitions