diff --git a/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py b/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py index 1582894..b48bfbb 100644 --- a/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py +++ b/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py @@ -3,84 +3,14 @@ # Status: Experimental / uncurated # Expectation: Correctness-first, performance not representative +from pathlib import Path +import sys + import torch import torch.nn as nn -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32}), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}), - ], - key=["M", "N", "K"], # autotune per problem size -) -@triton.jit -def _matmul_kernel( - a_ptr, - b_ptr, - c_ptr, - M, - N, - K, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - a_desc = tl.make_tensor_descriptor( - base=a_ptr, shape=(M, K), strides=(K, 1), block_shape=(BLOCK_M, BLOCK_K) - ) - b_desc = tl.make_tensor_descriptor( - base=b_ptr, shape=(K, N), strides=(N, 1), block_shape=(BLOCK_K, BLOCK_N) - ) - c_desc = tl.make_tensor_descriptor( - base=c_ptr, shape=(M, N), strides=(N, 1), block_shape=(BLOCK_M, BLOCK_N) - ) - - m = tl.program_id(0) * BLOCK_M - n = tl.program_id(1) * BLOCK_N - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(0, K, BLOCK_K): - a = a_desc.load((m, k)) - b = b_desc.load((k, n)) - acc = tl.dot(a, b, acc) - - c_desc.store((m, n), acc) - - -def _kernel_function_cpu(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - assert isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor) - assert A.device.type == "cpu" and B.device.type == "cpu", "A and B must be on CPU" - assert A.is_floating_point() and B.is_floating_point(), ( - "A and B must be floating point tensors" - ) - assert A.dtype == B.dtype, f"dtype mismatch: {A.dtype} vs {B.dtype}" - - orig_dtype = A.dtype - - M, K = A.shape - K2, N = B.shape - assert K == K2, f"Incompatible K dimensions: {K} vs {K2}" - - C32 = torch.empty((M, N), device=A.device, dtype=torch.float32) - - # Autotuned grid: depends on BLOCK_M/BLOCK_N chosen by config - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]), - triton.cdiv(N, META["BLOCK_N"]), - ) - - _matmul_kernel[grid]( - A, - B, - C32, - M, - N, - K, - ) - return C32.to(orig_dtype) +sys.path.insert(0, str(Path(__file__).parent)) +from sfc_matmul import sfc_matmul class Model(nn.Module): @@ -90,4 +20,4 @@ def __init__(self, *args, **kwargs): super(Model, self).__init__() def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - return _kernel_function_cpu(A, B) + return sfc_matmul(A, B) diff --git a/backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py b/backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py new file mode 100644 index 0000000..b48bfbb --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py @@ -0,0 +1,23 @@ +# ruff: noqa: E731 +# Example Triton CPU kernel +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +from pathlib import Path +import sys + +import torch +import torch.nn as nn + +sys.path.insert(0, str(Path(__file__).parent)) +from sfc_matmul import sfc_matmul + + +class Model(nn.Module): + """KernelBench-compatible wrapper""" + + def __init__(self, *args, **kwargs): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + return sfc_matmul(A, B) diff --git a/backends/triton/cpu/KernelBench/level1/LICENSE.gilbert_d2xy.txt b/backends/triton/cpu/KernelBench/level1/LICENSE.gilbert_d2xy.txt new file mode 100644 index 0000000..a8fbe21 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/LICENSE.gilbert_d2xy.txt @@ -0,0 +1,25 @@ +BSD 2-Clause License + +Copyright (c) 2018, Jakub Červený +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/backends/triton/cpu/KernelBench/level1/gilbert_d2xy.py b/backends/triton/cpu/KernelBench/level1/gilbert_d2xy.py new file mode 100755 index 0000000..7a8776e --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/gilbert_d2xy.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: BSD-2-Clause +# Copyright (c) 2024 abetusk + +# Source: https://github.com/jakubcerveny/gilbert/blob/9b080a74c5e3b6fe189785c52742f83ac85ba181/gilbert_d2xy.py + + +def gilbert_d2xy(idx, w, h): + """ + Generalized Hilbert ('gilbert') space-filling curve for arbitrary-sized + 2D rectangular grids. Takes a position along the gilbert curve and returns + its 2D (x,y) coordinate. + """ + + if w >= h: + return gilbert_d2xy_r(idx, 0, 0, 0, w, 0, 0, h) + return gilbert_d2xy_r(idx, 0, 0, 0, 0, h, w, 0) + + +def sgn(x): + return -1 if x < 0 else (1 if x > 0 else 0) + + +def gilbert_d2xy_r(dst_idx, cur_idx, x, y, ax, ay, bx, by): + w = abs(ax + ay) + h = abs(bx + by) + + (dax, day) = (sgn(ax), sgn(ay)) # unit major direction + (dbx, dby) = (sgn(bx), sgn(by)) # unit orthogonal direction + + di = dst_idx - cur_idx + + if h == 1: + return (x + dax * di, y + day * di) + if w == 1: + return (x + dbx * di, y + dby * di) + + (ax2, ay2) = (ax // 2, ay // 2) + (bx2, by2) = (bx // 2, by // 2) + + w2 = abs(ax2 + ay2) + h2 = abs(bx2 + by2) + + if 2 * w > 3 * h: + if (w2 % 2) and (w > 2): + # prefer even steps + (ax2, ay2) = (ax2 + dax, ay2 + day) + + # long case: split in two parts only + nxt_idx = cur_idx + abs((ax2 + ay2) * (bx + by)) + if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): + return gilbert_d2xy_r(dst_idx, cur_idx, x, y, ax2, ay2, bx, by) + cur_idx = nxt_idx + + return gilbert_d2xy_r( + dst_idx, cur_idx, x + ax2, y + ay2, ax - ax2, ay - ay2, bx, by + ) + + if (h2 % 2) and (h > 2): + # prefer even steps + (bx2, by2) = (bx2 + dbx, by2 + dby) + + # standard case: one step up, one long horizontal, one step down + nxt_idx = cur_idx + abs((bx2 + by2) * (ax2 + ay2)) + if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): + return gilbert_d2xy_r(dst_idx, cur_idx, x, y, bx2, by2, ax2, ay2) + cur_idx = nxt_idx + + nxt_idx = cur_idx + abs((ax + ay) * ((bx - bx2) + (by - by2))) + if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): + return gilbert_d2xy_r( + dst_idx, cur_idx, x + bx2, y + by2, ax, ay, bx - bx2, by - by2 + ) + cur_idx = nxt_idx + + return gilbert_d2xy_r( + dst_idx, + cur_idx, + x + (ax - dax) + (bx2 - dbx), + y + (ay - day) + (by2 - dby), + -bx2, + -by2, + -(ax - ax2), + -(ay - ay2), + ) diff --git a/backends/triton/cpu/KernelBench/level1/sfc_matmul.py b/backends/triton/cpu/KernelBench/level1/sfc_matmul.py new file mode 100644 index 0000000..49a5c39 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/sfc_matmul.py @@ -0,0 +1,193 @@ +import functools + +from gilbert_d2xy import gilbert_d2xy +import torch +import triton +import triton.language as tl + + +# Transforms the B matrix into a tensor of shape: +# +# (BLOCKS_N, BLOCKS_K, BLOCK_SIZE_K, BLOCK_SIZE_N) +# +# Data is blocked into contiguous chunks of memory. Neighboring blocks in the K +# dimension will also be neighboring in memory. +@triton.jit +def _block_transpose_kernel( + in_ptr, + out_ptr, + sfc_map_ptr, + N, + K, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_k = tl.load(sfc_map_ptr + 2 * pid) + block_n = tl.load(sfc_map_ptr + 2 * pid + 1) + + in_desc = tl.make_tensor_descriptor( + base=in_ptr, + shape=(K, N), + strides=(N, 1), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + ) + out_desc = tl.make_tensor_descriptor( + base=out_ptr, + shape=(N // BLOCK_SIZE_N, K // BLOCK_SIZE_K, BLOCK_SIZE_K, BLOCK_SIZE_N), + strides=(BLOCK_SIZE_N * K, BLOCK_SIZE_K * BLOCK_SIZE_N, BLOCK_SIZE_N, 1), + block_shape=(1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N), + ) + + block = in_desc.load((block_k * BLOCK_SIZE_K, block_n * BLOCK_SIZE_N)).reshape( + (1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N) + ) + out_desc.store((block_n, block_k, 0, 0), block) + + +# Matmul kernel using the space curve filling approach in +# https://arxiv.org/abs/2601.16294v1, based on the generalized hilbert curve +# implementation from https://github.com/jakubcerveny/gilbert +# +# Each program computes a single output tile with the 2D coordinates derived +# from the precomputed SFC mapping. If `BLOCKING_FACTOR_K == 1`, then program +# handles all `BLOCKS_K = K // BLOCK_SIZE_K` blocks along the common dimension, +# otherwise the program performs a partial accumulation of the K blocks in the +# half-open interval: +# [ ik * (BLOCKS_K // BLOCKING_FACTOR_K), (ik + 1) * (BLOCKS_K // BLOCKING_FACTOR_K) ) +@triton.jit +def _sfc_matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + sfc_map_ptr, + M, + N, + K, + ik, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCKING_FACTOR_K: tl.constexpr, +): + BLOCKS_M = M // BLOCK_SIZE_M + BLOCKS_N = N // BLOCK_SIZE_N + BLOCKS_K = K // BLOCK_SIZE_K + BLOCKS_K_PER_PROG = BLOCKS_K // BLOCKING_FACTOR_K + + dtype: tl.constexpr = a_ptr.type.element_ty + accum_dtype: tl.constexpr = tl.float32 if dtype.is_floating() else tl.int32 + + pid = tl.program_id(axis=0) + block_m = tl.load(sfc_map_ptr + 2 * pid) + block_n = tl.load(sfc_map_ptr + 2 * pid + 1) + block_k = ik * BLOCKS_K_PER_PROG + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(BLOCKS_M, BLOCKS_K, BLOCK_SIZE_M, BLOCK_SIZE_K), + strides=(BLOCK_SIZE_M * K, BLOCK_SIZE_K, K, 1), + block_shape=(1, 1, BLOCK_SIZE_M, BLOCK_SIZE_K), + ) + + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=(BLOCKS_N, BLOCKS_K, BLOCK_SIZE_K, BLOCK_SIZE_N), + strides=(BLOCK_SIZE_N * K, BLOCK_SIZE_K * BLOCK_SIZE_N, BLOCK_SIZE_N, 1), + block_shape=(1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N), + ) + + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(BLOCKS_M, BLOCKS_N, BLOCK_SIZE_M, BLOCK_SIZE_N), + strides=(BLOCK_SIZE_M * N, BLOCK_SIZE_N, N, 1), + block_shape=(1, 1, BLOCK_SIZE_M, BLOCK_SIZE_N), + ) + + if ik == 0: + c0 = tl.zeros((1, 1, BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dtype) + c_desc.store([block_m, block_n, 0, 0], c0) + + c = ( + c_desc.load([block_m, block_n, 0, 0]) + .reshape((BLOCK_SIZE_M, BLOCK_SIZE_N)) + .to(accum_dtype) + ) + + for block_ki in range(block_k, block_k + BLOCKS_K_PER_PROG): + a = a_desc.load([block_m, block_ki, 0, 0]).reshape((BLOCK_SIZE_M, BLOCK_SIZE_K)) + b = b_desc.load([block_n, block_ki, 0, 0]).reshape((BLOCK_SIZE_K, BLOCK_SIZE_N)) + + c = tl.dot(a, b, c) + + c = c.to(dtype).reshape((1, 1, BLOCK_SIZE_M, BLOCK_SIZE_N)) + c_desc.store([block_m, block_n, 0, 0], c) + + +@functools.lru_cache +def _make_sfc_tensor(x, y, dtype=torch.int32, device="cpu"): + gilbert = (gilbert_d2xy(i, x, y) for i in range(x * y)) + return torch.tensor([c for xy in gilbert for c in xy], dtype=dtype, device=device) + + +def sfc_matmul(a: torch.Tensor, b: torch.Tensor, blocking_factor_k=1): + assert isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor) + assert a.device.type == "cpu" and b.device.type == "cpu", "A and B must be on CPU" + assert a.dtype == b.dtype, f"dtype mismatch: {a.dtype} vs {b.dtype}" + M, K = a.shape + K2, N = b.shape + assert K == K2, f"Incompatible K dimensions: {K} vs {K2}" + + # AMX + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 32 + BLOCK_SIZE_K = 32 + + # TODO: Currently masked load is not supported yet. + assert ( + (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0) + ), ( + "Masking currently not supported, matrix dimensions must be multiples of block size" + ) + + BLOCKS_M = M // BLOCK_SIZE_M + BLOCKS_N = N // BLOCK_SIZE_N + BLOCKS_K = K // BLOCK_SIZE_K + + sfc_map_mn = _make_sfc_tensor(BLOCKS_M, BLOCKS_N) + sfc_map_kn = _make_sfc_tensor(BLOCKS_K, BLOCKS_N) + + bp = torch.empty( + (BLOCKS_N, BLOCKS_K, BLOCK_SIZE_K, BLOCK_SIZE_N), device=b.device, dtype=b.dtype + ) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + _block_transpose_kernel[(BLOCKS_K * BLOCKS_N,)]( + b, + bp, + sfc_map_kn, + N, + K, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + assume_in_bounds=True, + ) + + for ik in range(blocking_factor_k): + _sfc_matmul_kernel[(BLOCKS_M * BLOCKS_N,)]( + a, + bp, + c, # + sfc_map_mn, # + M, + N, + K, # + ik, # + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, # + BLOCKING_FACTOR_K=blocking_factor_k, + assume_in_bounds=True, + ) + + return c diff --git a/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml b/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml index 3a0fc9c..e4c7d80 100644 --- a/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml +++ b/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml @@ -16,13 +16,23 @@ bench-cpu: - params: [A, B] dtype: float32 dims: - N: 1024 + N: 512 flop: "2*N*N*N" + mem_bytes: "3*N*N * 4" # f32 - params: [A, B] dtype: bfloat16 dims: N: 1024 flop: "2*N*N*N" + mem_bytes: "3*N*N * 2" # bf16 + - params: [A, B] + dtype: bfloat16 + dims: + N: 8192 + flop: "2*N*N*N" + mem_bytes: "3*N*N * 2" # bf16 + atol: 2 # long accumulation chain requires higher tolerances + rtol: 10 bench-gpu: - params: [A, B] diff --git a/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml b/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml index 0ba0f91..7bbc9ad 100644 --- a/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml +++ b/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml @@ -18,11 +18,13 @@ bench-cpu: - params: [A, B] dtype: bfloat16 dims: - M: 128 - N: 256 - K: 512 + M: 6144 + N: 6144 + K: 8192 flop: "2*M*N*K" mem_bytes: "(M*K + K*N + M*N) * 2" # f16 + atol: 2 # long accumulation chain requires higher tolerance + rtol: 4 bench-gpu: - params: [A, B] diff --git a/pyproject.toml b/pyproject.toml index d0c35b1..ba285e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ override-dependencies = [ torch = { index = "pytorch" } pytorch-triton-xpu = { index = "pytorch" } pytorch-triton = { index = "pytorch" } -triton = { git = "https://github.com/triton-lang/triton-cpu.git", rev = "270e696" } +triton = { git = "https://github.com/triton-lang/triton-cpu.git", rev = "eece2e9" } lighthouse = { git = "https://github.com/llvm/lighthouse", rev = "456475d" } mlir-python-bindings = { index = "eudsl" }