diff --git a/backends/triton/cpu/KernelBench/level2/40_Matmul_Scaling_ResidualAdd.py b/backends/triton/cpu/KernelBench/level2/40_Matmul_Scaling_ResidualAdd.py new file mode 100644 index 0000000..a33f6f0 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level2/40_Matmul_Scaling_ResidualAdd.py @@ -0,0 +1,479 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import weakref + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +# ----------------------------------------------------------------------------- +# Intel XPU Triton GEMM with packed RHS [K, N]. +# Stage updates: +# - Add reusable packed-weight cache for standalone fused_linear() path. +# - Use explicit grf_mode="256" for large-tile XPU GEMM launches. +# - Keep persistent kernel only for large enough grids. +# ----------------------------------------------------------------------------- + + +_nonpersistent_configs = [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + ), +] + +_persistent_configs = [ + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "BLOCK_K": 32, + "GROUP_SIZE_M": 4, + "NUM_PROGS": 32, + }, + ), +] + + +@triton.autotune(configs=_nonpersistent_configs, key=["M", "N", "K"]) +@triton.jit +def _linear_bias_kernel_packed( + x_ptr, + wt_ptr, + b_ptr, + y_ptr, + M, + N, + K, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_wtk: tl.constexpr, + stride_wtn: tl.constexpr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + SCALE, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + K_DIVISIBLE: tl.constexpr, + M_DIVISIBLE: tl.constexpr, + N_DIVISIBLE: tl.constexpr, + grf_mode: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + tl.max_contiguous(offs_n, BLOCK_N) + + x_desc = tl.make_tensor_descriptor( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + block_shape=(BLOCK_M, BLOCK_K), + ) + wt_desc = tl.make_tensor_descriptor( + base=wt_ptr, + shape=(K, N), + strides=(stride_wtk, stride_wtn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = x_desc.load([pid_m * BLOCK_M, off_k]) + w = wt_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, w) + + bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + acc = (acc + bias[None, :]) * SCALE + + y_desc = tl.make_tensor_descriptor( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + block_shape=(BLOCK_M, BLOCK_N), + ) + y_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(y_ptr.type.element_ty)) + + +@triton.autotune(configs=_persistent_configs, key=["M", "N", "K"]) +@triton.jit +def _linear_bias_kernel_persistent_packed( + x_ptr, + wt_ptr, + b_ptr, + y_ptr, + M, + N, + K, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_wtk: tl.constexpr, + stride_wtn: tl.constexpr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + SCALE, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_PROGS: tl.constexpr, + K_DIVISIBLE: tl.constexpr, + M_DIVISIBLE: tl.constexpr, + N_DIVISIBLE: tl.constexpr, + grf_mode: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_tiles = num_pid_m * num_pid_n + + tile_id = pid + while tile_id < num_tiles: + group_tiles = GROUP_SIZE_M * num_pid_n + group_id = tile_id // group_tiles + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + tile_in_group = tile_id % group_tiles + pid_m = first_pid_m + (tile_in_group % group_size_m) + pid_n = tile_in_group // group_size_m + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + tl.max_contiguous(offs_n, BLOCK_N) + + x_desc = tl.make_tensor_descriptor( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + block_shape=(BLOCK_M, BLOCK_K), + ) + wt_desc = tl.make_tensor_descriptor( + base=wt_ptr, + shape=(K, N), + strides=(stride_wtk, stride_wtn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = x_desc.load([pid_m * BLOCK_M, off_k]) + w = wt_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, w) + + bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + acc = (acc + bias[None, :]) * SCALE + + y_desc = tl.make_tensor_descriptor( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + block_shape=(BLOCK_M, BLOCK_N), + ) + y_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(y_ptr.type.element_ty)) + + tile_id += NUM_PROGS + + +@triton.jit +def _scale_residual_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask, other=0.0) + y = (x.to(tl.float32) * 1.5).to(out_ptr.dtype.element_ty) + tl.store(out_ptr + offsets, y, mask=mask) + + +_PACKED_WEIGHT_CACHE = {} + + +def _cache_key_for_packed_weight(w: torch.Tensor): + return ( + int(w.data_ptr()), + tuple(w.shape), + tuple(w.stride()), + str(w.dtype), + str(w.device), + ) + + +def _cleanup_packed_weight_cache(): + dead_keys = [k for k, v in _PACKED_WEIGHT_CACHE.items() if v["weak"]() is None] + for k in dead_keys: + _PACKED_WEIGHT_CACHE.pop(k, None) + + +def _pack_weight_kn( + w: torch.Tensor, target_dtype: torch.dtype = torch.bfloat16 +) -> torch.Tensor: + if w.dtype != target_dtype: + w = w.to(dtype=target_dtype) + if not w.is_contiguous(): + w = w.contiguous() + return w.transpose(0, 1).contiguous() + + +def _get_cached_packed_weight( + w: torch.Tensor, target_dtype: torch.dtype = torch.bfloat16 +) -> torch.Tensor: + _cleanup_packed_weight_cache() + + if w.dtype != target_dtype: + w_ready = w.to(dtype=target_dtype).contiguous() + else: + w_ready = w.contiguous() + + key = _cache_key_for_packed_weight(w_ready) + entry = _PACKED_WEIGHT_CACHE.get(key) + + if entry is not None: + packed = entry["packed"] + if packed is not None: + return packed + + packed = w_ready.transpose(0, 1).contiguous() + _PACKED_WEIGHT_CACHE[key] = { + "weak": weakref.ref(w_ready), + "packed": packed, + } + return packed + + +def _should_use_persistent(M: int, N: int, K: int) -> bool: + tiles_m = triton.cdiv(M, 256) + tiles_n = triton.cdiv(N, 256) + total_tiles = tiles_m * tiles_n + return total_tiles >= 512 and M >= 512 and K >= 1024 + + +def _select_grf_mode(M: int, N: int, K: int) -> str: + if M >= 256 and N >= 256 and K >= 1024: + return "256" + return "auto" + + +def _launch_linear( + x_in: torch.Tensor, + wt_in: torch.Tensor, + b_in: torch.Tensor, + y: torch.Tensor, + scale: float, +): + M, K = x_in.shape + _, N = wt_in.shape + + stride_xm, stride_xk = x_in.stride() + stride_wtk, stride_wtn = wt_in.stride() + stride_ym, stride_yn = y.stride() + + k_divisible = K % 32 == 0 + m_divisible = M % 256 == 0 + n_divisible = N % 256 == 0 + grf_mode = _select_grf_mode(M, N, K) + + if _should_use_persistent(M, N, K): + + def grid(meta): + return (meta["NUM_PROGS"],) + + _linear_bias_kernel_persistent_packed[grid]( + x_in, + wt_in, + b_in, + y, + M, + N, + K, + stride_xm, + stride_xk, + stride_wtk, + stride_wtn, + stride_ym, + stride_yn, + scale, + K_DIVISIBLE=k_divisible, + M_DIVISIBLE=m_divisible, + N_DIVISIBLE=n_divisible, + grf_mode=grf_mode, + ) + else: + + def grid(meta): + return (triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),) + + _linear_bias_kernel_packed[grid]( + x_in, + wt_in, + b_in, + y, + M, + N, + K, + stride_xm, + stride_xk, + stride_wtk, + stride_wtn, + stride_ym, + stride_yn, + scale, + K_DIVISIBLE=k_divisible, + M_DIVISIBLE=m_divisible, + N_DIVISIBLE=n_divisible, + grf_mode=grf_mode, + ) + + +def fused_linear( + x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, scale: float = 1.5 +) -> torch.Tensor: + if not ( + isinstance(x, torch.Tensor) + and isinstance(w, torch.Tensor) + and isinstance(b, torch.Tensor) + ): + raise TypeError("x, w, b must be torch.Tensor") + if x.ndim != 2 or w.ndim != 2 or b.ndim != 1: + raise ValueError("Expected x: [M,K], w: [N,K], b: [N]") + + target_dtype = ( + x.dtype if x.dtype in (torch.float16, torch.bfloat16) else torch.bfloat16 + ) + + if x.dtype != target_dtype: + x_ready = x.to(dtype=target_dtype).contiguous() + else: + x_ready = x.contiguous() + + wt_ready = _get_cached_packed_weight(w, target_dtype=target_dtype) + + if b.dtype != target_dtype: + b_ready = b.to(dtype=target_dtype).contiguous() + else: + b_ready = b.contiguous() + + M, Kx = x_ready.shape + Kw, N = wt_ready.shape + if Kx != Kw: + raise ValueError(f"Incompatible shapes: x[K={Kx}] vs w[K={Kw}]") + if b_ready.shape[0] != N: + raise ValueError(f"Bias shape mismatch: b[{b_ready.shape[0]}] vs N={N}") + + y = torch.empty((M, N), device=x_ready.device, dtype=x_ready.dtype) + _launch_linear(x_ready, wt_ready, b_ready, y, scale) + return y + + +def fused_scale_residual(x: torch.Tensor) -> torch.Tensor: + if not isinstance(x, torch.Tensor): + raise TypeError("Expected a torch.Tensor input") + if x.dtype not in (torch.float16, torch.bfloat16): + raise TypeError(f"Unsupported dtype {x.dtype}. Supported: float16, bfloat16") + if not x.is_contiguous(): + x = x.contiguous() + + out = torch.empty_like(x) + n_elements = x.numel() + BLOCK_SIZE = 64 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + _scale_residual_kernel[grid](x, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return out + + +def kernel_function(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return fused_linear(x, w, b, scale=1.5) + + +batch_size = 16384 +in_features = 4096 +out_features = 4096 +scaling_factor = 0.5 + + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + + +def get_init_inputs(): + return [in_features, out_features, scaling_factor] + + +class Model(nn.Module): + def __init__(self, in_features, out_features, scaling_factor): + super().__init__() + self.linear = nn.Linear(in_features, out_features) + self.scaling_factor = scaling_factor + self._packed_ready = False + self._weight_packed = None + self._weight_version = None + self._target_dtype = None + + def _ensure_params(self, device, target_dtype): + if not self._packed_ready or self._target_dtype != target_dtype: + self._target_dtype = target_dtype + self.linear.weight.data = self.linear.weight.data.to( + device, dtype=target_dtype + ).contiguous() + self.linear.bias.data = self.linear.bias.data.to( + device, dtype=target_dtype + ).contiguous() + self._weight_packed = self.linear.weight.data.transpose(0, 1).contiguous() + self._weight_version = ( + self.linear.weight.data.data_ptr(), + self.linear.weight.shape, + self.linear.weight.dtype, + self.linear.weight.device, + ) + self._packed_ready = True + + def _refresh_packed_weight_if_needed(self): + cur_version = ( + self.linear.weight.data.data_ptr(), + self.linear.weight.shape, + self.linear.weight.dtype, + self.linear.weight.device, + ) + if self._weight_packed is None or self._weight_version != cur_version: + self._weight_packed = self.linear.weight.data.transpose(0, 1).contiguous() + self._weight_version = cur_version + + def forward(self, x): + target_dtype = ( + x.dtype if x.dtype in (torch.float16, torch.bfloat16) else torch.bfloat16 + ) + if x.dtype != target_dtype: + x = x.to(dtype=target_dtype).contiguous() + else: + x = x.contiguous() + + self._ensure_params(x.device, target_dtype) + self._refresh_packed_weight_if_needed() + + b = self.linear.bias + y = torch.empty( + (x.shape[0], self.linear.weight.shape[0]), device=x.device, dtype=x.dtype + ) + + M, K = x.shape + Kwt, N = self._weight_packed.shape + if K != Kwt: + raise ValueError(f"Incompatible shapes: x[K={K}] vs packed_w[K={Kwt}]") + + _launch_linear(x, self._weight_packed, b, y, 1.5) + return y diff --git a/backends/triton/cpu/KernelBench/level2/53_Gemm_Scaling_Hardtanh_GELU.py b/backends/triton/cpu/KernelBench/level2/53_Gemm_Scaling_Hardtanh_GELU.py new file mode 100644 index 0000000..6d3e271 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level2/53_Gemm_Scaling_Hardtanh_GELU.py @@ -0,0 +1,165 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _gemm_configs(): + return [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + ), + ] + + +@triton.autotune(configs=_gemm_configs(), key=["M", "N", "K"]) +@triton.jit +def _gemm_scale_hardtanh_gelu_kernel( + x_ptr, + w_ptr, + b_ptr, + y_ptr, + M, + N, + K, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + scale, + ht_min, + ht_max, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + x_desc = tl.make_tensor_descriptor( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + block_shape=(BLOCK_M, BLOCK_K), + ) + w_desc = tl.make_tensor_descriptor( + base=w_ptr, + shape=(K, N), + strides=(stride_wk, stride_wn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = x_desc.load([pid_m * BLOCK_M, off_k]) + b = w_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + acc += bias[None, :] + acc = acc * scale + acc = tl.minimum(tl.maximum(acc, ht_min), ht_max) + acc = 0.5 * acc * (1.0 + tl.math.erf(acc * 0.7071067811865476)) + + y_desc = tl.make_tensor_descriptor( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + block_shape=(BLOCK_M, BLOCK_N), + ) + y_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(y_ptr.type.element_ty)) + + +def kernel_function( + x: torch.Tensor, + weight_kn: torch.Tensor, + bias: torch.Tensor, + scaling_factor: float, + hardtanh_min: float, + hardtanh_max: float, +) -> torch.Tensor: + M, K = x.shape + _, N = weight_kn.shape + y = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + _gemm_scale_hardtanh_gelu_kernel[grid]( + x, + weight_kn, + bias, + y, + M, + N, + K, + x.stride(0), + x.stride(1), + weight_kn.stride(0), + weight_kn.stride(1), + y.stride(0), + y.stride(1), + float(scaling_factor), + float(hardtanh_min), + float(hardtanh_max), + ) + return y + + +batch_size = 2048 +in_features = 8192 +out_features = 8192 +scaling_factor = 0.5 +hardtanh_min = -2 +hardtanh_max = 2 + + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + + +def get_init_inputs(): + return [in_features, out_features, scaling_factor, hardtanh_min, hardtanh_max] + + +class Model(nn.Module): + def __init__( + self, in_features, out_features, scaling_factor, hardtanh_min, hardtanh_max + ): + super().__init__() + self.gemm = nn.Linear(in_features, out_features) + self.scaling_factor = scaling_factor + self.hardtanh_min = hardtanh_min + self.hardtanh_max = hardtanh_max + self._weight_kn = None + + def forward(self, x): + if self._weight_kn is None: + w = self.gemm.weight.data.to(dtype=x.dtype).contiguous() + self.gemm.weight.data = w + self.gemm.bias.data = self.gemm.bias.data.to(dtype=x.dtype).contiguous() + self._weight_kn = w.t().contiguous() + return kernel_function( + x.contiguous(), + self._weight_kn, + self.gemm.bias, + self.scaling_factor, + self.hardtanh_min, + self.hardtanh_max, + ) diff --git a/backends/triton/cpu/KernelBench/level2/59_Matmul_Swish_Scaling.py b/backends/triton/cpu/KernelBench/level2/59_Matmul_Swish_Scaling.py new file mode 100644 index 0000000..b52fc6d --- /dev/null +++ b/backends/triton/cpu/KernelBench/level2/59_Matmul_Swish_Scaling.py @@ -0,0 +1,142 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _gemm_configs(): + return [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + ), + ] + + +@triton.autotune(configs=_gemm_configs(), key=["M", "N", "K"]) +@triton.jit +def _gemm_swish_scale_kernel( + x_ptr, + w_ptr, + b_ptr, + y_ptr, + M, + N, + K, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + scale, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + x_desc = tl.make_tensor_descriptor( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + block_shape=(BLOCK_M, BLOCK_K), + ) + w_desc = tl.make_tensor_descriptor( + base=w_ptr, + shape=(K, N), + strides=(stride_wk, stride_wn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = x_desc.load([pid_m * BLOCK_M, off_k]) + b = w_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + acc += bias[None, :] + sig = 1.0 / (1.0 + tl.exp(-acc)) + acc = acc * sig * scale + + y_desc = tl.make_tensor_descriptor( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + block_shape=(BLOCK_M, BLOCK_N), + ) + y_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(y_ptr.type.element_ty)) + + +def kernel_function(x, weight_kn, bias, scaling_factor): + M, K = x.shape + _, N = weight_kn.shape + y = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + _gemm_swish_scale_kernel[grid]( + x, + weight_kn, + bias, + y, + M, + N, + K, + x.stride(0), + x.stride(1), + weight_kn.stride(0), + weight_kn.stride(1), + y.stride(0), + y.stride(1), + float(scaling_factor), + ) + return y + + +batch_size = 128 +in_features = 32768 +out_features = 32768 +scaling_factor = 2.0 + + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + + +def get_init_inputs(): + return [in_features, out_features, scaling_factor] + + +class Model(nn.Module): + def __init__(self, in_features, out_features, scaling_factor): + super().__init__() + self.linear = nn.Linear(in_features, out_features) + self.scaling_factor = scaling_factor + self._weight_kn = None + + def forward(self, x): + if self._weight_kn is None: + w = self.linear.weight.data.to(dtype=x.dtype).contiguous() + self.linear.weight.data = w + self.linear.bias.data = self.linear.bias.data.to(dtype=x.dtype).contiguous() + self._weight_kn = w.t().contiguous() + return kernel_function( + x.contiguous(), self._weight_kn, self.linear.bias, self.scaling_factor + ) diff --git a/backends/triton/cpu/KernelBench/level2/66_Matmul_Dropout_Softmax.py b/backends/triton/cpu/KernelBench/level2/66_Matmul_Dropout_Softmax.py new file mode 100644 index 0000000..912860d --- /dev/null +++ b/backends/triton/cpu/KernelBench/level2/66_Matmul_Dropout_Softmax.py @@ -0,0 +1,209 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _gemm_configs(): + return [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + ), + ] + + +@triton.autotune(configs=_gemm_configs(), key=["M", "N", "K"]) +@triton.jit +def _gemm_bias_kernel( + x_ptr, + w_ptr, + b_ptr, + y_ptr, + M, + N, + K, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + x_desc = tl.make_tensor_descriptor( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + block_shape=(BLOCK_M, BLOCK_K), + ) + w_desc = tl.make_tensor_descriptor( + base=w_ptr, + shape=(K, N), + strides=(stride_wk, stride_wn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = x_desc.load([pid_m * BLOCK_M, off_k]) + b = w_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + acc += bias[None, :] + + y_desc = tl.make_tensor_descriptor( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + block_shape=(BLOCK_M, BLOCK_N), + ) + y_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(y_ptr.type.element_ty)) + + +def _softmax_configs(): + configs = [triton.Config({"BLOCK_N": 32},)] + return configs + + +@triton.autotune(configs=_softmax_configs(), key=["M", "N"]) +@triton.jit +def _row_softmax_kernel( + x_ptr, + y_ptr, + M, + N, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + if row >= M: + return + + offs_n = tl.arange(0, BLOCK_N) + neg_inf = -float("inf") + row64 = row.to(tl.int64) + row_start_x = x_ptr + row64 * stride_xm + row_start_y = y_ptr + row64 * stride_ym + + row_max = tl.full((), neg_inf, tl.float32) + for start_n in tl.range(0, N, BLOCK_N): + cols = start_n + offs_n + mask = cols < N + vals = tl.load( + row_start_x + cols.to(tl.int64) * stride_xn, mask=mask, other=neg_inf + ).to(tl.float32) + row_max = tl.maximum(row_max, tl.max(vals, axis=0)) + + row_sum = tl.zeros((), tl.float32) + for start_n in tl.range(0, N, BLOCK_N): + cols = start_n + offs_n + mask = cols < N + vals = tl.load( + row_start_x + cols.to(tl.int64) * stride_xn, mask=mask, other=neg_inf + ).to(tl.float32) + row_sum += tl.sum(tl.exp(vals - row_max), axis=0) + + inv_row_sum = 1.0 / row_sum + for start_n in tl.range(0, N, BLOCK_N): + cols = start_n + offs_n + mask = cols < N + vals = tl.load( + row_start_x + cols.to(tl.int64) * stride_xn, mask=mask, other=neg_inf + ).to(tl.float32) + probs = tl.exp(vals - row_max) * inv_row_sum + tl.store( + row_start_y + cols.to(tl.int64) * stride_yn, + probs.to(y_ptr.dtype.element_ty), + mask=mask, + ) + + +def kernel_function(x, weight_kn, bias) -> torch.Tensor: + M, K = x.shape + _, N = weight_kn.shape + + logits = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid_gemm = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + _gemm_bias_kernel[grid_gemm]( + x, + weight_kn, + bias, + logits, + M, + N, + K, + x.stride(0), + x.stride(1), + weight_kn.stride(0), + weight_kn.stride(1), + logits.stride(0), + logits.stride(1), + ) + + y = torch.empty_like(logits) + _row_softmax_kernel[(M,)]( + logits, + y, + M, + N, + logits.stride(0), + logits.stride(1), + y.stride(0), + y.stride(1), + ) + return y + + +batch_size = 128 +in_features = 16384 +out_features = 16384 +dropout_p = 0.2 + + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + + +def get_init_inputs(): + return [in_features, out_features, dropout_p] + + +class Model(nn.Module): + def __init__(self, in_features, out_features, dropout_p=0.2): + super().__init__() + self.linear = nn.Linear(in_features, out_features) + self._weight_kn = None + + def forward(self, x): + if self._weight_kn is None: + w = self.linear.weight.data.to(dtype=x.dtype).contiguous() + self.linear.weight.data = w + self.linear.bias.data = self.linear.bias.data.to(dtype=x.dtype).contiguous() + self._weight_kn = w.t().contiguous() + return kernel_function(x.contiguous(), self._weight_kn, self.linear.bias) diff --git a/backends/triton/cpu/KernelBench/level2/70_Gemm_Sigmoid_Scaling_ResidualAdd.py b/backends/triton/cpu/KernelBench/level2/70_Gemm_Sigmoid_Scaling_ResidualAdd.py new file mode 100644 index 0000000..de8201c --- /dev/null +++ b/backends/triton/cpu/KernelBench/level2/70_Gemm_Sigmoid_Scaling_ResidualAdd.py @@ -0,0 +1,148 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _gemm_configs(): + return [ + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}), + ] + + +@triton.autotune(configs=_gemm_configs(), key=["M", "N", "K"]) +@triton.jit +def _gemm_sigmoid_scale_residual_kernel( + x_ptr, + w_ptr, + b_ptr, + y_ptr, + M, + N, + K, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + scale, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + x_desc = tl.make_tensor_descriptor( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + block_shape=(BLOCK_M, BLOCK_K), + ) + w_desc = tl.make_tensor_descriptor( + base=w_ptr, + shape=(K, N), + strides=(stride_wk, stride_wn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = x_desc.load([pid_m * BLOCK_M, off_k]) + b = w_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + acc += bias[None, :] + # out = gemm_out + scale * sigmoid(gemm_out) + absx = tl.abs(acc) + e = tl.exp(-absx) + sig = tl.where(acc >= 0, 1.0 / (1.0 + e), e / (1.0 + e)) + acc = acc + scale * sig + + y_desc = tl.make_tensor_descriptor( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + block_shape=(BLOCK_M, BLOCK_N), + ) + y_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(y_ptr.type.element_ty)) + + +def kernel_function( + x: torch.Tensor, + weight_kn: torch.Tensor, + bias: torch.Tensor, + scaling_factor: float, +) -> torch.Tensor: + M, K = x.shape + _, N = weight_kn.shape + y = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + _gemm_sigmoid_scale_residual_kernel[grid]( + x, + weight_kn, + bias, + y, + M, + N, + K, + x.stride(0), + x.stride(1), + weight_kn.stride(0), + weight_kn.stride(1), + y.stride(0), + y.stride(1), + float(scaling_factor), + ) + return y + + +batch_size = 1024 +input_size = 8192 +hidden_size = 8192 +scaling_factor = 2.0 + + +def get_inputs(): + return [torch.rand(batch_size, input_size)] + + +def get_init_inputs(): + return [input_size, hidden_size, scaling_factor] + + +class Model(nn.Module): + def __init__(self, input_size, hidden_size, scaling_factor): + super().__init__() + self.gemm = nn.Linear(input_size, hidden_size) + self.scaling_factor = scaling_factor + self._weight_kn = None + + def forward(self, x): + if self._weight_kn is None: + w = self.gemm.weight.data.to(dtype=x.dtype).contiguous() + self.gemm.weight.data = w + self.gemm.bias.data = self.gemm.bias.data.to(dtype=x.dtype).contiguous() + self._weight_kn = w.t().contiguous() + return kernel_function( + x.contiguous(), self._weight_kn, self.gemm.bias, self.scaling_factor + ) diff --git a/backends/triton/cpu/KernelBench/level2/76_Gemm_Add_ReLU.py b/backends/triton/cpu/KernelBench/level2/76_Gemm_Add_ReLU.py new file mode 100644 index 0000000..7125763 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level2/76_Gemm_Add_ReLU.py @@ -0,0 +1,395 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 1}, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def _gemm_bias_relu_kernel( + a_ptr, + b_t_ptr, + bias_ptr, + c_ptr, + M, + N, + K, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + + group_size = GROUP_SIZE_M * num_pid_n + group_id = pid // group_size + first_pid_m = group_id * GROUP_SIZE_M + group_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_in_group = pid % group_size + pid_m = first_pid_m + (pid_in_group % group_m) + pid_n = pid_in_group // group_m + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_t_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = a_desc.load([pid_m * BLOCK_M, off_k]) + b = b_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + tl.max_contiguous(offs_n, BLOCK_N) + bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + + out_dtype = c_ptr.type.element_ty + acc = acc.to(out_dtype).to(tl.float32) + bias[None, :] + acc = tl.maximum(acc, 0.0) + + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(out_dtype)) + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "BLOCK_K": 32, + "GROUP_SIZE_M": 1, + "NUM_PROGS": 32, + }, + num_warps=32, + num_stages=3, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def _gemm_bias_relu_persistent_kernel( + a_ptr, + b_t_ptr, + bias_ptr, + c_ptr, + M, + N, + K, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_PROGS: tl.constexpr, + grf_mode: tl.constexpr = "auto", +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_tiles = num_pid_m * num_pid_n + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_t_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + + tile_id = pid + while tile_id < num_tiles: + group_size = GROUP_SIZE_M * num_pid_n + group_id = tile_id // group_size + first_pid_m = group_id * GROUP_SIZE_M + group_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_in_group = tile_id % group_size + pid_m = first_pid_m + (pid_in_group % group_m) + pid_n = pid_in_group // group_m + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = a_desc.load([pid_m * BLOCK_M, off_k]) + b = b_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + tl.max_contiguous(offs_n, BLOCK_N) + bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + + out_dtype = c_ptr.type.element_ty + acc = acc.to(out_dtype).to(tl.float32) + bias[None, :] + acc = tl.maximum(acc, 0.0) + + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(out_dtype)) + + tile_id += NUM_PROGS + + +def _get_num_workers(device=None): + try: + if device is not None and device.type == "xpu": + if hasattr(torch, "xpu"): + if hasattr(torch.xpu, "get_device_capability"): + cap = torch.xpu.get_device_capability() + if isinstance(cap, dict): + for key in ( + "gpu_subslice_count", + "subslice_count", + "max_compute_units", + "gpu_eu_count", + ): + val = cap.get(key, None) + if isinstance(val, int) and val > 0: + return val + if hasattr(torch.xpu, "get_device_properties"): + props = torch.xpu.get_device_properties(device.index or 0) + for key in ( + "subslice_count", + "max_compute_units", + "multi_processor_count", + ): + if hasattr(props, key): + val = getattr(props, key) + if isinstance(val, int) and val > 0: + return val + elif device is not None and device.type == "cuda": + props = torch.cuda.get_device_properties(device) + if ( + hasattr(props, "multi_processor_count") + and props.multi_processor_count > 0 + ): + return props.multi_processor_count + except Exception: + pass + return 32 + + +def _select_num_progs_cap(total_tiles: int, device=None): + hw = _get_num_workers(device=device) + cap = max(1, min(total_tiles, hw)) + if cap >= 256: + return 256 + if cap >= 128: + return 128 + if cap >= 64: + return 64 + if cap >= 32: + return 32 + return cap + + +def kernel_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + packed_weight_t: torch.Tensor = None, +): + assert ( + isinstance(x, torch.Tensor) + and isinstance(weight, torch.Tensor) + and isinstance(bias, torch.Tensor) + ) + assert x.ndim == 2 and weight.ndim == 2 and bias.ndim == 1 + assert x.shape[1] == weight.shape[1], "Incompatible shapes" + assert bias.numel() == weight.shape[0], "Bias length mismatch" + + _supported_dtypes = (torch.float16, torch.bfloat16) + target_dtype = x.dtype if x.dtype in _supported_dtypes else torch.bfloat16 + + x_ready = ( + x + if (x.dtype == target_dtype and x.is_contiguous()) + else x.to(dtype=target_dtype).contiguous() + ) + weight_ready = ( + weight + if (weight.dtype == target_dtype and weight.is_contiguous()) + else weight.to(dtype=target_dtype).contiguous() + ) + bias_ready = ( + bias + if (bias.dtype == target_dtype and bias.is_contiguous()) + else bias.to(dtype=target_dtype).contiguous() + ) + + if packed_weight_t is not None: + weight_t_ready = ( + packed_weight_t + if ( + packed_weight_t.dtype == target_dtype + and packed_weight_t.is_contiguous() + ) + else packed_weight_t.to(dtype=target_dtype).contiguous() + ) + else: + weight_t_ready = weight_ready.transpose(0, 1).contiguous() + + M, K = x_ready.shape + N = weight_ready.shape[0] + + out = torch.empty((M, N), device=x.device, dtype=target_dtype) + + stride_am, stride_ak = x_ready.stride() + stride_bk, stride_bn = weight_t_ready.stride() + stride_cm, stride_cn = out.stride() + + total_tiles = triton.cdiv(M, 256) * triton.cdiv(N, 256) + num_progs_cap = _select_num_progs_cap(total_tiles, device=x.device) + + # Use persistent scheduling when there are enough tiles to amortize looping. + # Fall back to original kernel for very small grids to avoid persistent overhead. + if total_tiles >= 8 and num_progs_cap >= 1: + grid = lambda meta: (min(meta["NUM_PROGS"], num_progs_cap),) + _gemm_bias_relu_persistent_kernel[grid]( + x_ready, + weight_t_ready, + bias_ready, + out, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + ) + else: + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), + ) + _gemm_bias_relu_kernel[grid]( + x_ready, + weight_t_ready, + bias_ready, + out, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + ) + return out + + +batch_size = 1024 +in_features = 8192 +out_features = 8192 +bias_shape = (out_features,) + + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + + +def get_init_inputs(): + return [in_features, out_features, bias_shape] + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=32, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def _bias_relu_kernel( + x_ptr, + bias_ptr, + out_ptr, + n_elements, + N, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x_val = tl.load(x_ptr + offsets, mask=mask, other=0.0) + bias_idx = offsets % N + b_val = tl.load(bias_ptr + bias_idx, mask=mask, other=0.0) + + y = x_val + b_val + y = tl.maximum(y.to(tl.float32), 0.0) + out_dtype = out_ptr.dtype.element_ty + tl.store(out_ptr + offsets, y.to(out_dtype), mask=mask) + + +class Model(nn.Module): + def __init__(self, in_features, out_features, bias_shape): + super().__init__() + self.gemm = nn.Linear(in_features, out_features, bias=False) + self.bias = nn.Parameter(torch.randn(bias_shape)) + self._weight_kn = None + + def forward(self, x): + if self._weight_kn is None: + w = self.gemm.weight.data.to(dtype=x.dtype).contiguous() + self.gemm.weight.data = w + self.bias.data = self.bias.data.to(dtype=x.dtype).contiguous() + self._weight_kn = w.t().contiguous() + return kernel_function( + x.contiguous(), self.gemm.weight, self.bias, packed_weight_t=self._weight_kn + ) diff --git a/backends/triton/cpu/KernelBench/level2/80_Gemm_Max_Subtract_GELU.py b/backends/triton/cpu/KernelBench/level2/80_Gemm_Max_Subtract_GELU.py new file mode 100644 index 0000000..f1ab416 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level2/80_Gemm_Max_Subtract_GELU.py @@ -0,0 +1,188 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _gemm_configs(): + return [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + ), + ] + + +@triton.autotune(configs=_gemm_configs(), key=["M", "N", "K"]) +@triton.jit +def _gemm_bias_partial_max_kernel( + x_ptr, + w_ptr, + b_ptr, + pmax_ptr, + M, + N, + K, + stride_xm, + stride_xk, + stride_wk, + stride_wn, + stride_pm, + stride_pn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + x_desc = tl.make_tensor_descriptor( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + block_shape=(BLOCK_M, BLOCK_K), + ) + w_desc = tl.make_tensor_descriptor( + base=w_ptr, + shape=(K, N), + strides=(stride_wk, stride_wn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = x_desc.load([pid_m * BLOCK_M, off_k]) + b = w_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + acc += bias[None, :] + + tile_max = tl.max(acc, axis=1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = offs_m < M + tl.store( + pmax_ptr + offs_m.to(tl.int64) * stride_pm + pid_n.to(tl.int64) * stride_pn, + tile_max.to(pmax_ptr.dtype.element_ty), + mask=mask_m, + ) + + +@triton.jit +def _reduce_max_gelu_kernel( + pmax_ptr, + y_ptr, + M, + num_n_tiles, + stride_pm, + stride_pn, + stride_ym, + BLOCK_T: tl.constexpr, +): + row = tl.program_id(0) + if row >= M: + return + row64 = row.to(tl.int64) + + row_max = -float("inf") + for t in tl.range(0, num_n_tiles, BLOCK_T): + offs_t = t + tl.arange(0, BLOCK_T) + mask = offs_t < num_n_tiles + vals = tl.load( + pmax_ptr + row64 * stride_pm + offs_t.to(tl.int64) * stride_pn, + mask=mask, + other=-float("inf"), + ).to(tl.float32) + row_max = tl.maximum(row_max, tl.max(vals, axis=0)) + + # max - mean(max) = 0, GELU(0) = 0 + tl.store(y_ptr + row64 * stride_ym, tl.zeros((), dtype=y_ptr.dtype.element_ty)) + + +def kernel_function(x, weight_kn, bias) -> torch.Tensor: + M, K = x.shape + _, N = weight_kn.shape + + grid_gemm = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + # Estimate num_n_tiles from smallest BLOCK_N in configs (128) + max_n_tiles = triton.cdiv(N, 128) + pmax = torch.full((M, max_n_tiles), float("-inf"), device=x.device, dtype=x.dtype) + + _gemm_bias_partial_max_kernel[grid_gemm]( + x, + weight_kn, + bias, + pmax, + M, + N, + K, + x.stride(0), + x.stride(1), + weight_kn.stride(0), + weight_kn.stride(1), + pmax.stride(0), + pmax.stride(1), + ) + + y = torch.empty((M, 1), device=x.device, dtype=x.dtype) + BLOCK_T = max(16, triton.next_power_of_2(max_n_tiles)) + _reduce_max_gelu_kernel[(M,)]( + pmax, + y, + M, + max_n_tiles, + pmax.stride(0), + pmax.stride(1), + y.stride(0), + BLOCK_T=BLOCK_T, + num_warps=1, + num_stages=1, + ) + return y + + +batch_size = 1024 +in_features = 8192 +out_features = 8192 +max_dim = 1 + + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + + +def get_init_inputs(): + return [in_features, out_features, max_dim] + + +class Model(nn.Module): + def __init__(self, in_features, out_features, max_dim): + super().__init__() + self.gemm = nn.Linear(in_features, out_features) + self.max_dim = max_dim + self._weight_kn = None + + def forward(self, x): + if self._weight_kn is None: + w = self.gemm.weight.data.to(dtype=x.dtype).contiguous() + self.gemm.weight.data = w + self.gemm.bias.data = self.gemm.bias.data.to(dtype=x.dtype).contiguous() + self._weight_kn = w.t().contiguous() + return kernel_function(x.contiguous(), self._weight_kn, self.gemm.bias) diff --git a/backends/triton/cpu/KernelBench/level2/86_Matmul_Divide_GELU.py b/backends/triton/cpu/KernelBench/level2/86_Matmul_Divide_GELU.py new file mode 100644 index 0000000..b0178a1 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level2/86_Matmul_Divide_GELU.py @@ -0,0 +1,205 @@ +# ruff: noqa: E731, A002 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _get_autotune_configs(): + configs = [ + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "BLOCK_K": 32, + "GROUP_SIZE_M": 1, + "EVEN_M": False, + "EVEN_N": False, + "EVEN_K": False, + } + ), + ] + + return configs + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "N", "K"], +) +@triton.jit +def _linear_div_gelu_kernel_packed_rhs( + x_ptr, + w_ptr, + b_ptr, + out_ptr, + M, + N, + K, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + stride_b, + stride_om: tl.constexpr, + stride_on: tl.constexpr, + divisor, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + grf_mode: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_K: tl.constexpr, +): + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + m_start = pid_m * BLOCK_M + n_start = pid_n * BLOCK_N + + x_desc = tl.make_tensor_descriptor( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + block_shape=(BLOCK_M, BLOCK_K), + ) + w_desc = tl.make_tensor_descriptor( + base=w_ptr, + shape=(K, N), + strides=(stride_wk, stride_wn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = x_desc.load([m_start, off_k]) + b = w_desc.load([off_k, n_start]) + acc += tl.dot(a, b) + + offs_n = n_start + tl.arange(0, BLOCK_N) + bias = tl.load(b_ptr + offs_n * stride_b, mask=offs_n < N, other=0.0) + acc = (acc + bias[None, :]) / divisor + + inv_sqrt2 = 0.7071067811865475 + y = 0.5 * acc * (1.0 + tl.math.erf(acc * inv_sqrt2)) + + out_desc = tl.make_tensor_descriptor( + base=out_ptr, + shape=(M, N), + strides=(stride_om, stride_on), + block_shape=(BLOCK_M, BLOCK_N), + ) + out_desc.store([m_start, n_start], y.to(out_ptr.type.element_ty)) + + +def kernel_function(input, weight_packed, bias, divisor=10.0): + """ + Fused Triton kernel for output = GELU((input @ weight_packed + bias) / divisor) + input: [M, K] fp16/bf16 on device + weight_packed: [K, N] fp16/bf16 on device + bias: [N] fp16/bf16/fp32 on device + """ + target_dtype = ( + input.dtype + if input.dtype in (torch.float16, torch.bfloat16) + else torch.bfloat16 + ) + x_xpu = input.to(dtype=target_dtype).contiguous() + w_xpu = weight_packed.to(dtype=target_dtype).contiguous() + b_xpu = bias.to(dtype=target_dtype).contiguous() + + M, K = x_xpu.shape + K_w, N = w_xpu.shape + assert K == K_w and b_xpu.shape[0] == N + + out = torch.empty((M, N), device=x_xpu.device, dtype=target_dtype) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + _linear_div_gelu_kernel_packed_rhs[grid]( + x_xpu, + w_xpu, + b_xpu, + out, + M, + N, + K, + x_xpu.stride(0), + x_xpu.stride(1), + w_xpu.stride(0), + w_xpu.stride(1), + b_xpu.stride(0), + out.stride(0), + out.stride(1), + float(divisor), + grf_mode="auto", + ) + return out + + +batch_size = 1024 +input_size = 8192 +output_size = 8192 +divisor = 10.0 + + +def get_inputs(): + return [torch.rand(batch_size, input_size, dtype=torch.bfloat16)] + + +def get_init_inputs(): + return [input_size, output_size, divisor] + + +class Model(nn.Module): + def __init__(self, input_size, output_size, divisor): + super().__init__() + self.linear = nn.Linear(input_size, output_size) + self.divisor = divisor + self.input_size = input_size + self.output_size = output_size + self._packed_w = None + self._bias_packed = None + self._packed_dtype = None + + def _lazy_init(self, device, dtype): + target_dtype = ( + dtype if dtype in (torch.float16, torch.bfloat16) else torch.bfloat16 + ) + if ( + self._packed_w is None + or self._bias_packed is None + or self._packed_dtype != target_dtype + ): + w = self.linear.weight.detach().to(device, dtype=target_dtype).contiguous() + b = self.linear.bias.detach().to(device, dtype=target_dtype).contiguous() + self._packed_w = w.t().contiguous() # [K, N] + self._bias_packed = b + self._packed_dtype = target_dtype + + def forward(self, x): + target_dtype = ( + x.dtype if x.dtype in (torch.float16, torch.bfloat16) else torch.bfloat16 + ) + self._lazy_init(x.device, x.dtype) + x_xpu = x.to(dtype=target_dtype).contiguous() + return kernel_function(x_xpu, self._packed_w, self._bias_packed, self.divisor) diff --git a/backends/triton/cpu/KernelBench/level2/99_Matmul_GELU_Softmax.py b/backends/triton/cpu/KernelBench/level2/99_Matmul_GELU_Softmax.py new file mode 100644 index 0000000..e525c45 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level2/99_Matmul_GELU_Softmax.py @@ -0,0 +1,198 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _gemm_configs(): + return [ + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}), + ] + + +@triton.autotune(configs=_gemm_configs(), key=["M", "N", "K"]) +@triton.jit +def _gemm_bias_kernel( + x_ptr, + w_ptr, + b_ptr, + y_ptr, + M, + N, + K, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + x_desc = tl.make_tensor_descriptor( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + block_shape=(BLOCK_M, BLOCK_K), + ) + w_desc = tl.make_tensor_descriptor( + base=w_ptr, + shape=(K, N), + strides=(stride_wk, stride_wn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = x_desc.load([pid_m * BLOCK_M, off_k]) + b = w_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + acc += bias[None, :] + acc = 0.5 * acc * (1.0 + tl.math.erf(acc * 0.7071067811865476)) + + y_desc = tl.make_tensor_descriptor( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + block_shape=(BLOCK_M, BLOCK_N), + ) + y_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(y_ptr.type.element_ty)) + + +def _softmax_configs(): + configs = [triton.Config({"BLOCK_N": 32},)] + return configs + + +@triton.autotune(configs=_softmax_configs(), key=["N"]) +@triton.jit +def _softmax_inplace_kernel( + ptr, + M, + N, + stride_m, + stride_n, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + if row >= M: + return + row64 = row.to(tl.int64) + row_ptr = ptr + row64 * stride_m + neg_inf = -float("inf") + offs_n = tl.arange(0, BLOCK_N) + + row_max = tl.full((), neg_inf, tl.float32) + for sn in tl.range(0, N, BLOCK_N): + cols = sn + offs_n + mask = cols < N + vals = tl.load( + row_ptr + cols.to(tl.int64) * stride_n, mask=mask, other=neg_inf + ).to(tl.float32) + row_max = tl.maximum(row_max, tl.max(vals, axis=0)) + + row_sum = tl.zeros((), tl.float32) + for sn in tl.range(0, N, BLOCK_N): + cols = sn + offs_n + mask = cols < N + vals = tl.load( + row_ptr + cols.to(tl.int64) * stride_n, mask=mask, other=neg_inf + ).to(tl.float32) + row_sum += tl.sum(tl.exp(vals - row_max), axis=0) + + inv_sum = 1.0 / row_sum + for sn in tl.range(0, N, BLOCK_N): + cols = sn + offs_n + mask = cols < N + vals = tl.load( + row_ptr + cols.to(tl.int64) * stride_n, mask=mask, other=neg_inf + ).to(tl.float32) + probs = tl.exp(vals - row_max) * inv_sum + tl.store( + row_ptr + cols.to(tl.int64) * stride_n, + probs.to(ptr.dtype.element_ty), + mask=mask, + ) + + +def kernel_function(x, weight_kn, bias) -> torch.Tensor: + M, K = x.shape + _, N = weight_kn.shape + + out = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid_gemm = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + _gemm_bias_kernel[grid_gemm]( + x, + weight_kn, + bias, + out, + M, + N, + K, + x.stride(0), + x.stride(1), + weight_kn.stride(0), + weight_kn.stride(1), + out.stride(0), + out.stride(1), + ) + + _softmax_inplace_kernel[(M,)]( + out, + M, + N, + out.stride(0), + out.stride(1), + ) + return out + + +batch_size = 1024 +in_features = 4096 +out_features = 4096 + + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + + +def get_init_inputs(): + return [in_features, out_features] + + +class Model(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear = nn.Linear(in_features, out_features) + self._weight_kn = None + + def forward(self, x): + if self._weight_kn is None: + w = self.linear.weight.data.to(dtype=x.dtype).contiguous() + self.linear.weight.data = w + self.linear.bias.data = self.linear.bias.data.to(dtype=x.dtype).contiguous() + self._weight_kn = w.t().contiguous() + return kernel_function(x.contiguous(), self._weight_kn, self.linear.bias) diff --git a/backends/triton/cpu/KernelBench/level2/9_Matmul_Subtract_Multiply_ReLU.py b/backends/triton/cpu/KernelBench/level2/9_Matmul_Subtract_Multiply_ReLU.py new file mode 100644 index 0000000..a56f9c9 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level2/9_Matmul_Subtract_Multiply_ReLU.py @@ -0,0 +1,155 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _gemm_configs(): + return [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + ), + ] + + +@triton.autotune(configs=_gemm_configs(), key=["M", "N", "K"]) +@triton.jit +def _gemm_sub_mul_relu_kernel( + x_ptr, + w_ptr, + b_ptr, + y_ptr, + M, + N, + K, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + sub_val, + mul_val, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + x_desc = tl.make_tensor_descriptor( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + block_shape=(BLOCK_M, BLOCK_K), + ) + w_desc = tl.make_tensor_descriptor( + base=w_ptr, + shape=(K, N), + strides=(stride_wk, stride_wn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = x_desc.load([pid_m * BLOCK_M, off_k]) + b = w_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) + acc += bias[None, :] + acc = tl.maximum((acc - sub_val) * mul_val, 0.0) + + y_desc = tl.make_tensor_descriptor( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + block_shape=(BLOCK_M, BLOCK_N), + ) + y_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(y_ptr.type.element_ty)) + + +def kernel_function( + x: torch.Tensor, + weight_kn: torch.Tensor, + bias: torch.Tensor, + subtract_value: float, + multiply_value: float, +) -> torch.Tensor: + M, K = x.shape + _, N = weight_kn.shape + y = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + _gemm_sub_mul_relu_kernel[grid]( + x, + weight_kn, + bias, + y, + M, + N, + K, + x.stride(0), + x.stride(1), + weight_kn.stride(0), + weight_kn.stride(1), + y.stride(0), + y.stride(1), + float(subtract_value), + float(multiply_value), + ) + return y + + +batch_size = 1024 +in_features = 8192 +out_features = 8192 +subtract_value = 2.0 +multiply_value = 1.5 + + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + + +def get_init_inputs(): + return [in_features, out_features, subtract_value, multiply_value] + + +class Model(nn.Module): + def __init__(self, in_features, out_features, subtract_value, multiply_value): + super().__init__() + self.linear = nn.Linear(in_features, out_features) + self.subtract_value = subtract_value + self.multiply_value = multiply_value + self._weight_kn = None + + def forward(self, x): + if self._weight_kn is None: + w = self.linear.weight.data.to(dtype=x.dtype).contiguous() + self.linear.weight.data = w + self.linear.bias.data = self.linear.bias.data.to(dtype=x.dtype).contiguous() + self._weight_kn = w.t().contiguous() + return kernel_function( + x.contiguous(), + self._weight_kn, + self.linear.bias, + self.subtract_value, + self.multiply_value, + ) diff --git a/problems/specs/KernelBench/level2/40_Matmul_Scaling_ResidualAdd.yaml b/problems/specs/KernelBench/level2/40_Matmul_Scaling_ResidualAdd.yaml index 98b532f..1419080 100644 --- a/problems/specs/KernelBench/level2/40_Matmul_Scaling_ResidualAdd.yaml +++ b/problems/specs/KernelBench/level2/40_Matmul_Scaling_ResidualAdd.yaml @@ -16,3 +16,20 @@ ci: IN_FEATURES: 32 OUT_FEATURES: 32 SCALING_FACTOR: 0.5 + +simple-cpu: + - params: [X] + dtype: bfloat16 + dims: + BATCH_SIZE: 32 + IN_FEATURES: 1024 + OUT_FEATURES: 1024 + SCALING_FACTOR: 0.5 + # larger shapes to trigger persistent kernel + - params: [X] + dtype: bfloat16 + dims: + BATCH_SIZE: 16384 + IN_FEATURES: 4096 + OUT_FEATURES: 4096 + SCALING_FACTOR: 0.5 diff --git a/problems/specs/KernelBench/level2/53_Gemm_Scaling_Hardtanh_GELU.yaml b/problems/specs/KernelBench/level2/53_Gemm_Scaling_Hardtanh_GELU.yaml index f21f068..a4f4bdd 100644 --- a/problems/specs/KernelBench/level2/53_Gemm_Scaling_Hardtanh_GELU.yaml +++ b/problems/specs/KernelBench/level2/53_Gemm_Scaling_Hardtanh_GELU.yaml @@ -20,3 +20,14 @@ ci: SCALING_FACTOR: 0.5 HARDTANH_MIN: -2 HARDTANH_MAX: 2 + +simple-cpu: + - params: [X] + dtype: bfloat16 + dims: + BATCH_SIZE: 2 + IN_FEATURES: 256 + OUT_FEATURES: 256 + SCALING_FACTOR: 0.5 + HARDTANH_MIN: -2 + HARDTANH_MAX: 2 diff --git a/problems/specs/KernelBench/level2/59_Matmul_Swish_Scaling.yaml b/problems/specs/KernelBench/level2/59_Matmul_Swish_Scaling.yaml index b8f00da..da76863 100644 --- a/problems/specs/KernelBench/level2/59_Matmul_Swish_Scaling.yaml +++ b/problems/specs/KernelBench/level2/59_Matmul_Swish_Scaling.yaml @@ -16,3 +16,12 @@ ci: IN_FEATURES: 32 OUT_FEATURES: 32 SCALING_FACTOR: 2.0 + +simple-cpu: + - params: [X] + dtype: bfloat16 + dims: + BATCH_SIZE: 2 + IN_FEATURES: 256 + OUT_FEATURES: 256 + SCALING_FACTOR: 2.0 diff --git a/problems/specs/KernelBench/level2/66_Matmul_Dropout_Softmax.yaml b/problems/specs/KernelBench/level2/66_Matmul_Dropout_Softmax.yaml index 7e6d370..859a0c8 100644 --- a/problems/specs/KernelBench/level2/66_Matmul_Dropout_Softmax.yaml +++ b/problems/specs/KernelBench/level2/66_Matmul_Dropout_Softmax.yaml @@ -16,3 +16,12 @@ ci: IN_FEATURES: 32 OUT_FEATURES: 32 DROPOUT_P: 0.2 + +simple-cpu: + - params: [X] + dtype: bfloat16 + dims: + BATCH_SIZE: 128 + IN_FEATURES: 256 + OUT_FEATURES: 256 + DROPOUT_P: 0.2 diff --git a/problems/specs/KernelBench/level2/70_Gemm_Sigmoid_Scaling_ResidualAdd.yaml b/problems/specs/KernelBench/level2/70_Gemm_Sigmoid_Scaling_ResidualAdd.yaml index 0e2402f..4ad6ab6 100644 --- a/problems/specs/KernelBench/level2/70_Gemm_Sigmoid_Scaling_ResidualAdd.yaml +++ b/problems/specs/KernelBench/level2/70_Gemm_Sigmoid_Scaling_ResidualAdd.yaml @@ -16,3 +16,13 @@ ci: INPUT_SIZE: 32 HIDDEN_SIZE: 32 SCALING_FACTOR: 2.0 + +simple-cpu: + - params: [X] + dtype: float32 + dims: + BATCH_SIZE: 2 + INPUT_SIZE: 256 + HIDDEN_SIZE: 256 + SCALING_FACTOR: 2.0 + diff --git a/problems/specs/KernelBench/level2/76_Gemm_Add_ReLU.yaml b/problems/specs/KernelBench/level2/76_Gemm_Add_ReLU.yaml index ab7b11a..a2f9f9f 100644 --- a/problems/specs/KernelBench/level2/76_Gemm_Add_ReLU.yaml +++ b/problems/specs/KernelBench/level2/76_Gemm_Add_ReLU.yaml @@ -16,3 +16,12 @@ ci: IN_FEATURES: 32 OUT_FEATURES: 32 BIAS_SHAPE: [32] # TODO: bind these to other dims + +simple-cpu: + - params: [X] + dtype: bfloat16 + dims: + BATCH_SIZE: 2 + IN_FEATURES: 256 + OUT_FEATURES: 256 + BIAS_SHAPE: [256] # TODO: bind these to other dims diff --git a/problems/specs/KernelBench/level2/80_Gemm_Max_Subtract_GELU.yaml b/problems/specs/KernelBench/level2/80_Gemm_Max_Subtract_GELU.yaml index 803e4db..74778d1 100644 --- a/problems/specs/KernelBench/level2/80_Gemm_Max_Subtract_GELU.yaml +++ b/problems/specs/KernelBench/level2/80_Gemm_Max_Subtract_GELU.yaml @@ -16,3 +16,12 @@ ci: IN_FEATURES: 32 OUT_FEATURES: 32 MAX_DIM: 1 + +simple-cpu: + - params: [X] + dtype: bfloat16 + dims: + BATCH_SIZE: 2 + IN_FEATURES: 256 + OUT_FEATURES: 256 + MAX_DIM: 1 diff --git a/problems/specs/KernelBench/level2/86_Matmul_Divide_GELU.yaml b/problems/specs/KernelBench/level2/86_Matmul_Divide_GELU.yaml index 9e6c981..1fed800 100644 --- a/problems/specs/KernelBench/level2/86_Matmul_Divide_GELU.yaml +++ b/problems/specs/KernelBench/level2/86_Matmul_Divide_GELU.yaml @@ -16,3 +16,12 @@ ci: INPUT_SIZE: 32 OUTPUT_SIZE: 32 DIVISOR: 10.0 + +simple-cpu: + - params: [X] + dtype: float16 + dims: + BATCH_SIZE: 2 + INPUT_SIZE: 1024 + OUTPUT_SIZE: 1024 + DIVISOR: 10.0 diff --git a/problems/specs/KernelBench/level2/99_Matmul_GELU_Softmax.yaml b/problems/specs/KernelBench/level2/99_Matmul_GELU_Softmax.yaml index e790068..e7b34a1 100644 --- a/problems/specs/KernelBench/level2/99_Matmul_GELU_Softmax.yaml +++ b/problems/specs/KernelBench/level2/99_Matmul_GELU_Softmax.yaml @@ -16,6 +16,15 @@ ci: OUT_FEAT: 64 flop: "2*BATCH*IN_FEAT*OUT_FEAT" +simple-cpu: + - params: [X] + dtype: float16 + dims: + BATCH: 2 + IN_FEAT: 512 + OUT_FEAT: 512 + flop: "2*BATCH*IN_FEAT*OUT_FEAT" + bench-gpu: - params: [X] dtype: float16 diff --git a/problems/specs/KernelBench/level2/9_Matmul_Subtract_Multiply_ReLU.yaml b/problems/specs/KernelBench/level2/9_Matmul_Subtract_Multiply_ReLU.yaml index b81e184..2d084fd 100644 --- a/problems/specs/KernelBench/level2/9_Matmul_Subtract_Multiply_ReLU.yaml +++ b/problems/specs/KernelBench/level2/9_Matmul_Subtract_Multiply_ReLU.yaml @@ -20,6 +20,17 @@ ci: MUL_VAL: 1.5 flop: "2*BATCH*IN_FEAT*OUT_FEAT + 2*BATCH*OUT_FEAT" +simple-cpu: + - params: [X] + dtype: float16 + dims: + BATCH: 4 + IN_FEAT: 512 + OUT_FEAT: 512 + SUB_VAL: 2.0 + MUL_VAL: 1.5 + flop: "2*BATCH*IN_FEAT*OUT_FEAT + 2*BATCH*OUT_FEAT" + bench-gpu: - params: [X] dtype: float16 diff --git a/pyproject.toml b/pyproject.toml index d0c35b1..ba285e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ override-dependencies = [ torch = { index = "pytorch" } pytorch-triton-xpu = { index = "pytorch" } pytorch-triton = { index = "pytorch" } -triton = { git = "https://github.com/triton-lang/triton-cpu.git", rev = "270e696" } +triton = { git = "https://github.com/triton-lang/triton-cpu.git", rev = "eece2e9" } lighthouse = { git = "https://github.com/llvm/lighthouse", rev = "456475d" } mlir-python-bindings = { index = "eudsl" }