diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index b14d0f6ab0b..f0b61e128bb 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -1135,8 +1135,6 @@ void ComputeGraph::prepack() { int i = 0; bool submitted = false; const bool reduce_peak_memory = total_constant_nbytes_ > 10 * MB; - // int count = 0; - context_->set_cmd(); for (std::unique_ptr& node : prepack_nodes_) { // Do not trigger on the first or last prepack node. diff --git a/backends/vulkan/runtime/graph/containers/Constant.h b/backends/vulkan/runtime/graph/containers/Constant.h index a18c284a219..690a25dd9c8 100644 --- a/backends/vulkan/runtime/graph/containers/Constant.h +++ b/backends/vulkan/runtime/graph/containers/Constant.h @@ -29,6 +29,12 @@ struct TensorRef final { // This will be empty (default constructed) for the raw pointer constructor executorch::runtime::FreeableBuffer buffer; + // Number of PrepackNodes that still need to read from this TensorRef. When + // this reaches 0, the buffer can be safely freed. This prevents + // use-after-free when multiple PrepackNodes reference the same TensorRef + // (e.g. shared/tied weights). + int32_t prepack_use_count{0}; + explicit TensorRef( const std::vector& t_sizes, vkapi::ScalarType t_dtype, @@ -44,8 +50,6 @@ struct TensorRef final { return utils::multiply_integers(sizes) * vkapi::element_size(dtype); } - // Manually free the buffer if needed (though it will be freed automatically - // on destruction) void free_buffer() { buffer.Free(); } diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index bb21f4b7c2b..9d89a45e168 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -44,6 +44,9 @@ PrepackNode::PrepackNode( push_constants_(push_constants) { graph.update_descriptor_counts(shader, /*execute = */ false); graph.update_descriptor_counts(noop_shader_, /*execute = */ false); + if (!graph.val_is_none(tref)) { + graph.get_tref(tref)->prepack_use_count++; + } } api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { @@ -100,9 +103,10 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { } } - // Once the staging buffer is copied, if the TensorRef owns a FreeableBuffer, - // it can be freed. - tref->free_buffer(); + if (--tref->prepack_use_count == 0) { + tref->free_buffer(); + } + return staging; }