Skip to content

Support co-locate training and inference (#81)#92

Open
zhubohao911 wants to merge 14 commits into
lightseekorg:mainfrom
zhubohao911:feature/colocate-training-inference
Open

Support co-locate training and inference (#81)#92
zhubohao911 wants to merge 14 commits into
lightseekorg:mainfrom
zhubohao911:feature/colocate-training-inference

Conversation

@zhubohao911
Copy link
Copy Markdown
Collaborator

@zhubohao911 zhubohao911 commented May 7, 2026

Implements #81 — colocate the Eagle3 / DFlash draft trainer and the sglang target engine on the same physical GPUs via CUDA MPS, removing the disaggregated layout's idle-GPU waste.

Gated behind colocate_strategy=mps + transfer_mode=nccl; the disaggregated baseline is untouched when the flags are off.

📋 Full iteration history — every phase, follow-up round, N>1 bug fix, architectural correction, and rented-GPU validation run — is preserved in docs/colocate/implementation_log/pr92_detail.md. This description is the concise current summary.

Status — feature-complete, GPU-validated at production scale

Phases 0–8 plus follow-up rounds 1–12. run_smoke_host.sh --full is green on 4×H100 under the CUDA IPC default transport (single-node). Two production-scale 20000-step / 40k-sample Qwen3-8B colocate runs landed clean on 2×H100 with rc=0 — Eagle3 (CE1, round 11) and DFlash (C1, round 12) — each beating the same-SGLang disagg baseline rerun on main in GPU-h.

How it works

Layer Mechanism
GPU sharing CUDA MPS — trainer + engine share each physical GPU; fractional Ray bundles; memory split train_frac + infer_frac + 0.10 ≤ 1
Distributed Union NCCL world (2N ranks = N trainers + N engines) → trainer FSDP subgroup + engine TP + an all-rank gloo meta_group
Hidden-state transport CUDA IPC zero-copy (default) — engine exports a CUDA IPC handle, trainer maps it + one on-device D→D copy. TORCHSPEC_COLOCATE_IPC=0 → gloo CPU-staging fallback. ~170× faster than gloo on realistic payloads. Optional TORCHSPEC_COLOCATE_IPC_PIPELINE=1 adds a send-buffer pool + ack pipelining
sglang side vendored colocate.patch (v0.5.10.post1 default; v0.5.8.post1 via SGLANG_PATCH_VERSION)

Production-scale benchmark — colocate beats disagg on GPU-h

Both cells matched against the same-SGLang disagg rerun on origin/main (dflash_eagle3_disagg_modal_rerun_on_main.md), retiring the cross-branch confound earlier benchmark versions carried.

Cell Steps Samples Throughput GPU-h Disagg baseline Win
CE1 — Eagle3 2+2 colocate (round 11) 20000 40000 ~13.25 samples/s 1.68 / 40k (2 GPU) E1-rerun = 12.72 samples/s, 3.49 / 40k (4 GPU) ~2.1× less GPU-h
C1 — DFlash 2+2 colocate (round 12) 20000 40000 7.51 samples/s 2.96 / 40k (2 GPU) D1-rerun = 10.00 samples/s, 4.44 / 40k (4 GPU) ~1.5× less GPU-h

Wins decompose as 2.0× (half the GPU count via MPS sharing) × r (raw-throughput ratio): Eagle3 r ≈ 1.0; DFlash r ≈ 0.75 (heavier trainer → more MPS contention). The architectural saving is reclaiming the idle disagg inference GPUs. Convergence holds for both cells — CE1's final rolling loss (~2.09) matches disagg E1 (2.24 / 1.98); C1's (~3.81) sits inside the disagg D1 noise band (3.67 / 4.89). Full analysis: colocate_benchmark.md.

Validation highlights

  • --full green on 4×H100 under CUDA IPC default — test_phase4_tiny_one_step, test_phase7_tiny_loss_decreases, 4-GPU test_phase4_one_step_completes_end_to_end, grad parity (smoke/determinism/full), checkpoint save+resume, test_colocate_ipc, test_colocate_tp2, test_colocate_multi_engine, 200-step test_phase6_peak_alloc_flatness, 50-step test_phase7_convergence_loss_decreases.
  • 4 latent dp_size>1 bugs found by --full that the 1-GPU smoke could not surface, plus 1 same-shape bug found preemptively by audit — all fixed (detail in pr92_detail.md).
  • Round 9 — CUDA IPC default hung at step 0 under MPS. Root cause: a destructive capability probe (reduce_tensor smoke test) shared a tensor over IPC and discarded it, poisoning the MPS context. Fixed — non-destructive probe (e166c21) + active expandable-segments handling for IPC actors (e62c941).
  • Round 10 — transport optimization investigated. No hand-written C++/CUDA/Triton kernel is needed (the path is a bandwidth-bound D→D copy plus driver-API calls); ipc-pipe ack-pipelining is a 3.9× protocol-level win on the engine send() stall but low-priority (the transport is ~1 % of a colocate step). A 3000-step 4-GPU stability soak ran clean.
  • Round 11 — ipc-pipe productionized + issue-[Feature] Support co-locate training and inference #81 follow-ups GPU-validated in one 4×H100 pod session: --stability 1000-step (green), convergence vs Mooncake-disagg (loss curves overlap 0.006 % mean / 0.219 % max over 1000 steps), Qwen3-8B grad-parity smoke (green). The pipelined-transport --full run found + fixed one OOM on the memory-tight 8B config (pool retired-buffer leak + grow overshoot).
  • Round 12 — DFlash colocate brought up at production scale. Two distinct, sequential DFlash-only deadlocks were surfaced and fixed in f28dc73:
    1. DFlashTrainer._init_target_lm_head — 5 trainer-side collectives (dist.barrier / broadcast / 3× all_reduce) ran with no group=, defaulting to the union world; only trainers entered the method → deadlock. Scoped to get_gloo_group() (same shape as round-7 set_model_state_dict / dcp.save fixes; Eagle3Trainer already carried it).
    2. colocate_loop._build_tensor_specs — trainer declared 2 tensors, engine always sent 3 (last_hidden_states is unconditional in the colocate engine), per-tensor IPC ack handshake blocked on the third. Spec now always declares last_hidden_states; store_last_hidden_states removed.

Docs

Doc What
pr92_detail.md full iteration history of this PR, in depth (rounds 1–12)
implementation_log.md debug log — RunPod / Vast sessions + follow-up rounds 1–12
colocate_benchmark.md disagg-vs-colocate study (CE1 + C1 done; CE2 / C2 pending)
transport_benchmark.md gloo-vs-CUDA-IPC transport benchmark
transport_optimization.md transport kernel-vs-protocol investigation + MPS-validated A/B
handoff_followups.md open follow-ups, self-contained handoff

Open follow-ups (tracked, not blocking this PR)

  • CE2 / C2 benchmark cells (4+4 colocate) — code-ready, unrun; needs one 4×H100 pod and a matched 40k-sample run per cell against the existing disagg E2 / D2 rerun-on-main baselines. Next productive item that does not need new hardware beyond a 4-GPU pod.
  • Multi-node 2-node colocate run — code-complete (ensure_mps_on_all_nodes, 2-node config) but untested at scale; needs a 2-node rented cluster.
  • Large engine_tp_size (8-GPU TP per engine) — rank math + data plane handle any TP size, only GPU-tested at engine_tp_size=2.
  • Colocate fail-fast for spec / default-PG mismatches — round 12's two silent-deadlock failure modes could be turned into immediate errors via a runtime check in Trainer.__init__ (assert default PG ≠ union world) and a step-0 watchdog that dumps both sides' tensor specs on mismatch. Small change, high value.
  • draft_accumulation_steps > 1 in colocate_loop.py — guarded with NotImplementedError("Multi-step accumulation is parked"); out of scope unless the benchmark needs cell-for-cell parity with disagg §8.
  • v0.5.10 pp_size>1 — blocked by an explicit guard, out of scope for the current colocate plan.

Environment

The bundled sgl_kernel wheel ships sm90+ kernels only (no Ampere / Ada) — real GPU testing is H100 / H200 / B200.

@zhubohao911 zhubohao911 force-pushed the feature/colocate-training-inference branch from bf2d468 to 927beaa Compare May 14, 2026 20:28
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 20, 2026
…s snapshot

* implementation_log.md -- new "Follow-up issues — PR lightseekorg#92 review
  items" section covering all seven follow-ups (commit, status,
  rationale) plus a validation matrix. The status-snapshot table at
  the top is corrected: phases 2/4/5/6/7 are done and green (the
  colocate.patch is vendored in-repo, not a pending upstream
  dependency) -- the old "pending upstream patch" / NotImplementedError
  notes were stale relative to sessions lightseekorg#1-lightseekorg#5.
* usage.md -- Known limitations refreshed: multi-node is implemented
  (untested at scale) not "single-node only"; the NotImplementedError
  note is gone; engine tp_size>1 status added; new section on the
  gloo-vs-CUDA-IPC transport and the TORCHSPEC_COLOCATE_IPC opt-in.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 20, 2026
All seven PR lightseekorg#92 follow-ups validated on rented GPUs:
* 1xH100 — patch apply, tiny smoke, TP rank math, grad-parity
  determinism, checkpoint save/resume.
* 2xH100 — grad-parity determinism re-confirmed; grad-parity-full's
  Mooncake disagg baseline SIGSEGVs (third-party-lib env issue).
* 4xH200 — run_smoke_host.sh --full: 10 passed, 1 skipped, exit 0
  (incl. 4-GPU one_step + grad_parity_smoke, 200-step stability,
  convergence, CUDA IPC e2e).

Records the 8 bugs found+fixed during validation and the CUDA IPC /
pidfd_getfd / CAP_SYS_PTRACE capability finding.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 21, 2026
…nds 6 & 7)

The log folded the Mooncake-disagg crash FIX into round 4 and titled the
CUDA-IPC-default work as an unnumbered section, while PR lightseekorg#92 numbered
them round 6 and round 7. Split the Mooncake FIX out of round 4 into its
own "Follow-up round 6" (after round 5, matching the chronology — the
fix landed after round 5 was recorded) and renumber the CUDA-IPC-default
section "Follow-up round 7". This also makes store.py's cross-reference
to "implementation_log.md Follow-up round 6" resolve.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 21, 2026
…be fix

Adds the implementation-log entry for the IPC-default hang: the 4xH100
--full run hung at colocate step 0; isolated on 1xH100 (gloo passes, IPC
hangs); root-caused to probe_ipc_capability()'s destructive reduce_tensor
smoke test wedging CUDA under MPS; fixed in e166c21 (non-destructive
probe). Pairs with PR lightseekorg#92 follow-up round 9.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 21, 2026
…writeup

The PR lightseekorg#92 description had accumulated ~270 lines of phase / round /
N>1-bug / validation detail. Preserve the full detailed narrative here
(through round 10) so the PR body can be trimmed to a concise current
summary without losing the history.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 21, 2026
Commits the colocate handoff doc (previously untracked) and brings it
current: round-10 transport-optimization grounding, a leftover-items
row for productionizing ipc-pipe, and a note that the PR lightseekorg#92 body was
rewritten concise with full detail preserved in pr92_detail.md.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@zhubohao911 zhubohao911 force-pushed the feature/colocate-training-inference branch from c2d5d71 to 338fccc Compare May 21, 2026 08:58
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 21, 2026
…s snapshot

* implementation_log.md -- new "Follow-up issues — PR lightseekorg#92 review
  items" section covering all seven follow-ups (commit, status,
  rationale) plus a validation matrix. The status-snapshot table at
  the top is corrected: phases 2/4/5/6/7 are done and green (the
  colocate.patch is vendored in-repo, not a pending upstream
  dependency) -- the old "pending upstream patch" / NotImplementedError
  notes were stale relative to sessions lightseekorg#1-lightseekorg#5.
* usage.md -- Known limitations refreshed: multi-node is implemented
  (untested at scale) not "single-node only"; the NotImplementedError
  note is gone; engine tp_size>1 status added; new section on the
  gloo-vs-CUDA-IPC transport and the TORCHSPEC_COLOCATE_IPC opt-in.
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 21, 2026
All seven PR lightseekorg#92 follow-ups validated on rented GPUs:
* 1xH100 — patch apply, tiny smoke, TP rank math, grad-parity
  determinism, checkpoint save/resume.
* 2xH100 — grad-parity determinism re-confirmed; grad-parity-full's
  Mooncake disagg baseline SIGSEGVs (third-party-lib env issue).
* 4xH200 — run_smoke_host.sh --full: 10 passed, 1 skipped, exit 0
  (incl. 4-GPU one_step + grad_parity_smoke, 200-step stability,
  convergence, CUDA IPC e2e).

Records the 8 bugs found+fixed during validation and the CUDA IPC /
pidfd_getfd / CAP_SYS_PTRACE capability finding.
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 21, 2026
…nds 6 & 7)

The log folded the Mooncake-disagg crash FIX into round 4 and titled the
CUDA-IPC-default work as an unnumbered section, while PR lightseekorg#92 numbered
them round 6 and round 7. Split the Mooncake FIX out of round 4 into its
own "Follow-up round 6" (after round 5, matching the chronology — the
fix landed after round 5 was recorded) and renumber the CUDA-IPC-default
section "Follow-up round 7". This also makes store.py's cross-reference
to "implementation_log.md Follow-up round 6" resolve.
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 21, 2026
…be fix

Adds the implementation-log entry for the IPC-default hang: the 4xH100
--full run hung at colocate step 0; isolated on 1xH100 (gloo passes, IPC
hangs); root-caused to probe_ipc_capability()'s destructive reduce_tensor
smoke test wedging CUDA under MPS; fixed in e166c21 (non-destructive
probe). Pairs with PR lightseekorg#92 follow-up round 9.
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 21, 2026
…writeup

The PR lightseekorg#92 description had accumulated ~270 lines of phase / round /
N>1-bug / validation detail. Preserve the full detailed narrative here
(through round 10) so the PR body can be trimmed to a concise current
summary without losing the history.
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request May 21, 2026
Commits the colocate handoff doc (previously untracked) and brings it
current: round-10 transport-optimization grounding, a leftover-items
row for productionizing ipc-pipe, and a note that the PR lightseekorg#92 body was
rewritten concise with full detail preserved in pr92_detail.md.
@zhubohao911 zhubohao911 force-pushed the feature/colocate-training-inference branch 2 times, most recently from c2d5d71 to 6c0e9c7 Compare May 21, 2026 09:15
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
@zhubohao911 zhubohao911 force-pushed the feature/colocate-training-inference branch from 6c0e9c7 to b82d64b Compare May 21, 2026 09:29
@zhubohao911 zhubohao911 changed the title [WIP] Support co-locate training and inference (#81) Support co-locate training and inference (#81) May 23, 2026
@zhubohao911 zhubohao911 marked this pull request as ready for review May 23, 2026 00:41
zhubohao911 and others added 2 commits May 22, 2026 17:45
DFlash training in colocate (MPS + NCCL) mode hung in two distinct,
sequential places. Both are fixed here; a 20000-step DFlash 2+2
colocate run now completes cleanly (rc=0, zero hang/NaN/OOM).

Hang lightseekorg#1 — DFlashTrainer._init_target_lm_head / metric reduction
  dist.barrier() / dist.broadcast() in _init_target_lm_head, plus the
  3 dist.all_reduce() in the per-position metric reduction, ran with no
  group= argument. In colocate mode that defaults to the union-world PG
  (trainer ranks [0,N) + engine ranks [N,2N)); only trainer ranks
  execute this code, so the engine ranks never arrive and the
  collective deadlocks. Scoped all five collectives to get_gloo_group()
  (the trainer-only group), mirroring Eagle3Trainer which already
  carries this fix. No-op for disagg, where get_gloo_group() is the
  whole world.

Hang lightseekorg#2 — colocate_loop._build_tensor_specs
  The trainer derived its per-step recv tensor specs from the
  training-side store_last_hidden_states flag, omitting
  last_hidden_states when false (DFlash's config). But the colocate
  engine always sends it: enable_return_hidden_states=True is set
  unconditionally, so sglang's _send_hidden_states_to_nccl always ships
  a non-None last_hidden_states. The per-tensor CUDA-IPC ack handshake
  then left the engine's send blocked forever waiting for an ack the
  trainer never sends (3 tensors sent, 2 declared). _build_tensor_specs
  now always declares last_hidden_states; draft trainers that do not
  consume it (DFlash) ignore the extra dict key.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
Signed-off-by: Xing Han <h13008009668@gmail.com>
@zhubohao911 zhubohao911 force-pushed the feature/colocate-training-inference branch from 1483973 to 2813fe3 Compare May 23, 2026 00:46
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 14839732e8

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +281 to +285
ColocateTrainSample(
step_id=completed_steps,
tensor_specs=specs,
packed_loss_mask=entry.packed_loss_mask,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Forward per-sample loss-mask flags in colocate samples

ColocateTrainSample is enqueued without last_turn_loss_only (or metadata-derived override), so colocate training drops the per-sample metadata.has_thinking behavior that the disaggregated path preserves in AsyncTrainingController._dispatch_to_queues. When dynamic_loss_mask is enabled on mixed datasets, those samples are trained with the wrong mask, changing supervised tokens and silently skewing loss/gradients.

Useful? React with 👍 / 👎.

Comment on lines +204 to +207
# Submit the dataset (epoch=0, skip=0). Resumption from non-zero
# step is handled the same way as the disagg loop, but we don't
# exercise it in tests yet.
ray.get(controller.submit_training_dataset.remote(epoch=0, skip=0))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Resume colocate datasets from checkpoint step

The colocate loop always re-submits the dataset as epoch=0, skip=0, even though completed_steps is read from trainer state. Resumed runs therefore restart data from the beginning instead of skipping already-consumed samples (unlike training_loop), which can replay data and alter training trajectories after checkpoint restore.

Useful? React with 👍 / 👎.

Conflict resolution:
- docs/ray.md: kept main's reworked Default/custom rows and colocate
  branch's mps-strategy row in one table.
- torchspec/ray/placement_group.py: kept the colocate gate as
  is_colocate_enabled(args), folded main's placement_strategy=="custom"
  branch into the colocate path, and preserved the colocate branch's
  is_mps_colocate engine-rank invariant check.

Signed-off-by: Xing Han <h13008009668@gmail.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c0646dda42

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +710 to +714
should_skip, skip_count = self._should_skip_for_loss_mask(
data, item.step_id, skip_count
)
if should_skip:
continue
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle skipped colocate samples without deadlocking

In colocate mode this branch drops the current sample, but the step loop only enqueues one ColocateTrainSample per rank and then waits for train_from_queue(num_batches=1) to finish. If dynamic_loss_mask or min_loss_tokens makes _should_skip_for_loss_mask true, that trainer blocks waiting for a replacement batch that is never queued until the next step, while the driver is already blocked waiting on this step’s trainer futures. This creates a deterministic deadlock on the first skipped sample.

Useful? React with 👍 / 👎.

Comment on lines +104 to +115
if not is_colocate_enabled(args):
# Disaggregated default: nothing to validate. We do, however, want to
# warn the user if they set strategy/frac fields by mistake without
# turning colocate on, since otherwise those fields silently no-op.
for stray in ("colocate_strategy", "train_frac", "infer_frac"):
if _get(args, stray) is not None:
raise ColocateConfigError(
f"training.{stray} was set but training.colocate=False. "
f"Either set training.colocate=true (or "
f"training.colocate_strategy=mps) or remove training.{stray}."
)
return
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reject nccl transfer mode when colocate is off

Validation returns early when colocate is disabled, so training.transfer_mode='nccl' is currently accepted in a disaggregated config. That activates colocate-only behavior (TrainerActor.init takes the union-world NCCL path and SglEngine.generate short-circuits with no mooncake outputs), which breaks the async inference pipeline that expects mooncake-keyed outputs and can stall training. This should fail fast in config validation unless the supported colocate tuple is selected.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant