From 4cd537b05a9ad1a98d343dd27a373d04cb8fa3e2 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Thu, 10 Jul 2025 11:40:23 -0700 Subject: [PATCH] [ET-VK] Adding get or create int function to read int value. This diff adds a new function `get_or_create_int` to the `ComputeGraph` class, which allows reading an integer value from a `ValueRef` index. The function returns the extracted integer value if the value at the index is an integer, otherwise it throws an error. Additionally, an overload of the function is added to return a default value if the value at the index is `None`. Differential Revision: [D78094858](https://our.internmc.facebook.com/intern/diff/D78094858/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ComputeGraph.cpp | 16 ++++++++++++++++ backends/vulkan/runtime/graph/ComputeGraph.h | 10 ++++++++++ 2 files changed, 26 insertions(+) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index cb14a41e98a..2dc02b8b800 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -549,6 +549,22 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer( } } +int32_t ComputeGraph::get_or_create_int(const ValueRef idx) { + if (values_.at(idx).isInt()) { + return extract_scalar(idx); + } + VK_THROW("Cannot create a int param buffer for the given value"); +} + +int32_t ComputeGraph::get_or_create_int( + const ValueRef idx, + const int32_t default_val) { + if (values_.at(idx).isNone()) { + return default_val; + } + return get_or_create_int(idx); +} + void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) { get_symint(idx)->set(val); } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 78135a434e5..7a73ae1dee5 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -424,6 +424,12 @@ class ComputeGraph final { // Scalar Value Extraction // + bool is_scalar(const ValueRef idx) const { + const Value& value = values_.at(idx); + return value.isInt() || value.isDouble() || value.isBool() || + value.isNone(); + } + template T extract_scalar(const ValueRef idx) { Value& value = values_.at(idx); @@ -679,6 +685,10 @@ class ComputeGraph final { const ValueRef idx, const int32_t default_value); + int32_t get_or_create_int(const ValueRef idx); + + int32_t get_or_create_int(const ValueRef idx, const int32_t default_value); + void set_symint(const ValueRef idx, const int32_t val); int32_t read_symint(const ValueRef idx);