Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,84 +3,14 @@
# Status: Experimental / uncurated
# Expectation: Correctness-first, performance not representative

from pathlib import Path
import sys

import torch
import torch.nn as nn
import triton
import triton.language as tl


@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32}),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}),
],
key=["M", "N", "K"], # autotune per problem size
)
@triton.jit
def _matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
a_desc = tl.make_tensor_descriptor(
base=a_ptr, shape=(M, K), strides=(K, 1), block_shape=(BLOCK_M, BLOCK_K)
)
b_desc = tl.make_tensor_descriptor(
base=b_ptr, shape=(K, N), strides=(N, 1), block_shape=(BLOCK_K, BLOCK_N)
)
c_desc = tl.make_tensor_descriptor(
base=c_ptr, shape=(M, N), strides=(N, 1), block_shape=(BLOCK_M, BLOCK_N)
)

m = tl.program_id(0) * BLOCK_M
n = tl.program_id(1) * BLOCK_N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = a_desc.load((m, k))
b = b_desc.load((k, n))
acc = tl.dot(a, b, acc)

c_desc.store((m, n), acc)


def _kernel_function_cpu(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
assert isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor)
assert A.device.type == "cpu" and B.device.type == "cpu", "A and B must be on CPU"
assert A.is_floating_point() and B.is_floating_point(), (
"A and B must be floating point tensors"
)
assert A.dtype == B.dtype, f"dtype mismatch: {A.dtype} vs {B.dtype}"

orig_dtype = A.dtype

M, K = A.shape
K2, N = B.shape
assert K == K2, f"Incompatible K dimensions: {K} vs {K2}"

C32 = torch.empty((M, N), device=A.device, dtype=torch.float32)

# Autotuned grid: depends on BLOCK_M/BLOCK_N chosen by config
grid = lambda META: (
triton.cdiv(M, META["BLOCK_M"]),
triton.cdiv(N, META["BLOCK_N"]),
)

_matmul_kernel[grid](
A,
B,
C32,
M,
N,
K,
)

return C32.to(orig_dtype)
sys.path.insert(0, str(Path(__file__).parent))
from sfc_matmul import sfc_matmul


class Model(nn.Module):
Expand All @@ -90,4 +20,4 @@ def __init__(self, *args, **kwargs):
super(Model, self).__init__()

def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
return _kernel_function_cpu(A, B)
return sfc_matmul(A, B)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ruff: noqa: E731
# Example Triton CPU kernel
# Status: Experimental / uncurated
# Expectation: Correctness-first, performance not representative

from pathlib import Path
import sys

import torch
import torch.nn as nn

sys.path.insert(0, str(Path(__file__).parent))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

from sfc_matmul import sfc_matmul


class Model(nn.Module):
"""KernelBench-compatible wrapper"""

def __init__(self, *args, **kwargs):
super(Model, self).__init__()

def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
return sfc_matmul(A, B)
25 changes: 25 additions & 0 deletions backends/triton/cpu/KernelBench/level1/LICENSE.gilbert_d2xy.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
BSD 2-Clause License

Copyright (c) 2018, Jakub Červený
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
85 changes: 85 additions & 0 deletions backends/triton/cpu/KernelBench/level1/gilbert_d2xy.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd throw it into a separate subdir. Otherwise, it'll get a bit lost among other 100 kernels.
Not sure yet how we'd want to organize such helpers so maybe backends/triton/cpu/utils for now?

Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: BSD-2-Clause
# Copyright (c) 2024 abetusk

# Source: https://github.com/jakubcerveny/gilbert/blob/9b080a74c5e3b6fe189785c52742f83ac85ba181/gilbert_d2xy.py


def gilbert_d2xy(idx, w, h):
"""
Generalized Hilbert ('gilbert') space-filling curve for arbitrary-sized
2D rectangular grids. Takes a position along the gilbert curve and returns
its 2D (x,y) coordinate.
"""

if w >= h:
return gilbert_d2xy_r(idx, 0, 0, 0, w, 0, 0, h)
return gilbert_d2xy_r(idx, 0, 0, 0, 0, h, w, 0)


def sgn(x):
return -1 if x < 0 else (1 if x > 0 else 0)


def gilbert_d2xy_r(dst_idx, cur_idx, x, y, ax, ay, bx, by):
w = abs(ax + ay)
h = abs(bx + by)

(dax, day) = (sgn(ax), sgn(ay)) # unit major direction
(dbx, dby) = (sgn(bx), sgn(by)) # unit orthogonal direction

di = dst_idx - cur_idx

if h == 1:
return (x + dax * di, y + day * di)
if w == 1:
return (x + dbx * di, y + dby * di)

(ax2, ay2) = (ax // 2, ay // 2)
(bx2, by2) = (bx // 2, by // 2)

w2 = abs(ax2 + ay2)
h2 = abs(bx2 + by2)

if 2 * w > 3 * h:
if (w2 % 2) and (w > 2):
# prefer even steps
(ax2, ay2) = (ax2 + dax, ay2 + day)

# long case: split in two parts only
nxt_idx = cur_idx + abs((ax2 + ay2) * (bx + by))
if (cur_idx <= dst_idx) and (dst_idx < nxt_idx):
return gilbert_d2xy_r(dst_idx, cur_idx, x, y, ax2, ay2, bx, by)
cur_idx = nxt_idx

return gilbert_d2xy_r(
dst_idx, cur_idx, x + ax2, y + ay2, ax - ax2, ay - ay2, bx, by
)

if (h2 % 2) and (h > 2):
# prefer even steps
(bx2, by2) = (bx2 + dbx, by2 + dby)

# standard case: one step up, one long horizontal, one step down
nxt_idx = cur_idx + abs((bx2 + by2) * (ax2 + ay2))
if (cur_idx <= dst_idx) and (dst_idx < nxt_idx):
return gilbert_d2xy_r(dst_idx, cur_idx, x, y, bx2, by2, ax2, ay2)
cur_idx = nxt_idx

nxt_idx = cur_idx + abs((ax + ay) * ((bx - bx2) + (by - by2)))
if (cur_idx <= dst_idx) and (dst_idx < nxt_idx):
return gilbert_d2xy_r(
dst_idx, cur_idx, x + bx2, y + by2, ax, ay, bx - bx2, by - by2
)
cur_idx = nxt_idx

return gilbert_d2xy_r(
dst_idx,
cur_idx,
x + (ax - dax) + (bx2 - dbx),
y + (ay - day) + (by2 - dby),
-bx2,
-by2,
-(ax - ax2),
-(ay - ay2),
)
Loading
Loading