Qwen3 refactoring with compile-time keyword arguments#582
Conversation
📝 WalkthroughWalkthroughIntroduces a ChangesShared primitives library and model refactors
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request consolidates model-specific code by introducing a shared library of reusable @pl.inline primitives (such as RMSNorm, tiled matmuls, and SiLU activations) and test specification builders under models/shared/. The reviewer feedback highlights several opportunities for improvement: refactoring matmul_tiled_4d to keep the conditional check inside the pipeline loop for better hardware pipelining; lazily importing torch in specs.py to prevent module-level import overhead; adding and utilizing a b0 row offset parameter in rmsnorm_recip for consistency; and cleaning up unused slice variables left over from the refactoring in qwen3_32b_decode.py.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| tile_a = a[0:1, :, :, :] | ||
| tile_b = b[0:1, n1_offset:n1_offset + 1, :, :] | ||
| acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) | ||
| for kb in pl.pipeline(1, k_blocks, stage=stages): | ||
| tile_a = a[kb:kb + 1, :, :, :] | ||
| tile_b = b[kb:kb + 1, n1_offset:n1_offset + 1, :, :] | ||
| acc = pl.matmul_acc(acc, tile_a, tile_b) | ||
| return acc |
There was a problem hiding this comment.
According to the project convention for PyPTO on CANN/Ascend hardware, keeping a conditional check (e.g., if kb == 0) inside a pl.pipeline loop is preferred over peeling the first iteration. This allows the first chunk's load to overlap with the rest of the pipeline rather than running as an un-pipelined prologue. Please refactor matmul_tiled_4d to keep the conditional check inside the loop, similar to matmul_tiled.
| tile_a = a[0:1, :, :, :] | |
| tile_b = b[0:1, n1_offset:n1_offset + 1, :, :] | |
| acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) | |
| for kb in pl.pipeline(1, k_blocks, stage=stages): | |
| tile_a = a[kb:kb + 1, :, :, :] | |
| tile_b = b[kb:kb + 1, n1_offset:n1_offset + 1, :, :] | |
| acc = pl.matmul_acc(acc, tile_a, tile_b) | |
| return acc | |
| acc = pl.create_tensor([1, 1, batch, n1_chunk], dtype=pl.FP32) | |
| for kb in pl.pipeline(0, k_blocks, stage=stages): | |
| tile_a = a[kb:kb + 1, :, :, :] | |
| tile_b = b[kb:kb + 1, n1_offset:n1_offset + 1, :, :] | |
| if kb == 0: | |
| acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) | |
| else: | |
| acc = pl.matmul_acc(acc, tile_a, tile_b) | |
| return acc |
References
- In PyPTO on CANN/Ascend hardware, keeping a conditional check (e.g.,
if db == 0) inside apl.pipelineloop can be preferred over peeling the first iteration. This allows the first chunk's load to overlap with the rest of the pipeline rather than running as an un-pipelined prologue, provided the compiler successfully pipelines the loop-index branch.
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch |
There was a problem hiding this comment.
Importing torch at the module level violates the project convention for spec/builder files. Lazy imports for 'torch' and 'pypto.runtime' should be used to avoid import overhead when only builder functions are imported but not immediately executed.
References
- Lazy imports for 'torch' and 'pypto.runtime' are a project convention to avoid import overhead when only builder functions are used.
| @pl.inline | ||
| def rmsnorm_recip( | ||
| x, *, rows: int, k_chunk: int, eps: float, hidden: int, | ||
| stages: int = 2, | ||
| cast_input: bool = True, | ||
| ): |
There was a problem hiding this comment.
For consistency and better reusability, rmsnorm_recip should accept a b0 row offset parameter (defaulting to 0) just like rmsnorm does. This allows callers to compute the reciprocal on sliced or tiled hidden states with non-zero row offsets.
@pl.inline
def rmsnorm_recip(
x, b0: int = 0,
*, rows: int, k_chunk: int, eps: float, hidden: int,
stages: int = 2,
cast_input: bool = True,
):| if cast_input: | ||
| chunk = pl.cast( | ||
| pl.slice(x, [rows, k_chunk], [0, k0]), | ||
| target_type=pl.FP32, | ||
| ) | ||
| else: | ||
| chunk = pl.slice(x, [rows, k_chunk], [0, k0]) |
There was a problem hiding this comment.
Use the b0 parameter instead of the hardcoded 0 offset in pl.slice to support non-zero row offsets.
| if cast_input: | |
| chunk = pl.cast( | |
| pl.slice(x, [rows, k_chunk], [0, k0]), | |
| target_type=pl.FP32, | |
| ) | |
| else: | |
| chunk = pl.slice(x, [rows, k_chunk], [0, k0]) | |
| if cast_input: | |
| chunk = pl.cast( | |
| pl.slice(x, [rows, k_chunk], [b0, k0]), | |
| target_type=pl.FP32, | |
| ) | |
| else: | |
| chunk = pl.slice(x, [rows, k_chunk], [b0, k0]) |
| post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) | ||
| post_chunk_1 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, K_CHUNK]) | ||
| wg_0 = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [0, o0]) | ||
| gate_acc = pl.matmul(post_chunk_0, wg_0, out_dtype=pl.FP32) | ||
|
|
||
| wg_1 = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [K_CHUNK, o0]) | ||
| gate_acc = pl.matmul_acc(gate_acc, post_chunk_1, wg_1) | ||
|
|
||
| for kb in pl.pipeline(2, HIDDEN_BLOCKS, stage=2): | ||
| k0 = kb * K_CHUNK | ||
| post_chunk = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, k0]) | ||
| wg = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [k0, o0]) | ||
| gate_acc = pl.matmul_acc(gate_acc, post_chunk, wg) | ||
| gate_acc = matmul_tiled( | ||
| post_norm_tile, w_gate, o0, | ||
| m=BATCH, k_chunk=K_CHUNK, n_chunk=MLP_OUT_CHUNK, | ||
| k_blocks=HIDDEN_BLOCKS, stages=2, | ||
| ) |
There was a problem hiding this comment.
The slice variables post_chunk_0, post_chunk_1, and wg_0 are left over from the original hand-written loop and are completely unused now that matmul_tiled is used. Please remove them to keep the code clean.
gate_acc = matmul_tiled(
post_norm_tile, w_gate, o0,
m=BATCH, k_chunk=K_CHUNK, n_chunk=MLP_OUT_CHUNK,
k_blocks=HIDDEN_BLOCKS, stages=2,
)| post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) | ||
| post_chunk_1 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, K_CHUNK]) | ||
| wu_0 = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [0, o0]) | ||
| up_acc = pl.matmul(post_chunk_0, wu_0, out_dtype=pl.FP32) | ||
|
|
||
| wu_1 = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [K_CHUNK, o0]) | ||
| up_acc = pl.matmul_acc(up_acc, post_chunk_1, wu_1) | ||
|
|
||
| for kb in pl.pipeline(2, HIDDEN_BLOCKS, stage=2): | ||
| k0 = kb * K_CHUNK | ||
| post_chunk = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, k0]) | ||
| wu = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [k0, o0]) | ||
| up_acc = pl.matmul_acc(up_acc, post_chunk, wu) | ||
| up_acc = matmul_tiled( | ||
| post_norm_tile, w_up, o0, | ||
| m=BATCH, k_chunk=K_CHUNK, n_chunk=MLP_OUT_CHUNK, | ||
| k_blocks=HIDDEN_BLOCKS, stages=2, | ||
| ) |
There was a problem hiding this comment.
The slice variables post_chunk_0, post_chunk_1, and wu_0 are left over from the original hand-written loop and are completely unused now that matmul_tiled is used. Please remove them to keep the code clean.
up_acc = matmul_tiled(
post_norm_tile, w_up, o0,
m=BATCH, k_chunk=K_CHUNK, n_chunk=MLP_OUT_CHUNK,
k_blocks=HIDDEN_BLOCKS, stages=2,
)There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (4)
models/shared/golden_ref.py (2)
269-281: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueDead locals in the 4D variant:
hidden(line 270) and the cos/sin halves (lines 280-281) are unused.
attn_proj_tileis shaped fromout_proj_k_blocks/batch/out_proj_k_chunk, sohiddenis never read; the cos/sin halves are unused for the same reason as scope2.half(269) only feeds the dead split, so it can go too.♻️ Proposed cleanup
- half = head_dim // 2 - hidden = num_heads * head_dim attn_proj_tile = torch.zeros(out_proj_k_blocks, batch, out_proj_k_chunk, dtype=torch.bfloat16) @@ cos_row = rope_cos[pos, 0, :, :] sin_row = rope_sin[pos, 0, :, :] - cos_lo, cos_hi = cos_row[:, :half], cos_row[:, half:] - sin_lo, sin_hi = sin_row[:, :half], sin_row[:, half:]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/shared/golden_ref.py` around lines 269 - 281, Remove the unused dead variables in the 4D variant code block. Delete the unused `hidden` variable assignment, the `half` variable calculation and the subsequent cos/sin splits (`cos_lo`, `cos_hi`, `sin_lo`, `sin_hi`) that are never referenced in the code that follows. These variables are computed but never used and should be removed to clean up the code.
178-179: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueDead locals:
cos_lo/cos_hi/sin_lo/sin_hiare computed but never used.
golden_rope_rotate_halfsplitscos_row/sin_rowinternally (lines 68-73), so these four halves are unused here. Same pattern is duplicated ingolden_decode_scope2_4d(lines 280-281). Drop them to reduce noise.♻️ Proposed cleanup
cos_row = rope_cos[pos: pos + 1, :] sin_row = rope_sin[pos: pos + 1, :] - cos_lo, cos_hi = cos_row[:, :half], cos_row[:, half:] - sin_lo, sin_hi = sin_row[:, :half], sin_row[:, half:]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/shared/golden_ref.py` around lines 178 - 179, The variables cos_lo, cos_hi, sin_lo, and sin_hi are being computed by splitting cos_row and sin_row at lines 178-179, but they are never used anywhere after assignment. Since golden_rope_rotate_half already performs this splitting internally, these lines are redundant. Remove both line 178 and line 179 entirely. Additionally, locate and remove the identical duplicate pattern in the golden_decode_scope2_4d function at lines 280-281 where the same unused variables are created.models/shared/specs.py (1)
89-99: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value
num_headsis accepted but never used in either builder.Both
build_qwen3_decode_specsandbuild_qwen3_prefill_specssize weights fromhidden/kv_hiddenand never referencenum_heads. It's harmless for API symmetry, but if it's purely vestigial consider dropping it (or add a brief comment that it's intentionally retained for caller convenience) to avoid the impression it drives any shape.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/shared/specs.py` around lines 89 - 99, The num_heads parameter in both build_qwen3_decode_specs and build_qwen3_prefill_specs functions is accepted but never used in the implementation, as weight sizing is computed from hidden/kv_hidden instead. Either remove the num_heads parameter from both function signatures to eliminate the unused parameter, or add a brief comment above the parameter in the function docstring explaining it is intentionally retained for API symmetry and caller convenience. Choose the approach that best fits your codebase conventions.models/shared/matmul.py (1)
24-25: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winNormalize ambiguous Unicode punctuation in docstrings.
Ruff is flagging RUF002 on these lines; replacing non-breaking hyphens/spaces with plain ASCII keeps doc lint clean and avoids editor-dependent rendering.
Suggested docstring cleanup
- Both use ``@pl.inline`` with keyword-only CT kwargs (``k_chunk``, ``n_chunk``, - ``k_blocks``) for compile‑time tile sizes, and a **positional** offset parameter - for the dynamic N‑dimension (which is computed from SPMD/parallel loop vars). + Both use ``@pl.inline`` with keyword-only CT kwargs (``k_chunk``, ``n_chunk``, + ``k_blocks``) for compile-time tile sizes, and a **positional** offset parameter + for the dynamic N-dimension (which is computed from SPMD/parallel loop vars). - dynamic ``n_offset`` (the N‑tile start, e.g. ``qi * N_CHUNK`` from an - SPMD/parallel loop). The inline owns the K‑loop slicing internally, - generating per‑iteration Mat‑space tiles that fit within 512 KB. + dynamic ``n_offset`` (the N-tile start, e.g. ``qi * N_CHUNK`` from an + SPMD/parallel loop). The inline owns the K-loop slicing internally, + generating per-iteration Mat-space tiles that fit within 512 KB. - a: Activation tensor ``[M, K]`` — **full** K‑dimension. No pre-slicing. - b: Weight tensor ``[K, N]`` — **full** K‑dimension. No pre-slicing. - n_offset: Dynamic N‑dimension offset into ``b`` (a Scalar, e.g. from + a: Activation tensor ``[M, K]`` — **full** K-dimension. No pre-slicing. + b: Weight tensor ``[K, N]`` — **full** K-dimension. No pre-slicing. + n_offset: Dynamic N-dimension offset into ``b`` (a Scalar, e.g. from ... - k_chunk: K‑dimension chunk size for the reduction pipeline (CT kwarg). - n_chunk: N‑dimension chunk size for the output tile (CT kwarg). + k_chunk: K-dimension chunk size for the reduction pipeline (CT kwarg). + n_chunk: N-dimension chunk size for the output tile (CT kwarg). ... - k_chunk: K‑dimension chunk size (4th axis of the 4D tensors, CT kwarg). + k_chunk: K-dimension chunk size (4th axis of the 4D tensors, CT kwarg).Also applies to: 50-52, 55-57, 60-61, 101-101
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/shared/matmul.py` around lines 24 - 25, The docstring in the matmul.py file contains non-breaking Unicode hyphens (‑) instead of standard ASCII hyphens (-) which are flagged by Ruff's RUF002 rule. Replace all instances of non-breaking hyphens with regular ASCII hyphens throughout the docstrings, specifically in the phrases "compile‑time" and "N‑dimension" on lines 24-25, and apply this same normalization to all other occurrences mentioned across lines 50-52, 55-57, 60-61, and 101. Also check for and replace any non-breaking spaces with regular ASCII spaces to ensure the documentation is clean and renders consistently across different editors.Source: Linters/SAST tools
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/qwen3/32b/qwen3_32b_decode_4d.py`:
- Line 33: The `golden_decode_scope1` function at line 567 expects input shape
`[batch, hidden]` but is receiving block-major formatted `hidden_states`
directly, which will cause shape mismatches or incorrect validation. Before
calling `golden_decode_scope1`, reshape the block-major `hidden_states` to the
expected `[batch, hidden]` format. Additionally, the Scope 3 formula
implementation in lines 593-604 is incomplete: it is missing the output
projection step, lacks the proper residual connection, and skips the
post-RMSNorm operation. Fix this by adding the output projection, ensuring the
`down` projection is added to `resid1` (not directly to raw attention blocks),
and including the final RMSNorm normalization to match the complete formula
structure.
In `@models/qwen3/32b/qwen3_32b_decode.py`:
- Around line 425-467: The golden_qwen3_decode function computes tensors["out"]
using the refactored scope helper functions (golden_decode_scope1,
golden_decode_scope2, golden_decode_scope3), but execution continues to fall
through into the old stale inline golden body implementation that overwrites the
result. Add a return statement immediately after the tensors["out"] assignment
in the scope3 block, or delete the entire old inline golden implementation block
that follows (lines 468-588) to prevent the stale code from shadowing the
refactored path during validation.
In `@models/qwen3/32b/qwen3_32b_prefill_draft.py`:
- Around line 716-723: The help text for the --smoke argument states that smoke
behavior is implicit on *sim platforms, but the condition at line 723 only
checks if args.smoke is true and does not detect or check for sim platforms. To
fix this inconsistency, either remove the phrase "also the implicit behavior on
*sim platforms" from the help text in parser.add_argument for the smoke argument
to reflect the actual behavior of only checking args.smoke, or modify the
conditional logic to detect sim platforms and include that check in the if
statement alongside the args.smoke check. Choose the simpler option: update the
help text to remove the sim platform reference since the code comment explicitly
states that sim runs the full pipeline.
In `@models/shared/golden_ref.py`:
- Around line 13-18: The import statement at the top of golden_ref.py imports
from models.shared.golden, but since this file is golden_ref.py itself, the
import path is incorrect and won't resolve. Change the import statement to
import from models.shared.golden_ref instead of models.shared.golden to ensure
the documented usage example actually works and matches the correct module
location.
---
Nitpick comments:
In `@models/shared/golden_ref.py`:
- Around line 269-281: Remove the unused dead variables in the 4D variant code
block. Delete the unused `hidden` variable assignment, the `half` variable
calculation and the subsequent cos/sin splits (`cos_lo`, `cos_hi`, `sin_lo`,
`sin_hi`) that are never referenced in the code that follows. These variables
are computed but never used and should be removed to clean up the code.
- Around line 178-179: The variables cos_lo, cos_hi, sin_lo, and sin_hi are
being computed by splitting cos_row and sin_row at lines 178-179, but they are
never used anywhere after assignment. Since golden_rope_rotate_half already
performs this splitting internally, these lines are redundant. Remove both line
178 and line 179 entirely. Additionally, locate and remove the identical
duplicate pattern in the golden_decode_scope2_4d function at lines 280-281 where
the same unused variables are created.
In `@models/shared/matmul.py`:
- Around line 24-25: The docstring in the matmul.py file contains non-breaking
Unicode hyphens (‑) instead of standard ASCII hyphens (-) which are flagged by
Ruff's RUF002 rule. Replace all instances of non-breaking hyphens with regular
ASCII hyphens throughout the docstrings, specifically in the phrases
"compile‑time" and "N‑dimension" on lines 24-25, and apply this same
normalization to all other occurrences mentioned across lines 50-52, 55-57,
60-61, and 101. Also check for and replace any non-breaking spaces with regular
ASCII spaces to ensure the documentation is clean and renders consistently
across different editors.
In `@models/shared/specs.py`:
- Around line 89-99: The num_heads parameter in both build_qwen3_decode_specs
and build_qwen3_prefill_specs functions is accepted but never used in the
implementation, as weight sizing is computed from hidden/kv_hidden instead.
Either remove the num_heads parameter from both function signatures to eliminate
the unused parameter, or add a brief comment above the parameter in the function
docstring explaining it is intentionally retained for API symmetry and caller
convenience. Choose the approach that best fits your codebase conventions.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 37b05a1b-effc-4f3b-9b2e-a4cb17560e2c
📒 Files selected for processing (16)
models/__init__.pymodels/qwen3/14b/decode_layer.pymodels/qwen3/14b/prefill_fwd.pymodels/qwen3/14b/rms_lm_head.pymodels/qwen3/32b/qwen3_32b_decode.pymodels/qwen3/32b/qwen3_32b_decode_4d.pymodels/qwen3/32b/qwen3_32b_prefill_draft.pymodels/shared/__init__.pymodels/shared/attention.pymodels/shared/golden_ref.pymodels/shared/matmul.pymodels/shared/rmsnorm.pymodels/shared/rope.pymodels/shared/silu.pymodels/shared/specs.pymodels/shared/test_matmul_tiled.py
| import pypto.language as pl | ||
| from models.shared.silu import silu_activation | ||
| from models.shared.matmul import matmul_tiled_4d | ||
| from models.shared.golden_ref import golden_decode_scope1, golden_decode_scope2_4d, golden_rmsnorm, golden_swiglu |
There was a problem hiding this comment.
🎯 Functional Correctness | 🔴 Critical | 🏗️ Heavy lift
Fix the 4D golden path shape conversion and Scope 3 formula.
Line 567 passes block-major hidden_states directly to golden_decode_scope1, whose contract expects [batch, hidden]; this will either fail broadcasting/matmul or validate the wrong layout. Lines 593-604 also skip output projection + residual + post-RMSNorm, using input RMSNorm and adding down to raw attention blocks instead of resid1.
🐛 Suggested direction
-from models.shared.golden_ref import golden_decode_scope1, golden_decode_scope2_4d, golden_rmsnorm, golden_swiglu
+from models.shared.golden_ref import golden_decode_scope1, golden_decode_scope2_4d, golden_decode_scope3- hidden_states = tensors["hidden_states"]
+ hidden_states_chunked = tensors["hidden_states"]
@@
out = tensors["out"]
+ hidden_states = hidden_states_chunked[:, 0, :, :].permute(1, 0, 2).reshape(BATCH, HIDDEN)
+
@@
- normed_bf16 = golden_rmsnorm(hidden_states, input_rms_weight, eps=EPS)
-
- gate = torch.matmul(normed_bf16.float(), w_gate.float())
- up = torch.matmul(normed_bf16.float(), w_up.float())
- mlp_bf16 = golden_swiglu(gate, up).bfloat16()
- mlp_blocks = mlp_bf16.reshape(BATCH, MLP_OUT_BLOCKS, MLP_OUT_CHUNK).permute(1, 0, 2)
-
- attn_proj_blocks = attn_proj_tile # [OUT_PROJ_K_BLOCKS, BATCH, OUT_PROJ_K_CHUNK]
- down = torch.matmul(mlp_blocks.float().reshape(MLP_OUT_BLOCKS * BATCH, MLP_OUT_CHUNK), w_down.float()).reshape(MLP_OUT_BLOCKS, BATCH, DOWN_N_CHUNK).permute(1, 0, 2)
-
- out_blocks = down + attn_proj_blocks.permute(1, 0, 2)
- out[:] = out_blocks.bfloat16()
+ attn_out = attn_proj_tile.permute(1, 0, 2).reshape(BATCH, HIDDEN)
+ out_2d = golden_decode_scope3(
+ attn_out, hidden_states, wo, post_rms_weight, w_gate, w_up, w_down,
+ hidden=HIDDEN, eps=EPS,
+ )
+ out[:] = out_2d.reshape(BATCH, DOWN_N_BLOCKS, DOWN_N_CHUNK).permute(1, 0, 2).unsqueeze(1)Also applies to: 543-604
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/32b/qwen3_32b_decode_4d.py` at line 33, The
`golden_decode_scope1` function at line 567 expects input shape `[batch,
hidden]` but is receiving block-major formatted `hidden_states` directly, which
will cause shape mismatches or incorrect validation. Before calling
`golden_decode_scope1`, reshape the block-major `hidden_states` to the expected
`[batch, hidden]` format. Additionally, the Scope 3 formula implementation in
lines 593-604 is incomplete: it is missing the output projection step, lacks the
proper residual connection, and skips the post-RMSNorm operation. Fix this by
adding the output projection, ensuring the `down` projection is added to
`resid1` (not directly to raw attention blocks), and including the final RMSNorm
normalization to match the complete formula structure.
| def golden_qwen3_decode(tensors): | ||
| """PyTorch reference: scope1 (RMSNorm + projection), scope2 (attention), scope3 (output + MLP).""" | ||
| import math | ||
|
|
||
| def init_w_up(): | ||
| return (torch.rand(HIDDEN, INTERMEDIATE) - 0.5) / HIDDEN ** 0.5 | ||
| hidden_states = tensors["hidden_states"] | ||
| input_rms_weight = tensors["input_rms_weight"] | ||
| wq = tensors["wq"] | ||
| wk = tensors["wk"] | ||
| wv = tensors["wv"] | ||
| seq_lens = tensors["seq_lens"] | ||
| rope_cos = tensors["rope_cos"] | ||
| rope_sin = tensors["rope_sin"] | ||
| k_cache = tensors["k_cache"] | ||
| v_cache = tensors["v_cache"] | ||
| wo = tensors["wo"] | ||
| post_rms_weight = tensors["post_rms_weight"] | ||
| w_gate = tensors["w_gate"] | ||
| w_up = tensors["w_up"] | ||
| w_down = tensors["w_down"] | ||
|
|
||
| def init_w_down(): | ||
| return (torch.rand(INTERMEDIATE, HIDDEN) - 0.5) / INTERMEDIATE ** 0.5 | ||
| attn_scale = 1.0 / math.sqrt(HEAD_DIM) | ||
|
|
||
| return [ | ||
| TensorSpec("hidden_states", [BATCH, HIDDEN], torch.bfloat16, init_value=init_hidden_states), | ||
| TensorSpec("input_rms_weight", [1, HIDDEN], torch.float32, init_value=init_rms_weight), | ||
| TensorSpec("wq", [HIDDEN, HIDDEN], torch.bfloat16, init_value=init_wq), | ||
| TensorSpec("wk", [HIDDEN, KV_HIDDEN], torch.bfloat16, init_value=init_wk), | ||
| TensorSpec("wv", [HIDDEN, KV_HIDDEN], torch.bfloat16, init_value=init_wv), | ||
| TensorSpec("seq_lens", [BATCH], torch.int32, init_value=init_seq_lens), | ||
| TensorSpec("rope_cos", [MAX_SEQ, HEAD_DIM], torch.float32, init_value=init_rope_cos), | ||
| TensorSpec("rope_sin", [MAX_SEQ, HEAD_DIM], torch.float32, init_value=init_rope_sin), | ||
| TensorSpec("k_cache", [CACHE_ROWS, HEAD_DIM], torch.bfloat16, init_value=init_k_cache), | ||
| TensorSpec("v_cache", [CACHE_ROWS, HEAD_DIM], torch.bfloat16, init_value=init_v_cache), | ||
| TensorSpec("wo", [HIDDEN, HIDDEN], torch.bfloat16, init_value=init_wo), | ||
| TensorSpec("post_rms_weight", [1, HIDDEN], torch.float32, init_value=init_post_rms_weight), | ||
| TensorSpec("w_gate", [HIDDEN, INTERMEDIATE], torch.bfloat16, init_value=init_w_gate), | ||
| TensorSpec("w_up", [HIDDEN, INTERMEDIATE], torch.bfloat16, init_value=init_w_up), | ||
| TensorSpec("w_down", [INTERMEDIATE, HIDDEN], torch.bfloat16, init_value=init_w_down), | ||
| TensorSpec("out", [BATCH, HIDDEN], torch.bfloat16, is_output=True), | ||
| ] | ||
| # ── Scope 1: RMSNorm + Q/K/V projection ── | ||
| _, q_proj, k_proj, v_proj = golden_decode_scope1( | ||
| hidden_states, input_rms_weight, wq, wk, wv, | ||
| hidden=HIDDEN, eps=EPS, | ||
| ) | ||
|
|
||
| # ── Scope 2: RoPE + cache update + attention ── | ||
| attn_out = golden_decode_scope2( | ||
| q_proj, k_proj, v_proj, k_cache, v_cache, seq_lens, | ||
| rope_cos, rope_sin, | ||
| batch=BATCH, num_heads=NUM_HEADS, num_kv_heads=NUM_KV_HEADS, | ||
| head_dim=HEAD_DIM, max_seq=MAX_SEQ, seq_tile=SEQ_TILE, | ||
| q_per_kv=Q_PER_KV, q_head_batch=Q_HEAD_BATCH, | ||
| attn_scale=attn_scale, | ||
| ) | ||
|
|
||
| def golden_qwen3_decode(tensors): | ||
| # ── Scope 3: output projection + residual + post RMSNorm + MLP + residual ── | ||
| tensors["out"][:] = golden_decode_scope3( | ||
| attn_out, hidden_states, wo, post_rms_weight, w_gate, w_up, w_down, | ||
| hidden=HIDDEN, eps=EPS, | ||
| ) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Stop falling through into the stale golden implementation.
Lines 425-467 compute tensors["out"] through the shared scope helpers, but execution then continues into the old inline golden body at Lines 468-588 and overwrites the result. This masks the refactored golden path during validation; either delete the stale block or return after the helper path.
🐛 Minimal fix
tensors["out"][:] = golden_decode_scope3(
attn_out, hidden_states, wo, post_rms_weight, w_gate, w_up, w_down,
hidden=HIDDEN, eps=EPS,
)
+ return📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def golden_qwen3_decode(tensors): | |
| """PyTorch reference: scope1 (RMSNorm + projection), scope2 (attention), scope3 (output + MLP).""" | |
| import math | |
| def init_w_up(): | |
| return (torch.rand(HIDDEN, INTERMEDIATE) - 0.5) / HIDDEN ** 0.5 | |
| hidden_states = tensors["hidden_states"] | |
| input_rms_weight = tensors["input_rms_weight"] | |
| wq = tensors["wq"] | |
| wk = tensors["wk"] | |
| wv = tensors["wv"] | |
| seq_lens = tensors["seq_lens"] | |
| rope_cos = tensors["rope_cos"] | |
| rope_sin = tensors["rope_sin"] | |
| k_cache = tensors["k_cache"] | |
| v_cache = tensors["v_cache"] | |
| wo = tensors["wo"] | |
| post_rms_weight = tensors["post_rms_weight"] | |
| w_gate = tensors["w_gate"] | |
| w_up = tensors["w_up"] | |
| w_down = tensors["w_down"] | |
| def init_w_down(): | |
| return (torch.rand(INTERMEDIATE, HIDDEN) - 0.5) / INTERMEDIATE ** 0.5 | |
| attn_scale = 1.0 / math.sqrt(HEAD_DIM) | |
| return [ | |
| TensorSpec("hidden_states", [BATCH, HIDDEN], torch.bfloat16, init_value=init_hidden_states), | |
| TensorSpec("input_rms_weight", [1, HIDDEN], torch.float32, init_value=init_rms_weight), | |
| TensorSpec("wq", [HIDDEN, HIDDEN], torch.bfloat16, init_value=init_wq), | |
| TensorSpec("wk", [HIDDEN, KV_HIDDEN], torch.bfloat16, init_value=init_wk), | |
| TensorSpec("wv", [HIDDEN, KV_HIDDEN], torch.bfloat16, init_value=init_wv), | |
| TensorSpec("seq_lens", [BATCH], torch.int32, init_value=init_seq_lens), | |
| TensorSpec("rope_cos", [MAX_SEQ, HEAD_DIM], torch.float32, init_value=init_rope_cos), | |
| TensorSpec("rope_sin", [MAX_SEQ, HEAD_DIM], torch.float32, init_value=init_rope_sin), | |
| TensorSpec("k_cache", [CACHE_ROWS, HEAD_DIM], torch.bfloat16, init_value=init_k_cache), | |
| TensorSpec("v_cache", [CACHE_ROWS, HEAD_DIM], torch.bfloat16, init_value=init_v_cache), | |
| TensorSpec("wo", [HIDDEN, HIDDEN], torch.bfloat16, init_value=init_wo), | |
| TensorSpec("post_rms_weight", [1, HIDDEN], torch.float32, init_value=init_post_rms_weight), | |
| TensorSpec("w_gate", [HIDDEN, INTERMEDIATE], torch.bfloat16, init_value=init_w_gate), | |
| TensorSpec("w_up", [HIDDEN, INTERMEDIATE], torch.bfloat16, init_value=init_w_up), | |
| TensorSpec("w_down", [INTERMEDIATE, HIDDEN], torch.bfloat16, init_value=init_w_down), | |
| TensorSpec("out", [BATCH, HIDDEN], torch.bfloat16, is_output=True), | |
| ] | |
| # ── Scope 1: RMSNorm + Q/K/V projection ── | |
| _, q_proj, k_proj, v_proj = golden_decode_scope1( | |
| hidden_states, input_rms_weight, wq, wk, wv, | |
| hidden=HIDDEN, eps=EPS, | |
| ) | |
| # ── Scope 2: RoPE + cache update + attention ── | |
| attn_out = golden_decode_scope2( | |
| q_proj, k_proj, v_proj, k_cache, v_cache, seq_lens, | |
| rope_cos, rope_sin, | |
| batch=BATCH, num_heads=NUM_HEADS, num_kv_heads=NUM_KV_HEADS, | |
| head_dim=HEAD_DIM, max_seq=MAX_SEQ, seq_tile=SEQ_TILE, | |
| q_per_kv=Q_PER_KV, q_head_batch=Q_HEAD_BATCH, | |
| attn_scale=attn_scale, | |
| ) | |
| def golden_qwen3_decode(tensors): | |
| # ── Scope 3: output projection + residual + post RMSNorm + MLP + residual ── | |
| tensors["out"][:] = golden_decode_scope3( | |
| attn_out, hidden_states, wo, post_rms_weight, w_gate, w_up, w_down, | |
| hidden=HIDDEN, eps=EPS, | |
| ) | |
| def golden_qwen3_decode(tensors): | |
| """PyTorch reference: scope1 (RMSNorm + projection), scope2 (attention), scope3 (output + MLP).""" | |
| import math | |
| hidden_states = tensors["hidden_states"] | |
| input_rms_weight = tensors["input_rms_weight"] | |
| wq = tensors["wq"] | |
| wk = tensors["wk"] | |
| wv = tensors["wv"] | |
| seq_lens = tensors["seq_lens"] | |
| rope_cos = tensors["rope_cos"] | |
| rope_sin = tensors["rope_sin"] | |
| k_cache = tensors["k_cache"] | |
| v_cache = tensors["v_cache"] | |
| wo = tensors["wo"] | |
| post_rms_weight = tensors["post_rms_weight"] | |
| w_gate = tensors["w_gate"] | |
| w_up = tensors["w_up"] | |
| w_down = tensors["w_down"] | |
| attn_scale = 1.0 / math.sqrt(HEAD_DIM) | |
| # ── Scope 1: RMSNorm + Q/K/V projection ── | |
| _, q_proj, k_proj, v_proj = golden_decode_scope1( | |
| hidden_states, input_rms_weight, wq, wk, wv, | |
| hidden=HIDDEN, eps=EPS, | |
| ) | |
| # ── Scope 2: RoPE + cache update + attention ── | |
| attn_out = golden_decode_scope2( | |
| q_proj, k_proj, v_proj, k_cache, v_cache, seq_lens, | |
| rope_cos, rope_sin, | |
| batch=BATCH, num_heads=NUM_HEADS, num_kv_heads=NUM_KV_HEADS, | |
| head_dim=HEAD_DIM, max_seq=MAX_SEQ, seq_tile=SEQ_TILE, | |
| q_per_kv=Q_PER_KV, q_head_batch=Q_HEAD_BATCH, | |
| attn_scale=attn_scale, | |
| ) | |
| # ── Scope 3: output projection + residual + post RMSNorm + MLP + residual ── | |
| tensors["out"][:] = golden_decode_scope3( | |
| attn_out, hidden_states, wo, post_rms_weight, w_gate, w_up, w_down, | |
| hidden=HIDDEN, eps=EPS, | |
| ) | |
| return |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/32b/qwen3_32b_decode.py` around lines 425 - 467, The
golden_qwen3_decode function computes tensors["out"] using the refactored scope
helper functions (golden_decode_scope1, golden_decode_scope2,
golden_decode_scope3), but execution continues to fall through into the old
stale inline golden body implementation that overwrites the result. Add a return
statement immediately after the tensors["out"] assignment in the scope3 block,
or delete the entire old inline golden implementation block that follows (lines
468-588) to prevent the stale code from shadowing the refactored path during
validation.
| parser.add_argument("--smoke", action="store_true", default=False, | ||
| help="compile-only (no device); also the implicit behavior on *sim platforms.") | ||
| args = parser.parse_args() | ||
|
|
||
| specs = build_tensor_specs(use_max_seq=args.max_seq) | ||
|
|
||
| # Compile-only smoke: explicit --smoke (not sim platform — sim runs full pipeline). | ||
| if args.smoke: |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Align the --smoke help with the actual branch.
The help says smoke is implicit on *sim platforms, but Line 723 only checks args.smoke and the comment says sim runs the full pipeline. Please either remove that phrase or make the condition include sim platforms.
📝 Proposed help-text fix
parser.add_argument("--smoke", action="store_true", default=False,
- help="compile-only (no device); also the implicit behavior on *sim platforms.")
+ help="compile-only (no device)")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parser.add_argument("--smoke", action="store_true", default=False, | |
| help="compile-only (no device); also the implicit behavior on *sim platforms.") | |
| args = parser.parse_args() | |
| specs = build_tensor_specs(use_max_seq=args.max_seq) | |
| # Compile-only smoke: explicit --smoke (not sim platform — sim runs full pipeline). | |
| if args.smoke: | |
| parser.add_argument("--smoke", action="store_true", default=False, | |
| help="compile-only (no device)") | |
| args = parser.parse_args() | |
| specs = build_tensor_specs(use_max_seq=args.max_seq) | |
| # Compile-only smoke: explicit --smoke (not sim platform — sim runs full pipeline). | |
| if args.smoke: |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/32b/qwen3_32b_prefill_draft.py` around lines 716 - 723, The help
text for the --smoke argument states that smoke behavior is implicit on *sim
platforms, but the condition at line 723 only checks if args.smoke is true and
does not detect or check for sim platforms. To fix this inconsistency, either
remove the phrase "also the implicit behavior on *sim platforms" from the help
text in parser.add_argument for the smoke argument to reflect the actual
behavior of only checking args.smoke, or modify the conditional logic to detect
sim platforms and include that check in the if statement alongside the
args.smoke check. Choose the simpler option: update the help text to remove the
sim platform reference since the code comment explicitly states that sim runs
the full pipeline.
| from models.shared.golden import ( | ||
| golden_rmsnorm, | ||
| golden_rope_rotate_half, | ||
| golden_online_softmax_step, | ||
| golden_swiglu, | ||
| ) |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Docstring import path doesn't match the module name.
This file is golden_ref.py, so from models.shared.golden import (...) in the usage example won't resolve. Update to models.shared.golden_ref so the documented snippet works.
📝 Proposed fix
- from models.shared.golden import (
+ from models.shared.golden_ref import (
golden_rmsnorm,
golden_rope_rotate_half,
golden_online_softmax_step,
golden_swiglu,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from models.shared.golden import ( | |
| golden_rmsnorm, | |
| golden_rope_rotate_half, | |
| golden_online_softmax_step, | |
| golden_swiglu, | |
| ) | |
| from models.shared.golden_ref import ( | |
| golden_rmsnorm, | |
| golden_rope_rotate_half, | |
| golden_online_softmax_step, | |
| golden_swiglu, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/shared/golden_ref.py` around lines 13 - 18, The import statement at
the top of golden_ref.py imports from models.shared.golden, but since this file
is golden_ref.py itself, the import path is incorrect and won't resolve. Change
the import statement to import from models.shared.golden_ref instead of
models.shared.golden to ensure the documented usage example actually works and
matches the correct module location.
This PR consolidates duplicated leaf-level arithmetic across Qwen3-14B and Qwen3-32B model
files into shared @pl.inline primitives. The result is roughly 40% reduction in model code at
zero runtime overhead.
Dependency
This work depends on the CT kwargs mechanism added in hw-native-sys/pypto#1830, which allows compile-time keyword arguments to be folded to constants at parse time. Without it, parameterized primitives like
rmsnorm(rows=BATCH, k_chunk=512, eps=1e-6)would not produce the same IR as the hand-written loops they replace. Every primitive in this PR uses keyword-only CT kwargs for all per-call-site constants (tile sizes, loop bounds, epsilon values). The expanded IR is identical to the original code — verified by diffing compiler pass dumps against baselines captured before any changes.Changes
6 atomic commits, each independently verifiable (ruff + --smoke compile):
Primitives are ordered by complexity: silu (simplest, single tensor return) to rmsnorm
(two-pass reduction with CT kwargs) to matmul_tiled (pipelined K-accumulation) to
matmul_tiled_4d (block-major variant) to specs/golden (pure Python, no IR impact). Each commit
adds the primitive alongside its call sites — no dead code at intermediate states. The --smoke
flag is present from commit 1.