Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ X86_64_AVX512_BF16_COPTS = [

X86_64_AVX512_VNNI_COPTS = [
"-mavx512f",
"-mavx512bw",
"-mavx512vnni",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ iree_bitcode_library(
"iree_uk_mma_x86_avx512vnni_16x16x2_i32_i8_casti16.c"
COPTS
"-mavx512f"
"-mavx512bw"
"-mavx512vnni"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <immintrin.h>

Check failure on line 7 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:7:10 [clang-diagnostic-error]

'immintrin.h' file not found
#include "common.h"

// Microkernel for `iree_codegen.inner_tiled` with
Expand All @@ -13,40 +13,94 @@
// intrinsic name verbatim (lowercased, with the `iree_uk_` prefix), in line
// with the AMDGPU C ukernel convention.
//
// The "inner K loop" the ukernel owns is the loop over the K *tiles* that
// sits *inside* the outer M/N loops; those outer M/N loops are tiled away by
// ordinary IREE tiling before this ukernel runs. The ukernel handles
// arbitrary positive `intrinsics_{m,n,k}` (passed as arguments and looped
// over); the loops fully unroll after the ukernel is inlined into its
// constant-`intrinsics_*` caller -- the bitcode-LTO equivalent of a C++
// template.
// Adapted from the AMDGPU C ukernel
// `iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8`: same shape — accumulators
// held in registers, an outer loop over the K dimension, and inside it the
// `(intrinsics_m, intrinsics_n, intrinsics_k)` unrolling — minus the
// GPU-specific shared-memory / subgroup machinery.
//
// The "inner K loop" the ukernel owns is the loop over the K *tiles*
// (`k_outer` below) that sits *inside* the outer M/N loops; those outer M/N
// loops are tiled away by ordinary IREE tiling before this ukernel runs.
// This is NOT a restriction to `intrinsics_{m,n,k} = 1`: the ukernel handles
// arbitrary positive `intrinsics_{m,n,k}` via the `for` loops below.
//
// `intrinsics_{m,n,k}` are passed as function arguments and so look like
// runtime values inside this translation unit, but the ukernel is always
// inlined into its caller (a bug otherwise) and the caller always passes the
// matching `DataTiledMMAAttr` constants. Together with post-inline IR
// optimization on the linked bitcode, the `for` loops fully unroll and the
// `acc_regs` VLA becomes a fixed register array specialized to each call
// site's `intrinsics_{m,n,k}` — the bitcode-LTO equivalent of C++ templates.
//
// ABI: each shaped operand is passed as (base pointer, element offset) so the
// caller doesn't need a GEP before the call; the accumulator additionally
// gets the element stride of its innermost cross-intrinsic (N) dimension.
// caller doesn't need a GEP before the call (the offset is added here); the
// accumulator additionally gets the element stride of its innermost
// cross-intrinsic (N) dimension. Offsets/strides are in units of the operand
// element type (bf16 for LHS/RHS, f32 for ACC).
//
// NOTE (seed scaffolding): this initial seed has a stub body. It exists so
// that the surrounding *framework* -- bitcode build, embedding,
// `hal.executable_object` injection, IR rewrite to `ukernel.generic` -- can
// be landed and lit-tested. A follow-up commit replaces the body with the
// `_mm512_dpbf16_ps`-based inner loop and adds an e2e matmul test for it.
// Data-tiled operand layout (matching the `DataTiledMMAAttr` swizzle, same as
// the AMDGPU ukernel):
// - ACC: one `__m512` (= M0=1 x N0=16 f32) per (m, n) intrinsic. The (m, n)
// grid is row-major with `acc_stride` the element stride of the innermost
// cross-intrinsic dim (N), so fragment (m, n) is at
// `acc + (m * intrinsics_n + n) * acc_stride`.
// - LHS: per outer-K step, `intrinsics_m * intrinsics_k` units of 2 bf16
// (= one M0=1 x K0=2 fragment = a 4-byte `vdpbf16ps` m_bcst unit),
// ordered [m][k]; consecutive outer-K steps are contiguous.
// - RHS: per outer-K step, `intrinsics_n * intrinsics_k` panels of 32 bf16
// (= one N0=16 x K0=2 fragment = one `__m512`), ordered [n][k];
// consecutive outer-K steps are contiguous.
IREE_UK_ALWAYS_INLINE
void iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16(

Check warning on line 55 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:55:6 [readability-identifier-naming]

invalid case style for function 'iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16'
const uint16_t *lhs_base, int64_t lhs_offset, const uint16_t *rhs_base,

Check warning on line 56 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:56:67 [readability-identifier-naming]

invalid case style for parameter 'rhs_base'

Check warning on line 56 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:56:39 [readability-identifier-naming]

invalid case style for parameter 'lhs_offset'

Check warning on line 56 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:56:21 [readability-identifier-naming]

invalid case style for parameter 'lhs_base'
int64_t rhs_offset, float *acc_base, int64_t acc_offset, int64_t acc_stride,

Check warning on line 57 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:57:70 [readability-identifier-naming]

invalid case style for parameter 'acc_stride'

Check warning on line 57 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:57:50 [readability-identifier-naming]

invalid case style for parameter 'acc_offset'

Check warning on line 57 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:57:32 [readability-identifier-naming]

invalid case style for parameter 'acc_base'

Check warning on line 57 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:57:13 [readability-identifier-naming]

invalid case style for parameter 'rhs_offset'
int32_t k_outer, int32_t intrinsics_m, int32_t intrinsics_n,

Check warning on line 58 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:58:30 [readability-identifier-naming]

invalid case style for parameter 'intrinsics_m'

Check warning on line 58 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16.c:58:13 [readability-identifier-naming]

invalid case style for parameter 'k_outer'
int32_t intrinsics_k) {
(void)lhs_base;
(void)lhs_offset;
(void)rhs_base;
(void)rhs_offset;
(void)acc_base;
(void)acc_offset;
(void)acc_stride;
(void)k_outer;
(void)intrinsics_m;
(void)intrinsics_n;
(void)intrinsics_k;
// TODO(ukernels): real inner K loop using `_mm512_dpbf16_ps`, looping over
// intrinsics_{m,n,k}.
const uint16_t *lhs = lhs_base + lhs_offset;
const float *rhs = (const float *)(rhs_base + rhs_offset);
float *acc = acc_base + acc_offset;

// One accumulator register (1x16 f32) per (m, n) intrinsic. The VLA
// dimensions are compile-time constants at the inlined call site, so this
// lowers to a fixed register array.
__m512 acc_regs[intrinsics_m][intrinsics_n];
for (int32_t m = 0; m < intrinsics_m; ++m) {
for (int32_t n = 0; n < intrinsics_n; ++n) {
acc_regs[m][n] =
_mm512_loadu_ps(acc + (m * intrinsics_n + n) * acc_stride);
}
}

for (int32_t ko = 0; ko < k_outer; ++ko) {
const uint16_t *lhs_block =
lhs + (int64_t)ko * intrinsics_m * intrinsics_k * 2;
const float *rhs_block =
rhs + (int64_t)ko * intrinsics_n * intrinsics_k * 16;
for (int32_t m = 0; m < intrinsics_m; ++m) {
for (int32_t n = 0; n < intrinsics_n; ++n) {
for (int32_t k = 0; k < intrinsics_k; ++k) {
// LHS fragment: 2 bf16 (one M-row's K-pair) broadcast across the
// 16 SIMD lanes via `set1_ps` (the splat shape `vdpbf16ps`'s
// m_bcst variant pattern-matches). The bitcast to `__m512bh` is a
// width-preserving no-op LLVM elides.
__m512 lhs_bcast = _mm512_set1_ps(
*(const float *)(lhs_block + (m * intrinsics_k + k) * 2));
// RHS fragment: one (N=16 x K=2) bf16 panel = 16 f32.
__m512 rhs_panel =
_mm512_loadu_ps(rhs_block + (n * intrinsics_k + k) * 16);
acc_regs[m][n] =
_mm512_dpbf16_ps(acc_regs[m][n], *(const __m512bh *)&lhs_bcast,
*(const __m512bh *)&rhs_panel);
}
}
}
}

for (int32_t m = 0; m < intrinsics_m; ++m) {
for (int32_t n = 0; n < intrinsics_n; ++n) {
_mm512_storeu_ps(acc + (m * intrinsics_n + n) * acc_stride,
acc_regs[m][n]);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <immintrin.h>

Check failure on line 7 in compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512vnni_16x16x2_i32_i8_casti16.c

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/plugins/target/LLVMCPU/builtins/ukernel/iree_uk_mma_x86_avx512vnni_16x16x2_i32_i8_casti16.c:7:10 [clang-diagnostic-error]

'immintrin.h' file not found
#include "common.h"

// Microkernel for `iree_codegen.inner_tiled` with
Expand All @@ -13,49 +13,113 @@
// intrinsic name verbatim (lowercased, with the `iree_uk_` prefix), in line
// with the AMDGPU C ukernel convention.
//
// Implements the inner K-loop for the unrolled (intrinsics_m, intrinsics_n,
// intrinsics_k) tile built from the 16x16x2 i8 VNNI intrinsic via AVX-512
// VNNI `vpdpwssd`. The "CASTI16" in the MMA intrinsic name reflects that
// the s8 inputs are zero/sign-extended into i16 lanes before being fed to
// the 16-bit VNNI instruction; that cast is handled in the inner loop.
// Structurally identical to `iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16`:
// accumulators held in registers, an outer loop over the K *tiles*
// (`k_outer`), and inside it the `(intrinsics_m, intrinsics_n, intrinsics_k)`
// unroll. This is NOT a restriction to `intrinsics_{m,n,k} = 1`: the ukernel
// handles arbitrary positive `intrinsics_{m,n,k}` via the `for` loops below,
// which fully unroll once inlined into the constant-`intrinsics_*` caller.
//
// `intrinsics_{m,n,k}` are passed as function arguments and so look like
// runtime values inside this translation unit, but the ukernel is always
// inlined into its caller (a bug otherwise) and the caller always passes
// the matching `DataTiledMMAAttr` constants. Together with post-inline IR
// optimization on the linked bitcode, the body specializes to specific
// compile-time `intrinsics_{m,n,k}` values at each call site.
// The "CASTI16" in the intrinsic name reflects that the s8 inputs are
// sign-extended to i16 lanes before being fed to the 16-bit VNNI instruction
// `vpdpwssd` (the 8-bit `vpdpbusd` would mishandle the s8 x s8 signedness);
// that widen happens once per panel in the inner loop.
//
// NOTE (seed scaffolding): this initial seed has a stub body. It exists so
// that the surrounding *framework* -- bitcode build, embedding,
// `hal.executable_object` injection, IR rewrite to `ukernel.generic` -- can
// be exercised end-to-end. A follow-up commit replaces the body with the
// `_mm512_dpwssd_epi32`-based inner loop and adds an e2e matmul test for
// it. This is the "practically useful" seed: i8x i8->i32 via VNNI is a
// workhorse for quantized inference, and codegen has a residual perf gap
// on this case.
// ABI matches the inner_tiled -> ukernel.generic lowering (see
// `iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16`): each shaped operand passed
// as (base, element offset), ACC additionally as its innermost
// cross-intrinsic stride, then the scalar `k_outer` / `intrinsics_{m,n,k}`.
// ABI: each shaped operand is passed as (base pointer, element offset) so the
// caller doesn't need a GEP before the call (the offset is added here); the
// accumulator additionally gets the element stride of its innermost
// cross-intrinsic (N) dimension. Offsets/strides are in units of the operand
// element type (i8 for LHS/RHS, i32 for ACC).
//
// Per-intrinsic (16x16x2) tile, matching `lowerX86Avx512Vnni16x16x2I8` in
// IREECPUAttrs.cpp (the codegen path this ukernel must be bit-compatible
// with) and the `getIntrinsicSwizzle` data layout:
// - LHS: one row-major 16x2 i8 panel (= 32 i8 = <32xi8>). Dword `x` of its
// i16-widened form holds the (k0, k1) pair of LHS row `x`.
// - RHS: one row-major 16x2 i8 panel (= 32 i8), same shape; dword `x` holds
// column `x`'s (k0, k1) pair.
// - ACC: one 16x16 i32 tile (= 256 i32) in the block-interleaved
// (rlo, chi, rhi, clo) order, with row r = 4*rhi + rlo and column
// c = 4*chi + clo. The 16 i32 at flat offset (4*rlo + chi)*16 are one
// `vpdpwssd` accumulator, whose dword `4*rhi + clo` holds ACC element
// (r, c). Fragments (m, n) are `acc_stride` apart along N.
IREE_UK_ALWAYS_INLINE
void iree_uk_mma_x86_avx512vnni_16x16x2_i32_i8_casti16(
const void *lhs_base, int64_t lhs_offset, const void *rhs_base,
int64_t rhs_offset, void *acc_base, int64_t acc_offset, int64_t acc_stride,
int32_t k_outer, int32_t intrinsics_m, int32_t intrinsics_n,
int32_t intrinsics_k) {
(void)lhs_base;
(void)lhs_offset;
(void)rhs_base;
(void)rhs_offset;
(void)acc_base;
(void)acc_offset;
(void)acc_stride;
(void)k_outer;
(void)intrinsics_m;
(void)intrinsics_n;
(void)intrinsics_k;
// TODO(ukernels): real inner K loop using `_mm512_dpwssd_epi32` after
// widening the s8 LHS/RHS halves to i16 lanes (loop over
// intrinsics_{m,n,k} like the bf16 ukernel).
const int8_t *lhs = (const int8_t *)lhs_base + lhs_offset;
const int8_t *rhs = (const int8_t *)rhs_base + rhs_offset;
int32_t *acc = (int32_t *)acc_base + acc_offset;

// 256 i32 (= 16 __m512i `vpdpwssd` accumulators) per (m, n) intrinsic. The
// VLA dimensions are compile-time constants at the inlined call site, so
// this lowers to a fixed register array.
__m512i acc_regs[intrinsics_m][intrinsics_n][16];
for (int32_t m = 0; m < intrinsics_m; ++m) {
for (int32_t n = 0; n < intrinsics_n; ++n) {
const int32_t *frag = acc + (m * intrinsics_n + n) * acc_stride;
for (int c = 0; c < 16; ++c) {
acc_regs[m][n][c] = _mm512_loadu_si512(frag + c * 16);
}
}
}

for (int32_t ko = 0; ko < k_outer; ++ko) {
// Each (m, k) / (n, k) fragment is one 16x2 i8 panel = 32 i8.
const int8_t *lhs_block =
lhs + (int64_t)ko * intrinsics_m * intrinsics_k * 32;
const int8_t *rhs_block =
rhs + (int64_t)ko * intrinsics_n * intrinsics_k * 32;
for (int32_t m = 0; m < intrinsics_m; ++m) {
for (int32_t n = 0; n < intrinsics_n; ++n) {
__m512i(*regs)[16] = &acc_regs[m][n];
for (int32_t k = 0; k < intrinsics_k; ++k) {
// Widen each i8 panel to i16 once (one `vpmovsxbw`); dword `x` then
// holds the (k0, k1) pair of LHS row / RHS column `x`.
__m512i lhs_i16 = _mm512_cvtepi8_epi16(_mm256_loadu_si256(
(const __m256i *)(lhs_block + (m * intrinsics_k + k) * 32)));
__m512i rhs_i16 = _mm512_cvtepi8_epi16(_mm256_loadu_si256(
(const __m256i *)(rhs_block + (n * intrinsics_k + k) * 32)));
// lhs_dup[rlo]: `vpshufd` broadcasting dword `4*lane + rlo` across
// each 128-bit lane (lane L then holds LHS row 4*L + rlo).
// rhs_bcast[chi]: `vbroadcasti32x4` of the 128-bit block of columns
// [4*chi, 4*chi+4) to all 4 lanes. The shuffle immediates must be
// compile-time constants (`s * 0x55` for s = 0..3), so the 4 cases
// are spelled out rather than looped.
__m512i lhs_dup[4] = {
_mm512_shuffle_epi32(lhs_i16, (_MM_PERM_ENUM)0x00),
_mm512_shuffle_epi32(lhs_i16, (_MM_PERM_ENUM)0x55),
_mm512_shuffle_epi32(lhs_i16, (_MM_PERM_ENUM)0xAA),
_mm512_shuffle_epi32(lhs_i16, (_MM_PERM_ENUM)0xFF),
};
__m512i rhs_bcast[4] = {
_mm512_shuffle_i32x4(rhs_i16, rhs_i16, 0x00),
_mm512_shuffle_i32x4(rhs_i16, rhs_i16, 0x55),
_mm512_shuffle_i32x4(rhs_i16, rhs_i16, 0xAA),
_mm512_shuffle_i32x4(rhs_i16, rhs_i16, 0xFF),
};
// 16 `vpdpwssd` over the 4x4 (rlo, chi) grid; accumulator (rlo, chi)
// lives at flat offset (4*rlo + chi)*16.
for (int rlo = 0; rlo < 4; ++rlo) {
for (int chi = 0; chi < 4; ++chi) {
int idx = 4 * rlo + chi;
(*regs)[idx] = _mm512_dpwssd_epi32((*regs)[idx], lhs_dup[rlo],
rhs_bcast[chi]);
}
}
}
}
}
}

for (int32_t m = 0; m < intrinsics_m; ++m) {
for (int32_t n = 0; n < intrinsics_n; ++n) {
int32_t *frag = acc + (m * intrinsics_n + n) * acc_stride;
for (int c = 0; c < 16; ++c) {
_mm512_storeu_si512(frag + c * 16, acc_regs[m][n][c]);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,46 @@ func.func @bf16_inner_tiled_ukernel_disabled(
// CHAIN: iree_codegen.inner_tiled
// CHAIN-NOT: iree_codegen.ukernel.generic
// CHAIN-NOT: hal.executable.objects

// -----

// Same configuration as `bf16_inner_tiled_ukernel_enabled` but with
// `intrinsics_m = 2`. Arbitrary positive `intrinsics_{m,n,k}` are
// supported — the ukernel loops over them — so `selectCPUUKernel` matches
// this case just like the unit one, and the chained pipeline rewrites it
// to a `ukernel.generic`.
#executable_target_enabled_unrolled = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
cpu_features = "+avx512f,+avx512bf16",
data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
iree_codegen.ukernel_provider = #iree_cpu.ukernel_provider,
llvm_ukernels = "inner_tiled",
native_vector_size = 64 : index,
target_triple = "x86_64-unknown-unknown-eabi-elf"
}>

func.func @bf16_inner_tiled_ukernel_unrolled_accepted(
%lhs: tensor<2x4x2x1x2xbf16>, %rhs: tensor<2x4x1x16x2xbf16>, %acc: tensor<2x2x2x1x1x16xf32>
) -> tensor<2x2x2x1x1x16xf32>
attributes {hal.executable.target = #executable_target_enabled_unrolled} {
%0 = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%acc) {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
],
iterator_types = [#linalg.iterator_type<parallel>,
#linalg.iterator_type<parallel>,
#linalg.iterator_type<reduction>],
kind = #iree_cpu.data_tiled_mma_layout<intrinsic = MMA_X86_AVX512BF16_1x16x2_F32_BF16, intrinsics_m = 2>,
semantics = #iree_cpu.mma_semantics<>
} : tensor<2x4x2x1x2xbf16>, tensor<2x4x1x16x2xbf16> into tensor<2x2x2x1x1x16xf32>
return %0 : tensor<2x2x2x1x1x16xf32>
}
// CHECK-LABEL: func.func @bf16_inner_tiled_ukernel_unrolled_accepted
// CHECK: iree_codegen.inner_tiled
// CHECK-SAME: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16", bitcode>

// CHAIN-LABEL: func.func @bf16_inner_tiled_ukernel_unrolled_accepted
// CHAIN: iree_codegen.ukernel.generic
// CHAIN-SAME: "iree_uk_mma_x86_avx512bf16_1x16x2_f32_bf16"
// CHAIN-NOT: iree_codegen.inner_tiled
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ IREE::Codegen::UKernelDescriptorAttr selectCPUUKernel(Operation *op) {
return {};
}

// Any positive `intrinsics_{m,n,k}` is supported: the ukernel takes them as
// arguments and loops over them, and those loops fully unroll after the
// ukernel is inlined into its (constant-`intrinsics_*`) caller.

auto execTarget = IREE::HAL::ExecutableTargetAttr::lookup(op);
if (!execTarget) {
return {};
Expand All @@ -74,7 +78,16 @@ IREE::Codegen::UKernelDescriptorAttr selectCPUUKernel(Operation *op) {
// (`ensureUKernelBitcodeAndFinalizeConfig` in
// `compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp`)
// and makes the configuration-pass output self-contained for lit tests.
IREE::CPU::attachUKernelBitcodeOnOp(op, name);
//
// Only select the ukernel if its bitcode actually exists. Not every
// `MMAIntrinsic` the cost model picks has a built-in ukernel — e.g. the
// M<->N-swapped `MMA_X86_AVX512BF16_16x1x2_F32_BF16` orientation has no
// seed even though its natural sibling does — and matching one without
// bitcode would dangle an undefined symbol at link time. When absent, fall
// back to codegen by returning {}.
if (!IREE::CPU::attachUKernelBitcodeOnOp(op, name)) {
return {};
}

MLIRContext *context = op->getContext();
return IREE::Codegen::UKernelDescriptorAttr::get(
Expand Down
Loading
Loading