feat: svdquant w4a4 + awq w4a16 kitchen native (CORE-31)#36
Open
HK416-TYPED wants to merge 11 commits into
Open
feat: svdquant w4a4 + awq w4a16 kitchen native (CORE-31)#36HK416-TYPED wants to merge 11 commits into
HK416-TYPED wants to merge 11 commits into
Conversation
Introduces: - comfy_kitchen.quantize_svdquant_w4a4 / scaled_mm_svdquant_w4a4 as torch.library.custom_op entry points backed by the registry. - Eager pure-PyTorch reference impls for SVDQuant W4A4 (activation quant + smooth + LoRA-down fused; scaled_mm with LoRA-up + bias) and AWQ W4A16 GEMV, following kitchen's existing eager backend conventions. - TensorCoreSVDQuantW4A4Layout dispatching aten.t / aten.mm / aten.addmm / aten.linear for int4 storage with proj_down / proj_up / smooth params alongside weight scales. The eager path is a correctness reference. Kernel-level parity with nunchaku's MMA tile layout is not part of this change; it is covered by the CUDA backend port in the following commit.
Native int4 GEMM + activation quantize kernels that consume kitchen-native row-major tensors (no nunchaku vendor code, no tile-interleaved packing). sm_80+. End-to-end verified on Qwen-Image-Edit via ComfyUI. Accumulator is fp32 by design, not fp16 half2: production Qwen weight scales reach ~28 so per-term products overflow fp16 (±65504) and silently cascade into NaN → black images. See the header of ops/scaled_mm_svdquant_w4a4.cu for the full rationale before considering any fp16-accumulator optimization. Also wires act_unsigned end-to-end for the post-GELU fc2 layers (u4.s4 MMA variant, layer-level +0.171875 shift), aligns the eager signed-quantize clamp to [-7, 7] to match nunchaku's absmax/7 scheme, and adds tests/test_svdquant_w4a4.py (16 tests covering clamp contract, MMA dispatch, lora_x separation, cross-shape smoke).
The eager AWQ W4A16 GEMV (`comfy_kitchen.gemv_awq_w4a16`, registered as a torch.library custom op) was already in place but had no aten dispatch plumbing — there was no `QuantizedTensor` layout to wrap pre-quantized weights and route `F.linear` / `mm` / `addmm` to the GEMV. Add `TensorCoreAWQW4A16Layout` mirroring `TensorCoreSVDQuantW4A4Layout`: - `Params(scale=wscales, zeros=wzeros, group_size=64, transposed=False)` on top of `BaseLayoutParams`; both wscales and wzeros are per-group fp tensors of shape `(K // G, N)` matching the eager kernel's contract. - aten.t / aten.mm / aten.addmm / aten.linear handlers, plus a `dequantize()` reference for the fallback path (used when the RHS is not transposed or operand types don't match). Targets the modulation linears (`img_mod.1` / `txt_mod.1` in Qwen-Image-Edit, ~13 GB of bf16 storage on a r96 checkpoint after the current dequant-on-load path) so they can stay int4 end-to-end and cut both checkpoint size and resident VRAM by ~4x. CUDA fast path TBD; eager backs the dispatch in the meantime. Verified F.linear via this layout is bit-exact with `dequantize()` + plain matmul on synthetic AWQ tensors.
Native int4 × bf16/fp16 matmul on the kitchen-native row-major layout:
* GEMV path (M ≤ 8): naive 1-thread-per-output kernel, fp32 accum.
18× faster than eager at M=1, 4× at M=8 on RTX 5090 / Qwen-Image
modulation shape (N=18432, K=3072).
* Fused MMA path (8 < M ≤ 256): cooperative dequantize qweight to
bf16 in shmem, then mma.m16n8k16.row.col.f32.bf16.bf16.f32 (or
fp16 specialization) along K. No 113 MB intermediate W workspace.
BLOCK_M=16 / BLOCK_N=128 / BLOCK_K=64 (= one quant group),
4 warps, ldmatrix.x4 for both A and B (W is N-major row-major in
shmem, which matches mma B's lane layout directly), shmem stride
padded to 72 b16 to break the 128-byte ldmatrix bank conflict.
1.6× at M=256 vs eager.
* Large-M fallback (M > 256): dequantize-then-cuBLAS bf16 matmul.
The crossover sits there on Blackwell because the MMA kernel's
single-thread-per-N-row dequant pass and lack of cp.async
pipelining stop scaling past M ≈ 256; raising the limit is a
follow-up tuning task. 1.2× at M=4096 vs eager.
bias is applied externally in the Python wrapper (`out.add_(bias)`),
mirroring scaled_mm_svdquant_w4a4's epilogue contract.
Routing is internal to `gemv_awq_w4a16` so the public API and the
ComfyUI MixedPrecisionOps dispatch are unchanged.
Verified vs nunchaku runtime `ops.gemv_awq` oracle on a real r96
modulation tensor: max-abs / max-magnitude ≤ 0.65 % at all M (within
the bf16 ULP precision floor of `(nibble - 8) * scale + zero` chains).
Qwen-Image-Edit r96 ComfyUI E2E sampling: 274 s vs 342 s eager-AWQ
(20 % faster), PSNR 33.48 dB vs the bf16-dequant kitchen-native
baseline image — visually equivalent.
- Rename M/N/K/R/M_pad/K_half/G locals to lowercase per pep8-naming (N806/N803). Math notation in docstrings/comments preserved. - Replace × (multiplication sign) with x in docstrings/comments (RUF002/003). - Remove quoted type annotations (UP037). - Tag the standard `import torch.nn.functional as F` alias with noqa: N812. - Drop unused get_capable_backends import + reflow imports in test_svdquant_w4a4.py.
The previous lint pass missed comfy_kitchen/backends/cuda/__init__.py (N806 uppercase math vars in Python wrappers for quantize/scaled_mm and gemv_awq_w4a16) and the I001 import sort in tensor/__init__.py. Apply the same M/N/K/R/G/M_pad/K_half lowercase rename plus × → x in the two remaining docstring/comment occurrences.
|
Does this have an associated PR for https://github.com/Comfy-Org/ComfyUI ? Edit: Ignore me. I am just blind xD |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add SVDQuant W4A4 + AWQ W4A16 int4 quantization support for Qwen-Image-Edit and similar SVD-LoRA / AWQ-modulation checkpoints. Built as kitchen-native row-major backends.
Targets sm_80+ (Ampere).
SVDQuant W4A4
comfy_kitchen.quantize_svdquant_w4a4/scaled_mm_svdquant_w4a4astorch.library.custom_opentry points with eager pure-PyTorch reference.TensorCoreSVDQuantW4A4Layoutdispatchingaten.t/aten.mm/aten.addmm/aten.linearover int4 qdata +proj_down/proj_up/smooth_factor+ per-groupweight_scale.m16n8k64int4 MMA GEMM + fused activation quantize over kitchen-native row-major tensors; fp32 accumulator (fp16 overflows at production Qwen scale ≈ 28 — see kernel header).act_unsignedend-to-end for post-GELUfc2layers:u4.s4MMA variant ++0.171875shift applied at the layer (kernel API stays shift-free).[-7, 7](skip-8) to match nunchaku'sabsmax/7dequant-symmetric contract.addmm_(faster than kernel fusion for R=96; seescaled_mm_svdquant_w4a4docstring for the precision tradeoff and the cuBLASLt upgrade path ifneeded).
AWQ W4A16
comfy_kitchen.gemv_awq_w4a16as a single op covering both GEMV (M ≤ 8) and GEMM regimes — kernel routes internally.TensorCoreAWQW4A16Layoutdispatchingaten.t/aten.mm/aten.addmm/aten.linearover kitchen-native uint4 qdata + per-groupwscales/wzeros.mma.m16n8k16.f32.bf16.bf16.f32with cooperative dequant of int4 weight tile to bf16 in shmem. No 113 MB intermediate W workspace. shmem stride padded 8 b16 to breakthe 128-byte ldmatrix bank conflict.
limit is a follow-up tuning task).
out.add_(bias)), mirroringscaled_mm_svdquant_w4a4's epilogue contract.Tests
tests/test_svdquant_w4a4.py— 16 tests: quantizer clamp contract, signed/unsigned MMA dispatch,lora_xseparation, cross-shape smoke.Verification
ops.gemv_awqon a real r96 modulation tensor — max-abs / max-magnitude ≤ 0.65% at all M (within bf16 ULP precision floor).variant families (balanced / fast / mid / quality × ranks 32/64/96/128 × base/lightning4/lightning8) all sample successfully with the right
(steps, cfg)per distillation contract.ontract, signed/unsigned MMA dispatch, lora_x separation, cross-shape smoke.Model repository