From a6d75729b63d189a57e61873a0ccb669624e4d97 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Tue, 3 Jun 2025 21:05:12 +0000 Subject: [PATCH] Add XeVM dialect. --- .../mlir/Dialect/LLVMIR/CMakeLists.txt | 10 + .../include/mlir/Dialect/LLVMIR/XeVMDialect.h | 40 ++ mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 428 ++++++++++++++++++ mlir/include/mlir/InitAllDialects.h | 4 +- mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 22 + mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 366 +++++++++++++++ mlir/test/Dialect/LLVMIR/xevm.mlir | 52 +++ mlir/test/lib/Dialect/GPU/CMakeLists.txt | 1 + 8 files changed, 922 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h create mode 100644 mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td create mode 100644 mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp create mode 100644 mlir/test/Dialect/LLVMIR/xevm.mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt index 9c5bbae1022f7..cfad07e57021f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -87,3 +87,13 @@ mlir_tablegen(VCIXConversions.inc -gen-llvmir-conversions) mlir_tablegen(VCIXOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=vcix) mlir_tablegen(VCIXOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=vcix) add_public_tablegen_target(MLIRVCIXConversionsIncGen) + +add_mlir_dialect(XeVMOps xevm) +add_mlir_doc(XeVMOps XeVMDialect Dialects/ -gen-dialect-doc -dialect=xevm) +set(LLVM_TARGET_DEFINITIONS XeVMOps.td) +mlir_tablegen(XeVMConversions.inc -gen-llvmir-conversions) +mlir_tablegen(XeVMOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(XeVMOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(XeVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=xevm) +mlir_tablegen(XeVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=xevm) +add_public_tablegen_target(MLIRXeVMConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h new file mode 100644 index 0000000000000..52b95644845ba --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h @@ -0,0 +1,40 @@ +//===-- XeVMDialect.h - MLIR XeVM target definitions ------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include + +namespace mlir::xevm { + +enum class XeVMAddrSpace : uint32_t { + kPrivate = 0, // OpenCL Workitem address space, SPIRV function + kGlobal = 1, // OpenCL Global memory, SPIRV crossworkgroup + kConstant = 2, // OpenCL Constant memory, SPIRV uniform constant + kShared = 3, // OpenCL Local memory, SPIRV workgroup + kGeneric = 4 // OpenCL Generic memory, SPIRV generic +}; + +} // namespace mlir::xevm + +#define GET_ATTRDEF_CLASSES +#include + +#define GET_OP_CLASSES +#include + +#include + +#endif /* MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ */ diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td new file mode 100644 index 0000000000000..c16efd72a38eb --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td @@ -0,0 +1,428 @@ +//===-- XeVMOps.td - XeVM dialect definition ---------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef XEVMIR_OPS +#define XEVMIR_OPS + +include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" + +def XeVM_Dialect : Dialect { + let name = "xevm"; + let cppNamespace = "::mlir::xevm"; + let dependentDialects = ["LLVM::LLVMDialect"]; + + let extraClassDeclaration = [{ + /// Get the name for the attribute used to specify cache control + /// decorations. + static constexpr ::llvm::StringRef getCacheControlsAttrName() { + return ::llvm::StringLiteral("xevm.DecorationCacheControl"); + } + }]; + + let useDefaultAttributePrinterParser = 1; +} + +class XeVM_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +class XeVM_Op traits = []> + : Op; + +def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>; + +def XeVM_LoadCacheControl : I32EnumAttr<"LoadCacheControl", "XeVM load ops cache control", + [ + I32EnumAttrCase<"DEFAULT", 0, "Default">, + I32EnumAttrCase<"UC", 1, "UC">, // uncached + I32EnumAttrCase<"C", 2, "C">, // cached + I32EnumAttrCase<"S", 3, "S">, // streaming + I32EnumAttrCase<"IAR", 4, "IAR">, // invalidate-after-read + ]> { + let cppNamespace = "::mlir::xevm"; + let genSpecializedAttr = 0; +} + +def XeVM_LoadCacheControlAttr + : EnumAttr { + let summary = [{ }]; + let assemblyFormat = "$value"; +} + +def XeVM_StoreCacheControl : I32EnumAttr<"StoreCacheControl", "XeVM store ops cache control", + [ + I32EnumAttrCase<"DEFAULT", 0, "Default">, + I32EnumAttrCase<"UC", 1, "UC">, // uncached + I32EnumAttrCase<"WT", 2, "WT">, // write-through + I32EnumAttrCase<"S", 3, "S">, // streaming + I32EnumAttrCase<"WB", 4, "WB">, // write back + ]> { + let cppNamespace = "::mlir::xevm"; + let genSpecializedAttr = 0; +} + +def XeVM_StoreCacheControlAttr + : EnumAttr { + let summary = [{ }]; + let assemblyFormat = "$value"; +} + +def XeVM_BlockLoad2dOp + : XeVM_Op<"blockload2d">, + Results<(outs FixedVectorOfRankAndType<[1, 2, 3], [XeVM_ElemType]>:$res)>, + Arguments<( + ins Arg:$ptr, I32:$base_width, + I32:$base_height, I32:$base_pitch, I32:$x, I32:$y, + I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height, + I32Attr:$v_blocks, I1Attr:$transpose, I1Attr:$vnni_transform, + DefaultValuedAttr< + XeVM_LoadCacheControlAttr, + "::mlir::xevm::LoadCacheControl::DEFAULT">:$l1_cache_control, + DefaultValuedAttr< + XeVM_LoadCacheControlAttr, + "::mlir::xevm::LoadCacheControl::DEFAULT">:$l3_cache_control)> { + + let summary = "2D block load"; + + let description = [{ + The `xevm.blockload2d` operation loads a two dimensional matrix tile + from a larger matrix residing in memory. The parameters are: + $ptr - the base address of the matrix containing the tile to load + $base_width, $base_height, $base_pitch - the shape of matrix + $x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to load + $elem_size_in_bits - the size in bits of the matrix element + - 32 for f32, bf32 + - 16 for f16, int16, bf16 + - 8 for int8, int4, int2 + $v_blocks - number of tiles to load (a.k.a. array length) + $transpose - transpose the tile in registers (useful for 32 bit element type) + $vnni_transform - transpose and pack the submatrix in registers (useful for < 32 bit element types) + $cache_control - an enumerator that sets the L1 and L3 cache behaviour + + Notes: + - pitch is the physical stride between the first columns of the current row and the subsequent row, + this may include (possibly implicit) padding, alignment, or other factors. + - the $transpose and $vnni_transform parameters are mutual exclusive + - transposing the tile loaded is typically used for the B matrix operand + (D = C + A * B), where A has row-major layout and B should have column-major layout in memory. + - if the tile loaded contains out of bound elements of the matrix, they are filled with 0. + - coordinate is provided in elements, while width and pitch are provided in bytes. + + Example: + ```mlir + %base_width_a = arith.constant 32 : i32 + %base_height_a = arith.constant 8 : i32 + %base_pitch_a = arith.constant 32 : i32 + %x = arith.constant 0 : i32 + %y = arith.constant 0 : i32 + %loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %base_pitch_a, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + ``` + }]; + + let assemblyFormat = [{ + operands ` ` `{` `elem_size_in_bits` `=` $elem_size_in_bits `,` `tile_width` `=` $tile_width `,` + `tile_height` `=` $tile_height `,` `v_blocks` `=` $v_blocks `,` `transpose` `=` $transpose `,` + `vnni_transform` `=` $vnni_transform `,` `l1_cache_control` `=` $l1_cache_control `,` + `l3_cache_control` `=` $l3_cache_control `}` attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +def XeVM_BlockStore2dOp + : XeVM_Op<"blockstore2d">, + Arguments<( + ins Arg:$ptr, I32:$base_width, + I32:$base_height, I32:$base_pitch, I32:$x, I32:$y, + I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height, + DefaultValuedAttr:$v_blocks, + FixedVectorOfRankAndType<[1, 2, 3], [XeVM_ElemType]>:$stored_val, + DefaultValuedAttr< + XeVM_StoreCacheControlAttr, + "::mlir::xevm::StoreCacheControl::DEFAULT">:$l1_cache_control, + DefaultValuedAttr< + XeVM_StoreCacheControlAttr, + "::mlir::xevm::StoreCacheControl::DEFAULT">:$l3_cache_control)> { + + let summary = "2D block store"; + + let description = [{ + The `xevm.blockstore2d` operation stores a two dimensional tile into a + larger matrix residing in memory. The parameters are: + $ptr - the base address of the matrix where to store the tile + $base_width, $base_height, $base_pitch - the shape of the matrix + $x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to store + $elem_size_in_bits - the size in bits of the matrix element + - 32 for f32, bf32 + - 16 for f16, int16, bf16 + - 8 for int8, int4, int2 + $v_blocks - number of tiles to store + $cache_control - an enumerator that sets the L1 and L3 cache behaviour + $stored_val - the tile to store + + Notes: + - pitch is the physical stride between the first columns of the current row and the subsequent row, + this may include (possibly implicit) padding, alignment, or other factors. + - coordinate is provided in elements, while width and pitch are provided in bytes. + + Example: + ```mlir + %base_width_c = arith.constant 64 : i32 + %base_height_c = arith.constant 8 : i32 + %base_pitch_c = arith.constant 64 : i32 + %x = arith.constant 0 : i32 + %y = arith.constant 0 : i32 + xevm.blockstore2d %dst, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %src {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + ``` + }]; + + let assemblyFormat = [{ + operands ` ` `{` `elem_size_in_bits` `=` $elem_size_in_bits `,` `tile_width` `=` $tile_width `,` + `tile_height` `=` $tile_height `,` `v_blocks` `=` $v_blocks `,` `l1_cache_control` `=` $l1_cache_control `,` + `l3_cache_control` `=` $l3_cache_control `}` + attr-dict `:` `(` type(operands) `)` + }]; + + let hasVerifier = 1; +} + +def XeVM_MemoryScope + : I32EnumAttr<"MemoryScope", "Memory scope for memory operations", + [I32EnumAttrCase<"WORKGROUP", 0, "workgroup">, + I32EnumAttrCase<"LOCAL", 1, "local">, + I32EnumAttrCase<"TILE", 2, "tile">, + I32EnumAttrCase<"GPU", 3, "gpu">, + I32EnumAttrCase<"SYSTEM", 4, "system">]> { + let cppNamespace = "mlir::xevm"; +} + +def XeVM_AddrSpace : I32EnumAttr<"OclAddrSpace", "Address spaces in OpenCL", + [ + I32EnumAttrCase<"kPrivate", 0, "private">, // OpenCL Workitem address space, SPIRV function + I32EnumAttrCase<"kGlobal", 1, "global">, // OpenCL Global memory, SPIRV crossworkgroup + I32EnumAttrCase<"kConstant", 2, "constant">, // OpenCL Constant memory, SPIRV uniform constant + I32EnumAttrCase<"kShared", 3, "shared">, // OpenCL Local memory, SPIRV workgroup + I32EnumAttrCase<"kGeneric", 4, "generic"> // OpenCL Generic memory, SPIRV generic + ]>{ + let cppNamespace = "mlir::xevm"; +} + +def XeVM_MemfenceOp + : XeVM_Op<"memfence">, + Arguments<(ins XeVM_MemoryScope:$scope, + DefaultValuedAttr:$addrspace)> { + let summary = "Work-item's memory fence."; + let description = [{ + This operation ensures that all prior memory accesses of this + work-item to `addrspace` are visible to all other work-items in `scope`. + Parameters description: + $scope - specify the memory scope at which all other work-items should observe + memory operations prior to the fence. + $addrspace - specify the address space of work-item's memory accesses + to be affected by the fence. + }]; + let assemblyFormat = + [{`addrspace` `=` $addrspace `,` `scope` `=` $scope attr-dict}]; +} + +def XeVM_PrefetchOp + : XeVM_Op<"prefetch">, + Arguments<( + ins Arg:$ptr, + XeVM_AddrSpace:$addrspace, + DefaultValuedAttr< + XeVM_LoadCacheControlAttr, + "::mlir::xevm::LoadCacheControl::DEFAULT">:$l1_cache_control, + DefaultValuedAttr< + XeVM_LoadCacheControlAttr, + "::mlir::xevm::LoadCacheControl::DEFAULT">:$l3_cache_control)> { + let summary = "Prefetch data into a cache subsystem."; + let description = [{ + Work-item issues a prefetch from global memory to L1/L3 cache: + $ptr - memory pointer. + $addrspace - address space of a pointer, must be generic or global. + $cache_control - specify caching options (e.g., L1c, L3uc). + }]; + let assemblyFormat = [{ + operands ` ` `{` `addrspace` `=` $addrspace `,` `l1_cc` `=` $l1_cache_control `,` `l3_cc` `=` $l3_cache_control `}` + attr-dict `:` `(` type(operands) `)` + }]; + + // let hasVerifier = 1; +} + +def XeVM_BlockPrefetch2dOp + : XeVM_Op<"blockprefetch2d">, + Arguments<( + ins Arg:$ptr, I32:$base_width, + I32:$base_height, I32:$base_pitch, I32:$x, I32:$y, + I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height, + I32Attr:$v_blocks, + DefaultValuedAttr< + XeVM_LoadCacheControlAttr, + "::mlir::xevm::LoadCacheControl::DEFAULT">:$l1_cache_control, + DefaultValuedAttr< + XeVM_LoadCacheControlAttr, + "::mlir::xevm::LoadCacheControl::DEFAULT">:$l3_cache_control)> { + + let summary = "2D block prefetch"; + + let description = [{ + The `xevm.blockprefetch2d` operation prefetches a two dimensional tile + from a larger matrix residing in memory. The parameters are: + $ptr - the base address of the matrix containing the tile to prefetch + $base_width, $base_height, $base_pitch - the shape of the matrix + $x, $y, $tile_width, $tile_height - the starting offsets and shape of tile to prefetch + $elem_size_in_bits - the size in bits of the matrix element + - 32 for f32, bf32 + - 16 for f16, int16, bf16 + - 8 for int8, int4, int2 + $v_blocks - number of tiles to prefetch + $cache_control - an enumerator that sets the L1 and L3 cache behaviour + + Notes: + - pitch is the physical stride between the first columns of the current row and the subsequent row, + this may include (possibly implicit) padding, alignment, or other factors. + - coordinate is provided in elements, while width and pitch are provided in bytes. + + Example: + ```mlir + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, l1_cache_control=UC, l3_cache_control=UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + ``` + }]; + + let assemblyFormat = [{ + operands ` ` `{` `elem_size_in_bits` `=` $elem_size_in_bits `,` `tile_width` `=` $tile_width `,` + `tile_height` `=` $tile_height `,` `v_blocks` `=` $v_blocks `,` `l1_cache_control` `=` $l1_cache_control `,` + `l3_cache_control` `=` $l3_cache_control `}` + attr-dict `:` `(` type(operands) `)` + }]; + + let hasVerifier = 1; +} + +def XeVM_MatrixElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>; + +/// Enum attribute of the different precision types. +def XeVM_PrecisionTypeAttr + : I32EnumAttr< + "PrecisionType", "XeVM precision type", + [I32EnumAttrCase<"UNUSED", 0, "unused">, + I32EnumAttrCase<"U8", 1, "u8">, I32EnumAttrCase<"U4", 2, "u4">, + I32EnumAttrCase<"U2", 3, "u2">, I32EnumAttrCase<"S8", 4, "i8">, + I32EnumAttrCase<"S4", 5, "i4">, I32EnumAttrCase<"S2", 6, "i2">, + I32EnumAttrCase<"BF8", 7, "bf8">, I32EnumAttrCase<"TF32", 8, "tf32">, + I32EnumAttrCase<"BF16", 9, "bf16">, + I32EnumAttrCase<"FP16", 10, "f16">]> { + let cppNamespace = "::mlir::xevm"; +} + +def XeVM_DpasOp + : XeVM_Op<"dpas">, + Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>, + Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$c, + FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a, + FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b, + XeVM_PrecisionTypeAttr:$pa, XeVM_PrecisionTypeAttr:$pb, + I32Attr:$rc)> { + + let summary = "Matrix multiply-add"; + + let description = [{ + The `xevm.dpas` operation is a matrix multiplication plus accumulation: + + D = C + A x B + + where the A, B, C input matrices and the result D have shapes: + D : MxN + C : MxN + A : MxK + B : KxN + + Shape restrictions: + M : must be 1, 2, 4, or 8 + N : fixed execution size, must be 16 + K : systolic_depth * OPS_PER_CHAN + OPS_PER_CHAN + 1 : for TF32 + 2 : for 16-bit precision(BF, HF) + 4 : for 8-bit precision (FP8, UB, B) + 8 : for less-then 8 bit precision (U4/S4, U2/S2). + + If systolic_depth is 8, K would be 8, 16, 32, or 64 (based on OPS_PER_CHAN). + $a, $b, $c, $d - matrix A, B, C, D, respectively + $pa, $pb - precision of matrix A and B resepectively + $rc - repeat count + + Further restrictions as well as more details can be found here: + https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html + + Example: + ```mlir + %c_result = xevm.dpas %c, %a, %b {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32> + ``` + }]; + + let assemblyFormat = [{ + operands ` ` `{` `pa` `=` $pa `,` `pb` `=` $pb `,` `rc` `=` $rc `}` attr-dict `:` functional-type(operands, results) + }]; + + // let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// XeVM target attribute. +//===----------------------------------------------------------------------===// + +def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> { + let description = [{ + GPU target attribute for controlling compilation of Intel GPU targets. All + parameters decay into default values if not present. + + Examples: + + 1. Target with default values. + ``` + gpu.module @mymodule [#xevm.target] attributes {...} { + ... + } + ``` + }]; + let parameters = + (ins DefaultValuedParameter<"int", "2", + "Optimization level to apply.">:$O, + StringRefParameter<"Target triple.", + "\"spirv64-unknown-unknown\"">:$triple, + StringRefParameter<"Target chip.", "\"pvc\"">:$chip, + OptionalParameter<"::mlir::DictionaryAttr", + "Target specific flags.">:$flags, + OptionalParameter<"::mlir::ArrayAttr", + "Files to link to the LLVM module.">:$linkFiles); + let assemblyFormat = [{ + (`<` struct($O, $triple, $chip, $flags, $linkFiles)^ `>`)? + }]; + let builders = [AttrBuilder< + (ins CArg<"int", "2">:$optLevel, + CArg<"::llvm::StringRef", "\"spirv64-unknown-unknown\"">:$triple, + CArg<"::llvm::StringRef", "\"pvc\"">:$chip, + CArg<"::mlir::DictionaryAttr", "nullptr">:$targetFlags, + CArg<"::mlir::ArrayAttr", "nullptr">:$linkFiles), + [{ + return Base::get($_ctxt, optLevel, triple, chip, targetFlags, linkFiles); + }]>]; + let skipDefaultBuilders = 1; + let genVerifyDecl = 1; +} + +#endif // XEVMIR_OPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index ea285ac7f16e3..048191e7dc42c 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -46,6 +46,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" @@ -151,7 +152,8 @@ inline void registerAllDialects(DialectRegistry ®istry) { ub::UBDialect, vector::VectorDialect, x86vector::X86VectorDialect, - xegpu::XeGPUDialect>(); + xegpu::XeGPUDialect, + xevm::XeVMDialect>(); // clang-format on // Register all external models. diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index d83fd3800eb91..67081ca61e6e5 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -110,3 +110,25 @@ add_mlir_dialect_library(MLIRVCIXDialect MLIRLLVMDialect MLIRSideEffectInterfaces ) + +add_mlir_dialect_library(MLIRXeVMDialect + IR/XeVMDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + + DEPENDS + MLIRGPUCompilationAttrInterfacesIncGen + MLIRXeVMOpsIncGen + MLIRXeVMConversionsIncGen + intrinsics_gen + + LINK_COMPONENTS + AsmParser + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + MLIRSideEffectInterfaces +) diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp new file mode 100644 index 0000000000000..aa391e959bf05 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp @@ -0,0 +1,366 @@ +//===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace mlir::xevm; + +#include +#include + +namespace { +constexpr uint32_t subgroupSize = 16; + +template +LogicalResult verifyMatrixInput(Op op) { + static_assert(llvm::is_one_of::value, + "Unexpected template parameter"); + + std::optional width = getConstantIntValue(op.getBaseWidth()); + std::optional pitch = getConstantIntValue(op.getBasePitch()); + if (pitch && width && *pitch < *width) + return op->emitOpError( + "4th operand (base pitch) should be >= 2nd operand (base width)"); + + uint32_t elemSize = op.getElemSizeInBits(); + if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32) + return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32"); + + uint32_t tileHeight = op.getTileHeight(); + if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight)) + return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32"); + + uint32_t vBlocks = op.getVBlocks(); + if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks)) + return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8"); + + return success(); +} + +LogicalResult verify2DBlockLoadHWRestriction(BlockLoad2dOp op) { + VectorType resTy = op.getRes().getType(); + if (!resTy.getElementType().isIntOrFloat()) + return op.emitOpError() + << "expecting result element type to be int or float"; + unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth(); + unsigned resSize = resTy.getNumElements() * resElemTySize; + unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() * + op.getTileWidth() * op.getVBlocks() / subgroupSize; + if (resSize != expectedSize) + return op.emitOpError() << "result size of " << resSize + << " bits does not match the expected size of " + << expectedSize << " bits"; + + if (op.getTranspose() && op.getVnniTransform()) + return op.emitOpError( + "transpose and vnni_transform are mutually exclusive"); + + if (!op.getTranspose() && !op.getVnniTransform()) { + uint32_t tileHeight = op.getTileHeight(); + if (tileHeight < 1 || tileHeight > 32) + return op.emitOpError("expecting tile_height to be between 1 and 32"); + + uint32_t tileWidth = op.getTileWidth(); + uint32_t vBlocks = op.getVBlocks(); + switch (op.getElemSizeInBits()) { + case 8: + if (tileWidth < 4 || tileWidth > 64) + return op.emitOpError("expecting tile_width to be between 4 and 64"); + if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4) + return op.emitOpError("expecting v_blocks to be 1, 2, or 4"); + if (tileWidth * vBlocks > 64) + return op.emitOpError( + "tile_width * v_blocks should be less than or equal " + "to 64 for 8 bit elements"); + break; + case 16: + if (tileWidth < 2 || tileWidth > 32) + return op.emitOpError("expecting tile_width to be between 2 and 32"); + if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4) + return op.emitOpError("expecting v_blocks to be 1, 2, or 4"); + if (tileWidth * vBlocks > 32) + return op.emitOpError( + "tile_width * v_blocks should be less than or equal " + "to 32 for 16 bit elements"); + break; + case 32: + if (tileWidth < 1 || tileWidth > 16) + return op.emitOpError("expecting tile_width to be between 1 and 16"); + if (vBlocks != 1 && vBlocks != 2) + return op.emitOpError("expecting v_blocks to be 1 or 2"); + if (tileWidth * vBlocks > 16) + return op.emitOpError( + "tile_width * v_blocks should be less than or equal " + "to 16 for 32 bit elements"); + break; + case 64: + if (tileWidth < 1 || tileWidth > 8) + return op.emitOpError("expecting tile_width to be between 1 and 8"); + if (vBlocks != 1) + return op.emitOpError("expecting v_blocks to be 1"); + break; + default: + return op.emitOpError( + "expecting elem_size_in_bits to be 8, 16, 32, or 64"); + } + + return success(); + } + + if (op.getTranspose()) { + assert(!op.getVnniTransform() && + "Expecting vnni_transform should be false"); + + uint32_t vBlocks = op.getVBlocks(); + if (vBlocks != 1) + return op.emitOpError("expecting v_blocks to be 1"); + + uint32_t tileHeight = op.getTileHeight(); + uint32_t tileWidth = op.getTileWidth(); + switch (op.getElemSizeInBits()) { + case 32: + if (tileHeight < 1 || tileHeight > 32) + return op.emitOpError("expecting tile_height to be between 1 and 32"); + if (tileWidth < 1 || tileWidth > 8) + return op.emitOpError("expecting tile_width to be between 1 and 8"); + break; + case 64: + if (tileHeight != 8) + return op.emitOpError( + "expecting tile_height to be 8 for 64 bit elements"); + if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4) + return op.emitOpError("expecting tile_width to be 1, 2, or 4"); + break; + default: + return op.emitOpError("transpose is only supported for 32 and 64 bit " + "elements"); + } + + return success(); + } + + assert(op.getVnniTransform() && !op.getTranspose() && + "Expecting vnni_transform should be true and transpose should be " + "false"); + + uint32_t vBlocks = op.getVBlocks(); + if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4) + return op.emitOpError("expecting v_blocks to be 1, 2, or 4"); + + uint32_t tileHeight = op.getTileHeight(); + uint32_t tileWidth = op.getTileWidth(); + switch (op.getElemSizeInBits()) { + case 8: + if (tileHeight < 4 || tileHeight > 32) + return op.emitOpError("expecting tile_height to be between 4 and 32"); + if (tileWidth < 4 || tileWidth > 16) + return op.emitOpError("expecting tile_width to be between 4 and 16"); + break; + case 16: + if (tileHeight < 2 || tileHeight > 32) + return op.emitOpError("expecting tile_height to be between 2 and 32"); + if (tileWidth < 2 || tileWidth > 16) + return op.emitOpError("expecting tile_width to be between 2 and 16"); + if (tileWidth * vBlocks > 32) + return op.emitOpError( + "tile_width * v_blocks should be less than or equal " + "to 32 for 16 bit elements"); + break; + default: + return op.emitOpError("vnni_transform is only supported for 8 and 16 bit " + "elements"); + } + + return success(); +} + +static LogicalResult verify2DBlockStoreHWRestriction(BlockStore2dOp op) { + uint32_t tileHeight = op.getTileHeight(); + if (tileHeight < 1 || tileHeight > 8) + return op.emitOpError("expecting tile_height to be between 1 and 8"); + + uint32_t tileWidth = op.getTileWidth(); + switch (op.getElemSizeInBits()) { + case 8: + if (tileWidth < 4 || tileWidth > 64) + return op.emitOpError("expecting tile_width to be between 4 and 64"); + break; + case 16: + if (tileWidth < 2 || tileWidth > 32) + return op.emitOpError("expecting tile_width to be between 2 and 32"); + break; + case 32: + if (tileWidth < 1 || tileWidth > 16) + return op.emitOpError("expecting tile_width to be between 1 and 16"); + break; + case 64: + if (tileWidth < 1 || tileWidth > 8) + return op.emitOpError("expecting tile_width to be between 1 and 8"); + break; + default: + return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64"); + } + + uint32_t vBlocks = op.getVBlocks(); + if (vBlocks != 1) + return op.emitOpError("expecting v_blocks to be 1"); + return success(); +} + +} // namespace + +LogicalResult BlockLoad2dOp::verify() { + if (verify2DBlockLoadHWRestriction(*this).failed()) + return failure(); + + if (verifyMatrixInput(*this).failed()) + return failure(); + + VectorType resTy = getRes().getType(); + if (!resTy.getElementType().isIntOrFloat()) + return emitOpError() << "expecting result element type to be int of float"; + unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth(); + if (getElemSizeInBits() == 32 || getVnniTransform()) { + if (resElemTySize != 32) + return emitOpError() << "expecting result element type to be 32 bits"; + } + + uint32_t tileWidth = getTileWidth(); + if (getVnniTransform()) { + if (tileWidth != 16) + return emitOpError( + "tile_width when vnni_transform is true should be equal " + "to subgroup size (16 elements)"); + return success(); + } + + return success(); +} + +LogicalResult BlockStore2dOp::verify() { + if (verify2DBlockStoreHWRestriction(*this).failed()) + return failure(); + + if (verifyMatrixInput(*this).failed()) + return failure(); + + uint32_t tileWidth = getTileWidth(); + switch (getElemSizeInBits()) { + case 8: + if (tileWidth != 16 && tileWidth != 32) + return emitOpError("tile_width for 8 bit elements should be equal to " + "16 or 32"); + break; + case 16: + if (tileWidth != 16) + return emitOpError("tile_width for 16 bit elements should be equal " + "to 16"); + break; + case 32: + if (tileWidth != 16) + return emitOpError("tile_width for 32 bit elements should be equal " + "to 16"); + break; + default: + llvm_unreachable("unexpected element size"); + } + + return success(); +} + +LogicalResult BlockPrefetch2dOp::verify() { + if (verifyMatrixInput(*this).failed()) + return failure(); + + uint32_t tileWidth = getTileWidth(); + switch (getElemSizeInBits()) { + case 8: + if (tileWidth != 16 && tileWidth != 32) + return emitOpError("tile_width for 8 bit elements should be equal to " + "16 or 32"); + break; + case 16: + if (tileWidth != 16) + return emitOpError("tile_width for 16 bit elements should be equal " + "to 16"); + break; + case 32: + if (tileWidth != 8 && tileWidth != 16) + return emitOpError( + "tile_width for 32 bit elements should be equal to 8 or 16"); + break; + default: + llvm_unreachable("unexpected element size"); + } + + return success(); +} + +LogicalResult +XeVMTargetAttr::verify(function_ref emitError, int O, + StringRef triple, StringRef chip, DictionaryAttr flags, + ArrayAttr linkFiles) { + if (O < 0 || O > 3) { + emitError() << "The optimization level must be a number between 0 and 3."; + return failure(); + } + if (triple.empty()) { + emitError() << "The target triple cannot be empty."; + return failure(); + } + if (chip.empty()) { + emitError() << "The target chip cannot be empty."; + return failure(); + } + if (linkFiles) { + for (Attribute fileAttr : linkFiles) { + if (auto fileStrAttr = llvm::dyn_cast(fileAttr)) { + StringRef filePath = fileStrAttr.getValue(); + if (filePath.empty()) { + emitError() << "File paths in linkFiles cannot be empty."; + return failure(); + } + if (!llvm::sys::fs::exists(filePath)) { + emitError() << "File '" << filePath << "' does not exist."; + return failure(); + } + } + } + } + return success(); +} + +void XeVMDialect::initialize() { + // NOLINTBEGIN + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc" + >(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc" + >(); + // NOLINTEND + declarePromisedInterface(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir new file mode 100644 index 0000000000000..10338d60fb053 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/xevm.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: func.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +func.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { + // CHECK: %[[VAR0:.*]] = xevm.blockload2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, transpose = false, vnni_transform = false, l1_cache_control = Default, l3_cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + return %loaded_a : vector<8xi16> +} + +// ----- +// CHECK: func.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) +func.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { + // CHECK: xevm.blockstore2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]] {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, l1_cache_control = Default, l3_cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + return +} + +// ----- +// CHECK: func.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +func.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) { + // CHECK: xevm.blockprefetch2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] {elem_size_in_bits = 8, tile_width = 32, tile_height = 8, v_blocks = 1, l1_cache_control = UC, l3_cache_control = UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, l1_cache_control=UC, l3_cache_control=UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + return +} + +// ----- +// CHECK: func.func @dpas(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) +func.func @dpas(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> { + // CHECK: %0 = xevm.dpas %[[ARG0]], %[[ARG1]], %[[ARG2]] {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32> + %c_result = xevm.dpas %loaded_c_casted, %loaded_a, %loaded_b_casted {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32> + return %c_result : vector<8xf32> +} + +// ----- +func.func @memfence() { + // CHECK: xevm.memfence addrspace = global, scope = workgroup + xevm.memfence addrspace=global, scope=workgroup + return +} + +// ----- +// CHECK: func.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) +func.func @prefetch(%ptr: !llvm.ptr<1>) { + // CHECK: xevm.prefetch %[[ARG0]] {addrspace = global, l1_cc = UC, l3_cc = UC} : (!llvm.ptr<1>) + xevm.prefetch %ptr {addrspace = global, l1_cc = UC, l3_cc = UC} : (!llvm.ptr<1>) + return +} + +// ----- +// CHECK: @xevm_module [#xevm.target] { +gpu.module @xevm_module [#xevm.target]{ +} diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt index 4ca5974ed5a49..418c884dc03b3 100644 --- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt @@ -29,6 +29,7 @@ set(LIBS MLIRTranslateLib MLIRVectorDialect MLIRVectorToLLVMPass + MLIRXeVMDialect ) add_mlir_library(MLIRGPUTestPasses