From 6d9f74c05f6288d5063d74204dd1c513b76068d0 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 18 Jun 2026 08:03:31 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/vulkan/op_registry.py | 11 +++++++ .../graph/ops/glsl/binary_scalar_buffer.glsl | 20 +++++++++++-- .../graph/ops/glsl/binary_scalar_buffer.yaml | 6 ++-- .../graph/ops/glsl/binary_scalar_texture.glsl | 20 +++++++++++-- .../graph/ops/glsl/binary_scalar_texture.yaml | 6 ++-- .../runtime/graph/ops/impl/BinaryScalarOp.cpp | 5 ++++ backends/vulkan/test/op_tests/cases.py | 29 +++++++++++++++++++ 7 files changed, 87 insertions(+), 10 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 466f9d69bde..105bcea89a6 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -327,6 +327,17 @@ def register_pow_tensor_scalar(): ) +@update_features(exir_ops.edge.aten.eq.Scalar) +def register_eq_scalar(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + inputs_dtypes=utils.FP_INT_T, + outputs_dtypes=utils.BOOL_T, + supports_resize=True, + supports_highdim=True, + ) + + # ============================================================================= # ToCopy.cpp # ============================================================================= diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl index 9e3a35bf4f1..42c8b5b2326 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl @@ -6,15 +6,24 @@ * LICENSE file in the root directory of this source tree. */ +// Binary comparison ops require that the output is boolean and not the same as +// input. IS_COMPARISON_OP is set explicitly per shader variant in the .yaml. + #version 450 core ${define_required_extensions(STORAGE, DTYPE)} +$if IS_COMPARISON_OP: + ${define_required_extensions(STORAGE, "uint8")} #define PRECISION ${PRECISION} #define NAME ${VARIANT_NAME} #define T ${buffer_scalar_type(DTYPE)} +$if IS_COMPARISON_OP: + #define OUT_T ${buffer_scalar_type("uint8")} +$else: + #define OUT_T ${buffer_scalar_type(DTYPE)} #define op(X, Y) ${OPERATOR} @@ -24,7 +33,11 @@ layout(std430) buffer; #include "indexing.glslh" -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +$if IS_COMPARISON_OP: + ${layout_declare_tensor(B, "w", "t_out", "uint8", STORAGE)} +$else: + ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} + ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} ${layout_declare_ubo(B, "BufferMetadata", "outp")} @@ -36,7 +49,8 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -#include "binary_op_defs.glslh" +$if not IS_COMPARISON_OP: + #include "binary_op_defs.glslh" void main() { const uint out_bufi = gl_GlobalInvocationID.x; @@ -44,5 +58,5 @@ void main() { return; } - t_out[out_bufi] = T(op(t_in[out_bufi], T(scalar_value))); + t_out[out_bufi] = OUT_T(op(t_in[out_bufi], T(scalar_value))); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml index b818132cf9b..e46c3c55332 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml @@ -7,9 +7,8 @@ binary_scalar_buffer: parameter_names_with_default_values: OPERATOR: power_of(X, Y) - NDIM: 3 + IS_COMPARISON_OP: false DTYPE: float - PACKING: C_packed STORAGE: buffer generate_variant_forall: DTYPE: @@ -18,3 +17,6 @@ binary_scalar_buffer: - VALUE: int32 shader_variants: - NAME: pow_scalar_buffer + - NAME: eq_scalar_buffer + OPERATOR: X == Y + IS_COMPARISON_OP: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl index 651dfdd7b5d..34a029521e2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl @@ -6,9 +6,14 @@ * LICENSE file in the root directory of this source tree. */ +// Binary comparison ops require that the output is boolean and not the same as +// input. IS_COMPARISON_OP is set explicitly per shader variant in the .yaml. + #version 450 core ${define_required_extensions(STORAGE, DTYPE)} +$if IS_COMPARISON_OP: + ${define_required_extensions(STORAGE, "uint8")} #define PRECISION ${PRECISION} @@ -16,6 +21,10 @@ ${define_required_extensions(STORAGE, DTYPE)} #define VEC4_T ${texel_load_type(DTYPE, STORAGE)} #define T ${texel_load_component_type(DTYPE, STORAGE)} +$if IS_COMPARISON_OP: + #define VEC4_OUT_T ${texel_load_type("uint8", STORAGE)} +$else: + #define VEC4_OUT_T VEC4_T #define op(X, Y) ${OPERATOR} @@ -25,7 +34,11 @@ layout(std430) buffer; #include "indexing.glslh" -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +$if IS_COMPARISON_OP: + ${layout_declare_tensor(B, "w", "t_out", "uint8", STORAGE)} +$else: + ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} + ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} ${layout_declare_ubo(B, "TextureMetadata", "outp")} @@ -37,7 +50,8 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -#include "binary_op_defs.glslh" +$if not IS_COMPARISON_OP: + #include "binary_op_defs.glslh" void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -47,7 +61,7 @@ void main() { } VEC4_T in_texel = texelFetch(t_in, pos, 0); - VEC4_T out_texel = VEC4_T(op(in_texel, VEC4_T(scalar_value))); + VEC4_OUT_T out_texel = VEC4_OUT_T(op(in_texel, VEC4_T(scalar_value))); imageStore(t_out, pos, out_texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml index 3e731bf7a15..c6239ea33bd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml @@ -7,9 +7,8 @@ binary_scalar_texture: parameter_names_with_default_values: OPERATOR: power_of(X, Y) - NDIM: 3 + IS_COMPARISON_OP: false DTYPE: float - PACKING: C_packed STORAGE: texture3d generate_variant_forall: DTYPE: @@ -18,3 +17,6 @@ binary_scalar_texture: - VALUE: int32 shader_variants: - NAME: pow_scalar_texture3d + - NAME: eq_scalar_texture3d + OPERATOR: equal(X, Y) + IS_COMPARISON_OP: true diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp index 15553706494..0470ee9fb6c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp @@ -73,8 +73,13 @@ void pow_tensor_scalar(ComputeGraph& graph, const std::vector& args) { return add_binary_scalar_op_node(graph, args[0], args[1], args[2], "pow"); } +void eq_tensor_scalar(ComputeGraph& graph, const std::vector& args) { + return add_binary_scalar_op_node(graph, args[0], args[1], args[2], "eq"); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten.pow.Tensor_Scalar, pow_tensor_scalar); + VK_REGISTER_OP(aten.eq.Scalar, eq_tensor_scalar); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 681f2c31475..c42d4026733 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -2237,3 +2237,32 @@ def get_pow_tensor_scalar_inputs(): ] test_suite.dtypes = ["at::kFloat"] return test_suite + + +@register_test_suite("aten.eq.Scalar") +def get_eq_scalar_inputs(): + # Scalars are chosen to fall within the make_seq_tensor range (1..numel), + # so each case exercises a genuine mix of equal / not-equal elements rather + # than a trivially all-false comparison. + test_suite = VkTestSuite( + [ + ((M1,), 5), + ((M2, M1), 100), + ((S1, M1, M2), 1000), + ((S1, S2, S2, M2), 2000), + ((S, S1, S2), 50), + ((M1, M2), 700), + ((S1, S2), 20), + ] + ) + test_suite.storage_types = [ + "utils::kBuffer", + "utils::kTexture3D", + ] + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + test_suite.dtypes = ["at::kInt"] + test_suite.data_gen = "make_seq_tensor" + return test_suite