Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,11 +1390,20 @@ def register_repeat():

@update_features(exir_ops.edge.aten.embedding.default)
def register_embedding():
def check_embedding_weight_size(node: torch.fx.Node) -> bool:
weight = node.args[0]
if isinstance(weight, torch.fx.Node) and utils.is_tensor_node(weight):
numel = weight.meta["val"].numel()
if numel > utils.DEFAULT_BUFFER_LIMIT:
return False
return True

return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=[utils.FP_T, utils.INT_T],
supports_prepacking=True,
supports_resize=True,
are_node_inputs_supported_fn=check_embedding_weight_size,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", "int", STORAGE)}
${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_weight", DTYPE, "texture2d")}
${layout_declare_ubo(B, "ivec4", "sizes")}

#include "indexing_utils.h"
Expand All @@ -30,9 +30,6 @@ const lowp int packed_dim = unhash_packed_dim(out_layout);
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);

${layout_declare_spec_const(C, "int", "weight_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 weight_axis_map = unhash_axis_map(weight_layout);

void main() {
const ivec3 out_lpos = ivec3(gl_GlobalInvocationID);
const ivec4 out_tidx = lpos_to_tidx(out_lpos, sizes, out_axis_map.w, packed_dim);
Expand All @@ -48,8 +45,8 @@ void main() {
const int in_texel_elem = load_texel_lpos(t_in, in_lpos, in_axis_map)[out_tidx.w % 4];

// Read weight tensor for embedding, it is height-packed.
const ivec3 weight_lpos = ivec3(out_tidx.x, in_texel_elem / 4, 0);
out_texel[i] = load_texel_lpos(t_weight, weight_lpos, weight_axis_map)[in_texel_elem % 4];
const ivec2 weight_pos = ivec2(out_tidx.x, in_texel_elem / 4);
out_texel[i] = texelFetch(t_weight, weight_pos, 0)[in_texel_elem % 4];
}

write_texel_lpos(t_out, out_lpos, out_texel, out_axis_map);
Expand Down
4 changes: 1 addition & 3 deletions backends/vulkan/runtime/graph/ops/impl/Embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ void add_embedding_legacy_node(
// Push Constants
{},
// Specialization Constants
{graph.hashed_layout_of(out),
graph.hashed_layout_of(in),
graph.hashed_layout_of(weight)},
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
// Resize Args
{},
// Resizing Logic
Expand Down
5 changes: 2 additions & 3 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,14 +1167,13 @@ def get_embedding_inputs():
Test(weight=[10, 9], indices=[[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]),
]

# Channels packed test cases currently fail on Mac, so they are not included.
# However the test case definition is kept for later debugging.
test_suite_cpack = VkTestSuite(
[tuple(tc) + (-1, "false", "false") for tc in test_cases]
)

test_suite_cpack.dtypes = ["at::kFloat"]
test_suite_cpack.layouts = ["utils::kChannelsPacked"]
test_suite_cpack.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
test_suite_cpack.test_name_suffix = "cpacked"

test_suite_wpack = VkTestSuite(
Expand All @@ -1186,7 +1185,7 @@ def get_embedding_inputs():
test_suite_wpack.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
test_suite_wpack.test_name_suffix = "wpacked"

return test_suite_wpack
return [test_suite_cpack, test_suite_wpack]


@register_test_suite("aten.gather.default")
Expand Down
Loading