Skip to content

perf(qwen3-14b decode): fuse rope+attn+softmax into one fa_fused kernel#656

Draft
Hzfengsy wants to merge 1 commit into
hw-native-sys:mainfrom
Hzfengsy:perf/fuse-attn-softmax-syncall
Draft

perf(qwen3-14b decode): fuse rope+attn+softmax into one fa_fused kernel#656
Hzfengsy wants to merge 1 commit into
hw-native-sys:mainfrom
Hzfengsy:perf/fuse-attn-softmax-syncall

Conversation

@Hzfengsy

@Hzfengsy Hzfengsy commented Jul 1, 2026

Copy link
Copy Markdown
Member

Summary

  • Fuse rope_qkv, attention, and online-softmax into a single mixed cube+vector fa_fused root, replacing two grid dispatches + their cross-kernel dep edges with in-kernel pl.system.syncall barriers.
  • Per-region dual-AIV split via pl.split_aiv: rope phase-0 (NONE, 32-lane pipeline) → syncall → attention phase-1 (UP_DOWN row-halving) → syncall → online-softmax phase-2 (NONE, 48-way).
  • qk-norm kept in-register (no q/k_proj_norm GM round-trip); adopts main's feat(qwen3): runtime-dynamic paged KV cache for decode + restore defe… #637 runtime-dynamic paged KV cache.
  • Golden PASS on a2a3 (varied seq + --max-seq) with --no-dep-gen. The big win is the attn+softmax syncall fusion; the rope→fa merge is dispatch-overhead-neutral.

Related Issues

  • Depends on pypto PR #1894 (split_aiv SplitMode.NONE no-halve dual-AIV + cross-half GM base-param repoint in ExpandMixedKernel). Does not build on the currently pinned pypto — hence draft.
  • Golden/CI must run with --no-dep-gen: the dep-gen (DFX) instrumentation perturbs core occupancy and trips fa_fused's full-occupancy syncall (AICore timeout 507018). The kernel itself runs correctly.

Collapse rope_qkv, fa_fused (attention) and online_softmax into a single mixed
cube+vector root, replacing two grid dispatches + their cross-kernel dep edges
with in-kernel pl.system.syncall barriers:

- qk_norm folded in-register into the RoPE step (no q/k_proj_norm GM round-trip).
- Per-region dual-AIV split (pl.split_aiv): rope phase-0 (NONE, 32-lane pipeline)
  -> syncall -> attn phase-1 (UP_DOWN row-halving) -> syncall -> online-softmax
  phase-2 (NONE, 48-way).
- rope folded in as phase-0; standalone rope_qkv dispatch + rope->fa dep removed.

Requires pypto PR #1894 (split_aiv SplitMode.NONE + cross-half GM base-param
repoint fix); does not build on the currently pinned pypto.
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@coderabbitai

coderabbitai Bot commented Jul 1, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

The fa_fused SPMD kernel in models/qwen3/14b/decode_layer.py is refactored to internally perform RoPE, QK-norm, and online-softmax reduction, removing separate rope_qkv and online_softmax dispatches. Multiple task-id arrays are consolidated into a single rope_dep_tids array. A test harness gains a dump_passes compile flag.

Changes

Fused Attention Refactor

Layer / File(s) Summary
Consolidated task-id constants and array
models/qwen3/14b/decode_layer.py
Adds K_TID_BASE, V_TID_BASE, RMS_TID_IDX, ROPE_NDEPS, FA_NDEPS, K_RED_ROWS constants with 32B-alignment assertions, and replaces per-tile task-id arrays with a single rope_dep_tids array.
Producer task-id wiring
models/qwen3/14b/decode_layer.py
Q/K/V projection loops, rms_recip, and fa_work_build now write their produced task ids into fixed rope_dep_tids offsets instead of separate arrays.
In-kernel RoPE/QK-norm phase 0
models/qwen3/14b/decode_layer.py
fa_fused gates on rope_dep_tids, computes RoPE and QK-norm in-kernel, and writes k_cache, v_cache, and all_q_padded before a syncall.
Phase-1 attention and fused softmax reduction
models/qwen3/14b/decode_layer.py
Phase-1 QK/softmax/SV uses split_aiv/pipeline staging; online-softmax reduction is fused post-syncall into fa_fused, removing the standalone online_softmax dispatch; test harness adds dump_passes compile config.

Estimated code review effort: 4 (Complex) | ~60 minutes

Sequence Diagram(s)

sequenceDiagram
  participant Projections as Q/K/V Projections
  participant RopeDepTids as rope_dep_tids
  participant FaFused as fa_fused kernel
  participant Cache as k_cache/v_cache/all_q_padded

  Projections->>RopeDepTids: write Q/K/V/RMS/work task ids
  RopeDepTids->>FaFused: gate phase 0
  FaFused->>FaFused: RoPE + QK-norm (phase 0)
  FaFused->>Cache: write k_cache, v_cache, all_q_padded
  FaFused->>FaFused: syncall (mix)
  FaFused->>FaFused: phase 1 QK/softmax/SV (split_aiv, pipeline)
  FaFused->>FaFused: syncall (hard)
  FaFused->>FaFused: phase 2 online-softmax reduction (fused)
Loading

Possibly related issues

Possibly related PRs

  • hw-native-sys/pypto-lib#420: Refactors the same decode layer to integrate online-softmax recurrence directly into fa_fused, removing the standalone online_softmax region.
  • hw-native-sys/pypto-lib#602: Modifies the RoPE + QK-norm + cache-write path in the same file that this PR further fuses into fa_fused.
  • hw-native-sys/pypto-lib#651: Removes a standalone RoPE dispatch by fusing RoPE into the main attention/merge kernel with rewired task dependencies.

Suggested labels: enhancement

Poem

A rabbit hops through kernel code,
Fusing phases on one bright road. 🐇
Rope and softmax, hand in paw,
No more dispatch, just one clean draw.
Sync, compute, and cache anew—
Faster hops for me and you!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: fusing rope, attention, and softmax into one fa_fused kernel.
Description check ✅ Passed The description is directly related to the changeset and accurately describes the fused kernel and performance intent.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@Hzfengsy Hzfengsy marked this pull request as ready for review July 1, 2026 02:51
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Hzfengsy Hzfengsy marked this pull request as draft July 1, 2026 02:57

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
models/qwen3/14b/decode_layer.py (1)

1674-1674: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick win

Avoid enabling pass dumps by default.

dump_passes=True is forwarded by run_jit, so the default golden path will always emit compiler dumps. Gate this behind an explicit debug flag or remove it before merging.

Proposed fix
-            compile_cfg=dict(dump_passes=True),
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/decode_layer.py` at line 1674, The `run_jit` path in
`decode_layer.py` is always enabling compiler pass dumps via
`compile_cfg=dict(dump_passes=True)`, which makes the default golden path emit
debug artifacts. Remove the hardcoded `dump_passes=True` from the JIT compile
configuration, or only set it when an explicit debug flag is enabled in the
relevant `run_jit`/decode-layer setup. Use the existing `run_jit` and
`compile_cfg` wiring to keep dump behavior opt-in by default.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/qwen3/14b/decode_layer.py`:
- Line 717: The loop variable in the pl.split_aiv(...,
mode=pl.SplitMode.UP_DOWN) iteration is unused, so rename aiv_id to _aiv_id in
the relevant loop inside decode_layer to make the intent explicit and silence
the lint warning.

---

Nitpick comments:
In `@models/qwen3/14b/decode_layer.py`:
- Line 1674: The `run_jit` path in `decode_layer.py` is always enabling compiler
pass dumps via `compile_cfg=dict(dump_passes=True)`, which makes the default
golden path emit debug artifacts. Remove the hardcoded `dump_passes=True` from
the JIT compile configuration, or only set it when an explicit debug flag is
enabled in the relevant `run_jit`/decode-layer setup. Use the existing `run_jit`
and `compile_cfg` wiring to keep dump behavior opt-in by default.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: c7f970e0-2c59-4332-9235-6a6bf2394542

📥 Commits

Reviewing files that changed from the base of the PR and between 284d6b4 and 3e00096.

📒 Files selected for processing (1)
  • models/qwen3/14b/decode_layer.py

# each); the compiler inserts aiv_shard at the QK C->V boundary and
# aic_gather at the exp->SV V->C boundary. aiv_id is unused (the halving is
# automatic via the region's own subblock index).
for aiv_id in pl.split_aiv(2, mode=pl.SplitMode.UP_DOWN):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the relevant file around the reported line.
sed -n '690,740p' models/qwen3/14b/decode_layer.py | cat -n

# Find all references to the loop variable name in the file.
rg -n '\baiv_id\b' models/qwen3/14b/decode_layer.py

# Show a compact outline of the file to understand surrounding structure.
ast-grep outline models/qwen3/14b/decode_layer.py --view expanded

Repository: hw-native-sys/pypto-lib

Length of output: 9853


Rename the unused split-lane variable. aiv_id is unused in this pl.split_aiv(..., mode=pl.SplitMode.UP_DOWN) loop, so rename it to _aiv_id to keep the intent clear and avoid the lint warning. models/qwen3/14b/decode_layer.py:717

🧰 Tools
🪛 Ruff (0.15.20)

[warning] 717-717: Loop control variable aiv_id not used within loop body

(B007)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/decode_layer.py` at line 717, The loop variable in the
pl.split_aiv(..., mode=pl.SplitMode.UP_DOWN) iteration is unused, so rename
aiv_id to _aiv_id in the relevant loop inside decode_layer to make the intent
explicit and silence the lint warning.

Source: Linters/SAST tools

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 3e00096b98

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

# NONE region (both lanes, disjoint (b, kvh) work). A function-level pl.split
# cannot express this — it would also try to halve phase-2's un-halvable
# [5, 128] / [1, 640] reduction tiles ("even split dimension").
deps=[rope_dep_tids[i] for i in range(FA_NDEPS)], # rope deps + work_tid (rope folded into fa_fused phase 0)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fence seed tasks before entering syncall

When the scheduler overlaps the dependency-free down_seed/gate_seed/up_seed zero-fill tasks just above with fa_fused (the comments explicitly place them here for that overlap), the new in-kernel syncall barriers require all 24 fa_fused blocks to be resident and to reach the barrier. Any seed task still occupying cores can leave some fa_fused participants unscheduled, so the barrier can wait forever; add an explicit dependency/fence or move these seeds after attention when using syncall.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant