diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 61135c18648..bdda551de27 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1390,11 +1390,20 @@ def register_repeat(): @update_features(exir_ops.edge.aten.embedding.default) def register_embedding(): + def check_embedding_weight_size(node: torch.fx.Node) -> bool: + weight = node.args[0] + if isinstance(weight, torch.fx.Node) and utils.is_tensor_node(weight): + numel = weight.meta["val"].numel() + if numel > utils.DEFAULT_BUFFER_LIMIT: + return False + return True + return OpFeatures( - inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + inputs_storage=utils.ANY_STORAGE, inputs_dtypes=[utils.FP_T, utils.INT_T], supports_prepacking=True, supports_resize=True, + are_node_inputs_supported_fn=check_embedding_weight_size, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.glsl index 73a444cd84d..87cea50cdea 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.glsl @@ -16,7 +16,7 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_in", "int", STORAGE)} -${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "texture2d")} ${layout_declare_ubo(B, "ivec4", "sizes")} #include "indexing_utils.h" @@ -30,9 +30,6 @@ const lowp int packed_dim = unhash_packed_dim(out_layout); ${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); -${layout_declare_spec_const(C, "int", "weight_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 weight_axis_map = unhash_axis_map(weight_layout); - void main() { const ivec3 out_lpos = ivec3(gl_GlobalInvocationID); const ivec4 out_tidx = lpos_to_tidx(out_lpos, sizes, out_axis_map.w, packed_dim); @@ -48,8 +45,8 @@ void main() { const int in_texel_elem = load_texel_lpos(t_in, in_lpos, in_axis_map)[out_tidx.w % 4]; // Read weight tensor for embedding, it is height-packed. - const ivec3 weight_lpos = ivec3(out_tidx.x, in_texel_elem / 4, 0); - out_texel[i] = load_texel_lpos(t_weight, weight_lpos, weight_axis_map)[in_texel_elem % 4]; + const ivec2 weight_pos = ivec2(out_tidx.x, in_texel_elem / 4); + out_texel[i] = texelFetch(t_weight, weight_pos, 0)[in_texel_elem % 4]; } write_texel_lpos(t_out, out_lpos, out_texel, out_axis_map); diff --git a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp index 61d27d48f6c..b98eb75cebd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp @@ -111,9 +111,7 @@ void add_embedding_legacy_node( // Push Constants {}, // Specialization Constants - {graph.hashed_layout_of(out), - graph.hashed_layout_of(in), - graph.hashed_layout_of(weight)}, + {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}, // Resize Args {}, // Resizing Logic diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 99c4bebb64f..87905860081 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1167,14 +1167,13 @@ def get_embedding_inputs(): Test(weight=[10, 9], indices=[[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]), ] - # Channels packed test cases currently fail on Mac, so they are not included. - # However the test case definition is kept for later debugging. test_suite_cpack = VkTestSuite( [tuple(tc) + (-1, "false", "false") for tc in test_cases] ) test_suite_cpack.dtypes = ["at::kFloat"] test_suite_cpack.layouts = ["utils::kChannelsPacked"] + test_suite_cpack.storage_types = ["utils::kBuffer", "utils::kTexture3D"] test_suite_cpack.test_name_suffix = "cpacked" test_suite_wpack = VkTestSuite( @@ -1186,7 +1185,7 @@ def get_embedding_inputs(): test_suite_wpack.storage_types = ["utils::kBuffer", "utils::kTexture3D"] test_suite_wpack.test_name_suffix = "wpacked" - return test_suite_wpack + return [test_suite_cpack, test_suite_wpack] @register_test_suite("aten.gather.default")