Skip to content

perf: merge per-dim Partial all_reduces into one union-group all_reduce #100

@junjzhang

Description

@junjzhang

Background

For sources with multiple Partial dimensions, the current chunk-level reduce iterates each Partial sub-group sequentially:

# Chunk.apply_partial_reduce
for group, op_str in self.source_partial_groups:
    dist.all_reduce(self.buffer, op=_REDUCE_OP_MAP[op_str], group=group)

For an N-Partial-dim source with sub-group sizes k1, k2, ..., kN, each rank issues N sequential all_reduce calls per chunk. This is correct and matches PyTorch DTensor's redistribute semantics, but it leaves NCCL launch overhead on the table.

Observation

When all Partial dimensions share the same reduce_op and that op is associative (SUM, MAX, MIN, PRODUCT, AVG with matching divisor accounting), the sequential per-dim reduces are mathematically equivalent to a single all_reduce on the union of all participating ranks — which is exactly the connected component we already compute in _expand_partial_shadows.

Examples (using (Partial, Partial) on a (2, 3) mesh, full component {0..5}):

  • SUM: (D_0+D_3) + (D_1+D_4) + (D_2+D_5) = single all_reduce SUM on {0..5}
  • AVG: sequential divides by 2 * 3 = 6; merged AVG on 6-rank group also divides by 6. |component| always equals the product of per-dim sub-group sizes when no Shard interferes.
  • MAX / MIN / PRODUCT: trivially associative.

Proposed change

In _create_partial_groups (or a new helper), if all partial_reductions use the same op:

  1. Compute the union of all sub-groups containing this rank (= the connected component for this rank).
  2. Create one NCCL communicator for the union group.
  3. Return [(union_group, op)] instead of the per-sub-group list.

Drop-in compatible with the existing chunk pipeline — apply_partial_reduce already iterates a list, just with one entry now.

Tradeoffs

sequential (current) merged
NCCL launches per chunk N (number of Partial dims) 1
Communicators several smaller groups one larger group
Latency N * (α + β*B/k_i) ish `α + β*B/
Bandwidth utilization depends on topology of small groups depends on topology of union
PyTorch DTensor parity matches diverges in mechanism (same end result)

Topology caveat: small sub-groups can be intra-node (NVLink only) while the union group spans cross-node IB. In that case the sequential path can be faster despite more launches. Benchmark before committing to merge as the default.

Out of scope

  • Mixed reduce ops (e.g. Partial(\"sum\"), Partial(\"max\")): cannot merge. Fall back to sequential.
  • Single Partial dim: degenerate — already "merged" since there's only one sub-group level. No change needed.

Suggested investigation steps

  1. Add a benchmark that runs (Partial, Partial) reduce on the chunk-level path with both implementations, on a representative cluster topology (probably intra-node multi-Partial vs cross-node multi-Partial).
  2. Decide policy: merge unconditionally, merge only when union stays intra-node, or expose as flag.
  3. Implement in _create_partial_groups.

Acceptance

  • Benchmark numbers comparing sequential vs merged for at least one realistic multi-Partial config.
  • Correctness regression tests for both ops (case 9 / case 10 in test_communication_replicate_shard) pass under merged path.
  • AVG correctness explicitly covered (divisor accounting is the trickiest case).

Priority

Low. Multi-Partial sources are uncommon in real etha workloads (trainer ↔ vLLM weight sync rarely produces them); the merge optimization should be done only with concrete benchmark motivation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions