diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 2a4e722f68b..466f9d69bde 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -294,28 +294,16 @@ def register_comparison_ops(): # ============================================================================= -@update_features(exir_ops.edge.aten.bitwise_and.Tensor) -def register_bitwise_and(): - return OpFeatures( - inputs_storage=utils.ANY_STORAGE, - inputs_dtypes=utils.BOOL_T, - supports_resize=True, - supports_highdim=True, - ) - - -@update_features(exir_ops.edge.aten.bitwise_not.default) -def register_bitwise_not(): - return OpFeatures( - inputs_storage=utils.ANY_STORAGE, - inputs_dtypes=utils.BOOL_T, - supports_resize=True, - supports_highdim=True, - ) - - -@update_features(exir_ops.edge.aten.logical_and.default) -def register_logical_and(): +@update_features( + [ + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_not.default, + exir_ops.edge.aten.logical_and.default, + exir_ops.edge.aten.logical_or.default, + ] +) +def register_bool_binary_ops(): return OpFeatures( inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.BOOL_T, diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op_buffer.yaml index 8aef89cd739..1f217acb127 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op_buffer.yaml @@ -52,3 +52,8 @@ binary_op_buffer: generate_variant_forall: DTYPE: - VALUE: uint8 + - NAME: binary_bitwise_or_buffer + OPERATOR: X | Y + generate_variant_forall: + DTYPE: + - VALUE: uint8 diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op_texture.yaml index 437803b2410..289466e7845 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op_texture.yaml @@ -54,3 +54,8 @@ binary_op_texture: generate_variant_forall: DTYPE: - VALUE: uint8 + - NAME: binary_bitwise_or_texture3d + OPERATOR: X | Y + generate_variant_forall: + DTYPE: + - VALUE: uint8 diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 6ff58e72dc3..9e696a008fe 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -143,6 +143,7 @@ DEFINE_BINARY_OP_FN(le); DEFINE_BINARY_OP_FN(gt); DEFINE_BINARY_OP_FN(ge); DEFINE_BINARY_OP_FN(bitwise_and); +DEFINE_BINARY_OP_FN(bitwise_or); REGISTER_OPERATORS { VK_REGISTER_OP(aten.add.Tensor, add); @@ -159,6 +160,8 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.ge.Tensor, ge); VK_REGISTER_OP(aten.bitwise_and.Tensor, bitwise_and); VK_REGISTER_OP(aten.logical_and.default, bitwise_and); + VK_REGISTER_OP(aten.bitwise_or.Tensor, bitwise_or); + VK_REGISTER_OP(aten.logical_or.default, bitwise_or); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 6efae3d0398..681f2c31475 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -2164,6 +2164,29 @@ def get_bitwise_and_inputs(): return test_suite +@register_test_suite("aten.bitwise_or.Tensor") +def get_bitwise_or_inputs(): + test_suite = VkTestSuite( + [ + ((M1, M2), (M1, M2)), + ((S, S1, S2), (S, S1, S2)), + ((XS, S, S1, S2), (XS, S, S1, S2)), + ((1, M1), (1, M1)), + ] + ) + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + test_suite.storage_types = [ + "utils::kBuffer", + "utils::kTexture3D", + ] + test_suite.dtypes = ["at::kBool"] + test_suite.data_gen = "make_seq_tensor" + return test_suite + + @register_test_suite("aten.index.Tensor") def get_index_tensor_inputs(): Test = namedtuple("IndexTensorTest", ["self", "indices"])