diff --git a/tile_kernels/quant/per_token_cast_to_e5m6_kernel.py b/tile_kernels/quant/per_token_cast_to_e5m6_kernel.py index 0956c13..ef8e080 100644 --- a/tile_kernels/quant/per_token_cast_to_e5m6_kernel.py +++ b/tile_kernels/quant/per_token_cast_to_e5m6_kernel.py @@ -119,7 +119,7 @@ def per_token_cast_to_e5m6_kernel( # Copy input into registers T.copy(x[pid_token * block_m, pid_hidden * block_k], x_fragment) - amax_fragment = T.alloc_fragment((block_m, num_groups), in_config.dtype) + amax_fragment = T.alloc_fragment((block_m, num_groups), T.float32) x_fragment_reshaped = T.reshape(x_fragment, [block_m, num_groups, num_per_channels]) # Reduce SF T.reduce_absmax(x_fragment_reshaped, amax_fragment, dim=2)