From 4281491fe386cd9fd6db2614524c6eaf2b1914f5 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Sat, 4 Apr 2026 03:26:27 +0000 Subject: [PATCH 01/21] Add vector add example Signed-off-by: William Zhang --- examples/vector_add/vector_add.py | 84 +++++++++++++++++++++++++++++++ tests/examples/test_examples.py | 2 + 2 files changed, 86 insertions(+) create mode 100644 examples/vector_add/vector_add.py diff --git a/examples/vector_add/vector_add.py b/examples/vector_add/vector_add.py new file mode 100644 index 00000000..2e8b8f29 --- /dev/null +++ b/examples/vector_add/vector_add.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Elementwise vector addition on the GPU. + +Each thread block loads a contiguous tile of ``block_elems`` values from ``a`` and ``b``, +computes ``c = a + b``, and stores the result. This is a minimal Tilus example: one +:class:`tilus.Script`, :meth:`global_view`, :meth:`load_global`, elementwise ``+``, and +:meth:`store_global`. + +``n`` must be divisible by ``block_elems`` (enforced in :func:`main`). +""" + +import pandas +import tilus +import torch +from tilus import float32, int32 +from tilus.utils import benchmark_func, cdiv + + +class VectorAdd(tilus.Script): + """``c[i] = a[i] + b[i]`` for ``i in range(n)``.""" + + def __init__(self): + super().__init__() + self.block_elems = 1024 + + def __call__( + self, + n: int32, + a_ptr: ~float32, + b_ptr: ~float32, + c_ptr: ~float32, + ): + self.attrs.blocks = (cdiv(n, self.block_elems),) + self.attrs.warps = 4 + + offset: int32 = self.block_elems * self.blockIdx.x + + ga = self.global_view(a_ptr, dtype=float32, shape=[n]) + gb = self.global_view(b_ptr, dtype=float32, shape=[n]) + gc = self.global_view(c_ptr, dtype=float32, shape=[n]) + + ra = self.load_global(ga, offsets=[offset], shape=[self.block_elems]) + rb = self.load_global(gb, offsets=[offset], shape=[self.block_elems]) + rc = ra + rb + self.store_global(gc, rc, offsets=[offset]) + + +def main(): + headers = ["n", "name", "latency (ms)", "GB/s"] + # 3 x fp32: read a, read b, write c + nbytes = lambda n_elts: n_elts * 4 * 3 + workloads = [1 << 20, 1 << 24] + + rows = [] + for n in workloads: + assert n % 1024 == 0, "n must be divisible by block_elems (1024)" + + kernel = VectorAdd() + a = torch.randn(n, dtype=torch.float32, device="cuda") + b = torch.randn(n, dtype=torch.float32, device="cuda") + c_actual = torch.empty(n, dtype=torch.float32, device="cuda") + c_expect = a + b + torch.cuda.synchronize() + + kernel(n, a, b, c_actual) + torch.cuda.synchronize() + + torch.testing.assert_close(c_expect, c_actual) + + for name, func in [ + ("torch", lambda: torch.add(a, b, out=c_actual)), + ("tilus", lambda: kernel(n, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + gbps = nbytes(n) / (latency * 1e-3) / 1e9 + rows.append([n, name, latency, gbps]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +if __name__ == "__main__": + main() diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index ef61ae82..67431e28 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -43,6 +43,8 @@ ("norm", "layer_norm.py", None), # softmax example ("softmax", "softmax.py", None), + # vector add + ("vector_add", "vector_add.py", None), # attention examples (SM 8.0+) ("attention", "flash_attention_v1.py", nvgpu_sm80), ("attention", "flash_attention_v2.py", nvgpu_sm80), From c1d50d070832a42c30d937219528709f76e87c4f Mon Sep 17 00:00:00 2001 From: William Zhang Date: Tue, 7 Apr 2026 03:13:56 +0000 Subject: [PATCH 02/21] fix formatting Signed-off-by: William Zhang --- examples/vector_add/vector_add.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/vector_add/vector_add.py b/examples/vector_add/vector_add.py index 2e8b8f29..25f92beb 100644 --- a/examples/vector_add/vector_add.py +++ b/examples/vector_add/vector_add.py @@ -46,10 +46,13 @@ def __call__( self.store_global(gc, rc, offsets=[offset]) +def _nbytes_fp32_vector_add(n_elts: int) -> int: + # 3 x fp32: read a, read b, write c + return n_elts * 4 * 3 + + def main(): headers = ["n", "name", "latency (ms)", "GB/s"] - # 3 x fp32: read a, read b, write c - nbytes = lambda n_elts: n_elts * 4 * 3 workloads = [1 << 20, 1 << 24] rows = [] @@ -73,7 +76,7 @@ def main(): ("tilus", lambda: kernel(n, a, b, c_actual)), ]: latency = benchmark_func(func, warmup=5, repeat=20) - gbps = nbytes(n) / (latency * 1e-3) / 1e9 + gbps = _nbytes_fp32_vector_add(n) / (latency * 1e-3) / 1e9 rows.append([n, name, latency, gbps]) df = pandas.DataFrame(rows, columns=headers) From 145f5bf1d4b3fbadadd2343ecc9e56982c0c2c24 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Sun, 3 May 2026 20:22:37 +0000 Subject: [PATCH 03/21] add hopper matmuls Signed-off-by: William Zhang --- examples/hopper_matmul/benchmark.py | 116 +++++++++++++ examples/hopper_matmul/matmul_v4.py | 245 ++++++++++++++++++++++++++ examples/hopper_matmul/matmul_v5.py | 255 ++++++++++++++++++++++++++++ 3 files changed, 616 insertions(+) create mode 100644 examples/hopper_matmul/benchmark.py create mode 100644 examples/hopper_matmul/matmul_v4.py create mode 100644 examples/hopper_matmul/matmul_v5.py diff --git a/examples/hopper_matmul/benchmark.py b/examples/hopper_matmul/benchmark.py new file mode 100644 index 00000000..0e65bfe1 --- /dev/null +++ b/examples/hopper_matmul/benchmark.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Benchmark all hopper_matmul versions against cuBLAS (torch.matmul). + +Run directly: + python benchmark.py + +Or via the slurm job: + sbatch sample_slurm_hopper_benchmark.sh +""" + +import importlib +import math +import sys + +import pandas +import torch + +import tilus +from tilus.utils import benchmark_func, cdiv + +tilus.option.cache_dir("./cache") + +WORKLOADS = [ + # (m, n, k, label) + (1024, 1024, 1024, "1k-sq"), + (2048, 2048, 2048, "2k-sq"), + (4096, 4096, 4096, "4k-sq"), + (4096, 4096, 14336, "llm-ffn"), + (8192, 8192, 8192, "8k-sq"), + (10240, 10240, 10240, "10k-sq"), +] + +VERSIONS = ["v0", "v1", "v2", "v3", "v4", "v5"] + +VERSION_CLASS = { + "v0": "MatmulTMA", + "v1": "MatmulWGMMA", + "v2": "MatmulWGMMAV2", + "v3": "MatmulWGMMAV3", + "v4": "MatmulWGMMAV4", + "v5": "MatmulWGMMAV5", +} + +WARMUP = 5 +REPEAT = 30 + + +def load_version(name: str): + mod = importlib.import_module(f"matmul_{name}") + return getattr(mod, VERSION_CLASS[name]) + + +def tflops(m, n, k, latency_ms): + return 2 * m * n * k / latency_ms * 1e-9 + + +def run_benchmark(versions=None): + if versions is None: + versions = VERSIONS + + device = torch.cuda.get_device_name(0) + print(f"Device: {device}") + print(f"Versions: {versions}") + print() + + headers = ["workload", "m", "n", "k", "kernel", "latency (ms)", "tflops", "% of cublas"] + rows = [] + + for m, n, k, label in WORKLOADS: + print(f"--- {label} m={m} n={n} k={k} ---") + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_ref = torch.empty(m, n, dtype=torch.float16).cuda() + c_tilus = torch.empty(m, n, dtype=torch.float16).cuda() + + # cuBLAS baseline + cublas_lat = benchmark_func(lambda: torch.matmul(a, b.T, out=c_ref), warmup=WARMUP, repeat=REPEAT) + cublas_tf = tflops(m, n, k, cublas_lat) + rows.append([label, m, n, k, "cublas", cublas_lat, cublas_tf, 100.0]) + print(f" cublas {cublas_lat:.4f} ms {cublas_tf:.1f} TFLOPS") + + for ver in versions: + try: + cls = load_version(ver) + kernel = cls() + # correctness check + kernel(m, n, k, a, b, c_tilus) + torch.cuda.synchronize() + torch.testing.assert_close(c_ref, c_tilus, atol=1e-2, rtol=1e-2) + + lat = benchmark_func(lambda: kernel(m, n, k, a, b, c_tilus), warmup=WARMUP, repeat=REPEAT) + tf = tflops(m, n, k, lat) + pct = tf / cublas_tf * 100.0 + rows.append([label, m, n, k, f"tilus-{ver}", lat, tf, pct]) + print(f" tilus-{ver} {lat:.4f} ms {tf:.1f} TFLOPS ({pct:.1f}% of cuBLAS)") + except Exception as e: + print(f" tilus-{ver} ERROR: {e}", file=sys.stderr) + rows.append([label, m, n, k, f"tilus-{ver}", float("nan"), float("nan"), float("nan")]) + + print() + + df = pandas.DataFrame(rows, columns=headers) + print("\n=== Summary ===") + print(df.to_string(index=False)) + return df + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--versions", nargs="+", default=None, + help="Subset of versions to benchmark, e.g. --versions v3 v4 v5") + args = parser.parse_args() + run_benchmark(versions=args.versions) diff --git a/examples/hopper_matmul/matmul_v4.py b/examples/hopper_matmul/matmul_v4.py new file mode 100644 index 00000000..229d2796 --- /dev/null +++ b/examples/hopper_matmul/matmul_v4.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# v4: Pipeline class abstraction + tile rasterization for better L2 cache reuse. +# +# Changes from v3: +# - Pipeline class encapsulates barrier/phase/stage ring-buffer logic. +# - 1D grid launch with compute_block_coord() maps linear blockIdx to (m, n) +# using a swizzle group so adjacent tiles share B columns → L2 reuse. + +import math + +import pandas +import tilus +import torch +from tilus import RegisterTensor, float16, float32, int32, uint32 +from tilus.utils import benchmark_func, cdiv + + +class Pipeline(tilus.Class): + def __init__( + self, + num_stages: int, + producer_arrive_count: int = 1, + consumer_arrive_count: int = 1, + ): + self.num_stages: int = num_stages + self.empty_barriers = self.mbarrier.alloc( + [consumer_arrive_count for _ in range(num_stages)] + ) + self.full_barriers = self.mbarrier.alloc( + [producer_arrive_count for _ in range(num_stages)] + ) + self.producer_stage: int32 = 0 + self.consumer_stage: int32 = 0 + self.producer_phase: uint32 = self.mbarrier.producer_initial_phase + self.consumer_phase: uint32 = self.mbarrier.consumer_initial_phase + + def producer_acquire(self): + self.mbarrier.wait( + barrier=self.empty_barriers[self.producer_stage], + phase=self.producer_phase, + sem="relaxed", + scope="cta", + ) + + def producer_barrier(self) -> RegisterTensor: + return self.full_barriers[self.producer_stage] + + def producer_advance(self): + self.producer_stage = (self.producer_stage + 1) % self.num_stages + self.producer_phase = self.producer_phase ^ (self.producer_stage == 0) + + def consumer_acquire(self): + self.mbarrier.wait( + barrier=self.full_barriers[self.consumer_stage], + phase=self.consumer_phase, + sem="relaxed", + scope="cta", + ) + + def consumer_barrier(self) -> RegisterTensor: + return self.empty_barriers[self.consumer_stage] + + def consumer_advance(self): + self.consumer_stage = (self.consumer_stage + 1) % self.num_stages + self.consumer_phase = self.consumer_phase ^ (self.consumer_stage == 0) + + def prev_consumer_barrier(self) -> RegisterTensor: + # Use (consumer_stage + num_stages - 1) so num_stages - 1 is evaluated as a + # Python int constant first, ensuring the inner expression is always non-negative + # and avoids C's negative-dividend truncated-modulo behaviour. + prev_stage = (self.consumer_stage + (self.num_stages - 1)) % self.num_stages + return self.empty_barriers[prev_stage] + + +@tilus.autotune("num_stages", [2, 3, 4, 5, 6, 7]) +@tilus.autotune( + "block_m, block_n", [[128, 64], [128, 128], [128, 256], [256, 128], [256, 256]] +) +@tilus.autotune("block_k", [16, 32, 64]) +@tilus.autotune("swizzle_size", [1, 4, 8]) +class MatmulWGMMAV4(tilus.Script): + def __init__(self, num_stages, block_m, block_n, block_k, swizzle_size): + super().__init__() + self.num_stages = num_stages + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + self.swizzle_size = swizzle_size + + def compute_block_coord( + self, linear_idx: int32, num_m_blocks: int32, num_n_blocks: int + ): + """Map 1D linear block index to 2D (m_block, n_block) with swizzle grouping. + + Tiles in the same swizzle group share N-columns, keeping B columns in L2. + """ + swizzle_size = self.swizzle_size + tiles_per_group = num_m_blocks * swizzle_size + group_idx, in_group_idx = self.fast_divmod(linear_idx, tiles_per_group) + first_n = group_idx * swizzle_size + m_block: int32 = 0 + n_block: int32 = 0 + remainder = num_n_blocks - num_n_blocks // swizzle_size * swizzle_size + last_group_width = remainder if remainder > 0 else swizzle_size + if first_n + swizzle_size <= num_n_blocks: + m_block, r = self.fast_divmod(in_group_idx, swizzle_size) + n_block = first_n + r + else: + m_block, r = self.fast_divmod(in_group_idx, last_group_width) + n_block = first_n + r + return m_block, n_block + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + num_stages = self.num_stages + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + + num_m_blocks = cdiv(m_size, block_m) + num_n_blocks = cdiv(n_size, block_n) + self.attrs.blocks = num_m_blocks * num_n_blocks + self.attrs.warps = 5 + + m_block, n_block = self.compute_block_coord( + self.blockIdx.x, num_m_blocks, num_n_blocks + ) + offset_m: int32 = m_block * block_m + offset_n: int32 = n_block * block_n + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) + sa = self.shared_tensor(dtype=float16, shape=[num_stages, block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[num_stages, block_n, block_k]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + # producer_arrive_count=1: single thread does arrive_and_expect_tx + # consumer_arrive_count=128: all 128 consumer threads arrive when done + tma_pipe = Pipeline(num_stages, producer_arrive_count=1, consumer_arrive_count=128) + + with self.thread_group(thread_begin=128, num_threads=32): # TMA producer warp + for offset_k in self.range(0, k_size, block_k, unroll=num_stages): + tma_pipe.producer_acquire() + with self.single_thread(): + self.mbarrier.arrive_and_expect_tx( + tma_pipe.producer_barrier(), + transaction_bytes=sa[tma_pipe.producer_stage].nbytes + + sb[tma_pipe.producer_stage].nbytes, + ) + self.tma.global_to_shared( + src=ga, + dst=sa[tma_pipe.producer_stage], + offsets=[offset_m, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + self.tma.global_to_shared( + src=gb, + dst=sb[tma_pipe.producer_stage], + offsets=[offset_n, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + tma_pipe.producer_advance() + + # drain: wait for consumer to finish processing all in-flight stages + for _ in self.range(min(num_stages, cdiv(k_size, block_k))): + tma_pipe.producer_acquire() + tma_pipe.producer_advance() + + with self.thread_group(thread_begin=0, num_threads=128): # WGMMA consumer + # Prologue: issue first MMA; don't release stage 0 yet (it is + # still being read; we release it in the first main-loop iteration + # after wait_group(1) confirms the MMA is done). + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) + self.wgmma.commit_group() + tma_pipe.consumer_advance() + + # Main loop: issue MMA then call wait_group(1) *after* commit so + # the hardware can pipeline the current and previous groups while + # the TMA wait (consumer_acquire) was overlapping with the prior + # group. wait_group(1) stalls until the *previous* group (n-1) is + # done, then we safely release its stage before advancing. + for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) + self.wgmma.commit_group() + self.wgmma.wait_group(1) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + tma_pipe.consumer_advance() + + # Epilogue: drain the last in-flight MMA group, then release its stage. + self.wgmma.wait_group(0) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + + self.sync() + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + [4096, 4096, 14336], + [8192, 8192, 8192], + [10240, 10240, 10240], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulWGMMAV4() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b.T + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) + + for name, func in [ + ("torch", lambda: torch.matmul(a, b.T, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +if __name__ == "__main__": + main() diff --git a/examples/hopper_matmul/matmul_v5.py b/examples/hopper_matmul/matmul_v5.py new file mode 100644 index 00000000..6a2187d1 --- /dev/null +++ b/examples/hopper_matmul/matmul_v5.py @@ -0,0 +1,255 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# v5: TMA epilogue — replaces direct register-to-global stores with a +# shared-memory staging buffer written out via TMA bulk store. +# +# Changes from v4: +# - Pre-allocates s_c[block_m, block_n] alongside s_a/s_b. +# - After the K loop, stores the float16-cast accumulator to s_c, then +# issues a TMA shared→global transfer instead of store_global. +# - A fence.proxy_async(space="shared") between store_shared and TMA is +# required so the generic-proxy writes are visible to the async proxy. + +import math + +import pandas +import tilus +import torch +from tilus import RegisterTensor, float16, float32, int32, uint32 +from tilus.utils import benchmark_func, cdiv + + +class Pipeline(tilus.Class): + def __init__( + self, + num_stages: int, + producer_arrive_count: int = 1, + consumer_arrive_count: int = 1, + ): + self.num_stages: int = num_stages + self.empty_barriers = self.mbarrier.alloc( + [consumer_arrive_count for _ in range(num_stages)] + ) + self.full_barriers = self.mbarrier.alloc( + [producer_arrive_count for _ in range(num_stages)] + ) + self.producer_stage: int32 = 0 + self.consumer_stage: int32 = 0 + self.producer_phase: uint32 = self.mbarrier.producer_initial_phase + self.consumer_phase: uint32 = self.mbarrier.consumer_initial_phase + + def producer_acquire(self): + self.mbarrier.wait( + barrier=self.empty_barriers[self.producer_stage], + phase=self.producer_phase, + sem="relaxed", + scope="cta", + ) + + def producer_barrier(self) -> RegisterTensor: + return self.full_barriers[self.producer_stage] + + def producer_advance(self): + self.producer_stage = (self.producer_stage + 1) % self.num_stages + self.producer_phase = self.producer_phase ^ (self.producer_stage == 0) + + def consumer_acquire(self): + self.mbarrier.wait( + barrier=self.full_barriers[self.consumer_stage], + phase=self.consumer_phase, + sem="relaxed", + scope="cta", + ) + + def consumer_barrier(self) -> RegisterTensor: + return self.empty_barriers[self.consumer_stage] + + def consumer_advance(self): + self.consumer_stage = (self.consumer_stage + 1) % self.num_stages + self.consumer_phase = self.consumer_phase ^ (self.consumer_stage == 0) + + def prev_consumer_barrier(self) -> RegisterTensor: + # Use (consumer_stage + num_stages - 1) so num_stages - 1 is evaluated as a + # Python int constant first, ensuring the inner expression is always non-negative + # and avoids C's negative-dividend truncated-modulo behaviour. + prev_stage = (self.consumer_stage + (self.num_stages - 1)) % self.num_stages + return self.empty_barriers[prev_stage] + + +@tilus.autotune("num_stages", [2, 3, 4, 5]) +@tilus.autotune( + "block_m, block_n", [[128, 64], [128, 128], [128, 256], [256, 128], [256, 256]] +) +@tilus.autotune("block_k", [16, 32, 64]) +@tilus.autotune("swizzle_size", [1, 4, 8]) +class MatmulWGMMAV5(tilus.Script): + def __init__(self, num_stages, block_m, block_n, block_k, swizzle_size): + super().__init__() + self.num_stages = num_stages + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + self.swizzle_size = swizzle_size + + def compute_block_coord( + self, linear_idx: int32, num_m_blocks: int32, num_n_blocks: int + ): + swizzle_size = self.swizzle_size + tiles_per_group = num_m_blocks * swizzle_size + group_idx, in_group_idx = self.fast_divmod(linear_idx, tiles_per_group) + first_n = group_idx * swizzle_size + m_block: int32 = 0 + n_block: int32 = 0 + remainder = num_n_blocks - num_n_blocks // swizzle_size * swizzle_size + last_group_width = remainder if remainder > 0 else swizzle_size + if first_n + swizzle_size <= num_n_blocks: + m_block, r = self.fast_divmod(in_group_idx, swizzle_size) + n_block = first_n + r + else: + m_block, r = self.fast_divmod(in_group_idx, last_group_width) + n_block = first_n + r + return m_block, n_block + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + num_stages = self.num_stages + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + + num_m_blocks = cdiv(m_size, block_m) + num_n_blocks = cdiv(n_size, block_n) + self.attrs.blocks = num_m_blocks * num_n_blocks + self.attrs.warps = 5 + + m_block, n_block = self.compute_block_coord( + self.blockIdx.x, num_m_blocks, num_n_blocks + ) + offset_m: int32 = m_block * block_m + offset_n: int32 = n_block * block_n + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + sa = self.shared_tensor(dtype=float16, shape=[num_stages, block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[num_stages, block_n, block_k]) + # s_c is the staging buffer for the TMA epilogue; allocated alongside + # s_a/s_b so the allocator can pick a fitting shared-memory partition. + sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + tma_pipe = Pipeline(num_stages, producer_arrive_count=1, consumer_arrive_count=128) + + with self.thread_group(thread_begin=128, num_threads=32): # TMA producer warp + for offset_k in self.range(0, k_size, block_k, unroll=num_stages): + tma_pipe.producer_acquire() + with self.single_thread(): + self.mbarrier.arrive_and_expect_tx( + tma_pipe.producer_barrier(), + transaction_bytes=sa[tma_pipe.producer_stage].nbytes + + sb[tma_pipe.producer_stage].nbytes, + ) + self.tma.global_to_shared( + src=ga, + dst=sa[tma_pipe.producer_stage], + offsets=[offset_m, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + self.tma.global_to_shared( + src=gb, + dst=sb[tma_pipe.producer_stage], + offsets=[offset_n, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + tma_pipe.producer_advance() + + for _ in self.range(min(num_stages, cdiv(k_size, block_k))): + tma_pipe.producer_acquire() + tma_pipe.producer_advance() + + with self.thread_group(thread_begin=0, num_threads=128): # WGMMA consumer + # Prologue: issue first MMA; release happens in first main-loop + # iteration after wait_group(1) confirms the MMA is done. + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) + self.wgmma.commit_group() + tma_pipe.consumer_advance() + + # Main loop: issue MMA then wait_group(1) *after* commit so the + # hardware pipelines the current and previous MMA groups while + # consumer_acquire overlaps with the prior group's execution. + for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) + self.wgmma.commit_group() + self.wgmma.wait_group(1) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + tma_pipe.consumer_advance() + + # Epilogue: drain last in-flight MMA, release its stage. + self.wgmma.wait_group(0) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + + # TMA epilogue: registers → shared → global + self.sync() + casted_acc = self.cast(acc, dtype=float16) + self.store_shared(sc, casted_acc) + # fence required: store_shared uses generic proxy; TMA uses async proxy + self.fence.proxy_async(space="shared") + self.sync() + with self.single_thread(): + self.tma.shared_to_global( + sc, + gc, + offsets=[offset_m, offset_n], + dims=[0, 1], + ) + self.tma.commit_group() + self.tma.wait_group(n=0, read=True) + self.sync() + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + [4096, 4096, 14336], + [8192, 8192, 8192], + [10240, 10240, 10240], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulWGMMAV5() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b.T + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) + + for name, func in [ + ("torch", lambda: torch.matmul(a, b.T, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +if __name__ == "__main__": + main() From 1e53e7aab53be4f775397b7127d6339c78240447 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Wed, 6 May 2026 13:18:59 +0000 Subject: [PATCH 04/21] make hopper matmuls generate correct results on h200 Signed-off-by: William Zhang --- .claude/skills/ncu-report/SKILL.md | 271 ---------------- .claude/skills/write-docs/SKILL.md | 220 ------------- examples/hopper_matmul/benchmark.py | 300 +++++++++++++----- examples/hopper_matmul/matmul_v3.py | 27 +- examples/hopper_matmul/matmul_v4.py | 27 +- examples/hopper_matmul/matmul_v5.py | 30 +- .../backends/contexts/mbarrier_alloc_ctx.py | 15 +- python/tilus/backends/emitter.py | 26 ++ .../backends/emitters/cuda/cp_async_tensor.py | 18 +- .../hidet/ir/analyzers/bound_analyzer.py | 4 +- python/tilus/hidet/option.py | 13 + .../transforms/_rule_based_simplifier_base.py | 58 +++- python/tilus/lang/instantiated_script.py | 80 ++++- python/tilus/option.py | 12 + tests/examples/test_examples.py | 3 + 15 files changed, 461 insertions(+), 643 deletions(-) delete mode 100644 .claude/skills/ncu-report/SKILL.md delete mode 100644 .claude/skills/write-docs/SKILL.md diff --git a/.claude/skills/ncu-report/SKILL.md b/.claude/skills/ncu-report/SKILL.md deleted file mode 100644 index 9fbccf1f..00000000 --- a/.claude/skills/ncu-report/SKILL.md +++ /dev/null @@ -1,271 +0,0 @@ ---- -name: ncu-report -description: > - Analyze NVIDIA Nsight Compute (ncu) profiling reports (.ncu-rep files). - Extract metrics, performance data, SASS/CUDA source, and identify bottlenecks. - TRIGGER when: user asks to analyze, profile, or look at an ncu report, .ncu-rep file, - Nsight Compute report, kernel performance/profiling data from ncu, or asks to generate/collect - an ncu profile for a tilus kernel or example script. - DO NOT TRIGGER when: user is writing unrelated profiling code. -user-invocable: true ---- - -# Nsight Compute Report Analysis - -This skill handles two modes: - -1. **Analyze an existing report**: The user provides a path to an `.ncu-rep` file (or one exists under `examples/`). Use the `ncu` CLI to extract and present metrics. - -2. **Generate a new report**: The user specifies a script or kernel to profile but does NOT provide a `.ncu-rep` file. In this case, set up profiling using `tilus.utils.ncu_utils.ncu_run()`, run it, then analyze the resulting report. - -## Generating a Report - -Tilus provides `tilus.utils.ncu_utils.ncu_run()` to profile kernels with full metrics and source correlation. - -### API -```python -from tilus.utils.ncu_utils import ncu_run - -# ncu_run(func, *args, kernel_regex=".*", **kwargs) -> NsightComputeReport -report = ncu_run(main, bench=False, kernel_regex="tilus|nvjet") -``` - -- `func`: a callable (typically a `main()` function) that runs the kernel(s) to profile -- `*args`, `**kwargs`: passed through to `func` -- `kernel_regex`: regex to filter which kernels to profile (default `".*"`) -- Returns `NsightComputeReport` with `.report_path` pointing to the generated `.ncu-rep` file - -### What it does -- Runs the function under `ncu` with `--set full`, all rules enabled, and `--import-source yes` -- Saves the report to `ncu-reports/reportN.ncu-rep` next to the script (auto-increments N) -- Uses the system Python and the `ncu` binary at `/usr/local/cuda/bin/ncu` - -### How to generate a report for the user - -**IMPORTANT**: `ncu_run()` must be called inside a `if __name__ == "__main__":` block. It works by re-importing the script as a subprocess under `ncu` — if `ncu_run()` is at module level, the subprocess will call `ncu_run()` again, causing infinite recursion and a runtime error. - -**Step 1: Read the script** — find the example script the user specified and read it to understand how the kernel is invoked. - -**Step 2: Set up profiling** — choose one of these approaches: - -- **If the script already has a `__main__` block with `ncu_run()`**: just run it directly. -- **If the script has a `__main__` block but no `ncu_run()`**: edit the `__main__` block to add `ncu_run()`. For example, if the block calls `main()`, change it to call `ncu_run(main, bench=False, kernel_regex="tilus")`. -- **If the script has no `__main__` block or is hard to modify** (e.g., it's a test file, or the kernel launch is deeply nested): create a new script next to it (e.g., `profile_.py`) that imports and calls the kernel under `ncu_run()`. - -Example of editing an existing `__main__` block: -```python -if __name__ == "__main__": - from tilus.utils.ncu_utils import ncu_run - ncu_run(main, bench=False, kernel_regex="tilus") -``` - -Example of creating a new profiling script: -```python -from tilus.utils.ncu_utils import ncu_run -from matmul_v9 import main - -if __name__ == "__main__": - ncu_run(main, bench=False, kernel_regex="tilus") -``` - -**Step 3: Run the script** — `python `. The report will be saved to `/ncu-reports/reportN.ncu-rep`. - -**Step 4: Analyze** — proceed to the Analysis Workflow below with the generated report. - -**Note**: `ncu` profiling requires `sudo` or appropriate permissions (CAP_SYS_ADMIN). If the command fails with permission errors, suggest running with `sudo`. - -## Analysis Workflow - -Follow this sequence. Skip steps the user doesn't need, but always start with Step 1. - -### Step 1: Overview — List kernels and session info - -Run these in parallel: - -```bash -# List all kernels with timing -ncu -i --page raw --csv --metrics gpu__time_duration.sum 2>&1 - -# Session/device info -ncu -i --page session --csv 2>&1 -``` - -Present a summary table: -- Kernel name (shortened), Block Size, Grid Size, Duration (ms) -- Device name, compute capability, CUDA version - -### Step 2: Speed of Light — Top-level throughput - -```bash -ncu -i --page details --csv --section SpeedOfLight 2>&1 -``` - -Key metrics to highlight per kernel: -- **Duration** (ms) -- **Compute (SM) Throughput** (%) — how busy the SMs are -- **Memory Throughput** (%) — overall memory utilization -- **DRAM Throughput** (%) — HBM bandwidth utilization -- **L1/TEX Cache Throughput** (%) -- **L2 Cache Throughput** (%) -- **SOLBottleneck rule** — check the Rule Description column for bottleneck guidance - -### Step 3: Compute & Memory Workload Analysis - -```bash -# Compute workload -ncu -i --page details --csv --section ComputeWorkloadAnalysis 2>&1 - -# Memory workload -ncu -i --page details --csv --section MemoryWorkloadAnalysis 2>&1 -``` - -Key compute metrics: Executed IPC Active, SM Busy %, Issue Slots Busy % -Key memory metrics: Mem Busy %, Max Bandwidth %, L1/L2 hit rates - -### Step 4: Occupancy - -```bash -ncu -i --page details --csv --section Occupancy 2>&1 -``` - -Report: Theoretical Occupancy, Achieved Occupancy, and limiters (registers, shared memory, block size). - -### Step 5: Detailed metrics (on demand) - -To extract specific raw metrics: -```bash -ncu -i --page raw --csv --metrics ,,... 2>&1 -``` - -To filter by kernel: -```bash -ncu -i --page raw --csv --metrics --kernel-name regex: 2>&1 -``` - -### Step 6: Source-level analysis (on demand) - -SASS-only (default, always available): -```bash -ncu -i --page source --csv --kernel-name regex: 2>&1 -``` - -CUDA source correlated with SASS (requires `--import-source yes` during profiling): -```bash -ncu -i --page source --csv --print-source cuda,sass --kernel-name regex: 2>&1 -``` - -Source output columns include per-instruction: Warp Stall Sampling, Instructions Executed, Thread Instructions Executed, stall reasons (stall_barrier, stall_math, stall_wait, etc.), shared memory conflicts, and more. - -### Step 7: Rules / automated analysis (on demand) - -Rules are included in the details page output. Look for non-empty "Rule Name" column entries. -```bash -ncu -i --page details --csv --print-rule-details 2>&1 | grep -v '^"[0-9]' | head -5 # header -``` - -To see all rule results with descriptions: -```bash -ncu -i --page details --csv --print-rule-details 2>&1 -``` -Filter for rows where column 17 (Rule Name) is non-empty. - -## Reference: ncu CLI Options for Report Analysis - -### Pages (`--page`) -| Page | Description | -|------|-------------| -| `details` | Sections with metrics organized by section name + rule results | -| `raw` | All collected metrics as flat columns (one row per kernel) | -| `source` | Per-instruction source code with correlated metrics | -| `session` | Session info, device attributes, launch settings | - -### Key Flags -| Flag | Description | -|------|-------------| -| `--csv` | Output as CSV (essential for parsing) | -| `--metrics ,` | Filter specific metrics (for `raw` page) | -| `--section ` | Filter by section identifier (for `details` page) | -| `--kernel-name regex:` | Filter kernels by name regex | -| `--kernel-name ` | Filter by exact kernel name | -| `--print-source sass\|ptx\|cuda\|cuda,sass` | Select source view for `source` page | -| `--print-details header\|body\|all` | Control detail level: `header` (default), `body` (charts/tables), `all` | -| `--print-metric-name name` | Show internal metric names instead of display labels | -| `--print-metric-name label-name` | Show both label and internal name | -| `--print-units base` | Show metrics in base units (no auto-scaling) | -| `--print-summary per-kernel` | Aggregate across invocations per kernel (min/max/avg) | -| `--print-rule-details` | Include additional rule tables and KPI metrics | - -### Section Identifiers -| Identifier | Display Name | -|------------|-------------| -| `SpeedOfLight` | GPU Speed Of Light Throughput | -| `ComputeWorkloadAnalysis` | Compute Workload Analysis | -| `MemoryWorkloadAnalysis` | Memory Workload Analysis | -| `MemoryWorkloadAnalysis_Tables` | Memory Workload Analysis Tables | -| `Occupancy` | Occupancy | -| `LaunchStats` | Launch Statistics | -| `SchedulerStats` | Scheduler Statistics | -| `WarpStateStats` | Warp State Statistics | -| `InstructionStats` | Instruction Statistics | -| `SourceCounters` | Source Counters | -| `WorkloadDistribution` | GPU and Memory Workload Distribution | -| `NumaAffinity` | NUMA Affinity | -| `SpeedOfLight_RooflineChart` | GPU Speed Of Light Roofline Chart | -| `SpeedOfLight_HierarchicalTensorRooflineChart` | Roofline Chart (Tensor Core) | -| `SpeedOfLight_HierarchicalHalfRooflineChart` | Roofline Chart (Half Precision) | - -### Section Sets (used during profiling with `--set`) -| Set | Sections | Est. Metrics | -|-----|----------|-------------| -| `basic` | LaunchStats, Occupancy, SpeedOfLight, WorkloadDistribution | 213 | -| `detailed` | basic + ComputeWorkloadAnalysis, MemoryWorkloadAnalysis, SourceCounters, Roofline | 906 | -| `full` | All sections including Instruction/Scheduler/WarpState stats, all Rooflines | 7794 | - -### Commonly Used Raw Metrics -| Metric | Description | -|--------|-------------| -| `gpu__time_duration.sum` | Kernel wall-clock duration | -| `sm__throughput.avg.pct_of_peak_sustained_elapsed` | SM throughput % | -| `dram__throughput.avg.pct_of_peak_sustained_elapsed` | DRAM throughput % | -| `sm__warps_active.avg.pct_of_peak_sustained_active` | Active warps % | -| `launch__occupancy_limit_registers` | Occupancy limiter: registers | -| `launch__occupancy_limit_shared_mem` | Occupancy limiter: shared memory | -| `launch__occupancy_limit_blocks` | Occupancy limiter: blocks | -| `launch__occupancy_limit_warps` | Occupancy limiter: warps | -| `sm__sass_thread_inst_executed_op_*` | Per-opcode instruction counts | -| `l1tex__t_sector_hit_rate.pct` | L1 cache hit rate | -| `lts__t_sector_hit_rate.pct` | L2 cache hit rate | - -### Available Rules (used during profiling with `--rule`) -| Rule ID | Description | -|---------|-------------| -| `SOLBottleneck` | High-level bottleneck detection | -| `SOLFPRoofline` | Floating Point Roofline Analysis | -| `CPIStall` | Warp stall analysis | -| `Occupancy` | Achieved Occupancy analysis | -| `LaunchConfiguration` | Kernel launch config analysis | -| `HighPipeUtilization` | High pipe utilization bottleneck | -| `IssueSlotUtilization` | Scheduler issue analysis | -| `SharedMemoryConflicts` | Shared memory bank conflicts | -| `ThreadDivergence` | Warp/thread divergence | -| `UncoalescedGlobalAccess` | Uncoalesced global memory | -| `UncoalescedSharedAccess` | Uncoalesced shared memory | -| `SlowPipeLimiter` | Slow pipe limiting compute | -| `FPInstructions` | FP instruction analysis | -| `PCSamplingData` | PC sampling data | - -## Comparing Kernels - -When the report contains multiple kernels (e.g., a reference nvjet kernel and a tilus kernel), always present metrics side-by-side for comparison. Highlight: -1. Duration difference (which is faster, by how much) -2. Throughput differences (compute vs memory bound) -3. Occupancy differences -4. Any rule findings that differ - -## Tips -- The `raw` page has one row per kernel with all metrics as columns — good for extracting specific values. -- The `details` page organizes metrics by section — good for browsing all metrics in a section. -- The `source` page is per-instruction — good for hotspot analysis. Output can be very large; pipe through `head` or filter with `grep`. -- Use `--print-units base` with `--csv` for consistent numeric parsing. -- Use `--print-metric-name name` to get programmatic metric names instead of human labels. -- Source analysis with `cuda,sass` view shows CUDA source lines interleaved with their SASS instructions — extremely useful for correlating high-level code with assembly hotspots. diff --git a/.claude/skills/write-docs/SKILL.md b/.claude/skills/write-docs/SKILL.md deleted file mode 100644 index 7566fb85..00000000 --- a/.claude/skills/write-docs/SKILL.md +++ /dev/null @@ -1,220 +0,0 @@ ---- -name: write-docs -description: > - Convention and format for writing instruction docstrings and RST tutorials in - tilus documentation. TRIGGER when: user asks to add, update, or write - documentation for tilus instructions, instruction groups, or tutorials. ---- - -# Writing Instruction Documentation - -## Docstring Format - -All instruction docstrings use **NumPy-style** format with the following structure: - -```python -def method_name(self, param1: Type1, param2: Type2) -> ReturnType: - """One-line summary of what the instruction does. - - Extended description explaining the behavior, semantics, and constraints - of the instruction. This can be multiple paragraphs. - - Parameters - ---------- - param1: Type1 - Description of the parameter. Include constraints and valid ranges - (e.g., "must be evaluated to a positive int32"). - param2: Type2 - Description. For parameters with defaults, explain the default behavior - (e.g., "By default, it is 1."). - - Returns - ------- - ret: ReturnType - Description of the return value including shape, dtype, and relationship - to inputs. - """ -``` - -## Conventions - -### Summary line -- Start with a verb: "Allocate...", "Arrive at...", "Compute the...", "Load...", "Wait at..." -- Keep it to one line - -### Extended description -- Explain what the instruction does at the hardware level when relevant -- Document state transitions (e.g., barrier phase switching) -- Describe relationships between parameters -- Explain multicast/cluster behavior for distributed instructions -- Scale detail to complexity: simple methods (load, cast) need minimal description; complex methods (mbarrier.alloc, tma.global_to_shared) need thorough explanation - -### Parameters section -- Document in the same order as the function signature -- Use full type annotations: `RegisterTensor`, `Expr | int`, `Optional[Type]` -- Include constraints: "must be in the range of [0, N)", "must be evaluated to a non-negative int32" -- Document valid candidates for string parameters: `Candidates: 'relaxed', 'release'.` -- Explain default values: "By default, it is 1." - -### Returns section -- Use `ret: Type` format -- Describe shape, dtype, and semantic meaning - -### Notes section (required) -Every instruction must have a Notes section with these items as a compact bullet list: - -```python - Notes - ----- - - **Thread group**: Can be executed by any sized thread group. - - **Hardware**: Requires compute capability 8.0+ (sm_80). - - **PTX**: ``mbarrier.init.shared::cta.b64`` -``` - -The three standard note items: - -- **Thread group**: Execution requirements. Common values: - - "Can be executed by any sized thread group." - - "Must be executed by a warp group (4 warps)." - - "Must be executed by a single thread (use ``self.single_thread()``)." - - "Must be executed by a single warp (use ``self.single_warp()`)" -- **Hardware**: Minimum compute capability. Format: "Requires compute capability X.Y+ (sm_XY)." - - sm_80 = Ampere (A100), sm_89 = Ada Lovelace (L4/L40), sm_90 = Hopper (H100), sm_100 = Blackwell (B200) -- **PTX**: The underlying PTX instruction(s) this maps to. Use double backticks for inline code. - If the instruction maps to multiple PTX instructions depending on parameters, list them: - ``` - - **PTX**: ``mbarrier.arrive.shared::cta.b64`` or ``mbarrier.arrive.noComplete.shared::cta.b64`` - ``` - If the instruction does not lower to a specific PTX instruction (e.g., it's a high-level - construct), omit the PTX line. - -### Other optional sections -- **Examples**: Use `.. code-block:: python` for usage examples -- **See Also**: Cross-reference related methods with `:py:meth:` or `:py:func:` - -### Memory ordering parameters -For synchronization instructions, document `sem` and `scope` parameters consistently: -```python -sem: str - The memory ordering semantics for the operation. Candidates: 'relaxed', 'release'. -scope: str - The synchronization scope for the operation. Candidates: 'cta', 'cluster'. -``` - -## Reference examples -- Simple instruction: `root.py:load_global`, `root.py:cast` -- Complex instruction: `mbarrier.py:alloc`, `mbarrier.py:arrive_and_expect_tx` -- With PTX reference: `fence.py:proxy_async`, `fence.py:proxy_async_release` -- With code example: `root.py:range`, `root.py:thread_group` -- TMA instruction: `tma.py:global_to_shared` - ---- - -# Writing RST Tutorials - -## Target audience - -Tutorials target **CS researchers who can write Triton kernels** but want to -understand the hardware features underneath Triton's abstractions. Assume readers -know: -- Block-level GPU programming (each program instance processes a tile) -- `tl.load`, `tl.store`, `tl.dot` semantics -- Autotuning concepts -- Basic GPU memory hierarchy (global, shared, registers) - -Do **not** assume they know: -- Explicit shared memory management or allocation -- Warp-level programming or thread indexing within a block -- Asynchronous execution models (mbarrier, TMA, commit/wait patterns) -- Tensor Memory or tcgen05 instruction families -- Memory ordering semantics (acquire/release, proxy fences) - -## Bridging the gap from Triton - -When introducing a concept that Triton handles implicitly, briefly explain **why** -explicit control is needed. Common contrasts: - -- **Shared memory**: "Unlike Triton where shared memory is managed automatically, - tilus gives explicit control --- necessary to use hardware features like TMA and - tcgen05." -- **Registers vs Tensor Memory**: "On earlier architectures (and in Triton), - MMA results accumulate in registers. Blackwell's tensor cores use dedicated - Tensor Memory, which provides higher bandwidth and avoids consuming register - file capacity for large tiles." -- **Synchronization**: "Triton handles synchronization implicitly. On Blackwell, - many operations are asynchronous --- the instruction returns immediately and - completes in the background. This enables overlap of data movement and - computation, but requires explicit tracking via mbarriers." -- **Thread/warp management**: "In Triton, all threads execute the same code. - Efficient Blackwell kernels require different warps to perform different - jobs (loading, computing, scheduling) and collaborate asynchronously via - thread groups." - -## Tutorial structure - -Each tutorial version (v0, v1, ...) should follow this structure: - -1. **Introduction** --- What this version adds, what Blackwell features it uses - (with hyperlinks to instruction group docs). -2. **Full kernel** --- Show the complete kernel upfront so readers see the big - picture before the detailed walkthrough. -3. **Topic sections** --- Explain each new concept with enough detail to - understand the example. Order by conceptual dependency. Include: - - A brief motivation (why does this exist / why do we need it) - - How it works at a high level - - Link to the detailed API/programming guide for deeper reading -4. **Walkthrough** --- Walk through the kernel code in logical groups (setup, - main loop, epilogue). Use `literalinclude` with `:start-at:`/`:end-at:` - markers (never absolute line numbers). For each group, use a bullet list - explaining each instruction with hyperlinks. -5. **What's Next** --- Motivate the next version by identifying the current - bottleneck. -6. **Full Source** --- Download link to the example file. - -## Writing guidelines - -### Explain the "why", not just the "what" -- For every `sync()` call, explain what it guards (e.g., "ensures shared memory - writes are visible to all threads before the MMA warp reads them"). -- For magic numbers like `warps = 4`, explain the choice (e.g., "4 warps = 128 - threads; later versions use more warps to overlap loading and computing"). -- For `enable_input_d`, explain: "On the first iteration, tensor memory contains - uninitialized data, so we ignore it. On subsequent iterations, it holds the - running sum from prior tiles." -- For mbarrier phase flipping, explain: "The same barrier is reused across - iterations. The phase distinguishes this iteration's completion from the - previous one's." - -### Hyperlinks -- Use `:meth:` for instruction methods: - `:meth:`~tilus.Script.copy_async`` for root instructions, - `:meth:`tcgen05.mma `` - for instruction group methods (shows short name, links to full path). -- Use `:attr:` for attributes: `:attr:`self.attrs.blocks `` -- Use `:doc:` for cross-references to other pages: `:doc:`/programming-guides/thread-group`` -- Use `:class:` for tensor types: `:class:`~tilus.ir.tensor.TMemoryTensor`` - -### Code inclusion -- Always use `:start-at:` / `:end-at:` / `:start-after:` / `:end-before:` - instead of absolute line numbers. This makes includes resilient to code - changes. -- Use `:dedent:` to strip leading indentation when including method bodies. -- Use `:caption:` for all included code blocks. - -### Figures and diagrams -- Place SVGs in a `figures/` subdirectory next to the tutorial RST files. -- Use `.. figure::` with `:width:` and `:align: center`. -- SVGs should be editable in draw.io for collaborative iteration. -- Suggest diagrams for: block tiling, data flow, pipeline stages, - cluster layouts, and any concept that benefits from a visual. - -### Tone -- Concise and direct. Avoid filler words. -- Vary sentence structure --- avoid starting every bullet with "We ..." - (lead with the instruction/concept name instead). -- Don't over-explain concepts that are well-covered in linked pages. The tutorial - should give enough to understand the example; detailed semantics belong in the - programming guides and API docs. - -## Reference tutorial -- Blackwell matmul V0: `docs/source/tutorials/matmul-blackwell/v0.rst` diff --git a/examples/hopper_matmul/benchmark.py b/examples/hopper_matmul/benchmark.py index 0e65bfe1..4daa3c5f 100644 --- a/examples/hopper_matmul/benchmark.py +++ b/examples/hopper_matmul/benchmark.py @@ -5,34 +5,17 @@ Run directly: python benchmark.py - -Or via the slurm job: - sbatch sample_slurm_hopper_benchmark.sh + python benchmark.py --ncu + python benchmark.py --versions v3 v4 v5 --size 4096 4096 4096 """ -import importlib -import math -import sys - -import pandas -import torch - -import tilus -from tilus.utils import benchmark_func, cdiv - -tilus.option.cache_dir("./cache") - -WORKLOADS = [ - # (m, n, k, label) - (1024, 1024, 1024, "1k-sq"), - (2048, 2048, 2048, "2k-sq"), - (4096, 4096, 4096, "4k-sq"), - (4096, 4096, 14336, "llm-ffn"), - (8192, 8192, 8192, "8k-sq"), - (10240, 10240, 10240, "10k-sq"), -] +import argparse +import csv +import io +import subprocess +import time -VERSIONS = ["v0", "v1", "v2", "v3", "v4", "v5"] +VERSION_NAMES = ["v0", "v1", "v2", "v3", "v4", "v5"] VERSION_CLASS = { "v0": "MatmulTMA", @@ -43,74 +26,231 @@ "v5": "MatmulWGMMAV5", } -WARMUP = 5 -REPEAT = 30 +def _load_version(name: str): + """Lazily import a matmul module by version name and return the class.""" + import importlib + + import tilus + + tilus.option.cache_dir("./cache") + + module = importlib.import_module(f"matmul_{name}") + return getattr(module, VERSION_CLASS[name]) + + +def run_kernels(version_names: list, m_size: int, n_size: int, k_size: int): + """Run cuBLAS and tilus matmul versions sequentially (used as the target for ncu_run).""" + import torch -def load_version(name: str): - mod = importlib.import_module(f"matmul_{name}") - return getattr(mod, VERSION_CLASS[name]) + a = torch.randn(m_size, k_size, dtype=torch.float16, device="cuda") + b = torch.randn(n_size, k_size, dtype=torch.float16, device="cuda") + c = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") + # tilus versions + for name in version_names: + matmul = _load_version(name)() + matmul(m_size, n_size, k_size, a, b, c) + torch.cuda.synchronize() -def tflops(m, n, k, latency_ms): - return 2 * m * n * k / latency_ms * 1e-9 + # cuBLAS + _ = a @ b.T + torch.cuda.synchronize() -def run_benchmark(versions=None): - if versions is None: - versions = VERSIONS +def _read_ncu_csv( + report_path: str, page: str, metrics: str | None = None +) -> csv.DictReader: + """Run ncu --import --csv and return a DictReader, skipping the units row.""" + cmd = ["/usr/local/cuda/bin/ncu", "--import", report_path, "--csv", "--page", page] + if metrics: + cmd += ["--metrics", metrics] + result = subprocess.run(cmd, capture_output=True, text=True) + reader = csv.DictReader(io.StringIO(result.stdout)) + next(reader, None) + return reader - device = torch.cuda.get_device_name(0) - print(f"Device: {device}") - print(f"Versions: {versions}") - print() - headers = ["workload", "m", "n", "k", "kernel", "latency (ms)", "tflops", "% of cublas"] +def _short_kernel_name(name: str) -> str: + idx = name.find("(") + return name[:idx] if idx != -1 else name + + +def parse_ncu_report(report_path: str) -> list[tuple[str, dict]]: + """Extract per-kernel metrics from an NCU report. Returns [(kernel_name, metrics), ...] in order.""" + tensor_col = "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed" + reader = _read_ncu_csv(report_path, "raw", metrics=tensor_col) + per_kernel: dict[str, dict] = {} + kernel_order: list[str] = [] + for row in reader: + kernel = _short_kernel_name(row["Kernel Name"]) + if kernel not in per_kernel: + per_kernel[kernel] = {} + kernel_order.append(kernel) + metrics = per_kernel[kernel] + if tensor_col in row and row[tensor_col]: + metrics["tensor_core_util (%)"] = float(row[tensor_col]) + + reader2 = _read_ncu_csv(report_path, "details") + for row in reader2: + kernel = _short_kernel_name(row["Kernel Name"]) + if kernel not in per_kernel: + per_kernel[kernel] = {} + kernel_order.append(kernel) + metrics = per_kernel[kernel] + if row.get("Metric Name") == "DRAM Throughput": + metrics["dram_throughput (%)"] = float(row["Metric Value"]) + if row.get("Metric Name") == "Compute (SM) Throughput": + metrics["sm_throughput (%)"] = float(row["Metric Value"]) + if row.get("Metric Name") == "SM Frequency": + metrics["sm_freq (GHz)"] = float(row["Metric Value"]) + if row.get("Metric Name") == "Duration": + value = float(row["Metric Value"]) + unit = row.get("Metric Unit", "ms") + if unit == "us": + value /= 1000.0 + elif unit == "s": + value *= 1000.0 + metrics["duration (ms)"] = value + + return [(k, per_kernel[k]) for k in kernel_order] + + +def benchmark_all(versions: list[str], m_size: int, n_size: int, k_size: int): + """Benchmark all versions using benchmark_func.""" + import math + + import pandas + import torch + from tilus.utils import benchmark_func + + headers = ["version", "latency (ms)", "tflops", "% of cublas"] rows = [] - for m, n, k, label in WORKLOADS: - print(f"--- {label} m={m} n={n} k={k} ---") - a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - c_ref = torch.empty(m, n, dtype=torch.float16).cuda() - c_tilus = torch.empty(m, n, dtype=torch.float16).cuda() - - # cuBLAS baseline - cublas_lat = benchmark_func(lambda: torch.matmul(a, b.T, out=c_ref), warmup=WARMUP, repeat=REPEAT) - cublas_tf = tflops(m, n, k, cublas_lat) - rows.append([label, m, n, k, "cublas", cublas_lat, cublas_tf, 100.0]) - print(f" cublas {cublas_lat:.4f} ms {cublas_tf:.1f} TFLOPS") - - for ver in versions: - try: - cls = load_version(ver) - kernel = cls() - # correctness check - kernel(m, n, k, a, b, c_tilus) - torch.cuda.synchronize() - torch.testing.assert_close(c_ref, c_tilus, atol=1e-2, rtol=1e-2) - - lat = benchmark_func(lambda: kernel(m, n, k, a, b, c_tilus), warmup=WARMUP, repeat=REPEAT) - tf = tflops(m, n, k, lat) - pct = tf / cublas_tf * 100.0 - rows.append([label, m, n, k, f"tilus-{ver}", lat, tf, pct]) - print(f" tilus-{ver} {lat:.4f} ms {tf:.1f} TFLOPS ({pct:.1f}% of cuBLAS)") - except Exception as e: - print(f" tilus-{ver} ERROR: {e}", file=sys.stderr) - rows.append([label, m, n, k, f"tilus-{ver}", float("nan"), float("nan"), float("nan")]) - - print() + a = (torch.rand(m_size, k_size, dtype=torch.float16, device="cuda") - 0.5) / math.sqrt(k_size) + b = (torch.rand(n_size, k_size, dtype=torch.float16, device="cuda") - 0.5) / math.sqrt(k_size) + c_ref = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") + c_tilus = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") + + # cuBLAS baseline first, so we can compute % of cublas + cublas_lat = benchmark_func(lambda: torch.matmul(a, b.T, out=c_ref), warmup=5, repeat=30) + cublas_tf = 2 * m_size * n_size * k_size / cublas_lat * 1e-9 + + for name in versions: + try: + matmul = _load_version(name)() + # warmup + correctness check + matmul(m_size, n_size, k_size, a, b, c_tilus) + torch.cuda.synchronize() + torch.testing.assert_close(c_ref, c_tilus, atol=1e-2, rtol=1e-2) + + latency = benchmark_func( + lambda: matmul(m_size, n_size, k_size, a, b, c_tilus), warmup=5, repeat=30 + ) + tf = 2 * m_size * n_size * k_size / latency * 1e-9 + pct = tf / cublas_tf * 100.0 + rows.append([f"tilus_{name}", latency, tf, pct]) + time.sleep(1) # cool down between runs + except Exception as e: + print(f" tilus_{name} ERROR: {e}") + rows.append([f"tilus_{name}", float("nan"), float("nan"), float("nan")]) + + rows.append(["cublas", cublas_lat, cublas_tf, 100.0]) df = pandas.DataFrame(rows, columns=headers) - print("\n=== Summary ===") + print(f"\nBenchmark results (m={m_size}, n={n_size}, k={k_size}):") print(df.to_string(index=False)) - return df -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--versions", nargs="+", default=None, - help="Subset of versions to benchmark, e.g. --versions v3 v4 v5") +def ncu_profile_all(versions: list[str], m_size: int, n_size: int, k_size: int): + """Profile all versions in a single ncu_run and extract key metrics.""" + import pandas + import tilus + + print("Warming up (JIT + autotuning)...") + run_kernels(versions, m_size, n_size, k_size) + + labels = list(versions) + ["cublas"] + + print(f"Profiling cublas, {', '.join(versions)} ...") + report = tilus.utils.ncu_run( + run_kernels, + versions, + m_size, + n_size, + k_size, + kernel_regex="tilus|cutlass|sm90|gemm|cublas", + ) + print(f"Report saved to: {report.report_path}") + + kernel_metrics = parse_ncu_report(report.report_path) + + headers = [ + "version", + "kernel", + "duration (ms)", + "tflops", + "sm_freq (GHz)", + "sm_throughput (%)", + "dram_throughput (%)", + "tensor_core_util (%)", + ] + rows = [] + for i, name in enumerate(labels): + if i < len(kernel_metrics): + kernel, metrics = kernel_metrics[i] + else: + kernel, metrics = "?", {} + duration_ms = metrics.get("duration (ms)", "") + tflops = 2 * m_size * n_size * k_size / duration_ms * 1e-9 if duration_ms else "" + rows.append( + [ + name, + kernel, + duration_ms, + tflops, + metrics.get("sm_freq (GHz)", ""), + metrics.get("sm_throughput (%)", ""), + metrics.get("dram_throughput (%)", ""), + metrics.get("tensor_core_util (%)", ""), + ] + ) + + df = pandas.DataFrame(rows, columns=headers) + print(f"\nNCU profiling results (m={m_size}, n={n_size}, k={k_size}):") + print(df.to_string(index=False)) + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark Hopper matmul V0-V5") + parser.add_argument( + "--ncu", + action="store_true", + help="Use NCU profiling instead of benchmark_func", + ) + parser.add_argument( + "--versions", + nargs="+", + default=VERSION_NAMES, + choices=VERSION_NAMES, + help="Which versions to benchmark (default: all)", + ) + parser.add_argument( + "--size", + nargs=3, + type=int, + default=[8192, 8192, 8192], + metavar=("M", "N", "K"), + help="Workload size M N K (default: 8192 8192 8192)", + ) args = parser.parse_args() - run_benchmark(versions=args.versions) + m_size, n_size, k_size = args.size + + if args.ncu: + ncu_profile_all(args.versions, m_size, n_size, k_size) + else: + benchmark_all(args.versions, m_size, n_size, k_size) + + +if __name__ == "__main__": + main() diff --git a/examples/hopper_matmul/matmul_v3.py b/examples/hopper_matmul/matmul_v3.py index a6eb96d8..ff850539 100644 --- a/examples/hopper_matmul/matmul_v3.py +++ b/examples/hopper_matmul/matmul_v3.py @@ -69,23 +69,26 @@ def __call__( for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): self.mbarrier.wait(producer_barriers[stage], phase=producer_phases[stage]) producer_phases[stage] ^= 1 + # Producer is already 32 threads (one warp) — TMA needs that to + # be warp-cooperative. Only the arrive must be by a single + # thread so tx-bytes is counted once. with self.single_thread(): self.mbarrier.arrive_and_expect_tx( consumer_barriers[stage], transaction_bytes=sa[stage].nbytes + sb[stage].nbytes, ) - self.tma.global_to_shared( - src=ga, - dst=sa[stage], - offsets=[offset_m, offset_k], - mbarrier=consumer_barriers[stage], - ) - self.tma.global_to_shared( - src=gb, - dst=sb[stage], - offsets=[offset_n, offset_k], - mbarrier=consumer_barriers[stage], - ) + self.tma.global_to_shared( + src=ga, + dst=sa[stage], + offsets=[offset_m, offset_k], + mbarrier=consumer_barriers[stage], + ) + self.tma.global_to_shared( + src=gb, + dst=sb[stage], + offsets=[offset_n, offset_k], + mbarrier=consumer_barriers[stage], + ) stage = (stage + 1) % self.num_stages for _ in self.range(min(self.num_stages, cdiv(k_size, self.block_k))): diff --git a/examples/hopper_matmul/matmul_v4.py b/examples/hopper_matmul/matmul_v4.py index 229d2796..c35d533e 100644 --- a/examples/hopper_matmul/matmul_v4.py +++ b/examples/hopper_matmul/matmul_v4.py @@ -148,24 +148,27 @@ def __call__( with self.thread_group(thread_begin=128, num_threads=32): # TMA producer warp for offset_k in self.range(0, k_size, block_k, unroll=num_stages): tma_pipe.producer_acquire() + # Producer warp is already 32 threads — the granularity TMA + # needs at SASS level. Only the arrive runs in single_thread so + # transaction-bytes is counted once. with self.single_thread(): self.mbarrier.arrive_and_expect_tx( tma_pipe.producer_barrier(), transaction_bytes=sa[tma_pipe.producer_stage].nbytes + sb[tma_pipe.producer_stage].nbytes, ) - self.tma.global_to_shared( - src=ga, - dst=sa[tma_pipe.producer_stage], - offsets=[offset_m, offset_k], - mbarrier=tma_pipe.producer_barrier(), - ) - self.tma.global_to_shared( - src=gb, - dst=sb[tma_pipe.producer_stage], - offsets=[offset_n, offset_k], - mbarrier=tma_pipe.producer_barrier(), - ) + self.tma.global_to_shared( + src=ga, + dst=sa[tma_pipe.producer_stage], + offsets=[offset_m, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + self.tma.global_to_shared( + src=gb, + dst=sb[tma_pipe.producer_stage], + offsets=[offset_n, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) tma_pipe.producer_advance() # drain: wait for consumer to finish processing all in-flight stages diff --git a/examples/hopper_matmul/matmul_v5.py b/examples/hopper_matmul/matmul_v5.py index 6a2187d1..50790798 100644 --- a/examples/hopper_matmul/matmul_v5.py +++ b/examples/hopper_matmul/matmul_v5.py @@ -149,24 +149,27 @@ def __call__( with self.thread_group(thread_begin=128, num_threads=32): # TMA producer warp for offset_k in self.range(0, k_size, block_k, unroll=num_stages): tma_pipe.producer_acquire() + # Producer warp is already 32 threads — the granularity TMA + # needs at SASS level. Only the arrive runs in single_thread so + # transaction-bytes is counted once. with self.single_thread(): self.mbarrier.arrive_and_expect_tx( tma_pipe.producer_barrier(), transaction_bytes=sa[tma_pipe.producer_stage].nbytes + sb[tma_pipe.producer_stage].nbytes, ) - self.tma.global_to_shared( - src=ga, - dst=sa[tma_pipe.producer_stage], - offsets=[offset_m, offset_k], - mbarrier=tma_pipe.producer_barrier(), - ) - self.tma.global_to_shared( - src=gb, - dst=sb[tma_pipe.producer_stage], - offsets=[offset_n, offset_k], - mbarrier=tma_pipe.producer_barrier(), - ) + self.tma.global_to_shared( + src=ga, + dst=sa[tma_pipe.producer_stage], + offsets=[offset_m, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + self.tma.global_to_shared( + src=gb, + dst=sb[tma_pipe.producer_stage], + offsets=[offset_n, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) tma_pipe.producer_advance() for _ in self.range(min(num_stages, cdiv(k_size, block_k))): @@ -205,7 +208,8 @@ def __call__( # fence required: store_shared uses generic proxy; TMA uses async proxy self.fence.proxy_async(space="shared") self.sync() - with self.single_thread(): + # tma.shared_to_global is warp-cooperative; run inside single_warp. + with self.single_warp(): self.tma.shared_to_global( sc, gc, diff --git a/python/tilus/backends/contexts/mbarrier_alloc_ctx.py b/python/tilus/backends/contexts/mbarrier_alloc_ctx.py index 73c5ab8b..b5a4a8ca 100644 --- a/python/tilus/backends/contexts/mbarrier_alloc_ctx.py +++ b/python/tilus/backends/contexts/mbarrier_alloc_ctx.py @@ -26,8 +26,6 @@ from tilus.hidet.ir.primitives.cuda.smem import dynamic_shared_memory from tilus.hidet.ir.primitives.cuda.sync import syncthreads from tilus.hidet.ir.primitives.cuda.vars import threadIdx -from tilus.ir.layout import ops -from tilus.ir.tensor import SharedTensor class BarrierAllocContext(BaseEmitContext): @@ -46,8 +44,17 @@ def finalize(self): # No barriers to allocate return - tensor = SharedTensor(dtype=uint64, shape=(num_barriers,), optional_layout=ops.shared_row_major(num_barriers)) - virtual_smem_addr = self.contexts.smem_alloc_ctx.allocate_shared_tensor(tensor, nbytes=tensor.storage_nbytes) + # Place barriers past the smem high-water mark rather than into a slot that + # was freed by FreeShared. The smem allocator only tracks the static + # post-emit free list; at runtime, data tensors that get freed *after* the + # main loop are still alive while the barriers are in use, so reusing their + # slot would cause TMA writes to clobber the barrier state. + smem_alloc_ctx = self.contexts.smem_alloc_ctx + nbytes = num_barriers * uint64.nbytes + alignment = 128 + aligned_hwm = (smem_alloc_ctx.smem_allocator.maximum_allocated + alignment - 1) // alignment * alignment + virtual_smem_addr = aligned_hwm + smem_alloc_ctx.smem_allocator.maximum_allocated = aligned_hwm + nbytes sb = StmtBuilder() sb.declare( v=self.barrier_addr, diff --git a/python/tilus/backends/emitter.py b/python/tilus/backends/emitter.py index 4078b982..955197cb 100644 --- a/python/tilus/backends/emitter.py +++ b/python/tilus/backends/emitter.py @@ -59,6 +59,32 @@ def assert_is_warp_aligned(self, inst: Instruction, msg: str) -> None: f"got thread_begin={self.current_thread_group_begin}, num_threads={self.current_num_threads}: {msg}." ) + def assert_is_single_thread_or_warp_aligned(self, inst: Instruction, msg: str) -> None: + # TMA copies must be issued by exactly one thread. The user can express + # that with single_thread() (num_threads == 1), or at warp scope where the + # `@pred` predicate selects the elected lane. Both are valid; reject only + # multi-thread non-warp contexts. + if self.current_num_threads == 1: + return + if self.current_num_threads != 32 or self.current_thread_group_begin % 32 != 0: + raise ValueError( + f"Instruction {inst} requires a single-thread or warp-aligned context " + f"(num_threads==1, or thread_begin % 32 == 0 and num_threads == 32), " + f"got thread_begin={self.current_thread_group_begin}, num_threads={self.current_num_threads}: {msg}." + ) + + @property + def tma_predicate(self): + # Inside single_thread() only one thread runs the TMA call, so the + # @pred predicate is the constant 1. At warp scope we still need to + # select a single lane, so use the elected leader-lane predicate to + # avoid an if-branch divergence. + from tilus.hidet.ir.dtypes import uint32 as _u32 + + if self.current_num_threads == 1: + return _u32(1) + return self.contexts.leader_lane_ctx.leader_lane + def sync(self): optional_sync_call = self.contexts.sync_ctx.sync() if optional_sync_call is not None: diff --git a/python/tilus/backends/emitters/cuda/cp_async_tensor.py b/python/tilus/backends/emitters/cuda/cp_async_tensor.py index e71b18cd..39c99e3d 100644 --- a/python/tilus/backends/emitters/cuda/cp_async_tensor.py +++ b/python/tilus/backends/emitters/cuda/cp_async_tensor.py @@ -52,7 +52,7 @@ from tilus.ir.tensor import GlobalTensor, SharedTensor from tilus.ir.utils.lineardec import LinearDecompositionError, decompose_linear from tilus.ir.utils.veceval import vectorized_evaluate -from tilus.target import nvgpu_sm90 +from tilus.target import get_current_target, nvgpu_sm90 @dataclass(frozen=True, eq=False) @@ -285,7 +285,7 @@ def create_tensor_map(self, global_info: GlobalTensorInfo, shared_info: SharedTe @register_emitter(CopyAsyncTensorGlobalToSharedInst, target=nvgpu_sm90) class CopyAsyncTensorGlobalToSharedInstEmitter(CopyAsyncTensorBaseEmitter): def emit(self, inst: CopyAsyncTensorGlobalToSharedInst) -> None: - self.assert_is_warp_aligned(inst, "TMA global to shared is a warp-cooperative instruction") + self.assert_is_single_thread_or_warp_aligned(inst, "TMA global to shared must be issued by one thread") global_tensor: GlobalTensor = inst.inputs[1].as_global_tensor() shared_tensor: SharedTensor = inst.inputs[0].as_shared_tensor() assert global_tensor.dtype == shared_tensor.dtype @@ -301,7 +301,11 @@ def emit(self, inst: CopyAsyncTensorGlobalToSharedInst) -> None: src_tensor_map = ~self.create_tensor_map(global_tensor_info, shared_tensor_info, dtype) coords = list(reversed(inst.offsets)) optional_multicast_mask = inst.multicast_mask - predicate = self.contexts.leader_lane_ctx.leader_lane + predicate = self.tma_predicate + # `.cta_group::{n}` is a Blackwell (sm_100+) PTX feature; ptxas rejects it on + # sm_90a even though the IR always carries cta_group=1. Pass None on Hopper so + # the inline asm template emits the unqualified TMA instruction. + cta_group = inst.cta_group if get_current_target().properties.compute_capability >= (10, 0) else None if optional_multicast_mask is None: self.append( @@ -310,7 +314,7 @@ def emit(self, inst: CopyAsyncTensorGlobalToSharedInst) -> None: src_tensor_map=src_tensor_map, coords=coords, mbarrier=inst.mbarrier, - cta_group=inst.cta_group, + cta_group=cta_group, cache_policy=inst.cache_policy, predicate=predicate, ) @@ -324,7 +328,7 @@ def emit(self, inst: CopyAsyncTensorGlobalToSharedInst) -> None: coords=coords, mbarrier=inst.mbarrier, multicast_mask=multicast_mask, - cta_group=inst.cta_group, + cta_group=cta_group, cache_policy=inst.cache_policy, predicate=predicate, ) @@ -334,7 +338,7 @@ def emit(self, inst: CopyAsyncTensorGlobalToSharedInst) -> None: @register_emitter(CopyAsyncTensorSharedToGlobalInst, target=nvgpu_sm90) class CopyAsyncTensorSharedToGlobalInstEmitter(CopyAsyncTensorBaseEmitter): def emit(self, inst: CopyAsyncTensorSharedToGlobalInst) -> None: - self.assert_is_warp_aligned(inst, "TMA shared to global is a warp-cooperative instruction") + self.assert_is_single_thread_or_warp_aligned(inst, "TMA shared to global must be issued by one thread") global_tensor: GlobalTensor = inst.inputs[0].as_global_tensor() shared_tensor: SharedTensor = inst.inputs[1].as_shared_tensor() assert global_tensor.dtype == shared_tensor.dtype @@ -355,7 +359,7 @@ def emit(self, inst: CopyAsyncTensorSharedToGlobalInst) -> None: src=shared_addr, coords=list(reversed(tensor_coords)), cache_policy=inst.cache_policy, - predicate=self.contexts.leader_lane_ctx.leader_lane, + predicate=self.tma_predicate, ) ) diff --git a/python/tilus/hidet/ir/analyzers/bound_analyzer.py b/python/tilus/hidet/ir/analyzers/bound_analyzer.py index 633f935e..14a3366f 100644 --- a/python/tilus/hidet/ir/analyzers/bound_analyzer.py +++ b/python/tilus/hidet/ir/analyzers/bound_analyzer.py @@ -237,8 +237,8 @@ class BoundAnalyzer(ExprVisitor, StmtVisitor, ModuleVisitor): Add: operator.add, Sub: operator.sub, Multiply: operator.mul, - Mod: operator.mod, - Div: operator.floordiv, # for the node with BoundInfo, we are sure they are integers + Mod: operator.mod, # floor-mod; only used for candidate enumeration (exact path) + Div: operator.floordiv, # floor-div; same — only used when candidates are enumerable } def __init__(self, var2bound: Dict[Expr, BoundInfo] = None): diff --git a/python/tilus/hidet/option.py b/python/tilus/hidet/option.py index 5bd3a8bf..b61e0e58 100644 --- a/python/tilus/hidet/option.py +++ b/python/tilus/hidet/option.py @@ -27,6 +27,7 @@ from __future__ import annotations +import os from typing import Any, Callable, Dict, Iterable, List, Optional @@ -115,6 +116,18 @@ def get_option(self, name: str) -> Any: if name not in OptionRegistry.registered_options: raise KeyError(f"Option {name} has not been registered.") registry = OptionRegistry.registered_options[name] + if registry.env is not None and registry.env in os.environ: + raw = os.environ[registry.env] + if registry.normalizer is not None: + return registry.normalizer(raw) + # coerce to the same type as the default value when no normalizer is provided + if isinstance(registry.default_value, int): + return int(raw) + if isinstance(registry.default_value, float): + return float(raw) + if isinstance(registry.default_value, bool): + return raw.lower() not in ("0", "false", "no", "off") + return raw return registry.default_value diff --git a/python/tilus/hidet/transforms/_rule_based_simplifier_base.py b/python/tilus/hidet/transforms/_rule_based_simplifier_base.py index e227c836..e8441e42 100644 --- a/python/tilus/hidet/transforms/_rule_based_simplifier_base.py +++ b/python/tilus/hidet/transforms/_rule_based_simplifier_base.py @@ -82,6 +82,35 @@ def c_div(a, b): return a / b +def _c_trunc_div(a, b): + """Integer division truncated toward zero (C semantics: a/b). + + When a or b are not plain ints (i.e. they are Hidet IR Expr nodes used + for pattern construction) fall back to Python's // so that the operator + overloads build the correct IR node. The C-semantics branch is only + reached during candidate-value verification where both operands are ints. + """ + if not isinstance(a, int) or not isinstance(b, int): + return a // b # IR expression path — builds a Hidet Div node + if b == 0: + return 0 + return int(a / b) # truncate-toward-zero, avoiding Python's floor behaviour + + +def _c_trunc_mod(a, b): + """Integer modulo with C semantics (result has the sign of a, not b). + + Same dual-mode design as _c_trunc_div: falls back to Python % when called + with Hidet IR Expr nodes so the operator overloads build a Hidet Mod node. + The C-semantics branch is only reached during candidate-value verification. + """ + if not isinstance(a, int) or not isinstance(b, int): + return a % b # IR expression path — builds a Hidet Mod node + if b == 0: + return 0 + return a - b * _c_trunc_div(a, b) + + class ConstExprSimplifier(IRRewriter): op_dict = { Add: operator.add, @@ -92,7 +121,7 @@ class ConstExprSimplifier(IRRewriter): BitwiseAnd: operator.and_, BitwiseXor: operator.xor, BitwiseNot: operator.invert, - Mod: operator.mod, + Mod: _c_trunc_mod, LessThan: operator.lt, LessEqual: operator.le, Equal: operator.eq, @@ -192,21 +221,29 @@ def __init__(self, skip_node_types: Optional[Sequence[Type[Expr]]] = None): (IfThenElse(ec1, ec2, ec2), ec2), ] self.bound_patterns = [ - # ((pattern_args, pattern_func, target_args, target_func) + # Verify using C-compatible truncated division/modulo rather than Python + # floor division/modulo. The generated code uses C's % and /, which + # differ from Python for negative operands, so the candidate-value checks + # must use the same semantics as the emitted C code. ( (ec1, ec2, c1), (ec1, ec2, c1), - lambda ec1, ec2, c1: (ec1 + ec2) // c1, - lambda ec1, ec2, c1: ec1 // c1 + ec2 // c1, + lambda ec1, ec2, c1: _c_trunc_div(ec1 + ec2, c1), + lambda ec1, ec2, c1: _c_trunc_div(ec1, c1) + _c_trunc_div(ec2, c1), ), ( (ec1, ec2, c1), (ec1, ec2, c1), - lambda ec1, ec2, c1: (ec1 + ec2) % c1, - lambda ec1, ec2, c1: ec1 % c1 + ec2 % c1, + lambda ec1, ec2, c1: _c_trunc_mod(ec1 + ec2, c1), + lambda ec1, ec2, c1: _c_trunc_mod(ec1, c1) + _c_trunc_mod(ec2, c1), + ), + ((ec1, c1), (ec1,), lambda ec1, c1: _c_trunc_mod(ec1, c1), lambda ec1: ec1), + ( + (ec1, c1, c2), + (ec1, c2), + lambda ec1, c1, c2: _c_trunc_mod(_c_trunc_mod(ec1, c1), c2), + lambda ec1, c2: _c_trunc_mod(ec1, c2), ), - ((ec1, c1), (ec1,), lambda ec1, c1: ec1 % c1, lambda ec1: ec1), - ((ec1, c1, c2), (ec1, c2), lambda ec1, c1, c2: (ec1 % c1) % c2, lambda ec1, c2: ec1 % c2), ] if skip_node_types: for skip_type in skip_node_types: @@ -288,7 +325,10 @@ def visit(self, obj): def visit_Mod(self, e: Mod): ua, ub = self.bound[e.a], self.bound[e.b] - if ua.is_zero() or ua < ub: + # In C, x % y == x only when 0 <= x < y (truncated modulo). The < check + # covers the upper bound; also require a provably non-negative lower bound. + a_min = ua.possible_min_value() + if ua.is_zero() or (ua < ub and a_min is not None and a_min >= 0): return self(e.a) return IRRewriter.visit_Mod(self, e) diff --git a/python/tilus/lang/instantiated_script.py b/python/tilus/lang/instantiated_script.py index 222cd8b7..275016a8 100644 --- a/python/tilus/lang/instantiated_script.py +++ b/python/tilus/lang/instantiated_script.py @@ -44,6 +44,11 @@ logger = logging.getLogger(__name__) +# Bump when tuner semantics change (e.g. correctness gates, new selection +# criteria). On bump, dispatch_table.json files written under the prior version +# are ignored and tuning re-runs, so users don't have to manually delete cache. +_TUNER_VERSION = 2 + def span_space(space: Mapping[str, Sequence[Any]]) -> list[dict[str, Any]]: """ @@ -650,10 +655,24 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: latency.append(0.0) else: tuning_key_name = " " + "-".join([str(v) for v in tuning_key]) if tuning_key else "" - kernel_args = [ - args[i].clone() if isinstance(args[i], torch.Tensor) else args[i] - for i in self.call_params.kernel_params - ] + # Clone tensor args fresh for each candidate. The clone serves two + # purposes: shield the user's buffers from mutation, and prevent + # one candidate's output (e.g. NaN) from becoming the next + # candidate's input. + # Snapshot pre-kernel finiteness so the correctness gate flags + # only finite→non-finite transitions caused by the kernel. + # User-supplied output buffers (e.g. torch.empty()) may already + # contain NaN/Inf bit patterns from uninitialized memory, and + # cloning preserves those bits — so a slot that was already + # non-finite cannot fail the gate. + nan_gate_enabled: bool = bool(tilus.option.get_option("autotune_nan_gate")) + pre_finite: list[bool] = [] + for j in self.call_params.kernel_params: + a_j = args[j] + if isinstance(a_j, torch.Tensor) and a_j.is_floating_point(): + pre_finite.append(bool(torch.isfinite(a_j).all().item())) + else: + pre_finite.append(True) for i, compiled_program in tqdm( iterable=enumerate(self.compiled_programs), desc="[{}] {}{}".format("Tuning", self.instance_name, tuning_key_name), @@ -662,14 +681,16 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: ncols=60 + max(60, len(self.instance_name) + len(tuning_key_name)), ): compiled_func = compiled_program.get_launch_func() + kernel_args = [ + args[j].clone() if isinstance(args[j], torch.Tensor) else args[j] + for j in self.call_params.kernel_params + ] try: - latency.append( - benchmark_func( - lambda: compiled_func(*kernel_args), - warmup=tilus.option.get_option("bench_warmup"), - repeat=tilus.option.get_option("bench_repeat"), - ) - ) # type: ignore + lat = benchmark_func( + lambda: compiled_func(*kernel_args), + warmup=tilus.option.get_option("bench_warmup"), + repeat=tilus.option.get_option("bench_repeat"), + ) except RuntimeError as e: raise RuntimeError( f"Failed to benchmark the kernel {self.instance_name} with schedule: \n" @@ -677,8 +698,33 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: "Error message:\n" f" {str(e)}" ) from e + # Correctness gate: reject candidates that flip a kernel + # tensor arg from finite to non-finite. Slots that were + # already non-finite before the kernel ran are ignored — + # otherwise uninitialized output buffers spuriously fail + # the gate even when the kernel writes valid values. + if nan_gate_enabled: + for slot_idx, t in enumerate(kernel_args): + if not pre_finite[slot_idx]: + continue + if ( + isinstance(t, torch.Tensor) + and t.is_floating_point() + and not torch.isfinite(t).all() + ): + lat = float("inf") + break + latency.append(lat) # type: ignore best_latency = min(latency) + if not (best_latency < float("inf")): + raise RuntimeError( + f"Autotune for {self.instance_name} found no schedule that produced finite outputs. " + f"All {len(latency)} candidates flipped a kernel tensor argument from finite to NaN/Inf. " + f"Inspect schedules in {self.cache_dir} or narrow the autotune space. " + "If you want to confirm whether the gate is the cause, re-run with " + "TILUS_AUTOTUNE_NAN_GATE=0 (the autotuner will then accept any candidate)." + ) best_program_idx = latency.index(best_latency) self.dispatch_table[tuning_key] = best_program_idx self.dump_dispatch_table() @@ -707,7 +753,15 @@ def load_dispatch_table(self): table_path = self.cache_dir / "dispatch_table.json" if table_path.exists(): with open(table_path, "r") as f: - entries = json.load(f) + payload = json.load(f) + # New format is {"tuner_version": int, "entries": [...]}; legacy format + # was a bare list. Treat legacy or version-mismatched files as empty so + # changes to tuner semantics (e.g. adding a NaN/Inf rejection gate) + # automatically re-run tuning instead of serving stale picks. + if isinstance(payload, dict) and payload.get("tuner_version") == _TUNER_VERSION: + entries = payload["entries"] + else: + entries = [] self.dispatch_table = {tuple(key): value for key, value in entries} def dump_dispatch_table(self): @@ -715,7 +769,7 @@ def dump_dispatch_table(self): table_txt_path = self.cache_dir / "dispatch_table.txt" entries = [[list(key), value] for key, value in self.dispatch_table.items()] with open(table_path, "w") as f: - json.dump(entries, f) + json.dump({"tuner_version": _TUNER_VERSION, "entries": entries}, f) headers = [] for idx in self.call_params.tuning_params: headers.append(self.call_params.param_names[idx]) diff --git a/python/tilus/option.py b/python/tilus/option.py index e9e6e87b..9ee8291c 100644 --- a/python/tilus/option.py +++ b/python/tilus/option.py @@ -61,6 +61,7 @@ def _register_options(): "tilus.parallel_workers", type_hint="int", default_value=os.cpu_count(), + env="TILUS_PARALLEL_WORKERS", description="The number of parallel workers the tilus package could use for parallel jobs.", ) _register_hidet_option( @@ -89,6 +90,17 @@ def _register_options(): default_value=50, description="The number of repeat iterations for benchmarking during autotuning.", ) + _register_hidet_option( + "tilus.autotune_nan_gate", + type_hint="bool", + default_value=True, + env="TILUS_AUTOTUNE_NAN_GATE", + description=( + "Whether the autotuner rejects schedules that flip a kernel tensor argument from " + "finite to non-finite. Set to 0 to disable for diagnosing whether a tuning failure " + "is caused by the gate or by the kernel actually producing NaN/Inf." + ), + ) _register_options() diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index eef63fb1..690878a7 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -64,6 +64,8 @@ ("hopper_matmul", "matmul_v1.py", nvgpu_sm90a), ("hopper_matmul", "matmul_v2.py", nvgpu_sm90a), ("hopper_matmul", "matmul_v3.py", nvgpu_sm90a), + ("hopper_matmul", "matmul_v4.py", nvgpu_sm90a), + ("hopper_matmul", "matmul_v5.py", nvgpu_sm90a), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), # flash attention decode examples (SM 8.0+) @@ -78,6 +80,7 @@ ("flash_attention_decode", "tilus_kernel.py"), # Benchmark utilities ("blackwell_matmul", "benchmark.py"), + ("hopper_matmul", "benchmark.py"), ] From f3015a9abe3aec1f854185f6dd77f7988b601d1b Mon Sep 17 00:00:00 2001 From: William Zhang Date: Wed, 6 May 2026 20:29:45 +0000 Subject: [PATCH 05/21] restore claude folder Signed-off-by: William Zhang --- .claude/skills/ncu-report/SKILL.md | 271 +++++++++++++++++++++++++++++ .claude/skills/write-docs/SKILL.md | 220 +++++++++++++++++++++++ 2 files changed, 491 insertions(+) create mode 100644 .claude/skills/ncu-report/SKILL.md create mode 100644 .claude/skills/write-docs/SKILL.md diff --git a/.claude/skills/ncu-report/SKILL.md b/.claude/skills/ncu-report/SKILL.md new file mode 100644 index 00000000..9fbccf1f --- /dev/null +++ b/.claude/skills/ncu-report/SKILL.md @@ -0,0 +1,271 @@ +--- +name: ncu-report +description: > + Analyze NVIDIA Nsight Compute (ncu) profiling reports (.ncu-rep files). + Extract metrics, performance data, SASS/CUDA source, and identify bottlenecks. + TRIGGER when: user asks to analyze, profile, or look at an ncu report, .ncu-rep file, + Nsight Compute report, kernel performance/profiling data from ncu, or asks to generate/collect + an ncu profile for a tilus kernel or example script. + DO NOT TRIGGER when: user is writing unrelated profiling code. +user-invocable: true +--- + +# Nsight Compute Report Analysis + +This skill handles two modes: + +1. **Analyze an existing report**: The user provides a path to an `.ncu-rep` file (or one exists under `examples/`). Use the `ncu` CLI to extract and present metrics. + +2. **Generate a new report**: The user specifies a script or kernel to profile but does NOT provide a `.ncu-rep` file. In this case, set up profiling using `tilus.utils.ncu_utils.ncu_run()`, run it, then analyze the resulting report. + +## Generating a Report + +Tilus provides `tilus.utils.ncu_utils.ncu_run()` to profile kernels with full metrics and source correlation. + +### API +```python +from tilus.utils.ncu_utils import ncu_run + +# ncu_run(func, *args, kernel_regex=".*", **kwargs) -> NsightComputeReport +report = ncu_run(main, bench=False, kernel_regex="tilus|nvjet") +``` + +- `func`: a callable (typically a `main()` function) that runs the kernel(s) to profile +- `*args`, `**kwargs`: passed through to `func` +- `kernel_regex`: regex to filter which kernels to profile (default `".*"`) +- Returns `NsightComputeReport` with `.report_path` pointing to the generated `.ncu-rep` file + +### What it does +- Runs the function under `ncu` with `--set full`, all rules enabled, and `--import-source yes` +- Saves the report to `ncu-reports/reportN.ncu-rep` next to the script (auto-increments N) +- Uses the system Python and the `ncu` binary at `/usr/local/cuda/bin/ncu` + +### How to generate a report for the user + +**IMPORTANT**: `ncu_run()` must be called inside a `if __name__ == "__main__":` block. It works by re-importing the script as a subprocess under `ncu` — if `ncu_run()` is at module level, the subprocess will call `ncu_run()` again, causing infinite recursion and a runtime error. + +**Step 1: Read the script** — find the example script the user specified and read it to understand how the kernel is invoked. + +**Step 2: Set up profiling** — choose one of these approaches: + +- **If the script already has a `__main__` block with `ncu_run()`**: just run it directly. +- **If the script has a `__main__` block but no `ncu_run()`**: edit the `__main__` block to add `ncu_run()`. For example, if the block calls `main()`, change it to call `ncu_run(main, bench=False, kernel_regex="tilus")`. +- **If the script has no `__main__` block or is hard to modify** (e.g., it's a test file, or the kernel launch is deeply nested): create a new script next to it (e.g., `profile_.py`) that imports and calls the kernel under `ncu_run()`. + +Example of editing an existing `__main__` block: +```python +if __name__ == "__main__": + from tilus.utils.ncu_utils import ncu_run + ncu_run(main, bench=False, kernel_regex="tilus") +``` + +Example of creating a new profiling script: +```python +from tilus.utils.ncu_utils import ncu_run +from matmul_v9 import main + +if __name__ == "__main__": + ncu_run(main, bench=False, kernel_regex="tilus") +``` + +**Step 3: Run the script** — `python `. The report will be saved to `/ncu-reports/reportN.ncu-rep`. + +**Step 4: Analyze** — proceed to the Analysis Workflow below with the generated report. + +**Note**: `ncu` profiling requires `sudo` or appropriate permissions (CAP_SYS_ADMIN). If the command fails with permission errors, suggest running with `sudo`. + +## Analysis Workflow + +Follow this sequence. Skip steps the user doesn't need, but always start with Step 1. + +### Step 1: Overview — List kernels and session info + +Run these in parallel: + +```bash +# List all kernels with timing +ncu -i --page raw --csv --metrics gpu__time_duration.sum 2>&1 + +# Session/device info +ncu -i --page session --csv 2>&1 +``` + +Present a summary table: +- Kernel name (shortened), Block Size, Grid Size, Duration (ms) +- Device name, compute capability, CUDA version + +### Step 2: Speed of Light — Top-level throughput + +```bash +ncu -i --page details --csv --section SpeedOfLight 2>&1 +``` + +Key metrics to highlight per kernel: +- **Duration** (ms) +- **Compute (SM) Throughput** (%) — how busy the SMs are +- **Memory Throughput** (%) — overall memory utilization +- **DRAM Throughput** (%) — HBM bandwidth utilization +- **L1/TEX Cache Throughput** (%) +- **L2 Cache Throughput** (%) +- **SOLBottleneck rule** — check the Rule Description column for bottleneck guidance + +### Step 3: Compute & Memory Workload Analysis + +```bash +# Compute workload +ncu -i --page details --csv --section ComputeWorkloadAnalysis 2>&1 + +# Memory workload +ncu -i --page details --csv --section MemoryWorkloadAnalysis 2>&1 +``` + +Key compute metrics: Executed IPC Active, SM Busy %, Issue Slots Busy % +Key memory metrics: Mem Busy %, Max Bandwidth %, L1/L2 hit rates + +### Step 4: Occupancy + +```bash +ncu -i --page details --csv --section Occupancy 2>&1 +``` + +Report: Theoretical Occupancy, Achieved Occupancy, and limiters (registers, shared memory, block size). + +### Step 5: Detailed metrics (on demand) + +To extract specific raw metrics: +```bash +ncu -i --page raw --csv --metrics ,,... 2>&1 +``` + +To filter by kernel: +```bash +ncu -i --page raw --csv --metrics --kernel-name regex: 2>&1 +``` + +### Step 6: Source-level analysis (on demand) + +SASS-only (default, always available): +```bash +ncu -i --page source --csv --kernel-name regex: 2>&1 +``` + +CUDA source correlated with SASS (requires `--import-source yes` during profiling): +```bash +ncu -i --page source --csv --print-source cuda,sass --kernel-name regex: 2>&1 +``` + +Source output columns include per-instruction: Warp Stall Sampling, Instructions Executed, Thread Instructions Executed, stall reasons (stall_barrier, stall_math, stall_wait, etc.), shared memory conflicts, and more. + +### Step 7: Rules / automated analysis (on demand) + +Rules are included in the details page output. Look for non-empty "Rule Name" column entries. +```bash +ncu -i --page details --csv --print-rule-details 2>&1 | grep -v '^"[0-9]' | head -5 # header +``` + +To see all rule results with descriptions: +```bash +ncu -i --page details --csv --print-rule-details 2>&1 +``` +Filter for rows where column 17 (Rule Name) is non-empty. + +## Reference: ncu CLI Options for Report Analysis + +### Pages (`--page`) +| Page | Description | +|------|-------------| +| `details` | Sections with metrics organized by section name + rule results | +| `raw` | All collected metrics as flat columns (one row per kernel) | +| `source` | Per-instruction source code with correlated metrics | +| `session` | Session info, device attributes, launch settings | + +### Key Flags +| Flag | Description | +|------|-------------| +| `--csv` | Output as CSV (essential for parsing) | +| `--metrics ,` | Filter specific metrics (for `raw` page) | +| `--section ` | Filter by section identifier (for `details` page) | +| `--kernel-name regex:` | Filter kernels by name regex | +| `--kernel-name ` | Filter by exact kernel name | +| `--print-source sass\|ptx\|cuda\|cuda,sass` | Select source view for `source` page | +| `--print-details header\|body\|all` | Control detail level: `header` (default), `body` (charts/tables), `all` | +| `--print-metric-name name` | Show internal metric names instead of display labels | +| `--print-metric-name label-name` | Show both label and internal name | +| `--print-units base` | Show metrics in base units (no auto-scaling) | +| `--print-summary per-kernel` | Aggregate across invocations per kernel (min/max/avg) | +| `--print-rule-details` | Include additional rule tables and KPI metrics | + +### Section Identifiers +| Identifier | Display Name | +|------------|-------------| +| `SpeedOfLight` | GPU Speed Of Light Throughput | +| `ComputeWorkloadAnalysis` | Compute Workload Analysis | +| `MemoryWorkloadAnalysis` | Memory Workload Analysis | +| `MemoryWorkloadAnalysis_Tables` | Memory Workload Analysis Tables | +| `Occupancy` | Occupancy | +| `LaunchStats` | Launch Statistics | +| `SchedulerStats` | Scheduler Statistics | +| `WarpStateStats` | Warp State Statistics | +| `InstructionStats` | Instruction Statistics | +| `SourceCounters` | Source Counters | +| `WorkloadDistribution` | GPU and Memory Workload Distribution | +| `NumaAffinity` | NUMA Affinity | +| `SpeedOfLight_RooflineChart` | GPU Speed Of Light Roofline Chart | +| `SpeedOfLight_HierarchicalTensorRooflineChart` | Roofline Chart (Tensor Core) | +| `SpeedOfLight_HierarchicalHalfRooflineChart` | Roofline Chart (Half Precision) | + +### Section Sets (used during profiling with `--set`) +| Set | Sections | Est. Metrics | +|-----|----------|-------------| +| `basic` | LaunchStats, Occupancy, SpeedOfLight, WorkloadDistribution | 213 | +| `detailed` | basic + ComputeWorkloadAnalysis, MemoryWorkloadAnalysis, SourceCounters, Roofline | 906 | +| `full` | All sections including Instruction/Scheduler/WarpState stats, all Rooflines | 7794 | + +### Commonly Used Raw Metrics +| Metric | Description | +|--------|-------------| +| `gpu__time_duration.sum` | Kernel wall-clock duration | +| `sm__throughput.avg.pct_of_peak_sustained_elapsed` | SM throughput % | +| `dram__throughput.avg.pct_of_peak_sustained_elapsed` | DRAM throughput % | +| `sm__warps_active.avg.pct_of_peak_sustained_active` | Active warps % | +| `launch__occupancy_limit_registers` | Occupancy limiter: registers | +| `launch__occupancy_limit_shared_mem` | Occupancy limiter: shared memory | +| `launch__occupancy_limit_blocks` | Occupancy limiter: blocks | +| `launch__occupancy_limit_warps` | Occupancy limiter: warps | +| `sm__sass_thread_inst_executed_op_*` | Per-opcode instruction counts | +| `l1tex__t_sector_hit_rate.pct` | L1 cache hit rate | +| `lts__t_sector_hit_rate.pct` | L2 cache hit rate | + +### Available Rules (used during profiling with `--rule`) +| Rule ID | Description | +|---------|-------------| +| `SOLBottleneck` | High-level bottleneck detection | +| `SOLFPRoofline` | Floating Point Roofline Analysis | +| `CPIStall` | Warp stall analysis | +| `Occupancy` | Achieved Occupancy analysis | +| `LaunchConfiguration` | Kernel launch config analysis | +| `HighPipeUtilization` | High pipe utilization bottleneck | +| `IssueSlotUtilization` | Scheduler issue analysis | +| `SharedMemoryConflicts` | Shared memory bank conflicts | +| `ThreadDivergence` | Warp/thread divergence | +| `UncoalescedGlobalAccess` | Uncoalesced global memory | +| `UncoalescedSharedAccess` | Uncoalesced shared memory | +| `SlowPipeLimiter` | Slow pipe limiting compute | +| `FPInstructions` | FP instruction analysis | +| `PCSamplingData` | PC sampling data | + +## Comparing Kernels + +When the report contains multiple kernels (e.g., a reference nvjet kernel and a tilus kernel), always present metrics side-by-side for comparison. Highlight: +1. Duration difference (which is faster, by how much) +2. Throughput differences (compute vs memory bound) +3. Occupancy differences +4. Any rule findings that differ + +## Tips +- The `raw` page has one row per kernel with all metrics as columns — good for extracting specific values. +- The `details` page organizes metrics by section — good for browsing all metrics in a section. +- The `source` page is per-instruction — good for hotspot analysis. Output can be very large; pipe through `head` or filter with `grep`. +- Use `--print-units base` with `--csv` for consistent numeric parsing. +- Use `--print-metric-name name` to get programmatic metric names instead of human labels. +- Source analysis with `cuda,sass` view shows CUDA source lines interleaved with their SASS instructions — extremely useful for correlating high-level code with assembly hotspots. diff --git a/.claude/skills/write-docs/SKILL.md b/.claude/skills/write-docs/SKILL.md new file mode 100644 index 00000000..7566fb85 --- /dev/null +++ b/.claude/skills/write-docs/SKILL.md @@ -0,0 +1,220 @@ +--- +name: write-docs +description: > + Convention and format for writing instruction docstrings and RST tutorials in + tilus documentation. TRIGGER when: user asks to add, update, or write + documentation for tilus instructions, instruction groups, or tutorials. +--- + +# Writing Instruction Documentation + +## Docstring Format + +All instruction docstrings use **NumPy-style** format with the following structure: + +```python +def method_name(self, param1: Type1, param2: Type2) -> ReturnType: + """One-line summary of what the instruction does. + + Extended description explaining the behavior, semantics, and constraints + of the instruction. This can be multiple paragraphs. + + Parameters + ---------- + param1: Type1 + Description of the parameter. Include constraints and valid ranges + (e.g., "must be evaluated to a positive int32"). + param2: Type2 + Description. For parameters with defaults, explain the default behavior + (e.g., "By default, it is 1."). + + Returns + ------- + ret: ReturnType + Description of the return value including shape, dtype, and relationship + to inputs. + """ +``` + +## Conventions + +### Summary line +- Start with a verb: "Allocate...", "Arrive at...", "Compute the...", "Load...", "Wait at..." +- Keep it to one line + +### Extended description +- Explain what the instruction does at the hardware level when relevant +- Document state transitions (e.g., barrier phase switching) +- Describe relationships between parameters +- Explain multicast/cluster behavior for distributed instructions +- Scale detail to complexity: simple methods (load, cast) need minimal description; complex methods (mbarrier.alloc, tma.global_to_shared) need thorough explanation + +### Parameters section +- Document in the same order as the function signature +- Use full type annotations: `RegisterTensor`, `Expr | int`, `Optional[Type]` +- Include constraints: "must be in the range of [0, N)", "must be evaluated to a non-negative int32" +- Document valid candidates for string parameters: `Candidates: 'relaxed', 'release'.` +- Explain default values: "By default, it is 1." + +### Returns section +- Use `ret: Type` format +- Describe shape, dtype, and semantic meaning + +### Notes section (required) +Every instruction must have a Notes section with these items as a compact bullet list: + +```python + Notes + ----- + - **Thread group**: Can be executed by any sized thread group. + - **Hardware**: Requires compute capability 8.0+ (sm_80). + - **PTX**: ``mbarrier.init.shared::cta.b64`` +``` + +The three standard note items: + +- **Thread group**: Execution requirements. Common values: + - "Can be executed by any sized thread group." + - "Must be executed by a warp group (4 warps)." + - "Must be executed by a single thread (use ``self.single_thread()``)." + - "Must be executed by a single warp (use ``self.single_warp()`)" +- **Hardware**: Minimum compute capability. Format: "Requires compute capability X.Y+ (sm_XY)." + - sm_80 = Ampere (A100), sm_89 = Ada Lovelace (L4/L40), sm_90 = Hopper (H100), sm_100 = Blackwell (B200) +- **PTX**: The underlying PTX instruction(s) this maps to. Use double backticks for inline code. + If the instruction maps to multiple PTX instructions depending on parameters, list them: + ``` + - **PTX**: ``mbarrier.arrive.shared::cta.b64`` or ``mbarrier.arrive.noComplete.shared::cta.b64`` + ``` + If the instruction does not lower to a specific PTX instruction (e.g., it's a high-level + construct), omit the PTX line. + +### Other optional sections +- **Examples**: Use `.. code-block:: python` for usage examples +- **See Also**: Cross-reference related methods with `:py:meth:` or `:py:func:` + +### Memory ordering parameters +For synchronization instructions, document `sem` and `scope` parameters consistently: +```python +sem: str + The memory ordering semantics for the operation. Candidates: 'relaxed', 'release'. +scope: str + The synchronization scope for the operation. Candidates: 'cta', 'cluster'. +``` + +## Reference examples +- Simple instruction: `root.py:load_global`, `root.py:cast` +- Complex instruction: `mbarrier.py:alloc`, `mbarrier.py:arrive_and_expect_tx` +- With PTX reference: `fence.py:proxy_async`, `fence.py:proxy_async_release` +- With code example: `root.py:range`, `root.py:thread_group` +- TMA instruction: `tma.py:global_to_shared` + +--- + +# Writing RST Tutorials + +## Target audience + +Tutorials target **CS researchers who can write Triton kernels** but want to +understand the hardware features underneath Triton's abstractions. Assume readers +know: +- Block-level GPU programming (each program instance processes a tile) +- `tl.load`, `tl.store`, `tl.dot` semantics +- Autotuning concepts +- Basic GPU memory hierarchy (global, shared, registers) + +Do **not** assume they know: +- Explicit shared memory management or allocation +- Warp-level programming or thread indexing within a block +- Asynchronous execution models (mbarrier, TMA, commit/wait patterns) +- Tensor Memory or tcgen05 instruction families +- Memory ordering semantics (acquire/release, proxy fences) + +## Bridging the gap from Triton + +When introducing a concept that Triton handles implicitly, briefly explain **why** +explicit control is needed. Common contrasts: + +- **Shared memory**: "Unlike Triton where shared memory is managed automatically, + tilus gives explicit control --- necessary to use hardware features like TMA and + tcgen05." +- **Registers vs Tensor Memory**: "On earlier architectures (and in Triton), + MMA results accumulate in registers. Blackwell's tensor cores use dedicated + Tensor Memory, which provides higher bandwidth and avoids consuming register + file capacity for large tiles." +- **Synchronization**: "Triton handles synchronization implicitly. On Blackwell, + many operations are asynchronous --- the instruction returns immediately and + completes in the background. This enables overlap of data movement and + computation, but requires explicit tracking via mbarriers." +- **Thread/warp management**: "In Triton, all threads execute the same code. + Efficient Blackwell kernels require different warps to perform different + jobs (loading, computing, scheduling) and collaborate asynchronously via + thread groups." + +## Tutorial structure + +Each tutorial version (v0, v1, ...) should follow this structure: + +1. **Introduction** --- What this version adds, what Blackwell features it uses + (with hyperlinks to instruction group docs). +2. **Full kernel** --- Show the complete kernel upfront so readers see the big + picture before the detailed walkthrough. +3. **Topic sections** --- Explain each new concept with enough detail to + understand the example. Order by conceptual dependency. Include: + - A brief motivation (why does this exist / why do we need it) + - How it works at a high level + - Link to the detailed API/programming guide for deeper reading +4. **Walkthrough** --- Walk through the kernel code in logical groups (setup, + main loop, epilogue). Use `literalinclude` with `:start-at:`/`:end-at:` + markers (never absolute line numbers). For each group, use a bullet list + explaining each instruction with hyperlinks. +5. **What's Next** --- Motivate the next version by identifying the current + bottleneck. +6. **Full Source** --- Download link to the example file. + +## Writing guidelines + +### Explain the "why", not just the "what" +- For every `sync()` call, explain what it guards (e.g., "ensures shared memory + writes are visible to all threads before the MMA warp reads them"). +- For magic numbers like `warps = 4`, explain the choice (e.g., "4 warps = 128 + threads; later versions use more warps to overlap loading and computing"). +- For `enable_input_d`, explain: "On the first iteration, tensor memory contains + uninitialized data, so we ignore it. On subsequent iterations, it holds the + running sum from prior tiles." +- For mbarrier phase flipping, explain: "The same barrier is reused across + iterations. The phase distinguishes this iteration's completion from the + previous one's." + +### Hyperlinks +- Use `:meth:` for instruction methods: + `:meth:`~tilus.Script.copy_async`` for root instructions, + `:meth:`tcgen05.mma `` + for instruction group methods (shows short name, links to full path). +- Use `:attr:` for attributes: `:attr:`self.attrs.blocks `` +- Use `:doc:` for cross-references to other pages: `:doc:`/programming-guides/thread-group`` +- Use `:class:` for tensor types: `:class:`~tilus.ir.tensor.TMemoryTensor`` + +### Code inclusion +- Always use `:start-at:` / `:end-at:` / `:start-after:` / `:end-before:` + instead of absolute line numbers. This makes includes resilient to code + changes. +- Use `:dedent:` to strip leading indentation when including method bodies. +- Use `:caption:` for all included code blocks. + +### Figures and diagrams +- Place SVGs in a `figures/` subdirectory next to the tutorial RST files. +- Use `.. figure::` with `:width:` and `:align: center`. +- SVGs should be editable in draw.io for collaborative iteration. +- Suggest diagrams for: block tiling, data flow, pipeline stages, + cluster layouts, and any concept that benefits from a visual. + +### Tone +- Concise and direct. Avoid filler words. +- Vary sentence structure --- avoid starting every bullet with "We ..." + (lead with the instruction/concept name instead). +- Don't over-explain concepts that are well-covered in linked pages. The tutorial + should give enough to understand the example; detailed semantics belong in the + programming guides and API docs. + +## Reference tutorial +- Blackwell matmul V0: `docs/source/tutorials/matmul-blackwell/v0.rst` From 32b51d866c8f0325f654445f89c01393e15323f6 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Thu, 7 May 2026 01:25:22 +0000 Subject: [PATCH 06/21] bring to 80% cuBLAS Signed-off-by: William Zhang --- examples/hopper_matmul/matmul_v3.py | 28 ++++- examples/hopper_matmul/matmul_v4.py | 11 +- examples/hopper_matmul/matmul_v5.py | 172 ++++++++++++++++------------ 3 files changed, 129 insertions(+), 82 deletions(-) diff --git a/examples/hopper_matmul/matmul_v3.py b/examples/hopper_matmul/matmul_v3.py index ff850539..8603260c 100644 --- a/examples/hopper_matmul/matmul_v3.py +++ b/examples/hopper_matmul/matmul_v3.py @@ -103,15 +103,37 @@ def __call__( dtype=uint32, shape=[self.num_stages], init=0 ) stage: int32 = 0 - for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): + + # Prologue: issue first wgmma but don't wait. The current group is + # still in flight; we'll overlap it with the next iteration's TMA + # wait via wait_group(1) below. + self.mbarrier.wait(consumer_barriers[stage], phase=consumer_phases[stage]) + consumer_phases[stage] ^= 1 + self.wgmma.fence() + self.wgmma.mma(sa[stage], sb[stage].transpose(), acc) + self.wgmma.commit_group() + stage = (stage + 1) % self.num_stages + + # Main loop: issue the next wgmma, then wait_group(1) drains the + # previous group while the new one runs. This is the double-buffered + # async pattern that lets wgmma overlap the next consumer_acquire. + for offset_k in self.range(block_k, k_size, block_k, unroll=self.num_stages): self.mbarrier.wait(consumer_barriers[stage], phase=consumer_phases[stage]) consumer_phases[stage] ^= 1 self.wgmma.fence() self.wgmma.mma(sa[stage], sb[stage].transpose(), acc) self.wgmma.commit_group() - self.wgmma.wait_group(0) - self.mbarrier.arrive(producer_barriers[stage]) + self.wgmma.wait_group(1) + # Release the previous stage (whose mma just finished). + prev_stage = (stage + self.num_stages - 1) % self.num_stages + self.mbarrier.arrive(producer_barriers[prev_stage]) stage = (stage + 1) % self.num_stages + + # Epilogue: drain the last in-flight group, release its stage. + self.wgmma.wait_group(0) + prev_stage = (stage + self.num_stages - 1) % self.num_stages + self.mbarrier.arrive(producer_barriers[prev_stage]) + self.sync() casted_acc = self.cast(acc, dtype=float16) gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) diff --git a/examples/hopper_matmul/matmul_v4.py b/examples/hopper_matmul/matmul_v4.py index c35d533e..e1afc324 100644 --- a/examples/hopper_matmul/matmul_v4.py +++ b/examples/hopper_matmul/matmul_v4.py @@ -74,12 +74,17 @@ def prev_consumer_barrier(self) -> RegisterTensor: return self.empty_barriers[prev_stage] -@tilus.autotune("num_stages", [2, 3, 4, 5, 6, 7]) +# Tightened autotune space: drop block_m=128/n=64 (skinny → poor wgmma fill), +# drop block_m=256/n=128 (heavy register pressure), drop num_stages=2 (too +# shallow to hide TMA), drop num_stages=7 (smem-thrashing on 256×256), drop +# swizzle_size=1 (≡ default rasterization, no L2 win). The remaining 60 configs +# are biased toward shapes cuBLAS uses for 8K² fp16 GEMMs. +@tilus.autotune("num_stages", [3, 4, 5, 6]) @tilus.autotune( - "block_m, block_n", [[128, 64], [128, 128], [128, 256], [256, 128], [256, 256]] + "block_m, block_n", [[128, 128], [128, 256], [256, 256]] ) @tilus.autotune("block_k", [16, 32, 64]) -@tilus.autotune("swizzle_size", [1, 4, 8]) +@tilus.autotune("swizzle_size", [4, 8]) class MatmulWGMMAV4(tilus.Script): def __init__(self, num_stages, block_m, block_n, block_k, swizzle_size): super().__init__() diff --git a/examples/hopper_matmul/matmul_v5.py b/examples/hopper_matmul/matmul_v5.py index 50790798..783cdd16 100644 --- a/examples/hopper_matmul/matmul_v5.py +++ b/examples/hopper_matmul/matmul_v5.py @@ -1,15 +1,18 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# v5: TMA epilogue — replaces direct register-to-global stores with a -# shared-memory staging buffer written out via TMA bulk store. +# v5: TMA epilogue + persistent grid. # # Changes from v4: -# - Pre-allocates s_c[block_m, block_n] alongside s_a/s_b. -# - After the K loop, stores the float16-cast accumulator to s_c, then -# issues a TMA shared→global transfer instead of store_global. -# - A fence.proxy_async(space="shared") between store_shared and TMA is -# required so the generic-proxy writes are visible to the async proxy. +# - Persistent grid: launch one CTA per SM and stride through tiles inside +# the kernel. The TMA epilogue of tile T runs while the wgmma compute of +# tile T+1 is in flight, hiding the shared->global bulk-store latency. +# v4's direct register stores are synchronous and can't be overlapped. +# - s_c[block_m, block_n] staging buffer: cast(acc) -> store_shared(s_c) +# -> tma.shared_to_global. fence.proxy_async between the generic-proxy +# store_shared and the async-proxy TMA store. +# - tma.wait_group is moved to the *start* of the next tile's epilogue so +# it overlaps with cast and consumer drain instead of blocking the CTA. import math @@ -77,12 +80,16 @@ def prev_consumer_barrier(self) -> RegisterTensor: return self.empty_barriers[prev_stage] +# Keep block_n=64 / [256,128] in the space: the TMA epilogue stages s_c via +# store_shared, and only certain (m,n) shapes give a single-swizzle s_c +# layout that tma.shared_to_global can consume. Trimming those out leaves +# every remaining schedule failing layout validation. @tilus.autotune("num_stages", [2, 3, 4, 5]) @tilus.autotune( "block_m, block_n", [[128, 64], [128, 128], [128, 256], [256, 128], [256, 256]] ) @tilus.autotune("block_k", [16, 32, 64]) -@tilus.autotune("swizzle_size", [1, 4, 8]) +@tilus.autotune("swizzle_size", [4, 8]) class MatmulWGMMAV5(tilus.Script): def __init__(self, num_stages, block_m, block_n, block_k, swizzle_size): super().__init__() @@ -125,100 +132,113 @@ def __call__( num_m_blocks = cdiv(m_size, block_m) num_n_blocks = cdiv(n_size, block_n) - self.attrs.blocks = num_m_blocks * num_n_blocks + total_tiles = num_m_blocks * num_n_blocks + # Persistent grid: launch one CTA per SM (H200 NVL = 132 SMs) and + # iterate tiles inside the kernel. The TMA epilogue of tile T then + # overlaps with the wgmma compute of tile T+1, hiding the + # shared->global bulk-store latency that v4 (direct register stores) + # cannot hide. + num_sms = 132 + self.attrs.blocks = num_sms self.attrs.warps = 5 - m_block, n_block = self.compute_block_coord( - self.blockIdx.x, num_m_blocks, num_n_blocks - ) - offset_m: int32 = m_block * block_m - offset_n: int32 = n_block * block_n - ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) sa = self.shared_tensor(dtype=float16, shape=[num_stages, block_m, block_k]) sb = self.shared_tensor(dtype=float16, shape=[num_stages, block_n, block_k]) - # s_c is the staging buffer for the TMA epilogue; allocated alongside - # s_a/s_b so the allocator can pick a fitting shared-memory partition. sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n]) - acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) tma_pipe = Pipeline(num_stages, producer_arrive_count=1, consumer_arrive_count=128) with self.thread_group(thread_begin=128, num_threads=32): # TMA producer warp - for offset_k in self.range(0, k_size, block_k, unroll=num_stages): - tma_pipe.producer_acquire() - # Producer warp is already 32 threads — the granularity TMA - # needs at SASS level. Only the arrive runs in single_thread so - # transaction-bytes is counted once. - with self.single_thread(): - self.mbarrier.arrive_and_expect_tx( - tma_pipe.producer_barrier(), - transaction_bytes=sa[tma_pipe.producer_stage].nbytes - + sb[tma_pipe.producer_stage].nbytes, - ) - self.tma.global_to_shared( - src=ga, - dst=sa[tma_pipe.producer_stage], - offsets=[offset_m, offset_k], - mbarrier=tma_pipe.producer_barrier(), - ) - self.tma.global_to_shared( - src=gb, - dst=sb[tma_pipe.producer_stage], - offsets=[offset_n, offset_k], - mbarrier=tma_pipe.producer_barrier(), + for tile_idx in self.range(self.blockIdx.x, total_tiles, num_sms): + m_block, n_block = self.compute_block_coord( + tile_idx, num_m_blocks, num_n_blocks ) - tma_pipe.producer_advance() + offset_m: int32 = m_block * block_m + offset_n: int32 = n_block * block_n + for offset_k in self.range(0, k_size, block_k, unroll=num_stages): + tma_pipe.producer_acquire() + with self.single_thread(): + self.mbarrier.arrive_and_expect_tx( + tma_pipe.producer_barrier(), + transaction_bytes=sa[tma_pipe.producer_stage].nbytes + + sb[tma_pipe.producer_stage].nbytes, + ) + self.tma.global_to_shared( + src=ga, + dst=sa[tma_pipe.producer_stage], + offsets=[offset_m, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + self.tma.global_to_shared( + src=gb, + dst=sb[tma_pipe.producer_stage], + offsets=[offset_n, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + tma_pipe.producer_advance() - for _ in self.range(min(num_stages, cdiv(k_size, block_k))): + # Drain: let the consumer release the last num_stages stages. + for _ in self.range(num_stages): tma_pipe.producer_acquire() tma_pipe.producer_advance() with self.thread_group(thread_begin=0, num_threads=128): # WGMMA consumer - # Prologue: issue first MMA; release happens in first main-loop - # iteration after wait_group(1) confirms the MMA is done. - tma_pipe.consumer_acquire() - self.wgmma.fence() - self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) - self.wgmma.commit_group() - tma_pipe.consumer_advance() - - # Main loop: issue MMA then wait_group(1) *after* commit so the - # hardware pipelines the current and previous MMA groups while - # consumer_acquire overlaps with the prior group's execution. - for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): + for tile_idx in self.range(self.blockIdx.x, total_tiles, num_sms): + m_block, n_block = self.compute_block_coord( + tile_idx, num_m_blocks, num_n_blocks + ) + offset_m: int32 = m_block * block_m + offset_n: int32 = n_block * block_n + acc = self.register_tensor( + dtype=float32, shape=[block_m, block_n], init=0.0 + ) + + # Prologue tma_pipe.consumer_acquire() self.wgmma.fence() self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) self.wgmma.commit_group() - self.wgmma.wait_group(1) - self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) tma_pipe.consumer_advance() - # Epilogue: drain last in-flight MMA, release its stage. - self.wgmma.wait_group(0) - self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) - - # TMA epilogue: registers → shared → global - self.sync() - casted_acc = self.cast(acc, dtype=float16) - self.store_shared(sc, casted_acc) - # fence required: store_shared uses generic proxy; TMA uses async proxy - self.fence.proxy_async(space="shared") - self.sync() - # tma.shared_to_global is warp-cooperative; run inside single_warp. + # Main K loop + for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) + self.wgmma.commit_group() + self.wgmma.wait_group(1) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + tma_pipe.consumer_advance() + + # Epilogue: drain last in-flight MMA, release its stage. + self.wgmma.wait_group(0) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + + # Cast in registers (no shared dep) so it can run while the + # previous tile's bulk-store is still in flight. + casted_acc = self.cast(acc, dtype=float16) + + # Wait for the previous tile's TMA to finish before reusing + # sc. On the first tile this is a no-op (zero pending). + with self.single_warp(): + self.tma.wait_group(n=0, read=True) + self.sync() + self.store_shared(sc, casted_acc) + # store_shared uses generic proxy; TMA uses async proxy. + self.fence.proxy_async(space="shared") + self.sync() + with self.single_warp(): + self.tma.shared_to_global( + sc, gc, offsets=[offset_m, offset_n], dims=[0, 1] + ) + self.tma.commit_group() + + # Drain the final tile's bulk store before kernel exit. with self.single_warp(): - self.tma.shared_to_global( - sc, - gc, - offsets=[offset_m, offset_n], - dims=[0, 1], - ) - self.tma.commit_group() self.tma.wait_group(n=0, read=True) - self.sync() def main(): From 1564f49f5bb3ee3bb29de5e7c8926933f6a73621 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Thu, 7 May 2026 14:59:07 +0000 Subject: [PATCH 07/21] adjust v5 further Signed-off-by: William Zhang --- examples/hopper_matmul/matmul_v5.py | 144 +++++++----------- .../backends/emitters/cuda/cp_async_tensor.py | 133 ++++++++++++++-- 2 files changed, 175 insertions(+), 102 deletions(-) diff --git a/examples/hopper_matmul/matmul_v5.py b/examples/hopper_matmul/matmul_v5.py index 783cdd16..e9f7f8ac 100644 --- a/examples/hopper_matmul/matmul_v5.py +++ b/examples/hopper_matmul/matmul_v5.py @@ -1,18 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# v5: TMA epilogue + persistent grid. +# v5: TMA epilogue — replaces direct register-to-global stores with a +# shared-memory staging buffer written out via TMA bulk store. # # Changes from v4: -# - Persistent grid: launch one CTA per SM and stride through tiles inside -# the kernel. The TMA epilogue of tile T runs while the wgmma compute of -# tile T+1 is in flight, hiding the shared->global bulk-store latency. -# v4's direct register stores are synchronous and can't be overlapped. -# - s_c[block_m, block_n] staging buffer: cast(acc) -> store_shared(s_c) -# -> tma.shared_to_global. fence.proxy_async between the generic-proxy -# store_shared and the async-proxy TMA store. -# - tma.wait_group is moved to the *start* of the next tile's epilogue so -# it overlaps with cast and consumer drain instead of blocking the CTA. +# - Pre-allocates s_c[block_m, block_n] alongside s_a/s_b. +# - After the K loop, stores the float16-cast accumulator to s_c, then +# issues a TMA shared->global transfer instead of store_global. +# - A fence.proxy_async(space="shared") between store_shared and TMA is +# required so the generic-proxy writes are visible to the async proxy. import math @@ -132,112 +129,81 @@ def __call__( num_m_blocks = cdiv(m_size, block_m) num_n_blocks = cdiv(n_size, block_n) - total_tiles = num_m_blocks * num_n_blocks - # Persistent grid: launch one CTA per SM (H200 NVL = 132 SMs) and - # iterate tiles inside the kernel. The TMA epilogue of tile T then - # overlaps with the wgmma compute of tile T+1, hiding the - # shared->global bulk-store latency that v4 (direct register stores) - # cannot hide. - num_sms = 132 - self.attrs.blocks = num_sms + self.attrs.blocks = num_m_blocks * num_n_blocks self.attrs.warps = 5 + m_block, n_block = self.compute_block_coord( + self.blockIdx.x, num_m_blocks, num_n_blocks + ) + offset_m: int32 = m_block * block_m + offset_n: int32 = n_block * block_n + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) sa = self.shared_tensor(dtype=float16, shape=[num_stages, block_m, block_k]) sb = self.shared_tensor(dtype=float16, shape=[num_stages, block_n, block_k]) sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) tma_pipe = Pipeline(num_stages, producer_arrive_count=1, consumer_arrive_count=128) with self.thread_group(thread_begin=128, num_threads=32): # TMA producer warp - for tile_idx in self.range(self.blockIdx.x, total_tiles, num_sms): - m_block, n_block = self.compute_block_coord( - tile_idx, num_m_blocks, num_n_blocks - ) - offset_m: int32 = m_block * block_m - offset_n: int32 = n_block * block_n - for offset_k in self.range(0, k_size, block_k, unroll=num_stages): - tma_pipe.producer_acquire() - with self.single_thread(): - self.mbarrier.arrive_and_expect_tx( - tma_pipe.producer_barrier(), - transaction_bytes=sa[tma_pipe.producer_stage].nbytes - + sb[tma_pipe.producer_stage].nbytes, - ) - self.tma.global_to_shared( - src=ga, - dst=sa[tma_pipe.producer_stage], - offsets=[offset_m, offset_k], - mbarrier=tma_pipe.producer_barrier(), - ) - self.tma.global_to_shared( - src=gb, - dst=sb[tma_pipe.producer_stage], - offsets=[offset_n, offset_k], - mbarrier=tma_pipe.producer_barrier(), + for offset_k in self.range(0, k_size, block_k, unroll=num_stages): + tma_pipe.producer_acquire() + with self.single_thread(): + self.mbarrier.arrive_and_expect_tx( + tma_pipe.producer_barrier(), + transaction_bytes=sa[tma_pipe.producer_stage].nbytes + + sb[tma_pipe.producer_stage].nbytes, ) - tma_pipe.producer_advance() + self.tma.global_to_shared( + src=ga, + dst=sa[tma_pipe.producer_stage], + offsets=[offset_m, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + self.tma.global_to_shared( + src=gb, + dst=sb[tma_pipe.producer_stage], + offsets=[offset_n, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + tma_pipe.producer_advance() - # Drain: let the consumer release the last num_stages stages. - for _ in self.range(num_stages): + for _ in self.range(min(num_stages, cdiv(k_size, block_k))): tma_pipe.producer_acquire() tma_pipe.producer_advance() with self.thread_group(thread_begin=0, num_threads=128): # WGMMA consumer - for tile_idx in self.range(self.blockIdx.x, total_tiles, num_sms): - m_block, n_block = self.compute_block_coord( - tile_idx, num_m_blocks, num_n_blocks - ) - offset_m: int32 = m_block * block_m - offset_n: int32 = n_block * block_n - acc = self.register_tensor( - dtype=float32, shape=[block_m, block_n], init=0.0 - ) + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) + self.wgmma.commit_group() + tma_pipe.consumer_advance() - # Prologue + for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): tma_pipe.consumer_acquire() self.wgmma.fence() self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) self.wgmma.commit_group() - tma_pipe.consumer_advance() - - # Main K loop - for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): - tma_pipe.consumer_acquire() - self.wgmma.fence() - self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) - self.wgmma.commit_group() - self.wgmma.wait_group(1) - self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) - tma_pipe.consumer_advance() - - # Epilogue: drain last in-flight MMA, release its stage. - self.wgmma.wait_group(0) + self.wgmma.wait_group(1) self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + tma_pipe.consumer_advance() - # Cast in registers (no shared dep) so it can run while the - # previous tile's bulk-store is still in flight. - casted_acc = self.cast(acc, dtype=float16) - - # Wait for the previous tile's TMA to finish before reusing - # sc. On the first tile this is a no-op (zero pending). - with self.single_warp(): - self.tma.wait_group(n=0, read=True) - self.sync() - self.store_shared(sc, casted_acc) - # store_shared uses generic proxy; TMA uses async proxy. - self.fence.proxy_async(space="shared") - self.sync() - with self.single_warp(): - self.tma.shared_to_global( - sc, gc, offsets=[offset_m, offset_n], dims=[0, 1] - ) - self.tma.commit_group() + self.wgmma.wait_group(0) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) - # Drain the final tile's bulk store before kernel exit. + self.sync() + casted_acc = self.cast(acc, dtype=float16) + self.store_shared(sc, casted_acc) + self.fence.proxy_async(space="shared") + self.sync() with self.single_warp(): + self.tma.shared_to_global( + sc, gc, offsets=[offset_m, offset_n], dims=[0, 1] + ) + self.tma.commit_group() self.tma.wait_group(n=0, read=True) diff --git a/python/tilus/backends/emitters/cuda/cp_async_tensor.py b/python/tilus/backends/emitters/cuda/cp_async_tensor.py index 39c99e3d..e96af98e 100644 --- a/python/tilus/backends/emitters/cuda/cp_async_tensor.py +++ b/python/tilus/backends/emitters/cuda/cp_async_tensor.py @@ -221,6 +221,98 @@ def resolve_shared_tensor_info(self, shared_tensor: SharedTensor) -> SharedTenso + layout.visualize() ) + def resolve_shared_tensor_segments( + self, shared_tensor: SharedTensor + ) -> tuple[list[tuple[SharedTensorInfo, int]], int]: + """Decompose a 2D shared tensor whose n-axis has been split by the layout + selector into ``S`` row-major sub-blocks (``mode_shape=[bm, S, bn/S]``, + ``mode_strides=[bn/S, bm*bn/S, 1]``) into ``S`` TMA-issuable boxes. + + Such layouts arise when ``store_shared`` mirrors the wgmma register-tile + fragmentation (each warp's output spans ``bn/S`` columns and the warps + are stacked along the n-axis at strided offsets). The aggregate layout + is not one of the 7 hardware swizzle patterns, but each per-segment + ``[bm, bn/S]`` sub-tile *is*, so we emit one TMA store per segment. + + Returns + ------- + segments + A list of ``(per_segment_info, segment_n_offset)`` pairs, in segment + order. ``segment_n_offset`` is the column offset (in elements) of the + sub-tile within the original n-axis, to be added to ``inst.offsets`` + on the segmented dim. + seg_dim + The shared-tensor dimension that has been segmented (always 1 for + now; surfaced for future extension). + """ + layout: SharedLayout = shared_tensor.layout + + if len(shared_tensor.shape) != 2 or len(layout.mode_shape) != 3: + raise NotImplementedError("segmented decomposition only handles 2D layouts split into 3 modes") + + bm, bn = shared_tensor.shape + m0, s, n_inner = layout.mode_shape + sm, ss, sn = layout.mode_strides + + # Expected: dim 0 = [bm], dim 1 = [S, bn/S]; segments contiguous as + # ``[bm, bn/S]`` row-major boxes stacked along n. + if m0 != bm or s * n_inner != bn: + raise NotImplementedError("mode shape does not segment the n-axis") + if sn != 1 or sm != n_inner or ss != bm * n_inner: + raise NotImplementedError( + "mode strides do not match contiguous [bm, bn/S] segments stacked along n" + ) + + # Validate that each segment is a single TMA-supported swizzle box. + per_segment_shape = (bm, n_inner) + per_segment_layout = SharedLayout( + shape=per_segment_shape, + mode_shape=(bm, n_inner), + mode_strides=(n_inner, 1), + optional_swizzle=layout.optional_swizzle, + ) + per_segment_grid = per_segment_layout.as_numpy_grid() + + chosen_swizzle: Optional[TensorMapSwizzle] = None + for swizzle in [ + TensorMapSwizzle.NONE, + TensorMapSwizzle.B32, + TensorMapSwizzle.B64, + TensorMapSwizzle.B128, + TensorMapSwizzle.B128_ATOM_32B, + TensorMapSwizzle.B128_ATOM_32B_FLIP_8B, + TensorMapSwizzle.B128_ATOM_64B, + ]: + swizzled_grid = get_offset_grid_of_swizzled_layout( + dtype_nbits=shared_tensor.dtype.nbits, shape=per_segment_shape, swizzle=swizzle + ) + if swizzled_grid is not None and np.array_equal(per_segment_grid, swizzled_grid): + chosen_swizzle = swizzle + break + + if chosen_swizzle is None: + raise NotImplementedError( + "Segment layout does not match any TMA hardware swizzle: \n" + + f"Per-segment shape: {shared_tensor.dtype.name}{list(per_segment_shape)}\n" + + per_segment_layout.visualize() + ) + + base_addr = self.shared_tensor_shared_space_addr[shared_tensor] + segment_nbytes = bm * n_inner * shared_tensor.dtype.nbytes + segments: list[tuple[SharedTensorInfo, int]] = [] + for k in range(s): + segments.append( + ( + SharedTensorInfo( + addr=base_addr + k * segment_nbytes, + shape=per_segment_shape, + swizzle=chosen_swizzle, + ), + k * n_inner, + ) + ) + return segments, 1 + def declare_host_buffer(self, name: str, dtype: DataType, shape: Sequence[int]) -> Var: return self.host_builder.declare_var(name=name, tp=TensorType(dtype=dtype, shape=shape)) @@ -348,20 +440,35 @@ def emit(self, inst: CopyAsyncTensorSharedToGlobalInst) -> None: global_tensor, offsets=inst.offsets, dims=inst.dims ) - shared_tensor_info: SharedTensorInfo = self.resolve_shared_tensor_info(shared_tensor) - - shared_addr = self.shared_tensor_shared_space_addr[shared_tensor] - tensor_map = self.create_tensor_map(global_tensor_info, shared_tensor_info, dtype) - tensor_coords = inst.offsets - self.append( - cp_async_tensor_shared_to_global( - dst_tensor_map=~tensor_map, - src=shared_addr, - coords=list(reversed(tensor_coords)), - cache_policy=inst.cache_policy, - predicate=self.tma_predicate, + try: + shared_tensor_info: SharedTensorInfo = self.resolve_shared_tensor_info(shared_tensor) + segments: list[tuple[SharedTensorInfo, int]] = [(shared_tensor_info, 0)] + seg_dim: Optional[int] = None + except NotImplementedError: + # Fall back to per-segment emission for layouts that split a dim + # into stacked sub-boxes (typically the wgmma-fragment layout that + # store_shared inherits when targeting sc[bm, bn]). + segments, seg_dim = self.resolve_shared_tensor_segments(shared_tensor) + + # All segments share box shape and swizzle, so reuse one descriptor. + first_info = segments[0][0] + tensor_map = self.create_tensor_map(global_tensor_info, first_info, dtype) + for info, segment_offset in segments: + tensor_coords = list(inst.offsets) + if seg_dim is not None and segment_offset != 0: + # seg_dim indexes the *shared* dims; map it to the matching + # global dim through inst.dims, then shift that coord. + global_seg_dim = inst.dims[seg_dim] + tensor_coords[global_seg_dim] = tensor_coords[global_seg_dim] + segment_offset + self.append( + cp_async_tensor_shared_to_global( + dst_tensor_map=~tensor_map, + src=info.addr, + coords=list(reversed(tensor_coords)), + cache_policy=inst.cache_policy, + predicate=self.tma_predicate, + ) ) - ) @register_emitter(CopyAsyncTensorCommitGroupInst, target=nvgpu_sm90) From 5c2a0d99265b3d83c778bee9459f2b24e85c9bb3 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Thu, 7 May 2026 15:05:47 +0000 Subject: [PATCH 08/21] fix linting Signed-off-by: William Zhang --- examples/hopper_matmul/benchmark.py | 12 +++++++++--- examples/hopper_matmul/matmul_v4.py | 18 ++++++++++++------ examples/hopper_matmul/matmul_v5.py | 14 +++++++++++--- .../backends/emitters/cuda/cp_async_tensor.py | 9 ++++----- 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/examples/hopper_matmul/benchmark.py b/examples/hopper_matmul/benchmark.py index 4daa3c5f..cf654f3d 100644 --- a/examples/hopper_matmul/benchmark.py +++ b/examples/hopper_matmul/benchmark.py @@ -127,13 +127,19 @@ def benchmark_all(versions: list[str], m_size: int, n_size: int, k_size: int): headers = ["version", "latency (ms)", "tflops", "% of cublas"] rows = [] - a = (torch.rand(m_size, k_size, dtype=torch.float16, device="cuda") - 0.5) / math.sqrt(k_size) - b = (torch.rand(n_size, k_size, dtype=torch.float16, device="cuda") - 0.5) / math.sqrt(k_size) + a = ( + torch.rand(m_size, k_size, dtype=torch.float16, device="cuda") - 0.5 + ) / math.sqrt(k_size) + b = ( + torch.rand(n_size, k_size, dtype=torch.float16, device="cuda") - 0.5 + ) / math.sqrt(k_size) c_ref = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") c_tilus = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") # cuBLAS baseline first, so we can compute % of cublas - cublas_lat = benchmark_func(lambda: torch.matmul(a, b.T, out=c_ref), warmup=5, repeat=30) + cublas_lat = benchmark_func( + lambda: torch.matmul(a, b.T, out=c_ref), warmup=5, repeat=30 + ) cublas_tf = 2 * m_size * n_size * k_size / cublas_lat * 1e-9 for name in versions: diff --git a/examples/hopper_matmul/matmul_v4.py b/examples/hopper_matmul/matmul_v4.py index e1afc324..7a48a859 100644 --- a/examples/hopper_matmul/matmul_v4.py +++ b/examples/hopper_matmul/matmul_v4.py @@ -80,9 +80,7 @@ def prev_consumer_barrier(self) -> RegisterTensor: # swizzle_size=1 (≡ default rasterization, no L2 win). The remaining 60 configs # are biased toward shapes cuBLAS uses for 8K² fp16 GEMMs. @tilus.autotune("num_stages", [3, 4, 5, 6]) -@tilus.autotune( - "block_m, block_n", [[128, 128], [128, 256], [256, 256]] -) +@tilus.autotune("block_m, block_n", [[128, 128], [128, 256], [256, 256]]) @tilus.autotune("block_k", [16, 32, 64]) @tilus.autotune("swizzle_size", [4, 8]) class MatmulWGMMAV4(tilus.Script): @@ -148,7 +146,9 @@ def __call__( # producer_arrive_count=1: single thread does arrive_and_expect_tx # consumer_arrive_count=128: all 128 consumer threads arrive when done - tma_pipe = Pipeline(num_stages, producer_arrive_count=1, consumer_arrive_count=128) + tma_pipe = Pipeline( + num_stages, producer_arrive_count=1, consumer_arrive_count=128 + ) with self.thread_group(thread_begin=128, num_threads=32): # TMA producer warp for offset_k in self.range(0, k_size, block_k, unroll=num_stages): @@ -187,7 +187,9 @@ def __call__( # after wait_group(1) confirms the MMA is done). tma_pipe.consumer_acquire() self.wgmma.fence() - self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) + self.wgmma.mma( + sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc + ) self.wgmma.commit_group() tma_pipe.consumer_advance() @@ -199,7 +201,11 @@ def __call__( for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): tma_pipe.consumer_acquire() self.wgmma.fence() - self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) + self.wgmma.mma( + sa[tma_pipe.consumer_stage], + sb[tma_pipe.consumer_stage].transpose(), + acc, + ) self.wgmma.commit_group() self.wgmma.wait_group(1) self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) diff --git a/examples/hopper_matmul/matmul_v5.py b/examples/hopper_matmul/matmul_v5.py index e9f7f8ac..f9e406c7 100644 --- a/examples/hopper_matmul/matmul_v5.py +++ b/examples/hopper_matmul/matmul_v5.py @@ -146,7 +146,9 @@ def __call__( sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n]) acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) - tma_pipe = Pipeline(num_stages, producer_arrive_count=1, consumer_arrive_count=128) + tma_pipe = Pipeline( + num_stages, producer_arrive_count=1, consumer_arrive_count=128 + ) with self.thread_group(thread_begin=128, num_threads=32): # TMA producer warp for offset_k in self.range(0, k_size, block_k, unroll=num_stages): @@ -178,14 +180,20 @@ def __call__( with self.thread_group(thread_begin=0, num_threads=128): # WGMMA consumer tma_pipe.consumer_acquire() self.wgmma.fence() - self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) + self.wgmma.mma( + sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc + ) self.wgmma.commit_group() tma_pipe.consumer_advance() for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): tma_pipe.consumer_acquire() self.wgmma.fence() - self.wgmma.mma(sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc) + self.wgmma.mma( + sa[tma_pipe.consumer_stage], + sb[tma_pipe.consumer_stage].transpose(), + acc, + ) self.wgmma.commit_group() self.wgmma.wait_group(1) self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) diff --git a/python/tilus/backends/emitters/cuda/cp_async_tensor.py b/python/tilus/backends/emitters/cuda/cp_async_tensor.py index e96af98e..fb5d82d7 100644 --- a/python/tilus/backends/emitters/cuda/cp_async_tensor.py +++ b/python/tilus/backends/emitters/cuda/cp_async_tensor.py @@ -224,8 +224,9 @@ def resolve_shared_tensor_info(self, shared_tensor: SharedTensor) -> SharedTenso def resolve_shared_tensor_segments( self, shared_tensor: SharedTensor ) -> tuple[list[tuple[SharedTensorInfo, int]], int]: - """Decompose a 2D shared tensor whose n-axis has been split by the layout - selector into ``S`` row-major sub-blocks (``mode_shape=[bm, S, bn/S]``, + """Decompose a 2D shared tensor whose n-axis has been split by the layout. + + The selector into ``S`` row-major sub-blocks (``mode_shape=[bm, S, bn/S]``, ``mode_strides=[bn/S, bm*bn/S, 1]``) into ``S`` TMA-issuable boxes. Such layouts arise when ``store_shared`` mirrors the wgmma register-tile @@ -259,9 +260,7 @@ def resolve_shared_tensor_segments( if m0 != bm or s * n_inner != bn: raise NotImplementedError("mode shape does not segment the n-axis") if sn != 1 or sm != n_inner or ss != bm * n_inner: - raise NotImplementedError( - "mode strides do not match contiguous [bm, bn/S] segments stacked along n" - ) + raise NotImplementedError("mode strides do not match contiguous [bm, bn/S] segments stacked along n") # Validate that each segment is a single TMA-supported swizzle box. per_segment_shape = (bm, n_inner) From 94cada85d21296d5fa818d9743e1424563ab8c39 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Thu, 7 May 2026 15:22:35 +0000 Subject: [PATCH 09/21] revert unnecessary changes to v3 Signed-off-by: William Zhang --- examples/hopper_matmul/matmul_v3.py | 29 +++-------------------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/examples/hopper_matmul/matmul_v3.py b/examples/hopper_matmul/matmul_v3.py index 8603260c..db4a127e 100644 --- a/examples/hopper_matmul/matmul_v3.py +++ b/examples/hopper_matmul/matmul_v3.py @@ -69,9 +69,6 @@ def __call__( for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): self.mbarrier.wait(producer_barriers[stage], phase=producer_phases[stage]) producer_phases[stage] ^= 1 - # Producer is already 32 threads (one warp) — TMA needs that to - # be warp-cooperative. Only the arrive must be by a single - # thread so tx-bytes is counted once. with self.single_thread(): self.mbarrier.arrive_and_expect_tx( consumer_barriers[stage], @@ -104,36 +101,16 @@ def __call__( ) stage: int32 = 0 - # Prologue: issue first wgmma but don't wait. The current group is - # still in flight; we'll overlap it with the next iteration's TMA - # wait via wait_group(1) below. - self.mbarrier.wait(consumer_barriers[stage], phase=consumer_phases[stage]) - consumer_phases[stage] ^= 1 - self.wgmma.fence() - self.wgmma.mma(sa[stage], sb[stage].transpose(), acc) - self.wgmma.commit_group() - stage = (stage + 1) % self.num_stages - - # Main loop: issue the next wgmma, then wait_group(1) drains the - # previous group while the new one runs. This is the double-buffered - # async pattern that lets wgmma overlap the next consumer_acquire. - for offset_k in self.range(block_k, k_size, block_k, unroll=self.num_stages): + for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): self.mbarrier.wait(consumer_barriers[stage], phase=consumer_phases[stage]) consumer_phases[stage] ^= 1 self.wgmma.fence() self.wgmma.mma(sa[stage], sb[stage].transpose(), acc) self.wgmma.commit_group() - self.wgmma.wait_group(1) - # Release the previous stage (whose mma just finished). - prev_stage = (stage + self.num_stages - 1) % self.num_stages - self.mbarrier.arrive(producer_barriers[prev_stage]) + self.wgmma.wait_group(0) + self.mbarrier.arrive(producer_barriers[stage]) stage = (stage + 1) % self.num_stages - # Epilogue: drain the last in-flight group, release its stage. - self.wgmma.wait_group(0) - prev_stage = (stage + self.num_stages - 1) % self.num_stages - self.mbarrier.arrive(producer_barriers[prev_stage]) - self.sync() casted_acc = self.cast(acc, dtype=float16) gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) From 2c09900a7df4137b39179f894615bd7aa7091910 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Thu, 7 May 2026 15:54:01 +0000 Subject: [PATCH 10/21] fully restore v3, bring v5 to 91 percent Signed-off-by: William Zhang --- examples/hopper_matmul/matmul_v3.py | 26 ++++--- examples/hopper_matmul/matmul_v5.py | 106 +++++++++++++++++----------- 2 files changed, 77 insertions(+), 55 deletions(-) diff --git a/examples/hopper_matmul/matmul_v3.py b/examples/hopper_matmul/matmul_v3.py index db4a127e..a6eb96d8 100644 --- a/examples/hopper_matmul/matmul_v3.py +++ b/examples/hopper_matmul/matmul_v3.py @@ -74,18 +74,18 @@ def __call__( consumer_barriers[stage], transaction_bytes=sa[stage].nbytes + sb[stage].nbytes, ) - self.tma.global_to_shared( - src=ga, - dst=sa[stage], - offsets=[offset_m, offset_k], - mbarrier=consumer_barriers[stage], - ) - self.tma.global_to_shared( - src=gb, - dst=sb[stage], - offsets=[offset_n, offset_k], - mbarrier=consumer_barriers[stage], - ) + self.tma.global_to_shared( + src=ga, + dst=sa[stage], + offsets=[offset_m, offset_k], + mbarrier=consumer_barriers[stage], + ) + self.tma.global_to_shared( + src=gb, + dst=sb[stage], + offsets=[offset_n, offset_k], + mbarrier=consumer_barriers[stage], + ) stage = (stage + 1) % self.num_stages for _ in self.range(min(self.num_stages, cdiv(k_size, self.block_k))): @@ -100,7 +100,6 @@ def __call__( dtype=uint32, shape=[self.num_stages], init=0 ) stage: int32 = 0 - for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): self.mbarrier.wait(consumer_barriers[stage], phase=consumer_phases[stage]) consumer_phases[stage] ^= 1 @@ -110,7 +109,6 @@ def __call__( self.wgmma.wait_group(0) self.mbarrier.arrive(producer_barriers[stage]) stage = (stage + 1) % self.num_stages - self.sync() casted_acc = self.cast(acc, dtype=float16) gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) diff --git a/examples/hopper_matmul/matmul_v5.py b/examples/hopper_matmul/matmul_v5.py index f9e406c7..a15efe2a 100644 --- a/examples/hopper_matmul/matmul_v5.py +++ b/examples/hopper_matmul/matmul_v5.py @@ -1,16 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# v5: TMA epilogue — replaces direct register-to-global stores with a -# shared-memory staging buffer written out via TMA bulk store. -# -# Changes from v4: -# - Pre-allocates s_c[block_m, block_n] alongside s_a/s_b. -# - After the K loop, stores the float16-cast accumulator to s_c, then -# issues a TMA shared->global transfer instead of store_global. -# - A fence.proxy_async(space="shared") between store_shared and TMA is -# required so the generic-proxy writes are visible to the async proxy. - import math import pandas @@ -70,20 +60,14 @@ def consumer_advance(self): self.consumer_phase = self.consumer_phase ^ (self.consumer_stage == 0) def prev_consumer_barrier(self) -> RegisterTensor: - # Use (consumer_stage + num_stages - 1) so num_stages - 1 is evaluated as a - # Python int constant first, ensuring the inner expression is always non-negative - # and avoids C's negative-dividend truncated-modulo behaviour. prev_stage = (self.consumer_stage + (self.num_stages - 1)) % self.num_stages return self.empty_barriers[prev_stage] -# Keep block_n=64 / [256,128] in the space: the TMA epilogue stages s_c via -# store_shared, and only certain (m,n) shapes give a single-swizzle s_c -# layout that tma.shared_to_global can consume. Trimming those out leaves -# every remaining schedule failing layout validation. -@tilus.autotune("num_stages", [2, 3, 4, 5]) +# block_m must be >= 128 so each WG's WGMMA M = block_m/2 >= 64. +@tilus.autotune("num_stages", [3, 4, 5, 6]) @tilus.autotune( - "block_m, block_n", [[128, 64], [128, 128], [128, 256], [256, 128], [256, 256]] + "block_m, block_n", [[128, 128], [128, 256], [256, 128], [256, 256]] ) @tilus.autotune("block_k", [16, 32, 64]) @tilus.autotune("swizzle_size", [4, 8]) @@ -126,11 +110,12 @@ def __call__( ): num_stages = self.num_stages block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + block_m_half = block_m // 2 num_m_blocks = cdiv(m_size, block_m) num_n_blocks = cdiv(n_size, block_n) self.attrs.blocks = num_m_blocks * num_n_blocks - self.attrs.warps = 5 + self.attrs.warps = 9 # 1 producer + 2 consumer WGs (4 warps each) m_block, n_block = self.compute_block_coord( self.blockIdx.x, num_m_blocks, num_n_blocks @@ -141,30 +126,38 @@ def __call__( ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) - sa = self.shared_tensor(dtype=float16, shape=[num_stages, block_m, block_k]) + # Per-WG A slab: index as sa[stage, wg_idx]. + sa = self.shared_tensor( + dtype=float16, shape=[num_stages, 2, block_m_half, block_k] + ) sb = self.shared_tensor(dtype=float16, shape=[num_stages, block_n, block_k]) - sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n]) - acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) tma_pipe = Pipeline( - num_stages, producer_arrive_count=1, consumer_arrive_count=128 + num_stages, producer_arrive_count=1, consumer_arrive_count=256 ) - with self.thread_group(thread_begin=128, num_threads=32): # TMA producer warp + with self.thread_group(thread_begin=256, num_threads=32): # TMA producer for offset_k in self.range(0, k_size, block_k, unroll=num_stages): tma_pipe.producer_acquire() with self.single_thread(): self.mbarrier.arrive_and_expect_tx( tma_pipe.producer_barrier(), - transaction_bytes=sa[tma_pipe.producer_stage].nbytes + transaction_bytes=sa[tma_pipe.producer_stage, 0].nbytes + + sa[tma_pipe.producer_stage, 1].nbytes + sb[tma_pipe.producer_stage].nbytes, ) self.tma.global_to_shared( src=ga, - dst=sa[tma_pipe.producer_stage], + dst=sa[tma_pipe.producer_stage, 0], offsets=[offset_m, offset_k], mbarrier=tma_pipe.producer_barrier(), ) + self.tma.global_to_shared( + src=ga, + dst=sa[tma_pipe.producer_stage, 1], + offsets=[offset_m + block_m_half, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) self.tma.global_to_shared( src=gb, dst=sb[tma_pipe.producer_stage], @@ -177,11 +170,16 @@ def __call__( tma_pipe.producer_acquire() tma_pipe.producer_advance() - with self.thread_group(thread_begin=0, num_threads=128): # WGMMA consumer + with self.thread_group(thread_begin=0, num_threads=128): # consumer WG0 + acc0 = self.register_tensor( + dtype=float32, shape=[block_m_half, block_n], init=0.0 + ) tma_pipe.consumer_acquire() self.wgmma.fence() self.wgmma.mma( - sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc + sa[tma_pipe.consumer_stage, 0], + sb[tma_pipe.consumer_stage].transpose(), + acc0, ) self.wgmma.commit_group() tma_pipe.consumer_advance() @@ -190,9 +188,9 @@ def __call__( tma_pipe.consumer_acquire() self.wgmma.fence() self.wgmma.mma( - sa[tma_pipe.consumer_stage], + sa[tma_pipe.consumer_stage, 0], sb[tma_pipe.consumer_stage].transpose(), - acc, + acc0, ) self.wgmma.commit_group() self.wgmma.wait_group(1) @@ -202,17 +200,43 @@ def __call__( self.wgmma.wait_group(0) self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) - self.sync() - casted_acc = self.cast(acc, dtype=float16) - self.store_shared(sc, casted_acc) - self.fence.proxy_async(space="shared") - self.sync() - with self.single_warp(): - self.tma.shared_to_global( - sc, gc, offsets=[offset_m, offset_n], dims=[0, 1] + casted0 = self.cast(acc0, dtype=float16) + self.store_global(gc, casted0, offsets=[offset_m, offset_n]) + + with self.thread_group(thread_begin=128, num_threads=128): # consumer WG1 + acc1 = self.register_tensor( + dtype=float32, shape=[block_m_half, block_n], init=0.0 + ) + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma( + sa[tma_pipe.consumer_stage, 1], + sb[tma_pipe.consumer_stage].transpose(), + acc1, + ) + self.wgmma.commit_group() + tma_pipe.consumer_advance() + + for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma( + sa[tma_pipe.consumer_stage, 1], + sb[tma_pipe.consumer_stage].transpose(), + acc1, ) - self.tma.commit_group() - self.tma.wait_group(n=0, read=True) + self.wgmma.commit_group() + self.wgmma.wait_group(1) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + tma_pipe.consumer_advance() + + self.wgmma.wait_group(0) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + + casted1 = self.cast(acc1, dtype=float16) + self.store_global( + gc, casted1, offsets=[offset_m + block_m_half, offset_n] + ) def main(): From af76ccc11a5fa5be333f8beacb5a66baf9cea49c Mon Sep 17 00:00:00 2001 From: William Zhang Date: Fri, 8 May 2026 00:30:01 +0000 Subject: [PATCH 11/21] add swiglu per token quant kernel Signed-off-by: William Zhang --- .../swiglu_forward_and_per_token_cast.py | 317 ++++++++++++++++++ .../include/tilus/tvm/ffi/extra_type_traits.h | 27 ++ python/tilus/lang/instantiated_script.py | 13 +- tests/examples/test_examples.py | 1 + 4 files changed, 356 insertions(+), 2 deletions(-) create mode 100644 examples/quantization/swiglu_forward_and_per_token_cast.py diff --git a/examples/quantization/swiglu_forward_and_per_token_cast.py b/examples/quantization/swiglu_forward_and_per_token_cast.py new file mode 100644 index 00000000..eb626fbd --- /dev/null +++ b/examples/quantization/swiglu_forward_and_per_token_cast.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Fused SwiGLU forward with per-token FP8 cast. + +This is a Tilus translation of DeepSeek TileKernels' +``swiglu_forward_and_per_token_cast_kernel.py``. It computes + + out = silu(x[:, :hidden]) * x[:, hidden:] + +optionally applies a routing weight and expert mask, then quantizes each +``num_per_channels`` group to FP8 e4m3 with one float32 scale factor per +token/group. +""" + +import pandas +import tilus +import torch +from tilus import float16, float32, float8_e4m3, int32 +from tilus.utils import benchmark_func, cdiv + + +@tilus.autotune("block_n", [128]) +@tilus.autotune("warps", [4, 8]) +class SwiGLUForwardAndPerTokenCast(tilus.Script): + def __init__( + self, + block_n: int, + warps: int, + with_weight: bool = True, + with_pos_to_expert: bool = True, + use_clamp: bool = True, + num_per_channels: int = 128, + ): + super().__init__() + self.block_m = 1 + self.block_n = block_n + self.warps = warps + self.with_weight = with_weight + self.with_pos_to_expert = with_pos_to_expert + self.use_clamp = use_clamp + self.num_per_channels = num_per_channels + + def __call__( + self, + num_expanded_tokens: int, + hidden: int32, + num_topk_values: int32, + x_ptr: ~float16, + out_ptr: ~float8_e4m3, + out_sf_ptr: ~float32, + pos_to_token_topk_ptr: ~int32, + topk_weights_ptr: ~float32, + pos_to_expert_ptr: ~int32, + swiglu_clamp_value: float32, + ): + self.attrs.blocks = ( + cdiv(num_expanded_tokens, self.block_m), + cdiv(hidden, self.block_n), + ) + self.attrs.warps = self.warps + + offset_m = self.blockIdx.x * self.block_m + offset_n = self.blockIdx.y * self.block_n + sf_col = offset_n // self.num_per_channels + + g_x = self.global_view( + x_ptr, + dtype=float16, + shape=[num_expanded_tokens, hidden * 2], + ) + g_out = self.global_view( + out_ptr, + dtype=float8_e4m3, + shape=[num_expanded_tokens, hidden], + ) + g_out_sf = self.global_view( + out_sf_ptr, + dtype=float32, + shape=[num_expanded_tokens, cdiv(hidden, self.num_per_channels)], + ) + g_pos_to_token_topk = self.global_view( + pos_to_token_topk_ptr, + dtype=int32, + shape=[num_expanded_tokens], + ) + g_topk_weights = self.global_view( + topk_weights_ptr, + dtype=float32, + shape=[num_topk_values], + ) + g_pos_to_expert = self.global_view( + pos_to_expert_ptr, + dtype=int32, + shape=[num_expanded_tokens], + ) + + if (not self.with_pos_to_expert) or g_pos_to_expert[offset_m].item() >= 0: + r_l = self.load_global( + g_x, + offsets=[offset_m, offset_n], + shape=[self.block_m, self.block_n], + ).to(float32) + r_r = self.load_global( + g_x, + offsets=[offset_m, offset_n + hidden], + shape=[self.block_m, self.block_n], + ).to(float32) + + if self.use_clamp: + r_l = self.where(r_l > swiglu_clamp_value, x=swiglu_clamp_value, y=r_l) + r_r = self.where(r_r > swiglu_clamp_value, x=swiglu_clamp_value, y=r_r) + r_r = self.where(r_r < -swiglu_clamp_value, x=-swiglu_clamp_value, y=r_r) + + r_silu = r_l / (self.exp(-r_l) + 1.0) + r_value = r_silu * r_r + + if self.with_weight: + topk_pos = g_pos_to_token_topk[offset_m].item() + if topk_pos >= 0: + topk_weight = g_topk_weights[topk_pos].item() + r_value = r_value * topk_weight + + r_absmax = self.max(self.abs(r_value), dim=1, keepdim=True) + r_fp8_max = self.register_tensor( + dtype=float32, + shape=[self.block_m, 1], + init=448.0, + ) + r_scale = self.where(r_absmax > 0.0, x=r_absmax / 448.0, y=1.0) + r_inv_scale = self.where(r_absmax > 0.0, x=r_fp8_max / r_absmax, y=1.0) + + self.store_global(g_out_sf, r_scale, offsets=[offset_m, sf_col]) + self.store_global( + g_out, + (r_value * r_inv_scale).to(float8_e4m3), + offsets=[offset_m, offset_n], + ) + + +def swiglu_reference( + x: torch.Tensor, + pos_to_token_topk: torch.Tensor, + topk_weights: torch.Tensor, + pos_to_expert: torch.Tensor, + clamp_value: float, + num_per_channels: int, +) -> tuple[torch.Tensor, torch.Tensor]: + hidden = x.shape[1] // 2 + x_l, x_r = x[:, :hidden].float(), x[:, hidden:].float() + x_l = torch.minimum(x_l, torch.tensor(clamp_value, device=x.device)) + x_r = torch.clamp(x_r, min=-clamp_value, max=clamp_value) + y = torch.nn.functional.silu(x_l) * x_r + + valid_weight = pos_to_token_topk >= 0 + weights = torch.ones(x.shape[0], dtype=torch.float32, device=x.device) + weights[valid_weight] = topk_weights.flatten()[pos_to_token_topk[valid_weight]] + y = y * weights[:, None] + + valid_expert = pos_to_expert >= 0 + y = torch.where(valid_expert[:, None], y, torch.zeros_like(y)) + + grouped = y.reshape(x.shape[0], hidden // num_per_channels, num_per_channels) + scales = grouped.abs().amax(dim=2) / 448.0 + scales = torch.where(scales > 0.0, scales, torch.ones_like(scales)) + out = (grouped / scales[:, :, None]).reshape_as(y).to(torch.float8_e4m3fn) + return out, scales + + +def dequantized_sum(out: torch.Tensor, scales: torch.Tensor, num_per_channels: int) -> torch.Tensor: + grouped = out.float().reshape( + out.shape[0], + out.shape[1] // num_per_channels, + num_per_channels, + ) + return (grouped * scales[:, :, None]).sum() + + +def main(): + rows = [] + headers = [ + "tokens", + "hidden", + "torch (ms)", + "tilus (ms)", + "speedup", + "sum diff", + ] + + for num_expanded_tokens, hidden, num_tokens, num_topk in [ + (128, 1024, 64, 2), + (256, 2048, 128, 2), + ]: + num_per_channels = 128 + kernel = SwiGLUForwardAndPerTokenCast(num_per_channels=num_per_channels) + + x = ( + torch.randn( + num_expanded_tokens, + hidden * 2, + device="cuda", + dtype=torch.float16, + ) + * 2.0 + ).contiguous() + pos_to_token_topk = torch.arange( + num_expanded_tokens, + device="cuda", + dtype=torch.int32, + ) % (num_tokens * num_topk) + topk_weights = torch.rand( + num_tokens, + num_topk, + device="cuda", + dtype=torch.float32, + ) + pos_to_expert = torch.ones(num_expanded_tokens, device="cuda", dtype=torch.int32) + pos_to_expert[::17] = -1 + + out = torch.empty( + (num_expanded_tokens, hidden), + device="cuda", + dtype=torch.float8_e4m3fn, + ) + out_sf = torch.empty( + (num_expanded_tokens, hidden // num_per_channels), + device="cuda", + dtype=torch.float32, + ) + + clamp_value = 6.0 + kernel( + num_expanded_tokens, + hidden, + num_tokens * num_topk, + x, + out, + out_sf, + pos_to_token_topk, + topk_weights, + pos_to_expert, + clamp_value, + ) + + expected_out, expected_sf = swiglu_reference( + x, + pos_to_token_topk, + topk_weights, + pos_to_expert, + clamp_value, + num_per_channels, + ) + valid = pos_to_expert >= 0 + torch.testing.assert_close( + out[valid].float(), + expected_out[valid].float(), + atol=1.0, + rtol=0.0, + ) + torch.testing.assert_close( + out_sf[valid], + expected_sf[valid], + atol=1e-5, + rtol=1e-5, + ) + actual_sum = dequantized_sum(out[valid], out_sf[valid], num_per_channels) + expected_sum = dequantized_sum( + expected_out[valid], + expected_sf[valid], + num_per_channels, + ) + torch.testing.assert_close(actual_sum, expected_sum, atol=1e-2, rtol=1e-4) + sum_diff = (actual_sum - expected_sum).abs().item() + + torch_ms = benchmark_func( + lambda: swiglu_reference( + x, + pos_to_token_topk, + topk_weights, + pos_to_expert, + clamp_value, + num_per_channels, + ) + ) + tilus_ms = benchmark_func( + lambda: kernel( + num_expanded_tokens, + hidden, + num_tokens * num_topk, + x, + out, + out_sf, + pos_to_token_topk, + topk_weights, + pos_to_expert, + clamp_value, + ) + ) + rows.append( + [ + num_expanded_tokens, + hidden, + torch_ms, + tilus_ms, + f"{torch_ms / tilus_ms:.2f}x", + sum_diff, + ] + ) + print( + "SwiGLU FP8 cast matches reference for size " + f"({num_expanded_tokens}, {hidden}); dequantized sum diff={sum_diff:.6g}" + ) + + print(pandas.DataFrame(rows, columns=headers)) + + +if __name__ == "__main__": + main() diff --git a/python/tilus/hidet/include/tilus/tvm/ffi/extra_type_traits.h b/python/tilus/hidet/include/tilus/tvm/ffi/extra_type_traits.h index ac8b7a92..7569da75 100644 --- a/python/tilus/hidet/include/tilus/tvm/ffi/extra_type_traits.h +++ b/python/tilus/hidet/include/tilus/tvm/ffi/extra_type_traits.h @@ -22,6 +22,9 @@ #include #include +#include +#include + #include "void_p.h" namespace tvm { @@ -85,6 +88,30 @@ struct TypeTraits<__nv_bfloat16*> : public FallbackOnlyTraitsBase<__nv_bfloat16* } }; +template <> +struct TypeTraits : public FallbackOnlyTraitsBase { + TVM_FFI_INLINE static std::string TypeStr() { return "float8_e4m3*"; } + + TVM_FFI_INLINE static float8_e4m3* ConvertFallbackValue(DLTensor* src) { + if (src->dtype.code != kDLFloat8_e4m3fn || src->dtype.bits != 8) { + TVM_FFI_THROW(ValueError) << "Expect a tensor with 8 bit float8_e4m3, got a tensor with dtype " << dtype_to_str(src->dtype); + } + return reinterpret_cast(src->data); + } +}; + +template <> +struct TypeTraits : public FallbackOnlyTraitsBase { + TVM_FFI_INLINE static std::string TypeStr() { return "float8_e5m2*"; } + + TVM_FFI_INLINE static float8_e5m2* ConvertFallbackValue(DLTensor* src) { + if (src->dtype.code != kDLFloat8_e5m2 || src->dtype.bits != 8) { + TVM_FFI_THROW(ValueError) << "Expect a tensor with 8 bit float8_e5m2, got a tensor with dtype " << dtype_to_str(src->dtype); + } + return reinterpret_cast(src->data); + } +}; + // Template specialization for float*, double* template struct TypeTraits>> : public FallbackOnlyTraitsBase { diff --git a/python/tilus/lang/instantiated_script.py b/python/tilus/lang/instantiated_script.py index 0f3a9092..d6ac4dcc 100644 --- a/python/tilus/lang/instantiated_script.py +++ b/python/tilus/lang/instantiated_script.py @@ -666,11 +666,20 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: # cloning preserves those bits — so a slot that was already # non-finite cannot fail the gate. nan_gate_enabled: bool = bool(tilus.option.get_option("autotune_nan_gate")) + + def all_finite_or_unchecked(t: torch.Tensor) -> bool: + if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + return True + try: + return bool(torch.isfinite(t).all().item()) + except NotImplementedError: + return True + pre_finite: list[bool] = [] for j in self.call_params.kernel_params: a_j = args[j] if isinstance(a_j, torch.Tensor) and a_j.is_floating_point(): - pre_finite.append(bool(torch.isfinite(a_j).all().item())) + pre_finite.append(all_finite_or_unchecked(a_j)) else: pre_finite.append(True) for i, compiled_program in tqdm( @@ -710,7 +719,7 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: if ( isinstance(t, torch.Tensor) and t.is_floating_point() - and not torch.isfinite(t).all() + and not all_finite_or_unchecked(t) ): lat = float("inf") break diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 690878a7..de978c94 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -68,6 +68,7 @@ ("hopper_matmul", "matmul_v5.py", nvgpu_sm90a), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), + ("quantization", "swiglu_forward_and_per_token_cast.py", nvgpu_sm90a), # flash attention decode examples (SM 8.0+) ("flash_attention_decode", "main.py", nvgpu_sm80), ] From b22e56a6c36fc1b9305e9db559c4c4d0e0a3fd4e Mon Sep 17 00:00:00 2001 From: William Zhang Date: Fri, 8 May 2026 01:00:24 +0000 Subject: [PATCH 12/21] add per token cast Signed-off-by: William Zhang --- examples/quantization/per_token_cast.py | 175 ++++++++++++++++++ .../swiglu_forward_and_per_token_cast.py | 2 +- tests/examples/test_examples.py | 1 + 3 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 examples/quantization/per_token_cast.py diff --git a/examples/quantization/per_token_cast.py b/examples/quantization/per_token_cast.py new file mode 100644 index 00000000..0a1ae1ee --- /dev/null +++ b/examples/quantization/per_token_cast.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Per-token FP8 cast with scale factors. + +This is a Tilus translation of DeepSeek TileKernels' +``per_token_cast_kernel.py`` for the common FP16 -> FP8 e4m3 path. Each CTA +processes one token and one channel group, computes the absolute maximum within +that group, stores a float32 scale factor, and writes the scaled FP8 output. +""" + +import pandas +import tilus +import torch +from tilus import float8_e4m3, float16, float32, int32 +from tilus.utils import benchmark_func, cdiv + + +@tilus.autotune("block_n", [128]) +@tilus.autotune("warps", [4, 8]) +class PerTokenCast(tilus.Script): + def __init__(self, block_n: int, warps: int, num_per_channels: int = 128): + super().__init__() + self.block_m = 1 + self.block_n = block_n + self.warps = warps + self.num_per_channels = num_per_channels + + def __call__( + self, + num_tokens: int, + hidden: int32, + x_ptr: ~float16, + out_ptr: ~float8_e4m3, + out_sf_ptr: ~float32, + ): + self.attrs.blocks = ( + cdiv(num_tokens, self.block_m), + cdiv(hidden, self.block_n), + ) + self.attrs.warps = self.warps + + offset_m = self.blockIdx.x * self.block_m + offset_n = self.blockIdx.y * self.block_n + sf_col = offset_n // self.num_per_channels + + g_x = self.global_view( + x_ptr, + dtype=float16, + shape=[num_tokens, hidden], + ) + g_out = self.global_view( + out_ptr, + dtype=float8_e4m3, + shape=[num_tokens, hidden], + ) + g_out_sf = self.global_view( + out_sf_ptr, + dtype=float32, + shape=[num_tokens, cdiv(hidden, self.num_per_channels)], + ) + + r_x = self.load_global( + g_x, + offsets=[offset_m, offset_n], + shape=[self.block_m, self.block_n], + ).to(float32) + + r_absmax = self.max(self.abs(r_x), dim=1, keepdim=True) + r_fp8_max = self.register_tensor( + dtype=float32, + shape=[self.block_m, 1], + init=448.0, + ) + r_scale = self.where(r_absmax > 0.0, x=r_absmax / 448.0, y=1.0) + r_inv_scale = self.where(r_absmax > 0.0, x=r_fp8_max / r_absmax, y=1.0) + + self.store_global(g_out_sf, r_scale, offsets=[offset_m, sf_col]) + self.store_global( + g_out, + (r_x * r_inv_scale).to(float8_e4m3), + offsets=[offset_m, offset_n], + ) + + +def per_token_cast_reference( + x: torch.Tensor, + num_per_channels: int, +) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens, hidden = x.shape + grouped = x.float().reshape(num_tokens, hidden // num_per_channels, num_per_channels) + scales = grouped.abs().amax(dim=2) / 448.0 + scales = torch.where(scales > 0.0, scales, torch.ones_like(scales)) + out = (grouped / scales[:, :, None]).reshape_as(x).to(torch.float8_e4m3fn) + return out, scales + + +def dequantized_sum(out: torch.Tensor, scales: torch.Tensor, num_per_channels: int) -> torch.Tensor: + grouped = out.float().reshape( + out.shape[0], + out.shape[1] // num_per_channels, + num_per_channels, + ) + return (grouped * scales[:, :, None]).sum() + + +def main(): + rows = [] + headers = [ + "tokens", + "hidden", + "torch (ms)", + "tilus (ms)", + "speedup", + "sum diff", + ] + + for num_tokens, hidden in [ + (128, 1024), + (256, 2048), + (257, 4096), + ]: + num_per_channels = 128 + kernel = PerTokenCast(num_per_channels=num_per_channels) + + x = ( + torch.randn( + num_tokens, + hidden, + device="cuda", + dtype=torch.float16, + ) + * 2.0 + ).contiguous() + out = torch.empty((num_tokens, hidden), device="cuda", dtype=torch.float8_e4m3fn) + out_sf = torch.empty( + (num_tokens, hidden // num_per_channels), + device="cuda", + dtype=torch.float32, + ) + + kernel(num_tokens, hidden, x, out, out_sf) + expected_out, expected_sf = per_token_cast_reference(x, num_per_channels) + + max_code_diff = (out.float() - expected_out.float()).abs().max().item() + assert max_code_diff <= 32.0, f"max decoded FP8 code diff is {max_code_diff}" + torch.testing.assert_close(out_sf, expected_sf, atol=1e-5, rtol=1e-5) + + actual_sum = dequantized_sum(out, out_sf, num_per_channels) + expected_sum = dequantized_sum(expected_out, expected_sf, num_per_channels) + torch.testing.assert_close(actual_sum, expected_sum, atol=2.0, rtol=2e-2) + sum_diff = (actual_sum - expected_sum).abs().item() + + torch_ms = benchmark_func(lambda: per_token_cast_reference(x, num_per_channels)) + tilus_ms = benchmark_func(lambda: kernel(num_tokens, hidden, x, out, out_sf)) + rows.append( + [ + num_tokens, + hidden, + torch_ms, + tilus_ms, + f"{torch_ms / tilus_ms:.2f}x", + sum_diff, + ] + ) + print( + "Per-token FP8 cast matches reference for size " + f"({num_tokens}, {hidden}); max code diff={max_code_diff:.6g}; " + f"dequantized sum diff={sum_diff:.6g}" + ) + + print(pandas.DataFrame(rows, columns=headers)) + + +if __name__ == "__main__": + main() diff --git a/examples/quantization/swiglu_forward_and_per_token_cast.py b/examples/quantization/swiglu_forward_and_per_token_cast.py index eb626fbd..50def68e 100644 --- a/examples/quantization/swiglu_forward_and_per_token_cast.py +++ b/examples/quantization/swiglu_forward_and_per_token_cast.py @@ -15,7 +15,7 @@ import pandas import tilus import torch -from tilus import float16, float32, float8_e4m3, int32 +from tilus import float8_e4m3, float16, float32, int32 from tilus.utils import benchmark_func, cdiv diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index de978c94..dc0e7d94 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -68,6 +68,7 @@ ("hopper_matmul", "matmul_v5.py", nvgpu_sm90a), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), + ("quantization", "per_token_cast.py", nvgpu_sm90a), ("quantization", "swiglu_forward_and_per_token_cast.py", nvgpu_sm90a), # flash attention decode examples (SM 8.0+) ("flash_attention_decode", "main.py", nvgpu_sm80), From e6c70efa2e6cc666fbe0b2ddda35180726b78541 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Fri, 8 May 2026 01:00:44 +0000 Subject: [PATCH 13/21] reformat from linting Signed-off-by: William Zhang --- examples/hopper_matmul/matmul_v5.py | 8 ++------ examples/quantization/per_token_cast.py | 4 +++- .../quantization/swiglu_forward_and_per_token_cast.py | 4 +++- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/hopper_matmul/matmul_v5.py b/examples/hopper_matmul/matmul_v5.py index a15efe2a..625ef0a0 100644 --- a/examples/hopper_matmul/matmul_v5.py +++ b/examples/hopper_matmul/matmul_v5.py @@ -66,9 +66,7 @@ def prev_consumer_barrier(self) -> RegisterTensor: # block_m must be >= 128 so each WG's WGMMA M = block_m/2 >= 64. @tilus.autotune("num_stages", [3, 4, 5, 6]) -@tilus.autotune( - "block_m, block_n", [[128, 128], [128, 256], [256, 128], [256, 256]] -) +@tilus.autotune("block_m, block_n", [[128, 128], [128, 256], [256, 128], [256, 256]]) @tilus.autotune("block_k", [16, 32, 64]) @tilus.autotune("swizzle_size", [4, 8]) class MatmulWGMMAV5(tilus.Script): @@ -234,9 +232,7 @@ def __call__( self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) casted1 = self.cast(acc1, dtype=float16) - self.store_global( - gc, casted1, offsets=[offset_m + block_m_half, offset_n] - ) + self.store_global(gc, casted1, offsets=[offset_m + block_m_half, offset_n]) def main(): diff --git a/examples/quantization/per_token_cast.py b/examples/quantization/per_token_cast.py index 0a1ae1ee..bbbdbba6 100644 --- a/examples/quantization/per_token_cast.py +++ b/examples/quantization/per_token_cast.py @@ -94,7 +94,9 @@ def per_token_cast_reference( return out, scales -def dequantized_sum(out: torch.Tensor, scales: torch.Tensor, num_per_channels: int) -> torch.Tensor: +def dequantized_sum( + out: torch.Tensor, scales: torch.Tensor, num_per_channels: int +) -> torch.Tensor: grouped = out.float().reshape( out.shape[0], out.shape[1] // num_per_channels, diff --git a/examples/quantization/swiglu_forward_and_per_token_cast.py b/examples/quantization/swiglu_forward_and_per_token_cast.py index 50def68e..82cd2bfa 100644 --- a/examples/quantization/swiglu_forward_and_per_token_cast.py +++ b/examples/quantization/swiglu_forward_and_per_token_cast.py @@ -166,7 +166,9 @@ def swiglu_reference( return out, scales -def dequantized_sum(out: torch.Tensor, scales: torch.Tensor, num_per_channels: int) -> torch.Tensor: +def dequantized_sum( + out: torch.Tensor, scales: torch.Tensor, num_per_channels: int +) -> torch.Tensor: grouped = out.float().reshape( out.shape[0], out.shape[1] // num_per_channels, From fe3074f9c838a08646eccae0eb5fed45a02ad4be Mon Sep 17 00:00:00 2001 From: William Zhang Date: Mon, 11 May 2026 01:36:50 +0000 Subject: [PATCH 14/21] fix mypy error, move emitter function Signed-off-by: William Zhang --- .../swiglu_forward_and_per_token_cast.py | 7 ++++- python/tilus/backends/emitter.py | 26 ------------------- .../backends/emitters/cuda/cp_async_tensor.py | 25 ++++++++++++++++++ 3 files changed, 31 insertions(+), 27 deletions(-) diff --git a/examples/quantization/swiglu_forward_and_per_token_cast.py b/examples/quantization/swiglu_forward_and_per_token_cast.py index 82cd2bfa..52d1750a 100644 --- a/examples/quantization/swiglu_forward_and_per_token_cast.py +++ b/examples/quantization/swiglu_forward_and_per_token_cast.py @@ -107,9 +107,14 @@ def __call__( ).to(float32) if self.use_clamp: + negative_swiglu_clamp_value = 0.0 - swiglu_clamp_value r_l = self.where(r_l > swiglu_clamp_value, x=swiglu_clamp_value, y=r_l) r_r = self.where(r_r > swiglu_clamp_value, x=swiglu_clamp_value, y=r_r) - r_r = self.where(r_r < -swiglu_clamp_value, x=-swiglu_clamp_value, y=r_r) + r_r = self.where( + r_r < negative_swiglu_clamp_value, + x=negative_swiglu_clamp_value, + y=r_r, + ) r_silu = r_l / (self.exp(-r_l) + 1.0) r_value = r_silu * r_r diff --git a/python/tilus/backends/emitter.py b/python/tilus/backends/emitter.py index 955197cb..4078b982 100644 --- a/python/tilus/backends/emitter.py +++ b/python/tilus/backends/emitter.py @@ -59,32 +59,6 @@ def assert_is_warp_aligned(self, inst: Instruction, msg: str) -> None: f"got thread_begin={self.current_thread_group_begin}, num_threads={self.current_num_threads}: {msg}." ) - def assert_is_single_thread_or_warp_aligned(self, inst: Instruction, msg: str) -> None: - # TMA copies must be issued by exactly one thread. The user can express - # that with single_thread() (num_threads == 1), or at warp scope where the - # `@pred` predicate selects the elected lane. Both are valid; reject only - # multi-thread non-warp contexts. - if self.current_num_threads == 1: - return - if self.current_num_threads != 32 or self.current_thread_group_begin % 32 != 0: - raise ValueError( - f"Instruction {inst} requires a single-thread or warp-aligned context " - f"(num_threads==1, or thread_begin % 32 == 0 and num_threads == 32), " - f"got thread_begin={self.current_thread_group_begin}, num_threads={self.current_num_threads}: {msg}." - ) - - @property - def tma_predicate(self): - # Inside single_thread() only one thread runs the TMA call, so the - # @pred predicate is the constant 1. At warp scope we still need to - # select a single lane, so use the elected leader-lane predicate to - # avoid an if-branch divergence. - from tilus.hidet.ir.dtypes import uint32 as _u32 - - if self.current_num_threads == 1: - return _u32(1) - return self.contexts.leader_lane_ctx.leader_lane - def sync(self): optional_sync_call = self.contexts.sync_ctx.sync() if optional_sync_call is not None: diff --git a/python/tilus/backends/emitters/cuda/cp_async_tensor.py b/python/tilus/backends/emitters/cuda/cp_async_tensor.py index fb5d82d7..db077f28 100644 --- a/python/tilus/backends/emitters/cuda/cp_async_tensor.py +++ b/python/tilus/backends/emitters/cuda/cp_async_tensor.py @@ -43,6 +43,7 @@ from tilus.hidet.ir.tools import rewrite, simplify from tilus.hidet.ir.type import DataType, PointerType, TensorType, sizeof from tilus.ir import GlobalLayout +from tilus.ir.inst import Instruction from tilus.ir.instructions.cuda.cp_async_tensor import ( CopyAsyncTensorCommitGroupInst, CopyAsyncTensorGlobalToSharedInst, @@ -151,6 +152,30 @@ def get_offset_grid_of_swizzled_layout( class CopyAsyncTensorBaseEmitter(BaseInstEmitter): + def assert_is_single_thread_or_warp_aligned(self, inst: Instruction, msg: str) -> None: + # TMA copies must be issued by exactly one thread. The user can express + # that with single_thread() (num_threads == 1), or at warp scope where the + # `@pred` predicate selects the elected lane. Both are valid; reject only + # multi-thread non-warp contexts. + if self.current_num_threads == 1: + return + if self.current_num_threads != 32 or self.current_thread_group_begin % 32 != 0: + raise ValueError( + f"Instruction {inst} requires a single-thread or warp-aligned context " + f"(num_threads==1, or thread_begin % 32 == 0 and num_threads == 32), " + f"got thread_begin={self.current_thread_group_begin}, num_threads={self.current_num_threads}: {msg}." + ) + + @property + def tma_predicate(self) -> Expr: + # Inside single_thread() only one thread runs the TMA call, so the + # @pred predicate is the constant 1. At warp scope we still need to + # select a single lane, so use the elected leader-lane predicate to + # avoid an if-branch divergence. + if self.current_num_threads == 1: + return uint32(1) + return self.contexts.leader_lane_ctx.leader_lane + def resolve_global_tensor_info( self, global_tensor: GlobalTensor, offsets: Sequence[Expr], dims: Sequence[int] ) -> GlobalTensorInfo: From 95ee4ad2a9ae4a3bcb6ddb29ce230b377f05aa34 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Mon, 11 May 2026 21:26:26 +0000 Subject: [PATCH 15/21] remove nan gate Signed-off-by: William Zhang --- python/tilus/lang/instantiated_script.py | 52 +++--------------------- python/tilus/option.py | 11 ----- 2 files changed, 5 insertions(+), 58 deletions(-) diff --git a/python/tilus/lang/instantiated_script.py b/python/tilus/lang/instantiated_script.py index 0f3a9092..9f33fbb2 100644 --- a/python/tilus/lang/instantiated_script.py +++ b/python/tilus/lang/instantiated_script.py @@ -44,9 +44,9 @@ logger = logging.getLogger(__name__) -# Bump when tuner semantics change (e.g. correctness gates, new selection -# criteria). On bump, dispatch_table.json files written under the prior version -# are ignored and tuning re-runs, so users don't have to manually delete cache. +# Bump when tuner semantics change. On bump, dispatch_table.json files written +# under the prior version are ignored and tuning re-runs, so users don't have to +# manually delete cache. _TUNER_VERSION = 2 @@ -655,24 +655,6 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: latency.append(0.0) else: tuning_key_name = " " + "-".join([str(v) for v in tuning_key]) if tuning_key else "" - # Clone tensor args fresh for each candidate. The clone serves two - # purposes: shield the user's buffers from mutation, and prevent - # one candidate's output (e.g. NaN) from becoming the next - # candidate's input. - # Snapshot pre-kernel finiteness so the correctness gate flags - # only finite→non-finite transitions caused by the kernel. - # User-supplied output buffers (e.g. torch.empty()) may already - # contain NaN/Inf bit patterns from uninitialized memory, and - # cloning preserves those bits — so a slot that was already - # non-finite cannot fail the gate. - nan_gate_enabled: bool = bool(tilus.option.get_option("autotune_nan_gate")) - pre_finite: list[bool] = [] - for j in self.call_params.kernel_params: - a_j = args[j] - if isinstance(a_j, torch.Tensor) and a_j.is_floating_point(): - pre_finite.append(bool(torch.isfinite(a_j).all().item())) - else: - pre_finite.append(True) for i, compiled_program in tqdm( iterable=enumerate(self.compiled_programs), desc="[{}] {}{}".format("Tuning", self.instance_name, tuning_key_name), @@ -698,33 +680,9 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: "Error message:\n" f" {str(e)}" ) from e - # Correctness gate: reject candidates that flip a kernel - # tensor arg from finite to non-finite. Slots that were - # already non-finite before the kernel ran are ignored — - # otherwise uninitialized output buffers spuriously fail - # the gate even when the kernel writes valid values. - if nan_gate_enabled: - for slot_idx, t in enumerate(kernel_args): - if not pre_finite[slot_idx]: - continue - if ( - isinstance(t, torch.Tensor) - and t.is_floating_point() - and not torch.isfinite(t).all() - ): - lat = float("inf") - break latency.append(lat) # type: ignore best_latency = min(latency) - if not (best_latency < float("inf")): - raise RuntimeError( - f"Autotune for {self.instance_name} found no schedule that produced finite outputs. " - f"All {len(latency)} candidates flipped a kernel tensor argument from finite to NaN/Inf. " - f"Inspect schedules in {self.cache_dir} or narrow the autotune space. " - "If you want to confirm whether the gate is the cause, re-run with " - "TILUS_AUTOTUNE_NAN_GATE=0 (the autotuner will then accept any candidate)." - ) best_program_idx = latency.index(best_latency) self.dispatch_table[tuning_key] = best_program_idx self.dump_dispatch_table() @@ -756,8 +714,8 @@ def load_dispatch_table(self): payload = json.load(f) # New format is {"tuner_version": int, "entries": [...]}; legacy format # was a bare list. Treat legacy or version-mismatched files as empty so - # changes to tuner semantics (e.g. adding a NaN/Inf rejection gate) - # automatically re-run tuning instead of serving stale picks. + # tuner semantic changes automatically re-run tuning instead of serving + # stale picks. if isinstance(payload, dict) and payload.get("tuner_version") == _TUNER_VERSION: entries = payload["entries"] else: diff --git a/python/tilus/option.py b/python/tilus/option.py index 9ee8291c..d1793391 100644 --- a/python/tilus/option.py +++ b/python/tilus/option.py @@ -90,17 +90,6 @@ def _register_options(): default_value=50, description="The number of repeat iterations for benchmarking during autotuning.", ) - _register_hidet_option( - "tilus.autotune_nan_gate", - type_hint="bool", - default_value=True, - env="TILUS_AUTOTUNE_NAN_GATE", - description=( - "Whether the autotuner rejects schedules that flip a kernel tensor argument from " - "finite to non-finite. Set to 0 to disable for diagnosing whether a tuning failure " - "is caused by the gate or by the kernel actually producing NaN/Inf." - ), - ) _register_options() From 47281f81e90ea0763849eb49265b78b682c6f3ae Mon Sep 17 00:00:00 2001 From: William Zhang Date: Mon, 11 May 2026 23:18:49 +0000 Subject: [PATCH 16/21] add v5 optimization description, restore some prev editions Signed-off-by: William Zhang --- examples/hopper_matmul/matmul_v5.py | 4 ++++ python/tilus/hidet/ir/analyzers/bound_analyzer.py | 4 ++-- python/tilus/lang/instantiated_script.py | 6 +++--- python/tilus/option.py | 1 - 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/hopper_matmul/matmul_v5.py b/examples/hopper_matmul/matmul_v5.py index a15efe2a..32e5e92b 100644 --- a/examples/hopper_matmul/matmul_v5.py +++ b/examples/hopper_matmul/matmul_v5.py @@ -1,6 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# Optimizations from v4: +# - Use mbarrier for synchronization between producer and consumer WGs, instead of using shared memory as flags. +# - Use two consumer WGs to consume the produced tiles in parallel, and use WGMMA commit/wait group to synchronize between them, instead of using a single consumer WG to consume all tiles. + import math import pandas diff --git a/python/tilus/hidet/ir/analyzers/bound_analyzer.py b/python/tilus/hidet/ir/analyzers/bound_analyzer.py index 14a3366f..633f935e 100644 --- a/python/tilus/hidet/ir/analyzers/bound_analyzer.py +++ b/python/tilus/hidet/ir/analyzers/bound_analyzer.py @@ -237,8 +237,8 @@ class BoundAnalyzer(ExprVisitor, StmtVisitor, ModuleVisitor): Add: operator.add, Sub: operator.sub, Multiply: operator.mul, - Mod: operator.mod, # floor-mod; only used for candidate enumeration (exact path) - Div: operator.floordiv, # floor-div; same — only used when candidates are enumerable + Mod: operator.mod, + Div: operator.floordiv, # for the node with BoundInfo, we are sure they are integers } def __init__(self, var2bound: Dict[Expr, BoundInfo] = None): diff --git a/python/tilus/lang/instantiated_script.py b/python/tilus/lang/instantiated_script.py index 9f33fbb2..fff48696 100644 --- a/python/tilus/lang/instantiated_script.py +++ b/python/tilus/lang/instantiated_script.py @@ -44,9 +44,9 @@ logger = logging.getLogger(__name__) -# Bump when tuner semantics change. On bump, dispatch_table.json files written -# under the prior version are ignored and tuning re-runs, so users don't have to -# manually delete cache. +# Bump when tuner semantics change (e.g. correctness gates, new selection +# criteria). On bump, dispatch_table.json files written under the prior version +# are ignored and tuning re-runs, so users don't have to manually delete cache. _TUNER_VERSION = 2 diff --git a/python/tilus/option.py b/python/tilus/option.py index d1793391..e9e6e87b 100644 --- a/python/tilus/option.py +++ b/python/tilus/option.py @@ -61,7 +61,6 @@ def _register_options(): "tilus.parallel_workers", type_hint="int", default_value=os.cpu_count(), - env="TILUS_PARALLEL_WORKERS", description="The number of parallel workers the tilus package could use for parallel jobs.", ) _register_hidet_option( From 8e69f28e11a2bb0be1fd40f6b41889f7549229df Mon Sep 17 00:00:00 2001 From: William Zhang Date: Tue, 12 May 2026 12:39:53 +0000 Subject: [PATCH 17/21] change benchmark to tilekernels Signed-off-by: William Zhang --- examples/quantization/per_token_cast.py | 29 +++++---- .../swiglu_forward_and_per_token_cast.py | 65 +++++++++---------- 2 files changed, 46 insertions(+), 48 deletions(-) diff --git a/examples/quantization/per_token_cast.py b/examples/quantization/per_token_cast.py index bbbdbba6..ddedebc5 100644 --- a/examples/quantization/per_token_cast.py +++ b/examples/quantization/per_token_cast.py @@ -11,6 +11,7 @@ import pandas import tilus import torch +from tile_kernels.quant.per_token_cast_kernel import per_token_cast from tilus import float8_e4m3, float16, float32, int32 from tilus.utils import benchmark_func, cdiv @@ -82,16 +83,11 @@ def __call__( ) -def per_token_cast_reference( +def tilekernels_per_token_cast_reference( x: torch.Tensor, num_per_channels: int, ) -> tuple[torch.Tensor, torch.Tensor]: - num_tokens, hidden = x.shape - grouped = x.float().reshape(num_tokens, hidden // num_per_channels, num_per_channels) - scales = grouped.abs().amax(dim=2) / 448.0 - scales = torch.where(scales > 0.0, scales, torch.ones_like(scales)) - out = (grouped / scales[:, :, None]).reshape_as(x).to(torch.float8_e4m3fn) - return out, scales + return per_token_cast(x, "e4m3", num_per_channels) def dequantized_sum( @@ -110,7 +106,7 @@ def main(): headers = [ "tokens", "hidden", - "torch (ms)", + "tilekernels (ms)", "tilus (ms)", "speedup", "sum diff", @@ -139,9 +135,13 @@ def main(): device="cuda", dtype=torch.float32, ) + x_tilekernels = x.float() kernel(num_tokens, hidden, x, out, out_sf) - expected_out, expected_sf = per_token_cast_reference(x, num_per_channels) + expected_out, expected_sf = tilekernels_per_token_cast_reference( + x_tilekernels, + num_per_channels, + ) max_code_diff = (out.float() - expected_out.float()).abs().max().item() assert max_code_diff <= 32.0, f"max decoded FP8 code diff is {max_code_diff}" @@ -152,15 +152,20 @@ def main(): torch.testing.assert_close(actual_sum, expected_sum, atol=2.0, rtol=2e-2) sum_diff = (actual_sum - expected_sum).abs().item() - torch_ms = benchmark_func(lambda: per_token_cast_reference(x, num_per_channels)) + tilekernels_ms = benchmark_func( + lambda: tilekernels_per_token_cast_reference( + x_tilekernels, + num_per_channels, + ) + ) tilus_ms = benchmark_func(lambda: kernel(num_tokens, hidden, x, out, out_sf)) rows.append( [ num_tokens, hidden, - torch_ms, + tilekernels_ms, tilus_ms, - f"{torch_ms / tilus_ms:.2f}x", + f"{tilekernels_ms / tilus_ms:.2f}x", sum_diff, ] ) diff --git a/examples/quantization/swiglu_forward_and_per_token_cast.py b/examples/quantization/swiglu_forward_and_per_token_cast.py index 52d1750a..d49c2c7f 100644 --- a/examples/quantization/swiglu_forward_and_per_token_cast.py +++ b/examples/quantization/swiglu_forward_and_per_token_cast.py @@ -15,6 +15,9 @@ import pandas import tilus import torch +from tile_kernels.quant.swiglu_forward_and_per_token_cast_kernel import ( + swiglu_forward_and_per_token_cast, +) from tilus import float8_e4m3, float16, float32, int32 from tilus.utils import benchmark_func, cdiv @@ -142,7 +145,7 @@ def __call__( ) -def swiglu_reference( +def tilekernels_swiglu_reference( x: torch.Tensor, pos_to_token_topk: torch.Tensor, topk_weights: torch.Tensor, @@ -150,25 +153,15 @@ def swiglu_reference( clamp_value: float, num_per_channels: int, ) -> tuple[torch.Tensor, torch.Tensor]: - hidden = x.shape[1] // 2 - x_l, x_r = x[:, :hidden].float(), x[:, hidden:].float() - x_l = torch.minimum(x_l, torch.tensor(clamp_value, device=x.device)) - x_r = torch.clamp(x_r, min=-clamp_value, max=clamp_value) - y = torch.nn.functional.silu(x_l) * x_r - - valid_weight = pos_to_token_topk >= 0 - weights = torch.ones(x.shape[0], dtype=torch.float32, device=x.device) - weights[valid_weight] = topk_weights.flatten()[pos_to_token_topk[valid_weight]] - y = y * weights[:, None] - - valid_expert = pos_to_expert >= 0 - y = torch.where(valid_expert[:, None], y, torch.zeros_like(y)) - - grouped = y.reshape(x.shape[0], hidden // num_per_channels, num_per_channels) - scales = grouped.abs().amax(dim=2) / 448.0 - scales = torch.where(scales > 0.0, scales, torch.ones_like(scales)) - out = (grouped / scales[:, :, None]).reshape_as(y).to(torch.float8_e4m3fn) - return out, scales + return swiglu_forward_and_per_token_cast( + x, + "e4m3", + num_per_channels, + pos_to_token_topk=pos_to_token_topk, + topk_weights=topk_weights, + pos_to_expert=pos_to_expert, + swiglu_clamp_value=clamp_value, + ) def dequantized_sum( @@ -187,7 +180,7 @@ def main(): headers = [ "tokens", "hidden", - "torch (ms)", + "tilekernels (ms)", "tilus (ms)", "speedup", "sum diff", @@ -233,6 +226,7 @@ def main(): device="cuda", dtype=torch.float32, ) + x_tilekernels = x.float() clamp_value = 6.0 kernel( @@ -248,8 +242,8 @@ def main(): clamp_value, ) - expected_out, expected_sf = swiglu_reference( - x, + expected_out, expected_sf = tilekernels_swiglu_reference( + x_tilekernels, pos_to_token_topk, topk_weights, pos_to_expert, @@ -257,12 +251,10 @@ def main(): num_per_channels, ) valid = pos_to_expert >= 0 - torch.testing.assert_close( - out[valid].float(), - expected_out[valid].float(), - atol=1.0, - rtol=0.0, - ) + max_code_diff = ( + out[valid].float() - expected_out[valid].float() + ).abs().max().item() + assert max_code_diff <= 32.0, f"max decoded FP8 code diff is {max_code_diff}" torch.testing.assert_close( out_sf[valid], expected_sf[valid], @@ -275,12 +267,12 @@ def main(): expected_sf[valid], num_per_channels, ) - torch.testing.assert_close(actual_sum, expected_sum, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(actual_sum, expected_sum, atol=2.0, rtol=2e-2) sum_diff = (actual_sum - expected_sum).abs().item() - torch_ms = benchmark_func( - lambda: swiglu_reference( - x, + tilekernels_ms = benchmark_func( + lambda: tilekernels_swiglu_reference( + x_tilekernels, pos_to_token_topk, topk_weights, pos_to_expert, @@ -306,15 +298,16 @@ def main(): [ num_expanded_tokens, hidden, - torch_ms, + tilekernels_ms, tilus_ms, - f"{torch_ms / tilus_ms:.2f}x", + f"{tilekernels_ms / tilus_ms:.2f}x", sum_diff, ] ) print( "SwiGLU FP8 cast matches reference for size " - f"({num_expanded_tokens}, {hidden}); dequantized sum diff={sum_diff:.6g}" + f"({num_expanded_tokens}, {hidden}); max code diff={max_code_diff:.6g}; " + f"dequantized sum diff={sum_diff:.6g}" ) print(pandas.DataFrame(rows, columns=headers)) From c07413f057d10e8a79466095f775490bea188a64 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Sat, 16 May 2026 01:32:28 +0000 Subject: [PATCH 18/21] move emitter details, cut autotuner changes Signed-off-by: William Zhang --- python/tilus/backends/emitter.py | 26 ------------------- .../backends/emitters/cuda/cp_async_tensor.py | 22 ++++++++++++++++ python/tilus/lang/instantiated_script.py | 17 ++---------- 3 files changed, 24 insertions(+), 41 deletions(-) diff --git a/python/tilus/backends/emitter.py b/python/tilus/backends/emitter.py index 955197cb..4078b982 100644 --- a/python/tilus/backends/emitter.py +++ b/python/tilus/backends/emitter.py @@ -59,32 +59,6 @@ def assert_is_warp_aligned(self, inst: Instruction, msg: str) -> None: f"got thread_begin={self.current_thread_group_begin}, num_threads={self.current_num_threads}: {msg}." ) - def assert_is_single_thread_or_warp_aligned(self, inst: Instruction, msg: str) -> None: - # TMA copies must be issued by exactly one thread. The user can express - # that with single_thread() (num_threads == 1), or at warp scope where the - # `@pred` predicate selects the elected lane. Both are valid; reject only - # multi-thread non-warp contexts. - if self.current_num_threads == 1: - return - if self.current_num_threads != 32 or self.current_thread_group_begin % 32 != 0: - raise ValueError( - f"Instruction {inst} requires a single-thread or warp-aligned context " - f"(num_threads==1, or thread_begin % 32 == 0 and num_threads == 32), " - f"got thread_begin={self.current_thread_group_begin}, num_threads={self.current_num_threads}: {msg}." - ) - - @property - def tma_predicate(self): - # Inside single_thread() only one thread runs the TMA call, so the - # @pred predicate is the constant 1. At warp scope we still need to - # select a single lane, so use the elected leader-lane predicate to - # avoid an if-branch divergence. - from tilus.hidet.ir.dtypes import uint32 as _u32 - - if self.current_num_threads == 1: - return _u32(1) - return self.contexts.leader_lane_ctx.leader_lane - def sync(self): optional_sync_call = self.contexts.sync_ctx.sync() if optional_sync_call is not None: diff --git a/python/tilus/backends/emitters/cuda/cp_async_tensor.py b/python/tilus/backends/emitters/cuda/cp_async_tensor.py index fb5d82d7..c2115c73 100644 --- a/python/tilus/backends/emitters/cuda/cp_async_tensor.py +++ b/python/tilus/backends/emitters/cuda/cp_async_tensor.py @@ -49,6 +49,7 @@ CopyAsyncTensorSharedToGlobalInst, CopyAsyncTensorWaitGroupInst, ) +from tilus.ir.inst import Instruction from tilus.ir.tensor import GlobalTensor, SharedTensor from tilus.ir.utils.lineardec import LinearDecompositionError, decompose_linear from tilus.ir.utils.veceval import vectorized_evaluate @@ -151,6 +152,27 @@ def get_offset_grid_of_swizzled_layout( class CopyAsyncTensorBaseEmitter(BaseInstEmitter): + def assert_is_single_thread_or_warp_aligned(self, inst: Instruction, msg: str) -> None: + # TMA copies are issued by one elected lane. single_thread() already + # narrows execution to one lane; at warp scope, the TMA predicate elects + # the leader lane. + if self.current_num_threads == 1: + return + if self.current_num_threads != 32 or self.current_thread_group_begin % 32 != 0: + raise ValueError( + f"Instruction {inst} requires a single-thread or warp-aligned context " + f"(num_threads==1, or thread_begin % 32 == 0 and num_threads == 32), " + f"got thread_begin={self.current_thread_group_begin}, num_threads={self.current_num_threads}: {msg}." + ) + + @property + def tma_predicate(self) -> Expr: + # Inside single_thread() only one thread reaches the TMA call, so use + # constant true. At warp scope, predicate the asm on the elected leader. + if self.current_num_threads == 1: + return uint32(1) + return self.contexts.leader_lane_ctx.leader_lane + def resolve_global_tensor_info( self, global_tensor: GlobalTensor, offsets: Sequence[Expr], dims: Sequence[int] ) -> GlobalTensorInfo: diff --git a/python/tilus/lang/instantiated_script.py b/python/tilus/lang/instantiated_script.py index fff48696..0014cffd 100644 --- a/python/tilus/lang/instantiated_script.py +++ b/python/tilus/lang/instantiated_script.py @@ -44,11 +44,6 @@ logger = logging.getLogger(__name__) -# Bump when tuner semantics change (e.g. correctness gates, new selection -# criteria). On bump, dispatch_table.json files written under the prior version -# are ignored and tuning re-runs, so users don't have to manually delete cache. -_TUNER_VERSION = 2 - def span_space(space: Mapping[str, Sequence[Any]]) -> list[dict[str, Any]]: """ @@ -711,15 +706,7 @@ def load_dispatch_table(self): table_path = self.cache_dir / "dispatch_table.json" if table_path.exists(): with open(table_path, "r") as f: - payload = json.load(f) - # New format is {"tuner_version": int, "entries": [...]}; legacy format - # was a bare list. Treat legacy or version-mismatched files as empty so - # tuner semantic changes automatically re-run tuning instead of serving - # stale picks. - if isinstance(payload, dict) and payload.get("tuner_version") == _TUNER_VERSION: - entries = payload["entries"] - else: - entries = [] + entries = json.load(f) self.dispatch_table = {tuple(key): value for key, value in entries} def dump_dispatch_table(self): @@ -727,7 +714,7 @@ def dump_dispatch_table(self): table_txt_path = self.cache_dir / "dispatch_table.txt" entries = [[list(key), value] for key, value in self.dispatch_table.items()] with open(table_path, "w") as f: - json.dump({"tuner_version": _TUNER_VERSION, "entries": entries}, f) + json.dump(entries, f) headers = [] for idx in self.call_params.tuning_params: headers.append(self.call_params.param_names[idx]) From 05cda70dedc77f95f42e80660bde62be55d2de37 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Sat, 16 May 2026 03:04:02 +0000 Subject: [PATCH 19/21] bring tile kernels performance up Signed-off-by: William Zhang --- examples/quantization/per_token_cast.py | 71 ++++++----- .../swiglu_forward_and_per_token_cast.py | 114 ++++++++++-------- python/tilus/transforms/lower_assume.py | 10 +- 3 files changed, 111 insertions(+), 84 deletions(-) diff --git a/examples/quantization/per_token_cast.py b/examples/quantization/per_token_cast.py index ddedebc5..49b97960 100644 --- a/examples/quantization/per_token_cast.py +++ b/examples/quantization/per_token_cast.py @@ -16,15 +16,23 @@ from tilus.utils import benchmark_func, cdiv -@tilus.autotune("block_n", [128]) +@tilus.autotune("block_m", [1, 2, 4, 8]) +@tilus.autotune("groups_per_block", [1, 2, 4, 8]) @tilus.autotune("warps", [4, 8]) class PerTokenCast(tilus.Script): - def __init__(self, block_n: int, warps: int, num_per_channels: int = 128): + def __init__( + self, + block_m: int, + groups_per_block: int, + warps: int, + num_per_channels: int = 128, + ): super().__init__() - self.block_m = 1 - self.block_n = block_n - self.warps = warps + self.block_m = block_m self.num_per_channels = num_per_channels + self.groups_per_block = groups_per_block + self.block_n = num_per_channels + self.warps = warps def __call__( self, @@ -34,15 +42,16 @@ def __call__( out_ptr: ~float8_e4m3, out_sf_ptr: ~float32, ): + n_step = self.block_n * self.groups_per_block self.attrs.blocks = ( cdiv(num_tokens, self.block_m), - cdiv(hidden, self.block_n), + cdiv(hidden, n_step), ) self.attrs.warps = self.warps + self.assume(hidden % self.num_per_channels == 0) offset_m = self.blockIdx.x * self.block_m - offset_n = self.blockIdx.y * self.block_n - sf_col = offset_n // self.num_per_channels + base_offset_n = self.blockIdx.y * n_step g_x = self.global_view( x_ptr, @@ -60,27 +69,31 @@ def __call__( shape=[num_tokens, cdiv(hidden, self.num_per_channels)], ) - r_x = self.load_global( - g_x, - offsets=[offset_m, offset_n], - shape=[self.block_m, self.block_n], - ).to(float32) - - r_absmax = self.max(self.abs(r_x), dim=1, keepdim=True) - r_fp8_max = self.register_tensor( - dtype=float32, - shape=[self.block_m, 1], - init=448.0, - ) - r_scale = self.where(r_absmax > 0.0, x=r_absmax / 448.0, y=1.0) - r_inv_scale = self.where(r_absmax > 0.0, x=r_fp8_max / r_absmax, y=1.0) - - self.store_global(g_out_sf, r_scale, offsets=[offset_m, sf_col]) - self.store_global( - g_out, - (r_x * r_inv_scale).to(float8_e4m3), - offsets=[offset_m, offset_n], - ) + for gi in range(self.groups_per_block): + offset_n = base_offset_n + gi * self.block_n + sf_col = offset_n // self.num_per_channels + + r_x = self.load_global( + g_x, + offsets=[offset_m, offset_n], + shape=[self.block_m, self.block_n], + ).to(float32) + + r_absmax = self.max(self.abs(r_x), dim=1, keepdim=True) + r_fp8_max = self.register_tensor( + dtype=float32, + shape=[self.block_m, 1], + init=448.0, + ) + r_scale = self.where(r_absmax > 0.0, x=r_absmax / 448.0, y=1.0) + r_inv_scale = self.where(r_absmax > 0.0, x=r_fp8_max / r_absmax, y=1.0) + + self.store_global(g_out_sf, r_scale, offsets=[offset_m, sf_col]) + self.store_global( + g_out, + (r_x * r_inv_scale).to(float8_e4m3), + offsets=[offset_m, offset_n], + ) def tilekernels_per_token_cast_reference( diff --git a/examples/quantization/swiglu_forward_and_per_token_cast.py b/examples/quantization/swiglu_forward_and_per_token_cast.py index d49c2c7f..afa9b91e 100644 --- a/examples/quantization/swiglu_forward_and_per_token_cast.py +++ b/examples/quantization/swiglu_forward_and_per_token_cast.py @@ -22,12 +22,14 @@ from tilus.utils import benchmark_func, cdiv -@tilus.autotune("block_n", [128]) -@tilus.autotune("warps", [4, 8]) +@tilus.autotune("block_m", [1]) +@tilus.autotune("groups_per_block", [1, 2, 4, 8, 16]) +@tilus.autotune("warps", [1, 2, 4, 8]) class SwiGLUForwardAndPerTokenCast(tilus.Script): def __init__( self, - block_n: int, + block_m: int, + groups_per_block: int, warps: int, with_weight: bool = True, with_pos_to_expert: bool = True, @@ -35,13 +37,14 @@ def __init__( num_per_channels: int = 128, ): super().__init__() - self.block_m = 1 - self.block_n = block_n + self.block_m = block_m + self.num_per_channels = num_per_channels + self.groups_per_block = groups_per_block + self.block_n = num_per_channels self.warps = warps self.with_weight = with_weight self.with_pos_to_expert = with_pos_to_expert self.use_clamp = use_clamp - self.num_per_channels = num_per_channels def __call__( self, @@ -56,15 +59,16 @@ def __call__( pos_to_expert_ptr: ~int32, swiglu_clamp_value: float32, ): + n_step = self.block_n * self.groups_per_block self.attrs.blocks = ( cdiv(num_expanded_tokens, self.block_m), - cdiv(hidden, self.block_n), + cdiv(hidden, n_step), ) self.attrs.warps = self.warps + self.assume(hidden % self.num_per_channels == 0) offset_m = self.blockIdx.x * self.block_m - offset_n = self.blockIdx.y * self.block_n - sf_col = offset_n // self.num_per_channels + base_offset_n = self.blockIdx.y * n_step g_x = self.global_view( x_ptr, @@ -98,51 +102,57 @@ def __call__( ) if (not self.with_pos_to_expert) or g_pos_to_expert[offset_m].item() >= 0: - r_l = self.load_global( - g_x, - offsets=[offset_m, offset_n], - shape=[self.block_m, self.block_n], - ).to(float32) - r_r = self.load_global( - g_x, - offsets=[offset_m, offset_n + hidden], - shape=[self.block_m, self.block_n], - ).to(float32) - - if self.use_clamp: - negative_swiglu_clamp_value = 0.0 - swiglu_clamp_value - r_l = self.where(r_l > swiglu_clamp_value, x=swiglu_clamp_value, y=r_l) - r_r = self.where(r_r > swiglu_clamp_value, x=swiglu_clamp_value, y=r_r) - r_r = self.where( - r_r < negative_swiglu_clamp_value, - x=negative_swiglu_clamp_value, - y=r_r, - ) - - r_silu = r_l / (self.exp(-r_l) + 1.0) - r_value = r_silu * r_r - if self.with_weight: topk_pos = g_pos_to_token_topk[offset_m].item() - if topk_pos >= 0: - topk_weight = g_topk_weights[topk_pos].item() - r_value = r_value * topk_weight - - r_absmax = self.max(self.abs(r_value), dim=1, keepdim=True) - r_fp8_max = self.register_tensor( - dtype=float32, - shape=[self.block_m, 1], - init=448.0, - ) - r_scale = self.where(r_absmax > 0.0, x=r_absmax / 448.0, y=1.0) - r_inv_scale = self.where(r_absmax > 0.0, x=r_fp8_max / r_absmax, y=1.0) - - self.store_global(g_out_sf, r_scale, offsets=[offset_m, sf_col]) - self.store_global( - g_out, - (r_value * r_inv_scale).to(float8_e4m3), - offsets=[offset_m, offset_n], - ) + + for gi in range(self.groups_per_block): + offset_n = base_offset_n + gi * self.block_n + sf_col = offset_n // self.num_per_channels + + r_l = self.load_global( + g_x, + offsets=[offset_m, offset_n], + shape=[self.block_m, self.block_n], + ).to(float32) + r_r = self.load_global( + g_x, + offsets=[offset_m, offset_n + hidden], + shape=[self.block_m, self.block_n], + ).to(float32) + + if self.use_clamp: + negative_swiglu_clamp_value = 0.0 - swiglu_clamp_value + r_l = self.where(r_l > swiglu_clamp_value, x=swiglu_clamp_value, y=r_l) + r_r = self.where(r_r > swiglu_clamp_value, x=swiglu_clamp_value, y=r_r) + r_r = self.where( + r_r < negative_swiglu_clamp_value, + x=negative_swiglu_clamp_value, + y=r_r, + ) + + r_silu = r_l / (self.exp(-r_l) + 1.0) + r_value = r_silu * r_r + + if self.with_weight: + if topk_pos >= 0: + topk_weight = g_topk_weights[topk_pos].item() + r_value = r_value * topk_weight + + r_absmax = self.max(self.abs(r_value), dim=1, keepdim=True) + r_fp8_max = self.register_tensor( + dtype=float32, + shape=[self.block_m, 1], + init=448.0, + ) + r_scale = self.where(r_absmax > 0.0, x=r_absmax / 448.0, y=1.0) + r_inv_scale = self.where(r_absmax > 0.0, x=r_fp8_max / r_absmax, y=1.0) + + self.store_global(g_out_sf, r_scale, offsets=[offset_m, sf_col]) + self.store_global( + g_out, + (r_value * r_inv_scale).to(float8_e4m3), + offsets=[offset_m, offset_n], + ) def tilekernels_swiglu_reference( diff --git a/python/tilus/transforms/lower_assume.py b/python/tilus/transforms/lower_assume.py index 9be73c37..38bda543 100644 --- a/python/tilus/transforms/lower_assume.py +++ b/python/tilus/transforms/lower_assume.py @@ -17,7 +17,7 @@ from tilus.ir.functors import IRRewriter from tilus.ir.instructions import AssumeInst from tilus.transforms.base import Pass -from tilus.utils import gcd +from tilus.utils import lcm class ApplyAssumeRewriter(IRRewriter): @@ -54,7 +54,11 @@ def visit_AssumeInst(self, inst: AssumeInst) -> None: raise RuntimeError( "We only allow to specify the divisibility of kernel parameter, got {}".format(a.name) ) - self.param2divisibility[a] = int(term.a.b.value) # type: ignore[arg-type] + divisor = int(term.a.b.value) # type: ignore[arg-type] + if a in self.param2divisibility: + self.param2divisibility[a] = lcm(self.param2divisibility[a], divisor) + else: + self.param2divisibility[a] = divisor else: raise RuntimeError("Can not recognize the condition in assume: {}".format(term)) @@ -70,7 +74,7 @@ def visit_Function(self, func: Function) -> Function: param2divisibility = updated_func.metadata.param2divisibility.copy() for var in self.param2divisibility: if var in param2divisibility: - param2divisibility[var] = gcd(param2divisibility[var], self.param2divisibility[var]) + param2divisibility[var] = lcm(param2divisibility[var], self.param2divisibility[var]) else: param2divisibility[var] = self.param2divisibility[var] return updated_func.with_metadata(updated_func.metadata.with_param2divisibility(param2divisibility)) From 55859d68889207888059cebc954f88fc01c51142 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Sun, 17 May 2026 02:43:22 +0000 Subject: [PATCH 20/21] fix linting Signed-off-by: William Zhang --- examples/hopper_matmul/matmul_v5.py | 8 ++------ python/tilus/backends/emitters/cuda/cp_async_tensor.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/examples/hopper_matmul/matmul_v5.py b/examples/hopper_matmul/matmul_v5.py index 32e5e92b..dddd6557 100644 --- a/examples/hopper_matmul/matmul_v5.py +++ b/examples/hopper_matmul/matmul_v5.py @@ -70,9 +70,7 @@ def prev_consumer_barrier(self) -> RegisterTensor: # block_m must be >= 128 so each WG's WGMMA M = block_m/2 >= 64. @tilus.autotune("num_stages", [3, 4, 5, 6]) -@tilus.autotune( - "block_m, block_n", [[128, 128], [128, 256], [256, 128], [256, 256]] -) +@tilus.autotune("block_m, block_n", [[128, 128], [128, 256], [256, 128], [256, 256]]) @tilus.autotune("block_k", [16, 32, 64]) @tilus.autotune("swizzle_size", [4, 8]) class MatmulWGMMAV5(tilus.Script): @@ -238,9 +236,7 @@ def __call__( self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) casted1 = self.cast(acc1, dtype=float16) - self.store_global( - gc, casted1, offsets=[offset_m + block_m_half, offset_n] - ) + self.store_global(gc, casted1, offsets=[offset_m + block_m_half, offset_n]) def main(): diff --git a/python/tilus/backends/emitters/cuda/cp_async_tensor.py b/python/tilus/backends/emitters/cuda/cp_async_tensor.py index c2115c73..a890b0b9 100644 --- a/python/tilus/backends/emitters/cuda/cp_async_tensor.py +++ b/python/tilus/backends/emitters/cuda/cp_async_tensor.py @@ -43,13 +43,13 @@ from tilus.hidet.ir.tools import rewrite, simplify from tilus.hidet.ir.type import DataType, PointerType, TensorType, sizeof from tilus.ir import GlobalLayout +from tilus.ir.inst import Instruction from tilus.ir.instructions.cuda.cp_async_tensor import ( CopyAsyncTensorCommitGroupInst, CopyAsyncTensorGlobalToSharedInst, CopyAsyncTensorSharedToGlobalInst, CopyAsyncTensorWaitGroupInst, ) -from tilus.ir.inst import Instruction from tilus.ir.tensor import GlobalTensor, SharedTensor from tilus.ir.utils.lineardec import LinearDecompositionError, decompose_linear from tilus.ir.utils.veceval import vectorized_evaluate From 2ad075b38bc1554811c1d7cf614138a2ece27af5 Mon Sep 17 00:00:00 2001 From: William Zhang Date: Sun, 17 May 2026 02:48:56 +0000 Subject: [PATCH 21/21] beat tilelang by 10% on kernels Signed-off-by: William Zhang --- .../swiglu_forward_and_per_token_cast.py | 111 ++++++++++-------- python/tilus/backends/emitters/transform.py | 10 +- python/tilus/ir/builders/stmt_builder.py | 11 ++ python/tilus/ir/instructions/__init__.py | 1 + python/tilus/ir/instructions/generic.py | 17 +++ .../inference/inference_rules/transform.py | 19 ++- python/tilus/ir/layout/inference/order.py | 4 +- .../inference/validation_rules/transform.py | 15 ++- python/tilus/ir/layout/ops/register_ops.py | 12 +- python/tilus/lang/instructions/root.py | 25 ++++ 10 files changed, 167 insertions(+), 58 deletions(-) diff --git a/examples/quantization/swiglu_forward_and_per_token_cast.py b/examples/quantization/swiglu_forward_and_per_token_cast.py index afa9b91e..153cc343 100644 --- a/examples/quantization/swiglu_forward_and_per_token_cast.py +++ b/examples/quantization/swiglu_forward_and_per_token_cast.py @@ -102,57 +102,66 @@ def __call__( ) if (not self.with_pos_to_expert) or g_pos_to_expert[offset_m].item() >= 0: + base_sf_col = base_offset_n // self.num_per_channels + + # Wide load: full n_step at once so layout-inference vectorises. + r_l = self.load_global( + g_x, + offsets=[offset_m, base_offset_n], + shape=[self.block_m, n_step], + ).to(float32) + r_r = self.load_global( + g_x, + offsets=[offset_m, base_offset_n + hidden], + shape=[self.block_m, n_step], + ).to(float32) + + if self.use_clamp: + negative_swiglu_clamp_value = 0.0 - swiglu_clamp_value + r_l = self.where(r_l > swiglu_clamp_value, x=swiglu_clamp_value, y=r_l) + r_r = self.where(r_r > swiglu_clamp_value, x=swiglu_clamp_value, y=r_r) + r_r = self.where( + r_r < negative_swiglu_clamp_value, + x=negative_swiglu_clamp_value, + y=r_r, + ) + + r_silu = r_l / (self.exp(-r_l) + 1.0) + r_value = r_silu * r_r + if self.with_weight: topk_pos = g_pos_to_token_topk[offset_m].item() + if topk_pos >= 0: + topk_weight = g_topk_weights[topk_pos].item() + r_value = r_value * topk_weight + + # Reshape into [block_m, groups_per_block, num_per_channels] so the + # per-group absmax is a single reduce on dim=2. + r_value_grouped = self.reshape( + r_value, + shape=[self.block_m, self.groups_per_block, self.num_per_channels], + ) + r_absmax = self.max( + self.abs(r_value_grouped), dim=2, keepdim=True + ) # [block_m, groups_per_block, 1] + r_fp8_max = self.register_tensor( + dtype=float32, + shape=[self.block_m, self.groups_per_block, 1], + init=448.0, + ) + r_scale = self.where(r_absmax > 0.0, x=r_absmax / 448.0, y=1.0) + r_inv_scale = self.where(r_absmax > 0.0, x=r_fp8_max / r_absmax, y=1.0) - for gi in range(self.groups_per_block): - offset_n = base_offset_n + gi * self.block_n - sf_col = offset_n // self.num_per_channels - - r_l = self.load_global( - g_x, - offsets=[offset_m, offset_n], - shape=[self.block_m, self.block_n], - ).to(float32) - r_r = self.load_global( - g_x, - offsets=[offset_m, offset_n + hidden], - shape=[self.block_m, self.block_n], - ).to(float32) - - if self.use_clamp: - negative_swiglu_clamp_value = 0.0 - swiglu_clamp_value - r_l = self.where(r_l > swiglu_clamp_value, x=swiglu_clamp_value, y=r_l) - r_r = self.where(r_r > swiglu_clamp_value, x=swiglu_clamp_value, y=r_r) - r_r = self.where( - r_r < negative_swiglu_clamp_value, - x=negative_swiglu_clamp_value, - y=r_r, - ) - - r_silu = r_l / (self.exp(-r_l) + 1.0) - r_value = r_silu * r_r - - if self.with_weight: - if topk_pos >= 0: - topk_weight = g_topk_weights[topk_pos].item() - r_value = r_value * topk_weight - - r_absmax = self.max(self.abs(r_value), dim=1, keepdim=True) - r_fp8_max = self.register_tensor( - dtype=float32, - shape=[self.block_m, 1], - init=448.0, - ) - r_scale = self.where(r_absmax > 0.0, x=r_absmax / 448.0, y=1.0) - r_inv_scale = self.where(r_absmax > 0.0, x=r_fp8_max / r_absmax, y=1.0) - - self.store_global(g_out_sf, r_scale, offsets=[offset_m, sf_col]) - self.store_global( - g_out, - (r_value * r_inv_scale).to(float8_e4m3), - offsets=[offset_m, offset_n], - ) + # Store one fp32 scale per group. + r_scale_2d = self.reshape( + r_scale, shape=[self.block_m, self.groups_per_block] + ) + self.store_global(g_out_sf, r_scale_2d, offsets=[offset_m, base_sf_col]) + + # Apply scaling, flatten back, cast to fp8, bulk store. + r_out_grouped = (r_value_grouped * r_inv_scale).to(float8_e4m3) + r_out = self.reshape(r_out_grouped, shape=[self.block_m, n_step]) + self.store_global(g_out, r_out, offsets=[offset_m, base_offset_n]) def tilekernels_swiglu_reference( @@ -199,6 +208,8 @@ def main(): for num_expanded_tokens, hidden, num_tokens, num_topk in [ (128, 1024, 64, 2), (256, 2048, 128, 2), + (257, 4096, 128, 2), + (1024, 4096, 512, 2), ]: num_per_channels = 128 kernel = SwiGLUForwardAndPerTokenCast(num_per_channels=num_per_channels) @@ -262,8 +273,8 @@ def main(): ) valid = pos_to_expert >= 0 max_code_diff = ( - out[valid].float() - expected_out[valid].float() - ).abs().max().item() + (out[valid].float() - expected_out[valid].float()).abs().max().item() + ) assert max_code_diff <= 32.0, f"max decoded FP8 code diff is {max_code_diff}" torch.testing.assert_close( out_sf[valid], diff --git a/python/tilus/backends/emitters/transform.py b/python/tilus/backends/emitters/transform.py index 5f79402c..37acec97 100644 --- a/python/tilus/backends/emitters/transform.py +++ b/python/tilus/backends/emitters/transform.py @@ -13,7 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from tilus.backends.emitter import BaseInstEmitter, register_emitter -from tilus.ir.instructions import RepeatInst, RepeatInterleaveInst, SqueezeInst, TransposeInst, UnsqueezeInst +from tilus.ir.instructions import ( + RepeatInst, + RepeatInterleaveInst, + ReshapeRegisterInst, + SqueezeInst, + TransposeInst, + UnsqueezeInst, +) @register_emitter(RepeatInst) @@ -87,6 +94,7 @@ def emit(self, inst: RepeatInterleaveInst) -> None: @register_emitter(UnsqueezeInst) @register_emitter(SqueezeInst) +@register_emitter(ReshapeRegisterInst) class SqueezeUnsqueezeInstEmitter(BaseInstEmitter): def emit(self, inst: SqueezeInst) -> None: src = inst.register_input diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index be8ba34d..549e818d 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -109,6 +109,7 @@ ReduceInst, RepeatInst, RepeatInterleaveInst, + ReshapeRegisterInst, ReshapeSharedInst, ScanInst, SliceAssignInst, @@ -553,6 +554,16 @@ def unsqueeze( self.append(inst) return inst.register_output + def reshape_register( + self, + x: RegisterTensor, + shape: Sequence[int], + out: Optional[RegisterTensor] = None, + ) -> RegisterTensor: + inst = ReshapeRegisterInst.create(x=x, shape=shape, out=out) + self.append(inst) + return inst.register_output + def cast( self, x: RegisterTensor, diff --git a/python/tilus/ir/instructions/__init__.py b/python/tilus/ir/instructions/__init__.py index 091e9417..e71bc540 100644 --- a/python/tilus/ir/instructions/__init__.py +++ b/python/tilus/ir/instructions/__init__.py @@ -65,6 +65,7 @@ ReduceInst, RepeatInst, RepeatInterleaveInst, + ReshapeRegisterInst, ReshapeSharedInst, ScanInst, ShuffleDownInst, diff --git a/python/tilus/ir/instructions/generic.py b/python/tilus/ir/instructions/generic.py index 4d39a47a..cf66b572 100644 --- a/python/tilus/ir/instructions/generic.py +++ b/python/tilus/ir/instructions/generic.py @@ -637,6 +637,23 @@ def create(x: RegisterTensor, out: Optional[RegisterTensor] = None) -> Transpose return TransposeInst(output=out, inputs=(x,)) +@dataclass(frozen=True, eq=False) +class ReshapeRegisterInst(Instruction): + @staticmethod + def create( + x: RegisterTensor, + shape: Sequence[int], + out: Optional[RegisterTensor] = None, + ) -> ReshapeRegisterInst: + from tilus.utils import prod + + if out is None: + if prod(x.shape) != prod(shape): + raise ValueError(f"Cannot reshape register tensor with shape {x.shape} to shape {shape}: sizes differ") + out = RegisterTensor.create(dtype=x.dtype, shape=tuple(shape)) + return ReshapeRegisterInst(output=out, inputs=(x,)) + + @dataclass(frozen=True, eq=False) class AllocateSharedInst(Instruction): @staticmethod diff --git a/python/tilus/ir/layout/inference/inference_rules/transform.py b/python/tilus/ir/layout/inference/inference_rules/transform.py index 55a5fae6..369eb940 100644 --- a/python/tilus/ir/layout/inference/inference_rules/transform.py +++ b/python/tilus/ir/layout/inference/inference_rules/transform.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from tilus import RegisterLayout -from tilus.ir.instructions import SqueezeInst, UnsqueezeInst +from tilus.ir.instructions import ReshapeRegisterInst, SqueezeInst, UnsqueezeInst from tilus.ir.layout import ops from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule from tilus.ir.tensor import RegisterTensor @@ -51,3 +51,20 @@ def inference(ctx: LayoutInferenceContext, inst: SqueezeInst) -> dict[RegisterTe return {x: ops.unsqueeze(y.layout, dims=inst.dims)} else: return {} + + +@register_rule(ReshapeRegisterInst) +class ReshapeRegisterRule(LayoutInferenceRule): + @staticmethod + def inference(ctx: LayoutInferenceContext, inst: ReshapeRegisterInst) -> dict[RegisterTensor, RegisterLayout]: + x = inst.register_input + y = inst.register_output + + if x.optional_layout is not None and y.optional_layout is not None: + return {} + elif x.optional_layout is not None: + return {y: ops.reshape(x.layout, shape=y.shape)} + elif y.optional_layout is not None: + return {x: ops.reshape(y.layout, shape=x.shape)} + else: + return {} diff --git a/python/tilus/ir/layout/inference/order.py b/python/tilus/ir/layout/inference/order.py index 59118291..e1d5c891 100644 --- a/python/tilus/ir/layout/inference/order.py +++ b/python/tilus/ir/layout/inference/order.py @@ -47,7 +47,7 @@ from .inference_rules.tcgen05.ldst import Tcgen05LoadRule, Tcgen05StoreRule from .inference_rules.tcgen05.mma import Tcgen05MmaSSRule, Tcgen05MmaTSRule from .inference_rules.tcgen05.slice import Tcgen05SliceRule -from .inference_rules.transform import SqueezeRule, UnsqueezeRule +from .inference_rules.transform import ReshapeRegisterRule, SqueezeRule, UnsqueezeRule from .inference_rules.transform_shared import PermuteSharedRule, SharedSliceRule from .inference_rules.transpose import TransposeRule from .inference_rules.wgmma import WgmmaMmaSSRule @@ -67,7 +67,7 @@ [LoadGlobalRule], [ReduceRule], [ScanRule], - [TransposeRule, SqueezeRule, UnsqueezeRule], + [TransposeRule, SqueezeRule, UnsqueezeRule, ReshapeRegisterRule], [WhereRule], [AssignRule], [StoreGlobalRule], diff --git a/python/tilus/ir/layout/inference/validation_rules/transform.py b/python/tilus/ir/layout/inference/validation_rules/transform.py index 1367e4dc..27af3be9 100644 --- a/python/tilus/ir/layout/inference/validation_rules/transform.py +++ b/python/tilus/ir/layout/inference/validation_rules/transform.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tilus.ir.instructions import SqueezeInst, UnsqueezeInst +from tilus.ir.instructions import ReshapeRegisterInst, SqueezeInst, UnsqueezeInst from tilus.ir.layout import ops from tilus.ir.layout.inference.rule import LayoutValidationRule, register_rule from tilus.ir.tensor import RegisterTensor @@ -36,3 +36,16 @@ def validate(inst: UnsqueezeInst) -> bool: y: RegisterTensor = inst.register_output return y.layout == ops.unsqueeze(x.layout, dims=inst.dims) + + +@register_rule(ReshapeRegisterInst) +class ReshapeRegisterRule(LayoutValidationRule): + @staticmethod + def validate(inst: ReshapeRegisterInst) -> bool: + x: RegisterTensor = inst.register_input + y: RegisterTensor = inst.register_output + + try: + return y.layout == ops.reshape(x.layout, shape=y.shape) + except Exception: + return False diff --git a/python/tilus/ir/layout/ops/register_ops.py b/python/tilus/ir/layout/ops/register_ops.py index 6c59a13d..3434c27d 100644 --- a/python/tilus/ir/layout/ops/register_ops.py +++ b/python/tilus/ir/layout/ops/register_ops.py @@ -496,15 +496,21 @@ def reshape(layout: RegisterLayout, shape: Sequence[int]) -> RegisterLayout: p = mode_shape.pop(0) grouped_mode_shape.append([]) - while shape: + while shape and p > 1: q = shape[0] + if q == 1: + shape.pop(0) + continue if q % p == 0: grouped_mode_shape[-1].append(p) shape[0] = q // p + if shape[0] == 1: + shape.pop(0) + p = 1 break elif p % q == 0: - if q > 1: - grouped_mode_shape[-1].append(q) + grouped_mode_shape[-1].append(q) + p //= q shape.pop(0) else: raise LayoutOperationError("Cannot reshape layout {} to shape {}".format(layout, shape)) diff --git a/python/tilus/lang/instructions/root.py b/python/tilus/lang/instructions/root.py index 2e1d970d..18d01cfd 100644 --- a/python/tilus/lang/instructions/root.py +++ b/python/tilus/lang/instructions/root.py @@ -695,6 +695,31 @@ def free_shared(self, tensor: SharedTensor) -> None: """ self._builder.free_shared(tensor) + def reshape(self, tensor: RegisterTensor, shape: Sequence[int]) -> RegisterTensor: + """Reshape a register tensor. + + The new shape must have the same total size as the original. The + underlying per-thread storage is unchanged; only the logical shape (and + mode grouping used for broadcasts/reductions) is updated. + + Parameters + ---------- + tensor: RegisterTensor + The register tensor to reshape. + shape: Sequence[int] + The new shape of the register tensor. + + Returns + ------- + ret: RegisterTensor + The reshaped register tensor. + + Notes + ----- + - **Thread group**: Can be executed by any sized thread group. + """ + return self._builder.reshape_register(x=tensor, shape=shape) + def reshape_shared(self, tensor: SharedTensor, shape: Sequence[int]) -> SharedTensor: """Reshape a shared tensor.