Add Qwen3 14B A8W8 kernels#642
Conversation
📝 WalkthroughWalkthroughExtends ChangesGolden Runner Compatibility and save_actual_data
Qwen3-14B A8W8 Decode Kernel
Qwen3-14B A8W8 Prefill Kernel and rms_lm_head Fix
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces full-layer decode and prefill forward passes for Qwen3-14B with A8W8 quantization, alongside compatibility shims and patching utilities in the test runner. The review feedback focuses on performance optimizations in the prefill implementation, suggesting the vectorization of row-by-row dequantization loops (for Q, K, and V projections) and the SiLU activation loop to eliminate scalar loop overhead on the NPU. Additionally, the reviewer recommends removing redundant pre-zeroing of the SiLU tile and simplifying a conditional boolean assignment in the runner utility.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| q_deq = pl.create_tensor([TOK_TILE, Q_OUT_CHUNK], dtype=pl.FP32) | ||
| for q_deq_ti in pl.range(TOK_TILE): | ||
| q_deq_row = pl.slice(q_deq_weighted, [1, Q_OUT_CHUNK], [q_deq_ti, 0]) | ||
| q_deq_scale = pl.read(act_scales, q_deq_ti) | ||
| q_deq = pl.assemble(q_deq, pl.mul(q_deq_row, q_deq_scale), [q_deq_ti, 0]) |
There was a problem hiding this comment.
The row-by-row dequantization loop can be completely vectorized using elementwise multiplication with a reshaped act_scales tensor. This avoids the overhead of slicing, reading, and assembling row-by-row on the NPU, significantly improving performance.
| q_deq = pl.create_tensor([TOK_TILE, Q_OUT_CHUNK], dtype=pl.FP32) | |
| for q_deq_ti in pl.range(TOK_TILE): | |
| q_deq_row = pl.slice(q_deq_weighted, [1, Q_OUT_CHUNK], [q_deq_ti, 0]) | |
| q_deq_scale = pl.read(act_scales, q_deq_ti) | |
| q_deq = pl.assemble(q_deq, pl.mul(q_deq_row, q_deq_scale), [q_deq_ti, 0]) | |
| q_deq = pl.mul(q_deq_weighted, pl.reshape(act_scales, [TOK_TILE, 1])) |
| k_deq = pl.create_tensor([TOK_TILE, KV_OUT_CHUNK], dtype=pl.FP32) | ||
| for k_deq_ti in pl.range(TOK_TILE): | ||
| k_deq_row = pl.slice(k_deq_weighted, [1, KV_OUT_CHUNK], [k_deq_ti, 0]) | ||
| k_deq_scale = pl.read(act_scales, k_deq_ti) | ||
| k_deq = pl.assemble(k_deq, pl.mul(k_deq_row, k_deq_scale), [k_deq_ti, 0]) |
There was a problem hiding this comment.
The row-by-row dequantization loop for the K projection can be vectorized using elementwise multiplication with a reshaped act_scales tensor, eliminating the scalar loop overhead.
| k_deq = pl.create_tensor([TOK_TILE, KV_OUT_CHUNK], dtype=pl.FP32) | |
| for k_deq_ti in pl.range(TOK_TILE): | |
| k_deq_row = pl.slice(k_deq_weighted, [1, KV_OUT_CHUNK], [k_deq_ti, 0]) | |
| k_deq_scale = pl.read(act_scales, k_deq_ti) | |
| k_deq = pl.assemble(k_deq, pl.mul(k_deq_row, k_deq_scale), [k_deq_ti, 0]) | |
| k_deq = pl.mul(k_deq_weighted, pl.reshape(act_scales, [TOK_TILE, 1])) |
| v_deq = pl.create_tensor([TOK_TILE, KV_OUT_CHUNK], dtype=pl.FP32) | ||
| for v_deq_ti in pl.range(TOK_TILE): | ||
| v_deq_row = pl.slice(v_deq_weighted, [1, KV_OUT_CHUNK], [v_deq_ti, 0]) | ||
| v_deq_scale = pl.read(act_scales, v_deq_ti) | ||
| v_deq = pl.assemble(v_deq, pl.mul(v_deq_row, v_deq_scale), [v_deq_ti, 0]) |
There was a problem hiding this comment.
The row-by-row dequantization loop for the V projection can be vectorized using elementwise multiplication with a reshaped act_scales tensor, eliminating the scalar loop overhead.
| v_deq = pl.create_tensor([TOK_TILE, KV_OUT_CHUNK], dtype=pl.FP32) | |
| for v_deq_ti in pl.range(TOK_TILE): | |
| v_deq_row = pl.slice(v_deq_weighted, [1, KV_OUT_CHUNK], [v_deq_ti, 0]) | |
| v_deq_scale = pl.read(act_scales, v_deq_ti) | |
| v_deq = pl.assemble(v_deq, pl.mul(v_deq_row, v_deq_scale), [v_deq_ti, 0]) | |
| v_deq = pl.mul(v_deq_weighted, pl.reshape(act_scales, [TOK_TILE, 1])) |
| for debug_kb in pl.range(HIDDEN_BLOCKS): | ||
| debug_mlp_k0 = debug_kb * K_CHUNK | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="debug_mlp_silu_out"): | ||
| debug_mlp_chunk = pl.slice(mlp_silu_tile, [TOK_TILE, K_CHUNK], [0, debug_mlp_k0]) | ||
| out = pl.assemble(out, debug_mlp_chunk, [token_p0, debug_mlp_k0]) | ||
|
|
||
| if DEBUG_STAGE_ID == 10: | ||
| pass | ||
| elif DEBUG_STAGE_ID == 12: | ||
| debug_down_partial0_tensor = pl.create_tensor([TOK_TILE, K_CHUNK], dtype=pl.FP32) | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="debug_down_partial0"): | ||
| debug_mlp_chunk0 = pl.slice(mlp_down_tile, [TOK_TILE, MLP_OUT_CHUNK], [0, 0]) | ||
| debug_w_down_chunk0 = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [layer_inter_base, 0]) | ||
| debug_down_partial0 = pl.matmul( | ||
| debug_mlp_chunk0, |
There was a problem hiding this comment.
The SiLU activation loop can be fully vectorized across the 2D tensor chunks without row-by-row slicing. Additionally, pre-zeroing the mlp_silu_tile with mlp_zero is redundant because pl.assemble completely overwrites the destination slice.
| for debug_kb in pl.range(HIDDEN_BLOCKS): | |
| debug_mlp_k0 = debug_kb * K_CHUNK | |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="debug_mlp_silu_out"): | |
| debug_mlp_chunk = pl.slice(mlp_silu_tile, [TOK_TILE, K_CHUNK], [0, debug_mlp_k0]) | |
| out = pl.assemble(out, debug_mlp_chunk, [token_p0, debug_mlp_k0]) | |
| if DEBUG_STAGE_ID == 10: | |
| pass | |
| elif DEBUG_STAGE_ID == 12: | |
| debug_down_partial0_tensor = pl.create_tensor([TOK_TILE, K_CHUNK], dtype=pl.FP32) | |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="debug_down_partial0"): | |
| debug_mlp_chunk0 = pl.slice(mlp_down_tile, [TOK_TILE, MLP_OUT_CHUNK], [0, 0]) | |
| debug_w_down_chunk0 = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [layer_inter_base, 0]) | |
| debug_down_partial0 = pl.matmul( | |
| debug_mlp_chunk0, | |
| mlp_silu_sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) | |
| mlp_silu_chunk = pl.mul(pl.mul(gate_acc, mlp_silu_sigmoid), up_acc) | |
| mlp_silu_tile = pl.assemble( | |
| mlp_silu_tile, | |
| pl.cast(mlp_silu_chunk, target_type=pl.BF16), | |
| [0, mlp_out_o0], | |
| ) |
| if hasattr(Worker, "chip_contexts"): | ||
| chip_contexts_installed = True | ||
| else: | ||
| chip_contexts_installed = False |
There was a problem hiding this comment.
Actionable comments posted: 8
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@golden/runner.py`:
- Line 317: Ruff B010 is flagging the constant-name setattr usage in the Worker
setup. Replace the setattr-based property assignment in the code that defines
chip_contexts, and the similar assignment mentioned for the other location, with
direct class attribute assignment while preserving the same property behavior.
Use the existing Worker and property(_chip_contexts) symbols to locate the
affected spots and update both occurrences consistently.
- Around line 804-805: The docstring for the data-saving behavior is incomplete:
`save_actual_data` is documented only for the `golden_data` path, but `Runner`
also writes `data/actual` in the non-golden path even when it is false. Update
the documentation near the `Runner`/data persistence logic to describe both
branches clearly, or adjust the condition so `data/actual` is only saved when
intended; make sure the explanation matches the behavior in the relevant
`save_actual_data` handling and `golden_data` flow.
- Around line 391-395: The import handling in runner.py is too broad because the
ModuleNotFoundError around rebuild_kernel_cpp_from_pto also hides missing
transitive dependencies inside pypto.runtime.debug.pto_rebuild. Update the
fallback in the code path that imports rebuild_kernel_cpp_from_pto so it only
returns work_dir when the debug module itself is unavailable, and let other
import failures surface instead of silently reusing stale artifacts.
- Around line 1100-1101: run_jit() is missing the same generated-artifact
patching that run() applies, so JIT-compiled or reloaded runtime_dir artifacts
can still fail on bitcast/host_orch compatibility. After
_compile_jit_with_compat() and after any runtime_dir reload logic in run_jit(),
invoke the same patch helpers used by run() on the compiled output_dir /
work_dir before proceeding. Use the existing run(), _compile_jit_with_compat(),
and runtime_dir handling paths as the reference points so the JIT flow matches
the non-JIT artifact normalization.
- Around line 965-975: The JIT fallback compile path in runner.py is dropping
distributed settings, so entries that rely on L3/distributed mode are compiled
with the wrong configuration. Update the ir.compile call in the fallback that
rebuilds the ir.Program to forward compile_cfg["distributed_config"] (or the
equivalent distributed config value) alongside the existing run_config fields,
using the ir.compile entry point and the surrounding fallback logic as the place
to fix it.
In `@models/qwen3/14b/decode_layer_a8w8.py`:
- Line 193: The comments in the affected decode layer sections use ambiguous
multiplication glyphs that Ruff flags as RUF003. Update the marked comment text
in decode_layer_a8w8.py to replace each non-ASCII multiplication symbol with
plain ASCII wording like x or by, and apply the same cleanup in the other
flagged comment blocks near the related SEQ_TILE/head references.
- Around line 19-25: The copied harness docstring in decode_layer_a8w8.py is
stale and conflicts with the current serving path. Update or remove the header
block near the decode harness so it no longer says the program does not compile
and no longer references obsolete standalone script names; instead, describe the
native serving flow through decode_fwd and keep the docstring aligned with the
current kernel entrypoint and behavior.
In `@models/qwen3/14b/prefill_fwd_a8w8.py`:
- Around line 1120-1125: The final writeback in the prefill path is assembling a
full token tile into dynamic output even when the last tile is partial. Update
the writeback logic around out_chunk, out_out_quant_chunk_bf16, and pl.assemble
in prefill_fwd_a8w8 to trim the chunk with the same set_validshape handling used
for resid1_chunk before any out writeback, and make sure the same fix is applied
in the debug branches as well.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 25cbdbf0-d212-4a1a-b664-b3666d82adc1
📒 Files selected for processing (5)
golden/runner.pymodels/qwen3/14b/decode_layer_a8w8.pymodels/qwen3/14b/prefill_fwd_a8w8.pymodels/qwen3/14b/rms_lm_head.pytests/golden/test_runner.py
| def _chip_contexts(self: Any) -> list[Any]: # noqa: ANN001 - runtime compatibility shim | ||
| return [] | ||
|
|
||
| setattr(Worker, "chip_contexts", property(_chip_contexts)) |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Use direct attribute assignment to satisfy Ruff B010.
The constant-name setattr() calls are flagged by Ruff and can be replaced without changing behavior.
Proposed fix
- setattr(Worker, "chip_contexts", property(_chip_contexts))
+ Worker.chip_contexts = property(_chip_contexts)
- setattr(_submit_next_level_compat, "_pypto_legacy_chip_callable_compat", True)
+ _submit_next_level_compat._pypto_legacy_chip_callable_compat = TrueAlso applies to: 339-339
🧰 Tools
🪛 Ruff (0.15.18)
[warning] 317-317: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@golden/runner.py` at line 317, Ruff B010 is flagging the constant-name
setattr usage in the Worker setup. Replace the setattr-based property assignment
in the code that defines chip_contexts, and the similar assignment mentioned for
the other location, with direct class attribute assignment while preserving the
same property behavior. Use the existing Worker and property(_chip_contexts)
symbols to locate the affected spots and update both occurrences consistently.
Source: Linters/SAST tools
| try: | ||
| from pypto.runtime.debug.pto_rebuild import rebuild_kernel_cpp_from_pto | ||
| except ModuleNotFoundError: | ||
| print("[runtime_only] pypto.runtime.debug unavailable; using existing runtime artifacts", flush=True) | ||
| return work_dir |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Avoid swallowing transitive import failures.
This catches any ModuleNotFoundError raised while importing pto_rebuild, including missing dependencies inside that module, and then silently reuses stale artifacts. Only fall back when the debug module itself is unavailable.
Proposed fix
- except ModuleNotFoundError:
+ except ModuleNotFoundError as e:
+ if e.name not in {"pypto.runtime.debug", "pypto.runtime.debug.pto_rebuild"}:
+ raise
print("[runtime_only] pypto.runtime.debug unavailable; using existing runtime artifacts", flush=True)
return work_dir📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| from pypto.runtime.debug.pto_rebuild import rebuild_kernel_cpp_from_pto | |
| except ModuleNotFoundError: | |
| print("[runtime_only] pypto.runtime.debug unavailable; using existing runtime artifacts", flush=True) | |
| return work_dir | |
| try: | |
| from pypto.runtime.debug.pto_rebuild import rebuild_kernel_cpp_from_pto | |
| except ModuleNotFoundError as e: | |
| if e.name not in {"pypto.runtime.debug", "pypto.runtime.debug.pto_rebuild"}: | |
| raise | |
| print("[runtime_only] pypto.runtime.debug unavailable; using existing runtime artifacts", flush=True) | |
| return work_dir |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@golden/runner.py` around lines 391 - 395, The import handling in runner.py is
too broad because the ModuleNotFoundError around rebuild_kernel_cpp_from_pto
also hides missing transitive dependencies inside
pypto.runtime.debug.pto_rebuild. Update the fallback in the code path that
imports rebuild_kernel_cpp_from_pto so it only returns work_dir when the debug
module itself is unavailable, and let other import failures surface instead of
silently reusing stale artifacts.
| save_actual_data: When True with *golden_data*, also persist runtime | ||
| outputs to ``{work_dir}/data/actual`` for downstream consumers. |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Clarify when actual outputs are saved.
The implementation saves data/actual by default when golden_data is not used, even if save_actual_data=False; the docstring only describes the golden_data + save_actual_data=True case. Please document both branches or tighten the condition.
Also applies to: 1030-1031
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@golden/runner.py` around lines 804 - 805, The docstring for the data-saving
behavior is incomplete: `save_actual_data` is documented only for the
`golden_data` path, but `Runner` also writes `data/actual` in the non-golden
path even when it is false. Update the documentation near the `Runner`/data
persistence logic to describe both branches clearly, or adjust the condition so
`data/actual` is only saved when intended; make sure the explanation matches the
behavior in the relevant `save_actual_data` handling and `golden_data` flow.
| return ir.compile( | ||
| program, | ||
| output_dir=run_config.save_kernels_dir, | ||
| strategy=run_config.strategy, | ||
| backend_type=run_config.backend_type, | ||
| dump_passes=run_config.dump_passes, | ||
| diagnostic_phase=run_config.diagnostic_phase, | ||
| disabled_diagnostics=run_config.disabled_diagnostics, | ||
| platform=run_config.platform, | ||
| profiling=run_config.compile_profiling, | ||
| ) |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify the PyPTO compile config contract without executing repo code.
rg -n "distributed_config|def compile|class RunConfig" -C3 .Repository: hw-native-sys/pypto-lib
Length of output: 11268
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== golden/runner.py around run_jit / fallback =="
sed -n '930,990p' golden/runner.py
echo
echo "== golden/runner.py around JIT docs / compile_cfg handling =="
sed -n '990,1145p' golden/runner.py
echo
echo "== tests in tests/golden/test_runner.py around JIT / distributed_config =="
sed -n '108,140p' tests/golden/test_runner.py
sed -n '1228,1264p' tests/golden/test_runner.pyRepository: hw-native-sys/pypto-lib
Length of output: 13393
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== files under pypto-related paths =="
git ls-files | rg '(^|/)(pypto|golden|tests/golden)/'
echo
echo "== search for ir.compile / distributed_config in source and tests =="
rg -n "ir\.compile\(|distributed_config|DistributedCompiledProgram|_maybe_reload_l3|_compile_jit_with_compat" golden tests -C 2Repository: hw-native-sys/pypto-lib
Length of output: 13597
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== golden/runner.py around the non-JIT compile path =="
sed -n '836,874p' golden/runner.py
echo
echo "== tests for the fallback branch of _compile_jit_with_compat =="
sed -n '140,190p' tests/golden/test_runner.pyRepository: hw-native-sys/pypto-lib
Length of output: 4483
Forward distributed_config in golden/runner.py:965-975.
The JIT fallback rebuilds an ir.Program and calls ir.compile() without passing compile_cfg["distributed_config"], so JIT entries that need L3/distributed settings will compile in the wrong mode.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@golden/runner.py` around lines 965 - 975, The JIT fallback compile path in
runner.py is dropping distributed settings, so entries that rely on
L3/distributed mode are compiled with the wrong configuration. Update the
ir.compile call in the fallback that rebuilds the ir.Program to forward
compile_cfg["distributed_config"] (or the equivalent distributed config value)
alongside the existing run_config fields, using the ir.compile entry point and
the surrounding fallback logic as the place to fix it.
| compiled = _compile_jit_with_compat(fn, dummy_args, cfg) | ||
| work_dir = Path(compiled.output_dir) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Apply the generated-artifact patches in run_jit() too.
run() patches the compiled/runtime directory before execution, but run_jit() skips those helpers after JIT compilation or runtime_dir reload. JIT L3 artifacts can still hit the same bitcast/host_orch compatibility failures.
Proposed fix
if compile_only:
total = time.time() - start
print(f"[RUN] PASS ({total:.2f}s)", flush=True)
return RunResult(passed=True, execution_time=total, work_dir=work_dir)
+
+ if work_dir is not None:
+ _patch_aicore_bitcast_helpers(work_dir)
+ _patch_l3_single_submit_host_orch(work_dir)
+ _patch_l3_host_orch_ssa_aliases(work_dir)
# Generate Inputs📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| compiled = _compile_jit_with_compat(fn, dummy_args, cfg) | |
| work_dir = Path(compiled.output_dir) | |
| compiled = _compile_jit_with_compat(fn, dummy_args, cfg) | |
| work_dir = Path(compiled.output_dir) | |
| if work_dir is not None: | |
| _patch_aicore_bitcast_helpers(work_dir) | |
| _patch_l3_single_submit_host_orch(work_dir) | |
| _patch_l3_host_orch_ssa_aliases(work_dir) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@golden/runner.py` around lines 1100 - 1101, run_jit() is missing the same
generated-artifact patching that run() applies, so JIT-compiled or reloaded
runtime_dir artifacts can still fail on bitcast/host_orch compatibility. After
_compile_jit_with_compat() and after any runtime_dir reload logic in run_jit(),
invoke the same patch helpers used by run() on the compiled output_dir /
work_dir before proceeding. Use the existing run(), _compile_jit_with_compat(),
and runtime_dir handling paths as the reference points so the JIT flow matches
the non-JIT artifact normalization.
| EXPECTED / INTENT program (the dense block-level load balancer). NOTE: this does | ||
| NOT compile on the current toolchain — the data-dependent ``pl.read`` scalar that | ||
| feeds the store offset (``g_base + sb * Q_HEAD_PAD``) trips a PTO codegen | ||
| limitation (``GetOrCreateTensorView`` / ptoas ``index vs i64``; see | ||
| ``KNOWN_ISSUES.md``). It is written to capture the desired structure; the | ||
| affine fallback that DOES compile lives in | ||
| ``qwen3_manual_scope_fused_kvsplit_static.py`` (coprime-stride, ~1.9x balance). |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Refresh the copied harness docstring.
This new serving kernel is expected to compile/run via decode_fwd, but the header still says the program does not compile and points users at obsolete standalone script names. Please update or remove this stale usage block so it matches the PR’s native serving path.
Also applies to: 52-55
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/decode_layer_a8w8.py` around lines 19 - 25, The copied
harness docstring in decode_layer_a8w8.py is stale and conflicts with the
current serving path. Update or remove the header block near the decode harness
so it no longer says the program does not compile and no longer references
obsolete standalone script names; instead, describe the native serving flow
through decode_fwd and keep the docstring aligned with the current kernel
entrypoint and behavior.
| # ragged seq_lens, at the cost of more fa_fused tasks | ||
| # (BATCH * (NUM_KV_HEADS // 2) * KV_SPLITS). | ||
| # Dispatch unit = ONE seq block (TOKENS_PER_SPLIT == SEQ_TILE). Every fa_fused | ||
| # work item is then a single SEQ_TILE block (×2 heads) — equal cost regardless of |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Replace ambiguous multiplication-sign glyphs flagged by Ruff.
Ruff reports RUF003 on these comment lines. Use plain x/by wording to keep lint clean.
Proposed cleanup
-# Dispatch unit = ONE seq block (TOKENS_PER_SPLIT == SEQ_TILE). Every fa_fused
-# work item is then a single SEQ_TILE block (×2 heads)
+# Dispatch unit = ONE seq block (TOKENS_PER_SPLIT == SEQ_TILE). Every fa_fused
+# work item is then a single SEQ_TILE block (by 2 heads)Apply the same replacement to the other flagged comment lines.
Also applies to: 215-215, 227-227, 817-817
🧰 Tools
🪛 Ruff (0.15.18)
[warning] 193-193: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/decode_layer_a8w8.py` at line 193, The comments in the
affected decode layer sections use ambiguous multiplication glyphs that Ruff
flags as RUF003. Update the marked comment text in decode_layer_a8w8.py to
replace each non-ASCII multiplication symbol with plain ASCII wording like x or
by, and apply the same cleanup in the other flagged comment blocks near the
related SEQ_TILE/head references.
Source: Linters/SAST tools
| out_chunk = pl.add( | ||
| down_acc, | ||
| pl.slice(resid1_tile, [TOK_TILE, K_CHUNK], [0, down_proj_d0]), | ||
| ) | ||
| out_out_quant_chunk_bf16 = pl.cast(out_chunk, target_type=pl.BF16) | ||
| out = pl.assemble(out, out_out_quant_chunk_bf16, [token_p0, down_proj_d0]) |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🔴 Critical | ⚡ Quick win
Trim the final partial tile before assembling into out.
Line 1125 writes a full [TOK_TILE, K_CHUNK] chunk into dynamic out; when valid_tok < TOK_TILE, this can write invalid rows past the packed prefill output. Apply the same set_validshape trim used for resid1_chunk before all out writebacks, including debug branches.
🐛 Proposed fix for the final writeback
- out_out_quant_chunk_bf16 = pl.cast(out_chunk, target_type=pl.BF16)
- out = pl.assemble(out, out_out_quant_chunk_bf16, [token_p0, down_proj_d0])
+ out_out_quant_chunk_bf16 = pl.cast(out_chunk, target_type=pl.BF16)
+ out_out_quant_chunk_valid = pl.tensor.set_validshape(
+ out_out_quant_chunk_bf16,
+ valid_tok,
+ K_CHUNK,
+ )
+ out = pl.assemble(out, out_out_quant_chunk_valid, [token_p0, down_proj_d0])📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| out_chunk = pl.add( | |
| down_acc, | |
| pl.slice(resid1_tile, [TOK_TILE, K_CHUNK], [0, down_proj_d0]), | |
| ) | |
| out_out_quant_chunk_bf16 = pl.cast(out_chunk, target_type=pl.BF16) | |
| out = pl.assemble(out, out_out_quant_chunk_bf16, [token_p0, down_proj_d0]) | |
| out_chunk = pl.add( | |
| down_acc, | |
| pl.slice(resid1_tile, [TOK_TILE, K_CHUNK], [0, down_proj_d0]), | |
| ) | |
| out_out_quant_chunk_bf16 = pl.cast(out_chunk, target_type=pl.BF16) | |
| out_out_quant_chunk_valid = pl.tensor.set_validshape( | |
| out_out_quant_chunk_bf16, | |
| valid_tok, | |
| K_CHUNK, | |
| ) | |
| out = pl.assemble(out, out_out_quant_chunk_valid, [token_p0, down_proj_d0]) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/prefill_fwd_a8w8.py` around lines 1120 - 1125, The final
writeback in the prefill path is assembling a full token tile into dynamic
output even when the last tile is partial. Update the writeback logic around
out_chunk, out_out_quant_chunk_bf16, and pl.assemble in prefill_fwd_a8w8 to trim
the chunk with the same set_validshape handling used for resid1_chunk before any
out writeback, and make sure the same fix is applied in the debug branches as
well.
c0e0128 to
c7ada3a
Compare
Add an explicit batched Q RoPE path guarded by QWEN_A8W8_Q_ROPE_BATCH_EXPLICIT and keep fused QK norm dependencies explicit so the A8W8 decode fast path preserves text quality during generation.
Add the Qwen3-14B A8W8 prefill_hidden and decode_fwd PyPTO kernels used by the native serving path.
Keep the existing BF16 Qwen3 path isolated while wiring the A8W8 kernel constants, runner support, and golden runner compile fixes needed by the generated callables.
Slim the lib-side delivery by removing obsolete standalone golden/smoke harnesses and unused JIT entry points after the scheduling path moved to pypto-serving. The debug-stage branches remain available for future numerical diagnosis.
Related: