Skip to content

Commit 36aa8be

Browse files
author
ssjia
committed
[ET-VK][runtime] Add prepack cache to avoid duplicate weight prepacking
Pull Request resolved: #18361 When embedding and linear ops share tied weights and both use the same prepacking function (prepack_quantized_linear_weight), the weight gets prepacked twice, wasting GPU memory. Add a cache to ComputeGraph keyed by (input ValueRef, kernel name) that returns the already-prepacked tensor on cache hit, avoiding the duplicate allocation. ghstack-source-id: 355397958 @exported-using-ghexport Differential Revision: [D97430801](https://our.internmc.facebook.com/intern/diff/D97430801/)
1 parent b5e7462 commit 36aa8be

3 files changed

Lines changed: 64 additions & 5 deletions

File tree

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,23 @@ bool ComputeGraph::is_valid_value_idx(const ValueRef idx) const noexcept {
297297
return idx >= 0 && idx < static_cast<int>(values_.size());
298298
}
299299

300+
ValueRef ComputeGraph::get_cached_prepack(
301+
const ValueRef input,
302+
const std::string& kernel_name) const {
303+
auto it = prepack_cache_.find({input, kernel_name});
304+
if (it != prepack_cache_.end()) {
305+
return it->second;
306+
}
307+
return kDummyValueRef;
308+
}
309+
310+
void ComputeGraph::cache_prepack(
311+
const ValueRef input,
312+
const std::string& kernel_name,
313+
const ValueRef prepacked) {
314+
prepack_cache_.emplace(std::make_pair(input, kernel_name), prepacked);
315+
}
316+
300317
std::vector<int64_t> ComputeGraph::sizes_of(const ValueRef idx) const {
301318
const Value& val = values_.at(idx);
302319
if (val.isTensor()) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <optional>
1414
#include <stack>
15+
#include <unordered_map>
1516

1617
#include <executorch/backends/vulkan/runtime/api/api.h>
1718

@@ -204,6 +205,22 @@ class ComputeGraph final {
204205
// Set to track which ValueRefs were updated during inference
205206
std::unordered_set<ValueRef> updated_values_;
206207

208+
// Cache to prevent duplicate prepacking of the same weight tensor with the
209+
// same kernel. Key is (inputValueRef, kernel_name).
210+
struct PrepackCacheHash {
211+
size_t operator()(const std::pair<ValueRef, std::string>& key) const {
212+
size_t h1 = std::hash<ValueRef>{}(key.first);
213+
size_t h2 = std::hash<std::string>{}(key.second);
214+
// Combine hashes using a method similar to boost::hash_combine
215+
return h1 ^ (h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2));
216+
}
217+
};
218+
std::unordered_map<
219+
std::pair<ValueRef, std::string>,
220+
ValueRef,
221+
PrepackCacheHash>
222+
prepack_cache_;
223+
207224
// Flag to indicate if re-encoding is required
208225
bool requires_reencode_ = false;
209226

@@ -687,6 +704,22 @@ class ComputeGraph final {
687704
void check_no_active_value_ptrs();
688705

689706
public:
707+
/*
708+
* Check if a prepacked tensor already exists for the given input and kernel.
709+
*/
710+
ValueRef get_cached_prepack(
711+
const ValueRef input,
712+
const std::string& kernel_name) const;
713+
714+
/*
715+
* Store a prepacked tensor in the cache, keyed by input ValueRef and kernel
716+
* name.
717+
*/
718+
void cache_prepack(
719+
const ValueRef input,
720+
const std::string& kernel_name,
721+
const ValueRef prepacked);
722+
690723
/*
691724
* Add a `api::vTensor` value to the graph with the specified properties.
692725
* There are various convenience overloads of this function that may be used

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,19 @@ ValueRef prepack_quantized_linear_weight(
256256
storage_type = utils::kBuffer;
257257
}
258258

259+
std::string kernel_name = weight_quant_config.nbits == 4
260+
? "pack_q4_linear_weight"
261+
: "pack_q8_linear_weight";
262+
add_storage_type_suffix(kernel_name, storage_type);
263+
264+
// Check prepack cache before creating a new prepack node. This avoids
265+
// allocating a duplicate output tensor when the same weight data has already
266+
// been prepacked with the same kernel (e.g. tied embedding/linear weights).
267+
ValueRef cached = graph.get_cached_prepack(qmat2_data, kernel_name);
268+
if (is_valid(cached)) {
269+
return cached;
270+
}
271+
259272
ValueRef qmat2 = graph.add_tensor(
260273
qmat2_sizes, vkcompute::vkapi::kInt, storage_type, utils::kWidthPacked);
261274

@@ -273,11 +286,6 @@ ValueRef prepack_quantized_linear_weight(
273286
1u};
274287
}
275288

276-
std::string kernel_name = weight_quant_config.nbits == 4
277-
? "pack_q4_linear_weight"
278-
: "pack_q8_linear_weight";
279-
add_storage_type_suffix(kernel_name, storage_type);
280-
281289
graph.prepack_nodes().emplace_back(new PrepackNode(
282290
graph,
283291
VK_KERNEL_FROM_STR(kernel_name),
@@ -294,6 +302,7 @@ ValueRef prepack_quantized_linear_weight(
294302
{graph.sizes_pc_of(qmat2),
295303
PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))}));
296304

305+
graph.cache_prepack(qmat2_data, kernel_name, qmat2);
297306
return qmat2;
298307
}
299308

0 commit comments

Comments
 (0)