feat(wo_split=8): opt-in K-parallel W_O GEMV prototype#8
Open
feat(wo_split=8): opt-in K-parallel W_O GEMV prototype#8
Conversation
Adds Phase_E_Beta_Kernel.wo_split (env CUTE_WO_SPLIT, default 1,
bounded by slice_ctas) and threads it through _coop_full_compile_key.
At wo_split=1 the kernel behavior is unchanged. Disk cache will
distinguish wo_split variants once subsequent tasks add a kernel
body change.
Verified: cache MISS observed on first launch with new key ("first
call for this config" log at phase_e_kernel.py:3170); smoke probe
returned coherent /v1/completions output post-warmup.
Task 1 of 12 (wo_split=8 production prototype, plan at
/home/natfii/.claude/plans/sorted-crafting-rainbow.md).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Unbundles total_ctas_per_seq_attn into two concepts:
- total_ctas_per_seq_attn (= num_kv_heads = 4): R1 attn-producer mask
- total_wo_slots (= num_kv_heads * wo_split = 4 at wo_split=1):
drives wo_output stride, gather loop, election target, counter reset
Slot-index formula bx*num_kv_heads+by stays legacy in this task; Task 8
lifts to by*wo_split+bx alongside the K-parallel kernel body.
At wo_split=1, total_wo_slots == 4 == legacy total_ctas_per_seq_attn,
so the address math is bit-exact. Cache key picks up new function
fingerprint via the Int32 arg addition.
Task 2 of 12.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the literal `4` in the β-coop _phase_e_coop_wo_output allocation with self.num_kv_heads * self._phase_e_coop_kernel.wo_split. At wo_split=1 the dim 1 is still 4 (no-op refactor); at wo_split=8 later, this expands to 32 slots. β-lite wo_output allocation at _backend.py:399-403 unchanged (β-lite uses a different code path, out of scope for this plan). Reset op _wo_output_reset_op.py shape preconditions are already generic (dim()==3); only the error-message string is updated to reflect the new shape semantic. Task 3 of 12. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds R11 (phase1_pre_wo_wait) and R12 (phase1_gather_reduce) to the
region-timing taxonomy. Host- and kernel-side constants bump together
to avoid the buffer-stride mismatch that an unbundled change would
introduce (host expects 13×16=208 bytes per CTA stride, kernel writes
11×16=176 — out-of-bounds or under-read between tasks).
Region classification is split: PHASE1_REGIONS stays {1,2,3};
WAIT_NOT_WORK_REGIONS gains R11 (consumer wait); new
DYNAMIC_SINGLE_CTA_REGIONS = {12} for the elected single-CTA gather.
This avoids the reducer's first-match-wins if/elif misclassifying
R11 as parallel phase1 work.
Files:
region_timing.py - REGION_NAMES (+2), region-class sets,
_phase1_wo_split_cta_ids helper, reducer branch
_backend.py - _REGION_TIMING_NUM_REGIONS = 13
phase_e_kernel.py - _region_timing_num_regions = 13;
_REGION_TIMING_PER_CTA_STRIDE module constant;
22 Int64(11*2*8) sites → Int64(_REGION_TIMING_PER_CTA_STRIDE)
extract_regions.py - --wo-split arg; dispatch to wo_split helper
for R2/R3/R11/R12 when wo_split>1
At wo_split=1: R11 mask never fires (bx>0 && bx<1 empty), R12 records
gather time, R0-R10 numerically identical to prior baseline. No
behavioral change at default.
Tasks 4+5 of 12 (combined for atomicity).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lays the pre-W_O barrier infrastructure for the K-parallel W_O GEMV
that lands in Task 8. At wo_split=1 the infrastructure is dormant:
the consumer mask `bx > 0 && bx < 1` is empty, so no CTA spins and
R11 buffer rows stay zero (host nonzero filter drops them).
Counter: _phase_e_coop_pre_wo_arrival_count, allocated in attach_mlp_fusion,
zeroed per-launch via host .zero_() inside run_beta_coop_full (mirrors
phase1_arrival_count reset pattern at line ~3120).
Producer (bx==0 && by<num_kv_heads): _threadfence + sync_threads
+ tid0 atomic_add 1 to pre_wo_arrival_count after attn output written.
Placed inside the existing bx==0 parent block, between R1 exit and
R2 entry timing.
Consumer (bx>0 && bx<wo_split && by<num_kv_heads): R11 entry sample,
spin-wait via _ld_volatile_u32 until counter == num_kv_heads,
_acquire_fence + sync_threads, R11 exit sample. Placed at kernel-
level (outside bx==0 parent) immediately before the existing R4
grid-barrier entry. Dead at wo_split=1.
Cooperative=True invariant preserved on β-coop launch (CLAUDE.md
rule 8: atomic-counter spin-wait barriers must run cooperative).
Cache MISS confirmed on first launch ("first call for this config"
log line) — function fingerprint shifted by the new pre_wo_arrival_ptr
arg and new R11 timing sites.
Task 6 of 12.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Wraps the existing is_last_cta gather block in _kernel_phase_0_to_4 with R12 entry/exit timing samples. Only the elected CTA writes a tick; all other CTAs leave R12 slots at zero. Host reducer (region_timing.py:208) drops zeros and reports the elected tick as median/mean for R12. R12 is in DYNAMIC_SINGLE_CTA_REGIONS (region_timing.py); reducer classifies as "dynamic_single" with NaN frac_of_kernel (not parallel work). At wo_split=1 the gather sums num_kv_heads=4 partials (legacy); post-Task-8 at wo_split=8 the gather sums 32 partials. R12 captures the duration in either case. Task 7 of 12. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lifts the W_O block out of the legacy `bx == 0 && by < num_kv_heads`
attn-producer parent. New gate: `bx < wo_split && by < num_kv_heads`.
At wo_split=1 the new gate is equivalent to the old (bx<1 ≡ bx==0)
and behavior is bit-exact preserved against the pre-Task-8 baseline.
K-range slicing follows torch_reference.py:443-446 exactly:
K_per_head = K // num_kv_heads
k_start_in_head = (K_per_head * bx) // wo_split
k_end_in_head = (K_per_head * (bx + 1)) // wo_split
k_start = by * K_per_head + k_start_in_head
k_end = by * K_per_head + k_end_in_head
Slot index: slot_idx = by * wo_split + bx (matches torch_reference.py:438-439
slot_id // wo_split == by, slot_id % wo_split == bx).
`wo_split_const: cutlass.Constexpr[int]` is threaded through both the
@cute.jit host wrapper and the @cute.kernel body, sourced from
self.wo_split at trace time. Cache key already includes self.wo_split
(Task 1) so flipping the env spawns a fresh compile.
The legacy W_O+gather block is moved out of the attn-producer parent
and placed at kernel-level after R11 (pre_wo_wait) so all W_O CTAs
(bx ∈ [0, wo_split), by < num_kv_heads) execute it. The election
counter target stays at total_wo_slots-1 (= num_kv_heads*wo_split - 1)
which scales naturally with the new W_O CTA count.
Bit-exact algorithm gate against reference_split_order:
- Harness microkernel @ docs/research/2026-05-03-w-o-k-parallel-harness/
reproduces this exact K-range/slot formula and reports
max_abs == 0.000e+00 vs reference_split_order(wo_split=N) at both
wo_split=1 and wo_split=8.
Production kernel verification (synthetic repro at /tmp/wo_split_repro.py):
- wo_split=1: run_beta_coop_full completes; wo_output[:,0,:] FINITE;
identical stats min=-7.84e+06 max=8.22e+06 mean=2.60e+04.
- wo_split=8: run_beta_coop_full completes; wo_output[:,0,:] FINITE;
identical stats; max_abs vs wo_split=1 = 7.0 (≈1 ULP at FP32 8M
magnitude — expected K-parallel reorder noise on K=6144 mixed-sign
random data; identical to harness wo_split=1-vs-wo_split=8 drift).
Serve smoke at both wo_split=1 (default) and CUTE_WO_SPLIT=8 produces
identical coherent /v1/completions output across the three test prompts:
- "What is 2+2?" → "2+2 equals 4..."
- "Capital of France?" → "<think>...What is the capital of France?..."
- "Write a haiku about coding." → "<think>...Topic: Coding. Format: Haiku..."
Cache MISS confirmed for the wo_split=8 + Task-8 config:
"Compiling PhaseE_Beta_Kernel β-coop full (first call for this config)…"
Task 8 of 12.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The kernel's wo_split bounds are robust for arbitrary 1..slice_ctas,
but only the powers-of-2 subset {1, 2, 4, 8} has the bench/correctness
story this PR ships. reference_split_order at torch_reference.py only
validates these four values, and the harness sweep evidence likewise
covers only this set.
Don't expose unevidenced settings (3/5/6/7) accidentally — the assert
fails fast on init if a user sets CUTE_WO_SPLIT to a non-evidenced
value. Comment block at the field documents the intent so a future
contributor knows the kernel logic itself isn't the gating constraint.
Pre-Task-10 cleanup. Subsequent baseline (Task 10) and graduation
(Task 11) traces will be captured against this restricted set.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Original Task 9 plan ("parameterize host Phase 1 mask helpers on
wo_split") was already subsumed by the Task 4+5 combined dispatch.
Repurposed to address the kernel-side cleanups flagged by the Task 8
spec+quality review:
#2 (Important): R11 timing/spin/exit gates now use wo_split_const
instead of self.wo_split, matching the W_O block (Task 8). Both
are bound from int(self.wo_split) in the same JIT compile call,
but mixing the two in the same kernel body forced readers to
verify equivalence. Now uniform across the kernel body.
#3 (Minor): Hoisted single pre_wo_consumer_active = (bx>0 &&
bx<wo_split_const && by<num_kv_heads) above the R11 entry; reused
at entry timing, spin gate, and exit timing. Removes the duplicate
pre_wo_consumer_active2 copy-paste artifact.
#4 (Minor): Dropped "# NEW:" prefix from the wo_split cache-key inline
comment — the marker would go stale at PR.
#5 (Real, fixed in same diff via the L253 comment block): bound-
restriction comment now points to docs/research/2026-05-03-w-o-k-
parallel-harness/torch_reference.py (the committed path) instead
of /tmp/wo_split_repro_workdir/torch_reference.py (machine-local
transient).
#6 (Minor): Added 3-line comment block before the new pre_wo_consumer_active
declaration explaining bx==0 producers skip R11 because their
attn_output reads are intra-CTA — the cross-CTA safety derivation
that the spec reviewer pointed out was undocumented.
Deferred to merge-prep (per user direction):
- #1: total_ctas_per_seq_attn dead-arg cleanup (Task 12 PR-prep)
- #7: cutlass.const_expr gate on wo_split=1 producer fence/atomic
(revisit if Task 10/11 evidence shows wo_split=1 overhead matters)
Pure refactor — bit-exact gate at wo_split=1 AND wo_split=8 still
passes with max_abs == 0.0 against reference_split_order. Cache MISS
on first launch (wo_split_const reference and mask hoist change the
PTX even though numerics are identical at runtime).
Task 9 of 12.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
vLLM's EngineCore subprocess (typically pid 146) strips most docker -e
env vars from its parent (per feedback_vllm_enginecore_env_strip in
project memory). Without this workaround, CUTE_WO_SPLIT=8 set on docker
run never reaches Phase_E_Beta_Kernel.__init__ and the kernel falls back
to the default wo_split=1.
Workaround mirrors the existing CUTE_C2_DIAG_* sentinel pattern:
1. scripts/serve-cute.sh writes CUTE_WO_SPLIT=${CUTE_WO_SPLIT:-1} to
the /tmp/c2_diag/ENV file (already bind-mounted into the container).
2. vllm/nvllm/models/qwen3_5.py reads /tmp/c2_diag/ENV at module import
and calls os.environ.setdefault for any line matching CUTE_C2_* OR
CUTE_WO_SPLIT=. The setdefault skips when the var is already set, so
real env wins.
Verified end-to-end on the live container today: with CUTE_WO_SPLIT=8
set on the host shell, serve-cute.sh writes the sentinel, EngineCore
reads it, PhaseE_Beta_Kernel constructs with self.wo_split=8, region
timing buffer shows R2 active CTAs = 32 (was 4), R11 active CTAs = 28
(consumer mask fires for bx>0).
Pre-Task-12 cleanup. Required for the wo_split=8 graduation evidence
in Tasks 10/11 to be reproducible.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Task 8 quality review flagged total_ctas_per_seq_attn as dead plumbing: defined in run_beta_coop_full host wrapper, plumbed through 4 kernel signature levels, but never consumed in the kernel body (the R1 attn-pre-W_O mask uses the literal `by < Int32(4)` directly, not the arg). This was acceptable as Task 8 ships because removing it is a separate refactor and the bit-exact gate already verified the kernel produces correct output regardless of the dead arg. Removed now as merge-prep cleanup before Task 12 (PR open). Removed 5 sites: - Host wrapper definition - all_args tuple pack - _jit_launch_phase_0_to_4 sig - _jit_launch_phase_0_to_4 forward to inner kernel call - _kernel_phase_0_to_4 sig Comment block at the host wrapper now documents the literal `4` in the R1 mask: it stays a literal because wo_split scales the W_O CTA count, NOT the attn-producer count (which is always num_kv_heads = 4 for Qwen3.5-27B). Bit-exact gate against reference_split_order still passes at both wo_split=1 (max_abs=0) and wo_split=8 (max_abs=0). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- L51: drop inline `hf_...` token-shaped example; say "pass HF_TOKEN via env or mount credentials" - L91: serve.sh row points to ig1/Qwen3.5-27B-NVFP4 (the script's actual default), not natfii/Opus-GB10 - L92: serve-cute.sh row points to same default model with HF_MODEL override note (was a broken 35B-A3B link) - L123: stream-K perf table gets a committed trace link to benchmarks/nvllm/traces/gemm_stream_k_cudagraph/2026-04-21/ - L131: stream-K warning drops --debug recommendation (eager is known-bad on SM120) - L137: CuTe status updated from "PyTorch prototype / CuTe replacement in progress" to current state — experimental CuTe DSL backend, production decode path since v0.3.0, β-coop fused kernel default, CUTE_WO_SPLIT=8 opt-in prototype - L139: launch path is `./scripts/serve-cute.sh` (PIECEWISE default); --debug explicitly called out as gibberish-producing on SM120 - Roadmap "Now" section: wo_split=8 opt-in prototype with region-cluster + GSM8K numbers and link to evidence summary - Veitner ack: now points to wo_split=8 evidence summary (the K-parallel patterns the blog described, applied) Pre-PR cleanup. No behavior change; documentation hygiene only. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Task 10/11 graduation evidence for the wo_split=8 K-parallel W_O GEMV opt-in prototype. All artifacts under benchmarks/nvllm/traces/wo_k_parallel_audit/2026-05-03-wo-split-8-prod/. Quality (GSM8K-50 full-think, seed=42, max_tokens=512, timeout=600s): baseline wo_split=1: 48/50 (96.0%), 0 errors, 65.7s OK-question median changed wo_split=8: 47/50 (94.0%), 0 errors, 62.3s OK-question median Δ: -1 question (within ±2% noise), -5.2% per-question, -2.6% wall Region timing (5-completion synthetic, ignore_eos): R2 phase1_wo_gemv: 14121 → 2360 us = -11761 us (5.99x) R4 grid_barrier_wait: 15211 → 1753 us = -13458 us (8.68x) R11 phase1_pre_wo_wait: 0 → 250 us = +250 us (new consumer wait) R12 phase1_gather_reduce: 73 → 167 us = +94 us (32 vs 4 partials) Cluster (R2+R4+R11+R12): 29405 → 4530 us = -24875 us (6.49x / -84.6%) nsys (harness microkernel at production grid shape; 50-launch median): kernel_cutlass__wo_kernel_body: 13715 → 1598 us = -12117 us (8.58x) Stddev collapse: 2102 → 5.7 us at wo_split=8. Bit-exact correctness: max_abs == 0.0 vs reference_split_order at both wo_splits, via V=constant trick (FP8 V-cache = +1.0). See summary.md and nsys_summary.md for full reproduction commands and methodology caveats. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Opt-in
CUTE_WO_SPLIT=8K-parallel W_O GEMV prototype for the production β-coop fused kernel on Qwen3.5-27B-NVFP4 (serve-cute.sh). Default stayswo_split=1— no behavior change for callers who don't set the env var. Restricted to{1, 2, 4, 8}(the evidenced set).Builds on PR #7 (audit + harness): the parity audit identified W_O GEMV (R2) at 4 active CTAs as the W_O bottleneck. This PR widens R2 to 32 active CTAs via K-parallel split (each CTA owns 1/8 of one KV-head's K range; gather of 32 partials in slot 0; new pre-W_O barrier for bx>0 W_O CTAs to wait on bx==0 attn producers).
Why this isn't duplicate work
No existing upstream
vllm-project/vllmPR covers SM120 K-parallel W_O GEMV for the cute_paged backend (this fork's custom kernel). PR #7 (just merged) shipped the validation harness + parity audit as a separate, smaller-scope artifact; this PR ships the production prototype that consumes that audit's findings.Performance evidence (scoped to measured region cluster + nsys)
All artifacts under
benchmarks/nvllm/traces/wo_k_parallel_audit/2026-05-03-wo-split-8-prod/— see summary.md for the full breakdown.Region timing (5-completion synthetic ignore_eos load, R2/R4/R11/R12 only):
phase1_wo_gemvgrid_barrier_waitphase1_pre_wo_waitphase1_gather_reduceR4 shrinks because R2 finishes faster, so all CTAs hit the grid barrier sooner. Other regions (R0, R1, R3, R5–R10) unchanged.
nsys total-kernel (harness microkernel at production grid shape, 50-launch median):
kernel_cutlass__wo_kernel_body: 13715.248 → 1598.064 μs = −12117 μs (8.58×), stddev 2102 → 5.7 μs (much more stable at wo_split=8). Perfeedback_vllm_profiling, vLLM V1 nsys against EngineCore is blocked by CUPTI inheritance — the harness microkernel under nsys is the canonical production-grid measurement; harness reproduces production W_O+gather math bit-exactly (verified below). NCU NOT used (per repo policy).End-to-end serving (Qwen3.5-27B-NVFP4 PIECEWISE, max-num-seqs=4, GSM8K-50 full-think workload): per-question median 65.7s → 62.3s = −5.2%, 50-question wall 3760s → 3664s = −2.6%. Translation from kernel speedup to e2e is bounded by what fraction of kernel time the W_O cluster represents at this workload.
Quality (parity, not improvement)
GSM8K-50 full-think (seed=42, max_tokens=512, timeout=600s):
−1 question is within ±2% noise; 0 errors both runs (well above the kernel-change ≥30/50 floor). Quality parity, not improvement.
Bit-exact correctness gate
Kernel reproduces
reference_split_order(wo_split=N)fromdocs/research/2026-05-03-w-o-k-parallel-harness/torch_reference.pybit-exactly at both wo_split=1 AND wo_split=8 (max_abs == 0.0againstwo_output[seq, 0, :]post-gather). Methodology: V=constant trick (FP8 V-cache = +1.0 makes Phase 1 attn output deterministic).Test commands run
Full reproduction in summary.md.
Implementation notes
serve-cute.sh:118adds-e CUTE_WO_SPLIT="${CUTE_WO_SPLIT:-1}"and writes the value to/tmp/c2_diag/ENV(the existing sentinel pattern).vllm/nvllm/models/qwen3_5.py:52reads it at module import viaos.environ.setdefault. Without this sentinel pattern, vLLM V1's EngineCore subprocess strips the env var (perfeedback_vllm_enginecore_env_strip).phase_e_kernel.py: K-parallel W_O block lifted out of the legacybx==0 && by<num_kv_headsparent into a new gatebx<wo_split && by<num_kv_heads. K-range integer-divide slicing pertorch_reference.py:443-446. Slot indexby*wo_split+bx. New pre-W_O arrival counter (_phase_e_coop_pre_wo_arrival_count) zeroed per-launch.WAIT_NOT_WORK_REGIONSonly (avoids first-match-wins misclassification as work). R12 in newDYNAMIC_SINGLE_CTA_REGIONS.cooperative=True).Caveats
R2 5.99×is the production-side measurement.CUTE_WO_SPLIT=8requires the sentinel-file workaround to reach EngineCore; bare-e CUTE_WO_SPLIT=8doesn't propagate.AI assistance disclosure (per AGENTS.md §1)
This PR was developed with AI assistance (Claude Opus 4.6/4.7). The submitting human reviewed every changed line, ran the test commands listed above, and is responsible for defending the change end-to-end. No code-agent autonomy beyond the bounded plan in
/home/natfii/.claude/plans/sorted-crafting-rainbow.md.🤖 Generated with Claude Code