diff --git a/examples/hopper_matmul/benchmark.py b/examples/hopper_matmul/benchmark.py new file mode 100644 index 00000000..cf654f3d --- /dev/null +++ b/examples/hopper_matmul/benchmark.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Benchmark all hopper_matmul versions against cuBLAS (torch.matmul). + +Run directly: + python benchmark.py + python benchmark.py --ncu + python benchmark.py --versions v3 v4 v5 --size 4096 4096 4096 +""" + +import argparse +import csv +import io +import subprocess +import time + +VERSION_NAMES = ["v0", "v1", "v2", "v3", "v4", "v5"] + +VERSION_CLASS = { + "v0": "MatmulTMA", + "v1": "MatmulWGMMA", + "v2": "MatmulWGMMAV2", + "v3": "MatmulWGMMAV3", + "v4": "MatmulWGMMAV4", + "v5": "MatmulWGMMAV5", +} + + +def _load_version(name: str): + """Lazily import a matmul module by version name and return the class.""" + import importlib + + import tilus + + tilus.option.cache_dir("./cache") + + module = importlib.import_module(f"matmul_{name}") + return getattr(module, VERSION_CLASS[name]) + + +def run_kernels(version_names: list, m_size: int, n_size: int, k_size: int): + """Run cuBLAS and tilus matmul versions sequentially (used as the target for ncu_run).""" + import torch + + a = torch.randn(m_size, k_size, dtype=torch.float16, device="cuda") + b = torch.randn(n_size, k_size, dtype=torch.float16, device="cuda") + c = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") + + # tilus versions + for name in version_names: + matmul = _load_version(name)() + matmul(m_size, n_size, k_size, a, b, c) + torch.cuda.synchronize() + + # cuBLAS + _ = a @ b.T + torch.cuda.synchronize() + + +def _read_ncu_csv( + report_path: str, page: str, metrics: str | None = None +) -> csv.DictReader: + """Run ncu --import --csv and return a DictReader, skipping the units row.""" + cmd = ["/usr/local/cuda/bin/ncu", "--import", report_path, "--csv", "--page", page] + if metrics: + cmd += ["--metrics", metrics] + result = subprocess.run(cmd, capture_output=True, text=True) + reader = csv.DictReader(io.StringIO(result.stdout)) + next(reader, None) + return reader + + +def _short_kernel_name(name: str) -> str: + idx = name.find("(") + return name[:idx] if idx != -1 else name + + +def parse_ncu_report(report_path: str) -> list[tuple[str, dict]]: + """Extract per-kernel metrics from an NCU report. Returns [(kernel_name, metrics), ...] in order.""" + tensor_col = "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed" + reader = _read_ncu_csv(report_path, "raw", metrics=tensor_col) + per_kernel: dict[str, dict] = {} + kernel_order: list[str] = [] + for row in reader: + kernel = _short_kernel_name(row["Kernel Name"]) + if kernel not in per_kernel: + per_kernel[kernel] = {} + kernel_order.append(kernel) + metrics = per_kernel[kernel] + if tensor_col in row and row[tensor_col]: + metrics["tensor_core_util (%)"] = float(row[tensor_col]) + + reader2 = _read_ncu_csv(report_path, "details") + for row in reader2: + kernel = _short_kernel_name(row["Kernel Name"]) + if kernel not in per_kernel: + per_kernel[kernel] = {} + kernel_order.append(kernel) + metrics = per_kernel[kernel] + if row.get("Metric Name") == "DRAM Throughput": + metrics["dram_throughput (%)"] = float(row["Metric Value"]) + if row.get("Metric Name") == "Compute (SM) Throughput": + metrics["sm_throughput (%)"] = float(row["Metric Value"]) + if row.get("Metric Name") == "SM Frequency": + metrics["sm_freq (GHz)"] = float(row["Metric Value"]) + if row.get("Metric Name") == "Duration": + value = float(row["Metric Value"]) + unit = row.get("Metric Unit", "ms") + if unit == "us": + value /= 1000.0 + elif unit == "s": + value *= 1000.0 + metrics["duration (ms)"] = value + + return [(k, per_kernel[k]) for k in kernel_order] + + +def benchmark_all(versions: list[str], m_size: int, n_size: int, k_size: int): + """Benchmark all versions using benchmark_func.""" + import math + + import pandas + import torch + from tilus.utils import benchmark_func + + headers = ["version", "latency (ms)", "tflops", "% of cublas"] + rows = [] + + a = ( + torch.rand(m_size, k_size, dtype=torch.float16, device="cuda") - 0.5 + ) / math.sqrt(k_size) + b = ( + torch.rand(n_size, k_size, dtype=torch.float16, device="cuda") - 0.5 + ) / math.sqrt(k_size) + c_ref = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") + c_tilus = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") + + # cuBLAS baseline first, so we can compute % of cublas + cublas_lat = benchmark_func( + lambda: torch.matmul(a, b.T, out=c_ref), warmup=5, repeat=30 + ) + cublas_tf = 2 * m_size * n_size * k_size / cublas_lat * 1e-9 + + for name in versions: + try: + matmul = _load_version(name)() + # warmup + correctness check + matmul(m_size, n_size, k_size, a, b, c_tilus) + torch.cuda.synchronize() + torch.testing.assert_close(c_ref, c_tilus, atol=1e-2, rtol=1e-2) + + latency = benchmark_func( + lambda: matmul(m_size, n_size, k_size, a, b, c_tilus), warmup=5, repeat=30 + ) + tf = 2 * m_size * n_size * k_size / latency * 1e-9 + pct = tf / cublas_tf * 100.0 + rows.append([f"tilus_{name}", latency, tf, pct]) + time.sleep(1) # cool down between runs + except Exception as e: + print(f" tilus_{name} ERROR: {e}") + rows.append([f"tilus_{name}", float("nan"), float("nan"), float("nan")]) + + rows.append(["cublas", cublas_lat, cublas_tf, 100.0]) + + df = pandas.DataFrame(rows, columns=headers) + print(f"\nBenchmark results (m={m_size}, n={n_size}, k={k_size}):") + print(df.to_string(index=False)) + + +def ncu_profile_all(versions: list[str], m_size: int, n_size: int, k_size: int): + """Profile all versions in a single ncu_run and extract key metrics.""" + import pandas + import tilus + + print("Warming up (JIT + autotuning)...") + run_kernels(versions, m_size, n_size, k_size) + + labels = list(versions) + ["cublas"] + + print(f"Profiling cublas, {', '.join(versions)} ...") + report = tilus.utils.ncu_run( + run_kernels, + versions, + m_size, + n_size, + k_size, + kernel_regex="tilus|cutlass|sm90|gemm|cublas", + ) + print(f"Report saved to: {report.report_path}") + + kernel_metrics = parse_ncu_report(report.report_path) + + headers = [ + "version", + "kernel", + "duration (ms)", + "tflops", + "sm_freq (GHz)", + "sm_throughput (%)", + "dram_throughput (%)", + "tensor_core_util (%)", + ] + rows = [] + for i, name in enumerate(labels): + if i < len(kernel_metrics): + kernel, metrics = kernel_metrics[i] + else: + kernel, metrics = "?", {} + duration_ms = metrics.get("duration (ms)", "") + tflops = 2 * m_size * n_size * k_size / duration_ms * 1e-9 if duration_ms else "" + rows.append( + [ + name, + kernel, + duration_ms, + tflops, + metrics.get("sm_freq (GHz)", ""), + metrics.get("sm_throughput (%)", ""), + metrics.get("dram_throughput (%)", ""), + metrics.get("tensor_core_util (%)", ""), + ] + ) + + df = pandas.DataFrame(rows, columns=headers) + print(f"\nNCU profiling results (m={m_size}, n={n_size}, k={k_size}):") + print(df.to_string(index=False)) + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark Hopper matmul V0-V5") + parser.add_argument( + "--ncu", + action="store_true", + help="Use NCU profiling instead of benchmark_func", + ) + parser.add_argument( + "--versions", + nargs="+", + default=VERSION_NAMES, + choices=VERSION_NAMES, + help="Which versions to benchmark (default: all)", + ) + parser.add_argument( + "--size", + nargs=3, + type=int, + default=[8192, 8192, 8192], + metavar=("M", "N", "K"), + help="Workload size M N K (default: 8192 8192 8192)", + ) + args = parser.parse_args() + m_size, n_size, k_size = args.size + + if args.ncu: + ncu_profile_all(args.versions, m_size, n_size, k_size) + else: + benchmark_all(args.versions, m_size, n_size, k_size) + + +if __name__ == "__main__": + main() diff --git a/examples/hopper_matmul/matmul_v4.py b/examples/hopper_matmul/matmul_v4.py new file mode 100644 index 00000000..7a48a859 --- /dev/null +++ b/examples/hopper_matmul/matmul_v4.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# v4: Pipeline class abstraction + tile rasterization for better L2 cache reuse. +# +# Changes from v3: +# - Pipeline class encapsulates barrier/phase/stage ring-buffer logic. +# - 1D grid launch with compute_block_coord() maps linear blockIdx to (m, n) +# using a swizzle group so adjacent tiles share B columns → L2 reuse. + +import math + +import pandas +import tilus +import torch +from tilus import RegisterTensor, float16, float32, int32, uint32 +from tilus.utils import benchmark_func, cdiv + + +class Pipeline(tilus.Class): + def __init__( + self, + num_stages: int, + producer_arrive_count: int = 1, + consumer_arrive_count: int = 1, + ): + self.num_stages: int = num_stages + self.empty_barriers = self.mbarrier.alloc( + [consumer_arrive_count for _ in range(num_stages)] + ) + self.full_barriers = self.mbarrier.alloc( + [producer_arrive_count for _ in range(num_stages)] + ) + self.producer_stage: int32 = 0 + self.consumer_stage: int32 = 0 + self.producer_phase: uint32 = self.mbarrier.producer_initial_phase + self.consumer_phase: uint32 = self.mbarrier.consumer_initial_phase + + def producer_acquire(self): + self.mbarrier.wait( + barrier=self.empty_barriers[self.producer_stage], + phase=self.producer_phase, + sem="relaxed", + scope="cta", + ) + + def producer_barrier(self) -> RegisterTensor: + return self.full_barriers[self.producer_stage] + + def producer_advance(self): + self.producer_stage = (self.producer_stage + 1) % self.num_stages + self.producer_phase = self.producer_phase ^ (self.producer_stage == 0) + + def consumer_acquire(self): + self.mbarrier.wait( + barrier=self.full_barriers[self.consumer_stage], + phase=self.consumer_phase, + sem="relaxed", + scope="cta", + ) + + def consumer_barrier(self) -> RegisterTensor: + return self.empty_barriers[self.consumer_stage] + + def consumer_advance(self): + self.consumer_stage = (self.consumer_stage + 1) % self.num_stages + self.consumer_phase = self.consumer_phase ^ (self.consumer_stage == 0) + + def prev_consumer_barrier(self) -> RegisterTensor: + # Use (consumer_stage + num_stages - 1) so num_stages - 1 is evaluated as a + # Python int constant first, ensuring the inner expression is always non-negative + # and avoids C's negative-dividend truncated-modulo behaviour. + prev_stage = (self.consumer_stage + (self.num_stages - 1)) % self.num_stages + return self.empty_barriers[prev_stage] + + +# Tightened autotune space: drop block_m=128/n=64 (skinny → poor wgmma fill), +# drop block_m=256/n=128 (heavy register pressure), drop num_stages=2 (too +# shallow to hide TMA), drop num_stages=7 (smem-thrashing on 256×256), drop +# swizzle_size=1 (≡ default rasterization, no L2 win). The remaining 60 configs +# are biased toward shapes cuBLAS uses for 8K² fp16 GEMMs. +@tilus.autotune("num_stages", [3, 4, 5, 6]) +@tilus.autotune("block_m, block_n", [[128, 128], [128, 256], [256, 256]]) +@tilus.autotune("block_k", [16, 32, 64]) +@tilus.autotune("swizzle_size", [4, 8]) +class MatmulWGMMAV4(tilus.Script): + def __init__(self, num_stages, block_m, block_n, block_k, swizzle_size): + super().__init__() + self.num_stages = num_stages + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + self.swizzle_size = swizzle_size + + def compute_block_coord( + self, linear_idx: int32, num_m_blocks: int32, num_n_blocks: int + ): + """Map 1D linear block index to 2D (m_block, n_block) with swizzle grouping. + + Tiles in the same swizzle group share N-columns, keeping B columns in L2. + """ + swizzle_size = self.swizzle_size + tiles_per_group = num_m_blocks * swizzle_size + group_idx, in_group_idx = self.fast_divmod(linear_idx, tiles_per_group) + first_n = group_idx * swizzle_size + m_block: int32 = 0 + n_block: int32 = 0 + remainder = num_n_blocks - num_n_blocks // swizzle_size * swizzle_size + last_group_width = remainder if remainder > 0 else swizzle_size + if first_n + swizzle_size <= num_n_blocks: + m_block, r = self.fast_divmod(in_group_idx, swizzle_size) + n_block = first_n + r + else: + m_block, r = self.fast_divmod(in_group_idx, last_group_width) + n_block = first_n + r + return m_block, n_block + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + num_stages = self.num_stages + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + + num_m_blocks = cdiv(m_size, block_m) + num_n_blocks = cdiv(n_size, block_n) + self.attrs.blocks = num_m_blocks * num_n_blocks + self.attrs.warps = 5 + + m_block, n_block = self.compute_block_coord( + self.blockIdx.x, num_m_blocks, num_n_blocks + ) + offset_m: int32 = m_block * block_m + offset_n: int32 = n_block * block_n + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) + sa = self.shared_tensor(dtype=float16, shape=[num_stages, block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[num_stages, block_n, block_k]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + # producer_arrive_count=1: single thread does arrive_and_expect_tx + # consumer_arrive_count=128: all 128 consumer threads arrive when done + tma_pipe = Pipeline( + num_stages, producer_arrive_count=1, consumer_arrive_count=128 + ) + + with self.thread_group(thread_begin=128, num_threads=32): # TMA producer warp + for offset_k in self.range(0, k_size, block_k, unroll=num_stages): + tma_pipe.producer_acquire() + # Producer warp is already 32 threads — the granularity TMA + # needs at SASS level. Only the arrive runs in single_thread so + # transaction-bytes is counted once. + with self.single_thread(): + self.mbarrier.arrive_and_expect_tx( + tma_pipe.producer_barrier(), + transaction_bytes=sa[tma_pipe.producer_stage].nbytes + + sb[tma_pipe.producer_stage].nbytes, + ) + self.tma.global_to_shared( + src=ga, + dst=sa[tma_pipe.producer_stage], + offsets=[offset_m, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + self.tma.global_to_shared( + src=gb, + dst=sb[tma_pipe.producer_stage], + offsets=[offset_n, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + tma_pipe.producer_advance() + + # drain: wait for consumer to finish processing all in-flight stages + for _ in self.range(min(num_stages, cdiv(k_size, block_k))): + tma_pipe.producer_acquire() + tma_pipe.producer_advance() + + with self.thread_group(thread_begin=0, num_threads=128): # WGMMA consumer + # Prologue: issue first MMA; don't release stage 0 yet (it is + # still being read; we release it in the first main-loop iteration + # after wait_group(1) confirms the MMA is done). + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma( + sa[tma_pipe.consumer_stage], sb[tma_pipe.consumer_stage].transpose(), acc + ) + self.wgmma.commit_group() + tma_pipe.consumer_advance() + + # Main loop: issue MMA then call wait_group(1) *after* commit so + # the hardware can pipeline the current and previous groups while + # the TMA wait (consumer_acquire) was overlapping with the prior + # group. wait_group(1) stalls until the *previous* group (n-1) is + # done, then we safely release its stage before advancing. + for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma( + sa[tma_pipe.consumer_stage], + sb[tma_pipe.consumer_stage].transpose(), + acc, + ) + self.wgmma.commit_group() + self.wgmma.wait_group(1) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + tma_pipe.consumer_advance() + + # Epilogue: drain the last in-flight MMA group, then release its stage. + self.wgmma.wait_group(0) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + + self.sync() + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + [4096, 4096, 14336], + [8192, 8192, 8192], + [10240, 10240, 10240], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulWGMMAV4() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b.T + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) + + for name, func in [ + ("torch", lambda: torch.matmul(a, b.T, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +if __name__ == "__main__": + main() diff --git a/examples/hopper_matmul/matmul_v5.py b/examples/hopper_matmul/matmul_v5.py new file mode 100644 index 00000000..dddd6557 --- /dev/null +++ b/examples/hopper_matmul/matmul_v5.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Optimizations from v4: +# - Use mbarrier for synchronization between producer and consumer WGs, instead of using shared memory as flags. +# - Use two consumer WGs to consume the produced tiles in parallel, and use WGMMA commit/wait group to synchronize between them, instead of using a single consumer WG to consume all tiles. + +import math + +import pandas +import tilus +import torch +from tilus import RegisterTensor, float16, float32, int32, uint32 +from tilus.utils import benchmark_func, cdiv + + +class Pipeline(tilus.Class): + def __init__( + self, + num_stages: int, + producer_arrive_count: int = 1, + consumer_arrive_count: int = 1, + ): + self.num_stages: int = num_stages + self.empty_barriers = self.mbarrier.alloc( + [consumer_arrive_count for _ in range(num_stages)] + ) + self.full_barriers = self.mbarrier.alloc( + [producer_arrive_count for _ in range(num_stages)] + ) + self.producer_stage: int32 = 0 + self.consumer_stage: int32 = 0 + self.producer_phase: uint32 = self.mbarrier.producer_initial_phase + self.consumer_phase: uint32 = self.mbarrier.consumer_initial_phase + + def producer_acquire(self): + self.mbarrier.wait( + barrier=self.empty_barriers[self.producer_stage], + phase=self.producer_phase, + sem="relaxed", + scope="cta", + ) + + def producer_barrier(self) -> RegisterTensor: + return self.full_barriers[self.producer_stage] + + def producer_advance(self): + self.producer_stage = (self.producer_stage + 1) % self.num_stages + self.producer_phase = self.producer_phase ^ (self.producer_stage == 0) + + def consumer_acquire(self): + self.mbarrier.wait( + barrier=self.full_barriers[self.consumer_stage], + phase=self.consumer_phase, + sem="relaxed", + scope="cta", + ) + + def consumer_barrier(self) -> RegisterTensor: + return self.empty_barriers[self.consumer_stage] + + def consumer_advance(self): + self.consumer_stage = (self.consumer_stage + 1) % self.num_stages + self.consumer_phase = self.consumer_phase ^ (self.consumer_stage == 0) + + def prev_consumer_barrier(self) -> RegisterTensor: + prev_stage = (self.consumer_stage + (self.num_stages - 1)) % self.num_stages + return self.empty_barriers[prev_stage] + + +# block_m must be >= 128 so each WG's WGMMA M = block_m/2 >= 64. +@tilus.autotune("num_stages", [3, 4, 5, 6]) +@tilus.autotune("block_m, block_n", [[128, 128], [128, 256], [256, 128], [256, 256]]) +@tilus.autotune("block_k", [16, 32, 64]) +@tilus.autotune("swizzle_size", [4, 8]) +class MatmulWGMMAV5(tilus.Script): + def __init__(self, num_stages, block_m, block_n, block_k, swizzle_size): + super().__init__() + self.num_stages = num_stages + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + self.swizzle_size = swizzle_size + + def compute_block_coord( + self, linear_idx: int32, num_m_blocks: int32, num_n_blocks: int + ): + swizzle_size = self.swizzle_size + tiles_per_group = num_m_blocks * swizzle_size + group_idx, in_group_idx = self.fast_divmod(linear_idx, tiles_per_group) + first_n = group_idx * swizzle_size + m_block: int32 = 0 + n_block: int32 = 0 + remainder = num_n_blocks - num_n_blocks // swizzle_size * swizzle_size + last_group_width = remainder if remainder > 0 else swizzle_size + if first_n + swizzle_size <= num_n_blocks: + m_block, r = self.fast_divmod(in_group_idx, swizzle_size) + n_block = first_n + r + else: + m_block, r = self.fast_divmod(in_group_idx, last_group_width) + n_block = first_n + r + return m_block, n_block + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + num_stages = self.num_stages + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + block_m_half = block_m // 2 + + num_m_blocks = cdiv(m_size, block_m) + num_n_blocks = cdiv(n_size, block_n) + self.attrs.blocks = num_m_blocks * num_n_blocks + self.attrs.warps = 9 # 1 producer + 2 consumer WGs (4 warps each) + + m_block, n_block = self.compute_block_coord( + self.blockIdx.x, num_m_blocks, num_n_blocks + ) + offset_m: int32 = m_block * block_m + offset_n: int32 = n_block * block_n + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + # Per-WG A slab: index as sa[stage, wg_idx]. + sa = self.shared_tensor( + dtype=float16, shape=[num_stages, 2, block_m_half, block_k] + ) + sb = self.shared_tensor(dtype=float16, shape=[num_stages, block_n, block_k]) + + tma_pipe = Pipeline( + num_stages, producer_arrive_count=1, consumer_arrive_count=256 + ) + + with self.thread_group(thread_begin=256, num_threads=32): # TMA producer + for offset_k in self.range(0, k_size, block_k, unroll=num_stages): + tma_pipe.producer_acquire() + with self.single_thread(): + self.mbarrier.arrive_and_expect_tx( + tma_pipe.producer_barrier(), + transaction_bytes=sa[tma_pipe.producer_stage, 0].nbytes + + sa[tma_pipe.producer_stage, 1].nbytes + + sb[tma_pipe.producer_stage].nbytes, + ) + self.tma.global_to_shared( + src=ga, + dst=sa[tma_pipe.producer_stage, 0], + offsets=[offset_m, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + self.tma.global_to_shared( + src=ga, + dst=sa[tma_pipe.producer_stage, 1], + offsets=[offset_m + block_m_half, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + self.tma.global_to_shared( + src=gb, + dst=sb[tma_pipe.producer_stage], + offsets=[offset_n, offset_k], + mbarrier=tma_pipe.producer_barrier(), + ) + tma_pipe.producer_advance() + + for _ in self.range(min(num_stages, cdiv(k_size, block_k))): + tma_pipe.producer_acquire() + tma_pipe.producer_advance() + + with self.thread_group(thread_begin=0, num_threads=128): # consumer WG0 + acc0 = self.register_tensor( + dtype=float32, shape=[block_m_half, block_n], init=0.0 + ) + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma( + sa[tma_pipe.consumer_stage, 0], + sb[tma_pipe.consumer_stage].transpose(), + acc0, + ) + self.wgmma.commit_group() + tma_pipe.consumer_advance() + + for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma( + sa[tma_pipe.consumer_stage, 0], + sb[tma_pipe.consumer_stage].transpose(), + acc0, + ) + self.wgmma.commit_group() + self.wgmma.wait_group(1) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + tma_pipe.consumer_advance() + + self.wgmma.wait_group(0) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + + casted0 = self.cast(acc0, dtype=float16) + self.store_global(gc, casted0, offsets=[offset_m, offset_n]) + + with self.thread_group(thread_begin=128, num_threads=128): # consumer WG1 + acc1 = self.register_tensor( + dtype=float32, shape=[block_m_half, block_n], init=0.0 + ) + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma( + sa[tma_pipe.consumer_stage, 1], + sb[tma_pipe.consumer_stage].transpose(), + acc1, + ) + self.wgmma.commit_group() + tma_pipe.consumer_advance() + + for offset_k in self.range(block_k, k_size, block_k, unroll=num_stages): + tma_pipe.consumer_acquire() + self.wgmma.fence() + self.wgmma.mma( + sa[tma_pipe.consumer_stage, 1], + sb[tma_pipe.consumer_stage].transpose(), + acc1, + ) + self.wgmma.commit_group() + self.wgmma.wait_group(1) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + tma_pipe.consumer_advance() + + self.wgmma.wait_group(0) + self.mbarrier.arrive(tma_pipe.prev_consumer_barrier()) + + casted1 = self.cast(acc1, dtype=float16) + self.store_global(gc, casted1, offsets=[offset_m + block_m_half, offset_n]) + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + [4096, 4096, 14336], + [8192, 8192, 8192], + [10240, 10240, 10240], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulWGMMAV5() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b.T + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) + + for name, func in [ + ("torch", lambda: torch.matmul(a, b.T, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +if __name__ == "__main__": + main() diff --git a/python/tilus/backends/contexts/mbarrier_alloc_ctx.py b/python/tilus/backends/contexts/mbarrier_alloc_ctx.py index 73c5ab8b..b5a4a8ca 100644 --- a/python/tilus/backends/contexts/mbarrier_alloc_ctx.py +++ b/python/tilus/backends/contexts/mbarrier_alloc_ctx.py @@ -26,8 +26,6 @@ from tilus.hidet.ir.primitives.cuda.smem import dynamic_shared_memory from tilus.hidet.ir.primitives.cuda.sync import syncthreads from tilus.hidet.ir.primitives.cuda.vars import threadIdx -from tilus.ir.layout import ops -from tilus.ir.tensor import SharedTensor class BarrierAllocContext(BaseEmitContext): @@ -46,8 +44,17 @@ def finalize(self): # No barriers to allocate return - tensor = SharedTensor(dtype=uint64, shape=(num_barriers,), optional_layout=ops.shared_row_major(num_barriers)) - virtual_smem_addr = self.contexts.smem_alloc_ctx.allocate_shared_tensor(tensor, nbytes=tensor.storage_nbytes) + # Place barriers past the smem high-water mark rather than into a slot that + # was freed by FreeShared. The smem allocator only tracks the static + # post-emit free list; at runtime, data tensors that get freed *after* the + # main loop are still alive while the barriers are in use, so reusing their + # slot would cause TMA writes to clobber the barrier state. + smem_alloc_ctx = self.contexts.smem_alloc_ctx + nbytes = num_barriers * uint64.nbytes + alignment = 128 + aligned_hwm = (smem_alloc_ctx.smem_allocator.maximum_allocated + alignment - 1) // alignment * alignment + virtual_smem_addr = aligned_hwm + smem_alloc_ctx.smem_allocator.maximum_allocated = aligned_hwm + nbytes sb = StmtBuilder() sb.declare( v=self.barrier_addr, diff --git a/python/tilus/backends/emitters/cuda/cp_async_tensor.py b/python/tilus/backends/emitters/cuda/cp_async_tensor.py index e71b18cd..a890b0b9 100644 --- a/python/tilus/backends/emitters/cuda/cp_async_tensor.py +++ b/python/tilus/backends/emitters/cuda/cp_async_tensor.py @@ -43,6 +43,7 @@ from tilus.hidet.ir.tools import rewrite, simplify from tilus.hidet.ir.type import DataType, PointerType, TensorType, sizeof from tilus.ir import GlobalLayout +from tilus.ir.inst import Instruction from tilus.ir.instructions.cuda.cp_async_tensor import ( CopyAsyncTensorCommitGroupInst, CopyAsyncTensorGlobalToSharedInst, @@ -52,7 +53,7 @@ from tilus.ir.tensor import GlobalTensor, SharedTensor from tilus.ir.utils.lineardec import LinearDecompositionError, decompose_linear from tilus.ir.utils.veceval import vectorized_evaluate -from tilus.target import nvgpu_sm90 +from tilus.target import get_current_target, nvgpu_sm90 @dataclass(frozen=True, eq=False) @@ -151,6 +152,27 @@ def get_offset_grid_of_swizzled_layout( class CopyAsyncTensorBaseEmitter(BaseInstEmitter): + def assert_is_single_thread_or_warp_aligned(self, inst: Instruction, msg: str) -> None: + # TMA copies are issued by one elected lane. single_thread() already + # narrows execution to one lane; at warp scope, the TMA predicate elects + # the leader lane. + if self.current_num_threads == 1: + return + if self.current_num_threads != 32 or self.current_thread_group_begin % 32 != 0: + raise ValueError( + f"Instruction {inst} requires a single-thread or warp-aligned context " + f"(num_threads==1, or thread_begin % 32 == 0 and num_threads == 32), " + f"got thread_begin={self.current_thread_group_begin}, num_threads={self.current_num_threads}: {msg}." + ) + + @property + def tma_predicate(self) -> Expr: + # Inside single_thread() only one thread reaches the TMA call, so use + # constant true. At warp scope, predicate the asm on the elected leader. + if self.current_num_threads == 1: + return uint32(1) + return self.contexts.leader_lane_ctx.leader_lane + def resolve_global_tensor_info( self, global_tensor: GlobalTensor, offsets: Sequence[Expr], dims: Sequence[int] ) -> GlobalTensorInfo: @@ -221,6 +243,97 @@ def resolve_shared_tensor_info(self, shared_tensor: SharedTensor) -> SharedTenso + layout.visualize() ) + def resolve_shared_tensor_segments( + self, shared_tensor: SharedTensor + ) -> tuple[list[tuple[SharedTensorInfo, int]], int]: + """Decompose a 2D shared tensor whose n-axis has been split by the layout. + + The selector into ``S`` row-major sub-blocks (``mode_shape=[bm, S, bn/S]``, + ``mode_strides=[bn/S, bm*bn/S, 1]``) into ``S`` TMA-issuable boxes. + + Such layouts arise when ``store_shared`` mirrors the wgmma register-tile + fragmentation (each warp's output spans ``bn/S`` columns and the warps + are stacked along the n-axis at strided offsets). The aggregate layout + is not one of the 7 hardware swizzle patterns, but each per-segment + ``[bm, bn/S]`` sub-tile *is*, so we emit one TMA store per segment. + + Returns + ------- + segments + A list of ``(per_segment_info, segment_n_offset)`` pairs, in segment + order. ``segment_n_offset`` is the column offset (in elements) of the + sub-tile within the original n-axis, to be added to ``inst.offsets`` + on the segmented dim. + seg_dim + The shared-tensor dimension that has been segmented (always 1 for + now; surfaced for future extension). + """ + layout: SharedLayout = shared_tensor.layout + + if len(shared_tensor.shape) != 2 or len(layout.mode_shape) != 3: + raise NotImplementedError("segmented decomposition only handles 2D layouts split into 3 modes") + + bm, bn = shared_tensor.shape + m0, s, n_inner = layout.mode_shape + sm, ss, sn = layout.mode_strides + + # Expected: dim 0 = [bm], dim 1 = [S, bn/S]; segments contiguous as + # ``[bm, bn/S]`` row-major boxes stacked along n. + if m0 != bm or s * n_inner != bn: + raise NotImplementedError("mode shape does not segment the n-axis") + if sn != 1 or sm != n_inner or ss != bm * n_inner: + raise NotImplementedError("mode strides do not match contiguous [bm, bn/S] segments stacked along n") + + # Validate that each segment is a single TMA-supported swizzle box. + per_segment_shape = (bm, n_inner) + per_segment_layout = SharedLayout( + shape=per_segment_shape, + mode_shape=(bm, n_inner), + mode_strides=(n_inner, 1), + optional_swizzle=layout.optional_swizzle, + ) + per_segment_grid = per_segment_layout.as_numpy_grid() + + chosen_swizzle: Optional[TensorMapSwizzle] = None + for swizzle in [ + TensorMapSwizzle.NONE, + TensorMapSwizzle.B32, + TensorMapSwizzle.B64, + TensorMapSwizzle.B128, + TensorMapSwizzle.B128_ATOM_32B, + TensorMapSwizzle.B128_ATOM_32B_FLIP_8B, + TensorMapSwizzle.B128_ATOM_64B, + ]: + swizzled_grid = get_offset_grid_of_swizzled_layout( + dtype_nbits=shared_tensor.dtype.nbits, shape=per_segment_shape, swizzle=swizzle + ) + if swizzled_grid is not None and np.array_equal(per_segment_grid, swizzled_grid): + chosen_swizzle = swizzle + break + + if chosen_swizzle is None: + raise NotImplementedError( + "Segment layout does not match any TMA hardware swizzle: \n" + + f"Per-segment shape: {shared_tensor.dtype.name}{list(per_segment_shape)}\n" + + per_segment_layout.visualize() + ) + + base_addr = self.shared_tensor_shared_space_addr[shared_tensor] + segment_nbytes = bm * n_inner * shared_tensor.dtype.nbytes + segments: list[tuple[SharedTensorInfo, int]] = [] + for k in range(s): + segments.append( + ( + SharedTensorInfo( + addr=base_addr + k * segment_nbytes, + shape=per_segment_shape, + swizzle=chosen_swizzle, + ), + k * n_inner, + ) + ) + return segments, 1 + def declare_host_buffer(self, name: str, dtype: DataType, shape: Sequence[int]) -> Var: return self.host_builder.declare_var(name=name, tp=TensorType(dtype=dtype, shape=shape)) @@ -285,7 +398,7 @@ def create_tensor_map(self, global_info: GlobalTensorInfo, shared_info: SharedTe @register_emitter(CopyAsyncTensorGlobalToSharedInst, target=nvgpu_sm90) class CopyAsyncTensorGlobalToSharedInstEmitter(CopyAsyncTensorBaseEmitter): def emit(self, inst: CopyAsyncTensorGlobalToSharedInst) -> None: - self.assert_is_warp_aligned(inst, "TMA global to shared is a warp-cooperative instruction") + self.assert_is_single_thread_or_warp_aligned(inst, "TMA global to shared must be issued by one thread") global_tensor: GlobalTensor = inst.inputs[1].as_global_tensor() shared_tensor: SharedTensor = inst.inputs[0].as_shared_tensor() assert global_tensor.dtype == shared_tensor.dtype @@ -301,7 +414,11 @@ def emit(self, inst: CopyAsyncTensorGlobalToSharedInst) -> None: src_tensor_map = ~self.create_tensor_map(global_tensor_info, shared_tensor_info, dtype) coords = list(reversed(inst.offsets)) optional_multicast_mask = inst.multicast_mask - predicate = self.contexts.leader_lane_ctx.leader_lane + predicate = self.tma_predicate + # `.cta_group::{n}` is a Blackwell (sm_100+) PTX feature; ptxas rejects it on + # sm_90a even though the IR always carries cta_group=1. Pass None on Hopper so + # the inline asm template emits the unqualified TMA instruction. + cta_group = inst.cta_group if get_current_target().properties.compute_capability >= (10, 0) else None if optional_multicast_mask is None: self.append( @@ -310,7 +427,7 @@ def emit(self, inst: CopyAsyncTensorGlobalToSharedInst) -> None: src_tensor_map=src_tensor_map, coords=coords, mbarrier=inst.mbarrier, - cta_group=inst.cta_group, + cta_group=cta_group, cache_policy=inst.cache_policy, predicate=predicate, ) @@ -324,7 +441,7 @@ def emit(self, inst: CopyAsyncTensorGlobalToSharedInst) -> None: coords=coords, mbarrier=inst.mbarrier, multicast_mask=multicast_mask, - cta_group=inst.cta_group, + cta_group=cta_group, cache_policy=inst.cache_policy, predicate=predicate, ) @@ -334,7 +451,7 @@ def emit(self, inst: CopyAsyncTensorGlobalToSharedInst) -> None: @register_emitter(CopyAsyncTensorSharedToGlobalInst, target=nvgpu_sm90) class CopyAsyncTensorSharedToGlobalInstEmitter(CopyAsyncTensorBaseEmitter): def emit(self, inst: CopyAsyncTensorSharedToGlobalInst) -> None: - self.assert_is_warp_aligned(inst, "TMA shared to global is a warp-cooperative instruction") + self.assert_is_single_thread_or_warp_aligned(inst, "TMA shared to global must be issued by one thread") global_tensor: GlobalTensor = inst.inputs[0].as_global_tensor() shared_tensor: SharedTensor = inst.inputs[1].as_shared_tensor() assert global_tensor.dtype == shared_tensor.dtype @@ -344,20 +461,35 @@ def emit(self, inst: CopyAsyncTensorSharedToGlobalInst) -> None: global_tensor, offsets=inst.offsets, dims=inst.dims ) - shared_tensor_info: SharedTensorInfo = self.resolve_shared_tensor_info(shared_tensor) - - shared_addr = self.shared_tensor_shared_space_addr[shared_tensor] - tensor_map = self.create_tensor_map(global_tensor_info, shared_tensor_info, dtype) - tensor_coords = inst.offsets - self.append( - cp_async_tensor_shared_to_global( - dst_tensor_map=~tensor_map, - src=shared_addr, - coords=list(reversed(tensor_coords)), - cache_policy=inst.cache_policy, - predicate=self.contexts.leader_lane_ctx.leader_lane, + try: + shared_tensor_info: SharedTensorInfo = self.resolve_shared_tensor_info(shared_tensor) + segments: list[tuple[SharedTensorInfo, int]] = [(shared_tensor_info, 0)] + seg_dim: Optional[int] = None + except NotImplementedError: + # Fall back to per-segment emission for layouts that split a dim + # into stacked sub-boxes (typically the wgmma-fragment layout that + # store_shared inherits when targeting sc[bm, bn]). + segments, seg_dim = self.resolve_shared_tensor_segments(shared_tensor) + + # All segments share box shape and swizzle, so reuse one descriptor. + first_info = segments[0][0] + tensor_map = self.create_tensor_map(global_tensor_info, first_info, dtype) + for info, segment_offset in segments: + tensor_coords = list(inst.offsets) + if seg_dim is not None and segment_offset != 0: + # seg_dim indexes the *shared* dims; map it to the matching + # global dim through inst.dims, then shift that coord. + global_seg_dim = inst.dims[seg_dim] + tensor_coords[global_seg_dim] = tensor_coords[global_seg_dim] + segment_offset + self.append( + cp_async_tensor_shared_to_global( + dst_tensor_map=~tensor_map, + src=info.addr, + coords=list(reversed(tensor_coords)), + cache_policy=inst.cache_policy, + predicate=self.tma_predicate, + ) ) - ) @register_emitter(CopyAsyncTensorCommitGroupInst, target=nvgpu_sm90) diff --git a/python/tilus/hidet/option.py b/python/tilus/hidet/option.py index 5bd3a8bf..b61e0e58 100644 --- a/python/tilus/hidet/option.py +++ b/python/tilus/hidet/option.py @@ -27,6 +27,7 @@ from __future__ import annotations +import os from typing import Any, Callable, Dict, Iterable, List, Optional @@ -115,6 +116,18 @@ def get_option(self, name: str) -> Any: if name not in OptionRegistry.registered_options: raise KeyError(f"Option {name} has not been registered.") registry = OptionRegistry.registered_options[name] + if registry.env is not None and registry.env in os.environ: + raw = os.environ[registry.env] + if registry.normalizer is not None: + return registry.normalizer(raw) + # coerce to the same type as the default value when no normalizer is provided + if isinstance(registry.default_value, int): + return int(raw) + if isinstance(registry.default_value, float): + return float(raw) + if isinstance(registry.default_value, bool): + return raw.lower() not in ("0", "false", "no", "off") + return raw return registry.default_value diff --git a/python/tilus/hidet/transforms/_rule_based_simplifier_base.py b/python/tilus/hidet/transforms/_rule_based_simplifier_base.py index e227c836..e8441e42 100644 --- a/python/tilus/hidet/transforms/_rule_based_simplifier_base.py +++ b/python/tilus/hidet/transforms/_rule_based_simplifier_base.py @@ -82,6 +82,35 @@ def c_div(a, b): return a / b +def _c_trunc_div(a, b): + """Integer division truncated toward zero (C semantics: a/b). + + When a or b are not plain ints (i.e. they are Hidet IR Expr nodes used + for pattern construction) fall back to Python's // so that the operator + overloads build the correct IR node. The C-semantics branch is only + reached during candidate-value verification where both operands are ints. + """ + if not isinstance(a, int) or not isinstance(b, int): + return a // b # IR expression path — builds a Hidet Div node + if b == 0: + return 0 + return int(a / b) # truncate-toward-zero, avoiding Python's floor behaviour + + +def _c_trunc_mod(a, b): + """Integer modulo with C semantics (result has the sign of a, not b). + + Same dual-mode design as _c_trunc_div: falls back to Python % when called + with Hidet IR Expr nodes so the operator overloads build a Hidet Mod node. + The C-semantics branch is only reached during candidate-value verification. + """ + if not isinstance(a, int) or not isinstance(b, int): + return a % b # IR expression path — builds a Hidet Mod node + if b == 0: + return 0 + return a - b * _c_trunc_div(a, b) + + class ConstExprSimplifier(IRRewriter): op_dict = { Add: operator.add, @@ -92,7 +121,7 @@ class ConstExprSimplifier(IRRewriter): BitwiseAnd: operator.and_, BitwiseXor: operator.xor, BitwiseNot: operator.invert, - Mod: operator.mod, + Mod: _c_trunc_mod, LessThan: operator.lt, LessEqual: operator.le, Equal: operator.eq, @@ -192,21 +221,29 @@ def __init__(self, skip_node_types: Optional[Sequence[Type[Expr]]] = None): (IfThenElse(ec1, ec2, ec2), ec2), ] self.bound_patterns = [ - # ((pattern_args, pattern_func, target_args, target_func) + # Verify using C-compatible truncated division/modulo rather than Python + # floor division/modulo. The generated code uses C's % and /, which + # differ from Python for negative operands, so the candidate-value checks + # must use the same semantics as the emitted C code. ( (ec1, ec2, c1), (ec1, ec2, c1), - lambda ec1, ec2, c1: (ec1 + ec2) // c1, - lambda ec1, ec2, c1: ec1 // c1 + ec2 // c1, + lambda ec1, ec2, c1: _c_trunc_div(ec1 + ec2, c1), + lambda ec1, ec2, c1: _c_trunc_div(ec1, c1) + _c_trunc_div(ec2, c1), ), ( (ec1, ec2, c1), (ec1, ec2, c1), - lambda ec1, ec2, c1: (ec1 + ec2) % c1, - lambda ec1, ec2, c1: ec1 % c1 + ec2 % c1, + lambda ec1, ec2, c1: _c_trunc_mod(ec1 + ec2, c1), + lambda ec1, ec2, c1: _c_trunc_mod(ec1, c1) + _c_trunc_mod(ec2, c1), + ), + ((ec1, c1), (ec1,), lambda ec1, c1: _c_trunc_mod(ec1, c1), lambda ec1: ec1), + ( + (ec1, c1, c2), + (ec1, c2), + lambda ec1, c1, c2: _c_trunc_mod(_c_trunc_mod(ec1, c1), c2), + lambda ec1, c2: _c_trunc_mod(ec1, c2), ), - ((ec1, c1), (ec1,), lambda ec1, c1: ec1 % c1, lambda ec1: ec1), - ((ec1, c1, c2), (ec1, c2), lambda ec1, c1, c2: (ec1 % c1) % c2, lambda ec1, c2: ec1 % c2), ] if skip_node_types: for skip_type in skip_node_types: @@ -288,7 +325,10 @@ def visit(self, obj): def visit_Mod(self, e: Mod): ua, ub = self.bound[e.a], self.bound[e.b] - if ua.is_zero() or ua < ub: + # In C, x % y == x only when 0 <= x < y (truncated modulo). The < check + # covers the upper bound; also require a provably non-negative lower bound. + a_min = ua.possible_min_value() + if ua.is_zero() or (ua < ub and a_min is not None and a_min >= 0): return self(e.a) return IRRewriter.visit_Mod(self, e) diff --git a/python/tilus/lang/instantiated_script.py b/python/tilus/lang/instantiated_script.py index becefe23..0014cffd 100644 --- a/python/tilus/lang/instantiated_script.py +++ b/python/tilus/lang/instantiated_script.py @@ -650,10 +650,6 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: latency.append(0.0) else: tuning_key_name = " " + "-".join([str(v) for v in tuning_key]) if tuning_key else "" - kernel_args = [ - args[i].clone() if isinstance(args[i], torch.Tensor) else args[i] - for i in self.call_params.kernel_params - ] for i, compiled_program in tqdm( iterable=enumerate(self.compiled_programs), desc="[{}] {}{}".format("Tuning", self.instance_name, tuning_key_name), @@ -662,14 +658,16 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: ncols=60 + max(60, len(self.instance_name) + len(tuning_key_name)), ): compiled_func = compiled_program.get_launch_func() + kernel_args = [ + args[j].clone() if isinstance(args[j], torch.Tensor) else args[j] + for j in self.call_params.kernel_params + ] try: - latency.append( - benchmark_func( - lambda: compiled_func(*kernel_args), - warmup=tilus.option.get_option("bench_warmup"), - repeat=tilus.option.get_option("bench_repeat"), - ) - ) # type: ignore + lat = benchmark_func( + lambda: compiled_func(*kernel_args), + warmup=tilus.option.get_option("bench_warmup"), + repeat=tilus.option.get_option("bench_repeat"), + ) except RuntimeError as e: raise RuntimeError( f"Failed to benchmark the kernel {self.instance_name} with schedule: \n" @@ -677,6 +675,7 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: "Error message:\n" f" {str(e)}" ) from e + latency.append(lat) # type: ignore best_latency = min(latency) best_program_idx = latency.index(best_latency) diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index eef63fb1..690878a7 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -64,6 +64,8 @@ ("hopper_matmul", "matmul_v1.py", nvgpu_sm90a), ("hopper_matmul", "matmul_v2.py", nvgpu_sm90a), ("hopper_matmul", "matmul_v3.py", nvgpu_sm90a), + ("hopper_matmul", "matmul_v4.py", nvgpu_sm90a), + ("hopper_matmul", "matmul_v5.py", nvgpu_sm90a), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), # flash attention decode examples (SM 8.0+) @@ -78,6 +80,7 @@ ("flash_attention_decode", "tilus_kernel.py"), # Benchmark utilities ("blackwell_matmul", "benchmark.py"), + ("hopper_matmul", "benchmark.py"), ]