From 21a359ee0db744846cc602c0a960a62499e946f6 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 15 Apr 2026 10:09:12 -0700 Subject: [PATCH] [ET-VK] Fix use-after-free in PrepackNode when TensorRefs are shared MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/18906 When a model has shared/tied weights (e.g. tied embeddings in transformers), the serialization deduplicates them into a single TensorRef that multiple PrepackNodes reference. Previously, `PrepackNode::create_staging_buffer()` called `tref->free_buffer()` unconditionally after copying weight data to a GPU staging buffer. This meant the first PrepackNode to execute would free the underlying host memory, and subsequent PrepackNodes sharing the same TensorRef would read from a dangling pointer — producing garbage/NaN values in prepacked weight and bias tensors on the GPU. The fix adds a `prepack_use_count` field to `TensorRef` that tracks how many PrepackNodes still need to read from it. Each PrepackNode increments the count in its constructor and decrements it after copying data. The buffer is only freed when the count reaches zero. This preserves the original eager-free behavior for non-shared weights (freeing immediately after the single consumer copies) while correctly deferring the free for shared weights until the last consumer is done — avoiding both the use-after-free and unnecessary peak memory increase. ghstack-source-id: 367726483 @exported-using-ghexport Differential Revision: [D101009402](https://our.internmc.facebook.com/intern/diff/D101009402/) --- backends/vulkan/runtime/graph/ComputeGraph.cpp | 2 -- backends/vulkan/runtime/graph/containers/Constant.h | 8 ++++++-- backends/vulkan/runtime/graph/ops/PrepackNode.cpp | 10 +++++++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index b14d0f6ab0b..f0b61e128bb 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -1135,8 +1135,6 @@ void ComputeGraph::prepack() { int i = 0; bool submitted = false; const bool reduce_peak_memory = total_constant_nbytes_ > 10 * MB; - // int count = 0; - context_->set_cmd(); for (std::unique_ptr& node : prepack_nodes_) { // Do not trigger on the first or last prepack node. diff --git a/backends/vulkan/runtime/graph/containers/Constant.h b/backends/vulkan/runtime/graph/containers/Constant.h index a18c284a219..690a25dd9c8 100644 --- a/backends/vulkan/runtime/graph/containers/Constant.h +++ b/backends/vulkan/runtime/graph/containers/Constant.h @@ -29,6 +29,12 @@ struct TensorRef final { // This will be empty (default constructed) for the raw pointer constructor executorch::runtime::FreeableBuffer buffer; + // Number of PrepackNodes that still need to read from this TensorRef. When + // this reaches 0, the buffer can be safely freed. This prevents + // use-after-free when multiple PrepackNodes reference the same TensorRef + // (e.g. shared/tied weights). + int32_t prepack_use_count{0}; + explicit TensorRef( const std::vector& t_sizes, vkapi::ScalarType t_dtype, @@ -44,8 +50,6 @@ struct TensorRef final { return utils::multiply_integers(sizes) * vkapi::element_size(dtype); } - // Manually free the buffer if needed (though it will be freed automatically - // on destruction) void free_buffer() { buffer.Free(); } diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index bb21f4b7c2b..9d89a45e168 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -44,6 +44,9 @@ PrepackNode::PrepackNode( push_constants_(push_constants) { graph.update_descriptor_counts(shader, /*execute = */ false); graph.update_descriptor_counts(noop_shader_, /*execute = */ false); + if (!graph.val_is_none(tref)) { + graph.get_tref(tref)->prepack_use_count++; + } } api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { @@ -100,9 +103,10 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { } } - // Once the staging buffer is copied, if the TensorRef owns a FreeableBuffer, - // it can be freed. - tref->free_buffer(); + if (--tref->prepack_use_count == 0) { + tref->free_buffer(); + } + return staging; }