Skip to content

Commit e34f63e

Browse files
author
ssjia
committed
Update
[ghstack-poisoned]
1 parent 603d3bb commit e34f63e

5 files changed

Lines changed: 46 additions & 22 deletions

File tree

backends/vulkan/op_registry.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -294,28 +294,16 @@ def register_comparison_ops():
294294
# =============================================================================
295295

296296

297-
@update_features(exir_ops.edge.aten.bitwise_and.Tensor)
298-
def register_bitwise_and():
299-
return OpFeatures(
300-
inputs_storage=utils.ANY_STORAGE,
301-
inputs_dtypes=utils.BOOL_T,
302-
supports_resize=True,
303-
supports_highdim=True,
304-
)
305-
306-
307-
@update_features(exir_ops.edge.aten.bitwise_not.default)
308-
def register_bitwise_not():
309-
return OpFeatures(
310-
inputs_storage=utils.ANY_STORAGE,
311-
inputs_dtypes=utils.BOOL_T,
312-
supports_resize=True,
313-
supports_highdim=True,
314-
)
315-
316-
317-
@update_features(exir_ops.edge.aten.logical_and.default)
318-
def register_logical_and():
297+
@update_features(
298+
[
299+
exir_ops.edge.aten.bitwise_and.Tensor,
300+
exir_ops.edge.aten.bitwise_or.Tensor,
301+
exir_ops.edge.aten.bitwise_not.default,
302+
exir_ops.edge.aten.logical_and.default,
303+
exir_ops.edge.aten.logical_or.default,
304+
]
305+
)
306+
def register_bool_binary_ops():
319307
return OpFeatures(
320308
inputs_storage=utils.ANY_STORAGE,
321309
inputs_dtypes=utils.BOOL_T,

backends/vulkan/runtime/graph/ops/glsl/binary_op_buffer.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,8 @@ binary_op_buffer:
5252
generate_variant_forall:
5353
DTYPE:
5454
- VALUE: uint8
55+
- NAME: binary_bitwise_or_buffer
56+
OPERATOR: X | Y
57+
generate_variant_forall:
58+
DTYPE:
59+
- VALUE: uint8

backends/vulkan/runtime/graph/ops/glsl/binary_op_texture.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,8 @@ binary_op_texture:
5454
generate_variant_forall:
5555
DTYPE:
5656
- VALUE: uint8
57+
- NAME: binary_bitwise_or_texture3d
58+
OPERATOR: X | Y
59+
generate_variant_forall:
60+
DTYPE:
61+
- VALUE: uint8

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ DEFINE_BINARY_OP_FN(le);
143143
DEFINE_BINARY_OP_FN(gt);
144144
DEFINE_BINARY_OP_FN(ge);
145145
DEFINE_BINARY_OP_FN(bitwise_and);
146+
DEFINE_BINARY_OP_FN(bitwise_or);
146147

147148
REGISTER_OPERATORS {
148149
VK_REGISTER_OP(aten.add.Tensor, add);
@@ -159,6 +160,8 @@ REGISTER_OPERATORS {
159160
VK_REGISTER_OP(aten.ge.Tensor, ge);
160161
VK_REGISTER_OP(aten.bitwise_and.Tensor, bitwise_and);
161162
VK_REGISTER_OP(aten.logical_and.default, bitwise_and);
163+
VK_REGISTER_OP(aten.bitwise_or.Tensor, bitwise_or);
164+
VK_REGISTER_OP(aten.logical_or.default, bitwise_or);
162165
}
163166

164167
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,29 @@ def get_bitwise_and_inputs():
21642164
return test_suite
21652165

21662166

2167+
@register_test_suite("aten.bitwise_or.Tensor")
2168+
def get_bitwise_or_inputs():
2169+
test_suite = VkTestSuite(
2170+
[
2171+
((M1, M2), (M1, M2)),
2172+
((S, S1, S2), (S, S1, S2)),
2173+
((XS, S, S1, S2), (XS, S, S1, S2)),
2174+
((1, M1), (1, M1)),
2175+
]
2176+
)
2177+
test_suite.layouts = [
2178+
"utils::kWidthPacked",
2179+
"utils::kChannelsPacked",
2180+
]
2181+
test_suite.storage_types = [
2182+
"utils::kBuffer",
2183+
"utils::kTexture3D",
2184+
]
2185+
test_suite.dtypes = ["at::kBool"]
2186+
test_suite.data_gen = "make_seq_tensor"
2187+
return test_suite
2188+
2189+
21672190
@register_test_suite("aten.index.Tensor")
21682191
def get_index_tensor_inputs():
21692192
Test = namedtuple("IndexTensorTest", ["self", "indices"])

0 commit comments

Comments
 (0)