From a3a42e425c6a7fa71f34f7f3cd7ad4d9fcaa85ca Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Apr 2026 12:25:09 -0400 Subject: [PATCH 1/5] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/utils.cpp | 10 ++++------ backends/apple/metal/runtime/shims/utils.h | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp index 50b46ec69d4..f0dc57997ae 100644 --- a/backends/apple/metal/runtime/shims/utils.cpp +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -20,8 +20,10 @@ extern "C" { bool is_dtype_supported_in_et_metal(int32_t dtype) { switch (dtype) { case static_cast(SupportedDTypes::UINT8): + case static_cast(SupportedDTypes::INT32): case static_cast(SupportedDTypes::INT64): case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BOOL): case static_cast(SupportedDTypes::BFLOAT16): return true; default: @@ -37,12 +39,8 @@ AOTITorchError validate_dtype(int32_t dtype) { ET_LOG( Error, - "Unsupported dtype: %d. Supported dtypes: %d (uint8), %d (int64), %d (float32), %d (bfloat16)", - dtype, - static_cast(SupportedDTypes::UINT8), - static_cast(SupportedDTypes::INT64), - static_cast(SupportedDTypes::FLOAT32), - static_cast(SupportedDTypes::BFLOAT16)); + "Unsupported dtype: %d", + dtype); return Error::InvalidArgument; } diff --git a/backends/apple/metal/runtime/shims/utils.h b/backends/apple/metal/runtime/shims/utils.h index 60412812b16..d749ee77947 100644 --- a/backends/apple/metal/runtime/shims/utils.h +++ b/backends/apple/metal/runtime/shims/utils.h @@ -22,12 +22,12 @@ enum class SupportedDTypes : int32_t { UINT8 = 0, // PyTorch's uint8 dtype code // INT8 = 1, // PyTorch's int8 dtype code // INT16 = 2, // PyTorch's int16 dtype code - // INT32 = 3, // PyTorch's int32 dtype code + INT32 = 3, // PyTorch's int32 dtype code INT64 = 4, // PyTorch's int64 dtype code // FLOAT16 = 5, // PyTorch's float16 dtype code FLOAT32 = 6, // PyTorch's float32 dtype code // FLOAT64 = 7, // PyTorch's float64 dtype code - // BOOL = 11, // PyTorch's bool dtype code + BOOL = 11, // PyTorch's bool dtype code BFLOAT16 = 15 // PyTorch's bfloat16 dtype code }; From 1c965c6b34623b738a308c6bd22a9ca8d0019883 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Apr 2026 12:25:14 -0400 Subject: [PATCH 2/5] Update [ghstack-poisoned] --- backends/apple/metal/runtime/ops/op_sdpa.mm | 11 ++++----- backends/apple/metal/tests/test_modules.py | 25 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/backends/apple/metal/runtime/ops/op_sdpa.mm b/backends/apple/metal/runtime/ops/op_sdpa.mm index fdaabcf6b0b..558c16b50e5 100644 --- a/backends/apple/metal/runtime/ops/op_sdpa.mm +++ b/backends/apple/metal/runtime/ops/op_sdpa.mm @@ -226,7 +226,8 @@ #define INSTANTIATE_SDPA_VECTOR_HEADS(DTYPE) \ INSTANTIATE_SDPA_VECTOR(DTYPE, 64, 64); \ INSTANTIATE_SDPA_VECTOR(DTYPE, 96, 96); \ - INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128); + INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128); \ + INSTANTIATE_SDPA_VECTOR(DTYPE, 256, 256); INSTANTIATE_SDPA_VECTOR_HEADS(float); INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); @@ -430,11 +431,11 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( throw std::runtime_error("Unsupported dtype for Metal SDPA kernel"); } - // Select head_dim - must match exactly one of the supported sizes (64, 96, 128) + // Select head_dim - must match exactly one of the supported sizes (64, 96, 128, 256) int64_t head_dim = headSize; - if (head_dim != 64 && head_dim != 96 && head_dim != 128) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, or 128)", head_dim); - throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, or 128"); + if (head_dim != 64 && head_dim != 96 && head_dim != 128 && head_dim != 256) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, 128, or 256)", head_dim); + throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, 128, or 256"); } std::string kernel_name = "sdpa_vector_" + type_name + "_" + std::to_string(head_dim) + "_" + std::to_string(head_dim); diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 2bf14a0d10b..97ffed3b33c 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -639,6 +639,31 @@ def __init__(self): } +# ------------------------------------------------------------------------- +# SDPA with head_dim=256 (Qwen 3.5 MoE) +# ------------------------------------------------------------------------- + + +class SDPAHeadDim256(nn.Module): + """SDPA with head_dim=256, required by Qwen 3.5 MoE full attention layers.""" + + def forward( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + + +MODULE_REGISTRY["sdpa_head_dim_256"] = { + "model_class": SDPAHeadDim256, + "input_shapes": [(1, 4, 8, 256), (1, 4, 8, 256), (1, 4, 8, 256)], + "description": "SDPA with head_dim=256 (Qwen 3.5 MoE)", + "atol_float32": 1e-4, + "atol_bfloat16": 5e-2, +} + + # ============================================================================= # Helper Functions # ============================================================================= From 1be53ab1fbde5632ef0a4cb245b026344cf834a9 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Apr 2026 12:25:19 -0400 Subject: [PATCH 3/5] Update [ghstack-poisoned] --- backends/apple/metal/CMakeLists.txt | 1 + backends/apple/metal/metal_backend.py | 1 + backends/apple/metal/runtime/ops/op_topk.mm | 303 ++++++++++++++++++++ backends/apple/metal/tests/test_modules.py | 25 ++ 4 files changed, 330 insertions(+) create mode 100644 backends/apple/metal/runtime/ops/op_topk.mm diff --git a/backends/apple/metal/CMakeLists.txt b/backends/apple/metal/CMakeLists.txt index 85ffbfc9cc5..17691d29d29 100644 --- a/backends/apple/metal/CMakeLists.txt +++ b/backends/apple/metal/CMakeLists.txt @@ -48,6 +48,7 @@ set(_aoti_metal_sources runtime/ops/op_linear_4bit.mm runtime/ops/op_mm.mm runtime/ops/op_sdpa.mm + runtime/ops/op_topk.mm ) add_library(metal_backend STATIC ${_aoti_metal_sources}) diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 5ddd5e13d88..90d0551fb1a 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -36,6 +36,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]: "aoti_torch_mps_mm_out": None, "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, "torchao::_linear_fp_act_4bit_weight": None, + "at::_ops::topk::call": None, } @classmethod diff --git a/backends/apple/metal/runtime/ops/op_topk.mm b/backends/apple/metal/runtime/ops/op_topk.mm new file mode 100644 index 00000000000..7dce002419a --- /dev/null +++ b/backends/apple/metal/runtime/ops/op_topk.mm @@ -0,0 +1,303 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Top-k operator using MPSGraph. +// Used by MoE routing (torch.topk in SparseMoE.forward). + +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +AOTITorchError aoti_torch_mps_topk( + AOTITensorHandle self, + int64_t k, + int64_t dim, + int32_t largest, + int32_t sorted, + AOTITensorHandle* ret0, // values + AOTITensorHandle* ret1) { // indices + + ET_LOG(Debug, "aoti_torch_mps_topk: k=%lld, dim=%lld, largest=%d, sorted=%d", + k, dim, largest, sorted); + + if (!self || !ret0 || !ret1) { + ET_LOG(Error, "aoti_torch_mps_topk: null tensor handles"); + return Error::InvalidArgument; + } + + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_topk: Failed to get Metal stream"); + return Error::Internal; + } + + try { + @autoreleasepool { + auto* self_tensor = reinterpret_cast(self); + + int64_t ndim = self_tensor->dim(); + if (dim < 0) { + dim += ndim; + } + if (dim < 0 || dim >= ndim) { + ET_LOG(Error, "aoti_torch_mps_topk: invalid dim"); + return Error::InvalidArgument; + } + + int64_t dim_size = self_tensor->sizes()[dim]; + if (k > dim_size) { + ET_LOG(Error, "aoti_torch_mps_topk: k=%lld > dim_size=%lld\n", k, dim_size); + return Error::InvalidArgument; + } + + // Determine dtype + int32_t dtype = static_cast(self_tensor->scalar_type()); + size_t element_size; + MPSDataType mps_dtype; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + element_size = sizeof(float); + mps_dtype = MPSDataTypeFloat32; + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + element_size = sizeof(uint16_t); + mps_dtype = MPSDataTypeBFloat16; + } else { + ET_LOG(Error, "aoti_torch_mps_topk: Unsupported dtype %d", dtype); + return Error::InvalidArgument; + } + + // Build output shape: same as input but with dim replaced by k + std::vector out_sizes; + for (int64_t i = 0; i < ndim; i++) { + out_sizes.push_back(i == dim ? k : self_tensor->sizes()[i]); + } + + // Compute strides (contiguous) + std::vector out_strides(ndim); + out_strides[ndim - 1] = 1; + for (int64_t i = ndim - 2; i >= 0; i--) { + out_strides[i] = out_strides[i + 1] * out_sizes[i + 1]; + } + + // Total elements + size_t num_elements = 1; + for (auto s : out_sizes) num_elements *= s; + + // Allocate output buffers + size_t values_bytes = num_elements * element_size; + size_t indices_bytes = num_elements * sizeof(int32_t); + + void* values_ptr = nullptr; + void* indices_ptr = nullptr; + allocate_mtl_buffer(&values_ptr, values_bytes); + allocate_mtl_buffer(&indices_ptr, indices_bytes); + + // Build MPSGraph + // Convert input shape to NSArray + NSMutableArray* input_shape = [NSMutableArray arrayWithCapacity:ndim]; + for (int64_t i = 0; i < ndim; i++) { + [input_shape addObject:@(self_tensor->sizes()[i])]; + } + + // Check graph cache + GraphCacheKey cache_key; + cache_key.op_name = "topk"; + cache_key.shape_params.push_back(k); + cache_key.shape_params.push_back(dim); + cache_key.shape_params.push_back(largest); + for (int64_t i = 0; i < ndim; i++) { + cache_key.shape_params.push_back(self_tensor->sizes()[i]); + } + cache_key.dtype = dtype; + cache_key.transpose_flag = false; + + auto cache_it = graph_cache.find(cache_key); + if (cache_it != graph_cache.end()) { + cache_stats.hits++; + auto& cached = cache_it->second; + + id self_buffer = get_mtl_buffer(self_tensor, "topk", "self"); + id values_buffer = ptr_to_mtl_buffer[values_ptr]; + id indices_buffer = ptr_to_mtl_buffer[indices_ptr]; + + NSDictionary* feeds = @{ + cached.input1: [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype], + }; + + NSMutableArray* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim]; + for (int64_t i = 0; i < ndim; i++) { + [out_ns_shape addObject:@(out_sizes[i])]; + } + + NSDictionary* results = @{ + cached.output: [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype], + cached.input2: [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32], + }; + + stream->executeMPSGraph(cached.graph, feeds, results, SyncType::COMMIT); + } else { + cache_stats.misses++; + ET_LOG(Debug, "aoti_torch_mps_topk: cache miss, building graph"); + + @try { + MPSGraph* graph = [[MPSGraph alloc] init]; + MPSGraphTensor* input = [graph placeholderWithShape:input_shape + dataType:mps_dtype + name:@"self"]; + + // MPSGraph topK: returns (values, indices) along the last dimension. + // If dim != -1, we need to transpose dim to last, topk, then transpose back. + MPSGraphTensor* work = input; + bool need_transpose = (dim != ndim - 1); + + if (need_transpose) { + work = [graph transposeTensor:work dimension:dim withDimension:ndim - 1 name:nil]; + } + + // MPSGraph topKWithTensor returns along the last axis + NSArray* topk_results; + if (largest) { + topk_results = [graph topKWithSourceTensor:work k:(NSUInteger)k name:nil]; + } else { + // For smallest: negate, topk, negate back + MPSGraphTensor* neg = [graph negativeWithTensor:work name:nil]; + topk_results = [graph topKWithSourceTensor:neg k:(NSUInteger)k name:nil]; + topk_results = @[ + [graph negativeWithTensor:topk_results[0] name:nil], + topk_results[1] + ]; + } + + MPSGraphTensor* values_out = topk_results[0]; + MPSGraphTensor* indices_out = topk_results[1]; + + if (need_transpose) { + values_out = [graph transposeTensor:values_out dimension:dim withDimension:ndim - 1 name:nil]; + indices_out = [graph transposeTensor:indices_out dimension:dim withDimension:ndim - 1 name:nil]; + } + + // Cache the graph + CachedGraph cached_graph; + cached_graph.graph = graph; + cached_graph.input1 = input; + cached_graph.input2 = indices_out; // reuse input2 slot for indices output + cached_graph.output = values_out; + graph_cache[cache_key] = cached_graph; + + // Execute + id self_buffer = get_mtl_buffer(self_tensor, "topk", "self"); + id values_buffer = ptr_to_mtl_buffer[values_ptr]; + id indices_buffer = ptr_to_mtl_buffer[indices_ptr]; + + NSDictionary* feeds = @{ + input: [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype], + }; + + NSMutableArray* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim]; + for (int64_t i = 0; i < ndim; i++) { + [out_ns_shape addObject:@(out_sizes[i])]; + } + + NSDictionary* results = @{ + values_out: [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype], + indices_out: [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32], + }; + + ET_LOG(Debug, "aoti_torch_mps_topk: executing MPSGraph"); + stream->executeMPSGraph(graph, feeds, results, SyncType::COMMIT); + ET_LOG(Debug, "aoti_torch_mps_topk: MPSGraph done"); + } @catch (NSException* e) { + ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s", + e.name.UTF8String, e.reason.UTF8String); + throw std::runtime_error(std::string("MPSGraph topk failed: ") + e.reason.UTF8String); + } + } + + // Create output tensor handles + // Values tensor + AOTITensorHandle values_handle = nullptr; + aoti_torch_create_tensor_from_blob_v2( + values_ptr, ndim, out_sizes.data(), out_strides.data(), + 0, dtype, 13, 0, &values_handle, 0, nullptr, 0); + + if (!values_handle) { + ET_LOG(Error, "aoti_torch_mps_topk: failed to create values tensor"); + aoti_torch_mps_free(values_ptr); + aoti_torch_mps_free(indices_ptr); + return Error::Internal; + } + ET_LOG(Debug, "aoti_torch_mps_topk: values tensor created"); + + extern std::unordered_map memory_to_n_tensor; + memory_to_n_tensor[values_ptr] = 1; + + // Indices tensor — MPSGraph outputs int32, AOTInductor expects int64. + // Allocate a new int64 buffer and convert. + size_t indices_i64_bytes = num_elements * sizeof(int64_t); + void* indices_i64_ptr = nullptr; + allocate_mtl_buffer(&indices_i64_ptr, indices_i64_bytes); + + // Copy int32 → int64 on CPU (small tensor, fast) + { + auto* stream_sync = getCurrentMetalStream(); + stream_sync->synchronize(SyncType::COMMIT_AND_WAIT); + + int32_t* src = reinterpret_cast(indices_ptr); + int64_t* dst = reinterpret_cast(indices_i64_ptr); + for (size_t i = 0; i < num_elements; i++) { + dst[i] = static_cast(src[i]); + } + } + aoti_torch_mps_free(indices_ptr); + + int32_t indices_dtype = static_cast(exec_aten::ScalarType::Long); + std::vector indices_strides(ndim); + indices_strides[ndim - 1] = 1; + for (int64_t i = ndim - 2; i >= 0; i--) { + indices_strides[i] = indices_strides[i + 1] * out_sizes[i + 1]; + } + + AOTITensorHandle indices_handle = nullptr; + AOTITorchError idx_err = aoti_torch_create_tensor_from_blob_v2( + indices_i64_ptr, ndim, out_sizes.data(), indices_strides.data(), + 0, indices_dtype, 13, 0, &indices_handle, 0, nullptr, 0); + + if (idx_err != Error::Ok || !indices_handle) { + ET_LOG(Error, "aoti_torch_mps_topk: failed to create indices tensor, err=%d", idx_err); + aoti_torch_mps_free(indices_i64_ptr); + return Error::Internal; + } + memory_to_n_tensor[indices_i64_ptr] = 1; + + *ret0 = values_handle; + *ret1 = indices_handle; + + ET_LOG(Debug, "aoti_torch_mps_topk: Completed successfully"); + + } // @autoreleasepool + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_topk exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_topk: unknown exception"); + return Error::Internal; + } +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 97ffed3b33c..9ca529ecdf9 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -664,6 +664,31 @@ def forward( } +# ------------------------------------------------------------------------- +# Top-k (MoE expert routing) +# ------------------------------------------------------------------------- + + +class TopK(nn.Module): + """Top-k routing used by MoE expert selection.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(64, 8, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + scores = self.linear(x) + values, indices = torch.topk(scores, 2, dim=-1) + return values + + +MODULE_REGISTRY["topk"] = { + "model_class": TopK, + "input_shapes": [(4, 64)], + "description": "Top-k routing for MoE expert selection", +} + + # ============================================================================= # Helper Functions # ============================================================================= From e7a7acc6becd8a6986cc73e21a1d3109c8d95b57 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Apr 2026 18:44:41 -0400 Subject: [PATCH 4/5] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/utils.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp index f0dc57997ae..be61f013c6e 100644 --- a/backends/apple/metal/runtime/shims/utils.cpp +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -37,10 +37,7 @@ AOTITorchError validate_dtype(int32_t dtype) { return Error::Ok; } - ET_LOG( - Error, - "Unsupported dtype: %d", - dtype); + ET_LOG(Error, "Unsupported dtype: %d", dtype); return Error::InvalidArgument; } From 98d2f8144219df980e2a1886f410a15a3a537eda Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 20 Apr 2026 14:12:47 -0400 Subject: [PATCH 5/5] Update [ghstack-poisoned] --- backends/apple/metal/runtime/ops/op_topk.mm | 189 ++++++++++---------- 1 file changed, 97 insertions(+), 92 deletions(-) diff --git a/backends/apple/metal/runtime/ops/op_topk.mm b/backends/apple/metal/runtime/ops/op_topk.mm index 7dce002419a..8d1b6722466 100644 --- a/backends/apple/metal/runtime/ops/op_topk.mm +++ b/backends/apple/metal/runtime/ops/op_topk.mm @@ -8,6 +8,7 @@ // Top-k operator using MPSGraph. // Used by MoE routing (torch.topk in SparseMoE.forward). +// Note: sorted parameter is accepted but MPSGraph always returns sorted results. #include @@ -40,6 +41,9 @@ AOTITorchError aoti_torch_mps_topk( return Error::Internal; } + void* values_ptr = nullptr; + void* indices_ptr = nullptr; + try { @autoreleasepool { auto* self_tensor = reinterpret_cast(self); @@ -55,7 +59,7 @@ AOTITorchError aoti_torch_mps_topk( int64_t dim_size = self_tensor->sizes()[dim]; if (k > dim_size) { - ET_LOG(Error, "aoti_torch_mps_topk: k=%lld > dim_size=%lld\n", k, dim_size); + ET_LOG(Error, "aoti_torch_mps_topk: k=%lld > dim_size=%lld", k, dim_size); return Error::InvalidArgument; } @@ -96,18 +100,20 @@ AOTITorchError aoti_torch_mps_topk( size_t values_bytes = num_elements * element_size; size_t indices_bytes = num_elements * sizeof(int32_t); - void* values_ptr = nullptr; - void* indices_ptr = nullptr; allocate_mtl_buffer(&values_ptr, values_bytes); allocate_mtl_buffer(&indices_ptr, indices_bytes); - // Build MPSGraph // Convert input shape to NSArray NSMutableArray* input_shape = [NSMutableArray arrayWithCapacity:ndim]; for (int64_t i = 0; i < ndim; i++) { [input_shape addObject:@(self_tensor->sizes()[i])]; } + NSMutableArray* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim]; + for (int64_t i = 0; i < ndim; i++) { + [out_ns_shape addObject:@(out_sizes[i])]; + } + // Check graph cache GraphCacheKey cache_key; cache_key.op_name = "topk"; @@ -120,101 +126,103 @@ AOTITorchError aoti_torch_mps_topk( cache_key.dtype = dtype; cache_key.transpose_flag = false; + stream->endKernelCoalescing(); + + id self_buffer = get_mtl_buffer(self_tensor, "topk", "self"); + id values_buffer = ptr_to_mtl_buffer[values_ptr]; + id indices_buffer = ptr_to_mtl_buffer[indices_ptr]; + auto cache_it = graph_cache.find(cache_key); if (cache_it != graph_cache.end()) { cache_stats.hits++; + cache_stats.logStats(); auto& cached = cache_it->second; - id self_buffer = get_mtl_buffer(self_tensor, "topk", "self"); - id values_buffer = ptr_to_mtl_buffer[values_ptr]; - id indices_buffer = ptr_to_mtl_buffer[indices_ptr]; + MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype]; + MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype]; + MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32]; NSDictionary* feeds = @{ - cached.input1: [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype], + cached.input1: selfData, }; - - NSMutableArray* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim]; - for (int64_t i = 0; i < ndim; i++) { - [out_ns_shape addObject:@(out_sizes[i])]; - } - NSDictionary* results = @{ - cached.output: [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype], - cached.input2: [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32], + cached.output: valuesData, + cached.input2: indicesData, }; - stream->executeMPSGraph(cached.graph, feeds, results, SyncType::COMMIT); + @try { + stream->executeMPSGraph(cached.graph, feeds, results, SyncType::COMMIT); + } @catch (NSException* e) { + ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s", + e.name.UTF8String, e.reason.UTF8String); + throw std::runtime_error(std::string("MPSGraph topk failed: ") + e.reason.UTF8String); + } + + [selfData release]; + [valuesData release]; + [indicesData release]; } else { cache_stats.misses++; + cache_stats.logStats(); ET_LOG(Debug, "aoti_torch_mps_topk: cache miss, building graph"); @try { - MPSGraph* graph = [[MPSGraph alloc] init]; - MPSGraphTensor* input = [graph placeholderWithShape:input_shape - dataType:mps_dtype - name:@"self"]; - - // MPSGraph topK: returns (values, indices) along the last dimension. - // If dim != -1, we need to transpose dim to last, topk, then transpose back. - MPSGraphTensor* work = input; - bool need_transpose = (dim != ndim - 1); - - if (need_transpose) { - work = [graph transposeTensor:work dimension:dim withDimension:ndim - 1 name:nil]; - } - - // MPSGraph topKWithTensor returns along the last axis - NSArray* topk_results; - if (largest) { - topk_results = [graph topKWithSourceTensor:work k:(NSUInteger)k name:nil]; - } else { - // For smallest: negate, topk, negate back - MPSGraphTensor* neg = [graph negativeWithTensor:work name:nil]; - topk_results = [graph topKWithSourceTensor:neg k:(NSUInteger)k name:nil]; - topk_results = @[ - [graph negativeWithTensor:topk_results[0] name:nil], - topk_results[1] - ]; - } - - MPSGraphTensor* values_out = topk_results[0]; - MPSGraphTensor* indices_out = topk_results[1]; - - if (need_transpose) { - values_out = [graph transposeTensor:values_out dimension:dim withDimension:ndim - 1 name:nil]; - indices_out = [graph transposeTensor:indices_out dimension:dim withDimension:ndim - 1 name:nil]; - } - - // Cache the graph - CachedGraph cached_graph; - cached_graph.graph = graph; - cached_graph.input1 = input; - cached_graph.input2 = indices_out; // reuse input2 slot for indices output - cached_graph.output = values_out; - graph_cache[cache_key] = cached_graph; - - // Execute - id self_buffer = get_mtl_buffer(self_tensor, "topk", "self"); - id values_buffer = ptr_to_mtl_buffer[values_ptr]; - id indices_buffer = ptr_to_mtl_buffer[indices_ptr]; - - NSDictionary* feeds = @{ - input: [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype], - }; - - NSMutableArray* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim]; - for (int64_t i = 0; i < ndim; i++) { - [out_ns_shape addObject:@(out_sizes[i])]; - } - - NSDictionary* results = @{ - values_out: [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype], - indices_out: [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32], - }; - - ET_LOG(Debug, "aoti_torch_mps_topk: executing MPSGraph"); - stream->executeMPSGraph(graph, feeds, results, SyncType::COMMIT); - ET_LOG(Debug, "aoti_torch_mps_topk: MPSGraph done"); + MPSGraph* graph = [[MPSGraph alloc] init]; + MPSGraphTensor* input = [graph placeholderWithShape:input_shape + dataType:mps_dtype + name:@"self"]; + + MPSGraphTensor* work = input; + bool need_transpose = (dim != ndim - 1); + + if (need_transpose) { + work = [graph transposeTensor:work dimension:dim withDimension:ndim - 1 name:nil]; + } + + NSArray* topk_results; + if (largest) { + topk_results = [graph topKWithSourceTensor:work k:(NSUInteger)k name:nil]; + } else { + MPSGraphTensor* neg = [graph negativeWithTensor:work name:nil]; + topk_results = [graph topKWithSourceTensor:neg k:(NSUInteger)k name:nil]; + topk_results = @[ + [graph negativeWithTensor:topk_results[0] name:nil], + topk_results[1] + ]; + } + + MPSGraphTensor* values_out = topk_results[0]; + MPSGraphTensor* indices_out = topk_results[1]; + + if (need_transpose) { + values_out = [graph transposeTensor:values_out dimension:dim withDimension:ndim - 1 name:nil]; + indices_out = [graph transposeTensor:indices_out dimension:dim withDimension:ndim - 1 name:nil]; + } + + CachedGraph cached_graph; + cached_graph.graph = graph; + cached_graph.input1 = input; + cached_graph.input2 = indices_out; + cached_graph.output = values_out; + graph_cache[cache_key] = cached_graph; + + MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype]; + MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype]; + MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32]; + + NSDictionary* feeds = @{ + input: selfData, + }; + NSDictionary* results = @{ + values_out: valuesData, + indices_out: indicesData, + }; + + stream->executeMPSGraph(graph, feeds, results, SyncType::COMMIT); + + [selfData release]; + [valuesData release]; + [indicesData release]; } @catch (NSException* e) { ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s", e.name.UTF8String, e.reason.UTF8String); @@ -223,7 +231,6 @@ AOTITorchError aoti_torch_mps_topk( } // Create output tensor handles - // Values tensor AOTITensorHandle values_handle = nullptr; aoti_torch_create_tensor_from_blob_v2( values_ptr, ndim, out_sizes.data(), out_strides.data(), @@ -235,22 +242,17 @@ AOTITorchError aoti_torch_mps_topk( aoti_torch_mps_free(indices_ptr); return Error::Internal; } - ET_LOG(Debug, "aoti_torch_mps_topk: values tensor created"); - extern std::unordered_map memory_to_n_tensor; memory_to_n_tensor[values_ptr] = 1; // Indices tensor — MPSGraph outputs int32, AOTInductor expects int64. - // Allocate a new int64 buffer and convert. size_t indices_i64_bytes = num_elements * sizeof(int64_t); void* indices_i64_ptr = nullptr; allocate_mtl_buffer(&indices_i64_ptr, indices_i64_bytes); // Copy int32 → int64 on CPU (small tensor, fast) + stream->synchronize(SyncType::COMMIT_AND_WAIT); { - auto* stream_sync = getCurrentMetalStream(); - stream_sync->synchronize(SyncType::COMMIT_AND_WAIT); - int32_t* src = reinterpret_cast(indices_ptr); int64_t* dst = reinterpret_cast(indices_i64_ptr); for (size_t i = 0; i < num_elements; i++) { @@ -258,6 +260,7 @@ AOTITorchError aoti_torch_mps_topk( } } aoti_torch_mps_free(indices_ptr); + indices_ptr = nullptr; int32_t indices_dtype = static_cast(exec_aten::ScalarType::Long); std::vector indices_strides(ndim); @@ -281,17 +284,19 @@ AOTITorchError aoti_torch_mps_topk( *ret0 = values_handle; *ret1 = indices_handle; - ET_LOG(Debug, "aoti_torch_mps_topk: Completed successfully"); - } // @autoreleasepool return Error::Ok; } catch (const std::exception& e) { ET_LOG(Error, "aoti_torch_mps_topk exception: %s", e.what()); + if (values_ptr) aoti_torch_mps_free(values_ptr); + if (indices_ptr) aoti_torch_mps_free(indices_ptr); return Error::Internal; } catch (...) { ET_LOG(Error, "aoti_torch_mps_topk: unknown exception"); + if (values_ptr) aoti_torch_mps_free(values_ptr); + if (indices_ptr) aoti_torch_mps_free(indices_ptr); return Error::Internal; } }