You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This document describes the design for adding Context Parallelism (CP) to dnet, enabling long-context inference (128K+ tokens) by distributing sequence dimensions across multiple Apple Silicon devices. CP complements the existing RingStrategy (layer/pipeline parallelism) with a new axis of parallelization.
Goals
Primary: Enable 128K+ context inference across heterogeneous device clusters
Secondary: Achieve near-linear latency scaling with device count
Constraint: Zero approximations to attention computation (exact attention)
Non-Goals (v1)
Mixed CP + pipeline parallelism (future work)
Training support (inference-only)
CUDA/AMD backends (Apple Silicon only)
2. Background
2.1 Current Architecture
graph LR
subgraph "Pipeline Parallelism"
A[API] --> S1[Shard 1<br>Layers 0-10]
S1 --> S2[Shard 2<br>Layers 11-20]
S2 --> S3[Shard 3<br>Layers 21-31]
S3 -->|token| A
end
Loading
The current dnet uses pipeline parallelism: each shard owns a subset of layers, and activations flow through the ring. This works well for large models but does not reduce per-device context memory.
2.2 Problem Statement
Context Length
KV Cache (FP16, 7B model)
Fits in 24GB RAM?
8K
~1 GB
Yes
32K
~4 GB
Yes
128K
~16 GB
Tight
512K
~64 GB
No
1M
~128 GB
No
Pipeline parallelism does not shard KV cache across devices. Context Parallelism solves this.
2.3 Ring Attention
Ring Attention (Liu et al., 2023) distributes the sequence dimension across devices:
# src/dnet/core/cp/sharding.pydefload_balanced_shard(
tokens: mx.array, # [seq_len, ...]num_ranks: int,
rank_id: int,
) ->tuple[mx.array, list[int]]:
""" Shard tokens with load balancing for causal attention. Returns: sharded_tokens: tokens for this rank chunk_indices: original positions (for unsharding) """seq_len=tokens.shape[0]
chunk_size=seq_len// (2*num_ranks)
# Assign chunks (i, 2N-i-1) to rank ichunk_a=rank_idchunk_b=2*num_ranks-rank_id-1start_a=chunk_a*chunk_sizeend_a=start_a+chunk_sizestart_b=chunk_b*chunk_sizeend_b=start_b+chunk_sizeifchunk_b<2*num_ranks-1elseseq_lensharded=mx.concatenate([tokens[start_a:end_a], tokens[start_b:end_b]])
chunk_indices=list(range(start_a, end_a)) +list(range(start_b, end_b))
returnsharded, chunk_indices
4.1.2 Merge Attention Operator
When computing blockwise attention across distributed KV, each device produces partial outputs with local softmax denominators. These must be merged correctly.
Math: For blocks with outputs O_i, max scores m_i, and log-sum-exp l_i:
# src/dnet/core/cp/merge_attention.py@dataclassclassPartialAttentionOutput:
output: mx.array# [batch, seq, heads, dim]max_score: mx.array# [batch, seq, heads]log_sum_exp: mx.array# [batch, seq, heads]defmerge_partial_attention(
partials: list[PartialAttentionOutput],
) ->mx.array:
"""Merge partial attention outputs with numerically stable rescaling."""# Find global max for stabilitym_global=partials[0].max_scoreforpinpartials[1:]:
m_global=mx.maximum(m_global, p.max_score)
# Rescale and accumulatenumerator=mx.zeros_like(partials[0].output)
denominator=mx.zeros_like(partials[0].log_sum_exp)
forpinpartials:
scale=mx.exp(p.max_score-m_global)
numerator+=scale[..., None] *p.log_sum_exp[..., None] *p.outputdenominator+=scale*p.log_sum_expreturnnumerator/denominator[..., None]
4.1.3 Ring Communication
gRPC-based ring for passing KV or Q blocks between CP ranks.
# src/dnet/core/cp/ring_comm.pyclassCPRingCommunicator:
"""Manages ring communication for context parallelism."""def__init__(
self,
rank_id: int,
num_ranks: int,
discovery: AsyncDnetP2P,
):
self.rank_id=rank_idself.num_ranks=num_ranksself._prev_rank= (rank_id-1) %num_ranksself._next_rank= (rank_id+1) %num_ranksself._discovery=discovery# gRPC channelsself._prev_channel: Optional[aio_grpc.Channel] =Noneself._next_channel: Optional[aio_grpc.Channel] =Noneasyncdefsend_recv(
self,
send_data: bytes,
tag: str,
) ->bytes:
""" Simultaneously send to next rank and receive from previous rank. Overlaps communication with computation when used correctly. """send_task=asyncio.create_task(self._send_to_next(send_data, tag))
recv_task=asyncio.create_task(self._recv_from_prev(tag))
awaitsend_taskreturnawaitrecv_task
4.2 Ring Attention Variants
4.2.1 Pass-KV (Full Prefill)
Best for full prefill where KV is smaller than Q (GQA models: 8 KV heads vs 128 Q heads).
# src/dnet/shard/adapters/context_parallel.pyasyncdefring_pass_kv_attention(
self,
query: mx.array, # Local Q chunkkey: mx.array, # Local K chunk (will be rotated)value: mx.array, # Local V chunk (will be rotated)
) ->mx.array:
""" Ring attention with KV rotation. Algorithm: 1. Compute local attention: Attn(Q_local, KV_local) 2. For i in 1..N-1: a. SendRecv: send KV to next, receive from prev b. Compute partial attention with received KV c. Accumulate partial outputs 3. Merge all partial outputs """partials: list[PartialAttentionOutput] = []
# Local attention firstlocal_out=self._compute_partial_attention(query, key, value)
partials.append(local_out)
current_k, current_v=key, valueforstepinrange(1, self.num_ranks):
# Overlap: send current KV while computing with previouskv_bytes=self._serialize_kv(current_k, current_v)
recv_bytes=awaitself.ring_comm.send_recv(kv_bytes, f"kv_{step}")
current_k, current_v=self._deserialize_kv(recv_bytes)
# Compute attention with received KVpartial=self._compute_partial_attention(query, current_k, current_v)
partials.append(partial)
returnmerge_partial_attention(partials)
4.2.2 Pass-Q (Decode / High Cache Hit)
Best for decode (single token Q) or partial prefill with high cache hit rate.
asyncdefring_pass_q_attention(
self,
query: mx.array, # Local Q chunk (will be rotated)key: mx.array, # Full local K (stationary)value: mx.array, # Full local V (stationary)
) ->mx.array:
""" Ring attention with Q rotation. Key difference: After ring loop, partial outputs are scattered across ranks. Requires All2All to redistribute. """# Compute attention for local Q against local KVlocal_outputs: dict[int, PartialAttentionOutput] = {}
current_q=querysource_rank=self.rank_idforstepinrange(self.num_ranks):
# Compute attention: Q from source_rank, KV from localpartial=self._compute_partial_attention(current_q, key, value)
local_outputs[source_rank] =partialifstep<self.num_ranks-1:
q_bytes=self._serialize_q(current_q)
recv_bytes=awaitself.ring_comm.send_recv(q_bytes, f"q_{step}")
current_q=self._deserialize_q(recv_bytes)
source_rank= (source_rank-1) %self.num_ranks# All2All: redistribute partial outputs to source ranksmy_partials=awaitself._all2all_outputs(local_outputs)
returnmerge_partial_attention(my_partials)
4.2.3 Adaptive Heuristic
# src/dnet/core/cp/heuristics.pydefselect_ring_algorithm(
new_tokens: int, # Tcached_tokens: int, # Pnum_kv_heads: int, # NKVnum_q_heads: int, # NHnum_ranks: int, # Nflops_per_device: float, # Cinter_device_bandwidth: float# BW
) ->Literal["pass_kv", "pass_q"]:
""" Select optimal ring algorithm based on cache miss rate and arithmetic intensity. Heuristic (from Meta's paper): - pass-KV if T/(T+P) >= 2*NKV/NH (cache miss rate threshold) - pass-KV if T >= N * (C * NKV * e) / (2 * NH * BW) (sufficient compute) - pass-Q otherwise """total_tokens=new_tokens+cached_tokensmiss_rate=new_tokens/total_tokensiftotal_tokens>0else1.0# Threshold from GQA ratiogqa_threshold=2*num_kv_heads/num_q_heads# e.g., 2*8/128 = 0.125ifmiss_rate>=gqa_threshold:
return"pass_kv"# Check if sufficient compute to overlap pass-KV communicationelement_size=2# bfloat16min_tokens_for_overlap=num_ranks* (flops_per_device*num_kv_heads*element_size) / (2*num_q_heads*inter_device_bandwidth)
ifnew_tokens>=min_tokens_for_overlap:
return"pass_kv"return"pass_q"
4.3 Strategy Integration
4.3.1 ContextParallelStrategy
# src/dnet/api/strategies/context_parallel.pyclassCPTopologySolver(TopologySolver):
"""Topology solver for context parallelism."""asyncdefsolve(
self,
profiles: Dict[str, DeviceProfile],
model_profile: Any,
model_name: str,
num_layers: int,
kv_bits: Literal["4bit", "8bit", "fp16"],
shards: Dict[str, DnetDeviceProperties],
thunderbolts: Dict[str, Dict[str, ThunderboltConnection]],
) ->CPTopologyInfo:
""" For CP, all devices get the full model. Optimize ordering for ring bandwidth. """# Order devices by Thunderbolt connectivity for minimal latencyordered=self._optimize_ring_order(shards, thunderbolts)
returnCPTopologyInfo(
model=model_name,
kv_bits=kv_bits,
num_layers=num_layers,
devices=ordered,
# Each device gets ALL layers (full model)assignments={name: list(range(num_layers)) fornameinordered},
num_cp_ranks=len(ordered),
)
classContextParallelStrategy(Strategy):
"""Execution strategy using context parallelism."""def__init__(self):
self._solver=CPTopologySolver()
self._adapter=CPApiAdapter()
@propertydefsolver(self) ->TopologySolver:
returnself._solver@propertydefadapter(self) ->ApiAdapterBase:
returnself._adapter
4.3.2 Shard-Side CPAdapter
# src/dnet/shard/adapters/context_parallel.pyclassCPAdapter(ShardAdapterBase):
"""Context parallel adapter for shards."""def__init__(
self,
runtime: ShardRuntime,
discovery: AsyncDnetP2P,
rank_id: int,
num_ranks: int,
):
super().__init__(runtime, discovery)
self.rank_id=rank_idself.num_ranks=num_ranksself.ring_comm=CPRingCommunicator(rank_id, num_ranks, discovery)
self._algorithm: Literal["pass_kv", "pass_q"] ="pass_kv"asyncdefconfigure_topology(self, req: ShardLoadModelRequest) ->None:
"""Configure CP topology from load request."""self.rank_id=req.cp_rank_idself.num_ranks=req.cp_num_ranksawaitself.ring_comm.connect_neighbors()
asyncdefprocess_activation(self, msg: ActivationMessage) ->ActivationMessage:
"""Process with context-parallel attention."""# 1. Load-balanced unshard to get local tokenslocal_tokens, indices=load_balanced_shard(
msg.tokens, self.num_ranks, self.rank_id
)
# 2. Compute embeddings and projections locallyhidden=self.runtime.compute_embeddings(local_tokens)
q, k, v=self.runtime.compute_qkv(hidden)
# 3. Ring attention (select algorithm dynamically)ifself._algorithm=="pass_kv":
attn_out=awaitself.ring_pass_kv_attention(q, k, v)
else:
attn_out=awaitself.ring_pass_q_attention(q, k, v)
# 4. FFN + output projection (local compute)output=self.runtime.compute_ffn(attn_out)
returnmsg.with_output(output, indices)
4.4 Configuration
Following the existing pattern in config.py, we use Literal types for constrained choices (which Pydantic validates) and integrate with the .env.example auto-generation via scripts/generate_env_example.py.
# src/dnet/config.py (additions)fromenumimportStrEnumclassCPAlgorithm(StrEnum):
"""Ring attention algorithm selection."""AUTO="auto"# Dynamic selection based on heuristicsPASS_KV="pass_kv"# Rotate KV blocks (best for prefill)PASS_Q="pass_q"# Rotate Q blocks (best for decode)classContextParallelSettings(BaseSettings):
"""Context parallelism configuration."""model_config=SettingsConfigDict(env_prefix="DNET_CP_")
enabled: bool=Field(
default=False,
description="Enable context parallelism mode",
)
algorithm: CPAlgorithm=Field(
default=CPAlgorithm.AUTO,
description="Ring attention algorithm (auto, pass_kv, pass_q)",
)
min_context_for_cp: int=Field(
default=32768,
description="Minimum context length to enable CP (below this, single-device)",
)
chunk_overlap: int=Field(
default=0,
description="Overlap between chunks for sliding window attention",
)
.env.example Integration:
Add ContextParallelSettings to generate_env_example.py:
1. Executive Summary
This document describes the design for adding Context Parallelism (CP) to dnet, enabling long-context inference (128K+ tokens) by distributing sequence dimensions across multiple Apple Silicon devices. CP complements the existing RingStrategy (layer/pipeline parallelism) with a new axis of parallelization.
Goals
Non-Goals (v1)
2. Background
2.1 Current Architecture
graph LR subgraph "Pipeline Parallelism" A[API] --> S1[Shard 1<br>Layers 0-10] S1 --> S2[Shard 2<br>Layers 11-20] S2 --> S3[Shard 3<br>Layers 21-31] S3 -->|token| A endThe current dnet uses pipeline parallelism: each shard owns a subset of layers, and activations flow through the ring. This works well for large models but does not reduce per-device context memory.
2.2 Problem Statement
Pipeline parallelism does not shard KV cache across devices. Context Parallelism solves this.
2.3 Ring Attention
Ring Attention (Liu et al., 2023) distributes the sequence dimension across devices:
graph LR subgraph "Context Parallelism" D1[Device 1<br>Tokens 0-32K] --> D2[Device 2<br>Tokens 32K-64K] D2 --> D3[Device 3<br>Tokens 64K-96K] D3 --> D4[Device 4<br>Tokens 96K-128K] D4 -->|KV blocks| D1 endKey insight: Blockwise attention is permutation invariant over KV blocks, so we can compute partial attention in any order and merge results.
3. Design Overview
3.1 High-Level Architecture
flowchart TB subgraph API["API Node"] direction TB CM["ClusterManager"] MM["ModelManager"] IM["InferenceManager"] CPS["ContextParallelStrategy"] CPTS["CPTopologySolver"] CPAA["CPApiAdapter"] CPS -->|solver| CPTS CPS -->|adapter| CPAA IM --> CPAA end subgraph Shards["Shard Nodes (CP Ring)"] direction LR subgraph S1["Shard 1"] CPA1["Adapter 1"] SR1["Runtime 1 (Full Model)"] CPA1 --> SR1 end subgraph S2["Shard 2"] CPA2["Adapter 2"] SR2["Runtime 2 (Full Model)"] CPA2 --> SR2 end subgraph S3["Shard 3"] CPA3["Adapter 3"] SR3["Runtime 3 (Full Model)"] CPA3 --> SR3 end subgraph S4["Shard 4"] CPA4["Adapter 4"] SR4["Runtime 4 (Full Model)"] CPA4 --> SR4 end end CPAA --> CPA1 CPA1 <-.->|"KV/Q blocks"| CPA2 CPA2 <-.->|"KV/Q blocks"| CPA3 CPA3 <-.->|"KV/Q blocks"| CPA4 CPA4 <-.->|"KV/Q blocks"| CPA1Data Flow:
InferenceManager→CPApiAdapterCPApiAdaptersends sharded tokens to Shard 1 (head of ring)CPApiAdapter3.2 Key Differences from RingStrategy
4. Detailed Design
4.1 New Components
4.1.1 Load-Balanced Sharding
Causal attention has asymmetric compute: later tokens attend to more predecessors. Naive even partitioning causes load imbalance.
Solution: Partition sequence into
2Nchunks, assign complementary pairs:Each device gets roughly equal compute load.
4.1.2 Merge Attention Operator
When computing blockwise attention across distributed KV, each device produces partial outputs with local softmax denominators. These must be merged correctly.
Math: For blocks with outputs
O_i, max scoresm_i, and log-sum-expl_i:4.1.3 Ring Communication
gRPC-based ring for passing KV or Q blocks between CP ranks.
4.2 Ring Attention Variants
4.2.1 Pass-KV (Full Prefill)
Best for full prefill where KV is smaller than Q (GQA models: 8 KV heads vs 128 Q heads).
4.2.2 Pass-Q (Decode / High Cache Hit)
Best for decode (single token Q) or partial prefill with high cache hit rate.
4.2.3 Adaptive Heuristic
4.3 Strategy Integration
4.3.1 ContextParallelStrategy
4.3.2 Shard-Side CPAdapter
4.4 Configuration
Following the existing pattern in
config.py, we useLiteraltypes for constrained choices (which Pydantic validates) and integrate with the.env.exampleauto-generation viascripts/generate_env_example.py..env.exampleIntegration:ContextParallelSettingstogenerate_env_example.py:make env-exampleto regenerate.env.examplewith CP settings:4.5 Protocol Changes
Decision: Separate proto file vs. additions to existing
dnet_cp.protodnet_ring.protoActivationRequest); fewer importsRecommendation: Create
dnet_cp.protoas a separate file because:KVBlockTransfer/QBlockTransferare CP-specific and don't belong in ring transportMinor addition to
dnet_ring.proto(for CP-enabled requests):5. Proposed Changes
5.1 New Files
src/dnet/core/cp/__init__.pysrc/dnet/core/cp/sharding.pysrc/dnet/core/cp/merge_attention.pysrc/dnet/core/cp/ring_comm.pysrc/dnet/core/cp/heuristics.pysrc/dnet/api/strategies/context_parallel.pysrc/dnet/shard/adapters/context_parallel.pytests/subsystems/test_cp_sharding.pytests/subsystems/test_cp_merge.pytests/subsystems/test_cp_heuristics.py5.2 Modified Files
[MODIFY] config.py
ContextParallelSettingsclasscontext_parallel: ContextParallelSettingstoDnetSettings[MODIFY] dnet_ring.proto
CPConfig,KVBlockTransfer,QBlockTransfermessagescp_configfield toActivationRequest[MODIFY] api.py
[MODIFY] shard.py
[MODIFY] models.py
cp_rank_id,cp_num_rankstoShardLoadModelRequest6. Implementation Phases
Phase 1: Core Infrastructure (2-3 days)
src/dnet/core/cp/packagesharding.pywith load-balanced partitioningmerge_attention.pywith numerically stable mergingPhase 2: Ring Communication (2-3 days)
ring_comm.pywith gRPC send/recvPhase 3: Ring Attention Variants (3-4 days)
CPAdapterPhase 4: Strategy Integration (2-3 days)
ContextParallelStrategyclassPhase 5: Verification & Optimization (2-3 days)
7. Verification Plan
7.1 Unit Tests
Sharding Tests (
tests/subsystems/test_cp_sharding.py):Merge Attention Tests (
tests/subsystems/test_cp_merge.py):Heuristic Tests (
tests/subsystems/test_cp_heuristics.py):7.2 Integration Tests
Ring Communication (
tests/integration/test_cp_ring.py):7.3 CI Workflow for Coordinated Multi-Runner E2E Tests
Since dnet has 2 self-hosted macOS runners (
mac2.metal), we can design a workflow that coordinates both runners for CP e2e tests:Approach: Use a hostfile + static discovery pattern (similar to
test-static-discovery.yml) where:Warning
Challenge: GitHub Actions artifact uploads/downloads add latency. For reliable coordination, consider:
7.4 Manual Verification (Local Development)
Single-machine test (2 shards on localhost):
Cross-machine test (2 Apple Silicon devices on same network):
192.168.1.10,192.168.1.11)8. Risks and Mitigations
9. Future Work
10. References