From 15b254c0f76cc1063c657a2d7f552e53fc4e7890 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 1 Jul 2025 09:38:44 -0700 Subject: [PATCH] [Executorch][llm] Make mask tensor float only for sdpa Now that we support quantized sdpa query tensor can be quantized and attention mask can be float (the only type allowed). So this check doesnt make sense anymore. Differential Revision: [D77516821](https://our.internmc.facebook.com/intern/diff/D77516821/) [ghstack-poisoned] --- extension/llm/custom_ops/op_sdpa.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 91802a8445d..c98fa1729fa 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -59,8 +59,8 @@ bool validate_flash_attention_args( ET_CHECK_OR_RETURN_FALSE( !attn_mask.has_value() || - attn_mask.value().scalar_type() == query.scalar_type(), - "Attention mask must be a 2D tensor"); + attn_mask.value().scalar_type() == ScalarType::Float, + "Attention mask must be a Float tensor"); ET_CHECK_OR_RETURN_FALSE( is_contiguous_dim_order(query.dim_order().data(), query.dim()),