From f918c85575ce48aceb236086e3a77ba41ec4371b Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 3 Mar 2026 15:50:00 +0100 Subject: [PATCH] [harness] Device-specific kernels for Triton and Helion Adds device-specific subdirectories to Triton and Helion backends. The deeper nesting allows for separate kernel implementations that may incode target-specific optimizations. For maintenance simplicity and readability, the split is chosen over maintaining multiple kernel implementations in a single file. The separate files act as entry points for the runner. In the future, truly universal kernels can be stored in a separate location and backend file structure might offer only simple redirection. While these backends support running the same kernel on different devices, encoding target-specific details can improve performance. The baseline PyTorch backend still relies on a single implementation thanks to its higher abstraction. Future backends should pick the most suitable structure for their needs. --- .github/workflows/kernel_bench.yml | 70 ++++++--- README.md | 2 +- .../harness/runner/kernel_bench_runner.py | 8 +- .../level1/1_Square_matrix_multiplication_.py | 103 ++++++++++++++ .../level1/1_Square_matrix_multiplication_.py | 5 + .../level1/1_Square_matrix_multiplication_.py | 134 ++++++++++++++++++ .../level1/16_Matmul_with_transposed_A.py | 0 .../81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py | 0 .../95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py | 0 .../level2/99_Matmul_GELU_Softmax.py | 0 infra/scripts/ci-cuda-run-kernel-bench.sh | 17 ++- tests/test_backend.py | 14 +- 12 files changed, 323 insertions(+), 30 deletions(-) create mode 100644 backends/helion/cuda/KernelBench/level1/1_Square_matrix_multiplication_.py rename backends/helion/{ => xpu}/KernelBench/level1/1_Square_matrix_multiplication_.py (94%) create mode 100644 backends/triton/cuda/KernelBench/level1/1_Square_matrix_multiplication_.py rename backends/triton/{ => xpu}/KernelBench/level1/16_Matmul_with_transposed_A.py (100%) rename backends/triton/{ => xpu}/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py (100%) rename backends/triton/{ => xpu}/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py (100%) rename backends/triton/{ => xpu}/KernelBench/level2/99_Matmul_GELU_Softmax.py (100%) diff --git a/.github/workflows/kernel_bench.yml b/.github/workflows/kernel_bench.yml index b4e628b..d6cc26b 100644 --- a/.github/workflows/kernel_bench.yml +++ b/.github/workflows/kernel_bench.yml @@ -3,32 +3,36 @@ name: KernelBench Perf on: workflow_dispatch: inputs: - RUN_CPU_TORCH: - description: "Run on CPU (PyTorch eager)" + DEVICE_CPU: + description: "Device: CPU" type: boolean default: false - RUN_CPU_MLIR: - description: "Run on CPU (MLIR)" + DEVICE_XPU: + description: "Device: Intel GPU" + type: boolean + default: true + DEVICE_CUDA: + description: "Device: Nvidia GPU" type: boolean default: false - RUN_XPU_TORCH: - description: "Run on Intel GPU (PyTorch eager)" + BACKEND_PYTORCH: + description: "Backend: PyTorch (eager)" type: boolean default: true - RUN_XPU_TORCH_COMPILE: - description: "Run on Intel GPU (PyTorch compile)" + BACKEND_PYTORCH_COMPILE: + description: "Backend: PyTorch (compile)" type: boolean default: false - RUN_XPU_TRITON: - description: "Run on Intel GPU (Triton)" + BACKEND_TRITON: + description: "Backend: Triton" type: boolean default: false - RUN_XPU_HELION: - description: "Run on Intel GPU (Helion)" + BACKEND_HELION: + description: "Backend: Helion" type: boolean default: false - RUN_CUDA_TORCH: - description: "Run on Nvidia GPU (PyTorch eager)" + BACKEND_MLIR: + description: "Backend: MLIR" type: boolean default: false @@ -40,7 +44,7 @@ jobs: CPU-PyTorch: runs-on: pcl-tiergarten if: | - (github.event_name == 'workflow_dispatch' && inputs.RUN_CPU_TORCH) + (github.event_name == 'workflow_dispatch' && inputs.DEVICE_CPU && inputs.BACKEND_PYTORCH) steps: - uses: actions/checkout@v5 @@ -52,7 +56,7 @@ jobs: CPU-MLIR: runs-on: pcl-tiergarten if: | - (github.event_name == 'workflow_dispatch' && inputs.RUN_CPU_MLIR) + (github.event_name == 'workflow_dispatch' && inputs.DEVICE_CPU && inputs.BACKEND_MLIR) steps: - uses: actions/checkout@v5 @@ -64,7 +68,7 @@ jobs: XPU-PyTorch: runs-on: pcl-tiergarten if: | - (github.event_name == 'workflow_dispatch' && inputs.RUN_XPU_TORCH) + (github.event_name == 'workflow_dispatch' && inputs.DEVICE_XPU && inputs.BACKEND_PYTORCH) steps: - uses: actions/checkout@v5 @@ -76,7 +80,7 @@ jobs: XPU-PyTorch-Compile: runs-on: pcl-tiergarten if: | - (github.event_name == 'workflow_dispatch' && inputs.RUN_XPU_TORCH_COMPILE) + (github.event_name == 'workflow_dispatch' && inputs.DEVICE_XPU && inputs.BACKEND_PYTORCH_COMPILE) steps: - uses: actions/checkout@v5 @@ -88,7 +92,7 @@ jobs: XPU-Triton: runs-on: pcl-tiergarten if: | - (github.event_name == 'workflow_dispatch' && inputs.RUN_XPU_TRITON) + (github.event_name == 'workflow_dispatch' && inputs.DEVICE_XPU && inputs.BACKEND_TRITON) steps: - uses: actions/checkout@v5 @@ -100,7 +104,7 @@ jobs: XPU-Helion: runs-on: pcl-tiergarten if: | - (github.event_name == 'workflow_dispatch' && inputs.RUN_XPU_HELION) + (github.event_name == 'workflow_dispatch' && inputs.DEVICE_XPU && inputs.BACKEND_HELION) steps: - uses: actions/checkout@v5 @@ -112,7 +116,7 @@ jobs: CUDA-PyTorch: runs-on: pcl-tiergarten if: | - (github.event_name == 'workflow_dispatch' && inputs.RUN_CUDA_TORCH) + (github.event_name == 'workflow_dispatch' && inputs.DEVICE_CUDA && inputs.BACKEND_PYTORCH) steps: - uses: actions/checkout@v5 @@ -120,3 +124,27 @@ jobs: - name: Nvidia A100 run: "${{ env.SRUN }} --partition=a100 --time=0:15:00 -- \ '${{ github.workspace }}/infra/scripts/ci-cuda-run-kernel-bench.sh -b torch'" + + CUDA-Triton: + runs-on: pcl-tiergarten + if: | + (github.event_name == 'workflow_dispatch' && inputs.DEVICE_CUDA && inputs.BACKEND_TRITON) + + steps: + - uses: actions/checkout@v5 + + - name: Nvidia A100 + run: "${{ env.SRUN }} --partition=a100 --time=0:15:00 -- \ + '${{ github.workspace }}/infra/scripts/ci-cuda-run-kernel-bench.sh -b triton'" + + CUDA-Helion: + runs-on: pcl-tiergarten + if: | + (github.event_name == 'workflow_dispatch' && inputs.DEVICE_CUDA && inputs.BACKEND_HELION) + + steps: + - uses: actions/checkout@v5 + + - name: Nvidia A100 + run: "${{ env.SRUN }} --partition=a100 --time=0:15:00 -- \ + '${{ github.workspace }}/infra/scripts/ci-cuda-run-kernel-bench.sh -b helion'" diff --git a/README.md b/README.md index af209b8..90d2123 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ A benchmarking framework for evaluating AI kernel implementations across multipl |:---:|:---:|:---:|:---:|:---:| | **CPU** | ✅ | ❌ | ❌ | ✅ | | **XPU** | ✅ | ✅ | ✅ | ❌ | -| **CUDA** | ✅ | ⚠️ | ⚠️ | ❌ | +| **CUDA** | ✅ | ✅ | ✅ | ❌ | ✅ - Supported ⚠️ - Partially implemented ❌ - Unsupported diff --git a/ai_bench/harness/runner/kernel_bench_runner.py b/ai_bench/harness/runner/kernel_bench_runner.py index 4d69d6c..f610786 100644 --- a/ai_bench/harness/runner/kernel_bench_runner.py +++ b/ai_bench/harness/runner/kernel_bench_runner.py @@ -75,9 +75,13 @@ def __init__( if self.is_torch_backend(): self.kernels = ai_utils.kernel_bench_dir() / "KernelBench" elif self.backend == ai_hc.Backend.TRITON: - self.kernels = ai_utils.triton_kernels_dir() / "KernelBench" + self.kernels = ( + ai_utils.triton_kernels_dir() / self.device.type / "KernelBench" + ) elif self.backend == ai_hc.Backend.HELION: - self.kernels = ai_utils.helion_kernels_dir() / "KernelBench" + self.kernels = ( + ai_utils.helion_kernels_dir() / self.device.type / "KernelBench" + ) elif self.backend == ai_hc.Backend.MLIR: self.kernels = ( ai_utils.mlir_kernels_dir() / self.device.type / "KernelBench" diff --git a/backends/helion/cuda/KernelBench/level1/1_Square_matrix_multiplication_.py b/backends/helion/cuda/KernelBench/level1/1_Square_matrix_multiplication_.py new file mode 100644 index 0000000..c610116 --- /dev/null +++ b/backends/helion/cuda/KernelBench/level1/1_Square_matrix_multiplication_.py @@ -0,0 +1,103 @@ +# Example Helion CUDA kernel +# Source: helion matmul example +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import helion +import helion.language as hl +import torch +import torch.nn as nn + + +@helion.kernel( + static_shapes=True, + configs=[ + helion.Config( + block_sizes=[64, 128, 16], + indexing="tensor_descriptor", + l2_groupings=[32], + loop_orders=[[1, 0]], + num_stages=2, + num_warps=8, + pid_type="flat", + range_flattens=[None, None], + range_multi_buffers=[None, None], + range_num_stages=[0, 2], + range_unroll_factors=[0, 1], + ), + helion.Config( + block_sizes=[256, 256, 32], + indexing="tensor_descriptor", + l2_groupings=[4], + loop_orders=[[0, 1]], + num_stages=2, + num_warps=32, + pid_type="flat", + range_flattens=[None, False], + range_multi_buffers=[None, False], + range_num_stages=[0, 2], + range_unroll_factors=[0, 1], + ), + helion.Config( + block_sizes=[256, 128, 32], + indexing="tensor_descriptor", + l2_groupings=[32], + loop_orders=[[0, 1]], + num_stages=4, + num_warps=32, + pid_type="persistent_interleaved", + range_flattens=[None, False], + range_multi_buffers=[True, False], + range_num_stages=[1, 4], + range_unroll_factors=[4, 1], + ), + helion.Config( + block_sizes=[128, 256, 16], + indexing="tensor_descriptor", + l2_groupings=[4], + loop_orders=[[0, 1]], + num_stages=5, + num_warps=32, + pid_type="persistent_interleaved", + range_flattens=[None, True], + range_multi_buffers=[False, False], + range_num_stages=[1, 4], + range_unroll_factors=[2, 0], + ), + ], +) +def _square_matmul_kernel(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + """ + Performs square matrix multiplication using Helion. + C = A * B + + Args: + A: Input matrix A of shape (N, N) + B: Input matrix B of shape (N, N) + + Returns: + Output matrix C of shape (N, N) + """ + N, N2 = A.size() + N3, N4 = B.size() + assert N == N2 == N3 == N4, f"size mismatch: A{A.size()}, B{B.size()}" + + out = torch.empty( + [N, N], dtype=torch.promote_types(A.dtype, B.dtype), device=A.device + ) + + for tile_m, tile_n in hl.tile([N, N]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(N): + acc = torch.addmm(acc, A[tile_m, tile_k], B[tile_k, tile_n]) + out[tile_m, tile_n] = acc + + return out + + +class Model(nn.Module): + def __init__(self, *args, **kwargs): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + return _square_matmul_kernel(A, B) diff --git a/backends/helion/KernelBench/level1/1_Square_matrix_multiplication_.py b/backends/helion/xpu/KernelBench/level1/1_Square_matrix_multiplication_.py similarity index 94% rename from backends/helion/KernelBench/level1/1_Square_matrix_multiplication_.py rename to backends/helion/xpu/KernelBench/level1/1_Square_matrix_multiplication_.py index d60762b..5c54e73 100644 --- a/backends/helion/KernelBench/level1/1_Square_matrix_multiplication_.py +++ b/backends/helion/xpu/KernelBench/level1/1_Square_matrix_multiplication_.py @@ -1,3 +1,8 @@ +# Example Helion XPU kernel +# Source: helion matmul example +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + import helion import helion.language as hl import torch diff --git a/backends/triton/cuda/KernelBench/level1/1_Square_matrix_multiplication_.py b/backends/triton/cuda/KernelBench/level1/1_Square_matrix_multiplication_.py new file mode 100644 index 0000000..9c83761 --- /dev/null +++ b/backends/triton/cuda/KernelBench/level1/1_Square_matrix_multiplication_.py @@ -0,0 +1,134 @@ +# ruff: noqa: E731 +# Example Triton CUDA kernel +# Source: triton-lang/triton matmul tutorial +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=4 + ), + ], + key=["M", "N", "K"], # autotune per problem size +) +@triton.jit +def _matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + + acc = tl.dot(a, b, acc) + + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc, mask=c_mask) + + +def _kernel_function_cuda(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor) + assert hasattr(torch, "cuda"), "torch.cuda is required for this kernel" + assert A.device.type == "cuda" and B.device.type == "cuda", ( + "A and B must be on CUDA" + ) + assert A.is_floating_point() and B.is_floating_point(), ( + "A and B must be floating point tensors" + ) + assert A.dtype == B.dtype, f"dtype mismatch: {A.dtype} vs {B.dtype}" + + orig_dtype = A.dtype + + K, M = A.shape + K2, N = B.shape + assert K == K2, f"Incompatible K dimensions: {K} vs {K2}" + + C32 = torch.empty((M, N), device=A.device, dtype=torch.float32) + + stride_A0, stride_A1 = A.stride() + stride_B0, stride_B1 = B.stride() + stride_C0, stride_C1 = C32.stride() + + # Autotuned grid: depends on BLOCK_M/BLOCK_N chosen by config + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]), + triton.cdiv(N, META["BLOCK_N"]), + ) + + _matmul_kernel[grid]( + A, + B, + C32, + M, + N, + K, + stride_A0, + stride_A1, + stride_B0, + stride_B1, + stride_C0, + stride_C1, + ) + + torch.accelerator.synchronize() + return C32.to(orig_dtype) + + +class Model(nn.Module): + """KernelBench-compatible wrapper""" + + def __init__(self, *args, **kwargs): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + return _kernel_function_cuda(A, B) diff --git a/backends/triton/KernelBench/level1/16_Matmul_with_transposed_A.py b/backends/triton/xpu/KernelBench/level1/16_Matmul_with_transposed_A.py similarity index 100% rename from backends/triton/KernelBench/level1/16_Matmul_with_transposed_A.py rename to backends/triton/xpu/KernelBench/level1/16_Matmul_with_transposed_A.py diff --git a/backends/triton/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py b/backends/triton/xpu/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py similarity index 100% rename from backends/triton/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py rename to backends/triton/xpu/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py diff --git a/backends/triton/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py b/backends/triton/xpu/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py similarity index 100% rename from backends/triton/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py rename to backends/triton/xpu/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py diff --git a/backends/triton/KernelBench/level2/99_Matmul_GELU_Softmax.py b/backends/triton/xpu/KernelBench/level2/99_Matmul_GELU_Softmax.py similarity index 100% rename from backends/triton/KernelBench/level2/99_Matmul_GELU_Softmax.py rename to backends/triton/xpu/KernelBench/level2/99_Matmul_GELU_Softmax.py diff --git a/infra/scripts/ci-cuda-run-kernel-bench.sh b/infra/scripts/ci-cuda-run-kernel-bench.sh index b30ce63..befe6fc 100755 --- a/infra/scripts/ci-cuda-run-kernel-bench.sh +++ b/infra/scripts/ci-cuda-run-kernel-bench.sh @@ -9,9 +9,11 @@ SCRIPTS_DIR=$(realpath $(dirname $0)) # Backends BENCH_BACKEND_TORCH="torch" BENCH_BACKEND_TORCH_COMPILE="torch-compile" +BENCH_BACKEND_TRITON="triton" +BENCH_BACKEND_HELION="helion" die_syntax() { - echo "Syntax: $0 [-b (${BENCH_BACKEND_TORCH}|${BENCH_BACKEND_TORCH_COMPILE})]" + echo "Syntax: $0 [-b (${BENCH_BACKEND_TORCH}|${BENCH_BACKEND_TORCH_COMPILE}|${BENCH_BACKEND_TRITON}|${BENCH_BACKEND_HELION})]" echo "" echo " -b: Optional, backend to use (default: torch)" exit 1 @@ -23,7 +25,9 @@ while getopts "b:" arg; do case ${arg} in b) if [ "${OPTARG}" == "${BENCH_BACKEND_TORCH}" ] || \ - [ "${OPTARG}" == "${BENCH_BACKEND_TORCH_COMPILE}" ]; then + [ "${OPTARG}" == "${BENCH_BACKEND_TORCH_COMPILE}" ] || \ + [ "${OPTARG}" == "${BENCH_BACKEND_TRITON}" ] || \ + [ "${OPTARG}" == "${BENCH_BACKEND_HELION}" ]; then BENCH_BACKEND="${OPTARG}" else echo "Invalid backend: ${OPTARG}" @@ -59,6 +63,15 @@ BENCH_FLAGS="--cuda --bench" if [[ "${BENCH_BACKEND}" == "${BENCH_BACKEND_TORCH_COMPILE}" ]]; then BENCH_FLAGS="${BENCH_FLAGS} --torch-compile" fi +if [[ "${BENCH_BACKEND}" == "${BENCH_BACKEND_TRITON}" ]]; then + BENCH_FLAGS="${BENCH_FLAGS} --triton" +fi +if [[ "${BENCH_BACKEND}" == "${BENCH_BACKEND_HELION}" ]]; then + BENCH_FLAGS="${BENCH_FLAGS} --helion" + # Suppress logging to minimize noise in the benchmark output. + export HELION_AUTOTUNE_PROGRESS_BAR=0 + export HELION_AUTOTUNE_LOG_LEVEL=0 +fi ${AI_BENCH_UV} run ai-bench ${BENCH_FLAGS} EXIT_CODE=$? diff --git a/tests/test_backend.py b/tests/test_backend.py index 0e1083a..297dab1 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -302,10 +302,10 @@ def temp_dirs(self): tmpdir / "third_party" / "KernelBench" / "KernelBench" / "level1" ) triton_kernels_dir = ( - tmpdir / "backends" / "triton" / "KernelBench" / "level1" + tmpdir / "backends" / "triton" / "cpu" / "KernelBench" / "level1" ) helion_kernels_dir = ( - tmpdir / "backends" / "helion" / "KernelBench" / "level1" + tmpdir / "backends" / "helion" / "cpu" / "KernelBench" / "level1" ) mlir_cpu_kernels_dir = ( tmpdir / "backends" / "mlir" / "cpu" / "KernelBench" / "level1" @@ -767,8 +767,12 @@ def integration_setup(self): pytorch_dir = ( tmpdir / "third_party" / "KernelBench" / "KernelBench" / "level1" ) - triton_dir = tmpdir / "backends" / "triton" / "KernelBench" / "level1" - helion_dir = tmpdir / "backends" / "helion" / "KernelBench" / "level1" + triton_dir = ( + tmpdir / "backends" / "triton" / "cpu" / "KernelBench" / "level1" + ) + helion_dir = ( + tmpdir / "backends" / "helion" / "cpu" / "KernelBench" / "level1" + ) mlir_dir = tmpdir / "backends" / "mlir" / "cpu" / "KernelBench" / "level1" specs_dir.mkdir(parents=True) @@ -974,6 +978,7 @@ def test_backend_produces_same_results(self, integration_setup): integration_setup / "backends" / "triton" + / "cpu" / "KernelBench" / "level1" / "matmul.py" @@ -982,6 +987,7 @@ def test_backend_produces_same_results(self, integration_setup): integration_setup / "backends" / "helion" + / "cpu" / "KernelBench" / "level1" / "matmul.py"