From a3a42e425c6a7fa71f34f7f3cd7ad4d9fcaa85ca Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Apr 2026 12:25:09 -0400 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/utils.cpp | 10 ++++------ backends/apple/metal/runtime/shims/utils.h | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp index 50b46ec69d4..f0dc57997ae 100644 --- a/backends/apple/metal/runtime/shims/utils.cpp +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -20,8 +20,10 @@ extern "C" { bool is_dtype_supported_in_et_metal(int32_t dtype) { switch (dtype) { case static_cast(SupportedDTypes::UINT8): + case static_cast(SupportedDTypes::INT32): case static_cast(SupportedDTypes::INT64): case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BOOL): case static_cast(SupportedDTypes::BFLOAT16): return true; default: @@ -37,12 +39,8 @@ AOTITorchError validate_dtype(int32_t dtype) { ET_LOG( Error, - "Unsupported dtype: %d. Supported dtypes: %d (uint8), %d (int64), %d (float32), %d (bfloat16)", - dtype, - static_cast(SupportedDTypes::UINT8), - static_cast(SupportedDTypes::INT64), - static_cast(SupportedDTypes::FLOAT32), - static_cast(SupportedDTypes::BFLOAT16)); + "Unsupported dtype: %d", + dtype); return Error::InvalidArgument; } diff --git a/backends/apple/metal/runtime/shims/utils.h b/backends/apple/metal/runtime/shims/utils.h index 60412812b16..d749ee77947 100644 --- a/backends/apple/metal/runtime/shims/utils.h +++ b/backends/apple/metal/runtime/shims/utils.h @@ -22,12 +22,12 @@ enum class SupportedDTypes : int32_t { UINT8 = 0, // PyTorch's uint8 dtype code // INT8 = 1, // PyTorch's int8 dtype code // INT16 = 2, // PyTorch's int16 dtype code - // INT32 = 3, // PyTorch's int32 dtype code + INT32 = 3, // PyTorch's int32 dtype code INT64 = 4, // PyTorch's int64 dtype code // FLOAT16 = 5, // PyTorch's float16 dtype code FLOAT32 = 6, // PyTorch's float32 dtype code // FLOAT64 = 7, // PyTorch's float64 dtype code - // BOOL = 11, // PyTorch's bool dtype code + BOOL = 11, // PyTorch's bool dtype code BFLOAT16 = 15 // PyTorch's bfloat16 dtype code }; From 1c965c6b34623b738a308c6bd22a9ca8d0019883 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Apr 2026 12:25:14 -0400 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- backends/apple/metal/runtime/ops/op_sdpa.mm | 11 ++++----- backends/apple/metal/tests/test_modules.py | 25 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/backends/apple/metal/runtime/ops/op_sdpa.mm b/backends/apple/metal/runtime/ops/op_sdpa.mm index fdaabcf6b0b..558c16b50e5 100644 --- a/backends/apple/metal/runtime/ops/op_sdpa.mm +++ b/backends/apple/metal/runtime/ops/op_sdpa.mm @@ -226,7 +226,8 @@ #define INSTANTIATE_SDPA_VECTOR_HEADS(DTYPE) \ INSTANTIATE_SDPA_VECTOR(DTYPE, 64, 64); \ INSTANTIATE_SDPA_VECTOR(DTYPE, 96, 96); \ - INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128); + INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128); \ + INSTANTIATE_SDPA_VECTOR(DTYPE, 256, 256); INSTANTIATE_SDPA_VECTOR_HEADS(float); INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); @@ -430,11 +431,11 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( throw std::runtime_error("Unsupported dtype for Metal SDPA kernel"); } - // Select head_dim - must match exactly one of the supported sizes (64, 96, 128) + // Select head_dim - must match exactly one of the supported sizes (64, 96, 128, 256) int64_t head_dim = headSize; - if (head_dim != 64 && head_dim != 96 && head_dim != 128) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, or 128)", head_dim); - throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, or 128"); + if (head_dim != 64 && head_dim != 96 && head_dim != 128 && head_dim != 256) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, 128, or 256)", head_dim); + throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, 128, or 256"); } std::string kernel_name = "sdpa_vector_" + type_name + "_" + std::to_string(head_dim) + "_" + std::to_string(head_dim); diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 2bf14a0d10b..97ffed3b33c 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -639,6 +639,31 @@ def __init__(self): } +# ------------------------------------------------------------------------- +# SDPA with head_dim=256 (Qwen 3.5 MoE) +# ------------------------------------------------------------------------- + + +class SDPAHeadDim256(nn.Module): + """SDPA with head_dim=256, required by Qwen 3.5 MoE full attention layers.""" + + def forward( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + + +MODULE_REGISTRY["sdpa_head_dim_256"] = { + "model_class": SDPAHeadDim256, + "input_shapes": [(1, 4, 8, 256), (1, 4, 8, 256), (1, 4, 8, 256)], + "description": "SDPA with head_dim=256 (Qwen 3.5 MoE)", + "atol_float32": 1e-4, + "atol_bfloat16": 5e-2, +} + + # ============================================================================= # Helper Functions # ============================================================================= From e7a7acc6becd8a6986cc73e21a1d3109c8d95b57 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Apr 2026 18:44:41 -0400 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/utils.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp index f0dc57997ae..be61f013c6e 100644 --- a/backends/apple/metal/runtime/shims/utils.cpp +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -37,10 +37,7 @@ AOTITorchError validate_dtype(int32_t dtype) { return Error::Ok; } - ET_LOG( - Error, - "Unsupported dtype: %d", - dtype); + ET_LOG(Error, "Unsupported dtype: %d", dtype); return Error::InvalidArgument; }