diff --git a/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp b/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp index cfdadd21..2175d5e3 100644 --- a/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp +++ b/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp @@ -49,8 +49,7 @@ class GenericFMAVectorMultiplier : public triton::gpu::FMAVectorMultiplier { } // Multiply and accumulate. - auto mul = LLVM::FMulOp::create(builder, loc, tgtTy, aElem, bElem); - accum = LLVM::FAddOp::create(builder, loc, tgtTy, accum, mul); + accum = LLVM::FMAOp::create(builder, loc, tgtTy, aElem, bElem, accum); } return accum; } diff --git a/test/Conversion/dot.mlir b/test/Conversion/dot.mlir index a0b59ffd..e0e5038d 100644 --- a/test/Conversion/dot.mlir +++ b/test/Conversion/dot.mlir @@ -13,8 +13,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // COM: We should see a bunch of repetitions of this pattern: // CHECK: [[A:%.*]] = llvm.fpext {{%.*}} : f16 to f32 // CHECK: [[B:%.*]] = llvm.fpext {{%.*}} : f16 to f32 - // CHECK: [[MUL:%.*]] = llvm.fmul [[A]], [[B]] : f32 - // CHECK: {{%.*}} = llvm.fadd {{%.*}}, [[MUL]] : f32 + // CHECK: [[MUL:%.*]] = llvm.intr.fma([[A]], [[B]], {{%.*}}) : (f32, f32, f32) -> f32 tt.return %d : tensor<2x2xf32, #blocked> } } @@ -34,8 +33,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // COM: We should see a bunch of repetitions of this pattern: // CHECK: [[A:%.*]] = llvm.fpext {{%.*}} : bf16 to f32 // CHECK: [[B:%.*]] = llvm.fpext {{%.*}} : bf16 to f32 - // CHECK: [[MUL:%.*]] = llvm.fmul [[A]], [[B]] : f32 - // CHECK: {{%.*}} = llvm.fadd {{%.*}}, [[MUL]] : f32 + // CHECK: [[MUL:%.*]] = llvm.intr.fma([[A]], [[B]], {{%.*}}) : (f32, f32, f32) -> f32 tt.return %d : tensor<2x2xf32, #blocked> } }