diff --git a/examples/quantization/per_token_cast.py b/examples/quantization/per_token_cast.py new file mode 100644 index 00000000..49b97960 --- /dev/null +++ b/examples/quantization/per_token_cast.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Per-token FP8 cast with scale factors. + +This is a Tilus translation of DeepSeek TileKernels' +``per_token_cast_kernel.py`` for the common FP16 -> FP8 e4m3 path. Each CTA +processes one token and one channel group, computes the absolute maximum within +that group, stores a float32 scale factor, and writes the scaled FP8 output. +""" + +import pandas +import tilus +import torch +from tile_kernels.quant.per_token_cast_kernel import per_token_cast +from tilus import float8_e4m3, float16, float32, int32 +from tilus.utils import benchmark_func, cdiv + + +@tilus.autotune("block_m", [1, 2, 4, 8]) +@tilus.autotune("groups_per_block", [1, 2, 4, 8]) +@tilus.autotune("warps", [4, 8]) +class PerTokenCast(tilus.Script): + def __init__( + self, + block_m: int, + groups_per_block: int, + warps: int, + num_per_channels: int = 128, + ): + super().__init__() + self.block_m = block_m + self.num_per_channels = num_per_channels + self.groups_per_block = groups_per_block + self.block_n = num_per_channels + self.warps = warps + + def __call__( + self, + num_tokens: int, + hidden: int32, + x_ptr: ~float16, + out_ptr: ~float8_e4m3, + out_sf_ptr: ~float32, + ): + n_step = self.block_n * self.groups_per_block + self.attrs.blocks = ( + cdiv(num_tokens, self.block_m), + cdiv(hidden, n_step), + ) + self.attrs.warps = self.warps + self.assume(hidden % self.num_per_channels == 0) + + offset_m = self.blockIdx.x * self.block_m + base_offset_n = self.blockIdx.y * n_step + + g_x = self.global_view( + x_ptr, + dtype=float16, + shape=[num_tokens, hidden], + ) + g_out = self.global_view( + out_ptr, + dtype=float8_e4m3, + shape=[num_tokens, hidden], + ) + g_out_sf = self.global_view( + out_sf_ptr, + dtype=float32, + shape=[num_tokens, cdiv(hidden, self.num_per_channels)], + ) + + for gi in range(self.groups_per_block): + offset_n = base_offset_n + gi * self.block_n + sf_col = offset_n // self.num_per_channels + + r_x = self.load_global( + g_x, + offsets=[offset_m, offset_n], + shape=[self.block_m, self.block_n], + ).to(float32) + + r_absmax = self.max(self.abs(r_x), dim=1, keepdim=True) + r_fp8_max = self.register_tensor( + dtype=float32, + shape=[self.block_m, 1], + init=448.0, + ) + r_scale = self.where(r_absmax > 0.0, x=r_absmax / 448.0, y=1.0) + r_inv_scale = self.where(r_absmax > 0.0, x=r_fp8_max / r_absmax, y=1.0) + + self.store_global(g_out_sf, r_scale, offsets=[offset_m, sf_col]) + self.store_global( + g_out, + (r_x * r_inv_scale).to(float8_e4m3), + offsets=[offset_m, offset_n], + ) + + +def tilekernels_per_token_cast_reference( + x: torch.Tensor, + num_per_channels: int, +) -> tuple[torch.Tensor, torch.Tensor]: + return per_token_cast(x, "e4m3", num_per_channels) + + +def dequantized_sum( + out: torch.Tensor, scales: torch.Tensor, num_per_channels: int +) -> torch.Tensor: + grouped = out.float().reshape( + out.shape[0], + out.shape[1] // num_per_channels, + num_per_channels, + ) + return (grouped * scales[:, :, None]).sum() + + +def main(): + rows = [] + headers = [ + "tokens", + "hidden", + "tilekernels (ms)", + "tilus (ms)", + "speedup", + "sum diff", + ] + + for num_tokens, hidden in [ + (128, 1024), + (256, 2048), + (257, 4096), + ]: + num_per_channels = 128 + kernel = PerTokenCast(num_per_channels=num_per_channels) + + x = ( + torch.randn( + num_tokens, + hidden, + device="cuda", + dtype=torch.float16, + ) + * 2.0 + ).contiguous() + out = torch.empty((num_tokens, hidden), device="cuda", dtype=torch.float8_e4m3fn) + out_sf = torch.empty( + (num_tokens, hidden // num_per_channels), + device="cuda", + dtype=torch.float32, + ) + x_tilekernels = x.float() + + kernel(num_tokens, hidden, x, out, out_sf) + expected_out, expected_sf = tilekernels_per_token_cast_reference( + x_tilekernels, + num_per_channels, + ) + + max_code_diff = (out.float() - expected_out.float()).abs().max().item() + assert max_code_diff <= 32.0, f"max decoded FP8 code diff is {max_code_diff}" + torch.testing.assert_close(out_sf, expected_sf, atol=1e-5, rtol=1e-5) + + actual_sum = dequantized_sum(out, out_sf, num_per_channels) + expected_sum = dequantized_sum(expected_out, expected_sf, num_per_channels) + torch.testing.assert_close(actual_sum, expected_sum, atol=2.0, rtol=2e-2) + sum_diff = (actual_sum - expected_sum).abs().item() + + tilekernels_ms = benchmark_func( + lambda: tilekernels_per_token_cast_reference( + x_tilekernels, + num_per_channels, + ) + ) + tilus_ms = benchmark_func(lambda: kernel(num_tokens, hidden, x, out, out_sf)) + rows.append( + [ + num_tokens, + hidden, + tilekernels_ms, + tilus_ms, + f"{tilekernels_ms / tilus_ms:.2f}x", + sum_diff, + ] + ) + print( + "Per-token FP8 cast matches reference for size " + f"({num_tokens}, {hidden}); max code diff={max_code_diff:.6g}; " + f"dequantized sum diff={sum_diff:.6g}" + ) + + print(pandas.DataFrame(rows, columns=headers)) + + +if __name__ == "__main__": + main() diff --git a/examples/quantization/swiglu_forward_and_per_token_cast.py b/examples/quantization/swiglu_forward_and_per_token_cast.py new file mode 100644 index 00000000..153cc343 --- /dev/null +++ b/examples/quantization/swiglu_forward_and_per_token_cast.py @@ -0,0 +1,338 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Fused SwiGLU forward with per-token FP8 cast. + +This is a Tilus translation of DeepSeek TileKernels' +``swiglu_forward_and_per_token_cast_kernel.py``. It computes + + out = silu(x[:, :hidden]) * x[:, hidden:] + +optionally applies a routing weight and expert mask, then quantizes each +``num_per_channels`` group to FP8 e4m3 with one float32 scale factor per +token/group. +""" + +import pandas +import tilus +import torch +from tile_kernels.quant.swiglu_forward_and_per_token_cast_kernel import ( + swiglu_forward_and_per_token_cast, +) +from tilus import float8_e4m3, float16, float32, int32 +from tilus.utils import benchmark_func, cdiv + + +@tilus.autotune("block_m", [1]) +@tilus.autotune("groups_per_block", [1, 2, 4, 8, 16]) +@tilus.autotune("warps", [1, 2, 4, 8]) +class SwiGLUForwardAndPerTokenCast(tilus.Script): + def __init__( + self, + block_m: int, + groups_per_block: int, + warps: int, + with_weight: bool = True, + with_pos_to_expert: bool = True, + use_clamp: bool = True, + num_per_channels: int = 128, + ): + super().__init__() + self.block_m = block_m + self.num_per_channels = num_per_channels + self.groups_per_block = groups_per_block + self.block_n = num_per_channels + self.warps = warps + self.with_weight = with_weight + self.with_pos_to_expert = with_pos_to_expert + self.use_clamp = use_clamp + + def __call__( + self, + num_expanded_tokens: int, + hidden: int32, + num_topk_values: int32, + x_ptr: ~float16, + out_ptr: ~float8_e4m3, + out_sf_ptr: ~float32, + pos_to_token_topk_ptr: ~int32, + topk_weights_ptr: ~float32, + pos_to_expert_ptr: ~int32, + swiglu_clamp_value: float32, + ): + n_step = self.block_n * self.groups_per_block + self.attrs.blocks = ( + cdiv(num_expanded_tokens, self.block_m), + cdiv(hidden, n_step), + ) + self.attrs.warps = self.warps + self.assume(hidden % self.num_per_channels == 0) + + offset_m = self.blockIdx.x * self.block_m + base_offset_n = self.blockIdx.y * n_step + + g_x = self.global_view( + x_ptr, + dtype=float16, + shape=[num_expanded_tokens, hidden * 2], + ) + g_out = self.global_view( + out_ptr, + dtype=float8_e4m3, + shape=[num_expanded_tokens, hidden], + ) + g_out_sf = self.global_view( + out_sf_ptr, + dtype=float32, + shape=[num_expanded_tokens, cdiv(hidden, self.num_per_channels)], + ) + g_pos_to_token_topk = self.global_view( + pos_to_token_topk_ptr, + dtype=int32, + shape=[num_expanded_tokens], + ) + g_topk_weights = self.global_view( + topk_weights_ptr, + dtype=float32, + shape=[num_topk_values], + ) + g_pos_to_expert = self.global_view( + pos_to_expert_ptr, + dtype=int32, + shape=[num_expanded_tokens], + ) + + if (not self.with_pos_to_expert) or g_pos_to_expert[offset_m].item() >= 0: + base_sf_col = base_offset_n // self.num_per_channels + + # Wide load: full n_step at once so layout-inference vectorises. + r_l = self.load_global( + g_x, + offsets=[offset_m, base_offset_n], + shape=[self.block_m, n_step], + ).to(float32) + r_r = self.load_global( + g_x, + offsets=[offset_m, base_offset_n + hidden], + shape=[self.block_m, n_step], + ).to(float32) + + if self.use_clamp: + negative_swiglu_clamp_value = 0.0 - swiglu_clamp_value + r_l = self.where(r_l > swiglu_clamp_value, x=swiglu_clamp_value, y=r_l) + r_r = self.where(r_r > swiglu_clamp_value, x=swiglu_clamp_value, y=r_r) + r_r = self.where( + r_r < negative_swiglu_clamp_value, + x=negative_swiglu_clamp_value, + y=r_r, + ) + + r_silu = r_l / (self.exp(-r_l) + 1.0) + r_value = r_silu * r_r + + if self.with_weight: + topk_pos = g_pos_to_token_topk[offset_m].item() + if topk_pos >= 0: + topk_weight = g_topk_weights[topk_pos].item() + r_value = r_value * topk_weight + + # Reshape into [block_m, groups_per_block, num_per_channels] so the + # per-group absmax is a single reduce on dim=2. + r_value_grouped = self.reshape( + r_value, + shape=[self.block_m, self.groups_per_block, self.num_per_channels], + ) + r_absmax = self.max( + self.abs(r_value_grouped), dim=2, keepdim=True + ) # [block_m, groups_per_block, 1] + r_fp8_max = self.register_tensor( + dtype=float32, + shape=[self.block_m, self.groups_per_block, 1], + init=448.0, + ) + r_scale = self.where(r_absmax > 0.0, x=r_absmax / 448.0, y=1.0) + r_inv_scale = self.where(r_absmax > 0.0, x=r_fp8_max / r_absmax, y=1.0) + + # Store one fp32 scale per group. + r_scale_2d = self.reshape( + r_scale, shape=[self.block_m, self.groups_per_block] + ) + self.store_global(g_out_sf, r_scale_2d, offsets=[offset_m, base_sf_col]) + + # Apply scaling, flatten back, cast to fp8, bulk store. + r_out_grouped = (r_value_grouped * r_inv_scale).to(float8_e4m3) + r_out = self.reshape(r_out_grouped, shape=[self.block_m, n_step]) + self.store_global(g_out, r_out, offsets=[offset_m, base_offset_n]) + + +def tilekernels_swiglu_reference( + x: torch.Tensor, + pos_to_token_topk: torch.Tensor, + topk_weights: torch.Tensor, + pos_to_expert: torch.Tensor, + clamp_value: float, + num_per_channels: int, +) -> tuple[torch.Tensor, torch.Tensor]: + return swiglu_forward_and_per_token_cast( + x, + "e4m3", + num_per_channels, + pos_to_token_topk=pos_to_token_topk, + topk_weights=topk_weights, + pos_to_expert=pos_to_expert, + swiglu_clamp_value=clamp_value, + ) + + +def dequantized_sum( + out: torch.Tensor, scales: torch.Tensor, num_per_channels: int +) -> torch.Tensor: + grouped = out.float().reshape( + out.shape[0], + out.shape[1] // num_per_channels, + num_per_channels, + ) + return (grouped * scales[:, :, None]).sum() + + +def main(): + rows = [] + headers = [ + "tokens", + "hidden", + "tilekernels (ms)", + "tilus (ms)", + "speedup", + "sum diff", + ] + + for num_expanded_tokens, hidden, num_tokens, num_topk in [ + (128, 1024, 64, 2), + (256, 2048, 128, 2), + (257, 4096, 128, 2), + (1024, 4096, 512, 2), + ]: + num_per_channels = 128 + kernel = SwiGLUForwardAndPerTokenCast(num_per_channels=num_per_channels) + + x = ( + torch.randn( + num_expanded_tokens, + hidden * 2, + device="cuda", + dtype=torch.float16, + ) + * 2.0 + ).contiguous() + pos_to_token_topk = torch.arange( + num_expanded_tokens, + device="cuda", + dtype=torch.int32, + ) % (num_tokens * num_topk) + topk_weights = torch.rand( + num_tokens, + num_topk, + device="cuda", + dtype=torch.float32, + ) + pos_to_expert = torch.ones(num_expanded_tokens, device="cuda", dtype=torch.int32) + pos_to_expert[::17] = -1 + + out = torch.empty( + (num_expanded_tokens, hidden), + device="cuda", + dtype=torch.float8_e4m3fn, + ) + out_sf = torch.empty( + (num_expanded_tokens, hidden // num_per_channels), + device="cuda", + dtype=torch.float32, + ) + x_tilekernels = x.float() + + clamp_value = 6.0 + kernel( + num_expanded_tokens, + hidden, + num_tokens * num_topk, + x, + out, + out_sf, + pos_to_token_topk, + topk_weights, + pos_to_expert, + clamp_value, + ) + + expected_out, expected_sf = tilekernels_swiglu_reference( + x_tilekernels, + pos_to_token_topk, + topk_weights, + pos_to_expert, + clamp_value, + num_per_channels, + ) + valid = pos_to_expert >= 0 + max_code_diff = ( + (out[valid].float() - expected_out[valid].float()).abs().max().item() + ) + assert max_code_diff <= 32.0, f"max decoded FP8 code diff is {max_code_diff}" + torch.testing.assert_close( + out_sf[valid], + expected_sf[valid], + atol=1e-5, + rtol=1e-5, + ) + actual_sum = dequantized_sum(out[valid], out_sf[valid], num_per_channels) + expected_sum = dequantized_sum( + expected_out[valid], + expected_sf[valid], + num_per_channels, + ) + torch.testing.assert_close(actual_sum, expected_sum, atol=2.0, rtol=2e-2) + sum_diff = (actual_sum - expected_sum).abs().item() + + tilekernels_ms = benchmark_func( + lambda: tilekernels_swiglu_reference( + x_tilekernels, + pos_to_token_topk, + topk_weights, + pos_to_expert, + clamp_value, + num_per_channels, + ) + ) + tilus_ms = benchmark_func( + lambda: kernel( + num_expanded_tokens, + hidden, + num_tokens * num_topk, + x, + out, + out_sf, + pos_to_token_topk, + topk_weights, + pos_to_expert, + clamp_value, + ) + ) + rows.append( + [ + num_expanded_tokens, + hidden, + tilekernels_ms, + tilus_ms, + f"{tilekernels_ms / tilus_ms:.2f}x", + sum_diff, + ] + ) + print( + "SwiGLU FP8 cast matches reference for size " + f"({num_expanded_tokens}, {hidden}); max code diff={max_code_diff:.6g}; " + f"dequantized sum diff={sum_diff:.6g}" + ) + + print(pandas.DataFrame(rows, columns=headers)) + + +if __name__ == "__main__": + main() diff --git a/python/tilus/backends/emitters/transform.py b/python/tilus/backends/emitters/transform.py index 5f79402c..37acec97 100644 --- a/python/tilus/backends/emitters/transform.py +++ b/python/tilus/backends/emitters/transform.py @@ -13,7 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from tilus.backends.emitter import BaseInstEmitter, register_emitter -from tilus.ir.instructions import RepeatInst, RepeatInterleaveInst, SqueezeInst, TransposeInst, UnsqueezeInst +from tilus.ir.instructions import ( + RepeatInst, + RepeatInterleaveInst, + ReshapeRegisterInst, + SqueezeInst, + TransposeInst, + UnsqueezeInst, +) @register_emitter(RepeatInst) @@ -87,6 +94,7 @@ def emit(self, inst: RepeatInterleaveInst) -> None: @register_emitter(UnsqueezeInst) @register_emitter(SqueezeInst) +@register_emitter(ReshapeRegisterInst) class SqueezeUnsqueezeInstEmitter(BaseInstEmitter): def emit(self, inst: SqueezeInst) -> None: src = inst.register_input diff --git a/python/tilus/hidet/include/tilus/tvm/ffi/extra_type_traits.h b/python/tilus/hidet/include/tilus/tvm/ffi/extra_type_traits.h index ac8b7a92..7569da75 100644 --- a/python/tilus/hidet/include/tilus/tvm/ffi/extra_type_traits.h +++ b/python/tilus/hidet/include/tilus/tvm/ffi/extra_type_traits.h @@ -22,6 +22,9 @@ #include #include +#include +#include + #include "void_p.h" namespace tvm { @@ -85,6 +88,30 @@ struct TypeTraits<__nv_bfloat16*> : public FallbackOnlyTraitsBase<__nv_bfloat16* } }; +template <> +struct TypeTraits : public FallbackOnlyTraitsBase { + TVM_FFI_INLINE static std::string TypeStr() { return "float8_e4m3*"; } + + TVM_FFI_INLINE static float8_e4m3* ConvertFallbackValue(DLTensor* src) { + if (src->dtype.code != kDLFloat8_e4m3fn || src->dtype.bits != 8) { + TVM_FFI_THROW(ValueError) << "Expect a tensor with 8 bit float8_e4m3, got a tensor with dtype " << dtype_to_str(src->dtype); + } + return reinterpret_cast(src->data); + } +}; + +template <> +struct TypeTraits : public FallbackOnlyTraitsBase { + TVM_FFI_INLINE static std::string TypeStr() { return "float8_e5m2*"; } + + TVM_FFI_INLINE static float8_e5m2* ConvertFallbackValue(DLTensor* src) { + if (src->dtype.code != kDLFloat8_e5m2 || src->dtype.bits != 8) { + TVM_FFI_THROW(ValueError) << "Expect a tensor with 8 bit float8_e5m2, got a tensor with dtype " << dtype_to_str(src->dtype); + } + return reinterpret_cast(src->data); + } +}; + // Template specialization for float*, double* template struct TypeTraits>> : public FallbackOnlyTraitsBase { diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index be8ba34d..549e818d 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -109,6 +109,7 @@ ReduceInst, RepeatInst, RepeatInterleaveInst, + ReshapeRegisterInst, ReshapeSharedInst, ScanInst, SliceAssignInst, @@ -553,6 +554,16 @@ def unsqueeze( self.append(inst) return inst.register_output + def reshape_register( + self, + x: RegisterTensor, + shape: Sequence[int], + out: Optional[RegisterTensor] = None, + ) -> RegisterTensor: + inst = ReshapeRegisterInst.create(x=x, shape=shape, out=out) + self.append(inst) + return inst.register_output + def cast( self, x: RegisterTensor, diff --git a/python/tilus/ir/instructions/__init__.py b/python/tilus/ir/instructions/__init__.py index 091e9417..e71bc540 100644 --- a/python/tilus/ir/instructions/__init__.py +++ b/python/tilus/ir/instructions/__init__.py @@ -65,6 +65,7 @@ ReduceInst, RepeatInst, RepeatInterleaveInst, + ReshapeRegisterInst, ReshapeSharedInst, ScanInst, ShuffleDownInst, diff --git a/python/tilus/ir/instructions/generic.py b/python/tilus/ir/instructions/generic.py index 4d39a47a..cf66b572 100644 --- a/python/tilus/ir/instructions/generic.py +++ b/python/tilus/ir/instructions/generic.py @@ -637,6 +637,23 @@ def create(x: RegisterTensor, out: Optional[RegisterTensor] = None) -> Transpose return TransposeInst(output=out, inputs=(x,)) +@dataclass(frozen=True, eq=False) +class ReshapeRegisterInst(Instruction): + @staticmethod + def create( + x: RegisterTensor, + shape: Sequence[int], + out: Optional[RegisterTensor] = None, + ) -> ReshapeRegisterInst: + from tilus.utils import prod + + if out is None: + if prod(x.shape) != prod(shape): + raise ValueError(f"Cannot reshape register tensor with shape {x.shape} to shape {shape}: sizes differ") + out = RegisterTensor.create(dtype=x.dtype, shape=tuple(shape)) + return ReshapeRegisterInst(output=out, inputs=(x,)) + + @dataclass(frozen=True, eq=False) class AllocateSharedInst(Instruction): @staticmethod diff --git a/python/tilus/ir/layout/inference/inference_rules/transform.py b/python/tilus/ir/layout/inference/inference_rules/transform.py index 55a5fae6..369eb940 100644 --- a/python/tilus/ir/layout/inference/inference_rules/transform.py +++ b/python/tilus/ir/layout/inference/inference_rules/transform.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from tilus import RegisterLayout -from tilus.ir.instructions import SqueezeInst, UnsqueezeInst +from tilus.ir.instructions import ReshapeRegisterInst, SqueezeInst, UnsqueezeInst from tilus.ir.layout import ops from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule from tilus.ir.tensor import RegisterTensor @@ -51,3 +51,20 @@ def inference(ctx: LayoutInferenceContext, inst: SqueezeInst) -> dict[RegisterTe return {x: ops.unsqueeze(y.layout, dims=inst.dims)} else: return {} + + +@register_rule(ReshapeRegisterInst) +class ReshapeRegisterRule(LayoutInferenceRule): + @staticmethod + def inference(ctx: LayoutInferenceContext, inst: ReshapeRegisterInst) -> dict[RegisterTensor, RegisterLayout]: + x = inst.register_input + y = inst.register_output + + if x.optional_layout is not None and y.optional_layout is not None: + return {} + elif x.optional_layout is not None: + return {y: ops.reshape(x.layout, shape=y.shape)} + elif y.optional_layout is not None: + return {x: ops.reshape(y.layout, shape=x.shape)} + else: + return {} diff --git a/python/tilus/ir/layout/inference/order.py b/python/tilus/ir/layout/inference/order.py index 59118291..e1d5c891 100644 --- a/python/tilus/ir/layout/inference/order.py +++ b/python/tilus/ir/layout/inference/order.py @@ -47,7 +47,7 @@ from .inference_rules.tcgen05.ldst import Tcgen05LoadRule, Tcgen05StoreRule from .inference_rules.tcgen05.mma import Tcgen05MmaSSRule, Tcgen05MmaTSRule from .inference_rules.tcgen05.slice import Tcgen05SliceRule -from .inference_rules.transform import SqueezeRule, UnsqueezeRule +from .inference_rules.transform import ReshapeRegisterRule, SqueezeRule, UnsqueezeRule from .inference_rules.transform_shared import PermuteSharedRule, SharedSliceRule from .inference_rules.transpose import TransposeRule from .inference_rules.wgmma import WgmmaMmaSSRule @@ -67,7 +67,7 @@ [LoadGlobalRule], [ReduceRule], [ScanRule], - [TransposeRule, SqueezeRule, UnsqueezeRule], + [TransposeRule, SqueezeRule, UnsqueezeRule, ReshapeRegisterRule], [WhereRule], [AssignRule], [StoreGlobalRule], diff --git a/python/tilus/ir/layout/inference/validation_rules/transform.py b/python/tilus/ir/layout/inference/validation_rules/transform.py index 1367e4dc..27af3be9 100644 --- a/python/tilus/ir/layout/inference/validation_rules/transform.py +++ b/python/tilus/ir/layout/inference/validation_rules/transform.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tilus.ir.instructions import SqueezeInst, UnsqueezeInst +from tilus.ir.instructions import ReshapeRegisterInst, SqueezeInst, UnsqueezeInst from tilus.ir.layout import ops from tilus.ir.layout.inference.rule import LayoutValidationRule, register_rule from tilus.ir.tensor import RegisterTensor @@ -36,3 +36,16 @@ def validate(inst: UnsqueezeInst) -> bool: y: RegisterTensor = inst.register_output return y.layout == ops.unsqueeze(x.layout, dims=inst.dims) + + +@register_rule(ReshapeRegisterInst) +class ReshapeRegisterRule(LayoutValidationRule): + @staticmethod + def validate(inst: ReshapeRegisterInst) -> bool: + x: RegisterTensor = inst.register_input + y: RegisterTensor = inst.register_output + + try: + return y.layout == ops.reshape(x.layout, shape=y.shape) + except Exception: + return False diff --git a/python/tilus/ir/layout/ops/register_ops.py b/python/tilus/ir/layout/ops/register_ops.py index 6c59a13d..3434c27d 100644 --- a/python/tilus/ir/layout/ops/register_ops.py +++ b/python/tilus/ir/layout/ops/register_ops.py @@ -496,15 +496,21 @@ def reshape(layout: RegisterLayout, shape: Sequence[int]) -> RegisterLayout: p = mode_shape.pop(0) grouped_mode_shape.append([]) - while shape: + while shape and p > 1: q = shape[0] + if q == 1: + shape.pop(0) + continue if q % p == 0: grouped_mode_shape[-1].append(p) shape[0] = q // p + if shape[0] == 1: + shape.pop(0) + p = 1 break elif p % q == 0: - if q > 1: - grouped_mode_shape[-1].append(q) + grouped_mode_shape[-1].append(q) + p //= q shape.pop(0) else: raise LayoutOperationError("Cannot reshape layout {} to shape {}".format(layout, shape)) diff --git a/python/tilus/lang/instructions/root.py b/python/tilus/lang/instructions/root.py index 2e1d970d..18d01cfd 100644 --- a/python/tilus/lang/instructions/root.py +++ b/python/tilus/lang/instructions/root.py @@ -695,6 +695,31 @@ def free_shared(self, tensor: SharedTensor) -> None: """ self._builder.free_shared(tensor) + def reshape(self, tensor: RegisterTensor, shape: Sequence[int]) -> RegisterTensor: + """Reshape a register tensor. + + The new shape must have the same total size as the original. The + underlying per-thread storage is unchanged; only the logical shape (and + mode grouping used for broadcasts/reductions) is updated. + + Parameters + ---------- + tensor: RegisterTensor + The register tensor to reshape. + shape: Sequence[int] + The new shape of the register tensor. + + Returns + ------- + ret: RegisterTensor + The reshaped register tensor. + + Notes + ----- + - **Thread group**: Can be executed by any sized thread group. + """ + return self._builder.reshape_register(x=tensor, shape=shape) + def reshape_shared(self, tensor: SharedTensor, shape: Sequence[int]) -> SharedTensor: """Reshape a shared tensor. diff --git a/python/tilus/transforms/lower_assume.py b/python/tilus/transforms/lower_assume.py index 9be73c37..38bda543 100644 --- a/python/tilus/transforms/lower_assume.py +++ b/python/tilus/transforms/lower_assume.py @@ -17,7 +17,7 @@ from tilus.ir.functors import IRRewriter from tilus.ir.instructions import AssumeInst from tilus.transforms.base import Pass -from tilus.utils import gcd +from tilus.utils import lcm class ApplyAssumeRewriter(IRRewriter): @@ -54,7 +54,11 @@ def visit_AssumeInst(self, inst: AssumeInst) -> None: raise RuntimeError( "We only allow to specify the divisibility of kernel parameter, got {}".format(a.name) ) - self.param2divisibility[a] = int(term.a.b.value) # type: ignore[arg-type] + divisor = int(term.a.b.value) # type: ignore[arg-type] + if a in self.param2divisibility: + self.param2divisibility[a] = lcm(self.param2divisibility[a], divisor) + else: + self.param2divisibility[a] = divisor else: raise RuntimeError("Can not recognize the condition in assume: {}".format(term)) @@ -70,7 +74,7 @@ def visit_Function(self, func: Function) -> Function: param2divisibility = updated_func.metadata.param2divisibility.copy() for var in self.param2divisibility: if var in param2divisibility: - param2divisibility[var] = gcd(param2divisibility[var], self.param2divisibility[var]) + param2divisibility[var] = lcm(param2divisibility[var], self.param2divisibility[var]) else: param2divisibility[var] = self.param2divisibility[var] return updated_func.with_metadata(updated_func.metadata.with_param2divisibility(param2divisibility)) diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 690878a7..dc0e7d94 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -68,6 +68,8 @@ ("hopper_matmul", "matmul_v5.py", nvgpu_sm90a), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), + ("quantization", "per_token_cast.py", nvgpu_sm90a), + ("quantization", "swiglu_forward_and_per_token_cast.py", nvgpu_sm90a), # flash attention decode examples (SM 8.0+) ("flash_attention_decode", "main.py", nvgpu_sm80), ]