From 3c20b9435a971c5375b003bf592ee2bcaad947d8 Mon Sep 17 00:00:00 2001 From: JayceSu98 Date: Wed, 27 May 2026 04:21:41 +0000 Subject: [PATCH] [BugFix][Quant] Accumulate E5M6 BF16 amax in FP32 E5M6 per-token casts use the same scale-selection contract as the FP8/FP4 per-token path: compute a per-token/channel absmax, round the dequant scale when requested, then quantize through the inverse scale. The BF16 input path stored the amax reduction result in a BF16 fragment before computing the E5M6 scale. On H100 with CUDA 13.2 this can underestimate the BF16 row amax for large hidden blocks. The resulting scale is one power-of-two too small when round_sf=True, which changes the packed E5M6 bytes. In the cast-back test this bad forward scale can also overflow BF16 dequantization to inf, making the cosine-style diff report nan. Store the E5M6 amax reduction result in FP32 so scale selection matches the PyTorch reference and is not tied to packed BF16 reduction behavior. The cast-back kernel itself does not need a workaround; its nan failures were downstream of the incorrect forward scale. JayceSu98 authored and validated this patch. Co-author GitHub: https://github.com/dingsg Verified on H100 (NVIDIA H100 PCIe, sm_90) with CUDA 13.2, PyTorch 2.12.0+cu132, and TileLang 0.1.10+cuda.git23d91c58: the 24 per_token_cast_to_e5m6 byte-mismatch cases and 24 cast_back_e5m6 nan-diff cases passed with pytest -n 2. Co-authored-by: dingsg --- tile_kernels/quant/per_token_cast_to_e5m6_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tile_kernels/quant/per_token_cast_to_e5m6_kernel.py b/tile_kernels/quant/per_token_cast_to_e5m6_kernel.py index 0956c13..ef8e080 100644 --- a/tile_kernels/quant/per_token_cast_to_e5m6_kernel.py +++ b/tile_kernels/quant/per_token_cast_to_e5m6_kernel.py @@ -119,7 +119,7 @@ def per_token_cast_to_e5m6_kernel( # Copy input into registers T.copy(x[pid_token * block_m, pid_hidden * block_k], x_fragment) - amax_fragment = T.alloc_fragment((block_m, num_groups), in_config.dtype) + amax_fragment = T.alloc_fragment((block_m, num_groups), T.float32) x_fragment_reshaped = T.reshape(x_fragment, [block_m, num_groups, num_per_channels]) # Reduce SF T.reduce_absmax(x_fragment_reshaped, amax_fragment, dim=2)