From 43f2cc722be1a6e4782e1b15bc7f391878343cbd Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 3 Jul 2025 11:17:08 -0700 Subject: [PATCH] [ET-VK][ez] enabling fp64->fp32 converison for vulkan compatibility # Context We need this conversion so that certain operators can handle floating point values that need to be 64bit. This is predominantly applicable to choose_qparams.tensor where it expects a 64bit output. # Changes Simply adding an additional conversion for float64 to vulkan fp32. Differential Revision: [D77746137](https://our.internmc.facebook.com/intern/diff/D77746137/) [ghstack-poisoned] --- backends/vulkan/serialization/vulkan_graph_builder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 5bae0475c28..6ee37b4a0bc 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -79,6 +79,10 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: # Narrowing conversion for index tensor produced by max_poolNd_with_indices. elif torch_dtype == torch.int64: return vk_graph_schema.VkDataType.INT32 + # Narrowing conversion for float64 (double) to float32 for Vulkan compatibility + elif torch_dtype == torch.float64: + return vk_graph_schema.VkDataType.FLOAT32 + else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")