From 8cee28eb271d2013003efbfc89fd0cd99fce5118 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Mon, 26 Jan 2026 15:24:38 -0800 Subject: [PATCH 1/4] Use `llvm.fma` for `tt.dot` lowering This change replaces the `llvm.fmul` and `llvm.fadd` instructions with the fused `llvm.fma` operation. This should have no downstream impact on the emitted machine code which, due to auto-vectorization and other LLVM magic, already ends up using `VFMADD213PS`. What _is_ unclear about this change is that we materialize some fastmath flags from thin air: it seems like we should be able to configure this somewhere at the user level (TODO). --- cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp b/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp index cfdadd21..99ca9405 100644 --- a/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp +++ b/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp @@ -49,8 +49,9 @@ class GenericFMAVectorMultiplier : public triton::gpu::FMAVectorMultiplier { } // Multiply and accumulate. - auto mul = LLVM::FMulOp::create(builder, loc, tgtTy, aElem, bElem); - accum = LLVM::FAddOp::create(builder, loc, tgtTy, accum, mul); + auto flags = LLVM::FastmathFlagsAttr::get(builder.getContext(), + LLVM::FastmathFlags::fast); + accum = LLVM::FMAOp::create(builder, loc, tgtTy, aElem, bElem, accum, flags); } return accum; } From 23c84ab79a4d54ba74814b041ed1a3435bcd8973 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Mon, 26 Jan 2026 15:29:33 -0800 Subject: [PATCH 2/4] Apply clang formatting --- cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp b/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp index 99ca9405..8d4c96d8 100644 --- a/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp +++ b/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp @@ -51,7 +51,8 @@ class GenericFMAVectorMultiplier : public triton::gpu::FMAVectorMultiplier { // Multiply and accumulate. auto flags = LLVM::FastmathFlagsAttr::get(builder.getContext(), LLVM::FastmathFlags::fast); - accum = LLVM::FMAOp::create(builder, loc, tgtTy, aElem, bElem, accum, flags); + accum = + LLVM::FMAOp::create(builder, loc, tgtTy, aElem, bElem, accum, flags); } return accum; } From f17f2669d995361788ab0b082c6eda3f629dcbdc Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 28 Jan 2026 14:40:56 -0800 Subject: [PATCH 3/4] review: remove fastmath flags --- cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp b/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp index 8d4c96d8..2175d5e3 100644 --- a/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp +++ b/cpu/lib/TritonCPUToLLVM/DotOpToLLVM.cpp @@ -49,10 +49,7 @@ class GenericFMAVectorMultiplier : public triton::gpu::FMAVectorMultiplier { } // Multiply and accumulate. - auto flags = LLVM::FastmathFlagsAttr::get(builder.getContext(), - LLVM::FastmathFlags::fast); - accum = - LLVM::FMAOp::create(builder, loc, tgtTy, aElem, bElem, accum, flags); + accum = LLVM::FMAOp::create(builder, loc, tgtTy, aElem, bElem, accum); } return accum; } From df6cae7e2213f13d6647075e180a8c2ae3271a9b Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Thu, 29 Jan 2026 09:21:55 -0800 Subject: [PATCH 4/4] Update lit tests --- test/Conversion/dot.mlir | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/Conversion/dot.mlir b/test/Conversion/dot.mlir index a0b59ffd..e0e5038d 100644 --- a/test/Conversion/dot.mlir +++ b/test/Conversion/dot.mlir @@ -13,8 +13,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // COM: We should see a bunch of repetitions of this pattern: // CHECK: [[A:%.*]] = llvm.fpext {{%.*}} : f16 to f32 // CHECK: [[B:%.*]] = llvm.fpext {{%.*}} : f16 to f32 - // CHECK: [[MUL:%.*]] = llvm.fmul [[A]], [[B]] : f32 - // CHECK: {{%.*}} = llvm.fadd {{%.*}}, [[MUL]] : f32 + // CHECK: [[MUL:%.*]] = llvm.intr.fma([[A]], [[B]], {{%.*}}) : (f32, f32, f32) -> f32 tt.return %d : tensor<2x2xf32, #blocked> } } @@ -34,8 +33,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // COM: We should see a bunch of repetitions of this pattern: // CHECK: [[A:%.*]] = llvm.fpext {{%.*}} : bf16 to f32 // CHECK: [[B:%.*]] = llvm.fpext {{%.*}} : bf16 to f32 - // CHECK: [[MUL:%.*]] = llvm.fmul [[A]], [[B]] : f32 - // CHECK: {{%.*}} = llvm.fadd {{%.*}}, [[MUL]] : f32 + // CHECK: [[MUL:%.*]] = llvm.intr.fma([[A]], [[B]], {{%.*}}) : (f32, f32, f32) -> f32 tt.return %d : tensor<2x2xf32, #blocked> } }