Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# ruff: noqa: E731
# AUTOGENERATED KERNEL (LLM)
# Source: LLM-generated candidate implementation
# Status: Experimental / uncurated
# Expectation: Correctness-first, performance not representative

import torch
import torch.nn as nn
import triton
import triton.language as tl


def _gemm_configs():
return [
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4},
),
]


@triton.autotune(configs=_gemm_configs(), key=["M", "N", "K"])
@triton.jit
def _gemm_scale_hardtanh_gelu_kernel(
x_ptr,
w_ptr,
b_ptr,
y_ptr,
M,
N,
K,
stride_xm: tl.constexpr,
stride_xk: tl.constexpr,
stride_wk: tl.constexpr,
stride_wn: tl.constexpr,
stride_ym: tl.constexpr,
stride_yn: tl.constexpr,
scale,
ht_min,
ht_max,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

x_desc = tl.make_tensor_descriptor(
base=x_ptr,
shape=(M, K),
strides=(stride_xm, stride_xk),
block_shape=(BLOCK_M, BLOCK_K),
)
w_desc = tl.make_tensor_descriptor(
base=w_ptr,
shape=(K, N),
strides=(stride_wk, stride_wn),
block_shape=(BLOCK_K, BLOCK_N),
)

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for off_k in range(0, K, BLOCK_K):
a = x_desc.load([pid_m * BLOCK_M, off_k])
b = w_desc.load([off_k, pid_n * BLOCK_N])
acc += tl.dot(a, b)

offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32)
acc += bias[None, :]
acc = acc * scale
acc = tl.minimum(tl.maximum(acc, ht_min), ht_max)
acc = 0.5 * acc * (1.0 + tl.math.erf(acc * 0.7071067811865476))

y_desc = tl.make_tensor_descriptor(
base=y_ptr,
shape=(M, N),
strides=(stride_ym, stride_yn),
block_shape=(BLOCK_M, BLOCK_N),
)
y_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(y_ptr.type.element_ty))


def kernel_function(
x: torch.Tensor,
weight_kn: torch.Tensor,
bias: torch.Tensor,
scaling_factor: float,
hardtanh_min: float,
hardtanh_max: float,
) -> torch.Tensor:
M, K = x.shape
_, N = weight_kn.shape
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
)
_gemm_scale_hardtanh_gelu_kernel[grid](
x,
weight_kn,
bias,
y,
M,
N,
K,
x.stride(0),
x.stride(1),
weight_kn.stride(0),
weight_kn.stride(1),
y.stride(0),
y.stride(1),
float(scaling_factor),
float(hardtanh_min),
float(hardtanh_max),
)
return y


batch_size = 2048
in_features = 8192
out_features = 8192
scaling_factor = 0.5
hardtanh_min = -2
hardtanh_max = 2


def get_inputs():
return [torch.rand(batch_size, in_features)]


def get_init_inputs():
return [in_features, out_features, scaling_factor, hardtanh_min, hardtanh_max]


class Model(nn.Module):
def __init__(
self, in_features, out_features, scaling_factor, hardtanh_min, hardtanh_max
):
super().__init__()
self.gemm = nn.Linear(in_features, out_features)
self.scaling_factor = scaling_factor
self.hardtanh_min = hardtanh_min
self.hardtanh_max = hardtanh_max
self._weight_kn = None

def forward(self, x):
if self._weight_kn is None:
w = self.gemm.weight.data.to(dtype=x.dtype).contiguous()
self.gemm.weight.data = w
self.gemm.bias.data = self.gemm.bias.data.to(dtype=x.dtype).contiguous()
self._weight_kn = w.t().contiguous()
return kernel_function(
x.contiguous(),
self._weight_kn,
self.gemm.bias,
self.scaling_factor,
self.hardtanh_min,
self.hardtanh_max,
)
142 changes: 142 additions & 0 deletions backends/triton/cpu/KernelBench/level2/59_Matmul_Swish_Scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# ruff: noqa: E731
# AUTOGENERATED KERNEL (LLM)
# Source: LLM-generated candidate implementation
# Status: Experimental / uncurated
# Expectation: Correctness-first, performance not representative

import torch
import torch.nn as nn
import triton
import triton.language as tl


def _gemm_configs():
return [
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4},
),
]


@triton.autotune(configs=_gemm_configs(), key=["M", "N", "K"])
@triton.jit
def _gemm_swish_scale_kernel(
x_ptr,
w_ptr,
b_ptr,
y_ptr,
M,
N,
K,
stride_xm: tl.constexpr,
stride_xk: tl.constexpr,
stride_wk: tl.constexpr,
stride_wn: tl.constexpr,
stride_ym: tl.constexpr,
stride_yn: tl.constexpr,
scale,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

x_desc = tl.make_tensor_descriptor(
base=x_ptr,
shape=(M, K),
strides=(stride_xm, stride_xk),
block_shape=(BLOCK_M, BLOCK_K),
)
w_desc = tl.make_tensor_descriptor(
base=w_ptr,
shape=(K, N),
strides=(stride_wk, stride_wn),
block_shape=(BLOCK_K, BLOCK_N),
)

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for off_k in range(0, K, BLOCK_K):
a = x_desc.load([pid_m * BLOCK_M, off_k])
b = w_desc.load([off_k, pid_n * BLOCK_N])
acc += tl.dot(a, b)

offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32)
acc += bias[None, :]
sig = 1.0 / (1.0 + tl.exp(-acc))
acc = acc * sig * scale

y_desc = tl.make_tensor_descriptor(
base=y_ptr,
shape=(M, N),
strides=(stride_ym, stride_yn),
block_shape=(BLOCK_M, BLOCK_N),
)
y_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(y_ptr.type.element_ty))


def kernel_function(x, weight_kn, bias, scaling_factor):
M, K = x.shape
_, N = weight_kn.shape
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
)
_gemm_swish_scale_kernel[grid](
x,
weight_kn,
bias,
y,
M,
N,
K,
x.stride(0),
x.stride(1),
weight_kn.stride(0),
weight_kn.stride(1),
y.stride(0),
y.stride(1),
float(scaling_factor),
)
return y


batch_size = 128
in_features = 32768
out_features = 32768
scaling_factor = 2.0


def get_inputs():
return [torch.rand(batch_size, in_features)]


def get_init_inputs():
return [in_features, out_features, scaling_factor]


class Model(nn.Module):
def __init__(self, in_features, out_features, scaling_factor):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
self.scaling_factor = scaling_factor
self._weight_kn = None

def forward(self, x):
if self._weight_kn is None:
w = self.linear.weight.data.to(dtype=x.dtype).contiguous()
self.linear.weight.data = w
self.linear.bias.data = self.linear.bias.data.to(dtype=x.dtype).contiguous()
self._weight_kn = w.t().contiguous()
return kernel_function(
x.contiguous(), self._weight_kn, self.linear.bias, self.scaling_factor
)
Loading
Loading