feat(comm): support source-Partial placements#98
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughExtends Etha's tensor redistribution to support source-side Partial placements by collapsing them to Replicate via chunk-level all-reduce before sending, while rejecting Partial on targets. Integrates routing, distributed IR, agent process-groups, chunk/bucket execution, and comprehensive tests covering symmetric/asymmetric meshes and benchmark configurations. ChangesPartial Placement Support for Cross-Process-Group Redistribution
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Failed to generate code suggestions for PR |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (3)
src/etha/comm/get_m2m_map.py (1)
39-42: 💤 Low valuePotential runtime error if rank not in mesh tensor.
_rank_multi_idxwill raiseValueErrorifrankis not present inmesh_tensor(fromlist.index()). This appears intentional since callers only pass ranks known to be in the mesh, but the error message would be cryptic.Consider if a clearer error is warranted, or document this precondition.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/etha/comm/get_m2m_map.py` around lines 39 - 42, _rank_multi_idx currently uses list.index() which raises a cryptic ValueError if the requested rank isn't present; update _rank_multi_idx to explicitly check whether rank exists in mesh_tensor (e.g., via torch.eq(mesh_tensor, rank).any() or checking the flattened list) and if not raise a ValueError with a clear message including the requested rank and mesh_tensor.shape, or alternatively document the precondition in the function docstring stating that rank must be present; ensure references to the symbol _rank_multi_idx and the rank/mesh_tensor values appear in the message or docstring so callers get a helpful diagnostic.tests/test_communication_symmetric_mesh.py (1)
72-80: 💤 Low valuePrefix unused variables with underscore to satisfy linter.
Static analysis correctly identifies that
m2m_map_b_to_a,source_slicers_b, andtarget_slicers_aare unused. The call is intentional (deadlock test), but the variables should be prefixed with_for clarity and to silence the warnings.🧹 Proposed fix
# Direction 2: B -> A (reversed to test deadlock) - m2m_map_b_to_a, source_slicers_b, target_slicers_a, _ = get_m2m_map( + _m2m_map_b_to_a, _source_slicers_b, _target_slicers_a, _ = get_m2m_map( source_mesh=mesh_b, source_placements=specs, target_mesh=mesh_a, target_placements=specs, group=dist.group.WORLD, device=device, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_communication_symmetric_mesh.py` around lines 72 - 80, The variables m2m_map_b_to_a, source_slicers_b, and target_slicers_a returned from get_m2m_map are intentionally unused for the deadlock test—prefix them with an underscore (e.g., _m2m_map_b_to_a, _source_slicers_b, _target_slicers_a) in the assignment to silence the linter and clarify intent while leaving the get_m2m_map(...) call unchanged.src/etha/tensor_bus/agent.py (1)
693-707: 💤 Low valueMissing validation for unknown reduce_op strings.
If
op_stris not in_REDUCE_OP_MAP, line 705 will raise aKeyErrorwith a potentially confusing message. This could happen ifPartialis constructed with an unsupported reduce op (e.g.,"bor","band").Consider adding a guard with a clearer error message, or document the supported ops.
🛡️ Proposed defensive check
reduced = tensor.detach().clone() for group, op_str in pair_state.source_partial_groups: + if op_str not in _REDUCE_OP_MAP: + raise ValueError( + f"Unsupported Partial reduce_op '{op_str}'; " + f"supported: {list(_REDUCE_OP_MAP.keys())}" + ) dist.all_reduce(reduced, op=_REDUCE_OP_MAP[op_str], group=group)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/etha/tensor_bus/agent.py` around lines 693 - 707, Add a defensive check before using _REDUCE_OP_MAP in the loop over pair_state.source_partial_groups: for each op_str validate it exists in _REDUCE_OP_MAP and if not raise a clear ValueError that includes the offending op_str, the pair_name (or pair_state) and a list of supported ops (sorted(_REDUCE_OP_MAP.keys())); only call dist.all_reduce with _REDUCE_OP_MAP[op_str] after validation so the error is informative rather than a KeyError when producing reduced, and keep the rest of the logic that appends to batch_state.pair_reduced_sources and assigns source_tensor_for_send unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@bench/partial_reduce_prototype.py`:
- Around line 347-353: Validate that CLI arguments for --mib and --chunk-mib are
positive integers before using them: add a small validator (e.g., a positive_int
type function used in parser.add_argument or explicit checks after parse_args)
that rejects values <= 0 with an argparse-friendly error (ArgumentTypeError or
sys.exit with message). Ensure you reference and validate the values produced
for args.mib and args.chunk_mib (and any subsequent local variables derived from
them that compute chunk_numel or shapes) so zero or negative inputs are rejected
early and cannot produce step=0 in range() or invalid shapes; apply the same
guard for the duplicate occurrences around the later block noted (the other
parser.add_argument uses of "--mib"/"--chunk-mib").
In `@bench/run_partial_prototype.sh`:
- Around line 25-27: Replace the nonstandard variables used to export NODE_RANK
and MASTER_ADDR in run_partial_prototype.sh (and mirror the same change in
run_benchmark.sh): use the standard SLURM variable SLURM_PROCID (or SLURM_NODEID
if you want node index) to set NODE_RANK instead of JOB_COMPLETION_INDEX, derive
MASTER_ADDR from the canonical SLURM_JOB_NODELIST (e.g., resolve the first
hostname from SLURM_JOB_NODELIST using scontrol/show hostnames or equivalent)
rather than SLURM_JOB_FIRST_NODE_IP, and keep MASTER_PORT as a fixed port;
update the export lines that define NODE_RANK, MASTER_ADDR, and MASTER_PORT
accordingly so the scripts no longer rely on undeclared/nonportable env vars and
won’t fail under set -euo pipefail.
In `@README.md`:
- Around line 36-40: The docs currently state that source `Partial` is collapsed
via a "chunk-level all-reduce" which contradicts the implementation; update the
README text that describes "Supported source placements" so it says `Partial` is
collapsed to `Replicate` on the source mesh via a whole-tensor all-reduce
performed before chunking/sending (replace "chunk-level" with "whole-tensor"),
and keep the note that `Partial` on the target side is rejected; update the
sentence around the `Partial`/`Replicate` wording to reflect this exact ordering
and granularity.
In `@tests/test_partial_chunk_reduce.py`:
- Around line 3-6: Update the module docstring to remove the outdated claim that
get_m2m_map "currently rejects Partial placements with NotImplementedError" and
instead state that get_m2m_map now implements source-side Partial support;
clarify that these tests specifically validate the standalone streaming
chunk-level reduce algorithm (comparing it to PyTorch's DTensor.redistribute)
and are not full etha integration tests, so they remain focused on chunk-level
behavior despite the added Partial handling in get_m2m_map.
---
Nitpick comments:
In `@src/etha/comm/get_m2m_map.py`:
- Around line 39-42: _rank_multi_idx currently uses list.index() which raises a
cryptic ValueError if the requested rank isn't present; update _rank_multi_idx
to explicitly check whether rank exists in mesh_tensor (e.g., via
torch.eq(mesh_tensor, rank).any() or checking the flattened list) and if not
raise a ValueError with a clear message including the requested rank and
mesh_tensor.shape, or alternatively document the precondition in the function
docstring stating that rank must be present; ensure references to the symbol
_rank_multi_idx and the rank/mesh_tensor values appear in the message or
docstring so callers get a helpful diagnostic.
In `@src/etha/tensor_bus/agent.py`:
- Around line 693-707: Add a defensive check before using _REDUCE_OP_MAP in the
loop over pair_state.source_partial_groups: for each op_str validate it exists
in _REDUCE_OP_MAP and if not raise a clear ValueError that includes the
offending op_str, the pair_name (or pair_state) and a list of supported ops
(sorted(_REDUCE_OP_MAP.keys())); only call dist.all_reduce with
_REDUCE_OP_MAP[op_str] after validation so the error is informative rather than
a KeyError when producing reduced, and keep the rest of the logic that appends
to batch_state.pair_reduced_sources and assigns source_tensor_for_send
unchanged.
In `@tests/test_communication_symmetric_mesh.py`:
- Around line 72-80: The variables m2m_map_b_to_a, source_slicers_b, and
target_slicers_a returned from get_m2m_map are intentionally unused for the
deadlock test—prefix them with an underscore (e.g., _m2m_map_b_to_a,
_source_slicers_b, _target_slicers_a) in the assignment to silence the linter
and clarify intent while leaving the get_m2m_map(...) call unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 54212542-26f2-4d2c-a8ef-17c69d2393cd
📒 Files selected for processing (14)
README.mdbench/partial_reduce_prototype.pybench/run_partial_prototype.shbench/transfer_benchmark.pydocs/index.mdsrc/etha/comm/get_m2m_map.pysrc/etha/tensor_bus/agent.pysrc/etha/tensor_bus/batch_state.pysrc/etha/tensor_bus/client.pysrc/etha/tensor_bus/pair_state.pytests/test_communication_cpu.pytests/test_communication_replicate_shard.pytests/test_communication_symmetric_mesh.pytests/test_partial_chunk_reduce.py
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/etha/comm/ir.py (1)
65-67:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftReduce Partial chunks in the source dtype, not the wire dtype.
prepare()converts source buffers totransfer_dtypebefore eitherChunk.apply_partial_reduce()orBucket._reduce_partial()runs. For mixed-dtype transfers that changes the logical reduction result, not just the bytes on the wire—for example, an FP32 Partial sum will be collapsed in FP16 if the target side negotiated a narrower transfer dtype. The Partial all-reduce needs to happen in the original source dtype, and only the reduced buffer should be cast for transport.Also applies to: 70-83, 146-151
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/etha/comm/ir.py` around lines 65 - 67, prepare() currently casts source buffers to transfer_dtype before running partial reductions, which causes reductions (Chunk.apply_partial_reduce and Bucket._reduce_partial) to operate in the wire dtype; change the flow so that when self.is_source is true you do NOT convert buffer to self.transfer_dtype until after any partial-reduce logic has run (i.e., run Chunk.apply_partial_reduce and Bucket._reduce_partial using the original buffer.dtype), then cast the already-reduced buffer to transfer_dtype for transport; apply the same fix to the other occurrences noted (the blocks around lines 70-83 and 146-151) so partial reductions always use the source dtype and only the final reduced buffer is converted for the wire.
🧹 Nitpick comments (1)
tests/test_communication_replicate_shard.py (1)
32-42: ⚡ Quick winFix
_local_shapeto handle unevenShardchunking per-rank
tests/test_communication_replicate_shard.py’s_local_shape(lines ~32-42) uses rank-agnostic floor-division (local[p.dim] //= mesh_shape[mesh_dim]). DTensorShardlocal sizes for non-even splits follow torch.chunk-style uneven partitioning (ceil-based start/end per rank), so this helper would compute the wrongDTensor.from_localshape for future cases where a sharded tensor dim isn’t divisible by the mesh shard count (leading to false failures/mismatched ground truth).Current parametrized cases appear to use evenly divisible shapes, so this won’t trip today—but it’s unsafe for extension. Make
_local_shapetake the source rank / mesh coordinates and compute per-rank shard lengths (or derive shapes by constructing a dummyDTensorand calling.to_local()).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_communication_replicate_shard.py` around lines 32 - 42, The helper _local_shape currently uses floor-division (local[p.dim] //= mesh_shape[mesh_dim]) which fails for uneven Shard splits; update _local_shape to compute per-rank shard lengths using torch.chunk/torch.tensor_split semantics (or by constructing a dummy DTensor and calling .to_local()) based on the rank/mesh coordinates instead of global floor-division so each Shard dim uses the correct ceil/uneven chunk size; locate _local_shape and change its logic to accept or derive the target mesh coordinate (rank) for each mesh_dim and compute the start/end (or chunk sizes) for that rank for p.dim, returning the per-rank tuple required by DTensor.from_local/MS tests.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/etha/comm/get_m2m_map.py`:
- Around line 180-186: The round-robin cursor must be lifted out of the per-cell
loop so primary selection rotates across the entire (cell,owner) stream:
introduce a persistent cursor (e.g., rr_cursor) initialized once for the
subgroup before iterating target_owners_per_cell, replace members[slot_idx %
n_members] with members[rr_cursor] when assigning primary in the loop that
currently uses slot_idx, and increment rr_cursor = (rr_cursor + 1) % n_members
for each owner entry processed; keep the rest of the logic
(m2m_map[primary][local_pos].append((tr, tlocal)), members, n_members,
local_pos) the same so shadow assignment still happens for non-primary members.
---
Outside diff comments:
In `@src/etha/comm/ir.py`:
- Around line 65-67: prepare() currently casts source buffers to transfer_dtype
before running partial reductions, which causes reductions
(Chunk.apply_partial_reduce and Bucket._reduce_partial) to operate in the wire
dtype; change the flow so that when self.is_source is true you do NOT convert
buffer to self.transfer_dtype until after any partial-reduce logic has run
(i.e., run Chunk.apply_partial_reduce and Bucket._reduce_partial using the
original buffer.dtype), then cast the already-reduced buffer to transfer_dtype
for transport; apply the same fix to the other occurrences noted (the blocks
around lines 70-83 and 146-151) so partial reductions always use the source
dtype and only the final reduced buffer is converted for the wire.
---
Nitpick comments:
In `@tests/test_communication_replicate_shard.py`:
- Around line 32-42: The helper _local_shape currently uses floor-division
(local[p.dim] //= mesh_shape[mesh_dim]) which fails for uneven Shard splits;
update _local_shape to compute per-rank shard lengths using
torch.chunk/torch.tensor_split semantics (or by constructing a dummy DTensor and
calling .to_local()) based on the rank/mesh coordinates instead of global
floor-division so each Shard dim uses the correct ceil/uneven chunk size; locate
_local_shape and change its logic to accept or derive the target mesh coordinate
(rank) for each mesh_dim and compute the start/end (or chunk sizes) for that
rank for p.dim, returning the per-rank tuple required by DTensor.from_local/MS
tests.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: baf1703b-d697-4c64-a84f-ea4addd2b3e5
📒 Files selected for processing (8)
src/etha/comm/execution.pysrc/etha/comm/get_buckets.pysrc/etha/comm/get_chunks.pysrc/etha/comm/get_m2m_map.pysrc/etha/comm/ir.pysrc/etha/comm/transfer.pysrc/etha/tensor_bus/agent.pytests/test_communication_replicate_shard.py
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/etha/comm/ir.py (1)
70-88:⚠️ Potential issue | 🔴 Critical | 🏗️ Heavy liftMove the Partial all-reduce ahead of chunking.
Chunk.apply_partial_reduce()andBucket._reduce_partial()make the collective sequence depend on local chunk/bucket segmentation. In mixedShard + Partialroutes, subgroup members are not guaranteed to build a 1:1 chunk layout, so ranks can reduce different slices or issue a different number/order of collectives. That turns into silent corruption or a hang. The safer contract is the one described in the PR objectives: clone the local source buffer once, all-reduce it once per partial subgroup, then derive chunks from that reduced tensor.Also applies to: 145-169
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/etha/comm/ir.py` around lines 70 - 88, The current implementation performs Partial all-reduces after chunking which can make collectives depend on local chunk/bucket segmentation; change apply_partial_reduce so it clones the full source buffer if it aliases the tensor (keep the existing untyped_storage().data_ptr() check), then perform one dist.all_reduce per subgroup on that full buffer (iterate source_partial_groups and call dist.all_reduce on self.buffer before any chunking), and ensure any subsequent chunking derives from the already-reduced self.buffer rather than issuing per-chunk collectives; apply the same refactor pattern to Chunk.apply_partial_reduce and Bucket._reduce_partial so all Partial subgroup reductions happen once on the full source buffer prior to slicing into chunks.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/test_distributed_model_transfer.py`:
- Around line 73-79: The four ports picked with successive calls to
_find_free_port() (agent_port, train_port, inference_port, store_port) can
collide because _find_free_port may return duplicates; change the logic to
produce four unique ports per test case by repeatedly calling _find_free_port()
into a set until its size is 4 (or implement a helper like
_find_n_unique_free_ports(n) that loops/validates uniqueness) and then assign
the resulting unique values to agent_port, train_port, inference_port, and
store_port so no two variables hold the same port in a single run.
---
Outside diff comments:
In `@src/etha/comm/ir.py`:
- Around line 70-88: The current implementation performs Partial all-reduces
after chunking which can make collectives depend on local chunk/bucket
segmentation; change apply_partial_reduce so it clones the full source buffer if
it aliases the tensor (keep the existing untyped_storage().data_ptr() check),
then perform one dist.all_reduce per subgroup on that full buffer (iterate
source_partial_groups and call dist.all_reduce on self.buffer before any
chunking), and ensure any subsequent chunking derives from the already-reduced
self.buffer rather than issuing per-chunk collectives; apply the same refactor
pattern to Chunk.apply_partial_reduce and Bucket._reduce_partial so all Partial
subgroup reductions happen once on the full source buffer prior to slicing into
chunks.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1a020a45-e815-4a74-818c-b3f1686391e2
⛔ Files ignored due to path filters (1)
pixi.lockis excluded by!**/*.lock
📒 Files selected for processing (5)
pixi.tomlsrc/etha/comm/ir.pytests/distributed_model_transfer/common.pytests/test_communication_replicate_shard.pytests/test_distributed_model_transfer.py
✅ Files skipped from review due to trivial changes (1)
- tests/distributed_model_transfer/common.py
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@scripts/verify_partial_trace.py`:
- Line 53: The code currently recomputes partial_red after calling get_m2m_map;
instead use the partial_red value returned by get_m2m_map (the variable m2m_map,
src_sl, tgt_sl, partial_red) and remove the later derivation that re-creates
partial_red from inputs. Locate the block that recomputes/derives partial_red
(the lines later in the file that read inputs and build a new partial_red) and
replace those uses with the existing partial_red variable, deleting the
redundant computation and ensuring any assertions or routing metadata
comparisons reference the returned partial_red.
- Line 82: The print uses an f-string with no placeholders which triggers Ruff
F541; replace the placeholderless f-string in verify_partial_trace.py (the print
statement that currently reads the message about every source rank having at
least one primary entry) with a normal string literal (e.g., use print(" every
source rank has at least one primary entry")) so the f-prefix is removed.
In `@src/etha/comm/utils.py`:
- Around line 31-33: Validate mesh_dim_idx before computing
other_dims/other_sizes: check mesh_tensor.dim() (ndim) and raise a clear
exception (e.g., ValueError or IndexError) if mesh_dim_idx is out of range or
negative (require 0 <= mesh_dim_idx < ndim). Place this check immediately before
the lines that compute ndim, other_dims, and other_sizes so invalid indices fail
fast and include the values of mesh_dim_idx and ndim in the error message for
easier debugging.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 14c2ba6d-c32d-455c-b8a3-311a1f9de88c
📒 Files selected for processing (5)
.gitignorescripts/verify_partial_trace.pysrc/etha/comm/get_m2m_map.pysrc/etha/comm/utils.pysrc/etha/tensor_bus/agent.py
✅ Files skipped from review due to trivial changes (1)
- .gitignore
- test_distributed_model_transfer: dedup ports via set so two rapid bind(0) calls can't return the same number. - test_partial_chunk_reduce: drop stale docstring claim that get_m2m_map rejects Partial (Partial is supported as of this PR).
Replaces bench/partial_reduce_prototype.py + run_partial_prototype.sh with real-pipeline coverage in transfer_benchmark.py: - BENCH_CONFIGS loop runs the existing mesh × shape sweep once per source placement (no_partial baseline + partial_dp). - Partial sources build NCCL sub-groups via a benchmark-local mirror of agent._create_partial_groups, fed to map_to_chunk_ops. - Baseline path becomes "user-side DTensor.redistribute(Partial -> Replicate) + gather_broadcast" timed end-to-end, which is the honest comparison for what a user would do without etha's in-pipeline reduce. Replaces the prototype's three hypothetical paths (which never called etha and were a design-time artifact). - Plot legend / filename pick up config_name + has_partial so the no_partial and partial_dp runs land in distinct charts. - _placements_to_str migrated to match-case; _StridedShard branch moved above Shard since the former subclasses the latter and match cases check in order. The prototype shipped chunk-level reduce as a design proposal; once the design landed in PR #98 it had no runtime relationship with the actual code and was just bit-rot waiting to happen. Git history keeps it.
There was a problem hiding this comment.
🧹 Nitpick comments (3)
bench/transfer_benchmark.py (3)
664-674: 💤 Low valueDivision by
profile_iteris incorrect for single-call timing.
get_m2m_mapis called once, notprofile_itertimes, so dividing the elapsed time byprofile_iterproduces an artificially small value. This doesn't affect benchmark correctness (the real timing uses CUDA events), but the printed diagnostic is misleading.Suggested fix
- map_time = (time.perf_counter() - start_time) / profile_iter + map_time = time.perf_counter() - start_time🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@bench/transfer_benchmark.py` around lines 664 - 674, The printed diagnostic divides the single-call elapsed time by profile_iter, misreporting get_m2m_map timing; change the computation of map_time (where get_m2m_map is invoked and start_time is recorded) to compute the raw elapsed time without dividing by profile_iter (i.e., map_time = time.perf_counter() - start_time) and keep the existing print of map_time for rank 0 so the diagnostic reflects the actual single-call duration.
604-623: 💤 Low valueConsider using ASCII
xinstead of Unicode×.Static analysis flags the multiplication sign at line 613. While visually nicer,
×can render incorrectly in some terminal encodings.Suggested fix
- print(f"Total {len(mesh_combinations)} mesh combinations × {len(BENCH_CONFIGS)} configs to test") + print(f"Total {len(mesh_combinations)} mesh combinations x {len(BENCH_CONFIGS)} configs to test")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@bench/transfer_benchmark.py` around lines 604 - 623, The print statement that shows total tests uses the Unicode multiplication sign (×) which can break in some terminals; update the f-string that prints f"Total {len(mesh_combinations)} mesh combinations × {len(BENCH_CONFIGS)} configs to test" to use an ASCII "x" instead (e.g., " x ") to avoid encoding issues—locate the print near the BENCH_CONFIGS declaration and the loop that prints config headers and replace the Unicode character accordingly.
702-731: 💤 Low valueDivision by
profile_iteris incorrect for IR generation timing.The loop iterates
num_tensors_per_batchtimes, notprofile_itertimes. Dividing byprofile_iterproduces a misleading per-iteration time in the diagnostic print.Suggested fix
- ir_gen_time = (time.perf_counter() - start_time) / profile_iter + ir_gen_time = time.perf_counter() - start_time🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@bench/transfer_benchmark.py` around lines 702 - 731, The IR generation timing divides by profile_iter but the measured block runs num_tensors_per_batch times; update the calculation of ir_gen_time to divide elapsed time by num_tensors_per_batch (or remove the division to report total time) where ir_gen_time is computed (after start_time is set and after map_to_chunk_ops / chunk_to_bucket_ops are called); ensure the printed value uses the corrected divisor so the reported IR generation time reflects per-tensor or total timing as intended (references: start_time, ir_gen_time, num_tensors_per_batch, profile_iter, map_to_chunk_ops, chunk_to_bucket_ops).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@bench/transfer_benchmark.py`:
- Around line 664-674: The printed diagnostic divides the single-call elapsed
time by profile_iter, misreporting get_m2m_map timing; change the computation of
map_time (where get_m2m_map is invoked and start_time is recorded) to compute
the raw elapsed time without dividing by profile_iter (i.e., map_time =
time.perf_counter() - start_time) and keep the existing print of map_time for
rank 0 so the diagnostic reflects the actual single-call duration.
- Around line 604-623: The print statement that shows total tests uses the
Unicode multiplication sign (×) which can break in some terminals; update the
f-string that prints f"Total {len(mesh_combinations)} mesh combinations ×
{len(BENCH_CONFIGS)} configs to test" to use an ASCII "x" instead (e.g., " x ")
to avoid encoding issues—locate the print near the BENCH_CONFIGS declaration and
the loop that prints config headers and replace the Unicode character
accordingly.
- Around line 702-731: The IR generation timing divides by profile_iter but the
measured block runs num_tensors_per_batch times; update the calculation of
ir_gen_time to divide elapsed time by num_tensors_per_batch (or remove the
division to report total time) where ir_gen_time is computed (after start_time
is set and after map_to_chunk_ops / chunk_to_bucket_ops are called); ensure the
printed value uses the corrected divisor so the reported IR generation time
reflects per-tensor or total timing as intended (references: start_time,
ir_gen_time, num_tensors_per_batch, profile_iter, map_to_chunk_ops,
chunk_to_bucket_ops).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7d82b319-d63d-4154-af56-b5b3e0183c54
📒 Files selected for processing (7)
README.mdbench/transfer_benchmark.pysrc/etha/comm/get_m2m_map.pysrc/etha/tensor_bus/agent.pytests/test_communication_replicate_shard.pytests/test_distributed_model_transfer.pytests/test_partial_chunk_reduce.py
✅ Files skipped from review due to trivial changes (1)
- README.md
1b7f0f1 to
d4ad52e
Compare
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d4ad52e90e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 4a612b4fa0
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
@codex review |
…nsion Extends ``get_m2m_map`` to handle source-side ``Partial`` placements by substituting ``Partial → Replicate`` for the cross-process-group trace and inserting SHADOW entries for the dropped sub-group peers so every member participates in the chunk-level all-reduce that collapses Partial to Replicate before send. Target-side ``Partial`` is rejected — cross-PG decomposition of a logical tensor into a Partial contribution is not uniquely defined. Key changes: * ``get_m2m_map`` substitutes Partial -> Replicate, computes ``source_partial_reductions`` per Partial dim, and calls ``_expand_partial_shadows`` to add SHADOW peers (transitive closure via union-find over the Partial sub-groups). * ``Chunk.prepare`` performs the source-side all-reduce in source dtype before casting to ``transfer_dtype`` — matching DTensor's ``Partial -> Replicate`` semantics. Mixed-dtype transfers (fp32 source, bf16 wire) reduce at full precision. * ``Chunk.bucket_key`` adds ``cell_key`` for SHADOW chunks (dst_ranks=()) so each SHADOW cell forms its own bucket. PRIMARY chunks bucket normally; SHADOW has no wire transfer, so per-cell bucketing only affects launch count, not bandwidth. * ``enumerate_partial_subgroup_ranks`` validates ``mesh_dim_idx`` bounds. * New ``transfer_type=SHADOW`` for chunks that participate in the reduce but not in send/recv. Tests: * ``test_partial_chunk_reduce`` validates the standalone chunk-level reduce algorithm against DTensor.redistribute. * ``test_communication_replicate_shard`` adds five Partial cases (1D/2D meshes, single/multi Partial, asymmetric mesh sizes) covering both ``chunk_comm`` and ``bucket_comm``, plus a fp32-source/bf16-target mixed-dtype regression test that fails under the previous reduce-after-cast order.
… point Propagates source-Partial support from the comm layer into the ``TensorBusClient`` / ``TensorBusAgent`` entry point. * Add ``_create_partial_groups`` to bootstrap NCCL sub-groups for each Partial dim. ``new_group`` is WORLD-collective, so every rank must participate even non-members; the helper reuses the full-source group when a sub-group spans the entire source side. * ``_handle_init_pair`` detects Partial on each side and skips the direction whose target side has Partial (``m2m_send`` or ``m2m_recv`` stays ``None``) rather than crashing in ``get_m2m_map``. Without this skip, source-Partial pairs could not be initialised at all because the reverse-direction map construction would raise ``NotImplementedError``. * ``register_tensors`` builds chunks per available direction (``and`` → ``or``); ``_handle_transfer`` raises if the user calls the disabled direction. * ``PairState.source_partial_groups`` stores the local-side groups for use by ``register_tensors``. * Consolidate the canonical (first, second) swap across mesh, placements, ranks, group, and partial groups behind a single ``_order(loc, rem)`` helper at the top of ``_handle_init_pair``.
* Add ``transformers`` to the dev environment so the test (which loads a HF model) can run without manual package install. * Allocate four free ephemeral ports per case (agent / train / inference / store) via a set-based ``_find_n_free_ports`` helper, instead of hard-coded master ports that collided with the agent's pair-handshake TCPStores between parametrised cases.
Extends ``bench/transfer_benchmark.py`` from a single source placement to
three configs swept per mesh:
* ``no_partial`` — ``[Replicate, Shard(0), Replicate, Shard(1)]``
* ``replicate_dp`` — ``[Replicate, Replicate, Replicate, Shard(1)]``
* ``partial_dp`` — ``[Replicate, Partial("sum"), Replicate, Shard(1)]``
For ``partial_dp`` the baseline is amended to ``pre-reduce → gather-broadcast``
so the Partial → Replicate cost is counted on both sides for a fair
comparison. Plot legend distinguishes the chunk-based path with
"(Partial in-pipeline)".
Memory: keep only ``origin_tensors[0]`` per shape (used as the baseline
reference) and pass ``num_tensors`` explicitly to
``benchmark_single_shape`` — the previous list of 25 full-shape tensors
on every rank consumed ~56 GiB per rank at shape (24576, 24576) and OOM'd
node027.
Plots: replace the previous 8 single-config images with 24 (8 meshes × 3
configs) and lay them out three-across in ``bench/README.md``. Cleaned up
the placement description there too — it described an older single-config
run.
* README ``Placements`` note: source supports Shard / Replicate / Partial (collapsed to Replicate via a source-side all-reduce before send); target Partial is rejected. * docs/index.md: mirror the same note. * .gitignore: exclude ``.claude/`` (scheduled tasks, agent transcripts).
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3beb004d9b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| mesh_2_partial_groups = _create_partial_groups( | ||
| second_mesh_tensor, partial_red_2, self.rank, second_ranks, second_group |
There was a problem hiding this comment.
Skip partial-group creation when reductions are unavailable
When one side uses Partial placements, get_m2m_map is intentionally skipped for the reverse direction, leaving partial_red_1/partial_red_2 as None (set earlier in this function). This call passes that None directly into _create_partial_groups, which iterates partial_reductions and raises TypeError: 'NoneType' object is not iterable. As a result, init_pair fails for valid one-sided-Partial setups before any transfer can start.
Useful? React with 👍 / 👎.
3beb004 to
3daa188
Compare
What did you do
Replace the
NotImplementedErrorfrom PR #97 with a real implementation: sourcePartialplacements now flow through the M2M routing + chunk pipeline. TargetPartialis still rejected (cross-PG decomposition isn't well-defined).Design (after a few wrong turns; see
bench/partial_reduce_prototype.pyfor the data):Hand-rolled routing for Partial source (in
get_m2m_map): when source placements contain Partial, we skip the DTensor-encoding/full_tensor routing trace (which collapses Replicate-equivalent peers down to one rank) and compute the m2m_map deterministically from mesh + placements. Each Partial sub-group's (cell, target_owner) ship slots are round-robin assigned to sub-group members so every source rank gets a unique ship task -- no shadow chunks, no redundant transfers, load balanced across peers.Whole-tensor reduce before chunking (in
agent._handle_transfer'sregister_tensorspath): for each source-Partial pair, clone the user's local Partial buffer and run an all-reduce on each Partial sub-group; chunks then point at the reduced buffer. The clone survives inBatchState.pair_reduced_sourcesuntil transfer completes.We tried chunk-level reduce first. It's mathematically incorrect for mixed Shard+Partial: peers ship different cells, so reducing their chunk buffers in a single collective mixes different logical positions across peers. The fix would require shadow-chunk participation from non-shipping peers; for Phase 1 we accept the +1x tensor memory cost and leave shadow-chunk optimization to a follow-up.
NCCL sub-group creation (in
agent._handle_init_pair): for each Partial mesh dim, create the corresponding source-side NCCL sub-group viaget_or_create_process_group(cached). Reuseslocal_groupwhen the sub-group spans the entire source side (the common 1D-mesh case), avoiding a redundant new_group bootstrap.API changes:
get_m2m_mapreturns a 4-tuple now; the newsource_partial_reductions: list[(mesh_dim_idx, reduce_op_str)]lets callers know which dims need a sub-group reduce. All callers in this PR are updated.M2MMap/PairState/BatchStateextended with corresponding fields.init_pairdocstring + README +docs/index.mdupdated to say Partial is supported on the source side.New test cases
tests/test_communication_replicate_shard.py: 4 new Partial cases (1D Partial → Replicate / Shard, 2D Shard+Partial, double Partial), exercising the full chunk pipeline throughchunk_comm.tests/test_partial_chunk_reduce.py(new, 21 cases): mesh ndim × reduce_op × dtype × non-contiguous × chunk-size edge sweep, validating Partial-reduce correctness against PyTorch'sDTensor.redistributeground truth.Test results
ruff format + ruff check clean; pre-commit hooks pass.
Other comments
NotImplementedErroris still raised in two known sub-cases (Phase 2 followups):Partial(rejected by design)cells * target_owners < sub_group_size(rare in practice)bench/partial_reduce_prototype.py) shows the design trade-offs that motivated this approach. Chunk-overlap optimization (where reduce/send overlap on independent NVLink/IB paths) is left for a follow-up issue.Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests
Chores