diff --git a/backends/triton/cpu/KernelBench/level1/100_HingeLoss.py b/backends/triton/cpu/KernelBench/level1/100_HingeLoss.py new file mode 100644 index 0000000..92ea966 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/100_HingeLoss.py @@ -0,0 +1,75 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_X": 1024, "BLOCK_Y": 4}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_X": 2048, "BLOCK_Y": 2}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_X": 4096, "BLOCK_Y": 1}, num_warps=8, num_stages=2), + ], + key=["B", "D"], +) +@triton.jit +def _hinge_loss_kernel( + pred_ptr, + targ_ptr, + out_ptr, + B, + D, + stride_pb, + stride_pd, + BLOCK_X: tl.constexpr, + BLOCK_Y: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * BLOCK_Y + rows = row_start + tl.arange(0, BLOCK_Y) + mask_y = rows < B + + acc = tl.zeros((BLOCK_Y,), dtype=tl.float32) + + for col_start in range(0, D, BLOCK_X): + cols = col_start + tl.arange(0, BLOCK_X) + mask_x = cols < D + + targ = tl.load(targ_ptr + cols, mask=mask_x, other=0.0).to(tl.float32) + + offs = rows[:, None] * stride_pb + cols[None, :] * stride_pd + mask = mask_y[:, None] & mask_x[None, :] + pred = tl.load(pred_ptr + offs, mask=mask, other=0.0).to(tl.float32) + + hinge = tl.maximum(1.0 - pred * targ[None, :], 0.0) + acc += tl.sum(hinge, axis=1) + + tl.store(out_ptr + rows, acc, mask=mask_y) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + B, D = predictions.shape + row_sums = torch.empty(B, device=predictions.device, dtype=torch.float32) + + grid = lambda META: (triton.cdiv(B, META["BLOCK_Y"]),) + _hinge_loss_kernel[grid]( + predictions, + targets, + row_sums, + B, + D, + predictions.stride(0), + predictions.stride(1), + ) + + return row_sums.sum() / (B * D) diff --git a/backends/triton/cpu/KernelBench/level1/51_Argmax_over_a_dimension.py b/backends/triton/cpu/KernelBench/level1/51_Argmax_over_a_dimension.py new file mode 100644 index 0000000..f706782 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/51_Argmax_over_a_dimension.py @@ -0,0 +1,90 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 128}, num_warps=16, num_stages=2), + ], + key=["D1", "D2"], +) +@triton.jit +def argmax_dim1_kernel( + x_ptr, + out_ptr, + D1: tl.constexpr, + D2: tl.constexpr, + stride_b, + stride_d1, + stride_d2, + stride_ob, + stride_od2, + BLOCK_N: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_b = tl.program_id(1) + + col_start = pid_n * BLOCK_N + cols = col_start + tl.arange(0, BLOCK_N) + col_mask = cols < D2 + + max_val = tl.full((BLOCK_N,), -float("inf"), dtype=tl.float32) + max_idx = tl.zeros((BLOCK_N,), dtype=tl.int32) + + batch_offset = pid_b.to(tl.int64) * stride_b + col_offsets = cols.to(tl.int64) * stride_d2 + base = x_ptr + batch_offset + col_offsets + + for k in tl.range(0, D1): + val = tl.load(base + k * stride_d1, mask=col_mask, other=-float("inf")).to( + tl.float32 + ) + update = val > max_val + max_val = tl.where(update, val, max_val) + max_idx = tl.where(update, k, max_idx) + + out_ptrs = out_ptr + pid_b.to(tl.int64) * stride_ob + cols.to(tl.int64) * stride_od2 + tl.store(out_ptrs, max_idx.to(tl.int64), mask=col_mask) + + +class Model(nn.Module): + def __init__(self, dim=1): + super(Model, self).__init__() + try: + self.dim = int(dim) + except (ValueError, TypeError): + self.dim = 1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, D1, D2 = x.shape + output = torch.empty((B, D2), device=x.device, dtype=torch.int64) + + grid = lambda META: (triton.cdiv(D2, META["BLOCK_N"]), B) + argmax_dim1_kernel[grid]( + x, + output, + D1, + D2, + x.stride(0), + x.stride(1), + x.stride(2), + output.stride(0), + output.stride(1), + ) + + return output diff --git a/backends/triton/cpu/KernelBench/level1/52_Argmin_over_a_dimension.py b/backends/triton/cpu/KernelBench/level1/52_Argmin_over_a_dimension.py new file mode 100644 index 0000000..5e15e9f --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/52_Argmin_over_a_dimension.py @@ -0,0 +1,135 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_D2": 256, "BLOCK_K": 16, "warp_size": 32}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_D2": 256, "BLOCK_K": 32, "warp_size": 32}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_D2": 512, "BLOCK_K": 16, "warp_size": 32}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_D2": 512, "BLOCK_K": 32, "warp_size": 32}, num_warps=8, num_stages=2 + ), + triton.Config( + {"BLOCK_D2": 256, "BLOCK_K": 64, "warp_size": 32}, num_warps=8, num_stages=2 + ), + triton.Config( + {"BLOCK_D2": 128, "BLOCK_K": 32, "warp_size": 16}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_D2": 256, "BLOCK_K": 32, "warp_size": 16}, num_warps=8, num_stages=2 + ), + triton.Config( + {"BLOCK_D2": 512, "BLOCK_K": 16, "warp_size": 16}, num_warps=8, num_stages=2 + ), + triton.Config( + {"BLOCK_D2": 512, "BLOCK_K": 32, "warp_size": 16}, + num_warps=16, + num_stages=2, + ), + triton.Config( + {"BLOCK_D2": 1024, "BLOCK_K": 16, "warp_size": 32}, + num_warps=8, + num_stages=2, + ), + ], + key=["D1", "D2"], +) +@triton.jit +def argmin_kernel( + x_ptr, + out_ptr, + B, + D1, + D2, + stride_b, + stride_d1, + stride_d2, + out_stride_b, + out_stride_d2, + BLOCK_D2: tl.constexpr, + BLOCK_K: tl.constexpr, + warp_size: tl.constexpr, +): + pid = tl.program_id(0) + num_d2_blocks = tl.cdiv(D2, BLOCK_D2) + batch_idx = pid // num_d2_blocks + d2_block_idx = pid % num_d2_blocks + + d2_start = d2_block_idx * BLOCK_D2 + d2_offs = d2_start + tl.arange(0, BLOCK_D2) + d2_mask = d2_offs < D2 + + base = x_ptr + batch_idx.to(tl.int64) * stride_b + + min_val = tl.full([BLOCK_D2], float("inf"), dtype=tl.float32) + min_idx = tl.zeros([BLOCK_D2], dtype=tl.int32) + + k_offs_base = tl.arange(0, BLOCK_K) + + for k_start in tl.range(0, D1, BLOCK_K): + k_offs = k_start + k_offs_base + k_mask = k_offs < D1 + ptrs = ( + base + + k_offs[:, None].to(tl.int64) * stride_d1 + + d2_offs[None, :] * stride_d2 + ) + mask = k_mask[:, None] & d2_mask[None, :] + tile = tl.load(ptrs, mask=mask, other=float("inf")).to(tl.float32) + + tile_min = tl.min(tile, axis=0) + + update = tile_min < min_val + + k_indices = k_offs[:, None] + large_k = tl.full([1], D1, dtype=tl.int32) + k_masked = tl.where(tile == tile_min[None, :], k_indices, large_k) + tile_argmin = tl.min(k_masked, axis=0) + + min_idx = tl.where(update, tile_argmin, min_idx) + min_val = tl.where(update, tile_min, min_val) + + out_ptrs = out_ptr + batch_idx.to(tl.int64) * out_stride_b + d2_offs * out_stride_d2 + tl.store(out_ptrs, min_idx.to(tl.int64), mask=d2_mask) + + +class Model(nn.Module): + def __init__(self, dim: int): + super(Model, self).__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, D1, D2 = x.shape + output = torch.empty(B, D2, device=x.device, dtype=torch.int64) + + grid = lambda META: (B * triton.cdiv(D2, META["BLOCK_D2"]),) + + argmin_kernel[grid]( + x, + output, + B, + D1, + D2, + x.stride(0), + x.stride(1), + x.stride(2), + output.stride(0), + output.stride(1), + ) + + return output diff --git a/backends/triton/cpu/KernelBench/level1/53_Min_reduction_over_a_dimension.py b/backends/triton/cpu/KernelBench/level1/53_Min_reduction_over_a_dimension.py new file mode 100644 index 0000000..46ffb59 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/53_Min_reduction_over_a_dimension.py @@ -0,0 +1,118 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_D1": 128, "BLOCK_D2": 256, "warp_size": 32}, + num_warps=4, + num_stages=5, + ), + triton.Config( + {"BLOCK_D1": 256, "BLOCK_D2": 256, "warp_size": 32}, + num_warps=4, + num_stages=4, + ), + triton.Config( + {"BLOCK_D1": 64, "BLOCK_D2": 256, "warp_size": 32}, + num_warps=4, + num_stages=6, + ), + triton.Config( + {"BLOCK_D1": 512, "BLOCK_D2": 256, "warp_size": 32}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_D1": 256, "BLOCK_D2": 128, "warp_size": 32}, + num_warps=4, + num_stages=5, + ), + triton.Config( + {"BLOCK_D1": 128, "BLOCK_D2": 128, "warp_size": 32}, + num_warps=4, + num_stages=6, + ), + ], + key=["D1", "D2"], +) +@triton.jit +def min_reduction_kernel( + x_ptr, + out_ptr, + B, + D1, + D2, + stride_xb, + stride_xd1, + stride_xd2, + stride_ob, + stride_od2, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, + warp_size: tl.constexpr, +): + pid_d2 = tl.program_id(0) + pid_b = tl.program_id(1) + + d2_start = pid_d2 * BLOCK_D2 + d2_offs = d2_start + tl.arange(0, BLOCK_D2) + d2_mask = d2_offs < D2 + + batch_offset = pid_b.to(tl.int64) * stride_xb + base = x_ptr + batch_offset + + acc = tl.full((BLOCK_D2,), value=float("inf"), dtype=tl.float32) + + for d1_start in range(0, D1, BLOCK_D1): + d1_offs = d1_start + tl.arange(0, BLOCK_D1) + mask = (d1_offs[:, None] < D1) & d2_mask[None, :] + ptrs = base + d1_offs[:, None] * stride_xd1 + d2_offs[None, :] * stride_xd2 + tile = tl.load(ptrs, mask=mask, other=float("inf")).to(tl.float32) + tile_min = tl.min(tile, axis=0) + acc = tl.minimum(acc, tile_min) + + out_ptrs = out_ptr + pid_b.to(tl.int64) * stride_ob + d2_offs * stride_od2 + tl.store(out_ptrs, acc.to(tl.float16), mask=d2_mask) + + +class Model(nn.Module): + def __init__(self, dim: int): + super(Model, self).__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B = x.shape[0] + D1 = x.shape[1] + D2 = x.shape[2] + + out = torch.empty((B, D2), device=x.device, dtype=x.dtype) + + grid = lambda META: ( + triton.cdiv(D2, META["BLOCK_D2"]), + B, + ) + + min_reduction_kernel[grid]( + x, + out, + B, + D1, + D2, + x.stride(0), + x.stride(1), + x.stride(2), + out.stride(0), + out.stride(1), + ) + + return out diff --git a/backends/triton/cpu/KernelBench/level1/88_MinGPTNewGelu.py b/backends/triton/cpu/KernelBench/level1/88_MinGPTNewGelu.py new file mode 100644 index 0000000..cc326f3 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/88_MinGPTNewGelu.py @@ -0,0 +1,54 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl +from triton.language.extra.cpu import libdevice + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def _gelu_kernel( + x_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < n_elements + + x = tl.load(x_ptr + offs, mask=mask, other=0.0) + + x_f32 = x.to(tl.float32) + inner = x_f32 + 0.044715 * x_f32 * x_f32 * x_f32 + t = libdevice.tanh(0.7978845608028654 * inner) + result = (0.5 * x_f32 * (1.0 + t)).to(x.dtype) + + tl.store(out_ptr + offs, result, mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x_flat = x.contiguous().view(-1) + n_elements = x_flat.numel() + output = torch.empty_like(x_flat) + + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + _gelu_kernel[grid](x_flat, output, n_elements) + + return output.view(x.shape) diff --git a/backends/triton/cpu/KernelBench/level1/89_cumsum.py b/backends/triton/cpu/KernelBench/level1/89_cumsum.py new file mode 100644 index 0000000..44366ea --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/89_cumsum.py @@ -0,0 +1,79 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def scan_add_op(a, b): + return a + b + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_SIZE": 512}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1), + ], + key=["N"], +) +@triton.jit +def cumsum_kernel( + x_ptr, + out_ptr, + M, + N, + stride_xm, + stride_om, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + x_row = x_ptr + row_idx * stride_xm + o_row = out_ptr + row_idx * stride_om + + running_total = 0.0 + + for block_start in tl.range(0, N, BLOCK_SIZE): + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + x = tl.load(x_row + offsets, mask=mask, other=0.0).to(tl.float32) + + scanned = tl.associative_scan(x, axis=0, combine_fn=scan_add_op) + result = scanned + running_total + + running_total = running_total + tl.sum(x, axis=0) + + tl.store(o_row + offsets, result, mask=mask) + + +def kernel_function(x): + M, N = x.shape + out = torch.empty_like(x) + assert x.stride(1) == 1 + + grid = (M,) + cumsum_kernel[grid]( + x, + out, + M, + N, + x.stride(0), + out.stride(0), + ) + return out + + +class Model(nn.Module): + def __init__(self, dim): + super(Model, self).__init__() + self.dim = dim + + def forward(self, x): + return kernel_function(x) diff --git a/backends/triton/cpu/KernelBench/level1/90_cumprod.py b/backends/triton/cpu/KernelBench/level1/90_cumprod.py new file mode 100644 index 0000000..47cde28 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/90_cumprod.py @@ -0,0 +1,83 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def _mul_combine(a, b): + return a * b + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + ], + key=["N"], +) +@triton.jit +def _cumprod_kernel( + x_ptr, + out_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + + running_prod = 1.0 + + for col_start in range(0, N, BLOCK_SIZE): + cols = col_start + tl.arange(0, BLOCK_SIZE) + mask = cols < N + + x = tl.load( + x_ptr + row.to(tl.int64) * stride_m + cols * stride_n, mask=mask, other=1.0 + ).to(tl.float32) + + cum = tl.associative_scan(x, 0, _mul_combine) + + cum = cum * running_prod + + tl.store( + out_ptr + row.to(tl.int64) * stride_m + cols * stride_n, cum, mask=mask + ) + + block_product = tl.reduce(x, 0, _mul_combine) + running_prod = running_prod * block_product + + +def cumprod_triton(x, dim): + assert dim == 1 + M, N = x.shape + out = torch.empty_like(x) + + grid = (M,) + _cumprod_kernel[grid]( + x, + out, + M, + N, + x.stride(0), + x.stride(1), + ) + return out + + +class Model(nn.Module): + def __init__(self, dim): + super(Model, self).__init__() + self.dim = dim + + def forward(self, x): + return cumprod_triton(x, self.dim) diff --git a/backends/triton/cpu/KernelBench/level1/91_cumsum_reverse.py b/backends/triton/cpu/KernelBench/level1/91_cumsum_reverse.py new file mode 100644 index 0000000..5d9ed34 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/91_cumsum_reverse.py @@ -0,0 +1,80 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_N": 1024}, num_warps=8), + ], + key=["N"], +) +@triton.jit +def reverse_cumsum_kernel( + x_ptr, + out_ptr, + M, + N: tl.constexpr, + stride_xm, + stride_xn, + stride_om, + stride_on, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + x_base = x_ptr + row.to(tl.int64) * stride_xm + out_base = out_ptr + row.to(tl.int64) * stride_om + + # Pass 1: compute total sum of the row + total = 0.0 + for col_start in range(0, N, BLOCK_N): + cols = col_start + tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(x_base + cols * stride_xn, mask=mask, other=0.0).to(tl.float32) + total = total + tl.sum(x, axis=0) + + # Pass 2: compute suffix sums (left-to-right) + # suffix[i] = total - exclusive_prefix_sum[i] + # where exclusive_prefix[i] = running_prefix + cumsum_inc[local_i] - x[local_i] + running_prefix = 0.0 + for col_start in range(0, N, BLOCK_N): + cols = col_start + tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(x_base + cols * stride_xn, mask=mask, other=0.0).to(tl.float32) + cumsum_inc = tl.cumsum(x, axis=0) + suffix = total - running_prefix - cumsum_inc + x + tl.store(out_base + cols * stride_on, suffix, mask=mask) + running_prefix = running_prefix + tl.sum(x, axis=0) + + +class Model(nn.Module): + def __init__(self, dim): + super(Model, self).__init__() + self.dim = dim + + def forward(self, x): + x = x.contiguous() + M, N = x.shape + output = torch.empty_like(x) + + grid = (M,) + reverse_cumsum_kernel[grid]( + x, + output, + M, + N, + x.stride(0), + x.stride(1), + output.stride(0), + output.stride(1), + ) + return output diff --git a/backends/triton/cpu/KernelBench/level1/92_cumsum_exclusive.py b/backends/triton/cpu/KernelBench/level1/92_cumsum_exclusive.py new file mode 100644 index 0000000..00ee180 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/92_cumsum_exclusive.py @@ -0,0 +1,76 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=4, num_stages=2), + ], + key=["N"], +) +@triton.jit +def exclusive_cumsum_kernel( + x_ptr, + out_ptr, + M, + N, + stride_xm, + stride_xn, + stride_om, + stride_on, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + + x_row_ptr = x_ptr + row_idx.to(tl.int64) * stride_xm + o_row_ptr = out_ptr + row_idx.to(tl.int64) * stride_om + + running_sum = tl.zeros([1], dtype=tl.float32) + + for col_start in range(0, N, BLOCK_SIZE): + col_offsets = col_start + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < N + + x = tl.load(x_row_ptr + col_offsets * stride_xn, mask=mask, other=0.0).to( + tl.float32 + ) + + inclusive = tl.cumsum(x, axis=0) + exclusive = inclusive - x + + result = exclusive + running_sum + + tl.store(o_row_ptr + col_offsets * stride_on, result, mask=mask) + + running_sum += tl.sum(x, axis=0) + + +class Model(nn.Module): + def __init__(self, dim): + super(Model, self).__init__() + self.dim = dim + + def forward(self, x): + M, N = x.shape + out = torch.empty_like(x) + grid = (M,) + exclusive_cumsum_kernel[grid]( + x, + out, + M, + N, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + ) + return out diff --git a/backends/triton/cpu/KernelBench/level1/93_masked_cumsum.py b/backends/triton/cpu/KernelBench/level1/93_masked_cumsum.py new file mode 100644 index 0000000..6a18231 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/93_masked_cumsum.py @@ -0,0 +1,76 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def scan_add_op(a, b): + return a + b + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=4, num_stages=1), + ], + key=["N"], +) +@triton.jit +def cumsum_kernel( + x_ptr, + out_ptr, + M, + N, + stride_xm, + stride_om, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + x_row = x_ptr + row_idx * stride_xm + o_row = out_ptr + row_idx * stride_om + + running_total = 0.0 + + for block_start in tl.range(0, N, BLOCK_SIZE): + offsets = block_start + tl.arange(0, BLOCK_SIZE) + col_mask = offsets < N + + x = tl.load(x_row + offsets, mask=col_mask, other=0.0).to(tl.float32) + + scanned = tl.associative_scan(x, axis=0, combine_fn=scan_add_op) + result = scanned + running_total + + running_total = running_total + tl.sum(x, axis=0) + + tl.store(o_row + offsets, result.to(out_ptr.dtype.element_ty), mask=col_mask) + + +class Model(nn.Module): + def __init__(self, dim): + super(Model, self).__init__() + self.dim = dim + + def forward(self, x, mask): + masked = x.float() * mask.float() + masked = masked.contiguous() + M, N = masked.shape + out = torch.empty_like(masked) + + grid = (M,) + cumsum_kernel[grid]( + masked, + out, + M, + N, + masked.stride(0), + out.stride(0), + ) + return out diff --git a/backends/triton/cpu/KernelBench/level1/94_MSELoss.py b/backends/triton/cpu/KernelBench/level1/94_MSELoss.py new file mode 100644 index 0000000..3329625 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/94_MSELoss.py @@ -0,0 +1,90 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def get_autotune_configs(): + configs = [] + for block_size in [1024, 2048]: + configs.append( + triton.Config( + {"BLOCK_SIZE": block_size}, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_configs(), + key=["N_COLS"], +) +@triton.jit +def mse_row_kernel( + pred_ptr, + target_ptr, + row_sums_ptr, + N_ROWS, + N_COLS, + stride_pred_row, + stride_pred_col, + stride_target_row, + stride_target_col, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + + pred_row_start = pred_ptr + row_idx * stride_pred_row + target_row_start = target_ptr + row_idx * stride_target_row + + acc = 0.0 + + for col_start in tl.range(0, N_COLS, BLOCK_SIZE): + cols = col_start + tl.arange(0, BLOCK_SIZE) + mask = cols < N_COLS + + pred_vals = tl.load( + pred_row_start + cols * stride_pred_col, mask=mask, other=0.0 + ).to(tl.float32) + target_vals = tl.load( + target_row_start + cols * stride_target_col, mask=mask, other=0.0 + ).to(tl.float32) + + diff = pred_vals - target_vals + acc += tl.sum(diff * diff, axis=0) + + tl.store(row_sums_ptr + row_idx, acc) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + predictions = predictions.contiguous() + targets = targets.contiguous() + + N_ROWS, N_COLS = predictions.shape + + row_sums = torch.empty(N_ROWS, device=predictions.device, dtype=torch.float32) + + grid = (N_ROWS,) + mse_row_kernel[grid]( + predictions, + targets, + row_sums, + N_ROWS, + N_COLS, + predictions.stride(0), + predictions.stride(1), + targets.stride(0), + targets.stride(1), + ) + + return torch.sum(row_sums) / (N_ROWS * N_COLS) diff --git a/backends/triton/cpu/KernelBench/level1/95_CrossEntropyLoss.py b/backends/triton/cpu/KernelBench/level1/95_CrossEntropyLoss.py new file mode 100644 index 0000000..85d7bf3 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/95_CrossEntropyLoss.py @@ -0,0 +1,85 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _ce_configs(): + configs = [] + for BN in [1024, 2048]: + configs.append(triton.Config({"BLOCK_N": BN})) + return configs + + +@triton.autotune(configs=_ce_configs(), key=["N"]) +@triton.jit +def _cross_entropy_online_kernel( + logits_ptr, + targets_ptr, + losses_ptr, + N, + stride_lm, + stride_ln, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + row_off = row.to(tl.int64) * stride_lm + + LOG2E: tl.constexpr = 1.4426950408889634 + LN2: tl.constexpr = 0.6931471805599453 + + running_max = -float("inf") + running_sum = 0.0 + + for start in range(0, N, BLOCK_N): + col_offs = start + tl.arange(0, BLOCK_N) + mask = col_offs < N + x = tl.load( + logits_ptr + row_off + col_offs * stride_ln, mask=mask, other=-float("inf") + ).to(tl.float32) + + block_max = tl.max(x, axis=0) + new_max = tl.maximum(running_max, block_max) + + running_sum = running_sum * tl.math.exp2( + (running_max - new_max) * LOG2E + ) + tl.sum(tl.math.exp2((x - new_max) * LOG2E), axis=0) + running_max = new_max + + log_sum_exp = tl.math.log2(running_sum) * LN2 + + target = tl.load(targets_ptr + row).to(tl.int64) + target_logit = tl.load(logits_ptr + row_off + target * stride_ln).to(tl.float32) + + loss = -target_logit + running_max + log_sum_exp + tl.store(losses_ptr + row, loss) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + predictions = predictions.contiguous() + targets = targets.contiguous() + + M, N = predictions.shape + losses = torch.empty(M, device=predictions.device, dtype=torch.float32) + + grid = (M,) + _cross_entropy_online_kernel[grid]( + predictions, + targets, + losses, + N, + predictions.stride(0), + predictions.stride(1), + ) + + return losses.mean() diff --git a/backends/triton/cpu/KernelBench/level1/96_HuberLoss.py b/backends/triton/cpu/KernelBench/level1/96_HuberLoss.py new file mode 100644 index 0000000..f1d9919 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/96_HuberLoss.py @@ -0,0 +1,81 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 4096}, num_warps=4), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8), + ], + key=["n_cols"], +) +@triton.jit +def smooth_l1_row_kernel( + predictions_ptr, + targets_ptr, + row_sums_ptr, + n_cols, + stride_pred, + stride_targ, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + + pred_row_start = row_idx * stride_pred + targ_row_start = row_idx * stride_targ + + row_sum = 0.0 + + for col_start in range(0, n_cols, BLOCK_SIZE): + col_offsets = col_start + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + pred = tl.load( + predictions_ptr + pred_row_start + col_offsets, mask=mask, other=0.0 + ).to(tl.float32) + targ = tl.load( + targets_ptr + targ_row_start + col_offsets, mask=mask, other=0.0 + ).to(tl.float32) + + diff = pred - targ + abs_diff = tl.abs(diff) + loss = tl.where(abs_diff < 1.0, 0.5 * diff * diff, abs_diff - 0.5) + + row_sum += tl.sum(loss, axis=0) + + tl.store(row_sums_ptr + row_idx, row_sum) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + predictions = predictions.contiguous() + targets = targets.contiguous() + + n_rows = predictions.shape[0] + n_cols = predictions.shape[1] if predictions.ndim > 1 else predictions.numel() + n_elements = predictions.numel() + + row_sums = torch.empty(n_rows, device=predictions.device, dtype=torch.float32) + + grid = (n_rows,) + smooth_l1_row_kernel[grid]( + predictions, + targets, + row_sums, + n_cols, + predictions.stride(0) if predictions.ndim > 1 else n_cols, + targets.stride(0) if targets.ndim > 1 else n_cols, + ) + + return row_sums.sum() / n_elements diff --git a/backends/triton/cpu/KernelBench/level1/97_ScaledDotProductAttention.py b/backends/triton/cpu/KernelBench/level1/97_ScaledDotProductAttention.py new file mode 100644 index 0000000..03dbd2c --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/97_ScaledDotProductAttention.py @@ -0,0 +1,280 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import math + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32}, num_warps=16, num_stages=2 + ), + ], + key=["SEQ_LEN", "HEAD_DIM"], +) +@triton.jit +def _qk_gemm_kernel( + Q_ptr, + K_ptr, + S_ptr, + stride_qb, + stride_qm, + stride_qd, + stride_kb, + stride_kn, + stride_kd, + stride_sb, + stride_sm, + stride_sn, + SEQ_LEN, + HEAD_DIM, + scale, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_tile = tl.program_id(1) + + num_n = tl.cdiv(SEQ_LEN, BLOCK_N) + pid_m = pid_tile // num_n + pid_n = pid_tile % num_n + + q_base = Q_ptr + pid_b.to(tl.int64) * stride_qb + k_base = K_ptr + pid_b.to(tl.int64) * stride_kb + + Q_bp = tl.make_block_ptr( + base=q_base, + shape=(SEQ_LEN, HEAD_DIM), + strides=(stride_qm, stride_qd), + offsets=(pid_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(1, 0), + ) + KT_bp = tl.make_block_ptr( + base=k_base, + shape=(HEAD_DIM, SEQ_LEN), + strides=(stride_kd, stride_kn), + offsets=(0, pid_n * BLOCK_N), + block_shape=(BLOCK_K, BLOCK_N), + order=(1, 0), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for _ in range(0, HEAD_DIM, BLOCK_K): + q = tl.load(Q_bp, boundary_check=(0, 1)) + k_t = tl.load(KT_bp, boundary_check=(0, 1)) + acc = tl.dot(q, k_t, acc) + Q_bp = tl.advance(Q_bp, (0, BLOCK_K)) + KT_bp = tl.advance(KT_bp, (BLOCK_K, 0)) + + acc = acc * scale + + s_base = S_ptr + pid_b.to(tl.int64) * stride_sb + S_bp = tl.make_block_ptr( + base=s_base, + shape=(SEQ_LEN, SEQ_LEN), + strides=(stride_sm, stride_sn), + offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + tl.store(S_bp, acc.to(S_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_K": 64, "BLOCK_N": 128}, num_warps=16, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_K": 128, "BLOCK_N": 128}, num_warps=16, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_K": 64, "BLOCK_N": 256}, num_warps=16, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_K": 64, "BLOCK_N": 128}, num_warps=8, num_stages=3 + ), + ], + key=["SEQ_LEN", "HEAD_DIM"], +) +@triton.jit +def _fused_softmax_pv_kernel( + S_ptr, + V_ptr, + O_ptr, + stride_sb, + stride_sm, + stride_sn, + stride_vb, + stride_vn, + stride_vd, + stride_ob, + stride_om, + stride_od, + SEQ_LEN, + HEAD_DIM, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_tile = tl.program_id(1) + + num_n = tl.cdiv(HEAD_DIM, BLOCK_N) + pid_m = pid_tile // num_n + pid_n = pid_tile % num_n + + s_base = S_ptr + pid_b.to(tl.int64) * stride_sb + v_base = V_ptr + pid_b.to(tl.int64) * stride_vb + off_m = pid_m * BLOCK_M + + LOG2E = 1.4426950408889634 + + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for off_k in range(0, SEQ_LEN, BLOCK_K): + S_bp = tl.make_block_ptr( + base=s_base, + shape=(SEQ_LEN, SEQ_LEN), + strides=(stride_sm, stride_sn), + offsets=(off_m, off_k), + block_shape=(BLOCK_M, BLOCK_K), + order=(1, 0), + ) + s = tl.load(S_bp, boundary_check=(0, 1)).to(tl.float32) + + chunk_max = tl.max(s, axis=1) + m_new = tl.maximum(m_i, chunk_max) + alpha = tl.math.exp2((m_i - m_new) * LOG2E) + exp_s = tl.math.exp2((s - m_new[:, None]) * LOG2E) + chunk_sum = tl.sum(exp_s, axis=1) + l_i = alpha * l_i + chunk_sum + acc = acc * alpha[:, None] + m_i = m_new + + V_bp = tl.make_block_ptr( + base=v_base, + shape=(SEQ_LEN, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(off_k, pid_n * BLOCK_N), + block_shape=(BLOCK_K, BLOCK_N), + order=(1, 0), + ) + v = tl.load(V_bp, boundary_check=(0, 1)) + acc = tl.dot(exp_s.to(v.dtype), v, acc) + + acc = acc / l_i[:, None] + + o_base = O_ptr + pid_b.to(tl.int64) * stride_ob + O_bp = tl.make_block_ptr( + base=o_base, + shape=(SEQ_LEN, HEAD_DIM), + strides=(stride_om, stride_od), + offsets=(off_m, pid_n * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + tl.store(O_bp, acc.to(O_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + self._s_buf = None + self._o_buf = None + + def forward(self, Q, K, V): + B, H, S, D = Q.shape + scale = 1.0 / math.sqrt(D) + + device = Q.device + input_dtype = Q.dtype + + if Q.dtype != torch.float16: + Q = Q.to(torch.float16) + K = K.to(torch.float16) + V = V.to(torch.float16) + + Q = Q.reshape(B * H, S, D).contiguous() + K = K.reshape(B * H, S, D).contiguous() + V = V.reshape(B * H, S, D).contiguous() + + BH = B * H + + if ( + self._s_buf is None + or self._s_buf.shape != (BH, S, S) + or self._s_buf.device != device + ): + self._s_buf = torch.empty(BH, S, S, device=device, dtype=torch.float32) + S_mat = self._s_buf + + grid1 = lambda META: ( + BH, + triton.cdiv(S, META["BLOCK_M"]) * triton.cdiv(S, META["BLOCK_N"]), + ) + _qk_gemm_kernel[grid1]( + Q, + K, + S_mat, + Q.stride(0), + Q.stride(1), + Q.stride(2), + K.stride(0), + K.stride(1), + K.stride(2), + S_mat.stride(0), + S_mat.stride(1), + S_mat.stride(2), + S, + D, + scale, + ) + + if ( + self._o_buf is None + or self._o_buf.shape != (BH, S, D) + or self._o_buf.device != device + or self._o_buf.dtype != Q.dtype + ): + self._o_buf = torch.empty(BH, S, D, device=device, dtype=Q.dtype) + O = self._o_buf + + grid2 = lambda META: ( + BH, + triton.cdiv(S, META["BLOCK_M"]) * triton.cdiv(D, META["BLOCK_N"]), + ) + _fused_softmax_pv_kernel[grid2]( + S_mat, + V, + O, + S_mat.stride(0), + S_mat.stride(1), + S_mat.stride(2), + V.stride(0), + V.stride(1), + V.stride(2), + O.stride(0), + O.stride(1), + O.stride(2), + S, + D, + ) + + result = O.reshape(B, H, S, D) + if input_dtype != torch.float16: + result = result.to(input_dtype) + + return result diff --git a/backends/triton/cpu/KernelBench/level1/98_KLDivLoss.py b/backends/triton/cpu/KernelBench/level1/98_KLDivLoss.py new file mode 100644 index 0000000..528f004 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/98_KLDivLoss.py @@ -0,0 +1,87 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def get_kl_div_configs(): + configs = [] + for block_size in [4096]: + configs.append( + triton.Config( + {"BLOCK_SIZE": block_size}, + ) + ) + return configs + + +@triton.autotune( + configs=get_kl_div_configs(), + key=["n_cols"], +) +@triton.jit +def kl_div_row_kernel( + pred_ptr, + target_ptr, + out_ptr, + n_cols, + stride_pred_row, + stride_target_row, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + pred_row_start = pred_ptr + row_idx * stride_pred_row + target_row_start = target_ptr + row_idx * stride_target_row + + acc = 0.0 + + for col_start in tl.range(0, n_cols, BLOCK_SIZE): + col_offsets = col_start + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + pred_vals = tl.load(pred_row_start + col_offsets, mask=mask, other=1.0).to( + tl.float32 + ) + target_vals = tl.load(target_row_start + col_offsets, mask=mask, other=0.0).to( + tl.float32 + ) + + kl_vals = tl.where( + target_vals > 0, target_vals * tl.log(target_vals / pred_vals), 0.0 + ) + acc += tl.sum(kl_vals, axis=0) + + tl.store(out_ptr + row_idx, acc) + + +def kernel_function(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + assert predictions.is_contiguous() and targets.is_contiguous() + B, N = predictions.shape + + row_sums = torch.empty(B, device=predictions.device, dtype=torch.float32) + + grid = (B,) + kl_div_row_kernel[grid]( + predictions, + targets, + row_sums, + N, + predictions.stride(0), + targets.stride(0), + ) + + return row_sums.sum() / B + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return kernel_function(predictions, targets) diff --git a/backends/triton/cpu/KernelBench/level1/99_TripletMarginLoss.py b/backends/triton/cpu/KernelBench/level1/99_TripletMarginLoss.py new file mode 100644 index 0000000..58a0cd9 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/99_TripletMarginLoss.py @@ -0,0 +1,112 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_K": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_K": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_K": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_K": 2048}, num_warps=4, num_stages=2), + ], + key=["D"], +) +@triton.jit +def _triplet_margin_loss_kernel( + anchor_ptr, + positive_ptr, + negative_ptr, + loss_ptr, + B, + D, + stride_ab, + stride_ad, + stride_pb, + stride_pd, + stride_nb, + stride_nd, + margin, + eps, + BLOCK_K: tl.constexpr, +): + row = tl.program_id(0) + if row >= B: + return + + base_a = row.to(tl.int64) * stride_ab + base_p = row.to(tl.int64) * stride_pb + base_n = row.to(tl.int64) * stride_nb + + sum_pos_sq = 0.0 + sum_neg_sq = 0.0 + + for k_start in range(0, D, BLOCK_K): + offs_k = k_start + tl.arange(0, BLOCK_K) + mask_k = offs_k < D + + a = tl.load( + anchor_ptr + base_a + offs_k * stride_ad, mask=mask_k, other=0.0 + ).to(tl.float32) + p = tl.load( + positive_ptr + base_p + offs_k * stride_pd, mask=mask_k, other=0.0 + ).to(tl.float32) + n = tl.load( + negative_ptr + base_n + offs_k * stride_nd, mask=mask_k, other=0.0 + ).to(tl.float32) + + diff_pos = a - p + eps + diff_neg = a - n + eps + sum_pos_sq += tl.sum(diff_pos * diff_pos, axis=0) + sum_neg_sq += tl.sum(diff_neg * diff_neg, axis=0) + + d_pos = tl.sqrt(sum_pos_sq) + d_neg = tl.sqrt(sum_neg_sq) + + loss_val = tl.maximum(d_pos - d_neg + margin, 0.0) + + tl.atomic_add(loss_ptr, loss_val, sem="relaxed") + + +class Model(nn.Module): + def __init__(self, margin=1.0): + super(Model, self).__init__() + self.margin = margin + self.eps = 1e-6 + + def forward(self, anchor, positive, negative): + device = anchor.device + B, D = anchor.shape + + anchor = anchor.contiguous() + positive = positive.contiguous() + negative = negative.contiguous() + + loss_accum = torch.zeros((), device=device, dtype=torch.float32) + + grid = (B,) + _triplet_margin_loss_kernel[grid]( + anchor, + positive, + negative, + loss_accum, + B, + D, + anchor.stride(0), + anchor.stride(1), + positive.stride(0), + positive.stride(1), + negative.stride(0), + negative.stride(1), + self.margin, + self.eps, + ) + + return loss_accum / B diff --git a/problems/specs/KernelBench/level1/100_HingeLoss.yaml b/problems/specs/KernelBench/level1/100_HingeLoss.yaml index 8f21a60..7d9a1d6 100644 --- a/problems/specs/KernelBench/level1/100_HingeLoss.yaml +++ b/problems/specs/KernelBench/level1/100_HingeLoss.yaml @@ -15,3 +15,10 @@ ci: dims: BATCH_SIZE: 64 INPUT_DIM: 64 + +bench-cpu: + - params: [predictions, targets] + dtype: float32 + dims: + BATCH_SIZE: 1024 + INPUT_DIM: 1024 diff --git a/problems/specs/KernelBench/level1/51_Argmax_over_a_dimension.yaml b/problems/specs/KernelBench/level1/51_Argmax_over_a_dimension.yaml index 917cfb3..606e023 100644 --- a/problems/specs/KernelBench/level1/51_Argmax_over_a_dimension.yaml +++ b/problems/specs/KernelBench/level1/51_Argmax_over_a_dimension.yaml @@ -14,3 +14,12 @@ ci: DIM1: 64 DIM2: 63 ARGMAX_DIM: 1 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 32 + DIM1: 512 + DIM2: 511 + ARGMAX_DIM: 1 diff --git a/problems/specs/KernelBench/level1/52_Argmin_over_a_dimension.yaml b/problems/specs/KernelBench/level1/52_Argmin_over_a_dimension.yaml index fe08c21..a2947c9 100644 --- a/problems/specs/KernelBench/level1/52_Argmin_over_a_dimension.yaml +++ b/problems/specs/KernelBench/level1/52_Argmin_over_a_dimension.yaml @@ -14,3 +14,12 @@ ci: DIM1: 64 DIM2: 63 ARGMIN_DIM: 1 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 32 + DIM1: 512 + DIM2: 511 + ARGMIN_DIM: 1 diff --git a/problems/specs/KernelBench/level1/53_Min_reduction_over_a_dimension.yaml b/problems/specs/KernelBench/level1/53_Min_reduction_over_a_dimension.yaml index bf5fbda..122a71c 100644 --- a/problems/specs/KernelBench/level1/53_Min_reduction_over_a_dimension.yaml +++ b/problems/specs/KernelBench/level1/53_Min_reduction_over_a_dimension.yaml @@ -14,3 +14,12 @@ ci: DIM1: 64 DIM2: 63 REDUCE_DIM: 1 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 32 + DIM1: 512 + DIM2: 511 + REDUCE_DIM: 1 diff --git a/problems/specs/KernelBench/level1/88_MinGPTNewGelu.yaml b/problems/specs/KernelBench/level1/88_MinGPTNewGelu.yaml index 273d44b..9dd1ca1 100644 --- a/problems/specs/KernelBench/level1/88_MinGPTNewGelu.yaml +++ b/problems/specs/KernelBench/level1/88_MinGPTNewGelu.yaml @@ -11,3 +11,10 @@ ci: dims: BATCH_SIZE: 32 DIM: 32 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 1024 + DIM: 1024 diff --git a/problems/specs/KernelBench/level1/89_cumsum.yaml b/problems/specs/KernelBench/level1/89_cumsum.yaml index 82de8c7..b6ed4dd 100644 --- a/problems/specs/KernelBench/level1/89_cumsum.yaml +++ b/problems/specs/KernelBench/level1/89_cumsum.yaml @@ -13,3 +13,11 @@ ci: BATCH_SIZE: 64 INPUT_DIM: 64 SCAN_DIM: 1 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 1024 + INPUT_DIM: 1024 + SCAN_DIM: 1 diff --git a/problems/specs/KernelBench/level1/90_cumprod.yaml b/problems/specs/KernelBench/level1/90_cumprod.yaml index 82de8c7..b6ed4dd 100644 --- a/problems/specs/KernelBench/level1/90_cumprod.yaml +++ b/problems/specs/KernelBench/level1/90_cumprod.yaml @@ -13,3 +13,11 @@ ci: BATCH_SIZE: 64 INPUT_DIM: 64 SCAN_DIM: 1 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 1024 + INPUT_DIM: 1024 + SCAN_DIM: 1 diff --git a/problems/specs/KernelBench/level1/91_cumsum_reverse.yaml b/problems/specs/KernelBench/level1/91_cumsum_reverse.yaml index 82de8c7..b6ed4dd 100644 --- a/problems/specs/KernelBench/level1/91_cumsum_reverse.yaml +++ b/problems/specs/KernelBench/level1/91_cumsum_reverse.yaml @@ -13,3 +13,11 @@ ci: BATCH_SIZE: 64 INPUT_DIM: 64 SCAN_DIM: 1 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 1024 + INPUT_DIM: 1024 + SCAN_DIM: 1 diff --git a/problems/specs/KernelBench/level1/92_cumsum_exclusive.yaml b/problems/specs/KernelBench/level1/92_cumsum_exclusive.yaml index 82de8c7..b6ed4dd 100644 --- a/problems/specs/KernelBench/level1/92_cumsum_exclusive.yaml +++ b/problems/specs/KernelBench/level1/92_cumsum_exclusive.yaml @@ -13,3 +13,11 @@ ci: BATCH_SIZE: 64 INPUT_DIM: 64 SCAN_DIM: 1 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 1024 + INPUT_DIM: 1024 + SCAN_DIM: 1 diff --git a/problems/specs/KernelBench/level1/93_masked_cumsum.yaml b/problems/specs/KernelBench/level1/93_masked_cumsum.yaml index d296d36..5d36961 100644 --- a/problems/specs/KernelBench/level1/93_masked_cumsum.yaml +++ b/problems/specs/KernelBench/level1/93_masked_cumsum.yaml @@ -17,6 +17,14 @@ ci: INPUT_DIM: 64 SCAN_DIM: 1 +bench-cpu: + - params: [x, mask] + dtype: float32 + dims: + BATCH_SIZE: 1024 + INPUT_DIM: 1024 + SCAN_DIM: 1 + bench-gpu: - params: [x, mask] dtype: float16 diff --git a/problems/specs/KernelBench/level1/94_MSELoss.yaml b/problems/specs/KernelBench/level1/94_MSELoss.yaml index d6893b3..e10f89d 100644 --- a/problems/specs/KernelBench/level1/94_MSELoss.yaml +++ b/problems/specs/KernelBench/level1/94_MSELoss.yaml @@ -15,3 +15,10 @@ ci: dims: BATCH_SIZE: 64 INPUT_DIM: 64 + +bench-cpu: + - params: [predictions, targets] + dtype: float32 + dims: + BATCH_SIZE: 1024 + INPUT_DIM: 1024 diff --git a/problems/specs/KernelBench/level1/95_CrossEntropyLoss.yaml b/problems/specs/KernelBench/level1/95_CrossEntropyLoss.yaml index 1925d32..d1a1645 100644 --- a/problems/specs/KernelBench/level1/95_CrossEntropyLoss.yaml +++ b/problems/specs/KernelBench/level1/95_CrossEntropyLoss.yaml @@ -15,3 +15,10 @@ ci: dims: BATCH_SIZE: 64 NUM_CLASSES: 8 + +bench-cpu: + - params: [predictions, targets] + dtype: float32 + dims: + BATCH_SIZE: 1024 + NUM_CLASSES: 128 diff --git a/problems/specs/KernelBench/level1/96_HuberLoss.yaml b/problems/specs/KernelBench/level1/96_HuberLoss.yaml index d6893b3..e10f89d 100644 --- a/problems/specs/KernelBench/level1/96_HuberLoss.yaml +++ b/problems/specs/KernelBench/level1/96_HuberLoss.yaml @@ -15,3 +15,10 @@ ci: dims: BATCH_SIZE: 64 INPUT_DIM: 64 + +bench-cpu: + - params: [predictions, targets] + dtype: float32 + dims: + BATCH_SIZE: 1024 + INPUT_DIM: 1024 diff --git a/problems/specs/KernelBench/level1/97_ScaledDotProductAttention.yaml b/problems/specs/KernelBench/level1/97_ScaledDotProductAttention.yaml index 701f51b..8be8530 100644 --- a/problems/specs/KernelBench/level1/97_ScaledDotProductAttention.yaml +++ b/problems/specs/KernelBench/level1/97_ScaledDotProductAttention.yaml @@ -19,3 +19,12 @@ ci: NUM_HEADS: 8 SEQUENCE_LENGTH: 16 EMBEDDING_DIMENSION: 32 + +bench-cpu: + - params: [Q, K, V] + dtype: float16 + dims: + BATCH_SIZE: 2 + NUM_HEADS: 8 + SEQUENCE_LENGTH: 16 + EMBEDDING_DIMENSION: 32 diff --git a/problems/specs/KernelBench/level1/98_KLDivLoss.yaml b/problems/specs/KernelBench/level1/98_KLDivLoss.yaml index 63efbdc..95b550e 100644 --- a/problems/specs/KernelBench/level1/98_KLDivLoss.yaml +++ b/problems/specs/KernelBench/level1/98_KLDivLoss.yaml @@ -16,3 +16,10 @@ ci: dims: BATCH_SIZE: 64 INPUT_DIM: 64 + +bench-cpu: + - params: [predictions, targets] + dtype: float32 + dims: + BATCH_SIZE: 1024 + INPUT_DIM: 1024 diff --git a/problems/specs/KernelBench/level1/99_TripletMarginLoss.yaml b/problems/specs/KernelBench/level1/99_TripletMarginLoss.yaml index 114af17..2db46a3 100644 --- a/problems/specs/KernelBench/level1/99_TripletMarginLoss.yaml +++ b/problems/specs/KernelBench/level1/99_TripletMarginLoss.yaml @@ -20,3 +20,11 @@ ci: BATCH_SIZE: 64 INPUT_DIM: 8 MARGIN: 1.0 + +bench-cpu: + - params: [anchor, positive, negative] + dtype: float32 + dims: + BATCH_SIZE: 1024 + INPUT_DIM: 128 + MARGIN: 1.0 diff --git a/pyproject.toml b/pyproject.toml index d0c35b1..ba285e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ override-dependencies = [ torch = { index = "pytorch" } pytorch-triton-xpu = { index = "pytorch" } pytorch-triton = { index = "pytorch" } -triton = { git = "https://github.com/triton-lang/triton-cpu.git", rev = "270e696" } +triton = { git = "https://github.com/triton-lang/triton-cpu.git", rev = "eece2e9" } lighthouse = { git = "https://github.com/llvm/lighthouse", rev = "456475d" } mlir-python-bindings = { index = "eudsl" }