-
Notifications
You must be signed in to change notification settings - Fork 11
TraceML Multi-Node Distributed Architecture design #37
Description
TraceML — Multi-Node & Multi-GPU Architecture Design
Status: Design Proposal
Scope: Extends TraceML from single-node DDP to multi-node DDP (M nodes × N GPUs) and FSDP2.
1. Background & Goals
1.1 What TraceML Does Today
TraceML surfaces step-level training visibility for PyTorch workloads:
- Step breakdown:
dataloader → forward → backward → optimizer → overhead - Per-layer memory and timing (Deep-Dive mode)
- System telemetry: CPU, RAM, GPU utilisation, temperature, power
- DDP rank comparison: median vs worst rank, skew %
1.2 Compute Topology Target
| Topology | GPUs | Status |
|---|---|---|
| Single GPU | 1 × 1 | ✅ Supported |
| Single-node DDP | 1 × N | ✅ Supported |
| Multi-node DDP | M × N | ❌ This document |
| FSDP2 | M × N, sharded | ❌ This document (§14) |
| Tensor Parallel | Intra-layer splits | 🔲 Out of scope v1 |
| Pipeline Parallel | Stage-split layers | 🔲 Out of scope v1 |
1.3 Design Goals
| # | Goal |
|---|---|
| G1 | Every rank attributed to the correct (node, local_rank, global_rank) triple |
| G2 | < 0.5% overhead per training step |
| G3 | Fail-open: unreachable aggregator never slows training |
| G4 | Single user-facing entrypoint: traceml run |
| G5 | Zero behaviour change for existing single-node users |
| G6 | FSDP2-aware shard memory reporting without broken hooks |
2. Current Architecture
2.1 Process Topology
graph TD
subgraph "Operator Machine"
CLI["traceml run (cli.py)"]
end
subgraph "Single Node"
CLI -->|"spawn"| AGG["AggregatorProcess<br/>(aggregator_main.py)<br/>127.0.0.1:29765"]
CLI -->|"torchrun"| TRUN["torchrun → executor.py"]
subgraph "Worker Processes"
TRUN --> R0["Rank 0<br/>LocalRank=0"]
TRUN --> R1["Rank 1<br/>LocalRank=1"]
TRUN --> R2["Rank 2<br/>LocalRank=2"]
TRUN --> R3["Rank 3<br/>LocalRank=3"]
end
R0 -->|"TCP loopback"| AGG
R1 -->|"TCP loopback"| AGG
R2 -->|"TCP loopback"| AGG
R3 -->|"TCP loopback"| AGG
end
2.2 Data Flow (Current)
sequenceDiagram
participant TC as Training Code
participant D as decorators.py
participant Q as In-Process Queues
participant RT as TraceMLRuntime._tick()
participant DB as Local Database
participant S as DBIncrementalSender
participant AGG as Aggregator / RemoteDBStore
participant UI as CLI / NiceGUI
TC->>D: with trace_step(model)
D->>Q: emit TimeEvent, MemoryEvent
loop every 1s (background thread)
RT->>Q: sampler.sample() — drain queue
Q-->>RT: events
RT->>DB: db.add_record()
RT->>S: sender.flush()
S->>AGG: TCP — {rank, sampler, tables:{...}}
end
AGG->>UI: display_driver.tick()
2.3 Current Module Map
graph LR
subgraph "Training Process (per rank)"
DEC["decorators.py<br/>trace_step / trace_model_instance"]
RT["runtime/runtime.py<br/>TraceMLRuntime"]
SAMP["samplers/*<br/>System / Process / Time / Memory / Layer"]
DB["database/database.py<br/>Bounded in-memory deque store"]
SEND["database/database_sender.py<br/>DBIncrementalSender"]
TCP_C["transport/tcp_transport.py<br/>TCPClient (loopback only)"]
DIST["transport/distributed.py<br/>get_ddp_info() → LOCAL_RANK only"]
end
subgraph "Aggregator Process"
TCP_S["TCPServer"]
STORE["database/remote_database_store.py<br/>RemoteDBStore: rank→sampler→DB"]
DISP["aggregator/display_drivers/*<br/>CLI / NiceGUI"]
end
DEC --> RT
RT --> SAMP
SAMP --> DB
DB --> SEND
SEND --> TCP_C
TCP_C -->|"loopback TCP"| TCP_S
TCP_S --> STORE
STORE --> DISP
DEC --> DIST
3. Gap Analysis
3.1 The Identity Crisis
With 2 nodes × 4 GPUs = 8 total ranks, distributed.py only reads LOCAL_RANK:
Node 0: LOCAL_RANK 0,1,2,3 → GLOBAL_RANK 0,1,2,3
Node 1: LOCAL_RANK 0,1,2,3 → GLOBAL_RANK 4,5,6,7
Node 1's LOCAL_RANK=2 and Node 0's LOCAL_RANK=2 are indistinguishable in the current wire format. The aggregator merges both into the same rank=2 slot in RemoteDBStore — corrupting all telemetry.
3.2 Transport Cannot Cross Node Boundaries
TCPClient always connects to a single configurable host (default 127.0.0.1). Workers on Node 1 physically cannot reach an aggregator on Node 0 over loopback. There is also no reconnect logic — a network blip silently discards all subsequent telemetry.
3.3 SystemSampler Is Node-Blind
SystemSampler runs only on local_rank == 0 to avoid duplicating host-level metrics. With M nodes this means only Node 0's CPU/RAM/GPU health is ever collected — Nodes 1..M-1 are completely dark.
3.4 Full Gap Table
| Component | Gap | Severity |
|---|---|---|
transport/distributed.py |
No NODE_RANK, GLOBAL_RANK, hostname |
🔴 Blocker |
transport/tcp_transport.py |
Loopback-only, no reconnect, backlog=16 |
🔴 Blocker |
runtime/settings.py |
Single TCP host, no nnodes config | 🔴 Blocker |
database/remote_database_store.py |
rank → sampler, no node dimension |
🔴 Blocker |
cli.py |
No multi-node torchrun launch support |
🔴 Blocker |
| Wire format | Missing global_rank, node_rank, hostname |
🔴 Blocker |
samplers/system_sampler.py |
Blind to non-zero nodes | 🟠 High |
samplers/layer_* |
FSDP2 hook incompatibility | 🟠 High (see §14) |
| Aggregator backlog | 16 connections — far too small for 64+ ranks | 🟡 Medium |
4. Proposed Architecture
4.1 Design Philosophy
One global aggregator per session. One relay per node. All workers connect to their node-local relay.
Rather than having every rank connect across the network directly to the aggregator, we use a two-tier relay pattern:
- Tier 1 — NodeRelay (per node, embedded in rank0's process): Accepts connections from co-located workers on loopback, tags messages with full identity, forwards upstream to the global aggregator.
- Tier 2 — GlobalAggregator (one per session, on node0): Receives tagged rows from all node relays, drives the UI.
4.2 Full Multi-Node Topology
graph TD
CLI["traceml run (Operator)"]
subgraph Node0 ["Node 0 (node-0.cluster)"]
subgraph "rank 0 process [NODE LEADER]"
R0_RT["TraceMLRuntime<br/>global_rank=0"]
RELAY0["NodeRelay<br/>:29766 (loopback)"]
AGG_G["GlobalAggregator<br/>0.0.0.0:29765"]
end
R1["rank 1<br/>global_rank=1"]
R2["rank 2<br/>global_rank=2"]
R3["rank 3<br/>global_rank=3"]
end
subgraph Node1 ["Node 1 (node-1.cluster)"]
subgraph "rank 4 process [NODE LEADER]"
R4_RT["TraceMLRuntime<br/>global_rank=4"]
RELAY1["NodeRelay<br/>:29766 (loopback)"]
end
R5["rank 5<br/>global_rank=5"]
R6["rank 6<br/>global_rank=6"]
R7["rank 7<br/>global_rank=7"]
end
CLI -->|"spawn aggregator"| AGG_G
CLI -->|"torchrun --nnodes=2"| Node0
CLI -->|"torchrun --nnodes=2"| Node1
R0_RT -->|"loopback TCP"| RELAY0
R1 -->|"loopback TCP"| RELAY0
R2 -->|"loopback TCP"| RELAY0
R3 -->|"loopback TCP"| RELAY0
R4_RT -->|"loopback TCP"| RELAY1
R5 -->|"loopback TCP"| RELAY1
R6 -->|"loopback TCP"| RELAY1
R7 -->|"loopback TCP"| RELAY1
RELAY0 -->|"loopback → AGG"| AGG_G
RELAY1 -->|"inter-node TCP"| AGG_G
4.3 Current vs Proposed — Side-by-Side
| Dimension | Current | Proposed |
|---|---|---|
| Worker routing | Directly to loopback aggregator | Workers → NodeRelay (loopback) → GlobalAggregator (network) |
| Primary key | LOCAL_RANK (ambiguous) |
GLOBAL_RANK (unique across job) |
| Message identity | {rank, sampler, ...} |
{global_rank, local_rank, node_rank, hostname, ...} |
| Aggregator listener | 127.0.0.1 |
0.0.0.0 (when nnodes > 1) |
| Store index | rank → sampler → DB |
global_rank → sampler → DB + node index |
| System telemetry | Only node 0 | Every node-leader runs SystemSampler |
| TCP backlog | 16 | max(64, world_size × 2) |
| Reconnect | None | Exponential backoff in ResilientTCPClient |
| CLI launch | torchrun --nproc_per_node=N |
torchrun --nnodes=M --nproc_per_node=N --master_addr=... |
5. Component Design — Distributed Identity
5.1 Problem
transport/distributed.py::get_ddp_info() returns (is_ddp, local_rank, world_size). This loses NODE_RANK and produces identity ambiguity in multi-node mode.
5.2 New RankIdentity Dataclass
File: traceml/transport/distributed.py
@dataclass(frozen=True)
class RankIdentity:
global_rank: int # unique across entire job (0..world_size-1)
local_rank: int # GPU index on this node (0..nproc_per_node-1)
node_rank: int # node index in the cluster (0..nnodes-1)
world_size: int # total number of ranks
local_world_size: int # ranks per node (nproc_per_node)
nnodes: int # total node count
hostname: str # socket.gethostname()
is_global_rank0: bool # global_rank == 0
is_node_leader: bool # local_rank == 0
def get_rank_identity() -> RankIdentity:
"""
Resolve full rank identity. Priority order:
1. torch.distributed (most authoritative when initialized)
2. torchrun env vars (RANK, LOCAL_RANK, NODE_RANK, WORLD_SIZE, LOCAL_WORLD_SIZE)
3. Defaults for non-distributed (single-GPU) runs
"""
dist_ok = torch.distributed.is_available() and torch.distributed.is_initialized()
global_rank = torch.distributed.get_rank() if dist_ok else int(os.environ.get("RANK", "0"))
world_size = torch.distributed.get_world_size() if dist_ok else int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
node_rank = int(os.environ.get("NODE_RANK", "0"))
lws = int(os.environ.get("LOCAL_WORLD_SIZE", str(world_size)))
nnodes = int(os.environ.get("NNODES", str(max(1, world_size // lws))))
return RankIdentity(
global_rank=global_rank, local_rank=local_rank, node_rank=node_rank,
world_size=world_size, local_world_size=lws, nnodes=nnodes,
hostname=socket.gethostname(),
is_global_rank0=(global_rank == 0),
is_node_leader=(local_rank == 0),
)5.3 Propagation Through the Stack
RankIdentity is resolved once in TraceMLRuntime.__init__() and threaded into every sampler and sender at construction time. DBIncrementalSender.rank changes from local_rank → global_rank. All downstream code that previously read the raw rank field from a message now reads global_rank.
graph LR
RI["RankIdentity<br/>(resolved once at startup)"]
RI --> RT["TraceMLRuntime"]
RT --> SAMP["All Samplers<br/>(store global_rank in every row)"]
RT --> SEND["DBIncrementalSender<br/>rank = global_rank"]
RT --> RELAY["NodeRelay<br/>(tags forwarded msgs)"]
6. Component Design — Transport Layer
6.1 Architecture Overview
graph LR
subgraph "Per-Rank Process"
RC["ResilientTCPClient"]
end
subgraph "Node-Leader Process (local_rank=0)"
NR["NodeRelay<br/>:29766 loopback server<br/>+upstream client"]
RC2["ResilientTCPClient<br/>(own telemetry)"]
end
subgraph "GlobalAggregator"
HS["HighCapacityTCPServer<br/>(epoll/select)<br/>0.0.0.0:29765"]
end
RC -->|"loopback :29766"| NR
RC2 -->|"loopback :29766"| NR
NR -->|"inter-node TCP"| HS
6.2 ResilientTCPClient
Replaces the current no-reconnect TCPClient.
File: traceml/transport/tcp_transport.py
Policy:
- On connect failure: retry at 0.1s → 0.2s → 0.4s → ... capped at 30s
- After 10 consecutive failures: log warning, drop frame, keep retrying
- send() never raises, never blocks training
class ResilientTCPClient:
INITIAL_BACKOFF = 0.1 # seconds
MAX_BACKOFF = 30.0 # seconds
def send(self, payload: dict) -> None:
if not self._ensure_connected():
self._logger.warning("[TraceML] TCPClient: dropping frame (not connected)")
return
data = msgspec.msgpack.encode(payload)
header = struct.pack("!I", len(data))
with self._lock:
self._sock.sendall(header + data)
def _ensure_connected(self) -> bool:
if self._connected:
return True
now = time.monotonic()
if now < self._next_retry_at:
return False # backoff not elapsed yet
try:
self._connect()
self._backoff = self.INITIAL_BACKOFF
return True
except Exception:
self._backoff = min(self._backoff * 2, self.MAX_BACKOFF)
self._next_retry_at = now + self._backoff
return False6.3 NodeRelay
Runs as a daemon thread inside the node-leader process (local_rank == 0).
File: traceml/transport/node_relay.py
class NodeRelay:
"""
In-process relay: accepts local-rank telemetry on loopback,
tags each message with node identity, forwards to GlobalAggregator.
Why in-process (not a separate process)?
- Shares the training lifecycle automatically.
- No extra firewall rule needed for a new sidecar port.
- Simpler orchestration from the CLI.
"""
def _tag_message(self, msg: dict) -> dict:
msg["node_rank"] = self._identity.node_rank
msg["hostname"] = self._identity.hostname
msg["local_world_size"] = self._identity.local_world_size
return msg
def _forward_loop(self) -> None:
while not self._stop_event.is_set():
for msg in self._local_server.poll():
self._upstream.send(self._tag_message(msg))
self._stop_event.wait(0.05) # 50 ms forward cadenceWhy 50 ms cadence? It adds at most 50 ms of display lag (imperceptible at human-scale refresh rates) while batching many small sends into fewer TCP writes, reducing per-message overhead across 64+ ranks.
6.4 HighCapacityTCPServer
Replaces the current thread-per-connection server for the GlobalAggregator.
File: traceml/transport/tcp_server_async.py
Current: One daemon thread spawned per client connection → O(world_size) threads
Proposed: Single thread using selectors.DefaultSelector (epoll on Linux)
→ O(1) threads regardless of world_size
class HighCapacityTCPServer:
def __init__(self, cfg: TCPConfig, world_size: int):
self._backlog = max(64, world_size * 2)
self._sel = selectors.DefaultSelector()
self._queue = queue.Queue(maxsize=10_000)
def _run(self) -> None:
self._sel.register(self._sock, selectors.EVENT_READ, data="accept")
while not self._stop_event.is_set():
events = self._sel.select(timeout=0.5)
for key, _ in events:
if key.data == "accept":
self._accept()
else:
self._recv(key)6.5 Settings Extension
File: traceml/runtime/settings.py
@dataclass(frozen=True)
class TraceMLNodeRelaySettings:
local_port: int = 29766 # loopback port workers connect to on each node
upstream_host: str = "" # GlobalAggregator host (auto-filled from NODE_RANK=0 IP)
upstream_port: int = 29765 # GlobalAggregator port
@dataclass(frozen=True)
class TraceMLSettings:
# ... existing fields unchanged ...
# NEW
aggregator: TraceMLTCPSettings = TraceMLTCPSettings()
relay: TraceMLNodeRelaySettings = TraceMLNodeRelaySettings()
world_size: int = 1
nnodes: int = 1
nproc_per_node: int = 1
parallelism_mode: str = "auto" # "ddp" | "fsdp2" | "auto"7. Component Design — Samplers
7.1 Change Matrix
| Sampler | Change Required | Effort |
|---|---|---|
TimeSampler |
Use global_rank in payload row |
Trivial |
StepMemorySampler |
Use global_rank |
Trivial |
ProcessSampler |
Tag with global_rank; keep local_rank for CUDA device |
Trivial |
StdoutStderrSampler |
Use global_rank |
Trivial |
LayerForward/BackwardTimeSampler |
Use global_rank |
Trivial |
LayerForward/BackwardMemorySampler |
Use global_rank |
Trivial |
SystemSampler |
Run on every node-leader, not just global_rank==0; add node_rank to wire schema |
Small |
7.2 SystemSampler — Multi-Node Fix
File: traceml/runtime/runtime.py::_build_samplers()
def _build_samplers(self, identity: RankIdentity) -> List[BaseSampler]:
samplers: List[BaseSampler] = []
# Changed: every node leader (not just global rank0) runs SystemSampler
if identity.is_node_leader:
samplers.append(SystemSampler(node_rank=identity.node_rank,
hostname=identity.hostname))
samplers += [
ProcessSampler(identity=identity),
LayerMemorySampler(),
# ... rest unchanged ...
]
return samplersWire schema addition in SystemSample.to_wire():
{ ..., "node_rank": self.node_rank, "hostname": self.hostname }8. Component Design — Aggregator & Store
8.1 RemoteDBStore — Adding the Node Dimension
File: traceml/database/remote_database_store.py
@dataclass
class RankMetadata:
global_rank: int
local_rank: int
node_rank: int
hostname: str
local_world_size: int
class RemoteDBStore:
def __init__(self, max_rows: int = 500):
# Primary store: global_rank is unambiguous across the whole job
self._dbs: Dict[int, Dict[str, Database]] = {}
self._last_seen: Dict[int, float] = {}
# NEW: identity index consumed by renderers
self._rank_meta: Dict[int, RankMetadata] = {} # global_rank → meta
self._node_to_ranks: Dict[int, List[int]] = {} # node_rank → [global_ranks]
def ingest(self, message: dict) -> None:
global_rank = message.get("rank") # sender always sets this to global_rank (v2)
node_rank = message.get("node_rank", 0)
hostname = message.get("hostname", "unknown")
local_rank = message.get("local_rank", global_rank)
# Register identity
if global_rank not in self._rank_meta:
self._rank_meta[global_rank] = RankMetadata(
global_rank=global_rank, local_rank=local_rank,
node_rank=node_rank, hostname=hostname,
local_world_size=message.get("local_world_size", 1),
)
self._node_to_ranks.setdefault(node_rank, []).append(global_rank)
# Rest unchanged — tables ingested into per-global_rank Database
...
def get_nodes(self) -> List[int]:
return sorted(self._node_to_ranks.keys())
def get_ranks_for_node(self, node_rank: int) -> List[int]:
return sorted(self._node_to_ranks.get(node_rank, []))8.2 Cross-Node Summary Computation
flowchart TD
STORE["RemoteDBStore"]
subgraph "MultiNodeAggregator.compute_step_summary()"
A["Collect step_time_ms<br/>per global_rank"]
B["Group by node_rank<br/>→ per-node median"]
C["Global median<br/>Global worst rank<br/>Global skew %"]
D["Node skew %<br/>(slowest node vs fastest node)"]
end
STORE --> A --> B --> C --> D
D -->|"StepSummary"| DISP["DisplayDriver.tick()"]
class MultiNodeAggregator:
def compute_step_summary(self, store: RemoteDBStore, step: int) -> StepSummary:
per_rank_ms = {
r: self._get_step_ms(store, r, step)
for r in store.ranks()
}
per_rank_ms = {r: v for r, v in per_rank_ms.items() if v is not None}
per_node_median = {
n: statistics.median([per_rank_ms[r] for r in store.get_ranks_for_node(n)
if r in per_rank_ms])
for n in store.get_nodes()
}
global_median = statistics.median(per_rank_ms.values())
worst_rank = max(per_rank_ms, key=per_rank_ms.get)
global_skew_pct = (per_rank_ms[worst_rank] - global_median) / global_median * 100
slowest_node = max(per_node_median, key=per_node_median.get)
node_skew_pct = (
(per_node_median[slowest_node] - min(per_node_median.values()))
/ min(per_node_median.values()) * 100
)
return StepSummary(
per_rank_ms=per_rank_ms,
per_node_median=per_node_median,
global_median=global_median,
worst_global_rank=worst_rank,
global_skew_pct=global_skew_pct,
node_skew_pct=node_skew_pct,
slowest_node=slowest_node,
)9. Component Design — Renderers
9.1 CLI Dashboard — Node Summary Panel
The CLI (aggregator/display_drivers/cli.py) gains a Node Summary row rendered with Rich:
┌─ TraceML — Multi-Node DDP (2 nodes × 4 GPUs = 8 ranks) ──────────────────────────┐
│ SESSION abc123 │ STEP 142 │ GLOBAL SKEW 3.2% │ STRAGGLER rank5 │
├─ NODE SUMMARY ─────────────────────────────────────────────────────────────────────┤
│ node-0.cluster median=134ms worst=139ms (rank2) skew=3.7% │
│ node-1.cluster median=136ms worst=141ms (rank5) skew=3.7% ← +1.5% vs node0 │
├─ RANK GRID (step time ms) ─────────────────────────────────────────────────────────┤
│ rank0 132 rank1 131 rank2 139 rank3 135 │
│ rank4 134 rank5 141 ▲ rank6 137 rank7 133 │
└────────────────────────────────────────────────────────────────────────────────────┘
9.2 Web Dashboard — Node Tabs
In the NiceGUI dashboard (aggregator/display_drivers/nicegui.py), a top-level ui.tabs component contains:
- Overview tab — global step timeline, rank grid heatmap, skew history
- Node N tabs (one per node) — per-rank step sparklines, that node's system metrics (CPU/RAM/GPU from its own
SystemSampler), per-layer timing
graph LR
overview["Overview Tab<br/>• Global step timeline<br/>• Rank grid heatmap<br/>• Skew history"]
node0["Node 0 Tab<br/>• Ranks 0-3 sparklines<br/>• CPU/RAM/GPU (node-0)<br/>• Layer timing"]
node1["Node 1 Tab<br/>• Ranks 4-7 sparklines<br/>• CPU/RAM/GPU (node-1)<br/>• Layer timing"]
tabs["ui.tabs"] --> overview
tabs --> node0
tabs --> node1
10. CLI & Launcher
10.1 New Flags
traceml run train.py \
--nnodes 2 \ # number of nodes
--nproc-per-node 4 \ # GPUs per node
--master-addr node0.cluster \ # torchrun master (also default aggregator host)
--master-port 29500 \ # torchrun rendezvous
--aggregator-host node0.cluster \ # explicit aggregator override (optional)
--aggregator-port 29765 \ # aggregator TCP port
--relay-port 29766 # per-node relay loopback portAll existing single-node flags remain unchanged.
10.2 Multi-Node Launch Flow
sequenceDiagram
participant OP as Operator
participant CLI as cli.py
participant AGG as GlobalAggregator (node0)
participant TR as torchrun (all nodes)
participant EX as executor.py (per rank)
participant NR as NodeRelay (node leaders)
OP->>CLI: traceml run train.py --nnodes=2 ...
CLI->>CLI: generate SESSION_ID
CLI->>AGG: spawn aggregator on AGGREGATOR_HOST:29765
CLI->>CLI: wait for AGG TCP ready
CLI->>TR: torchrun --nnodes=2 --master_addr=... executor.py
TR->>EX: start worker processes on every node
EX->>EX: get_rank_identity()
EX->>NR: start NodeRelay (if is_node_leader and nnodes > 1)
EX->>EX: start TraceMLRuntime
EX->>NR: workers connect to :29766 (loopback)
NR->>AGG: forward tagged rows (inter-node TCP)
CLI->>TR: wait for torchrun exit
TR-->>CLI: exit code
CLI->>AGG: terminate
CLI-->>OP: exit with training code
10.3 Backward Compatibility
When --nnodes=1 (the default):
NodeRelayis not startedResilientTCPClientconnects directly toTRACEML_AGGREGATOR_HOST(127.0.0.1default)- The aggregator listens on
127.0.0.1(not0.0.0.0) - Zero behaviour change for existing users
11. Wire Protocol & Schema Versioning
11.1 Version Field
A _v field is added to every message payload for forward compatibility.
| Field | v1 (current) | v2 (proposed) |
|---|---|---|
_v |
(absent) | 2 |
rank |
LOCAL_RANK |
GLOBAL_RANK |
global_rank |
(absent) | GLOBAL_RANK |
local_rank |
(absent) | LOCAL_RANK |
node_rank |
(absent) | NODE_RANK |
hostname |
(absent) | socket.gethostname() |
local_world_size |
(absent) | nproc_per_node |
11.2 Aggregator Compatibility Rule
def ingest(self, message: dict) -> None:
v = message.get("_v", 1)
if v == 1:
# Legacy single-node: assume global_rank == rank, node_rank == 0
global_rank = message.get("rank", 0)
node_rank = 0
hostname = "localhost"
else:
global_rank = message["rank"] # sender sets rank = global_rank in v2
node_rank = message.get("node_rank", 0)
hostname = message.get("hostname", "unknown")
...Old single-node clients connecting to a new aggregator work transparently.
12. Migration Plan
Migration is structured in four independent phases. Each phase ships a working, tested, backward-compatible system.
gantt
title TraceML Multi-Node Migration Phases
dateFormat YYYY-MM-DD
section Phase 1 — Identity Fix
distributed.py RankIdentity :p1a, 2025-01-01, 5d
global_rank in wire format :p1b, after p1a, 3d
RemoteDBStore global_rank key :p1c, after p1b, 3d
Renderer label updates :p1d, after p1c, 2d
section Phase 2 — Resilient Transport
ResilientTCPClient :p2a, after p1d, 5d
HighCapacityTCPServer :p2b, after p2a, 5d
Settings expansion :p2c, after p2b, 2d
section Phase 3 — Multi-Node Launch
NodeRelay implementation :p3a, after p2c, 7d
CLI --nnodes flags :p3b, after p3a, 4d
SystemSampler all-node fix :p3c, after p3a, 3d
Node tabs in renderers :p3d, after p3b, 5d
section Phase 4 — FSDP2
detect_parallelism_mode :p4a, after p3d, 3d
FSDP2Sampler :p4b, after p4a, 7d
_trace_step_fsdp2 :p4c, after p4b, 4d
FSDP2 renderer view :p4d, after p4c, 4d
| Phase | Deliverable | Risk | Outcome |
|---|---|---|---|
| 1 | RankIdentity, v2 wire format, global_rank store key |
Low | Single-node DDP correctly labelled |
| 2 | ResilientTCPClient, HighCapacityTCPServer |
Medium | Transport robust to 64+ ranks and blips |
| 3 | NodeRelay, --nnodes CLI, per-node SystemSampler, node tabs |
Medium-High | Multi-node DDP fully supported |
| 4 | FSDP2Sampler, _trace_step_fsdp2, FSDP2 renderer |
High | FSDP2 step + shard visibility |
13. Performance Analysis
13.1 Overhead Budget
Target: < 0.5% step-time overhead added by TraceML on multi-node.
| Cost Source | Estimate | Path |
|---|---|---|
get_rank_identity() |
< 1 µs | Called once, cached in TraceMLRuntime |
NodeRelay._forward_loop() |
~50 µs per 50 ms tick | Background thread — off the training critical path |
ResilientTCPClient.send() |
100–500 µs per call | Background sampler thread, sendall with header |
| Extra wire bytes per message | +~100 bytes | Identity fields in msgpack |
SystemSampler (each node-leader) |
~500 µs per sample | Runs in sampler thread, not training thread |
At 100 ms step times and 1 s sampler interval, the per-step overhead budget is 500 µs. All costs above are either < 500 µs total or are off the training thread entirely.
13.2 NodeRelay Latency
The relay adds at most 50 ms of end-to-end display lag (the forward tick cadence). At a human-perceptible refresh rate (1–2 seconds), this is imperceptible. Training code never waits on the relay.
13.3 Aggregator Memory at Scale
world_size × samplers × max_rows × avg_row_bytes
= 64 × 10 × 200 × 1,024
≈ 131 MB
Above 32 ranks, we recommend reducing --remote-max-rows to 50 (default will adapt automatically when world_size > 32), bringing memory to ~33 MB.
14. FSDP2 Support
Note: FSDP2 (
torch.distributed.fsdp.fully_shard(), available since PyTorch 2.2) uses a fundamentally different parameter ownership model. This section covers the specific changes required and why naive hook-based approaches fail.
14.1 How FSDP2 Changes Training
sequenceDiagram
participant TR as Training Step
participant F2 as FSDP2 Runtime
participant ALLG as All-Gather NCCL
participant FW as Layer Forward
participant RS as Reduce-Scatter NCCL
participant OPTIM as Optimizer
TR->>F2: loss = model(batch)
loop For each FSDP layer in forward order
F2->>ALLG: All-gather shard to full params
ALLG-->>F2: params materialised
F2->>FW: layer.forward(input)
Note over F2: params optionally freed back to shard
end
TR->>F2: loss.backward()
loop For each FSDP layer in backward order
F2->>ALLG: All-gather shard to full params
ALLG-->>F2: params materialised
F2->>RS: Reduce-scatter grads to shard
end
TR->>OPTIM: optimizer.step()
Note over OPTIM: Works on local shards only
14.2 Why Layer Hooks Break Under FSDP2
| Hook type | What happens under FSDP2 |
|---|---|
register_forward_hook on nn.Linear |
nn.Linear is now inside an FSDPModule. The hook may fire before the all-gather completes — parameters not yet materialised |
register_forward_pre_hook |
Same issue — fires before un-sharding |
| Memory reading inside hook | Only shard-local fragments visible; reads partial parameter memory |
| Layer name lookup | FSDP2 may rename submodules during fully_shard() — name-to-module mapping breaks |
Consequence: In FSDP2 mode, trace_model_instance() must skip all layer-level hooks and fall back to root-only + shard-aware measurement.
14.3 What We Can and Cannot Measure
| Metric | Feasible? | Method |
|---|---|---|
| Total step time | ✅ Easy | CUDA events on trace_step boundary |
| Dataloader time | ✅ Easy | DataLoader patch (unchanged) |
| Forward time (total) | ✅ Easy | Single hook on FSDP root module only |
| Backward time | ✅ Easy | loss.backward() CUDA events |
| Optimizer step time | ✅ Easy | Optimizer wrapper hook |
| Local shard memory | ✅ Easy | torch.cuda.memory_allocated() at step end |
| Estimated full model memory | ✅ Easy | local_shard × world_size |
| Per-layer forward time | Hooks fire during incomplete un-shard; skip in FSDP2 | |
| Per-layer activation memory | Only shard fragments visible; skip in FSDP2 | |
| All-gather duration | FSDP2 internal event hooks (version-gated, may break) |
14.4 Auto-Detection
File: traceml/utils/fsdp_utils.py
def detect_parallelism_mode(model: nn.Module) -> str:
"""Returns: "fsdp2" | "ddp" | "single" """
try:
from torch.distributed.fsdp import FSDPModule
if any(isinstance(m, FSDPModule) for m in model.modules()):
return "fsdp2"
except (ImportError, AttributeError):
pass
try:
from torch.nn.parallel import DistributedDataParallel as DDP
if isinstance(model, DDP):
return "ddp"
except ImportError:
pass
if torch.distributed.is_available() and torch.distributed.is_initialized():
return "ddp"
return "single"Result is cached on the trace_step context manager after the first call.
14.5 FSDP2Sampler
File: traceml/samplers/fsdp2_sampler.py
class FSDP2Sampler(BaseSampler):
"""
Shard-aware memory sampler for FSDP2 workloads.
Tracks:
- local_param_bytes : bytes of parameter shards owned by this rank
- estimated_full_bytes: local × world_size (assumes even sharding)
- activation_bytes : peak torch.cuda.memory_allocated() during forward
- step : current global step count
Design constraints:
- NO per-layer hooks (they break under FSDP2)
- Single root-module forward hook for activation sampling
- All NVML / torch.cuda calls are best-effort
"""
def __init__(self, model: nn.Module, identity: RankIdentity):
super().__init__(sampler_name="FSDP2Sampler")
self._model = model
self._identity = identity
self._post_forward_mem: float = 0.0
self._attach_root_hook(model)
def _attach_root_hook(self, model: nn.Module) -> None:
"""Single post-forward hook on the FSDP root — safe and correct."""
def _hook(module, inputs, output):
try:
# All-gather buffers are still live here → peak shard activation
self._post_forward_mem = float(torch.cuda.memory_allocated())
except Exception:
pass
model.register_forward_hook(_hook)
def _local_param_bytes(self) -> int:
"""Sum of all parameter shard sizes on this rank."""
return sum(
p.numel() * p.element_size()
for p in self._model.parameters()
)
def sample(self) -> None:
try:
local_bytes = self._local_param_bytes()
row = {
"step": TraceState.step,
"global_rank": self._identity.global_rank,
"node_rank": self._identity.node_rank,
"local_param_bytes": local_bytes,
"estimated_full_bytes": local_bytes * self._identity.world_size,
"activation_bytes": self._post_forward_mem,
"timestamp": time.time(),
}
self.db.add_record("FSDP2Table", row)
except Exception as e:
self.logger.error(f"[TraceML] FSDP2Sampler error: {e}")14.6 FSDP2-Safe trace_step
File: traceml/decorators.py
@contextmanager
def trace_step(model: nn.Module):
mode = _get_parallelism_mode(model) # cached
if mode == "fsdp2":
with _trace_step_fsdp2(model):
yield
else:
with _trace_step_ddp(model): # existing path, unchanged
yield
@contextmanager
def _trace_step_fsdp2(model: nn.Module):
"""
FSDP2 step boundary.
- No layer hooks.
- Shard memory tracked via FSDP2Sampler (registered separately).
- Step timing via CUDA events on the outer boundary.
"""
step_completed = False
try:
with timed_region("_traceml_internal:step_time", scope="step", use_gpu=True):
with forward_auto_timer(), backward_auto_timer():
ensure_optimizer_timing_installed()
yield
step_completed = True
finally:
if step_completed:
TraceState.step += 1
flush_step_events(model, TraceState.step)14.7 trace_model_instance Safety Guard
def trace_model_instance(model: nn.Module, ...) -> None:
if _is_fsdp2(model):
print(
"[TraceML] FSDP2 detected — layer-level hooks are not attached "
"(parameter shards not materialised during hook execution). "
"Use FSDP2Sampler for shard-aware memory tracking.",
file=sys.stderr,
)
return
# ... existing DDP hook path unchanged ...14.8 FSDP2 Display
The renderer replaces the layer timing table with a shard memory summary when parallelism_mode == "fsdp2":
FSDP2 Shard Memory (global_rank=2, node_rank=0)
Local shard: 3.9 GB
Estimated full model: 31.2 GB (local × 8 ranks)
Peak activation: 5.1 GB (post-forward, all-gather buffers live)
Step time: 142 ms
End-State Architecture Diagram
graph TD
subgraph "Operator"
CLI["traceml run"]
end
subgraph "GlobalAggregator node0:29765"
HS["HighCapacityTCPServer<br/>epoll, backlog=max(64,WS×2)"]
STORE["RemoteDBStore<br/>global_rank × sampler × DB<br/>+ node index"]
MNA["MultiNodeAggregator<br/>cross-rank + cross-node stats"]
UI["CLI / NiceGUI<br/>Node tabs, Rank grid,<br/>Shard memory view"]
HS --> STORE --> MNA --> UI
end
subgraph "Node 0"
subgraph "rank 0 [NODE LEADER]"
R0["TraceMLRuntime<br/>identity: global_rank=0"]
NR0["NodeRelay :29766"]
R0 -->|"loopback"| NR0
end
R1["rank 1"] -->|"loopback :29766"| NR0
R2["rank 2"] -->|"loopback :29766"| NR0
R3["rank 3"] -->|"loopback :29766"| NR0
NR0 -->|"loopback"| HS
end
subgraph "Node 1"
subgraph "rank 4 [NODE LEADER]"
R4["TraceMLRuntime<br/>identity: global_rank=4"]
NR1["NodeRelay :29766"]
R4 -->|"loopback"| NR1
end
R5["rank 5"] -->|"loopback :29766"| NR1
R6["rank 6"] -->|"loopback :29766"| NR1
R7["rank 7"] -->|"loopback :29766"| NR1
NR1 -->|"inter-node TCP"| HS
end
CLI --> GlobalAggregator
CLI -->|"torchrun"| Node0
CLI -->|"torchrun"| Node1