diff --git a/backends/triton/cpu/KernelBench/level1/10_3D_tensor_matrix_multiplication.py b/backends/triton/cpu/KernelBench/level1/10_3D_tensor_matrix_multiplication.py index a94792e..7ff7a4b 100644 --- a/backends/triton/cpu/KernelBench/level1/10_3D_tensor_matrix_multiplication.py +++ b/backends/triton/cpu/KernelBench/level1/10_3D_tensor_matrix_multiplication.py @@ -10,15 +10,15 @@ import triton.language as tl -def _configs(): - return [ +@triton.autotune( + configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - -@triton.autotune(configs=_configs(), key=["M", "N", "K"]) + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], + key=["M", "N", "K"], +) @triton.jit def _matmul_kernel( a_ptr, @@ -101,6 +101,13 @@ def forward(self, A, B): c_flat = torch.empty((total_m, l), device=a.device, dtype=torch.bfloat16) def grid(META): + assert ( + m % META["BLOCK_M"] == 0 + and l % META["BLOCK_N"] == 0 + and k % META["BLOCK_K"] == 0 + ), ( + "M, L, and K must be divisible by BLOCK_M, BLOCK_N, and BLOCK_K respectively" + ) return ( triton.cdiv(total_m, META["BLOCK_M"]) * triton.cdiv(l, META["BLOCK_N"]), ) @@ -118,6 +125,7 @@ def grid(META): b.stride(1), c_flat.stride(0), c_flat.stride(1), + assume_in_bounds=True, ) return c_flat.reshape(batch, m, l) diff --git a/backends/triton/cpu/KernelBench/level1/11_4D_tensor_matrix_multiplication.py b/backends/triton/cpu/KernelBench/level1/11_4D_tensor_matrix_multiplication.py index 99f96b2..a188609 100644 --- a/backends/triton/cpu/KernelBench/level1/11_4D_tensor_matrix_multiplication.py +++ b/backends/triton/cpu/KernelBench/level1/11_4D_tensor_matrix_multiplication.py @@ -29,16 +29,13 @@ def swizzle_tile( return pid_m, pid_n -def get_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - @triton.autotune( - configs=get_autotune_configs(), + configs=[ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], key=["M", "N", "K"], ) @triton.jit @@ -111,9 +108,15 @@ def forward(self, A, B): C_2d = torch.empty((M, N), device=A.device, dtype=torch.bfloat16) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - ) + def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N, and K must be divisible by BLOCK_M, BLOCK_N, and BLOCK_K respectively" + ) + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) _gemm_kernel[grid]( A_flat, @@ -128,6 +131,7 @@ def forward(self, A, B): B_fp16.stride(1), C_2d.stride(0), C_2d.stride(1), + assume_in_bounds=True, ) result = C_2d.view(b_dim, i_dim, j_dim, k_dim) diff --git a/backends/triton/cpu/KernelBench/level1/12_Matmul_with_diagonal_matrices_.py b/backends/triton/cpu/KernelBench/level1/12_Matmul_with_diagonal_matrices_.py index 7a1a6c0..27d7980 100644 --- a/backends/triton/cpu/KernelBench/level1/12_Matmul_with_diagonal_matrices_.py +++ b/backends/triton/cpu/KernelBench/level1/12_Matmul_with_diagonal_matrices_.py @@ -15,6 +15,9 @@ triton.Config( {"BLOCK_M": 32, "BLOCK_N": 32}, ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + ), ], key=["N", "M"], ) diff --git a/backends/triton/cpu/KernelBench/level1/13_Matmul_for_symmetric_matrices.py b/backends/triton/cpu/KernelBench/level1/13_Matmul_for_symmetric_matrices.py index c47d878..dc2718a 100644 --- a/backends/triton/cpu/KernelBench/level1/13_Matmul_for_symmetric_matrices.py +++ b/backends/triton/cpu/KernelBench/level1/13_Matmul_for_symmetric_matrices.py @@ -29,15 +29,15 @@ def swizzle_tile( return pid_m, pid_n -def _configs(): - return [ +@triton.autotune( + configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - -@triton.autotune(configs=_configs(), key=["M", "N", "K"]) + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], + key=["M", "N", "K"], +) @triton.jit def _matmul_kernel( a_ptr, @@ -111,9 +111,16 @@ def forward(self, A, B): N = B.shape[1] C = torch.empty((M, N), device=A.device, dtype=torch.bfloat16) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - ) + def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N, and K must be divisible by BLOCK_M, BLOCK_N, and BLOCK_K respectively" + ) + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + _matmul_kernel[grid]( A, B, @@ -130,6 +137,7 @@ def forward(self, A, B): DIVISIBLE_M=(M % 256 == 0), DIVISIBLE_N=(N % 128 == 0), DIVISIBLE_K=(K % 32 == 0), + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.py b/backends/triton/cpu/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.py index 3e05bc4..2d3722a 100644 --- a/backends/triton/cpu/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.py +++ b/backends/triton/cpu/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.py @@ -34,8 +34,9 @@ def swizzle_tile( @triton.autotune( configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] ], key=["M", "N", "K"], ) @@ -117,9 +118,15 @@ def forward(self, A, B): N = B.shape[1] C = torch.zeros((M, N), device=A.device, dtype=A.dtype) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - ) + def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N, and K must be divisible by BLOCK_M, BLOCK_N, and BLOCK_K respectively" + ) + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) _triu_matmul_kernel[grid]( A, @@ -134,6 +141,7 @@ def forward(self, A, B): B.stride(1), C.stride(0), C.stride(1), + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.py b/backends/triton/cpu/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.py index 9dde830..4b4861b 100644 --- a/backends/triton/cpu/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.py +++ b/backends/triton/cpu/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.py @@ -32,10 +32,11 @@ def swizzle_tile( @triton.autotune( configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] ], - key=["M"], + key=["M", "N", "K"], ) @triton.jit def tril_matmul_kernel( @@ -111,9 +112,14 @@ def forward(self, A, B): M = A.shape[0] C = torch.zeros(M, M, device=A.device, dtype=A.dtype) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(M, META["BLOCK_N"]), - ) + def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and M % META["BLOCK_N"] == 0 + and M % META["BLOCK_K"] == 0 + ), "M must be divisible by BLOCK_M, BLOCK_N, and BLOCK_K respectively" + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(M, META["BLOCK_N"]),) + tril_matmul_kernel[grid]( A, B, @@ -125,6 +131,7 @@ def forward(self, A, B): B.stride(1), C.stride(0), C.stride(1), + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/16_Matmul_with_transposed_A.py b/backends/triton/cpu/KernelBench/level1/16_Matmul_with_transposed_A.py index 348ff28..5d50966 100644 --- a/backends/triton/cpu/KernelBench/level1/16_Matmul_with_transposed_A.py +++ b/backends/triton/cpu/KernelBench/level1/16_Matmul_with_transposed_A.py @@ -10,16 +10,13 @@ import triton.language as tl -def get_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - @triton.autotune( - configs=get_autotune_configs(), + configs=[ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], key=["M", "N", "K"], ) @triton.jit @@ -96,6 +93,13 @@ def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: C = torch.empty((M, N), device=A.device, dtype=A.dtype) def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N, and K must be divisible by BLOCK_M, BLOCK_N, and BLOCK_K respectively" + ) return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) _matmul_at_kernel[grid]( @@ -111,6 +115,7 @@ def grid(META): B.stride(1), C.stride(0), C.stride(1), + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/17_Matmul_with_transposed_B.py b/backends/triton/cpu/KernelBench/level1/17_Matmul_with_transposed_B.py index f88ca8d..3840323 100644 --- a/backends/triton/cpu/KernelBench/level1/17_Matmul_with_transposed_B.py +++ b/backends/triton/cpu/KernelBench/level1/17_Matmul_with_transposed_B.py @@ -10,16 +10,13 @@ import triton.language as tl -def get_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - @triton.autotune( - configs=get_autotune_configs(), + configs=[ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], key=["M", "N", "K"], ) @triton.jit @@ -100,6 +97,13 @@ def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: C = torch.empty((M, N), device=A.device, dtype=A.dtype) def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N, and K must be divisible by BLOCK_M, BLOCK_N, and BLOCK_K respectively" + ) return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) _matmul_bt_kernel[grid]( @@ -115,6 +119,7 @@ def grid(META): B.stride(1), C.stride(0), C.stride(1), + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/18_Matmul_with_transposed_both.py b/backends/triton/cpu/KernelBench/level1/18_Matmul_with_transposed_both.py index b4c40df..9b7bba9 100644 --- a/backends/triton/cpu/KernelBench/level1/18_Matmul_with_transposed_both.py +++ b/backends/triton/cpu/KernelBench/level1/18_Matmul_with_transposed_both.py @@ -10,15 +10,15 @@ import triton.language as tl -def _configs(): - return [ +@triton.autotune( + configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - -@triton.autotune(configs=_configs(), key=["M", "N", "K"]) + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], + key=["M", "N", "K"], +) @triton.jit def _matmul_tt_kernel( A_ptr, @@ -96,6 +96,13 @@ def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: C = torch.empty((M, N), device=A.device, dtype=A.dtype) def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N, and K must be divisible by BLOCK_M, BLOCK_N, and BLOCK_K respectively" + ) return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) _matmul_tt_kernel[grid]( @@ -111,6 +118,7 @@ def grid(META): B.stride(1), C.stride(0), C.stride(1), + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/19_ReLU.py b/backends/triton/cpu/KernelBench/level1/19_ReLU.py index b7be751..6d0562e 100644 --- a/backends/triton/cpu/KernelBench/level1/19_ReLU.py +++ b/backends/triton/cpu/KernelBench/level1/19_ReLU.py @@ -13,29 +13,26 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32, "NUM_PROGRAMS": 64}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["n_elements"], ) @triton.jit -def relu_kernel_persistent( +def relu_kernel( x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, - NUM_PROGRAMS: tl.constexpr, ): pid = tl.program_id(0) - num_blocks = tl.cdiv(n_elements, BLOCK_SIZE) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements - for block_id in tl.range(pid, num_blocks, NUM_PROGRAMS): - block_start = block_id * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - output = tl.maximum(x, 0.0) - tl.store(output_ptr + offsets, output, mask=mask) + x = tl.load(x_ptr + offsets, mask=mask) + out = tl.maximum(x, 0.0) + tl.store(output_ptr + offsets, out, mask=mask) class Model(nn.Module): @@ -45,8 +42,8 @@ def __init__(self): def forward(self, x: torch.Tensor) -> torch.Tensor: output = torch.empty_like(x) n_elements = x.numel() - grid = lambda META: (META["NUM_PROGRAMS"],) - relu_kernel_persistent[grid](x, output, n_elements) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + relu_kernel[grid](x, output, n_elements) return output diff --git a/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py b/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py index 80f240d..6c5206c 100644 --- a/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py +++ b/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py @@ -10,15 +10,15 @@ import triton.language as tl -def _configs(): - return [ +@triton.autotune( + configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - -@triton.autotune(configs=_configs(), key=["M", "N", "K"]) + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], + key=["M", "N", "K"], +) @triton.jit def _matmul_kernel( a_ptr, @@ -89,6 +89,13 @@ def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: C = torch.empty((M, N), device=A.device, dtype=A.dtype) def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N and K must be divisible by BLOCK_M, BLOCK_N and BLOCK_K respectively" + ) return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) _matmul_kernel[grid]( @@ -104,6 +111,7 @@ def grid(META): B.stride(1), C.stride(0), C.stride(1), + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/20_LeakyReLU.py b/backends/triton/cpu/KernelBench/level1/20_LeakyReLU.py index 12f9a28..f73ca5c 100644 --- a/backends/triton/cpu/KernelBench/level1/20_LeakyReLU.py +++ b/backends/triton/cpu/KernelBench/level1/20_LeakyReLU.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["n_elements"], ) diff --git a/backends/triton/cpu/KernelBench/level1/21_Sigmoid.py b/backends/triton/cpu/KernelBench/level1/21_Sigmoid.py index a6b8f44..930ab4d 100644 --- a/backends/triton/cpu/KernelBench/level1/21_Sigmoid.py +++ b/backends/triton/cpu/KernelBench/level1/21_Sigmoid.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["N"], ) @@ -29,13 +30,13 @@ def _sigmoid_kernel( offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < N - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) inv_ln2 = 1.4426950408889634 e = tl.math.exp2((-x) * inv_ln2) y = 1.0 / (1.0 + e) - tl.store(out_ptr + offsets, y.to(tl.bfloat16), mask=mask) + tl.store(out_ptr + offsets, y, mask=mask) class Model(nn.Module): diff --git a/backends/triton/cpu/KernelBench/level1/22_Tanh.py b/backends/triton/cpu/KernelBench/level1/22_Tanh.py index e261dfc..475b2f9 100644 --- a/backends/triton/cpu/KernelBench/level1/22_Tanh.py +++ b/backends/triton/cpu/KernelBench/level1/22_Tanh.py @@ -16,8 +16,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["n_elements"], ) @@ -32,7 +33,7 @@ def _tanh_kernel( offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) # tanh(x) = 2*sigmoid(2x) - 1 # sigmoid(z) = 1/(1 + exp2(-z * log2(e))) @@ -42,7 +43,7 @@ def _tanh_kernel( sig = 1.0 / (1.0 + e) result = 2.0 * sig - 1.0 - tl.store(out_ptr + offsets, result.to(tl.bfloat16), mask=mask) + tl.store(out_ptr + offsets, result, mask=mask) class Model(nn.Module): diff --git a/backends/triton/cpu/KernelBench/level1/23_Softmax.py b/backends/triton/cpu/KernelBench/level1/23_Softmax.py index c813a14..5887f99 100644 --- a/backends/triton/cpu/KernelBench/level1/23_Softmax.py +++ b/backends/triton/cpu/KernelBench/level1/23_Softmax.py @@ -10,15 +10,15 @@ import triton.language as tl -def _softmax_configs(): - return [ +@triton.autotune( + configs=[ triton.Config( - {"BLOCK_N": 32}, - ), - ] - - -@triton.autotune(configs=_softmax_configs(), key=["N"]) + {"BLOCK_N": n}, + ) + for n in [32, 64, 128, 256, 512, 1024, 2048, 4096] + ], + key=["N"], +) @triton.jit def _softmax_kernel( inp_ptr, @@ -43,9 +43,7 @@ def _softmax_kernel( for start in range(0, N, BLOCK_N): offs = start + tl.arange(0, BLOCK_N) mask = offs < N - x = tl.load(row_inp + offs * stride_in, mask=mask, other=-float("inf")).to( - tl.float32 - ) + x = tl.load(row_inp + offs * stride_in, mask=mask, other=-float("inf")) block_max = tl.max(x, axis=0) new_max = tl.maximum(row_max, block_max) row_sum = row_sum * tl.math.exp2((row_max - new_max) * LOG2E) + tl.sum( @@ -59,9 +57,7 @@ def _softmax_kernel( for start in range(0, N, BLOCK_N): offs = start + tl.arange(0, BLOCK_N) mask = offs < N - x = tl.load(row_inp + offs * stride_in, mask=mask, other=-float("inf")).to( - tl.float32 - ) + x = tl.load(row_inp + offs * stride_in, mask=mask, other=-float("inf")) e = tl.math.exp2((x - row_max) * LOG2E) y = (e * inv_sum).to(tl.bfloat16) tl.store(row_out + offs * stride_on, y, mask=mask) diff --git a/backends/triton/cpu/KernelBench/level1/24_LogSoftmax.py b/backends/triton/cpu/KernelBench/level1/24_LogSoftmax.py index e314810..986c60b 100644 --- a/backends/triton/cpu/KernelBench/level1/24_LogSoftmax.py +++ b/backends/triton/cpu/KernelBench/level1/24_LogSoftmax.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_N": 32, "warp_size": 32}, - ), + {"BLOCK_N": n}, + ) + for n in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["N"], ) @@ -27,7 +28,6 @@ def _logsoftmax_kernel( stride_im, stride_om, BLOCK_N: tl.constexpr, - warp_size: tl.constexpr, ): pid_m = tl.program_id(0) row_inp = inp_ptr + pid_m.to(tl.int64) * stride_im @@ -42,7 +42,7 @@ def _logsoftmax_kernel( for start in range(0, N, BLOCK_N): offs = start + tl.arange(0, BLOCK_N) mask = offs < N - x = tl.load(row_inp + offs, mask=mask, other=-float("inf")).to(tl.float32) + x = tl.load(row_inp + offs, mask=mask, other=-float("inf")) block_max = tl.max(x, axis=0) m_new = tl.maximum(m, block_max) s = s * tl.math.exp2((m - m_new) * LOG2E) + tl.sum( @@ -55,9 +55,9 @@ def _logsoftmax_kernel( for start in range(0, N, BLOCK_N): offs = start + tl.arange(0, BLOCK_N) mask = offs < N - x = tl.load(row_inp + offs, mask=mask, other=-float("inf")).to(tl.float32) + x = tl.load(row_inp + offs, mask=mask, other=-float("inf")) y = x - m - log_s - tl.store(row_out + offs, y.to(tl.bfloat16), mask=mask) + tl.store(row_out + offs, y, mask=mask) class Model(nn.Module): diff --git a/backends/triton/cpu/KernelBench/level1/25_Swish.py b/backends/triton/cpu/KernelBench/level1/25_Swish.py index 14ea27a..eac2988 100644 --- a/backends/triton/cpu/KernelBench/level1/25_Swish.py +++ b/backends/triton/cpu/KernelBench/level1/25_Swish.py @@ -20,8 +20,9 @@ def _sigmoid_exp2(x): @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["n_elements"], ) @@ -36,10 +37,10 @@ def swish_kernel( offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask, other=0.0) - x_f32 = x.to(tl.float32) + x_f32 = x sig = _sigmoid_exp2(x_f32) result = x_f32 * sig - tl.store(output_ptr + offsets, result.to(tl.bfloat16), mask=mask) + tl.store(output_ptr + offsets, result, mask=mask) class Model(nn.Module): diff --git a/backends/triton/cpu/KernelBench/level1/26_GELU_.py b/backends/triton/cpu/KernelBench/level1/26_GELU_.py index 777e424..f1c89f2 100644 --- a/backends/triton/cpu/KernelBench/level1/26_GELU_.py +++ b/backends/triton/cpu/KernelBench/level1/26_GELU_.py @@ -13,29 +13,27 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32, "NUM_PROGS": 16}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["n_elements"], ) @triton.jit -def gelu_persistent_kernel( +def gelu_kernel( x_ptr, - out_ptr, + output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, - NUM_PROGS: tl.constexpr, ): pid = tl.program_id(0) - num_tiles = tl.cdiv(n_elements, BLOCK_SIZE) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements - for tile_id in range(pid, num_tiles, NUM_PROGS): - offsets = tile_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - out = 0.5 * x * (1.0 + tl.math.erf(x * 0.70710678118654752440)) - tl.store(out_ptr + offsets, out.to(tl.bfloat16), mask=mask) + x = tl.load(x_ptr + offsets, mask=mask) + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + out = 0.5 * x * (1.0 + tl.math.erf(x * 0.70710678118654752440)) + tl.store(output_ptr + offsets, out, mask=mask) class Model(nn.Module): @@ -46,8 +44,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_flat = x.view(-1) out_flat = torch.empty_like(x_flat) n_elements = x_flat.numel() - grid = lambda META: (META["NUM_PROGS"],) - gelu_persistent_kernel[grid](x_flat, out_flat, n_elements) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + gelu_kernel[grid](x_flat, out_flat, n_elements) return out_flat.view_as(x) diff --git a/backends/triton/cpu/KernelBench/level1/27_SELU_.py b/backends/triton/cpu/KernelBench/level1/27_SELU_.py index 3eb2540..8488716 100644 --- a/backends/triton/cpu/KernelBench/level1/27_SELU_.py +++ b/backends/triton/cpu/KernelBench/level1/27_SELU_.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["n_elements"], ) @@ -32,11 +33,11 @@ def selu_kernel( mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask, other=0.0) - x_f32 = x.to(tl.float32) + x_f32 = x result = tl.where(x_f32 > 0.0, scale * x_f32, scale * alpha * (tl.exp(x_f32) - 1.0)) - tl.store(out_ptr + offsets, result.to(tl.bfloat16), mask=mask) + tl.store(out_ptr + offsets, result, mask=mask) class Model(nn.Module): diff --git a/backends/triton/cpu/KernelBench/level1/28_HardSigmoid.py b/backends/triton/cpu/KernelBench/level1/28_HardSigmoid.py index 2d9d71b..ca2dd0c 100644 --- a/backends/triton/cpu/KernelBench/level1/28_HardSigmoid.py +++ b/backends/triton/cpu/KernelBench/level1/28_HardSigmoid.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["N"], ) @@ -30,13 +31,13 @@ def hardsigmoid_kernel( mask = offsets < N x = tl.load(x_ptr + offsets, mask=mask, other=0.0) - x_f32 = x.to(tl.float32) + x_f32 = x result = x_f32 * (1.0 / 6.0) + 0.5 result = tl.maximum(result, 0.0) result = tl.minimum(result, 1.0) - tl.store(out_ptr + offsets, result.to(x.dtype), mask=mask) + tl.store(out_ptr + offsets, result, mask=mask) class Model(nn.Module): diff --git a/backends/triton/cpu/KernelBench/level1/29_Softplus.py b/backends/triton/cpu/KernelBench/level1/29_Softplus.py index 1829f87..2f75a7c 100644 --- a/backends/triton/cpu/KernelBench/level1/29_Softplus.py +++ b/backends/triton/cpu/KernelBench/level1/29_Softplus.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["n_elements"], ) @@ -29,7 +30,7 @@ def softplus_kernel( offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) # softplus(x) = log(1 + exp(x)), with threshold for numerical stability THRESHOLD: tl.constexpr = 20.0 diff --git a/backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py b/backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py index e002715..d13d934 100644 --- a/backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py +++ b/backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py @@ -10,15 +10,15 @@ import triton.language as tl -def _configs(): - return [ +@triton.autotune( + configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - -@triton.autotune(configs=_configs(), key=["M", "N", "K"]) + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], + key=["M", "N", "K"], +) @triton.jit def _matmul_kernel( a_ptr, @@ -87,9 +87,16 @@ def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: N = B_fp16.shape[1] C = torch.empty((M, N), device=A.device, dtype=torch.bfloat16) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - ) + def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N and K must be divisible by BLOCK_M, BLOCK_N and BLOCK_K respectively" + ) + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + _matmul_kernel[grid]( A_fp16, B_fp16, @@ -103,6 +110,7 @@ def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: B_fp16.stride(1), C.stride(0), C.stride(1), + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/30_Softsign.py b/backends/triton/cpu/KernelBench/level1/30_Softsign.py index 574b897..7b85f7e 100644 --- a/backends/triton/cpu/KernelBench/level1/30_Softsign.py +++ b/backends/triton/cpu/KernelBench/level1/30_Softsign.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["n_elements"], ) @@ -29,10 +30,10 @@ def softsign_kernel( offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) - x_f32 = x.to(tl.float32) + x_f32 = x abs_x = tl.abs(x_f32) result = x_f32 / (1.0 + abs_x) - tl.store(output_ptr + offsets, result.to(x.dtype), mask=mask) + tl.store(output_ptr + offsets, result, mask=mask) class Model(nn.Module): diff --git a/backends/triton/cpu/KernelBench/level1/31_ELU.py b/backends/triton/cpu/KernelBench/level1/31_ELU.py index 0397929..dc07093 100644 --- a/backends/triton/cpu/KernelBench/level1/31_ELU.py +++ b/backends/triton/cpu/KernelBench/level1/31_ELU.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["n_elements"], ) @@ -31,7 +32,7 @@ def elu_kernel( mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask, other=0.0) - x_f32 = x.to(tl.float32) + x_f32 = x inv_ln2: tl.constexpr = 1.4426950408889634 exp_x = tl.math.exp2(x_f32 * inv_ln2) @@ -39,7 +40,7 @@ def elu_kernel( result = tl.where(x_f32 > 0.0, x_f32, neg_branch) - tl.store(out_ptr + offsets, result.to(x.dtype), mask=mask) + tl.store(out_ptr + offsets, result, mask=mask) class Model(nn.Module): diff --git a/backends/triton/cpu/KernelBench/level1/32_HardTanh.py b/backends/triton/cpu/KernelBench/level1/32_HardTanh.py index 12e4fb8..6264685 100644 --- a/backends/triton/cpu/KernelBench/level1/32_HardTanh.py +++ b/backends/triton/cpu/KernelBench/level1/32_HardTanh.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048, 4096] ], key=["n_elements"], ) diff --git a/backends/triton/cpu/KernelBench/level1/3_Batched_matrix_multiplication.py b/backends/triton/cpu/KernelBench/level1/3_Batched_matrix_multiplication.py index 8e07a2f..f9e068d 100644 --- a/backends/triton/cpu/KernelBench/level1/3_Batched_matrix_multiplication.py +++ b/backends/triton/cpu/KernelBench/level1/3_Batched_matrix_multiplication.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] ], key=["M", "N", "K"], ) @@ -102,6 +103,13 @@ def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: C = torch.empty((BATCH, M, N), device=A.device, dtype=A.dtype) def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N and K must be divisible by BLOCK_M, BLOCK_N and BLOCK_K respectively" + ) return ( BATCH * triton.cdiv(M, META["BLOCK_M"]) @@ -125,6 +133,7 @@ def grid(META): C.stride(1), C.stride(2), BATCH=BATCH, + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/4_Matrix_vector_multiplication_.py b/backends/triton/cpu/KernelBench/level1/4_Matrix_vector_multiplication_.py index 5407588..55f907e 100644 --- a/backends/triton/cpu/KernelBench/level1/4_Matrix_vector_multiplication_.py +++ b/backends/triton/cpu/KernelBench/level1/4_Matrix_vector_multiplication_.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_K": 32}, - ), + {"BLOCK_K": k}, + ) + for k in [16, 32, 64, 128, 256] ], key=["K"], ) diff --git a/backends/triton/cpu/KernelBench/level1/5_Matrix_scalar_multiplication.py b/backends/triton/cpu/KernelBench/level1/5_Matrix_scalar_multiplication.py index 0b1463a..d3a54cc 100644 --- a/backends/triton/cpu/KernelBench/level1/5_Matrix_scalar_multiplication.py +++ b/backends/triton/cpu/KernelBench/level1/5_Matrix_scalar_multiplication.py @@ -13,8 +13,9 @@ @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE": 32}, - ), + {"BLOCK_SIZE": bs}, + ) + for bs in [32, 64, 128, 256, 512, 1024, 2048] ], key=["n_elements"], ) diff --git a/backends/triton/cpu/KernelBench/level1/6_Matmul_with_large_K_dimension_.py b/backends/triton/cpu/KernelBench/level1/6_Matmul_with_large_K_dimension_.py index 055a1a6..7af5479 100644 --- a/backends/triton/cpu/KernelBench/level1/6_Matmul_with_large_K_dimension_.py +++ b/backends/triton/cpu/KernelBench/level1/6_Matmul_with_large_K_dimension_.py @@ -10,15 +10,15 @@ import triton.language as tl -def _configs(): - return [ +@triton.autotune( + configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - -@triton.autotune(configs=_configs(), key=["M", "N", "K"]) + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], + key=["M", "N", "K"], +) @triton.jit def _matmul_kernel( a_ptr, @@ -87,9 +87,13 @@ def _matmul_triton(A, B): C = torch.empty((M, N), device=A.device, dtype=A.dtype) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - ) + def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), "Matrix dimensions must be divisible by block sizes" + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) _matmul_kernel[grid]( A, @@ -104,6 +108,7 @@ def _matmul_triton(A, B): B.stride(1), C.stride(0), C.stride(1), + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/7_Matmul_with_small_K_dimension_.py b/backends/triton/cpu/KernelBench/level1/7_Matmul_with_small_K_dimension_.py index 7dedcbe..bc9f0ac 100644 --- a/backends/triton/cpu/KernelBench/level1/7_Matmul_with_small_K_dimension_.py +++ b/backends/triton/cpu/KernelBench/level1/7_Matmul_with_small_K_dimension_.py @@ -32,8 +32,9 @@ def swizzle_tile( @triton.autotune( configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] ], key=["M", "N", "K"], ) @@ -98,9 +99,15 @@ def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: C = torch.empty((M, N), device=A.device, dtype=A.dtype) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - ) + def grid(META): + assert ( + M % META["BLOCK_M"] == 0 + and N % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N, and K must be divisible by BLOCK_M, BLOCK_N, and BLOCK_K respectively" + ) + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) _matmul_small_k_kernel[grid]( A, @@ -115,6 +122,7 @@ def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: B.stride(1), C.stride(0), C.stride(1), + assume_in_bounds=True, ) return C diff --git a/backends/triton/cpu/KernelBench/level1/8_Matmul_with_irregular_shapes_.py b/backends/triton/cpu/KernelBench/level1/8_Matmul_with_irregular_shapes_.py index 15b7b38..9b525c6 100644 --- a/backends/triton/cpu/KernelBench/level1/8_Matmul_with_irregular_shapes_.py +++ b/backends/triton/cpu/KernelBench/level1/8_Matmul_with_irregular_shapes_.py @@ -10,16 +10,13 @@ import triton.language as tl -def get_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - @triton.autotune( - configs=get_autotune_configs(), + configs=[ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], key=["M", "N", "K"], ) @triton.jit diff --git a/backends/triton/cpu/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.py b/backends/triton/cpu/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.py index e5605d7..78a0ca7 100644 --- a/backends/triton/cpu/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.py +++ b/backends/triton/cpu/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.py @@ -29,15 +29,15 @@ def swizzle_tile( return pid_m, pid_n -def _configs(): - return [ +@triton.autotune( + configs=[ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, - ), - ] - - -@triton.autotune(configs=_configs(), key=["M", "N", "K"]) + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": gs}, + ) + for gs in [1, 2, 4, 8] + ], + key=["M", "N", "K"], +) @triton.jit def _matmul_kernel( a_ptr, @@ -104,9 +104,18 @@ def forward(self, A, B): C = torch.empty((M_out, N_out), device=device, dtype=A.dtype) - grid = lambda META: ( - triton.cdiv(M_out, META["BLOCK_M"]) * triton.cdiv(N_out, META["BLOCK_N"]), - ) + def grid(META): + assert ( + M_out % META["BLOCK_M"] == 0 + and N_out % META["BLOCK_N"] == 0 + and K % META["BLOCK_K"] == 0 + ), ( + "M, N, and K must be divisible by BLOCK_M, BLOCK_N, and BLOCK_K respectively" + ) + return ( + triton.cdiv(M_out, META["BLOCK_M"]) + * triton.cdiv(N_out, META["BLOCK_N"]), + ) _matmul_kernel[grid]( A, @@ -121,6 +130,7 @@ def forward(self, A, B): B.stride(1), C.stride(0), C.stride(1), + assume_in_bounds=True, ) return C diff --git a/problems/specs/KernelBench/level1/10_3D_tensor_matrix_multiplication.yaml b/problems/specs/KernelBench/level1/10_3D_tensor_matrix_multiplication.yaml index 2f7f86a..4c873f5 100644 --- a/problems/specs/KernelBench/level1/10_3D_tensor_matrix_multiplication.yaml +++ b/problems/specs/KernelBench/level1/10_3D_tensor_matrix_multiplication.yaml @@ -26,6 +26,16 @@ simple-cpu: L: 96 flop: "2*N*M*L*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + N: 16 + M: 1024 + K: 2048 + L: 768 + flop: "2*N*M*L*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/11_4D_tensor_matrix_multiplication.yaml b/problems/specs/KernelBench/level1/11_4D_tensor_matrix_multiplication.yaml index 81822b0..5065763 100644 --- a/problems/specs/KernelBench/level1/11_4D_tensor_matrix_multiplication.yaml +++ b/problems/specs/KernelBench/level1/11_4D_tensor_matrix_multiplication.yaml @@ -28,6 +28,17 @@ simple-cpu: K: 96 flop: "2*B*I*J*K*L" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + B: 8 + I: 256 + J: 512 + L: 256 + K: 768 + flop: "2*B*I*J*K*L" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/12_Matmul_with_diagonal_matrices_.yaml b/problems/specs/KernelBench/level1/12_Matmul_with_diagonal_matrices_.yaml index 1e7a076..78221c6 100644 --- a/problems/specs/KernelBench/level1/12_Matmul_with_diagonal_matrices_.yaml +++ b/problems/specs/KernelBench/level1/12_Matmul_with_diagonal_matrices_.yaml @@ -22,6 +22,14 @@ simple-cpu: N: 128 flop: "2*M*N" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 4096 + N: 4096 + flop: "2*M*N" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/13_Matmul_for_symmetric_matrices.yaml b/problems/specs/KernelBench/level1/13_Matmul_for_symmetric_matrices.yaml index 3034011..63fb38f 100644 --- a/problems/specs/KernelBench/level1/13_Matmul_for_symmetric_matrices.yaml +++ b/problems/specs/KernelBench/level1/13_Matmul_for_symmetric_matrices.yaml @@ -20,4 +20,11 @@ simple-cpu: - params: [A, B] dtype: bfloat16 dims: - N: 128 + N: 256 + +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + N: 4096 + atol: 0.5 diff --git a/problems/specs/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.yaml b/problems/specs/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.yaml index 9247081..566257d 100644 --- a/problems/specs/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.yaml +++ b/problems/specs/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.yaml @@ -20,4 +20,11 @@ simple-cpu: - params: [A, B] dtype: bfloat16 dims: - N: 128 + N: 256 + +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + N: 4096 + atol: 0.5 diff --git a/problems/specs/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.yaml b/problems/specs/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.yaml index 6ec276c..ec656e8 100644 --- a/problems/specs/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.yaml +++ b/problems/specs/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.yaml @@ -20,4 +20,11 @@ simple-cpu: - params: [A, B] dtype: bfloat16 dims: - M: 128 + M: 256 + +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 4096 + atol: 1 diff --git a/problems/specs/KernelBench/level1/16_Matmul_with_transposed_A.yaml b/problems/specs/KernelBench/level1/16_Matmul_with_transposed_A.yaml index 3c22b31..698dc2e 100644 --- a/problems/specs/KernelBench/level1/16_Matmul_with_transposed_A.yaml +++ b/problems/specs/KernelBench/level1/16_Matmul_with_transposed_A.yaml @@ -19,9 +19,19 @@ simple-cpu: - params: [A, B] dtype: bfloat16 dims: - M: 64 - N: 128 - K: 256 + M: 256 + N: 512 + K: 1024 + flop: "2*M*N*K" + +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 1024 + N: 2048 + K: 4096 + atol: 1 flop: "2*M*N*K" bench-gpu: diff --git a/problems/specs/KernelBench/level1/17_Matmul_with_transposed_B.yaml b/problems/specs/KernelBench/level1/17_Matmul_with_transposed_B.yaml index 1e1e095..9eaf2df 100644 --- a/problems/specs/KernelBench/level1/17_Matmul_with_transposed_B.yaml +++ b/problems/specs/KernelBench/level1/17_Matmul_with_transposed_B.yaml @@ -19,9 +19,19 @@ simple-cpu: - params: [A, B] dtype: bfloat16 dims: - M: 64 - N: 128 - K: 256 + M: 256 + N: 512 + K: 1024 + flop: "2*M*N*K" + +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 1024 + N: 2048 + K: 4096 + atol: 1 flop: "2*M*N*K" bench-gpu: diff --git a/problems/specs/KernelBench/level1/18_Matmul_with_transposed_both.yaml b/problems/specs/KernelBench/level1/18_Matmul_with_transposed_both.yaml index d684c28..f8ae573 100644 --- a/problems/specs/KernelBench/level1/18_Matmul_with_transposed_both.yaml +++ b/problems/specs/KernelBench/level1/18_Matmul_with_transposed_both.yaml @@ -19,9 +19,19 @@ simple-cpu: - params: [A, B] dtype: bfloat16 dims: - M: 64 - N: 128 - K: 256 + M: 256 + N: 512 + K: 1024 + flop: "2*N*M*K" + +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 1024 + N: 2048 + K: 4096 + atol: 1 flop: "2*N*M*K" bench-gpu: diff --git a/problems/specs/KernelBench/level1/19_ReLU.yaml b/problems/specs/KernelBench/level1/19_ReLU.yaml index 3ccc889..effc8f4 100644 --- a/problems/specs/KernelBench/level1/19_ReLU.yaml +++ b/problems/specs/KernelBench/level1/19_ReLU.yaml @@ -19,12 +19,11 @@ simple-cpu: bench-cpu: - params: [X] - dtype: bfloat16 + dtype: float32 dims: - BATCH: 128 - DIM: 2048 - flop: "BATCH*DIM" - mem_bytes: "(2*BATCH*DIM) * 2" # f16 + BATCH: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH*DIM) * 4" # f32 bench-gpu: - params: [X] diff --git a/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml b/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml index fef7c31..c5cc18f 100644 --- a/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml +++ b/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml @@ -16,18 +16,14 @@ simple-cpu: - params: [A, B] dtype: bfloat16 dims: - N: 256 + N: 512 bench-cpu: - - params: [A, B] - dtype: float32 - dims: - N: 1024 - flop: "2*N*N*N" - params: [A, B] dtype: bfloat16 dims: - N: 1024 + N: 4096 + atol: 1 flop: "2*N*N*N" bench-gpu: diff --git a/problems/specs/KernelBench/level1/20_LeakyReLU.yaml b/problems/specs/KernelBench/level1/20_LeakyReLU.yaml index bed06d7..b3b2804 100644 --- a/problems/specs/KernelBench/level1/20_LeakyReLU.yaml +++ b/problems/specs/KernelBench/level1/20_LeakyReLU.yaml @@ -19,6 +19,14 @@ simple-cpu: DIM: 512 flop: "2*BATCH*DIM" +bench-cpu: + - params: [X] + dtype: float32 + dims: + BATCH: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH*DIM) * 4" # f32 + bench-gpu: - params: [X] dtype: float16 diff --git a/problems/specs/KernelBench/level1/21_Sigmoid.yaml b/problems/specs/KernelBench/level1/21_Sigmoid.yaml index 7f40ac7..fa3eec1 100644 --- a/problems/specs/KernelBench/level1/21_Sigmoid.yaml +++ b/problems/specs/KernelBench/level1/21_Sigmoid.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/22_Tanh.yaml b/problems/specs/KernelBench/level1/22_Tanh.yaml index 7f40ac7..fa3eec1 100644 --- a/problems/specs/KernelBench/level1/22_Tanh.yaml +++ b/problems/specs/KernelBench/level1/22_Tanh.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/23_Softmax.yaml b/problems/specs/KernelBench/level1/23_Softmax.yaml index ef98da3..d6efb6d 100644 --- a/problems/specs/KernelBench/level1/23_Softmax.yaml +++ b/problems/specs/KernelBench/level1/23_Softmax.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/24_LogSoftmax.yaml b/problems/specs/KernelBench/level1/24_LogSoftmax.yaml index 7f40ac7..fa3eec1 100644 --- a/problems/specs/KernelBench/level1/24_LogSoftmax.yaml +++ b/problems/specs/KernelBench/level1/24_LogSoftmax.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/25_Swish.yaml b/problems/specs/KernelBench/level1/25_Swish.yaml index 7f40ac7..fa3eec1 100644 --- a/problems/specs/KernelBench/level1/25_Swish.yaml +++ b/problems/specs/KernelBench/level1/25_Swish.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/26_GELU_.yaml b/problems/specs/KernelBench/level1/26_GELU_.yaml index 7f40ac7..fa3eec1 100644 --- a/problems/specs/KernelBench/level1/26_GELU_.yaml +++ b/problems/specs/KernelBench/level1/26_GELU_.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/27_SELU_.yaml b/problems/specs/KernelBench/level1/27_SELU_.yaml index 7f40ac7..fa3eec1 100644 --- a/problems/specs/KernelBench/level1/27_SELU_.yaml +++ b/problems/specs/KernelBench/level1/27_SELU_.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/28_HardSigmoid.yaml b/problems/specs/KernelBench/level1/28_HardSigmoid.yaml index 7f40ac7..fa3eec1 100644 --- a/problems/specs/KernelBench/level1/28_HardSigmoid.yaml +++ b/problems/specs/KernelBench/level1/28_HardSigmoid.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/29_Softplus.yaml b/problems/specs/KernelBench/level1/29_Softplus.yaml index 7f40ac7..fa3eec1 100644 --- a/problems/specs/KernelBench/level1/29_Softplus.yaml +++ b/problems/specs/KernelBench/level1/29_Softplus.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml b/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml index 3e2b28a..54a243a 100644 --- a/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml +++ b/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml @@ -18,19 +18,20 @@ simple-cpu: - params: [A, B] dtype: bfloat16 dims: - M: 256 + M: 512 N: 1024 - K: 512 + K: 256 bench-cpu: - params: [A, B] dtype: bfloat16 dims: - M: 128 - N: 256 - K: 512 + M: 3072 + N: 3072 + K: 4096 + atol: 1 flop: "2*M*N*K" - mem_bytes: "(M*K + K*N + M*N) * 2" # f16 + mem_bytes: "(M*K + K*N + M*N) * 2" # bf16 bench-gpu: - params: [A, B] diff --git a/problems/specs/KernelBench/level1/30_Softsign.yaml b/problems/specs/KernelBench/level1/30_Softsign.yaml index 7f40ac7..fa3eec1 100644 --- a/problems/specs/KernelBench/level1/30_Softsign.yaml +++ b/problems/specs/KernelBench/level1/30_Softsign.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/31_ELU.yaml b/problems/specs/KernelBench/level1/31_ELU.yaml index 092c0ed..d1dc663 100644 --- a/problems/specs/KernelBench/level1/31_ELU.yaml +++ b/problems/specs/KernelBench/level1/31_ELU.yaml @@ -21,3 +21,12 @@ simple-cpu: BATCH_SIZE: 128 DIM: 512 ALPHA: 1.0 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + ALPHA: 1.0 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/32_HardTanh.yaml b/problems/specs/KernelBench/level1/32_HardTanh.yaml index 7f40ac7..fa3eec1 100644 --- a/problems/specs/KernelBench/level1/32_HardTanh.yaml +++ b/problems/specs/KernelBench/level1/32_HardTanh.yaml @@ -18,3 +18,11 @@ simple-cpu: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: [x] + dtype: float32 + dims: + BATCH_SIZE: 4096 + DIM: 393216 + mem_bytes: "(2*BATCH_SIZE*DIM) * 4" # f32 diff --git a/problems/specs/KernelBench/level1/3_Batched_matrix_multiplication.yaml b/problems/specs/KernelBench/level1/3_Batched_matrix_multiplication.yaml index 5c11555..0466954 100644 --- a/problems/specs/KernelBench/level1/3_Batched_matrix_multiplication.yaml +++ b/problems/specs/KernelBench/level1/3_Batched_matrix_multiplication.yaml @@ -20,9 +20,20 @@ simple-cpu: dtype: bfloat16 dims: BATCH: 2 - M: 64 - N: 128 - K: 128 + M: 512 + N: 1024 + K: 256 + +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + BATCH: 128 + M: 512 + N: 1024 + K: 2048 + atol: 1 + flop: "2*BATCH*M*N*K" bench-gpu: - params: [A, B] diff --git a/problems/specs/KernelBench/level1/4_Matrix_vector_multiplication_.yaml b/problems/specs/KernelBench/level1/4_Matrix_vector_multiplication_.yaml index 36bce41..cefd90c 100644 --- a/problems/specs/KernelBench/level1/4_Matrix_vector_multiplication_.yaml +++ b/problems/specs/KernelBench/level1/4_Matrix_vector_multiplication_.yaml @@ -24,6 +24,15 @@ simple-cpu: K: 512 flop: "2*M*N*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 2048 + N: 1 + K: 16384 + flop: "2*M*N*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/5_Matrix_scalar_multiplication.yaml b/problems/specs/KernelBench/level1/5_Matrix_scalar_multiplication.yaml index 99d0101..d6e4c4f 100644 --- a/problems/specs/KernelBench/level1/5_Matrix_scalar_multiplication.yaml +++ b/problems/specs/KernelBench/level1/5_Matrix_scalar_multiplication.yaml @@ -24,6 +24,15 @@ simple-cpu: UNIT: 1 flop: "M*N" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 16384 + N: 4096 + UNIT: 1 + flop: "M*N" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/6_Matmul_with_large_K_dimension_.yaml b/problems/specs/KernelBench/level1/6_Matmul_with_large_K_dimension_.yaml index ff5ccde..f4a495f 100644 --- a/problems/specs/KernelBench/level1/6_Matmul_with_large_K_dimension_.yaml +++ b/problems/specs/KernelBench/level1/6_Matmul_with_large_K_dimension_.yaml @@ -19,10 +19,20 @@ simple-cpu: - params: [A, B] dtype: bfloat16 dims: - M: 64 - N: 32 - K: 512 + M: 256 + N: 256 + K: 2048 + flop: "2*M*N*K" + +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 256 + N: 256 + K: 16384 flop: "2*M*N*K" + atol: 0.25 bench-gpu: - params: [A, B] diff --git a/problems/specs/KernelBench/level1/7_Matmul_with_small_K_dimension_.yaml b/problems/specs/KernelBench/level1/7_Matmul_with_small_K_dimension_.yaml index 855daf5..11ebce6 100644 --- a/problems/specs/KernelBench/level1/7_Matmul_with_small_K_dimension_.yaml +++ b/problems/specs/KernelBench/level1/7_Matmul_with_small_K_dimension_.yaml @@ -24,6 +24,15 @@ simple-cpu: K: 32 flop: "2*M*N*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 8192 + N: 8192 + K: 64 + flop: "2*M*N*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/8_Matmul_with_irregular_shapes_.yaml b/problems/specs/KernelBench/level1/8_Matmul_with_irregular_shapes_.yaml index a05e921..89d7f26 100644 --- a/problems/specs/KernelBench/level1/8_Matmul_with_irregular_shapes_.yaml +++ b/problems/specs/KernelBench/level1/8_Matmul_with_irregular_shapes_.yaml @@ -24,6 +24,15 @@ simple-cpu: K: 191 flop: "2*M*N*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 8205 + N: 5921 + K: 2949 + flop: "2*M*N*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.yaml b/problems/specs/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.yaml index c990777..7376aae 100644 --- a/problems/specs/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.yaml +++ b/problems/specs/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.yaml @@ -22,6 +22,13 @@ simple-cpu: N: 32 flop: "2*M*M*N" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 8192 + N: 32 + bench-gpu: - params: [A, B] dtype: float16