Skip to content

[cuda] GGUF Q6_K real packed INT6 (W6A8 dp4a) + GGUF CI export#20229

Merged
Gasoonjia merged 17 commits into
mainfrom
g4-int6-gguf
Jun 17, 2026
Merged

[cuda] GGUF Q6_K real packed INT6 (W6A8 dp4a) + GGUF CI export#20229
Gasoonjia merged 17 commits into
mainfrom
g4-int6-gguf

Conversation

@Gasoonjia

@Gasoonjia Gasoonjia commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Add a genuine 6-bit packed weight path for GGUF Q6_K on the CUDA backend, parallel to the int4/int8 plain_mm paths:

  • int6_plain_mm CUDA shim (W6A8 dp4a; ql/qh planes; spread2; -32 symmetric offset)
  • CudaPackedInt6Tensor (ql/qh + per-group bf16 scale; symmetric, no zero tensor)
  • int6_dispatch: F.linear routing (M<=4 -> executorch_cuda::int6_plain_mm op, M>4 -> dequant)
  • backend fallback-kernel + custom_ops_to_c_shims registration; CMake build
  • route GGUF Q6_K -> CudaPackedInt6Tensor (gguf_loader, pack_cuda, dequantize_weight)
  • tests: int6 gtest, test_int6_dispatch.py, pack round-trip; fix stale int4/int6 type asserts

CI (export_model_artifact.sh, gemma4_31b): download the Q4_K_M GGUF from unsloth/gemma-4-31B-it-GGUF (tokenizer from unsloth/gemma-4-31B-it) and run the inference sanity check + export via the GGUF loader (--gguf) instead of the prequantized HF checkpoint.

prompt(p) prefill tok/s decode tok/s
p=2043, d =128 2280.1 46.73

Gasoonjia added 6 commits June 8, 2026 22:15
…decode

Coalesce int4 W4A8 decode-matvec scale/zero loads by baking the
[N, n_groups] layout into the weight constant at pack time. Introduces
CudaCoalescedInt4Tensor (an ExecuTorch-internal subclass) that owns the
[n_groups, N] -> [N, n_groups] transpose, registers the int4_plain_mm
dispatch on it by type, and adds the coalesced dp4a matvec kernel that
reads scale/zero row-for-row with qdata (single coalesced load vs 32
stride-N cache lines). ~29.2 -> 37.4 tok/s on gemma group_size=32.

Rebased onto main; INT8 dp4a decode op and the floor_div pass from this
branch landed separately and now live in quantize_op_dispatch/.
…ied) + benchmark rework

Summary:
At decode (L_q==1) the standard pack-GQA SDPA kernel's grid collapses to
CTA = batch * n_kv_heads, which under-occupies the SMs; split-K flash-decoding
partitions the KV sequence across many more CTAs to fill the GPU. In
ReplaceEdgeOpWithTritonOpPass._pick_sdpa_kernel, route decode to split-K when
L_q==1 and L_kv >= 256 (power-of-2 head dim required; prefill and non-pow2 head
dims keep the standard kernel).

The 256 crossover was measured under CUDA-graph timing (capture+replay, faithful
to the deployed --cuda_graph runtime). The earlier 2048 boundary was overfit to
a plain (non-cuda-graph) microbenchmark, which charged split-K a ~140us per-call
partial-buffer alloc + extra-launch overhead that the graph runtime eliminates;
under faithful timing split-K wins ~1.2-20x from L_kv ~= 256 upward.

benchmark_sdpa.py reworked: deleted run_sweep and all CSV/sentinel machinery;
run_benchmark now compares all six backends (ET-standard, ET-split-K, PyTorch,
Flash, Efficient, Math) with the PyTorch correctness check, across several
decode configs (gemma D256/CTA16, qwen D256/CTA2, D128/CTA16) over the L_kv
range, with a cuda-graph on/off toggle (--mode {cudagraph,plain,both}) timing
every backend through a small self-contained cuda-graph primitive; terminal-only
output. Each reported cell is the mean+/-std over the last 6 of 10 runs (first 4
discarded as warmup; N_RUNS=10, N_WARMUP=4).

Test Plan:
Exercised against the repo (PYTHONPATH) since the conda env's installed
executorch is stale; a lib reinstall is required for the routing to take effect
in a real export.

backends/cuda/tests/test_sdpa_splitk_replacement.py
  - L_kv=128 -> standard; L_kv=256 -> split-K; L_kv=4096 -> split-K;
    non-pow2 D=96 -> standard.
backends/cuda/tests/test_triton_sdpa_splitk.py (14) and
backends/cuda/tests/test_triton_sdpa_nan.py (3) pass. 21 tests total.

gemma4_31b long-context decode (2401-tok prompt, 256 new tokens, temp 0,
--cuda_graph, 10 runs middle-6) with split-K routing: decode 37.91 -> 43.98
tok/s (+16.0%); prefill within noise.

python backends/cuda/benchmarks/benchmark_sdpa.py --mode cudagraph (gemma
D256/CTA16, mean+/-std us): L_kv=2048 ET-std 102.4+/-0.0 / ET-split-K 24.6+/-0.2 /
PyTorch 475.1+/-0.3 / Flash 56.5+/-0.0; L_kv=16384 ET-std 785.5+/-0.0 /
ET-split-K 179.8+/-0.1 / PyTorch 3447+/-2.6. Plain-timing mode shows split-K's
per-call overhead (the artifact behind the old 2048).
…ock)

The decode-only int4_plain_mm matvec was bound by activation load-instruction
throughput, not DRAM bandwidth (already ~64% peak) or latency. Each inner
iteration issued ~15 loads per 16-byte weight chunk: 8 scalar int32 activation
loads + the same per-block scale d reloaded 4x.

Align Q8Block to 16 bytes (sizeof 36->48) so each block's qs_even/qs_odd 16B
halves are 16B-aligned, then load a whole activation block with two vectorized
uint4 loads + one d load (~4x fewer activation loads). dp4a math and
accumulation order are bit-identical; the int8 activation values and scale are
unchanged.

gemma4_31b decode (long-ctx harness, stacked on optimize_1):
  decode  43.98 -> 46.79 tok/s (+6.4%)
  prefill 1193  -> 1186     (noise; int4_plain_mm is decode-only)
nsys: int4 matvec avg 38.4 -> 34.75 us (-9.5%); quant kernel unchanged.
Unit tests test_aoti_torch_cuda_int4_plain_mm: 6/6 pass (M=1/8, gs=16/32/128).
Block-sparse early-exit in _sdpa_fwd_kernel_body: skip KV blocks that are
entirely masked (sliding-window via HAS_MASK sum==0, causal via start_n>max_seq_pos).
Exact (skipped blocks are x1,+0 no-ops). Prefill +46-88% all lengths; decode safe;
SDPA nsys 58.1%->18.5%. Numerically bf16-exact vs dense+mask (unit test).
@pytorch-bot

pytorch-bot Bot commented Jun 12, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20229

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job, 1 Unrelated Failure, 2 Unclassified Failures

As of commit f1c6087 with merge base eb7473b (image):

NEW FAILURE - The following job has failed:

UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 12, 2026
@linux-foundation-easycla

linux-foundation-easycla Bot commented Jun 12, 2026

Copy link
Copy Markdown

CLA Signed
The committers listed above are authorized under a signed CLA.

@mergennachin

Copy link
Copy Markdown
Contributor

The cuda path is doing:

  1. GGUF raw block -> ExportableGGUFTensor
  2. to_intx_unpacked_to_int8_tensor() unpacks GGUF Q6_K into int8, with effective gs=16
  3. pack_cuda.py detects symmetric gs=16 int8 and repacks to CudaPackedInt6Tensor.
  4. Runtime decode calls int6_plain_mm

Can you do directly and skip the int8 path?

@Gasoonjia Gasoonjia force-pushed the g4-opt-prefill-window-sdpa branch from 89b043f to 087938d Compare June 16, 2026 05:45
@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results June 16, 2026 06:06 — with GitHub Actions Inactive
Add a genuine 6-bit packed weight path for GGUF Q6_K on the CUDA backend,
parallel to the int4/int8 plain_mm paths:
- int6_plain_mm CUDA shim (W6A8 dp4a; ql/qh planes; spread2; -32 symmetric offset)
- CudaPackedInt6Tensor (ql/qh + per-group bf16 scale; symmetric, no zero tensor)
- int6_dispatch: F.linear routing (M<=4 -> executorch_cuda::int6_plain_mm op, M>4 -> dequant)
- backend fallback-kernel + custom_ops_to_c_shims registration; CMake build
- GGUF Q6_K: gguf_loader returns the native torchao IntxUnpackedToInt8Tensor and
  the backend packer (pack_cuda.pack_linear_for_cuda) repacks a symmetric Q6_K
  weight into CudaPackedInt6Tensor -- mirroring Int4Tensor -> CudaCoalescedInt4Tensor,
  so the loader stays backend-agnostic; dequantize_weight handles the tied embedding
- tests: int6 gtest, test_int6_dispatch.py, pack round-trip; fix stale int4/int6 type asserts

CI (export_model_artifact.sh, gemma4_31b): download the Q4_K_M GGUF from
unsloth/gemma-4-31B-it-GGUF (tokenizer from unsloth/gemma-4-31B-it) and run the
inference sanity check + export via the GGUF loader (--gguf) instead of the
prequantized HF checkpoint.

Signed-off-by: gasoonjia <gasoonjia@icloud.com>
…c_q6k heuristic)

Route the Q6_K CUDA path on the native ExportableGGUFTensor (ggml_type ==
"q6_k") instead of an int8 intermediate. pack_linear_for_cuda now repacks the
raw GGUF tensor via CudaPackedInt6Tensor.from_exportable_gguf, which REUSES the
shared Q6_K block decode in gguf.py (to_intx_unpacked_to_int8_tensor) then bakes
the ql/qh bit-pack -- the decode is not duplicated. This removes the brittle
_is_symmetric_q6k heuristic and makes the int8 passthrough unambiguous.

- packed_int6_tensor: add from_exportable_gguf (keeps from_intx_int8 low-level packer)
- gguf_loader._convert_weight: q6_k returns the raw ExportableGGUFTensor (like MLX); q4_k unchanged
- quantize.dequantize_weight: add ExportableGGUFTensor branch (tied token embedding -> bf16)
- pack_cuda.pack_linear_for_cuda: route Int4Tensor / ExportableGGUFTensor(q6_k) / IntxUnpackedToInt8Tensor; drop heuristic
- tests: feed synthetic q6_k ExportableGGUFTensor; cover from_exportable_gguf

Python-only refactor; .cu kernel and serialized CudaPackedInt6Tensor unchanged.
Can be squashed into the int6 commit (390238e) later.
Rename the int6 tensor subclass and its module file to reflect the
dp4a-planar (ql/qh split bit-plane) layout:
  backends/cuda/packed_int6_tensor.py -> backends/cuda/dp4a_planar_int6_tensor.py
  class CudaPackedInt6Tensor -> CudaDp4aPlanarInt6Tensor

Update all references (imports, type dispatch, CUDA packer, quantize
dequant branch, gguf_loader, tests, kernel comments) and the
torch.serialization.add_safe_globals registration so exported models
round-trip under the new qualified name. Classmethods
from_exportable_gguf/from_intx_int8 and helpers pack_int6/unpack_int6 are
unchanged; the runtime op int6_plain_mm and the .cu/.cuh kernel are untouched.
The gemma4_31b CUDA CI selector still keyed off the stale prequant HF repo
SocialLocalMobile/gemma-4-31B-it-HQQ-INT4, but export_model_artifact.sh
already downloads the weights from unsloth/gemma-4-31B-it-GGUF. Rename the
CI identifier to match the actual source:
  SocialLocalMobile/gemma-4-31B-it-HQQ-INT4 -> unsloth/gemma-4-31B-it-GGUF

Updated the case selectors + help text in export_model_artifact.sh and
test_model_e2e.sh, and the matrix entries, exclude rules, and the A100
runner-selection conditionals in both the export and e2e jobs of cuda.yml.
The executorch registry MODEL_NAME stays gemma4_31b; qwen3_5_moe's
SocialLocalMobile HQQ entry is left unchanged.
…om_intx_int8)

It has no external production caller (from_exportable_gguf is the only entry; pack_cuda routes Q6_K there); keep it as the internal, unit-tested ql/qh packer.

Signed-off-by: gasoonjia <gasoonjia@icloud.com>
@Gasoonjia

Copy link
Copy Markdown
Contributor Author

@mergennachin Done. Now we rename the CudaPackedInt6Tensor into CudaDp4aPlanarInt6Tensor for better repentsent its layout, and makes CudaDp4aPlanarInt6Tensor take ExportableGGUFTensor directly, removed the extra to_intx_unpacked_to_int8_tensor() conversion in pack_cuda.py.

One thing i need to highlight is we still do to_intx_unpacked_to_int8_tensor() inside CudaDp4aPlanarInt6Tensor. from_exportable_gguf() function, since gtensor -> int8 -> int6 is the best conversion path in practice. GGUF stores llama.cpp's scrambled ql/qh plus a two-level scale, which has no direct bit-permutation to our dp4a ql/qh + single bf16 scale.

Comment on lines +73 to +75
# Genuine INT8 weight: left unchanged for the int8 path. Q6_K never reaches
# here (it arrives as an ExportableGGUFTensor), so this is unambiguous.
pass

@mergennachin mergennachin Jun 16, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HQQ-INT4 path arrives here, right?

Because some tensors in that checkpoint is int8

def _(func, types, args, kwargs):
input_tensor = args[0]
weight_tensor = args[1]
bias = args[2] if len(args) > 2 else None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drops keyword bias. F.linear(x, weight, bias=bias) is valid

suggested fix:

bias = args[2] if len(args) > 2 else kwargs.get("bias", None)

also add a test that actually passes bias as keyword argument

@mergennachin mergennachin left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See inline comments. Also, consider adding one more e2e test in test_pack_cuda.py

  • Build synthetic Q6_K tensor in save in ExportableGGUFTensor
  • Pass through pack_cuda
  • Assert it becomes CudaDp4aPlanarInt6Tensor
  • Use a decode-shaped input (M <= 4), then export and lower through the CUDA backend
  • Make sure the exported graph contains int6_plain_mm
  • Run the exported graph and compare against reference original Q6_K tensor

Base automatically changed from g4-opt-prefill-window-sdpa to main June 16, 2026 22:03
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results June 17, 2026 02:31 — with GitHub Actions Inactive
@Gasoonjia Gasoonjia merged commit fa5fc74 into main Jun 17, 2026
532 of 538 checks passed
@Gasoonjia Gasoonjia deleted the g4-int6-gguf branch June 17, 2026 05:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants