Skip to content

TraceML Multi-Node Distributed Architecture design #37

@ppraneth

Description

@ppraneth

TraceML — Multi-Node & Multi-GPU Architecture Design

Status: Design Proposal
Scope: Extends TraceML from single-node DDP to multi-node DDP (M nodes × N GPUs) and FSDP2.


1. Background & Goals

1.1 What TraceML Does Today

TraceML surfaces step-level training visibility for PyTorch workloads:

  • Step breakdown: dataloader → forward → backward → optimizer → overhead
  • Per-layer memory and timing (Deep-Dive mode)
  • System telemetry: CPU, RAM, GPU utilisation, temperature, power
  • DDP rank comparison: median vs worst rank, skew %

1.2 Compute Topology Target

Topology GPUs Status
Single GPU 1 × 1 ✅ Supported
Single-node DDP 1 × N ✅ Supported
Multi-node DDP M × N ❌ This document
FSDP2 M × N, sharded ❌ This document (§14)
Tensor Parallel Intra-layer splits 🔲 Out of scope v1
Pipeline Parallel Stage-split layers 🔲 Out of scope v1

1.3 Design Goals

# Goal
G1 Every rank attributed to the correct (node, local_rank, global_rank) triple
G2 < 0.5% overhead per training step
G3 Fail-open: unreachable aggregator never slows training
G4 Single user-facing entrypoint: traceml run
G5 Zero behaviour change for existing single-node users
G6 FSDP2-aware shard memory reporting without broken hooks

2. Current Architecture

2.1 Process Topology

graph TD
    subgraph "Operator Machine"
        CLI["traceml run (cli.py)"]
    end

    subgraph "Single Node"
        CLI -->|"spawn"| AGG["AggregatorProcess<br/>(aggregator_main.py)<br/>127.0.0.1:29765"]
        CLI -->|"torchrun"| TRUN["torchrun → executor.py"]

        subgraph "Worker Processes"
            TRUN --> R0["Rank 0<br/>LocalRank=0"]
            TRUN --> R1["Rank 1<br/>LocalRank=1"]
            TRUN --> R2["Rank 2<br/>LocalRank=2"]
            TRUN --> R3["Rank 3<br/>LocalRank=3"]
        end

        R0 -->|"TCP loopback"| AGG
        R1 -->|"TCP loopback"| AGG
        R2 -->|"TCP loopback"| AGG
        R3 -->|"TCP loopback"| AGG
    end
Loading

2.2 Data Flow (Current)

sequenceDiagram
    participant TC as Training Code
    participant D  as decorators.py
    participant Q  as In-Process Queues
    participant RT as TraceMLRuntime._tick()
    participant DB as Local Database
    participant S  as DBIncrementalSender
    participant AGG as Aggregator / RemoteDBStore
    participant UI  as CLI / NiceGUI

    TC->>D: with trace_step(model)
    D->>Q: emit TimeEvent, MemoryEvent
    loop every 1s (background thread)
        RT->>Q: sampler.sample() — drain queue
        Q-->>RT: events
        RT->>DB: db.add_record()
        RT->>S: sender.flush()
        S->>AGG: TCP — {rank, sampler, tables:{...}}
    end
    AGG->>UI: display_driver.tick()
Loading

2.3 Current Module Map

graph LR
    subgraph "Training Process (per rank)"
        DEC["decorators.py<br/>trace_step / trace_model_instance"]
        RT["runtime/runtime.py<br/>TraceMLRuntime"]
        SAMP["samplers/*<br/>System / Process / Time / Memory / Layer"]
        DB["database/database.py<br/>Bounded in-memory deque store"]
        SEND["database/database_sender.py<br/>DBIncrementalSender"]
        TCP_C["transport/tcp_transport.py<br/>TCPClient (loopback only)"]
        DIST["transport/distributed.py<br/>get_ddp_info() → LOCAL_RANK only"]
    end

    subgraph "Aggregator Process"
        TCP_S["TCPServer"]
        STORE["database/remote_database_store.py<br/>RemoteDBStore: rank→sampler→DB"]
        DISP["aggregator/display_drivers/*<br/>CLI / NiceGUI"]
    end

    DEC --> RT
    RT --> SAMP
    SAMP --> DB
    DB --> SEND
    SEND --> TCP_C
    TCP_C -->|"loopback TCP"| TCP_S
    TCP_S --> STORE
    STORE --> DISP
    DEC --> DIST
Loading

3. Gap Analysis

3.1 The Identity Crisis

With 2 nodes × 4 GPUs = 8 total ranks, distributed.py only reads LOCAL_RANK:

Node 0:  LOCAL_RANK 0,1,2,3  →  GLOBAL_RANK 0,1,2,3
Node 1:  LOCAL_RANK 0,1,2,3  →  GLOBAL_RANK 4,5,6,7

Node 1's LOCAL_RANK=2 and Node 0's LOCAL_RANK=2 are indistinguishable in the current wire format. The aggregator merges both into the same rank=2 slot in RemoteDBStore — corrupting all telemetry.

3.2 Transport Cannot Cross Node Boundaries

TCPClient always connects to a single configurable host (default 127.0.0.1). Workers on Node 1 physically cannot reach an aggregator on Node 0 over loopback. There is also no reconnect logic — a network blip silently discards all subsequent telemetry.

3.3 SystemSampler Is Node-Blind

SystemSampler runs only on local_rank == 0 to avoid duplicating host-level metrics. With M nodes this means only Node 0's CPU/RAM/GPU health is ever collected — Nodes 1..M-1 are completely dark.

3.4 Full Gap Table

Component Gap Severity
transport/distributed.py No NODE_RANK, GLOBAL_RANK, hostname 🔴 Blocker
transport/tcp_transport.py Loopback-only, no reconnect, backlog=16 🔴 Blocker
runtime/settings.py Single TCP host, no nnodes config 🔴 Blocker
database/remote_database_store.py rank → sampler, no node dimension 🔴 Blocker
cli.py No multi-node torchrun launch support 🔴 Blocker
Wire format Missing global_rank, node_rank, hostname 🔴 Blocker
samplers/system_sampler.py Blind to non-zero nodes 🟠 High
samplers/layer_* FSDP2 hook incompatibility 🟠 High (see §14)
Aggregator backlog 16 connections — far too small for 64+ ranks 🟡 Medium

4. Proposed Architecture

4.1 Design Philosophy

One global aggregator per session. One relay per node. All workers connect to their node-local relay.

Rather than having every rank connect across the network directly to the aggregator, we use a two-tier relay pattern:

  • Tier 1 — NodeRelay (per node, embedded in rank0's process): Accepts connections from co-located workers on loopback, tags messages with full identity, forwards upstream to the global aggregator.
  • Tier 2 — GlobalAggregator (one per session, on node0): Receives tagged rows from all node relays, drives the UI.

4.2 Full Multi-Node Topology

graph TD
    CLI["traceml run (Operator)"]

    subgraph Node0 ["Node 0  (node-0.cluster)"]
        subgraph "rank 0 process  [NODE LEADER]"
            R0_RT["TraceMLRuntime<br/>global_rank=0"]
            RELAY0["NodeRelay<br/>:29766 (loopback)"]
            AGG_G["GlobalAggregator<br/>0.0.0.0:29765"]
        end
        R1["rank 1<br/>global_rank=1"]
        R2["rank 2<br/>global_rank=2"]
        R3["rank 3<br/>global_rank=3"]
    end

    subgraph Node1 ["Node 1  (node-1.cluster)"]
        subgraph "rank 4 process  [NODE LEADER]"
            R4_RT["TraceMLRuntime<br/>global_rank=4"]
            RELAY1["NodeRelay<br/>:29766 (loopback)"]
        end
        R5["rank 5<br/>global_rank=5"]
        R6["rank 6<br/>global_rank=6"]
        R7["rank 7<br/>global_rank=7"]
    end

    CLI -->|"spawn aggregator"| AGG_G
    CLI -->|"torchrun --nnodes=2"| Node0
    CLI -->|"torchrun --nnodes=2"| Node1

    R0_RT -->|"loopback TCP"| RELAY0
    R1 -->|"loopback TCP"| RELAY0
    R2 -->|"loopback TCP"| RELAY0
    R3 -->|"loopback TCP"| RELAY0

    R4_RT -->|"loopback TCP"| RELAY1
    R5 -->|"loopback TCP"| RELAY1
    R6 -->|"loopback TCP"| RELAY1
    R7 -->|"loopback TCP"| RELAY1

    RELAY0 -->|"loopback → AGG"| AGG_G
    RELAY1 -->|"inter-node TCP"| AGG_G
Loading

4.3 Current vs Proposed — Side-by-Side

Dimension Current Proposed
Worker routing Directly to loopback aggregator Workers → NodeRelay (loopback) → GlobalAggregator (network)
Primary key LOCAL_RANK (ambiguous) GLOBAL_RANK (unique across job)
Message identity {rank, sampler, ...} {global_rank, local_rank, node_rank, hostname, ...}
Aggregator listener 127.0.0.1 0.0.0.0 (when nnodes > 1)
Store index rank → sampler → DB global_rank → sampler → DB + node index
System telemetry Only node 0 Every node-leader runs SystemSampler
TCP backlog 16 max(64, world_size × 2)
Reconnect None Exponential backoff in ResilientTCPClient
CLI launch torchrun --nproc_per_node=N torchrun --nnodes=M --nproc_per_node=N --master_addr=...

5. Component Design — Distributed Identity

5.1 Problem

transport/distributed.py::get_ddp_info() returns (is_ddp, local_rank, world_size). This loses NODE_RANK and produces identity ambiguity in multi-node mode.

5.2 New RankIdentity Dataclass

File: traceml/transport/distributed.py

@dataclass(frozen=True)
class RankIdentity:
    global_rank: int        # unique across entire job (0..world_size-1)
    local_rank: int         # GPU index on this node (0..nproc_per_node-1)
    node_rank: int          # node index in the cluster (0..nnodes-1)
    world_size: int         # total number of ranks
    local_world_size: int   # ranks per node (nproc_per_node)
    nnodes: int             # total node count
    hostname: str           # socket.gethostname()
    is_global_rank0: bool   # global_rank == 0
    is_node_leader: bool    # local_rank == 0

def get_rank_identity() -> RankIdentity:
    """
    Resolve full rank identity. Priority order:
      1. torch.distributed (most authoritative when initialized)
      2. torchrun env vars (RANK, LOCAL_RANK, NODE_RANK, WORLD_SIZE, LOCAL_WORLD_SIZE)
      3. Defaults for non-distributed (single-GPU) runs
    """
    dist_ok = torch.distributed.is_available() and torch.distributed.is_initialized()
    global_rank = torch.distributed.get_rank()   if dist_ok else int(os.environ.get("RANK",        "0"))
    world_size  = torch.distributed.get_world_size() if dist_ok else int(os.environ.get("WORLD_SIZE", "1"))
    local_rank  = int(os.environ.get("LOCAL_RANK", "0"))
    node_rank   = int(os.environ.get("NODE_RANK",  "0"))
    lws         = int(os.environ.get("LOCAL_WORLD_SIZE", str(world_size)))
    nnodes      = int(os.environ.get("NNODES", str(max(1, world_size // lws))))

    return RankIdentity(
        global_rank=global_rank, local_rank=local_rank, node_rank=node_rank,
        world_size=world_size, local_world_size=lws, nnodes=nnodes,
        hostname=socket.gethostname(),
        is_global_rank0=(global_rank == 0),
        is_node_leader=(local_rank == 0),
    )

5.3 Propagation Through the Stack

RankIdentity is resolved once in TraceMLRuntime.__init__() and threaded into every sampler and sender at construction time. DBIncrementalSender.rank changes from local_rankglobal_rank. All downstream code that previously read the raw rank field from a message now reads global_rank.

graph LR
    RI["RankIdentity<br/>(resolved once at startup)"]
    RI --> RT["TraceMLRuntime"]
    RT --> SAMP["All Samplers<br/>(store global_rank in every row)"]
    RT --> SEND["DBIncrementalSender<br/>rank = global_rank"]
    RT --> RELAY["NodeRelay<br/>(tags forwarded msgs)"]
Loading

6. Component Design — Transport Layer

6.1 Architecture Overview

graph LR
    subgraph "Per-Rank Process"
        RC["ResilientTCPClient"]
    end
    subgraph "Node-Leader Process (local_rank=0)"
        NR["NodeRelay<br/>:29766 loopback server<br/>+upstream client"]
        RC2["ResilientTCPClient<br/>(own telemetry)"]
    end
    subgraph "GlobalAggregator"
        HS["HighCapacityTCPServer<br/>(epoll/select)<br/>0.0.0.0:29765"]
    end

    RC  -->|"loopback :29766"| NR
    RC2 -->|"loopback :29766"| NR
    NR  -->|"inter-node TCP"| HS
Loading

6.2 ResilientTCPClient

Replaces the current no-reconnect TCPClient.

File: traceml/transport/tcp_transport.py

Policy:
  - On connect failure: retry at 0.1s → 0.2s → 0.4s → ... capped at 30s
  - After 10 consecutive failures: log warning, drop frame, keep retrying
  - send() never raises, never blocks training
class ResilientTCPClient:
    INITIAL_BACKOFF = 0.1   # seconds
    MAX_BACKOFF     = 30.0  # seconds

    def send(self, payload: dict) -> None:
        if not self._ensure_connected():
            self._logger.warning("[TraceML] TCPClient: dropping frame (not connected)")
            return
        data   = msgspec.msgpack.encode(payload)
        header = struct.pack("!I", len(data))
        with self._lock:
            self._sock.sendall(header + data)

    def _ensure_connected(self) -> bool:
        if self._connected:
            return True
        now = time.monotonic()
        if now < self._next_retry_at:
            return False          # backoff not elapsed yet
        try:
            self._connect()
            self._backoff = self.INITIAL_BACKOFF
            return True
        except Exception:
            self._backoff = min(self._backoff * 2, self.MAX_BACKOFF)
            self._next_retry_at = now + self._backoff
            return False

6.3 NodeRelay

Runs as a daemon thread inside the node-leader process (local_rank == 0).

File: traceml/transport/node_relay.py

class NodeRelay:
    """
    In-process relay: accepts local-rank telemetry on loopback,
    tags each message with node identity, forwards to GlobalAggregator.

    Why in-process (not a separate process)?
      - Shares the training lifecycle automatically.
      - No extra firewall rule needed for a new sidecar port.
      - Simpler orchestration from the CLI.
    """

    def _tag_message(self, msg: dict) -> dict:
        msg["node_rank"]        = self._identity.node_rank
        msg["hostname"]         = self._identity.hostname
        msg["local_world_size"] = self._identity.local_world_size
        return msg

    def _forward_loop(self) -> None:
        while not self._stop_event.is_set():
            for msg in self._local_server.poll():
                self._upstream.send(self._tag_message(msg))
            self._stop_event.wait(0.05)   # 50 ms forward cadence

Why 50 ms cadence? It adds at most 50 ms of display lag (imperceptible at human-scale refresh rates) while batching many small sends into fewer TCP writes, reducing per-message overhead across 64+ ranks.

6.4 HighCapacityTCPServer

Replaces the current thread-per-connection server for the GlobalAggregator.

File: traceml/transport/tcp_server_async.py

Current:  One daemon thread spawned per client connection → O(world_size) threads
Proposed: Single thread using selectors.DefaultSelector (epoll on Linux)
          → O(1) threads regardless of world_size
class HighCapacityTCPServer:
    def __init__(self, cfg: TCPConfig, world_size: int):
        self._backlog = max(64, world_size * 2)
        self._sel     = selectors.DefaultSelector()
        self._queue   = queue.Queue(maxsize=10_000)

    def _run(self) -> None:
        self._sel.register(self._sock, selectors.EVENT_READ, data="accept")
        while not self._stop_event.is_set():
            events = self._sel.select(timeout=0.5)
            for key, _ in events:
                if key.data == "accept":
                    self._accept()
                else:
                    self._recv(key)

6.5 Settings Extension

File: traceml/runtime/settings.py

@dataclass(frozen=True)
class TraceMLNodeRelaySettings:
    local_port:    int = 29766   # loopback port workers connect to on each node
    upstream_host: str = ""      # GlobalAggregator host (auto-filled from NODE_RANK=0 IP)
    upstream_port: int = 29765   # GlobalAggregator port

@dataclass(frozen=True)
class TraceMLSettings:
    # ... existing fields unchanged ...

    # NEW
    aggregator:       TraceMLTCPSettings       = TraceMLTCPSettings()
    relay:            TraceMLNodeRelaySettings = TraceMLNodeRelaySettings()
    world_size:       int  = 1
    nnodes:           int  = 1
    nproc_per_node:   int  = 1
    parallelism_mode: str  = "auto"   # "ddp" | "fsdp2" | "auto"

7. Component Design — Samplers

7.1 Change Matrix

Sampler Change Required Effort
TimeSampler Use global_rank in payload row Trivial
StepMemorySampler Use global_rank Trivial
ProcessSampler Tag with global_rank; keep local_rank for CUDA device Trivial
StdoutStderrSampler Use global_rank Trivial
LayerForward/BackwardTimeSampler Use global_rank Trivial
LayerForward/BackwardMemorySampler Use global_rank Trivial
SystemSampler Run on every node-leader, not just global_rank==0; add node_rank to wire schema Small

7.2 SystemSampler — Multi-Node Fix

File: traceml/runtime/runtime.py::_build_samplers()

def _build_samplers(self, identity: RankIdentity) -> List[BaseSampler]:
    samplers: List[BaseSampler] = []

    # Changed: every node leader (not just global rank0) runs SystemSampler
    if identity.is_node_leader:
        samplers.append(SystemSampler(node_rank=identity.node_rank,
                                      hostname=identity.hostname))

    samplers += [
        ProcessSampler(identity=identity),
        LayerMemorySampler(),
        # ... rest unchanged ...
    ]
    return samplers

Wire schema addition in SystemSample.to_wire():

{ ..., "node_rank": self.node_rank, "hostname": self.hostname }

8. Component Design — Aggregator & Store

8.1 RemoteDBStore — Adding the Node Dimension

File: traceml/database/remote_database_store.py

@dataclass
class RankMetadata:
    global_rank:     int
    local_rank:      int
    node_rank:       int
    hostname:        str
    local_world_size: int

class RemoteDBStore:
    def __init__(self, max_rows: int = 500):
        # Primary store: global_rank is unambiguous across the whole job
        self._dbs:        Dict[int, Dict[str, Database]] = {}
        self._last_seen:  Dict[int, float] = {}

        # NEW: identity index consumed by renderers
        self._rank_meta:      Dict[int, RankMetadata] = {}   # global_rank → meta
        self._node_to_ranks:  Dict[int, List[int]]    = {}   # node_rank   → [global_ranks]

    def ingest(self, message: dict) -> None:
        global_rank = message.get("rank")     # sender always sets this to global_rank (v2)
        node_rank   = message.get("node_rank",  0)
        hostname    = message.get("hostname",  "unknown")
        local_rank  = message.get("local_rank", global_rank)

        # Register identity
        if global_rank not in self._rank_meta:
            self._rank_meta[global_rank] = RankMetadata(
                global_rank=global_rank, local_rank=local_rank,
                node_rank=node_rank, hostname=hostname,
                local_world_size=message.get("local_world_size", 1),
            )
            self._node_to_ranks.setdefault(node_rank, []).append(global_rank)

        # Rest unchanged — tables ingested into per-global_rank Database
        ...

    def get_nodes(self) -> List[int]:
        return sorted(self._node_to_ranks.keys())

    def get_ranks_for_node(self, node_rank: int) -> List[int]:
        return sorted(self._node_to_ranks.get(node_rank, []))

8.2 Cross-Node Summary Computation

flowchart TD
    STORE["RemoteDBStore"]
    subgraph "MultiNodeAggregator.compute_step_summary()"
        A["Collect step_time_ms<br/>per global_rank"]
        B["Group by node_rank<br/>→ per-node median"]
        C["Global median<br/>Global worst rank<br/>Global skew %"]
        D["Node skew %<br/>(slowest node vs fastest node)"]
    end
    STORE --> A --> B --> C --> D
    D -->|"StepSummary"| DISP["DisplayDriver.tick()"]
Loading
class MultiNodeAggregator:
    def compute_step_summary(self, store: RemoteDBStore, step: int) -> StepSummary:
        per_rank_ms = {
            r: self._get_step_ms(store, r, step)
            for r in store.ranks()
        }
        per_rank_ms = {r: v for r, v in per_rank_ms.items() if v is not None}

        per_node_median = {
            n: statistics.median([per_rank_ms[r] for r in store.get_ranks_for_node(n)
                                  if r in per_rank_ms])
            for n in store.get_nodes()
        }

        global_median    = statistics.median(per_rank_ms.values())
        worst_rank       = max(per_rank_ms, key=per_rank_ms.get)
        global_skew_pct  = (per_rank_ms[worst_rank] - global_median) / global_median * 100

        slowest_node     = max(per_node_median, key=per_node_median.get)
        node_skew_pct    = (
            (per_node_median[slowest_node] - min(per_node_median.values()))
            / min(per_node_median.values()) * 100
        )

        return StepSummary(
            per_rank_ms=per_rank_ms,
            per_node_median=per_node_median,
            global_median=global_median,
            worst_global_rank=worst_rank,
            global_skew_pct=global_skew_pct,
            node_skew_pct=node_skew_pct,
            slowest_node=slowest_node,
        )

9. Component Design — Renderers

9.1 CLI Dashboard — Node Summary Panel

The CLI (aggregator/display_drivers/cli.py) gains a Node Summary row rendered with Rich:

┌─ TraceML — Multi-Node DDP  (2 nodes × 4 GPUs = 8 ranks) ──────────────────────────┐
│  SESSION abc123  │  STEP 142  │  GLOBAL SKEW 3.2%  │  STRAGGLER rank5             │
├─ NODE SUMMARY ─────────────────────────────────────────────────────────────────────┤
│  node-0.cluster   median=134ms  worst=139ms (rank2)  skew=3.7%                    │
│  node-1.cluster   median=136ms  worst=141ms (rank5)  skew=3.7%  ← +1.5% vs node0 │
├─ RANK GRID (step time ms) ─────────────────────────────────────────────────────────┤
│  rank0  132   rank1  131   rank2  139   rank3  135                                 │
│  rank4  134   rank5  141 ▲ rank6  137   rank7  133                                 │
└────────────────────────────────────────────────────────────────────────────────────┘

9.2 Web Dashboard — Node Tabs

In the NiceGUI dashboard (aggregator/display_drivers/nicegui.py), a top-level ui.tabs component contains:

  • Overview tab — global step timeline, rank grid heatmap, skew history
  • Node N tabs (one per node) — per-rank step sparklines, that node's system metrics (CPU/RAM/GPU from its own SystemSampler), per-layer timing
graph LR
    overview["Overview Tab<br/>• Global step timeline<br/>• Rank grid heatmap<br/>• Skew history"]
    node0["Node 0 Tab<br/>• Ranks 0-3 sparklines<br/>• CPU/RAM/GPU (node-0)<br/>• Layer timing"]
    node1["Node 1 Tab<br/>• Ranks 4-7 sparklines<br/>• CPU/RAM/GPU (node-1)<br/>• Layer timing"]
    tabs["ui.tabs"] --> overview
    tabs --> node0
    tabs --> node1
Loading

10. CLI & Launcher

10.1 New Flags

traceml run train.py \
    --nnodes           2              \   # number of nodes
    --nproc-per-node   4              \   # GPUs per node
    --master-addr      node0.cluster  \   # torchrun master (also default aggregator host)
    --master-port      29500          \   # torchrun rendezvous
    --aggregator-host  node0.cluster  \   # explicit aggregator override (optional)
    --aggregator-port  29765          \   # aggregator TCP port
    --relay-port       29766              # per-node relay loopback port

All existing single-node flags remain unchanged.

10.2 Multi-Node Launch Flow

sequenceDiagram
    participant OP  as Operator
    participant CLI as cli.py
    participant AGG as GlobalAggregator (node0)
    participant TR  as torchrun (all nodes)
    participant EX  as executor.py (per rank)
    participant NR  as NodeRelay (node leaders)

    OP->>CLI: traceml run train.py --nnodes=2 ...
    CLI->>CLI: generate SESSION_ID
    CLI->>AGG: spawn aggregator on AGGREGATOR_HOST:29765
    CLI->>CLI: wait for AGG TCP ready
    CLI->>TR:  torchrun --nnodes=2 --master_addr=... executor.py
    TR->>EX:   start worker processes on every node
    EX->>EX:   get_rank_identity()
    EX->>NR:   start NodeRelay (if is_node_leader and nnodes > 1)
    EX->>EX:   start TraceMLRuntime
    EX->>NR:   workers connect to :29766 (loopback)
    NR->>AGG:  forward tagged rows (inter-node TCP)
    CLI->>TR:  wait for torchrun exit
    TR-->>CLI: exit code
    CLI->>AGG: terminate
    CLI-->>OP: exit with training code
Loading

10.3 Backward Compatibility

When --nnodes=1 (the default):

  • NodeRelay is not started
  • ResilientTCPClient connects directly to TRACEML_AGGREGATOR_HOST (127.0.0.1 default)
  • The aggregator listens on 127.0.0.1 (not 0.0.0.0)
  • Zero behaviour change for existing users

11. Wire Protocol & Schema Versioning

11.1 Version Field

A _v field is added to every message payload for forward compatibility.

Field v1 (current) v2 (proposed)
_v (absent) 2
rank LOCAL_RANK GLOBAL_RANK
global_rank (absent) GLOBAL_RANK
local_rank (absent) LOCAL_RANK
node_rank (absent) NODE_RANK
hostname (absent) socket.gethostname()
local_world_size (absent) nproc_per_node

11.2 Aggregator Compatibility Rule

def ingest(self, message: dict) -> None:
    v = message.get("_v", 1)
    if v == 1:
        # Legacy single-node: assume global_rank == rank, node_rank == 0
        global_rank = message.get("rank", 0)
        node_rank   = 0
        hostname    = "localhost"
    else:
        global_rank = message["rank"]   # sender sets rank = global_rank in v2
        node_rank   = message.get("node_rank", 0)
        hostname    = message.get("hostname", "unknown")
    ...

Old single-node clients connecting to a new aggregator work transparently.


12. Migration Plan

Migration is structured in four independent phases. Each phase ships a working, tested, backward-compatible system.

gantt
    title TraceML Multi-Node Migration Phases
    dateFormat  YYYY-MM-DD
    section Phase 1 — Identity Fix
    distributed.py RankIdentity      :p1a, 2025-01-01, 5d
    global_rank in wire format       :p1b, after p1a, 3d
    RemoteDBStore global_rank key    :p1c, after p1b, 3d
    Renderer label updates           :p1d, after p1c, 2d

    section Phase 2 — Resilient Transport
    ResilientTCPClient               :p2a, after p1d, 5d
    HighCapacityTCPServer            :p2b, after p2a, 5d
    Settings expansion               :p2c, after p2b, 2d

    section Phase 3 — Multi-Node Launch
    NodeRelay implementation         :p3a, after p2c, 7d
    CLI --nnodes flags               :p3b, after p3a, 4d
    SystemSampler all-node fix       :p3c, after p3a, 3d
    Node tabs in renderers           :p3d, after p3b, 5d

    section Phase 4 — FSDP2
    detect_parallelism_mode          :p4a, after p3d, 3d
    FSDP2Sampler                     :p4b, after p4a, 7d
    _trace_step_fsdp2                :p4c, after p4b, 4d
    FSDP2 renderer view              :p4d, after p4c, 4d
Loading
Phase Deliverable Risk Outcome
1 RankIdentity, v2 wire format, global_rank store key Low Single-node DDP correctly labelled
2 ResilientTCPClient, HighCapacityTCPServer Medium Transport robust to 64+ ranks and blips
3 NodeRelay, --nnodes CLI, per-node SystemSampler, node tabs Medium-High Multi-node DDP fully supported
4 FSDP2Sampler, _trace_step_fsdp2, FSDP2 renderer High FSDP2 step + shard visibility

13. Performance Analysis

13.1 Overhead Budget

Target: < 0.5% step-time overhead added by TraceML on multi-node.

Cost Source Estimate Path
get_rank_identity() < 1 µs Called once, cached in TraceMLRuntime
NodeRelay._forward_loop() ~50 µs per 50 ms tick Background thread — off the training critical path
ResilientTCPClient.send() 100–500 µs per call Background sampler thread, sendall with header
Extra wire bytes per message +~100 bytes Identity fields in msgpack
SystemSampler (each node-leader) ~500 µs per sample Runs in sampler thread, not training thread

At 100 ms step times and 1 s sampler interval, the per-step overhead budget is 500 µs. All costs above are either < 500 µs total or are off the training thread entirely.

13.2 NodeRelay Latency

The relay adds at most 50 ms of end-to-end display lag (the forward tick cadence). At a human-perceptible refresh rate (1–2 seconds), this is imperceptible. Training code never waits on the relay.

13.3 Aggregator Memory at Scale

world_size × samplers × max_rows × avg_row_bytes
= 64 × 10 × 200 × 1,024
≈ 131 MB

Above 32 ranks, we recommend reducing --remote-max-rows to 50 (default will adapt automatically when world_size > 32), bringing memory to ~33 MB.


14. FSDP2 Support

Note: FSDP2 (torch.distributed.fsdp.fully_shard(), available since PyTorch 2.2) uses a fundamentally different parameter ownership model. This section covers the specific changes required and why naive hook-based approaches fail.

14.1 How FSDP2 Changes Training

sequenceDiagram
    participant TR   as Training Step
    participant F2   as FSDP2 Runtime
    participant ALLG as All-Gather NCCL
    participant FW   as Layer Forward
    participant RS   as Reduce-Scatter NCCL
    participant OPTIM as Optimizer

    TR->>F2:   loss = model(batch)
    loop For each FSDP layer in forward order
        F2->>ALLG:  All-gather shard to full params
        ALLG-->>F2: params materialised
        F2->>FW:    layer.forward(input)
        Note over F2: params optionally freed back to shard
    end
    TR->>F2:   loss.backward()
    loop For each FSDP layer in backward order
        F2->>ALLG:  All-gather shard to full params
        ALLG-->>F2: params materialised
        F2->>RS:    Reduce-scatter grads to shard
    end
    TR->>OPTIM: optimizer.step()
    Note over OPTIM: Works on local shards only
Loading

14.2 Why Layer Hooks Break Under FSDP2

Hook type What happens under FSDP2
register_forward_hook on nn.Linear nn.Linear is now inside an FSDPModule. The hook may fire before the all-gather completes — parameters not yet materialised
register_forward_pre_hook Same issue — fires before un-sharding
Memory reading inside hook Only shard-local fragments visible; reads partial parameter memory
Layer name lookup FSDP2 may rename submodules during fully_shard() — name-to-module mapping breaks

Consequence: In FSDP2 mode, trace_model_instance() must skip all layer-level hooks and fall back to root-only + shard-aware measurement.

14.3 What We Can and Cannot Measure

Metric Feasible? Method
Total step time ✅ Easy CUDA events on trace_step boundary
Dataloader time ✅ Easy DataLoader patch (unchanged)
Forward time (total) ✅ Easy Single hook on FSDP root module only
Backward time ✅ Easy loss.backward() CUDA events
Optimizer step time ✅ Easy Optimizer wrapper hook
Local shard memory ✅ Easy torch.cuda.memory_allocated() at step end
Estimated full model memory ✅ Easy local_shard × world_size
Per-layer forward time ⚠️ Unsafe Hooks fire during incomplete un-shard; skip in FSDP2
Per-layer activation memory ⚠️ Unsafe Only shard fragments visible; skip in FSDP2
All-gather duration ⚠️ Experimental FSDP2 internal event hooks (version-gated, may break)

14.4 Auto-Detection

File: traceml/utils/fsdp_utils.py

def detect_parallelism_mode(model: nn.Module) -> str:
    """Returns: "fsdp2" | "ddp" | "single" """
    try:
        from torch.distributed.fsdp import FSDPModule
        if any(isinstance(m, FSDPModule) for m in model.modules()):
            return "fsdp2"
    except (ImportError, AttributeError):
        pass

    try:
        from torch.nn.parallel import DistributedDataParallel as DDP
        if isinstance(model, DDP):
            return "ddp"
    except ImportError:
        pass

    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return "ddp"

    return "single"

Result is cached on the trace_step context manager after the first call.

14.5 FSDP2Sampler

File: traceml/samplers/fsdp2_sampler.py

class FSDP2Sampler(BaseSampler):
    """
    Shard-aware memory sampler for FSDP2 workloads.

    Tracks:
      - local_param_bytes  : bytes of parameter shards owned by this rank
      - estimated_full_bytes: local × world_size (assumes even sharding)
      - activation_bytes   : peak torch.cuda.memory_allocated() during forward
      - step               : current global step count

    Design constraints:
      - NO per-layer hooks (they break under FSDP2)
      - Single root-module forward hook for activation sampling
      - All NVML / torch.cuda calls are best-effort
    """

    def __init__(self, model: nn.Module, identity: RankIdentity):
        super().__init__(sampler_name="FSDP2Sampler")
        self._model    = model
        self._identity = identity
        self._post_forward_mem: float = 0.0
        self._attach_root_hook(model)

    def _attach_root_hook(self, model: nn.Module) -> None:
        """Single post-forward hook on the FSDP root — safe and correct."""
        def _hook(module, inputs, output):
            try:
                # All-gather buffers are still live here → peak shard activation
                self._post_forward_mem = float(torch.cuda.memory_allocated())
            except Exception:
                pass
        model.register_forward_hook(_hook)

    def _local_param_bytes(self) -> int:
        """Sum of all parameter shard sizes on this rank."""
        return sum(
            p.numel() * p.element_size()
            for p in self._model.parameters()
        )

    def sample(self) -> None:
        try:
            local_bytes = self._local_param_bytes()
            row = {
                "step":                  TraceState.step,
                "global_rank":           self._identity.global_rank,
                "node_rank":             self._identity.node_rank,
                "local_param_bytes":     local_bytes,
                "estimated_full_bytes":  local_bytes * self._identity.world_size,
                "activation_bytes":      self._post_forward_mem,
                "timestamp":             time.time(),
            }
            self.db.add_record("FSDP2Table", row)
        except Exception as e:
            self.logger.error(f"[TraceML] FSDP2Sampler error: {e}")

14.6 FSDP2-Safe trace_step

File: traceml/decorators.py

@contextmanager
def trace_step(model: nn.Module):
    mode = _get_parallelism_mode(model)   # cached
    if mode == "fsdp2":
        with _trace_step_fsdp2(model):
            yield
    else:
        with _trace_step_ddp(model):      # existing path, unchanged
            yield

@contextmanager
def _trace_step_fsdp2(model: nn.Module):
    """
    FSDP2 step boundary.
    - No layer hooks.
    - Shard memory tracked via FSDP2Sampler (registered separately).
    - Step timing via CUDA events on the outer boundary.
    """
    step_completed = False
    try:
        with timed_region("_traceml_internal:step_time", scope="step", use_gpu=True):
            with forward_auto_timer(), backward_auto_timer():
                ensure_optimizer_timing_installed()
                yield
                step_completed = True
    finally:
        if step_completed:
            TraceState.step += 1
        flush_step_events(model, TraceState.step)

14.7 trace_model_instance Safety Guard

def trace_model_instance(model: nn.Module, ...) -> None:
    if _is_fsdp2(model):
        print(
            "[TraceML] FSDP2 detected — layer-level hooks are not attached "
            "(parameter shards not materialised during hook execution). "
            "Use FSDP2Sampler for shard-aware memory tracking.",
            file=sys.stderr,
        )
        return
    # ... existing DDP hook path unchanged ...

14.8 FSDP2 Display

The renderer replaces the layer timing table with a shard memory summary when parallelism_mode == "fsdp2":

FSDP2 Shard Memory  (global_rank=2, node_rank=0)
  Local shard:           3.9 GB
  Estimated full model: 31.2 GB  (local × 8 ranks)
  Peak activation:       5.1 GB  (post-forward, all-gather buffers live)
  Step time:           142 ms

End-State Architecture Diagram

graph TD
    subgraph "Operator"
        CLI["traceml run"]
    end

    subgraph "GlobalAggregator  node0:29765"
        HS["HighCapacityTCPServer<br/>epoll, backlog=max(64,WS×2)"]
        STORE["RemoteDBStore<br/>global_rank × sampler × DB<br/>+ node index"]
        MNA["MultiNodeAggregator<br/>cross-rank + cross-node stats"]
        UI["CLI / NiceGUI<br/>Node tabs, Rank grid,<br/>Shard memory view"]
        HS --> STORE --> MNA --> UI
    end

    subgraph "Node 0"
        subgraph "rank 0  [NODE LEADER]"
            R0["TraceMLRuntime<br/>identity: global_rank=0"]
            NR0["NodeRelay :29766"]
            R0 -->|"loopback"| NR0
        end
        R1["rank 1"] -->|"loopback :29766"| NR0
        R2["rank 2"] -->|"loopback :29766"| NR0
        R3["rank 3"] -->|"loopback :29766"| NR0
        NR0 -->|"loopback"| HS
    end

    subgraph "Node 1"
        subgraph "rank 4  [NODE LEADER]"
            R4["TraceMLRuntime<br/>identity: global_rank=4"]
            NR1["NodeRelay :29766"]
            R4 -->|"loopback"| NR1
        end
        R5["rank 5"] -->|"loopback :29766"| NR1
        R6["rank 6"] -->|"loopback :29766"| NR1
        R7["rank 7"] -->|"loopback :29766"| NR1
        NR1 -->|"inter-node TCP"| HS
    end

    CLI --> GlobalAggregator
    CLI -->|"torchrun"| Node0
    CLI -->|"torchrun"| Node1
Loading

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions