Skip to content

Add native Qwen3 14B A8W8 serving path#48

Open
vegetabledoww wants to merge 1 commit into
hw-native-sys:mainfrom
vegetabledoww:qwen3-a8w8-stage5-serving-slim
Open

Add native Qwen3 14B A8W8 serving path#48
vegetabledoww wants to merge 1 commit into
hw-native-sys:mainfrom
vegetabledoww:qwen3-a8w8-stage5-serving-slim

Conversation

@vegetabledoww

@vegetabledoww vegetabledoww commented Jun 29, 2026

Copy link
Copy Markdown

Register qwen3-a8w8 as a model format on the existing Qwen3-14B serving entrypoint instead of carrying a separate bridge driver.

Load compressed-tensors W8A8 weights into kernel-ready INT8/scale layouts, extend the PyPTO runner with A8W8 KV scale pages, and run prefill through the previously validated 10-layer hidden chunks.

Keep LLMEngine generation, KV allocation, sampling, and decode scheduling on the existing serving path; add a dependency-free Qwen ByteBPE fallback for environments without transformers.

Validated with py_compile, git diff --check, tokenizer roundtrip, 1-token smoke, and 8-token A8W8 hardware smoke for prompt '介绍一下北京故宫': TTFT 6.053s, TPOT 425.0ms/token.

Related:

@coderabbitai

coderabbitai Bot commented Jun 29, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds an A8W8/BF16 checkpoint loader and decode-backend CLI options for Qwen3-14B, introduces L3-mode and A8W8-decode-loop execution paths with quantization-aware compilation and weight handling in the NPU executor, threads a pto_isa_commit parameter through the core executor, adds ContinuousTensor support to the worker, and adds a dependency-free Qwen byte-level BPE tokenizer fallback.

Changes

A8W8/L3 Decode Feature

Layer / File(s) Summary
A8W8 directory checkpoint loader
examples/model/qwen3_14b/runner/a8w8_loader.py
New _SafeTensorIndex shard reader, layout conversion helpers, and Qwen3A8W8DirectoryLoader that builds tokenizer/config/per-layer quantized and bf16 weights into a LoadedModel.
CLI options, model-loader wiring, and L2 profiling
examples/model/qwen3_14b/npu_generate.py
New --model-format, --decode-backend, --max-batch-size, --pto-isa-commit, --l3-mode, --a8w8-decode-loop flags; conditional Qwen3A8W8DirectoryLoader registration; QWEN_A8W8_BF16_KV env var; timed_run_l2 replaces prior L2 timing wrapper; RuntimeConfig dtype/batch updates.
Executor initialization and runtime patches
examples/model/qwen3_14b/runner/npu_executor.py
New imports/constants, TASK_ID fallback, _patch_orch_make_tensor_arg and _StackedLayerView, and extended __init__ with pto_isa_commit, l3_mode, a8w8_decode_loop.
Generation control flow for A8W8 loop and L3
examples/model/qwen3_14b/runner/npu_executor.py
New validate_generate_batch, prompt_allocation_length, try_generate_batch, run_generate_a8w8_decode_loop, run_generate_l3 selecting between decode-loop and L3 generation paths.
Quantization-aware L2/L3 compilation pipeline
examples/model/qwen3_14b/runner/npu_executor.py
_compile_model branches by quantization mode; A8W8-specific L2 callables; L3 artifact compilation/extraction with host orchestration patching; KV cache dtype override.
Mixed-precision weight stacking and release
examples/model/qwen3_14b/runner/npu_executor.py, tests/test_batching.py
Merged BF16/A8W8 decode weight stacking, per-projection scale tensor handling, weight-release cleanup, _model_quantization helper, and test update for decode_loop=None.
Core executor and worker support
python/core/pypto_executor.py, python/runtime/worker.py
pto_isa_commit threaded into RunConfig; WorkerTensor.to_continuous_tensor() added with ContinuousTensor compatibility import.

Dependency-free Qwen Tokenizer Fallback

Layer / File(s) Summary
Qwen byte-level BPE fallback tokenizer
python/core/tokenizer.py
Local vocab/merges detection in from_pretrained, byte/unicode and text-splitting helpers, and _QwenByteBpeTokenizer with encode/decode.

Sequence Diagram(s)

sequenceDiagram
  participant Engine
  participant Executor as Qwen314BPyptoExecutor
  participant A8W8Loop as run_generate_a8w8_decode_loop
  participant L3 as run_generate_l3

  Engine->>Executor: try_generate_batch(record, requests, prefill_batch, config)
  Executor->>Executor: validate_generate_batch / prompt_allocation_length
  alt a8w8_decode_loop enabled
    Executor->>A8W8Loop: run_generate_a8w8_decode_loop(...)
  else l3_mode enabled
    Executor->>L3: run_generate_l3(...)
  end
  Executor-->>Engine: GenerateResult(finish_reason)
Loading
sequenceDiagram
  participant main as npu_generate.main
  participant Loader as Qwen3A8W8DirectoryLoader
  participant Index as _SafeTensorIndex
  participant Runtime as RuntimeModel

  main->>Loader: load(request)
  Loader->>Loader: read config.json, build tokenizer
  Loader->>Index: scan model-*.safetensors shards
  Index-->>Loader: tensor offsets/dtype/shape
  Loader->>Index: load(key) per tensor
  Loader->>Loader: convert layouts (int8 kernel, bf16 dequant)
  Loader->>Runtime: populate embeddings/norm/head/layers
  Loader-->>main: LoadedModel
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • hw-native-sys/pypto-serving#14: Prior changes to the same npu_generate.py profiling/timed L2 wrapper and callable naming that this PR further modifies.
  • hw-native-sys/pypto-serving#17: Earlier kernel timing/profiling instrumentation in the same executor path that the new timed_run_l2 wrapper builds upon.
  • hw-native-sys/pypto-serving#29: Earlier L2-to-L3 dispatch switch in the same npu_executor.py that this PR extends with l3_mode, a8w8_decode_loop, and quantization-aware compilation.

Poem

A rabbit hopped through bytes and scale,
Found A8W8 weights without fail,
BPE tokens split so neat,
L3 orchestration, quite a feat,
Thump-thump — the decode loop runs fast,
New tensors continuous at last! 🐰✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: adding a native Qwen3 14B A8W8 serving path.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The description accurately matches the Qwen3 A8W8 serving, loader, tokenizer fallback, and runner changes in the patch.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for qwen3-a8w8 model quantization, including a new loader for compressed-tensors checkpoints and updated NPU execution and runner logic to handle A8W8 kernels and KV cache scaling. It also adds a dependency-free byte-level BPE tokenizer fallback. The review highlights significant code duplication in the NPU executor and runner logic, suggesting refactoring to improve maintainability. Additionally, it recommends performance optimizations for file I/O and tokenization, replacing SimpleNamespace with dataclass for better type safety, and avoiding bare except blocks.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +354 to +414
if quantization == "a8w8":
prefill = self._compile_prefill_fwd_callable_a8w8(
qwen3_prefill_fwd.prefill_hidden_a8w8,
batch=kernel_batch,
max_seq=model.runtime.max_seq_len,
hidden_size=model.config.hidden_size,
intermediate_size=model.config.intermediate_size,
num_heads=model.config.num_attention_heads,
num_kv_heads=model.config.num_key_value_heads,
head_dim=model.config.head_dim,
num_layers=min(model.config.num_hidden_layers, _QWEN14B_A8W8_PREFILL_CHUNK_LAYERS),
vocab_size=padded_vocab,
block_table_stride=max_blocks_per_seq,
page_size=page_size,
)
else:
prefill = self._compile_prefill_fwd_callable(
qwen3_prefill_fwd.prefill_fwd,
batch=kernel_batch,
max_seq=model.runtime.max_seq_len,
hidden_size=model.config.hidden_size,
intermediate_size=model.config.intermediate_size,
num_heads=model.config.num_attention_heads,
num_kv_heads=model.config.num_key_value_heads,
head_dim=model.config.head_dim,
num_layers=model.config.num_hidden_layers,
vocab_size=padded_vocab,
block_table_stride=max_blocks_per_seq,
page_size=page_size,
)
_mark("compile_prefill")
decode = self._compile_decode_fwd_callable(
qwen3_l3_dispatch.qwen3_decode_host,
batch=kernel_batch,
max_seq=model.runtime.max_seq_len,
block_table_stride=max_blocks_per_seq,
hidden_size=model.config.hidden_size,
intermediate_size=model.config.intermediate_size,
num_heads=model.config.num_attention_heads,
num_kv_heads=model.config.num_key_value_heads,
head_dim=model.config.head_dim,
num_layers=model.config.num_hidden_layers,
vocab_size=padded_vocab,
page_size=page_size,
)
if quantization == "a8w8":
decode = self._compile_decode_fwd_callable_a8w8(
qwen3_decode_layer.decode_fwd,
batch=kernel_batch,
max_seq=model.runtime.max_seq_len,
block_table_stride=max_blocks_per_seq,
hidden_size=model.config.hidden_size,
intermediate_size=model.config.intermediate_size,
num_heads=model.config.num_attention_heads,
num_kv_heads=model.config.num_key_value_heads,
head_dim=model.config.head_dim,
num_layers=model.config.num_hidden_layers,
vocab_size=padded_vocab,
page_size=page_size,
)
else:
decode = self._compile_decode_fwd_callable(
qwen3_decode_layer.decode_fwd,
batch=kernel_batch,
max_seq=model.runtime.max_seq_len,
block_table_stride=max_blocks_per_seq,
hidden_size=model.config.hidden_size,
intermediate_size=model.config.intermediate_size,
num_heads=model.config.num_attention_heads,
num_kv_heads=model.config.num_key_value_heads,
head_dim=model.config.head_dim,
num_layers=model.config.num_hidden_layers,
vocab_size=padded_vocab,
page_size=page_size,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication between the if and else branches for handling a8w8 quantization. The logic for compiling the prefill and decode callables is very similar. This makes the code harder to read and maintain.

Consider refactoring this to reduce duplication. You could use a helper function that takes quantization-specific parameters (like the jit function and some arguments) and returns the compiled callable.

Comment on lines +280 to +373
if compiled.quantization == "a8w8":
if kv_scales is None:
raise RuntimeError(f"missing A8W8 KV scale cache for model {model.config.model_id!r}")
k_cache_scale, v_cache_scale = kv_scales
rows_per_layer = k_cache.shape[0] // model.config.num_hidden_layers
hidden = prefill_inputs.hidden

def weight_slice(name: str, start: int, layers: int, rows_per_layer_: int = 1) -> WorkerTensor:
tensor = dw[name][start * rows_per_layer_ : (start + layers) * rows_per_layer_]
return self._l2_child_tensor(compiled.prefill.runtime_name, tensor)

for layer_start in range(0, model.config.num_hidden_layers, _QWEN14B_A8W8_PREFILL_CHUNK_LAYERS):
layer_count = min(
_QWEN14B_A8W8_PREFILL_CHUNK_LAYERS,
model.config.num_hidden_layers - layer_start,
)
cache_row_start = layer_start * rows_per_layer
cache_rows = layer_count * rows_per_layer
hidden_out = torch.empty_like(hidden).share_memory_()
scratch_logits = torch.empty_like(logits_padded).share_memory_()
self._run_l2_program(
compiled.prefill,
hidden,
prefill_inputs.seq_lens,
prefill_inputs.chunk_lens,
prefill_inputs.chunk_offsets,
weight_slice("decode_input_rms_weight", layer_start, layer_count),
weight_slice("decode_wq", layer_start, layer_count, model.config.hidden_size),
weight_slice("decode_wk", layer_start, layer_count, model.config.hidden_size),
weight_slice("decode_wv", layer_start, layer_count, model.config.hidden_size),
weight_slice("decode_wq_scale", layer_start, layer_count),
weight_slice("decode_wk_scale", layer_start, layer_count),
weight_slice("decode_wv_scale", layer_start, layer_count),
weight_slice("decode_q_norm_weight", layer_start, layer_count),
weight_slice("decode_k_norm_weight", layer_start, layer_count),
self._l2_child_tensor(compiled.prefill.runtime_name, compiled.rope_cos),
self._l2_child_tensor(compiled.prefill.runtime_name, compiled.rope_sin),
prefill_inputs.block_table,
prefill_inputs.slot_mapping,
self._worker_tensor_view(
k_cache,
cache_row_start * model.config.head_dim,
(cache_rows, model.config.head_dim),
1,
),
self._worker_tensor_view(
v_cache,
cache_row_start * model.config.head_dim,
(cache_rows, model.config.head_dim),
1,
),
self._worker_tensor_view(k_cache_scale, cache_row_start * 8, (cache_rows, 8), 4),
self._worker_tensor_view(v_cache_scale, cache_row_start * 8, (cache_rows, 8), 4),
weight_slice("decode_wo", layer_start, layer_count, model.config.hidden_size),
weight_slice("decode_wo_scale", layer_start, layer_count),
weight_slice("decode_post_rms_weight", layer_start, layer_count),
weight_slice("decode_w_gate", layer_start, layer_count, model.config.hidden_size),
weight_slice("decode_w_up", layer_start, layer_count, model.config.hidden_size),
weight_slice("decode_w_down", layer_start, layer_count, model.config.intermediate_size),
self._l2_child_tensor(compiled.prefill.runtime_name, compiled.final_norm_weight),
self._l2_child_tensor(compiled.prefill.runtime_name, compiled.padded_lm_head_weight),
scratch_logits,
hidden_out,
)
hidden = hidden_out
logits_padded = self._project_logits_host(model, compiled, prefill_inputs, hidden)
else:
self._run_l2_program(
compiled.prefill,
prefill_inputs.hidden,
prefill_inputs.seq_lens,
prefill_inputs.chunk_lens,
prefill_inputs.chunk_offsets,
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_input_rms_weight"]),
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_wq"]),
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_wk"]),
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_wv"]),
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_q_norm_weight"]),
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_k_norm_weight"]),
self._l2_child_tensor(compiled.prefill.runtime_name, compiled.rope_cos),
self._l2_child_tensor(compiled.prefill.runtime_name, compiled.rope_sin),
prefill_inputs.block_table,
prefill_inputs.slot_mapping,
k_cache,
v_cache,
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_wo"]),
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_post_rms_weight"]),
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_w_gate"]),
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_w_up"]),
self._l2_child_tensor(compiled.prefill.runtime_name, dw["decode_w_down"]),
self._l2_child_tensor(compiled.prefill.runtime_name, compiled.final_norm_weight),
self._l2_child_tensor(compiled.prefill.runtime_name, compiled.padded_lm_head_weight),
logits_padded,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This if/else block for a8w8 quantization introduces a large amount of duplicated code. The logic inside both branches is very similar, differing mainly in the arguments passed to _run_l2_program and the chunking logic for a8w8.

To improve maintainability, this should be refactored. For example, you could prepare a dictionary of arguments for _run_l2_program based on the quantization mode, and then have a more unified control flow.

Comment on lines +446 to +511
if compiled.quantization == "a8w8":
if kv_scales is None:
raise RuntimeError(f"missing A8W8 KV scale cache for model {model_id!r}")
k_cache_scale, v_cache_scale = kv_scales
self._run_l2_program(
compiled.decode,
hidden,
self._l2_child_tensor(rt, dw["decode_input_rms_weight"]),
self._l2_child_tensor(rt, dw["decode_wq"]),
self._l2_child_tensor(rt, dw["decode_wk"]),
self._l2_child_tensor(rt, dw["decode_wv"]),
self._l2_child_tensor(rt, dw["decode_wq_scale"]),
self._l2_child_tensor(rt, dw["decode_wk_scale"]),
self._l2_child_tensor(rt, dw["decode_wv_scale"]),
self._l2_child_tensor(rt, dw["decode_q_norm_weight"]),
self._l2_child_tensor(rt, dw["decode_k_norm_weight"]),
seq_lens,
block_table,
slot_mapping,
self._l2_child_tensor(rt, compiled.rope_cos),
self._l2_child_tensor(rt, compiled.rope_sin),
k_cache,
v_cache,
k_cache_scale,
v_cache_scale,
self._l2_child_tensor(rt, dw["decode_wo"]),
self._l2_child_tensor(rt, dw["decode_wo_scale"]),
self._l2_child_tensor(rt, dw["decode_w_gate"]),
self._l2_child_tensor(rt, dw["decode_w_up"]),
self._l2_child_tensor(rt, dw["decode_w_down"]),
self._l2_child_tensor(rt, dw["decode_post_rms_weight"]),
self._l2_child_tensor(rt, compiled.final_norm_weight),
self._l2_child_tensor(rt, compiled.padded_lm_head_weight),
logits_padded,
)
else:
# Argument order MUST match decode_layer.decode_fwd (PAGED):
# hidden_states, input_rms_weight, wq, wk, wv, q_norm_weight,
# k_norm_weight, seq_lens, block_table, slot_mapping, rope_cos, rope_sin,
# k_cache, v_cache, wo, w_gate, w_up, w_down, post_rms_weight,
# final_norm_weight, lm_head_weight, out.
self._run_l2_program(
compiled.decode,
hidden,
self._l2_child_tensor(rt, dw["decode_input_rms_weight"]),
self._l2_child_tensor(rt, dw["decode_wq"]),
self._l2_child_tensor(rt, dw["decode_wk"]),
self._l2_child_tensor(rt, dw["decode_wv"]),
self._l2_child_tensor(rt, dw["decode_q_norm_weight"]),
self._l2_child_tensor(rt, dw["decode_k_norm_weight"]),
seq_lens,
block_table,
slot_mapping,
self._l2_child_tensor(rt, compiled.rope_cos),
self._l2_child_tensor(rt, compiled.rope_sin),
k_cache,
v_cache,
self._l2_child_tensor(rt, dw["decode_wo"]),
self._l2_child_tensor(rt, dw["decode_w_gate"]),
self._l2_child_tensor(rt, dw["decode_w_up"]),
self._l2_child_tensor(rt, dw["decode_w_down"]),
self._l2_child_tensor(rt, dw["decode_post_rms_weight"]),
self._l2_child_tensor(rt, compiled.final_norm_weight),
self._l2_child_tensor(rt, compiled.padded_lm_head_weight),
logits_padded,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to run_prefill, this if/else block for a8w8 quantization duplicates a significant amount of code from the bf16 path. This makes the code difficult to maintain, as changes need to be applied in two places.

Please refactor this to remove the duplication. You could prepare the arguments for _run_l2_program in a helper function based on the quantization mode, then have a single call to _run_l2_program.

Comment on lines +72 to +74
with rec.shard.open("rb") as f:
f.seek(rec.data_start + begin)
raw = bytearray(f.read(end - begin))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation opens and closes the shard file for every tensor loaded. If multiple tensors are located in the same shard, this can be inefficient due to repeated file I/O. Consider memory-mapping the files in __init__ or grouping tensor loading by shard to reduce file operations.

Comment on lines +132 to +161
SimpleNamespace(
quantization="a8w8",
input_rms_weight=index.load(f"{prefix}.input_layernorm.weight").reshape(1, -1).float(),
wq=_hf_linear_to_kernel_i8(index, f"{prefix}.self_attn.q_proj.weight"),
wk=_hf_linear_to_kernel_i8(index, f"{prefix}.self_attn.k_proj.weight"),
wv=_hf_linear_to_kernel_i8(index, f"{prefix}.self_attn.v_proj.weight"),
wq_scale=_hf_scale_to_kernel(index, f"{prefix}.self_attn.q_proj.weight_scale"),
wk_scale=_hf_scale_to_kernel(index, f"{prefix}.self_attn.k_proj.weight_scale"),
wv_scale=_hf_scale_to_kernel(index, f"{prefix}.self_attn.v_proj.weight_scale"),
q_norm_weight=index.load(f"{prefix}.self_attn.q_norm.weight").reshape(1, -1).float(),
k_norm_weight=index.load(f"{prefix}.self_attn.k_norm.weight").reshape(1, -1).float(),
wo=_hf_linear_to_kernel_i8(index, f"{prefix}.self_attn.o_proj.weight"),
wo_scale=_hf_scale_to_kernel(index, f"{prefix}.self_attn.o_proj.weight_scale"),
post_rms_weight=index.load(f"{prefix}.post_attention_layernorm.weight").reshape(1, -1).float(),
w_gate=_hf_linear_to_bf16(
index,
f"{prefix}.mlp.gate_proj.weight",
f"{prefix}.mlp.gate_proj.weight_scale",
),
w_up=_hf_linear_to_bf16(
index,
f"{prefix}.mlp.up_proj.weight",
f"{prefix}.mlp.up_proj.weight_scale",
),
w_down=_hf_linear_to_bf16(
index,
f"{prefix}.mlp.down_proj.weight",
f"{prefix}.mlp.down_proj.weight_scale",
),
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using types.SimpleNamespace is flexible but lacks the explicitness and type safety of a dataclass. For better maintainability and code clarity, consider defining a dataclass for the layer weights. This provides type hints, autocompletion, and a more rigid structure, reducing the chance of runtime errors.

Comment on lines +623 to +632
assembled = compile_and_assemble(
work_dir,
self._platform,
pto_isa_commit=config.pto_isa_commit,
)
if len(assembled) == 2:
chip_callable, runtime_name = assembled
runtime_config = {}
else:
chip_callable, runtime_name, runtime_config = assembled

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to handle the two possible return signatures of compile_and_assemble is duplicated here and in _compile_jit_fwd_callable (lines 906-915). To improve maintainability, this logic should be extracted into a shared helper function.

Comment on lines +760 to +775
def _compile_prefill_fwd_callable_a8w8(
self,
jit_fn: object,
*,
batch: int,
max_seq: int,
block_table_stride: int,
hidden_size: int,
intermediate_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
num_layers: int,
vocab_size: int,
page_size: int,
) -> _L2Callable:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new function _compile_prefill_fwd_callable_a8w8 is very similar to _compile_prefill_fwd_callable. The same applies to the decode counterparts. This introduces significant code duplication.

Consider merging the a8w8 and non-a8w8 versions of these functions. You could parameterize them by passing the dtypes for dummy tensors based on the quantization mode, which would make the code more maintainable.

Comment on lines +205 to +207
except Exception:
self._free_kv_cache_tensor(key_scale)
raise

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except Exception: is generally discouraged as it can catch and hide unexpected errors, making debugging more difficult. Please catch a more specific exception if possible, for instance a memory allocation error. If you must catch Exception, consider logging it before re-raising to provide more context on failures.

Comment thread python/core/tokenizer.py
Comment on lines +228 to +241
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
new_word.extend(word[i:j])
i = j
if i < len(word) - 1 and word[i] == first and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loop to rebuild the word after finding a bigram merge can be inefficient. The word.index(first, i) call inside the while loop may repeatedly scan parts of the word tuple.

For better performance, you could iterate through the word once and build the new_word list in a single pass, checking for the (first, second) bigram at each position. While lru_cache helps, optimizing the core algorithm is still beneficial.

@vegetabledoww vegetabledoww force-pushed the qwen3-a8w8-stage5-serving-slim branch 2 times, most recently from a9b4d34 to b09f2cf Compare June 30, 2026 09:36
@vegetabledoww vegetabledoww force-pushed the qwen3-a8w8-stage5-serving-slim branch from b09f2cf to 2343c8e Compare July 1, 2026 02:12

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/model/qwen3_14b/npu_generate.py`:
- Around line 421-422: The QWEN_A8W8_BF16_KV environment flag is only ever set
and never cleared in main(), so a previous bf16-kv run can leak into later A8W8
runs in the same process. Update the argument-handling path around
args.model_format and args.decode_backend to deterministically reset
QWEN_A8W8_BF16_KV for every invocation, explicitly setting it to "1" only for
qwen3-a8w8 with bf16-kv and removing or clearing it otherwise.

In `@examples/model/qwen3_14b/runner/a8w8_loader.py`:
- Around line 86-89: Reshape the linear scale tensor in _hf_linear_to_bf16
before multiplying so it broadcasts over output channels correctly instead of
the last axis. Update the _hf_linear_to_bf16 flow to explicitly reshape or
unsqueeze the loaded scale from index.load(scale_key) before the weight * scale
operation, then keep the transpose/contiguous BF16 conversion as-is.

In `@python/core/tokenizer.py`:
- Around line 248-250: The byte-level tokenizer input is being altered in
Tokenizer.encode by NFC normalization, which breaks round-tripping for
decomposed Unicode text. Remove the unicodedata.normalize("NFC", text) step from
encode and preserve the original input bytes before byte-level BPE processing;
keep the rest of Tokenizer.encode and decode behavior unchanged.
- Around line 135-181: Add parity tests for the Qwen fallback tokenizer behavior
in _split_qwen_text, focusing on edge cases where the current split logic can
diverge from Hugging Face: repeated spaces, newline handling, and
punctuation/space boundaries. Create coverage that compares the fallback output
against the expected Qwen regex-based tokenization behavior for representative
inputs, so regressions are caught before relying on this path in serving.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 5243ea88-e297-4ca7-9abc-c6e9f3bcda03

📥 Commits

Reviewing files that changed from the base of the PR and between d37496a and 2343c8e.

📒 Files selected for processing (8)
  • examples/model/qwen3_14b/npu_generate.py
  • examples/model/qwen3_14b/runner/a8w8_loader.py
  • examples/model/qwen3_14b/runner/npu_executor.py
  • examples/model/qwen3_14b/runner/npu_runner.py
  • python/core/pypto_executor.py
  • python/core/tokenizer.py
  • python/runtime/worker.py
  • tests/test_batching.py

Comment on lines +421 to +422
if args.model_format == "qwen3-a8w8" and args.decode_backend == "bf16-kv":
os.environ["QWEN_A8W8_BF16_KV"] = "1"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win

Set the BF16-KV env flag deterministically.

If main() runs more than once in-process, a prior bf16-kv run leaves QWEN_A8W8_BF16_KV=1, so a later plain A8W8 run can compile decode with BF16 KV assumptions while allocating INT8 KV.

🐛 Proposed fix
-    if args.model_format == "qwen3-a8w8" and args.decode_backend == "bf16-kv":
-        os.environ["QWEN_A8W8_BF16_KV"] = "1"
+    os.environ["QWEN_A8W8_BF16_KV"] = (
+        "1" if args.model_format == "qwen3-a8w8" and args.decode_backend == "bf16-kv" else "0"
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if args.model_format == "qwen3-a8w8" and args.decode_backend == "bf16-kv":
os.environ["QWEN_A8W8_BF16_KV"] = "1"
os.environ["QWEN_A8W8_BF16_KV"] = (
"1" if args.model_format == "qwen3-a8w8" and args.decode_backend == "bf16-kv" else "0"
)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/model/qwen3_14b/npu_generate.py` around lines 421 - 422, The
QWEN_A8W8_BF16_KV environment flag is only ever set and never cleared in main(),
so a previous bf16-kv run can leak into later A8W8 runs in the same process.
Update the argument-handling path around args.model_format and
args.decode_backend to deterministically reset QWEN_A8W8_BF16_KV for every
invocation, explicitly setting it to "1" only for qwen3-a8w8 with bf16-kv and
removing or clearing it otherwise.

Comment on lines +86 to +89
def _hf_linear_to_bf16(index: _SafeTensorIndex, weight_key: str, scale_key: str) -> torch.Tensor:
weight = index.load(weight_key).float()
scale = index.load(scale_key).float()
return (weight * scale).t().contiguous().to(torch.bfloat16)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify PyTorch broadcasting behavior for HF [out, in] linear weights.
python - <<'PY'
import torch
weight = torch.empty(17408, 5120)
scale = torch.empty(17408)
try:
    weight * scale
except RuntimeError as exc:
    print("1-D scale fails as expected:", str(exc).splitlines()[0])
print((weight * scale.reshape(-1, 1)).shape)
PY

Repository: hw-native-sys/pypto-serving

Length of output: 283


🏁 Script executed:

#!/bin/bash
set -e
file="examples/model/qwen3_14b/runner/a8w8_loader.py"

printf '== outline ==\n'
ast-grep outline "$file" --view expanded || true

printf '\n== relevant lines ==\n'
nl -ba "$file" | sed -n '1,220p'

Repository: hw-native-sys/pypto-serving

Length of output: 1168


🏁 Script executed:

#!/bin/bash
set -e
rg -n "_hf_linear_to_bf16|_hf_scale_to_kernel|reshape|transpose|contiguous|bfloat16|scale_key|weight_key" examples/model/qwen3_14b/runner/a8w8_loader.py

Repository: hw-native-sys/pypto-serving

Length of output: 2194


🌐 Web query:

PyTorch broadcasting 1D tensor with 2D tensor trailing dimension semantics official docs

💡 Result:

According to the official PyTorch documentation, broadcasting follows NumPy's semantics [1][2][3]. Two tensors are "broadcastable" if, when iterating over their dimension sizes starting from the trailing (last) dimension, the dimension sizes either match, one of them is 1, or one of them does not exist [1][4][5]. When broadcasting a 1D tensor with a 2D tensor, the 1D tensor's dimensions are compared starting from the trailing dimension of the 2D tensor [1][5]. Specifically, the 1D tensor (size N) is implicitly treated as having a leading dimension of 1 (size 1 x N) to align with the 2D tensor (size M x N) [1][3][5]. If you wish to broadcast a 1D tensor (size N) against a 2D tensor (size M x N) such that the operation applies across rows, you must explicitly align the dimensions, often by unsqueezing the 1D tensor to shape (N, 1) or (1, N) depending on the desired operation [6]. It is important to distinguish between general element-wise broadcasting (such as addition or multiplication) and specialized matrix operations like torch.matmul. While element-wise broadcasting follows the general rules described above [1], torch.matmul has its own distinct broadcasting logic for batched matrix multiplication, which includes special handling for appending or prepending dimensions specifically for 1D tensors to facilitate matrix-vector or batch-matrix multiplication [7].

Citations:


Reshape the linear scales before BF16 dequantization.
weight * scale treats the 1-D scale as a row vector, so it broadcasts across the last axis. That errors on non-square projections and scales columns instead of output channels on square ones.

🐛 Proposed fix
 def _hf_linear_to_bf16(index: _SafeTensorIndex, weight_key: str, scale_key: str) -> torch.Tensor:
     weight = index.load(weight_key).float()
-    scale = index.load(scale_key).float()
+    scale = index.load(scale_key).float().reshape(-1, 1)
     return (weight * scale).t().contiguous().to(torch.bfloat16)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _hf_linear_to_bf16(index: _SafeTensorIndex, weight_key: str, scale_key: str) -> torch.Tensor:
weight = index.load(weight_key).float()
scale = index.load(scale_key).float()
return (weight * scale).t().contiguous().to(torch.bfloat16)
def _hf_linear_to_bf16(index: _SafeTensorIndex, weight_key: str, scale_key: str) -> torch.Tensor:
weight = index.load(weight_key).float()
scale = index.load(scale_key).float().reshape(-1, 1)
return (weight * scale).t().contiguous().to(torch.bfloat16)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/model/qwen3_14b/runner/a8w8_loader.py` around lines 86 - 89, Reshape
the linear scale tensor in _hf_linear_to_bf16 before multiplying so it
broadcasts over output channels correctly instead of the last axis. Update the
_hf_linear_to_bf16 flow to explicitly reshape or unsqueeze the loaded scale from
index.load(scale_key) before the weight * scale operation, then keep the
transpose/contiguous BF16 conversion as-is.

Comment thread python/core/tokenizer.py
Comment on lines +135 to +181
def _split_qwen_text(text: str) -> list[str]:
parts: list[str] = []
i = 0
n = len(text)
while i < n:
lower_tail = text[i : i + 4].lower()
matched = False
for suffix in _CONTRACTIONS:
if lower_tail.startswith(suffix):
parts.append(text[i : i + len(suffix)])
i += len(suffix)
matched = True
break
if matched:
continue

start = i
if text[i] == " " and i + 1 < n and _is_letter(text[i + 1]):
i += 1
if i < n and _is_letter(text[i]):
i += 1
while i < n and _is_letter(text[i]):
i += 1
parts.append(text[start:i])
continue

if _is_number(text[i]):
parts.append(text[i])
i += 1
continue

start = i
if text[i] == " " and i + 1 < n and _is_punct(text[i + 1]):
i += 1
if i < n and _is_punct(text[i]):
i += 1
while i < n and _is_punct(text[i]):
i += 1
while i < n and _is_newline(text[i]):
i += 1
parts.append(text[start:i])
continue

start = i
while i < n and _is_space(text[i]):
i += 1
parts.append(text[start:i])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Inspect local tests or fixtures for tokenizer parity coverage without executing repo code.
rg -n -C2 'QwenByteBpe|from_pretrained|encode\(.*decode|tokenizer.*roundtrip|multiple spaces|newline|punct' tests python/core examples

Repository: hw-native-sys/pypto-serving

Length of output: 3949


🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\n== python/core/tokenizer.py ==\n'
sed -n '1,260p' python/core/tokenizer.py

printf '\n== tests mentioning tokenizer ==\n'
rg -n -C2 'tokenizer|Qwen|ByteBpe|transformers|from_pretrained|encode|decode' tests python/core examples

printf '\n== Qwen-related model artifacts and configs ==\n'
fd -a 'tokenizer.json|tokenizer_config.json|vocab.json|merges.txt|config.json' examples python tests 2>/dev/null | sed -n '1,120p'

Repository: hw-native-sys/pypto-serving

Length of output: 50384


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the fallback implementation and any direct references to Qwen tokenization behavior.
ast-grep outline python/core/tokenizer.py --view expanded

printf '\n== Qwen-specific references ==\n'
rg -n -C3 'Qwen|qwen|_split_qwen_text|_CONTRACTIONS|_QwenByteBpeTokenizer' python/core tests examples

Repository: hw-native-sys/pypto-serving

Length of output: 50383


🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\n== tokenizer encode/decode section ==\n'
sed -n '135,320p' python/core/tokenizer.py

printf '\n== possible tokenizer fixtures/artifacts in repo ==\n'
fd -a -t f 'vocab.json' 'merges.txt' 'tokenizer_config.json' 'tokenizer.json' tests examples python 2>/dev/null | sed -n '1,200p'

printf '\n== existing tokenizer tests ==\n'
rg -n -C2 'tokenizer|Qwen|ByteBpe|vocab.json|merges.txt|tokenizer_config.json' tests

Repository: hw-native-sys/pypto-serving

Length of output: 11386


🌐 Web query:

Qwen tokenizer pre-tokenization rules Hugging Face ByteLevel BPE exact regex Qwen3

💡 Result:

The Qwen tokenizer series, including Qwen3, utilizes a Byte-level Byte Pair Encoding (BBPE) architecture [1][2]. Unlike tokenizers that operate on Unicode codepoints (like SentencePiece), the Qwen tokenizer operates directly on UTF-8 byte sequences [3], which ensures no out-of-vocabulary (OOV) tokens exist for any input text [4][2]. Pre-tokenization Rules and Implementation: The Qwen3 tokenizer implemented in the Hugging Face transformers library employs a specific pre_tokenizer sequence [5]. This sequence consists of: 1. A split step using a regular expression defined as PRETOKENIZE_REGEX. This regex breaks the text into chunks before BPE is applied, which dictates how the text is partitioned and defines boundaries that cannot be crossed by learned merges [6][5]. 2. A ByteLevel pre-tokenizer that handles the mapping of characters to byte sequences, ensuring consistent byte-level processing [5]. The regex used for pre-tokenization in the Qwen series (often consistent across versions) is: (?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+ [7] Key Characteristics: - Byte-level operation: It operates purely on byte distribution [3]. - Pre-tokenization constraints: Because pre-tokenization splits text into specific units, new tokens cannot be added if they span across these pre-tokenized boundaries [6][3]. For example, a multi-word sequence containing a space might be split by the regex, preventing it from being learned as a single new token [6][3]. - Compatibility: Qwen3 maintains this efficient BBPE system, supporting a large vocabulary size (e.g., 151,646 to 151,936 tokens depending on the specific model variant) [1][2][8]. This structure allows the tokenizer to remain highly robust across various languages and character sets without the need for fallback mechanisms [3][4][2].

Citations:


Add parity tests for the Qwen fallback tokenizer. The fallback pre-tokenizer doesn’t match Qwen’s regex-based split rules exactly, so repeated spaces, newlines, and punctuation boundaries can drift from Hugging Face token IDs; add coverage for those edge cases before depending on this path in serving.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/core/tokenizer.py` around lines 135 - 181, Add parity tests for the
Qwen fallback tokenizer behavior in _split_qwen_text, focusing on edge cases
where the current split logic can diverge from Hugging Face: repeated spaces,
newline handling, and punctuation/space boundaries. Create coverage that
compares the fallback output against the expected Qwen regex-based tokenization
behavior for representative inputs, so regressions are caught before relying on
this path in serving.

Comment thread python/core/tokenizer.py
Comment on lines +248 to +250
def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
del add_special_tokens
text = unicodedata.normalize("NFC", text)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Do not normalize byte-level tokenizer input.

Line 250 changes user text before byte encoding, so decode(encode("e\u0301")) returns NFC-normalized text instead of the original byte sequence. Byte-level BPE should preserve the input.

Proposed fix
     def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
         del add_special_tokens
-        text = unicodedata.normalize("NFC", text)
         ids: list[int] = []
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
del add_special_tokens
text = unicodedata.normalize("NFC", text)
def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
del add_special_tokens
ids: list[int] = []
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/core/tokenizer.py` around lines 248 - 250, The byte-level tokenizer
input is being altered in Tokenizer.encode by NFC normalization, which breaks
round-tripping for decomposed Unicode text. Remove the
unicodedata.normalize("NFC", text) step from encode and preserve the original
input bytes before byte-level BPE processing; keep the rest of Tokenizer.encode
and decode behavior unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant