Add native Qwen3 14B A8W8 serving path#48
Conversation
📝 WalkthroughWalkthroughAdds an A8W8/BF16 checkpoint loader and decode-backend CLI options for Qwen3-14B, introduces L3-mode and A8W8-decode-loop execution paths with quantization-aware compilation and weight handling in the NPU executor, threads a ChangesA8W8/L3 Decode Feature
Dependency-free Qwen Tokenizer Fallback
Sequence Diagram(s)sequenceDiagram
participant Engine
participant Executor as Qwen314BPyptoExecutor
participant A8W8Loop as run_generate_a8w8_decode_loop
participant L3 as run_generate_l3
Engine->>Executor: try_generate_batch(record, requests, prefill_batch, config)
Executor->>Executor: validate_generate_batch / prompt_allocation_length
alt a8w8_decode_loop enabled
Executor->>A8W8Loop: run_generate_a8w8_decode_loop(...)
else l3_mode enabled
Executor->>L3: run_generate_l3(...)
end
Executor-->>Engine: GenerateResult(finish_reason)
sequenceDiagram
participant main as npu_generate.main
participant Loader as Qwen3A8W8DirectoryLoader
participant Index as _SafeTensorIndex
participant Runtime as RuntimeModel
main->>Loader: load(request)
Loader->>Loader: read config.json, build tokenizer
Loader->>Index: scan model-*.safetensors shards
Index-->>Loader: tensor offsets/dtype/shape
Loader->>Index: load(key) per tensor
Loader->>Loader: convert layouts (int8 kernel, bf16 dequant)
Loader->>Runtime: populate embeddings/norm/head/layers
Loader-->>main: LoadedModel
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for qwen3-a8w8 model quantization, including a new loader for compressed-tensors checkpoints and updated NPU execution and runner logic to handle A8W8 kernels and KV cache scaling. It also adds a dependency-free byte-level BPE tokenizer fallback. The review highlights significant code duplication in the NPU executor and runner logic, suggesting refactoring to improve maintainability. Additionally, it recommends performance optimizations for file I/O and tokenization, replacing SimpleNamespace with dataclass for better type safety, and avoiding bare except blocks.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if quantization == "a8w8": | ||
| prefill = self._compile_prefill_fwd_callable_a8w8( | ||
| qwen3_prefill_fwd.prefill_hidden_a8w8, | ||
| batch=kernel_batch, | ||
| max_seq=model.runtime.max_seq_len, | ||
| hidden_size=model.config.hidden_size, | ||
| intermediate_size=model.config.intermediate_size, | ||
| num_heads=model.config.num_attention_heads, | ||
| num_kv_heads=model.config.num_key_value_heads, | ||
| head_dim=model.config.head_dim, | ||
| num_layers=min(model.config.num_hidden_layers, _QWEN14B_A8W8_PREFILL_CHUNK_LAYERS), | ||
| vocab_size=padded_vocab, | ||
| block_table_stride=max_blocks_per_seq, | ||
| page_size=page_size, | ||
| ) | ||
| else: | ||
| prefill = self._compile_prefill_fwd_callable( | ||
| qwen3_prefill_fwd.prefill_fwd, | ||
| batch=kernel_batch, | ||
| max_seq=model.runtime.max_seq_len, | ||
| hidden_size=model.config.hidden_size, | ||
| intermediate_size=model.config.intermediate_size, | ||
| num_heads=model.config.num_attention_heads, | ||
| num_kv_heads=model.config.num_key_value_heads, | ||
| head_dim=model.config.head_dim, | ||
| num_layers=model.config.num_hidden_layers, | ||
| vocab_size=padded_vocab, | ||
| block_table_stride=max_blocks_per_seq, | ||
| page_size=page_size, | ||
| ) | ||
| _mark("compile_prefill") | ||
| decode = self._compile_decode_fwd_callable( | ||
| qwen3_l3_dispatch.qwen3_decode_host, | ||
| batch=kernel_batch, | ||
| max_seq=model.runtime.max_seq_len, | ||
| block_table_stride=max_blocks_per_seq, | ||
| hidden_size=model.config.hidden_size, | ||
| intermediate_size=model.config.intermediate_size, | ||
| num_heads=model.config.num_attention_heads, | ||
| num_kv_heads=model.config.num_key_value_heads, | ||
| head_dim=model.config.head_dim, | ||
| num_layers=model.config.num_hidden_layers, | ||
| vocab_size=padded_vocab, | ||
| page_size=page_size, | ||
| ) | ||
| if quantization == "a8w8": | ||
| decode = self._compile_decode_fwd_callable_a8w8( | ||
| qwen3_decode_layer.decode_fwd, | ||
| batch=kernel_batch, | ||
| max_seq=model.runtime.max_seq_len, | ||
| block_table_stride=max_blocks_per_seq, | ||
| hidden_size=model.config.hidden_size, | ||
| intermediate_size=model.config.intermediate_size, | ||
| num_heads=model.config.num_attention_heads, | ||
| num_kv_heads=model.config.num_key_value_heads, | ||
| head_dim=model.config.head_dim, | ||
| num_layers=model.config.num_hidden_layers, | ||
| vocab_size=padded_vocab, | ||
| page_size=page_size, | ||
| ) | ||
| else: | ||
| decode = self._compile_decode_fwd_callable( | ||
| qwen3_decode_layer.decode_fwd, | ||
| batch=kernel_batch, | ||
| max_seq=model.runtime.max_seq_len, | ||
| block_table_stride=max_blocks_per_seq, | ||
| hidden_size=model.config.hidden_size, | ||
| intermediate_size=model.config.intermediate_size, | ||
| num_heads=model.config.num_attention_heads, | ||
| num_kv_heads=model.config.num_key_value_heads, | ||
| head_dim=model.config.head_dim, | ||
| num_layers=model.config.num_hidden_layers, | ||
| vocab_size=padded_vocab, | ||
| page_size=page_size, | ||
| ) |
There was a problem hiding this comment.
There is significant code duplication between the if and else branches for handling a8w8 quantization. The logic for compiling the prefill and decode callables is very similar. This makes the code harder to read and maintain.
Consider refactoring this to reduce duplication. You could use a helper function that takes quantization-specific parameters (like the jit function and some arguments) and returns the compiled callable.
| if compiled.quantization == "a8w8": | ||
| if kv_scales is None: | ||
| raise RuntimeError(f"missing A8W8 KV scale cache for model {model.config.model_id!r}") | ||
| k_cache_scale, v_cache_scale = kv_scales | ||
| rows_per_layer = k_cache.shape[0] // model.config.num_hidden_layers | ||
| hidden = prefill_inputs.hidden | ||
|
|
||
| def weight_slice(name: str, start: int, layers: int, rows_per_layer_: int = 1) -> WorkerTensor: | ||
| tensor = dw[name][start * rows_per_layer_ : (start + layers) * rows_per_layer_] | ||
| return self._l2_child_tensor(compiled.prefill.runtime_name, tensor) | ||
|
|
||
| for layer_start in range(0, model.config.num_hidden_layers, _QWEN14B_A8W8_PREFILL_CHUNK_LAYERS): | ||
| layer_count = min( | ||
| _QWEN14B_A8W8_PREFILL_CHUNK_LAYERS, | ||
| model.config.num_hidden_layers - layer_start, | ||
| ) | ||
| cache_row_start = layer_start * rows_per_layer | ||
| cache_rows = layer_count * rows_per_layer | ||
| hidden_out = torch.empty_like(hidden).share_memory_() | ||
| scratch_logits = torch.empty_like(logits_padded).share_memory_() | ||
| self._run_l2_program( | ||
| compiled.prefill, | ||
| hidden, | ||
| prefill_inputs.seq_lens, | ||
| prefill_inputs.chunk_lens, | ||
| prefill_inputs.chunk_offsets, | ||
| weight_slice("decode_input_rms_weight", layer_start, layer_count), | ||
| weight_slice("decode_wq", layer_start, layer_count, model.config.hidden_size), | ||
| weight_slice("decode_wk", layer_start, layer_count, model.config.hidden_size), | ||
| weight_slice("decode_wv", layer_start, layer_count, model.config.hidden_size), | ||
| weight_slice("decode_wq_scale", layer_start, layer_count), | ||
| weight_slice("decode_wk_scale", layer_start, layer_count), | ||
| weight_slice("decode_wv_scale", layer_start, layer_count), | ||
| weight_slice("decode_q_norm_weight", layer_start, layer_count), | ||
| weight_slice("decode_k_norm_weight", layer_start, layer_count), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, compiled.rope_cos), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, compiled.rope_sin), | ||
| prefill_inputs.block_table, | ||
| prefill_inputs.slot_mapping, | ||
| self._worker_tensor_view( | ||
| k_cache, | ||
| cache_row_start * model.config.head_dim, | ||
| (cache_rows, model.config.head_dim), | ||
| 1, | ||
| ), | ||
| self._worker_tensor_view( | ||
| v_cache, | ||
| cache_row_start * model.config.head_dim, | ||
| (cache_rows, model.config.head_dim), | ||
| 1, | ||
| ), | ||
| self._worker_tensor_view(k_cache_scale, cache_row_start * 8, (cache_rows, 8), 4), | ||
| self._worker_tensor_view(v_cache_scale, cache_row_start * 8, (cache_rows, 8), 4), | ||
| weight_slice("decode_wo", layer_start, layer_count, model.config.hidden_size), | ||
| weight_slice("decode_wo_scale", layer_start, layer_count), | ||
| weight_slice("decode_post_rms_weight", layer_start, layer_count), | ||
| weight_slice("decode_w_gate", layer_start, layer_count, model.config.hidden_size), | ||
| weight_slice("decode_w_up", layer_start, layer_count, model.config.hidden_size), | ||
| weight_slice("decode_w_down", layer_start, layer_count, model.config.intermediate_size), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, compiled.final_norm_weight), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, compiled.padded_lm_head_weight), | ||
| scratch_logits, | ||
| hidden_out, | ||
| ) | ||
| hidden = hidden_out | ||
| logits_padded = self._project_logits_host(model, compiled, prefill_inputs, hidden) | ||
| else: | ||
| self._run_l2_program( | ||
| compiled.prefill, | ||
| prefill_inputs.hidden, | ||
| prefill_inputs.seq_lens, | ||
| prefill_inputs.chunk_lens, | ||
| prefill_inputs.chunk_offsets, | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_input_rms_weight"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_wq"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_wk"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_wv"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_q_norm_weight"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_k_norm_weight"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, compiled.rope_cos), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, compiled.rope_sin), | ||
| prefill_inputs.block_table, | ||
| prefill_inputs.slot_mapping, | ||
| k_cache, | ||
| v_cache, | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_wo"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_post_rms_weight"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_w_gate"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_w_up"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_w_down"]), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, compiled.final_norm_weight), | ||
| self._l2_child_tensor(compiled.prefill.runtime_name, compiled.padded_lm_head_weight), | ||
| logits_padded, | ||
| ) |
There was a problem hiding this comment.
This if/else block for a8w8 quantization introduces a large amount of duplicated code. The logic inside both branches is very similar, differing mainly in the arguments passed to _run_l2_program and the chunking logic for a8w8.
To improve maintainability, this should be refactored. For example, you could prepare a dictionary of arguments for _run_l2_program based on the quantization mode, and then have a more unified control flow.
| if compiled.quantization == "a8w8": | ||
| if kv_scales is None: | ||
| raise RuntimeError(f"missing A8W8 KV scale cache for model {model_id!r}") | ||
| k_cache_scale, v_cache_scale = kv_scales | ||
| self._run_l2_program( | ||
| compiled.decode, | ||
| hidden, | ||
| self._l2_child_tensor(rt, dw["decode_input_rms_weight"]), | ||
| self._l2_child_tensor(rt, dw["decode_wq"]), | ||
| self._l2_child_tensor(rt, dw["decode_wk"]), | ||
| self._l2_child_tensor(rt, dw["decode_wv"]), | ||
| self._l2_child_tensor(rt, dw["decode_wq_scale"]), | ||
| self._l2_child_tensor(rt, dw["decode_wk_scale"]), | ||
| self._l2_child_tensor(rt, dw["decode_wv_scale"]), | ||
| self._l2_child_tensor(rt, dw["decode_q_norm_weight"]), | ||
| self._l2_child_tensor(rt, dw["decode_k_norm_weight"]), | ||
| seq_lens, | ||
| block_table, | ||
| slot_mapping, | ||
| self._l2_child_tensor(rt, compiled.rope_cos), | ||
| self._l2_child_tensor(rt, compiled.rope_sin), | ||
| k_cache, | ||
| v_cache, | ||
| k_cache_scale, | ||
| v_cache_scale, | ||
| self._l2_child_tensor(rt, dw["decode_wo"]), | ||
| self._l2_child_tensor(rt, dw["decode_wo_scale"]), | ||
| self._l2_child_tensor(rt, dw["decode_w_gate"]), | ||
| self._l2_child_tensor(rt, dw["decode_w_up"]), | ||
| self._l2_child_tensor(rt, dw["decode_w_down"]), | ||
| self._l2_child_tensor(rt, dw["decode_post_rms_weight"]), | ||
| self._l2_child_tensor(rt, compiled.final_norm_weight), | ||
| self._l2_child_tensor(rt, compiled.padded_lm_head_weight), | ||
| logits_padded, | ||
| ) | ||
| else: | ||
| # Argument order MUST match decode_layer.decode_fwd (PAGED): | ||
| # hidden_states, input_rms_weight, wq, wk, wv, q_norm_weight, | ||
| # k_norm_weight, seq_lens, block_table, slot_mapping, rope_cos, rope_sin, | ||
| # k_cache, v_cache, wo, w_gate, w_up, w_down, post_rms_weight, | ||
| # final_norm_weight, lm_head_weight, out. | ||
| self._run_l2_program( | ||
| compiled.decode, | ||
| hidden, | ||
| self._l2_child_tensor(rt, dw["decode_input_rms_weight"]), | ||
| self._l2_child_tensor(rt, dw["decode_wq"]), | ||
| self._l2_child_tensor(rt, dw["decode_wk"]), | ||
| self._l2_child_tensor(rt, dw["decode_wv"]), | ||
| self._l2_child_tensor(rt, dw["decode_q_norm_weight"]), | ||
| self._l2_child_tensor(rt, dw["decode_k_norm_weight"]), | ||
| seq_lens, | ||
| block_table, | ||
| slot_mapping, | ||
| self._l2_child_tensor(rt, compiled.rope_cos), | ||
| self._l2_child_tensor(rt, compiled.rope_sin), | ||
| k_cache, | ||
| v_cache, | ||
| self._l2_child_tensor(rt, dw["decode_wo"]), | ||
| self._l2_child_tensor(rt, dw["decode_w_gate"]), | ||
| self._l2_child_tensor(rt, dw["decode_w_up"]), | ||
| self._l2_child_tensor(rt, dw["decode_w_down"]), | ||
| self._l2_child_tensor(rt, dw["decode_post_rms_weight"]), | ||
| self._l2_child_tensor(rt, compiled.final_norm_weight), | ||
| self._l2_child_tensor(rt, compiled.padded_lm_head_weight), | ||
| logits_padded, | ||
| ) |
There was a problem hiding this comment.
Similar to run_prefill, this if/else block for a8w8 quantization duplicates a significant amount of code from the bf16 path. This makes the code difficult to maintain, as changes need to be applied in two places.
Please refactor this to remove the duplication. You could prepare the arguments for _run_l2_program in a helper function based on the quantization mode, then have a single call to _run_l2_program.
| with rec.shard.open("rb") as f: | ||
| f.seek(rec.data_start + begin) | ||
| raw = bytearray(f.read(end - begin)) |
There was a problem hiding this comment.
| SimpleNamespace( | ||
| quantization="a8w8", | ||
| input_rms_weight=index.load(f"{prefix}.input_layernorm.weight").reshape(1, -1).float(), | ||
| wq=_hf_linear_to_kernel_i8(index, f"{prefix}.self_attn.q_proj.weight"), | ||
| wk=_hf_linear_to_kernel_i8(index, f"{prefix}.self_attn.k_proj.weight"), | ||
| wv=_hf_linear_to_kernel_i8(index, f"{prefix}.self_attn.v_proj.weight"), | ||
| wq_scale=_hf_scale_to_kernel(index, f"{prefix}.self_attn.q_proj.weight_scale"), | ||
| wk_scale=_hf_scale_to_kernel(index, f"{prefix}.self_attn.k_proj.weight_scale"), | ||
| wv_scale=_hf_scale_to_kernel(index, f"{prefix}.self_attn.v_proj.weight_scale"), | ||
| q_norm_weight=index.load(f"{prefix}.self_attn.q_norm.weight").reshape(1, -1).float(), | ||
| k_norm_weight=index.load(f"{prefix}.self_attn.k_norm.weight").reshape(1, -1).float(), | ||
| wo=_hf_linear_to_kernel_i8(index, f"{prefix}.self_attn.o_proj.weight"), | ||
| wo_scale=_hf_scale_to_kernel(index, f"{prefix}.self_attn.o_proj.weight_scale"), | ||
| post_rms_weight=index.load(f"{prefix}.post_attention_layernorm.weight").reshape(1, -1).float(), | ||
| w_gate=_hf_linear_to_bf16( | ||
| index, | ||
| f"{prefix}.mlp.gate_proj.weight", | ||
| f"{prefix}.mlp.gate_proj.weight_scale", | ||
| ), | ||
| w_up=_hf_linear_to_bf16( | ||
| index, | ||
| f"{prefix}.mlp.up_proj.weight", | ||
| f"{prefix}.mlp.up_proj.weight_scale", | ||
| ), | ||
| w_down=_hf_linear_to_bf16( | ||
| index, | ||
| f"{prefix}.mlp.down_proj.weight", | ||
| f"{prefix}.mlp.down_proj.weight_scale", | ||
| ), | ||
| ) |
There was a problem hiding this comment.
Using types.SimpleNamespace is flexible but lacks the explicitness and type safety of a dataclass. For better maintainability and code clarity, consider defining a dataclass for the layer weights. This provides type hints, autocompletion, and a more rigid structure, reducing the chance of runtime errors.
| assembled = compile_and_assemble( | ||
| work_dir, | ||
| self._platform, | ||
| pto_isa_commit=config.pto_isa_commit, | ||
| ) | ||
| if len(assembled) == 2: | ||
| chip_callable, runtime_name = assembled | ||
| runtime_config = {} | ||
| else: | ||
| chip_callable, runtime_name, runtime_config = assembled |
| def _compile_prefill_fwd_callable_a8w8( | ||
| self, | ||
| jit_fn: object, | ||
| *, | ||
| batch: int, | ||
| max_seq: int, | ||
| block_table_stride: int, | ||
| hidden_size: int, | ||
| intermediate_size: int, | ||
| num_heads: int, | ||
| num_kv_heads: int, | ||
| head_dim: int, | ||
| num_layers: int, | ||
| vocab_size: int, | ||
| page_size: int, | ||
| ) -> _L2Callable: |
There was a problem hiding this comment.
This new function _compile_prefill_fwd_callable_a8w8 is very similar to _compile_prefill_fwd_callable. The same applies to the decode counterparts. This introduces significant code duplication.
Consider merging the a8w8 and non-a8w8 versions of these functions. You could parameterize them by passing the dtypes for dummy tensors based on the quantization mode, which would make the code more maintainable.
| except Exception: | ||
| self._free_kv_cache_tensor(key_scale) | ||
| raise |
There was a problem hiding this comment.
Using a bare except Exception: is generally discouraged as it can catch and hide unexpected errors, making debugging more difficult. Please catch a more specific exception if possible, for instance a memory allocation error. If you must catch Exception, consider logging it before re-raising to provide more context on failures.
| while i < len(word): | ||
| try: | ||
| j = word.index(first, i) | ||
| except ValueError: | ||
| new_word.extend(word[i:]) | ||
| break | ||
| new_word.extend(word[i:j]) | ||
| i = j | ||
| if i < len(word) - 1 and word[i] == first and word[i + 1] == second: | ||
| new_word.append(first + second) | ||
| i += 2 | ||
| else: | ||
| new_word.append(word[i]) | ||
| i += 1 |
There was a problem hiding this comment.
The loop to rebuild the word after finding a bigram merge can be inefficient. The word.index(first, i) call inside the while loop may repeatedly scan parts of the word tuple.
For better performance, you could iterate through the word once and build the new_word list in a single pass, checking for the (first, second) bigram at each position. While lru_cache helps, optimizing the core algorithm is still beneficial.
a9b4d34 to
b09f2cf
Compare
b09f2cf to
2343c8e
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/model/qwen3_14b/npu_generate.py`:
- Around line 421-422: The QWEN_A8W8_BF16_KV environment flag is only ever set
and never cleared in main(), so a previous bf16-kv run can leak into later A8W8
runs in the same process. Update the argument-handling path around
args.model_format and args.decode_backend to deterministically reset
QWEN_A8W8_BF16_KV for every invocation, explicitly setting it to "1" only for
qwen3-a8w8 with bf16-kv and removing or clearing it otherwise.
In `@examples/model/qwen3_14b/runner/a8w8_loader.py`:
- Around line 86-89: Reshape the linear scale tensor in _hf_linear_to_bf16
before multiplying so it broadcasts over output channels correctly instead of
the last axis. Update the _hf_linear_to_bf16 flow to explicitly reshape or
unsqueeze the loaded scale from index.load(scale_key) before the weight * scale
operation, then keep the transpose/contiguous BF16 conversion as-is.
In `@python/core/tokenizer.py`:
- Around line 248-250: The byte-level tokenizer input is being altered in
Tokenizer.encode by NFC normalization, which breaks round-tripping for
decomposed Unicode text. Remove the unicodedata.normalize("NFC", text) step from
encode and preserve the original input bytes before byte-level BPE processing;
keep the rest of Tokenizer.encode and decode behavior unchanged.
- Around line 135-181: Add parity tests for the Qwen fallback tokenizer behavior
in _split_qwen_text, focusing on edge cases where the current split logic can
diverge from Hugging Face: repeated spaces, newline handling, and
punctuation/space boundaries. Create coverage that compares the fallback output
against the expected Qwen regex-based tokenization behavior for representative
inputs, so regressions are caught before relying on this path in serving.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 5243ea88-e297-4ca7-9abc-c6e9f3bcda03
📒 Files selected for processing (8)
examples/model/qwen3_14b/npu_generate.pyexamples/model/qwen3_14b/runner/a8w8_loader.pyexamples/model/qwen3_14b/runner/npu_executor.pyexamples/model/qwen3_14b/runner/npu_runner.pypython/core/pypto_executor.pypython/core/tokenizer.pypython/runtime/worker.pytests/test_batching.py
| if args.model_format == "qwen3-a8w8" and args.decode_backend == "bf16-kv": | ||
| os.environ["QWEN_A8W8_BF16_KV"] = "1" |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win
Set the BF16-KV env flag deterministically.
If main() runs more than once in-process, a prior bf16-kv run leaves QWEN_A8W8_BF16_KV=1, so a later plain A8W8 run can compile decode with BF16 KV assumptions while allocating INT8 KV.
🐛 Proposed fix
- if args.model_format == "qwen3-a8w8" and args.decode_backend == "bf16-kv":
- os.environ["QWEN_A8W8_BF16_KV"] = "1"
+ os.environ["QWEN_A8W8_BF16_KV"] = (
+ "1" if args.model_format == "qwen3-a8w8" and args.decode_backend == "bf16-kv" else "0"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if args.model_format == "qwen3-a8w8" and args.decode_backend == "bf16-kv": | |
| os.environ["QWEN_A8W8_BF16_KV"] = "1" | |
| os.environ["QWEN_A8W8_BF16_KV"] = ( | |
| "1" if args.model_format == "qwen3-a8w8" and args.decode_backend == "bf16-kv" else "0" | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/model/qwen3_14b/npu_generate.py` around lines 421 - 422, The
QWEN_A8W8_BF16_KV environment flag is only ever set and never cleared in main(),
so a previous bf16-kv run can leak into later A8W8 runs in the same process.
Update the argument-handling path around args.model_format and
args.decode_backend to deterministically reset QWEN_A8W8_BF16_KV for every
invocation, explicitly setting it to "1" only for qwen3-a8w8 with bf16-kv and
removing or clearing it otherwise.
| def _hf_linear_to_bf16(index: _SafeTensorIndex, weight_key: str, scale_key: str) -> torch.Tensor: | ||
| weight = index.load(weight_key).float() | ||
| scale = index.load(scale_key).float() | ||
| return (weight * scale).t().contiguous().to(torch.bfloat16) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify PyTorch broadcasting behavior for HF [out, in] linear weights.
python - <<'PY'
import torch
weight = torch.empty(17408, 5120)
scale = torch.empty(17408)
try:
weight * scale
except RuntimeError as exc:
print("1-D scale fails as expected:", str(exc).splitlines()[0])
print((weight * scale.reshape(-1, 1)).shape)
PYRepository: hw-native-sys/pypto-serving
Length of output: 283
🏁 Script executed:
#!/bin/bash
set -e
file="examples/model/qwen3_14b/runner/a8w8_loader.py"
printf '== outline ==\n'
ast-grep outline "$file" --view expanded || true
printf '\n== relevant lines ==\n'
nl -ba "$file" | sed -n '1,220p'Repository: hw-native-sys/pypto-serving
Length of output: 1168
🏁 Script executed:
#!/bin/bash
set -e
rg -n "_hf_linear_to_bf16|_hf_scale_to_kernel|reshape|transpose|contiguous|bfloat16|scale_key|weight_key" examples/model/qwen3_14b/runner/a8w8_loader.pyRepository: hw-native-sys/pypto-serving
Length of output: 2194
🌐 Web query:
PyTorch broadcasting 1D tensor with 2D tensor trailing dimension semantics official docs
💡 Result:
According to the official PyTorch documentation, broadcasting follows NumPy's semantics [1][2][3]. Two tensors are "broadcastable" if, when iterating over their dimension sizes starting from the trailing (last) dimension, the dimension sizes either match, one of them is 1, or one of them does not exist [1][4][5]. When broadcasting a 1D tensor with a 2D tensor, the 1D tensor's dimensions are compared starting from the trailing dimension of the 2D tensor [1][5]. Specifically, the 1D tensor (size N) is implicitly treated as having a leading dimension of 1 (size 1 x N) to align with the 2D tensor (size M x N) [1][3][5]. If you wish to broadcast a 1D tensor (size N) against a 2D tensor (size M x N) such that the operation applies across rows, you must explicitly align the dimensions, often by unsqueezing the 1D tensor to shape (N, 1) or (1, N) depending on the desired operation [6]. It is important to distinguish between general element-wise broadcasting (such as addition or multiplication) and specialized matrix operations like torch.matmul. While element-wise broadcasting follows the general rules described above [1], torch.matmul has its own distinct broadcasting logic for batched matrix multiplication, which includes special handling for appending or prepending dimensions specifically for 1D tensors to facilitate matrix-vector or batch-matrix multiplication [7].
Citations:
- 1: https://docs.pytorch.org/docs/main/notes/broadcasting.html
- 2: https://docs.pytorch.org/docs/2.12/notes/broadcasting.html
- 3: https://docs.pytorch.org/docs/2.9/notes/broadcasting.html
- 4: http://docs.pytorch.org/docs/2.11/notes/broadcasting.html
- 5: https://docs.pytorch.org/docs/2.5/notes/broadcasting.html
- 6: https://discuss.pytorch.org/t/multiply-1d-tensor-by-2d-tensor/9333
- 7: https://docs.pytorch.org/docs/2.1/generated/torch.matmul.html
Reshape the linear scales before BF16 dequantization.
weight * scale treats the 1-D scale as a row vector, so it broadcasts across the last axis. That errors on non-square projections and scales columns instead of output channels on square ones.
🐛 Proposed fix
def _hf_linear_to_bf16(index: _SafeTensorIndex, weight_key: str, scale_key: str) -> torch.Tensor:
weight = index.load(weight_key).float()
- scale = index.load(scale_key).float()
+ scale = index.load(scale_key).float().reshape(-1, 1)
return (weight * scale).t().contiguous().to(torch.bfloat16)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _hf_linear_to_bf16(index: _SafeTensorIndex, weight_key: str, scale_key: str) -> torch.Tensor: | |
| weight = index.load(weight_key).float() | |
| scale = index.load(scale_key).float() | |
| return (weight * scale).t().contiguous().to(torch.bfloat16) | |
| def _hf_linear_to_bf16(index: _SafeTensorIndex, weight_key: str, scale_key: str) -> torch.Tensor: | |
| weight = index.load(weight_key).float() | |
| scale = index.load(scale_key).float().reshape(-1, 1) | |
| return (weight * scale).t().contiguous().to(torch.bfloat16) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/model/qwen3_14b/runner/a8w8_loader.py` around lines 86 - 89, Reshape
the linear scale tensor in _hf_linear_to_bf16 before multiplying so it
broadcasts over output channels correctly instead of the last axis. Update the
_hf_linear_to_bf16 flow to explicitly reshape or unsqueeze the loaded scale from
index.load(scale_key) before the weight * scale operation, then keep the
transpose/contiguous BF16 conversion as-is.
| def _split_qwen_text(text: str) -> list[str]: | ||
| parts: list[str] = [] | ||
| i = 0 | ||
| n = len(text) | ||
| while i < n: | ||
| lower_tail = text[i : i + 4].lower() | ||
| matched = False | ||
| for suffix in _CONTRACTIONS: | ||
| if lower_tail.startswith(suffix): | ||
| parts.append(text[i : i + len(suffix)]) | ||
| i += len(suffix) | ||
| matched = True | ||
| break | ||
| if matched: | ||
| continue | ||
|
|
||
| start = i | ||
| if text[i] == " " and i + 1 < n and _is_letter(text[i + 1]): | ||
| i += 1 | ||
| if i < n and _is_letter(text[i]): | ||
| i += 1 | ||
| while i < n and _is_letter(text[i]): | ||
| i += 1 | ||
| parts.append(text[start:i]) | ||
| continue | ||
|
|
||
| if _is_number(text[i]): | ||
| parts.append(text[i]) | ||
| i += 1 | ||
| continue | ||
|
|
||
| start = i | ||
| if text[i] == " " and i + 1 < n and _is_punct(text[i + 1]): | ||
| i += 1 | ||
| if i < n and _is_punct(text[i]): | ||
| i += 1 | ||
| while i < n and _is_punct(text[i]): | ||
| i += 1 | ||
| while i < n and _is_newline(text[i]): | ||
| i += 1 | ||
| parts.append(text[start:i]) | ||
| continue | ||
|
|
||
| start = i | ||
| while i < n and _is_space(text[i]): | ||
| i += 1 | ||
| parts.append(text[start:i]) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Inspect local tests or fixtures for tokenizer parity coverage without executing repo code.
rg -n -C2 'QwenByteBpe|from_pretrained|encode\(.*decode|tokenizer.*roundtrip|multiple spaces|newline|punct' tests python/core examplesRepository: hw-native-sys/pypto-serving
Length of output: 3949
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '\n== python/core/tokenizer.py ==\n'
sed -n '1,260p' python/core/tokenizer.py
printf '\n== tests mentioning tokenizer ==\n'
rg -n -C2 'tokenizer|Qwen|ByteBpe|transformers|from_pretrained|encode|decode' tests python/core examples
printf '\n== Qwen-related model artifacts and configs ==\n'
fd -a 'tokenizer.json|tokenizer_config.json|vocab.json|merges.txt|config.json' examples python tests 2>/dev/null | sed -n '1,120p'Repository: hw-native-sys/pypto-serving
Length of output: 50384
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect the fallback implementation and any direct references to Qwen tokenization behavior.
ast-grep outline python/core/tokenizer.py --view expanded
printf '\n== Qwen-specific references ==\n'
rg -n -C3 'Qwen|qwen|_split_qwen_text|_CONTRACTIONS|_QwenByteBpeTokenizer' python/core tests examplesRepository: hw-native-sys/pypto-serving
Length of output: 50383
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '\n== tokenizer encode/decode section ==\n'
sed -n '135,320p' python/core/tokenizer.py
printf '\n== possible tokenizer fixtures/artifacts in repo ==\n'
fd -a -t f 'vocab.json' 'merges.txt' 'tokenizer_config.json' 'tokenizer.json' tests examples python 2>/dev/null | sed -n '1,200p'
printf '\n== existing tokenizer tests ==\n'
rg -n -C2 'tokenizer|Qwen|ByteBpe|vocab.json|merges.txt|tokenizer_config.json' testsRepository: hw-native-sys/pypto-serving
Length of output: 11386
🌐 Web query:
Qwen tokenizer pre-tokenization rules Hugging Face ByteLevel BPE exact regex Qwen3
💡 Result:
The Qwen tokenizer series, including Qwen3, utilizes a Byte-level Byte Pair Encoding (BBPE) architecture [1][2]. Unlike tokenizers that operate on Unicode codepoints (like SentencePiece), the Qwen tokenizer operates directly on UTF-8 byte sequences [3], which ensures no out-of-vocabulary (OOV) tokens exist for any input text [4][2]. Pre-tokenization Rules and Implementation: The Qwen3 tokenizer implemented in the Hugging Face transformers library employs a specific pre_tokenizer sequence [5]. This sequence consists of: 1. A split step using a regular expression defined as PRETOKENIZE_REGEX. This regex breaks the text into chunks before BPE is applied, which dictates how the text is partitioned and defines boundaries that cannot be crossed by learned merges [6][5]. 2. A ByteLevel pre-tokenizer that handles the mapping of characters to byte sequences, ensuring consistent byte-level processing [5]. The regex used for pre-tokenization in the Qwen series (often consistent across versions) is: (?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+ [7] Key Characteristics: - Byte-level operation: It operates purely on byte distribution [3]. - Pre-tokenization constraints: Because pre-tokenization splits text into specific units, new tokens cannot be added if they span across these pre-tokenized boundaries [6][3]. For example, a multi-word sequence containing a space might be split by the regex, preventing it from being learned as a single new token [6][3]. - Compatibility: Qwen3 maintains this efficient BBPE system, supporting a large vocabulary size (e.g., 151,646 to 151,936 tokens depending on the specific model variant) [1][2][8]. This structure allows the tokenizer to remain highly robust across various languages and character sets without the need for fallback mechanisms [3][4][2].
Citations:
- 1: https://arxiv.org/html/2505.09388
- 2: https://deepwiki.com/QwenLM/Qwen3/3-model-architecture-and-core-concepts
- 3: https://deepwiki.com/QwenLM/Qwen/8-tokenization
- 4: https://qwen.readthedocs.io/en/latest/getting_started/concepts.html
- 5: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_5/tokenization_qwen3_5.py
- 6: https://github.com/QwenLM/Qwen-7B/blob/main/tokenization_note.md
- 7: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/tokenization_qwen2.py
- 8: https://huggingface.co/docs/transformers/main/model_doc/qwen3
Add parity tests for the Qwen fallback tokenizer. The fallback pre-tokenizer doesn’t match Qwen’s regex-based split rules exactly, so repeated spaces, newlines, and punctuation boundaries can drift from Hugging Face token IDs; add coverage for those edge cases before depending on this path in serving.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@python/core/tokenizer.py` around lines 135 - 181, Add parity tests for the
Qwen fallback tokenizer behavior in _split_qwen_text, focusing on edge cases
where the current split logic can diverge from Hugging Face: repeated spaces,
newline handling, and punctuation/space boundaries. Create coverage that
compares the fallback output against the expected Qwen regex-based tokenization
behavior for representative inputs, so regressions are caught before relying on
this path in serving.
| def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: | ||
| del add_special_tokens | ||
| text = unicodedata.normalize("NFC", text) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Do not normalize byte-level tokenizer input.
Line 250 changes user text before byte encoding, so decode(encode("e\u0301")) returns NFC-normalized text instead of the original byte sequence. Byte-level BPE should preserve the input.
Proposed fix
def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
del add_special_tokens
- text = unicodedata.normalize("NFC", text)
ids: list[int] = []📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: | |
| del add_special_tokens | |
| text = unicodedata.normalize("NFC", text) | |
| def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: | |
| del add_special_tokens | |
| ids: list[int] = [] |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@python/core/tokenizer.py` around lines 248 - 250, The byte-level tokenizer
input is being altered in Tokenizer.encode by NFC normalization, which breaks
round-tripping for decomposed Unicode text. Remove the
unicodedata.normalize("NFC", text) step from encode and preserve the original
input bytes before byte-level BPE processing; keep the rest of Tokenizer.encode
and decode behavior unchanged.
Register qwen3-a8w8 as a model format on the existing Qwen3-14B serving entrypoint instead of carrying a separate bridge driver.
Load compressed-tensors W8A8 weights into kernel-ready INT8/scale layouts, extend the PyPTO runner with A8W8 KV scale pages, and run prefill through the previously validated 10-layer hidden chunks.
Keep LLMEngine generation, KV allocation, sampling, and decode scheduling on the existing serving path; add a dependency-free Qwen ByteBPE fallback for environments without transformers.
Validated with py_compile, git diff --check, tokenizer roundtrip, 1-token smoke, and 8-token A8W8 hardware smoke for prompt '介绍一下北京故宫': TTFT 6.053s, TPOT 425.0ms/token.
Related: