diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 218bd31d1b4..7077a9df59c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -359,15 +359,11 @@ class GraphBuilder { vkFn(*compute_graph_, args); } - // Parse the outputs, which will be mostly tensors. For some reason, - // mutable buffers are shown to be returned in the fx.Graph but do not get - // returned by the delegate; this may be an implementation detail of how the - // executorch emitter handles mutable buffers. + // Parse the outputs, which will be mostly tensors but may contain tensorref + // values as well if the source graph returns parameter nodes. for (const uint32_t fb_id : *flatbuffer_->output_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); - if (compute_graph_->val_is_tensor(ref)) { - compute_graph_->set_output_tensor(ref); - } + compute_graph_->set_output_value(ref); } if (compute_graph_->graphconfig().enable_querypool) { @@ -609,6 +605,12 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->outputs()[i].staging, args[o]->toTensor().mutable_data_ptr(), args[o]->toTensor().numel()); + } + // TensorRef values represent constant tensors which will not have been + // modified by the graph execution. Therefore, if a constant tensor is + // returned as an output, no action is required. + else if (compute_graph->val_is_tref(oref)) { + continue; } else { VK_THROW( "Could not handle output with type ", diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index b63f89e299d..cb14a41e98a 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -519,6 +519,14 @@ ValueRef ComputeGraph::set_output_tensor( return idx; } +ValueRef ComputeGraph::set_output_value(const ValueRef idx) { + if (values_.at(idx).isTensor()) { + return set_output_tensor(idx); + } + outputs_.push_back({idx, kDummyValueRef}); + return idx; +} + vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer( const ValueRef idx) { if (values_.at(idx).isInt()) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index eac632e6d35..78135a434e5 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -658,6 +658,8 @@ class ComputeGraph final { ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); + ValueRef set_output_value(const ValueRef idx); + template vkapi::BufferBindInfo create_params_buffer(const Block& data) { param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data)); diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index d21d33b75da..5bae0475c28 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -20,6 +20,7 @@ from executorch.backends.vulkan.utils import ( is_constant, is_get_attr_node, + is_mutable_buffer_node, is_param_node, is_symint_node, ) @@ -382,6 +383,11 @@ def process_output_node(self, node: Node) -> None: "the output node is being serialized before its corresponding " "internal node which is not allowed." ) + # Mutable buffers outputs are not included as an output to the + # delegate call. Skip marking them as an output. + if is_mutable_buffer_node(out_node, self.program): + continue + self.output_ids.append(self.node_to_value_ids[out_node]) def process_node(self, node: Node, call_node_debug_hdl: int) -> None: diff --git a/backends/vulkan/serialization/vulkan_graph_serialize.py b/backends/vulkan/serialization/vulkan_graph_serialize.py index ebb13bbb97d..2ceedf73d10 100644 --- a/backends/vulkan/serialization/vulkan_graph_serialize.py +++ b/backends/vulkan/serialization/vulkan_graph_serialize.py @@ -191,19 +191,23 @@ def serialize_constant_tensors( current_offset = len(raw_bytes) for tensor in const_tensors: - array_type = ctypes.c_char * tensor.untyped_storage().nbytes() - array = ctypes.cast( - tensor.untyped_storage().data_ptr(), - ctypes.POINTER(array_type), - ).contents - - tensor_bytes = bytes(array) - # Pad the tensor bytes to the next 16 byte boundary - raw_bytes += tensor_bytes - raw_bytes += b"\x00" * padding_required(len(tensor_bytes)) - - vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes))) - current_offset += aligned_size(len(tensor_bytes)) + if tensor.numel() == 0: + vk_graph.constants.append(VkBytes(current_offset, 0)) + continue + else: + array_type = ctypes.c_char * tensor.untyped_storage().nbytes() + array = ctypes.cast( + tensor.untyped_storage().data_ptr(), + ctypes.POINTER(array_type), + ).contents + + tensor_bytes = bytes(array) + # Pad the tensor bytes to the next 16 byte boundary + raw_bytes += tensor_bytes + raw_bytes += b"\x00" * padding_required(len(tensor_bytes)) + + vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes))) + current_offset += aligned_size(len(tensor_bytes)) def serialize_custom_shaders( diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 5d57ce1e7be..d71c0a35776 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -84,6 +84,15 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: ) +def is_mutable_buffer_node( + node: torch.fx.Node, exported_program: ExportedProgram +) -> bool: + if node.target not in exported_program.graph_signature.inputs_to_buffers: + return False + buf = exported_program.graph_signature.inputs_to_buffers[node.target] + return buf in exported_program.graph_signature.buffers_to_mutate.values() + + def is_symint_node(node: torch.fx.Node) -> bool: """ Returns true if the given node produces a SymInt value