perf(qwen3-14b decode): fuse rope+attn+softmax into one fa_fused kernel#656
perf(qwen3-14b decode): fuse rope+attn+softmax into one fa_fused kernel#656Hzfengsy wants to merge 1 commit into
Conversation
Collapse rope_qkv, fa_fused (attention) and online_softmax into a single mixed cube+vector root, replacing two grid dispatches + their cross-kernel dep edges with in-kernel pl.system.syncall barriers: - qk_norm folded in-register into the RoPE step (no q/k_proj_norm GM round-trip). - Per-region dual-AIV split (pl.split_aiv): rope phase-0 (NONE, 32-lane pipeline) -> syncall -> attn phase-1 (UP_DOWN row-halving) -> syncall -> online-softmax phase-2 (NONE, 48-way). - rope folded in as phase-0; standalone rope_qkv dispatch + rope->fa dep removed. Requires pypto PR #1894 (split_aiv SplitMode.NONE + cross-half GM base-param repoint fix); does not build on the currently pinned pypto.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
📝 WalkthroughWalkthroughThe fa_fused SPMD kernel in models/qwen3/14b/decode_layer.py is refactored to internally perform RoPE, QK-norm, and online-softmax reduction, removing separate rope_qkv and online_softmax dispatches. Multiple task-id arrays are consolidated into a single rope_dep_tids array. A test harness gains a dump_passes compile flag. ChangesFused Attention Refactor
Estimated code review effort: 4 (Complex) | ~60 minutes Sequence Diagram(s)sequenceDiagram
participant Projections as Q/K/V Projections
participant RopeDepTids as rope_dep_tids
participant FaFused as fa_fused kernel
participant Cache as k_cache/v_cache/all_q_padded
Projections->>RopeDepTids: write Q/K/V/RMS/work task ids
RopeDepTids->>FaFused: gate phase 0
FaFused->>FaFused: RoPE + QK-norm (phase 0)
FaFused->>Cache: write k_cache, v_cache, all_q_padded
FaFused->>FaFused: syncall (mix)
FaFused->>FaFused: phase 1 QK/softmax/SV (split_aiv, pipeline)
FaFused->>FaFused: syncall (hard)
FaFused->>FaFused: phase 2 online-softmax reduction (fused)
Possibly related issues
Possibly related PRs
Suggested labels: Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
models/qwen3/14b/decode_layer.py (1)
1674-1674: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick winAvoid enabling pass dumps by default.
dump_passes=Trueis forwarded byrun_jit, so the default golden path will always emit compiler dumps. Gate this behind an explicit debug flag or remove it before merging.Proposed fix
- compile_cfg=dict(dump_passes=True),🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/qwen3/14b/decode_layer.py` at line 1674, The `run_jit` path in `decode_layer.py` is always enabling compiler pass dumps via `compile_cfg=dict(dump_passes=True)`, which makes the default golden path emit debug artifacts. Remove the hardcoded `dump_passes=True` from the JIT compile configuration, or only set it when an explicit debug flag is enabled in the relevant `run_jit`/decode-layer setup. Use the existing `run_jit` and `compile_cfg` wiring to keep dump behavior opt-in by default.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/qwen3/14b/decode_layer.py`:
- Line 717: The loop variable in the pl.split_aiv(...,
mode=pl.SplitMode.UP_DOWN) iteration is unused, so rename aiv_id to _aiv_id in
the relevant loop inside decode_layer to make the intent explicit and silence
the lint warning.
---
Nitpick comments:
In `@models/qwen3/14b/decode_layer.py`:
- Line 1674: The `run_jit` path in `decode_layer.py` is always enabling compiler
pass dumps via `compile_cfg=dict(dump_passes=True)`, which makes the default
golden path emit debug artifacts. Remove the hardcoded `dump_passes=True` from
the JIT compile configuration, or only set it when an explicit debug flag is
enabled in the relevant `run_jit`/decode-layer setup. Use the existing `run_jit`
and `compile_cfg` wiring to keep dump behavior opt-in by default.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: c7f970e0-2c59-4332-9235-6a6bf2394542
📒 Files selected for processing (1)
models/qwen3/14b/decode_layer.py
| # each); the compiler inserts aiv_shard at the QK C->V boundary and | ||
| # aic_gather at the exp->SV V->C boundary. aiv_id is unused (the halving is | ||
| # automatic via the region's own subblock index). | ||
| for aiv_id in pl.split_aiv(2, mode=pl.SplitMode.UP_DOWN): |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect the relevant file around the reported line.
sed -n '690,740p' models/qwen3/14b/decode_layer.py | cat -n
# Find all references to the loop variable name in the file.
rg -n '\baiv_id\b' models/qwen3/14b/decode_layer.py
# Show a compact outline of the file to understand surrounding structure.
ast-grep outline models/qwen3/14b/decode_layer.py --view expandedRepository: hw-native-sys/pypto-lib
Length of output: 9853
Rename the unused split-lane variable. aiv_id is unused in this pl.split_aiv(..., mode=pl.SplitMode.UP_DOWN) loop, so rename it to _aiv_id to keep the intent clear and avoid the lint warning. models/qwen3/14b/decode_layer.py:717
🧰 Tools
🪛 Ruff (0.15.20)
[warning] 717-717: Loop control variable aiv_id not used within loop body
(B007)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/decode_layer.py` at line 717, The loop variable in the
pl.split_aiv(..., mode=pl.SplitMode.UP_DOWN) iteration is unused, so rename
aiv_id to _aiv_id in the relevant loop inside decode_layer to make the intent
explicit and silence the lint warning.
Source: Linters/SAST tools
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3e00096b98
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # NONE region (both lanes, disjoint (b, kvh) work). A function-level pl.split | ||
| # cannot express this — it would also try to halve phase-2's un-halvable | ||
| # [5, 128] / [1, 640] reduction tiles ("even split dimension"). | ||
| deps=[rope_dep_tids[i] for i in range(FA_NDEPS)], # rope deps + work_tid (rope folded into fa_fused phase 0) |
There was a problem hiding this comment.
Fence seed tasks before entering syncall
When the scheduler overlaps the dependency-free down_seed/gate_seed/up_seed zero-fill tasks just above with fa_fused (the comments explicitly place them here for that overlap), the new in-kernel syncall barriers require all 24 fa_fused blocks to be resident and to reach the barrier. Any seed task still occupying cores can leave some fa_fused participants unscheduled, so the barrier can wait forever; add an explicit dependency/fence or move these seeds after attention when using syncall.
Useful? React with 👍 / 👎.
Summary
rope_qkv, attention, and online-softmax into a single mixed cube+vectorfa_fusedroot, replacing two grid dispatches + their cross-kernel dep edges with in-kernelpl.system.syncallbarriers.pl.split_aiv: rope phase-0 (NONE, 32-lane pipeline) → syncall → attention phase-1 (UP_DOWNrow-halving) → syncall → online-softmax phase-2 (NONE, 48-way).q/k_proj_normGM round-trip); adopts main's feat(qwen3): runtime-dynamic paged KV cache for decode + restore defe… #637 runtime-dynamic paged KV cache.--max-seq) with--no-dep-gen. The big win is the attn+softmax syncall fusion; the rope→fa merge is dispatch-overhead-neutral.Related Issues
split_aivSplitMode.NONEno-halve dual-AIV + cross-half GM base-param repoint inExpandMixedKernel). Does not build on the currently pinned pypto — hence draft.--no-dep-gen: the dep-gen (DFX) instrumentation perturbs core occupancy and tripsfa_fused's full-occupancysyncall(AICore timeout 507018). The kernel itself runs correctly.