From e9958cfc1780ae9ce818ed958fc8a50b7b98d4c2 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 26 Jun 2025 08:37:58 -0700 Subject: [PATCH 1/2] [ET-VK][ez] Explicitly skip marking output nodes that are mutable buffers Pull Request resolved: https://github.com/pytorch/executorch/pull/11983 ## Changes * Move the logic skipping output nodes that are mutable buffers from runtime to AOT ## Context A `fx.Graph` may return nodes that are mutable buffers: ``` class GraphModule(torch.nn.Module): def forward(self, p_wrapped_module_wq_weight: "f32[2048, 2048]", p_wrapped_module_wk_weight: "f32[512, 2048]", p_wrapped_module_wv_weight: "f32[512, 2048]", p_wrapped_module_wo_weight: "f32[2048, 2048]", b_wrapped_module_kv_cache_k_cache: "f32[1, 2048, 8, 64]", b_wrapped_module_kv_cache_v_cache: "f32[1, 2048, 8, 64]", x: "f32[1, s27, 2048]", freqs_cos: "f32[s27, 32]", freqs_sin: "f32[s27, 32]", input_pos: "i64[1]"): sym_size: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1) ... # b_wrapped_module_kv_cache_*_cache are mutable buffers # getitem_2 and getitem_3 are derived from mutable buffers, hence they are # themselves mutable buffers auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = getitem_1, cache = b_wrapped_module_kv_cache_k_cache, start_pos = _local_scalar_dense_1); getitem_1 = b_wrapped_module_kv_cache_k_cache = None getitem_2: "f32[1, 2048, 8, 64]" = auto_functionalized[1]; auto_functionalized = None auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = aten_view_copy_default_8, cache = b_wrapped_module_kv_cache_v_cache, start_pos = _local_scalar_dense_1); aten_view_copy_default_8 = b_wrapped_module_kv_cache_v_cache = _local_scalar_dense_1 = None getitem_3: "f32[1, 2048, 8, 64]" = auto_functionalized_1[1]; auto_functionalized_1 = None ... aten_permute_copy_default_3: "f32[2048, 2048]" = executorch_exir_dialects_edge__ops_aten_permute_copy_default(p_wrapped_module_wo_weight, [1, 0]); p_wrapped_module_wo_weight = None aten_view_copy_default_10: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_view_copy_default_9, [sym_size, 2048]); aten_view_copy_default_9 = None aten_mm_default_3: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_mm_default(aten_view_copy_default_10, aten_permute_copy_default_3); aten_view_copy_default_10 = aten_permute_copy_default_3 = None aten_view_copy_default_11: "f32[1, s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_mm_default_3, [1, sym_size, 2048]); aten_mm_default_3 = sym_size = None # getitem_2 and getitem_3 are returned as outputs, presumably to prevent the # update_cache calls from being removed due to dead code elimination return (getitem_2, getitem_3, aten_view_copy_default_11, None) ``` In the graph signature of the `ExportedProgram` these show up as `BUFFER_MUTATION` outputs ``` Graph signature: # inputs p_wrapped_module_wq_weight: PARAMETER target='wrapped_module.wq.weight' p_wrapped_module_wk_weight: PARAMETER target='wrapped_module.wk.weight' p_wrapped_module_wv_weight: PARAMETER target='wrapped_module.wv.weight' p_wrapped_module_wo_weight: PARAMETER target='wrapped_module.wo.weight' b_wrapped_module_kv_cache_k_cache: BUFFER target='wrapped_module.kv_cache.k_cache' persistent=True b_wrapped_module_kv_cache_v_cache: BUFFER target='wrapped_module.kv_cache.v_cache' persistent=True x: USER_INPUT freqs_cos: USER_INPUT freqs_sin: USER_INPUT input_pos: USER_INPUT # outputs getitem_2: BUFFER_MUTATION target='wrapped_module.kv_cache.k_cache' getitem_3: BUFFER_MUTATION target='wrapped_module.kv_cache.v_cache' aten_view_copy_default_11: USER_OUTPUT : USER_OUTPUT ``` Although these nodes are technically returned by the `fx.Graph`, `BUFFER_MUTATION` outputs are not included in the delegate call schema. Since the Vulkan delegate serialization uses the output node to mark which values are returned as outputs, this could result in a mismatch betwen the outputs of the Vulkan delegate and the outputs expected by the ExecuTorch runtime. ## Motivation Previously, this mismatch was addressed in the runtime, by skipping the processing of non-tensor outputs. However, this solution does not account for the fact that in some models, paramters of the model may be returned as outputs. In this case, those parameter outputs would be skipped but the ExecuTorch runtime would still expect to receive them as outputs. To solve the problem properly, this diff changes the serialization logic to check if an output node is a mutable buffer, and skip marking it as an output if so. In the runtime, all output nodes are processed instead of only processing tensor outputs. ghstack-source-id: 292864908 @exported-using-ghexport Differential Revision: [D77281491](https://our.internmc.facebook.com/intern/diff/D77281491/) --- backends/vulkan/runtime/VulkanBackend.cpp | 16 +++++++++------- backends/vulkan/runtime/graph/ComputeGraph.cpp | 8 ++++++++ backends/vulkan/runtime/graph/ComputeGraph.h | 2 ++ .../vulkan/serialization/vulkan_graph_builder.py | 6 ++++++ backends/vulkan/utils.py | 9 +++++++++ 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 218bd31d1b4..7077a9df59c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -359,15 +359,11 @@ class GraphBuilder { vkFn(*compute_graph_, args); } - // Parse the outputs, which will be mostly tensors. For some reason, - // mutable buffers are shown to be returned in the fx.Graph but do not get - // returned by the delegate; this may be an implementation detail of how the - // executorch emitter handles mutable buffers. + // Parse the outputs, which will be mostly tensors but may contain tensorref + // values as well if the source graph returns parameter nodes. for (const uint32_t fb_id : *flatbuffer_->output_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); - if (compute_graph_->val_is_tensor(ref)) { - compute_graph_->set_output_tensor(ref); - } + compute_graph_->set_output_value(ref); } if (compute_graph_->graphconfig().enable_querypool) { @@ -609,6 +605,12 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->outputs()[i].staging, args[o]->toTensor().mutable_data_ptr(), args[o]->toTensor().numel()); + } + // TensorRef values represent constant tensors which will not have been + // modified by the graph execution. Therefore, if a constant tensor is + // returned as an output, no action is required. + else if (compute_graph->val_is_tref(oref)) { + continue; } else { VK_THROW( "Could not handle output with type ", diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index b63f89e299d..cb14a41e98a 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -519,6 +519,14 @@ ValueRef ComputeGraph::set_output_tensor( return idx; } +ValueRef ComputeGraph::set_output_value(const ValueRef idx) { + if (values_.at(idx).isTensor()) { + return set_output_tensor(idx); + } + outputs_.push_back({idx, kDummyValueRef}); + return idx; +} + vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer( const ValueRef idx) { if (values_.at(idx).isInt()) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index eac632e6d35..78135a434e5 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -658,6 +658,8 @@ class ComputeGraph final { ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); + ValueRef set_output_value(const ValueRef idx); + template vkapi::BufferBindInfo create_params_buffer(const Block& data) { param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data)); diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index d21d33b75da..5bae0475c28 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -20,6 +20,7 @@ from executorch.backends.vulkan.utils import ( is_constant, is_get_attr_node, + is_mutable_buffer_node, is_param_node, is_symint_node, ) @@ -382,6 +383,11 @@ def process_output_node(self, node: Node) -> None: "the output node is being serialized before its corresponding " "internal node which is not allowed." ) + # Mutable buffers outputs are not included as an output to the + # delegate call. Skip marking them as an output. + if is_mutable_buffer_node(out_node, self.program): + continue + self.output_ids.append(self.node_to_value_ids[out_node]) def process_node(self, node: Node, call_node_debug_hdl: int) -> None: diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 5d57ce1e7be..d71c0a35776 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -84,6 +84,15 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: ) +def is_mutable_buffer_node( + node: torch.fx.Node, exported_program: ExportedProgram +) -> bool: + if node.target not in exported_program.graph_signature.inputs_to_buffers: + return False + buf = exported_program.graph_signature.inputs_to_buffers[node.target] + return buf in exported_program.graph_signature.buffers_to_mutate.values() + + def is_symint_node(node: torch.fx.Node) -> bool: """ Returns true if the given node produces a SymInt value From e5e3467302280fd34dffcd244a84af40f29136ad Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 26 Jun 2025 08:38:01 -0700 Subject: [PATCH 2/2] [ET-VK][ez] Handle zero-element tensors when building Vulkan graph Pull Request resolved: https://github.com/pytorch/executorch/pull/11984 ## Changes As title. ## Motivation Some models may have parameter tensors which are zero-shape (i.e. no elements). In this case, trying to serialize the tensor data will result in a null pointer exception. ghstack-source-id: 292864904 Differential Revision: [D77281492](https://our.internmc.facebook.com/intern/diff/D77281492/) --- .../serialization/vulkan_graph_serialize.py | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/backends/vulkan/serialization/vulkan_graph_serialize.py b/backends/vulkan/serialization/vulkan_graph_serialize.py index ebb13bbb97d..2ceedf73d10 100644 --- a/backends/vulkan/serialization/vulkan_graph_serialize.py +++ b/backends/vulkan/serialization/vulkan_graph_serialize.py @@ -191,19 +191,23 @@ def serialize_constant_tensors( current_offset = len(raw_bytes) for tensor in const_tensors: - array_type = ctypes.c_char * tensor.untyped_storage().nbytes() - array = ctypes.cast( - tensor.untyped_storage().data_ptr(), - ctypes.POINTER(array_type), - ).contents - - tensor_bytes = bytes(array) - # Pad the tensor bytes to the next 16 byte boundary - raw_bytes += tensor_bytes - raw_bytes += b"\x00" * padding_required(len(tensor_bytes)) - - vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes))) - current_offset += aligned_size(len(tensor_bytes)) + if tensor.numel() == 0: + vk_graph.constants.append(VkBytes(current_offset, 0)) + continue + else: + array_type = ctypes.c_char * tensor.untyped_storage().nbytes() + array = ctypes.cast( + tensor.untyped_storage().data_ptr(), + ctypes.POINTER(array_type), + ).contents + + tensor_bytes = bytes(array) + # Pad the tensor bytes to the next 16 byte boundary + raw_bytes += tensor_bytes + raw_bytes += b"\x00" * padding_required(len(tensor_bytes)) + + vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes))) + current_offset += aligned_size(len(tensor_bytes)) def serialize_custom_shaders(