Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,17 @@ def register_pow_tensor_scalar():
)


@update_features(exir_ops.edge.aten.eq.Scalar)
def register_eq_scalar():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
outputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)


# =============================================================================
# ToCopy.cpp
# =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,24 @@
* LICENSE file in the root directory of this source tree.
*/

// Binary comparison ops require that the output is boolean and not the same as
// input. IS_COMPARISON_OP is set explicitly per shader variant in the .yaml.

#version 450 core

${define_required_extensions(STORAGE, DTYPE)}
$if IS_COMPARISON_OP:
${define_required_extensions(STORAGE, "uint8")}

#define PRECISION ${PRECISION}

#define NAME ${VARIANT_NAME}

#define T ${buffer_scalar_type(DTYPE)}
$if IS_COMPARISON_OP:
#define OUT_T ${buffer_scalar_type("uint8")}
$else:
#define OUT_T ${buffer_scalar_type(DTYPE)}

#define op(X, Y) ${OPERATOR}

Expand All @@ -24,7 +33,11 @@ layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
$if IS_COMPARISON_OP:
${layout_declare_tensor(B, "w", "t_out", "uint8", STORAGE)}
$else:
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}

${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}

${layout_declare_ubo(B, "BufferMetadata", "outp")}
Expand All @@ -36,13 +49,14 @@ layout(push_constant) uniform restrict Block {

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#include "binary_op_defs.glslh"
$if not IS_COMPARISON_OP:
#include "binary_op_defs.glslh"

void main() {
const uint out_bufi = gl_GlobalInvocationID.x;
if (out_of_bounds(out_bufi, outp)) {
return;
}

t_out[out_bufi] = T(op(t_in[out_bufi], T(scalar_value)));
t_out[out_bufi] = OUT_T(op(t_in[out_bufi], T(scalar_value)));
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
binary_scalar_buffer:
parameter_names_with_default_values:
OPERATOR: power_of(X, Y)
NDIM: 3
IS_COMPARISON_OP: false
DTYPE: float
PACKING: C_packed
STORAGE: buffer
generate_variant_forall:
DTYPE:
Expand All @@ -18,3 +17,6 @@ binary_scalar_buffer:
- VALUE: int32
shader_variants:
- NAME: pow_scalar_buffer
- NAME: eq_scalar_buffer
OPERATOR: X == Y
IS_COMPARISON_OP: true
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,25 @@
* LICENSE file in the root directory of this source tree.
*/

// Binary comparison ops require that the output is boolean and not the same as
// input. IS_COMPARISON_OP is set explicitly per shader variant in the .yaml.

#version 450 core

${define_required_extensions(STORAGE, DTYPE)}
$if IS_COMPARISON_OP:
${define_required_extensions(STORAGE, "uint8")}

#define PRECISION ${PRECISION}

#define NAME ${VARIANT_NAME}

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
#define T ${texel_load_component_type(DTYPE, STORAGE)}
$if IS_COMPARISON_OP:
#define VEC4_OUT_T ${texel_load_type("uint8", STORAGE)}
$else:
#define VEC4_OUT_T VEC4_T

#define op(X, Y) ${OPERATOR}

Expand All @@ -25,7 +34,11 @@ layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
$if IS_COMPARISON_OP:
${layout_declare_tensor(B, "w", "t_out", "uint8", STORAGE)}
$else:
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}

${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}

${layout_declare_ubo(B, "TextureMetadata", "outp")}
Expand All @@ -37,7 +50,8 @@ layout(push_constant) uniform restrict Block {

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#include "binary_op_defs.glslh"
$if not IS_COMPARISON_OP:
#include "binary_op_defs.glslh"

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
Expand All @@ -47,7 +61,7 @@ void main() {
}

VEC4_T in_texel = texelFetch(t_in, pos, 0);
VEC4_T out_texel = VEC4_T(op(in_texel, VEC4_T(scalar_value)));
VEC4_OUT_T out_texel = VEC4_OUT_T(op(in_texel, VEC4_T(scalar_value)));

imageStore(t_out, pos, out_texel);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
binary_scalar_texture:
parameter_names_with_default_values:
OPERATOR: power_of(X, Y)
NDIM: 3
IS_COMPARISON_OP: false
DTYPE: float
PACKING: C_packed
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -18,3 +17,6 @@ binary_scalar_texture:
- VALUE: int32
shader_variants:
- NAME: pow_scalar_texture3d
- NAME: eq_scalar_texture3d
OPERATOR: equal(X, Y)
IS_COMPARISON_OP: true
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,13 @@ void pow_tensor_scalar(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_binary_scalar_op_node(graph, args[0], args[1], args[2], "pow");
}

void eq_tensor_scalar(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_binary_scalar_op_node(graph, args[0], args[1], args[2], "eq");
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.pow.Tensor_Scalar, pow_tensor_scalar);
VK_REGISTER_OP(aten.eq.Scalar, eq_tensor_scalar);
}

} // namespace vkcompute
29 changes: 29 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2237,3 +2237,32 @@ def get_pow_tensor_scalar_inputs():
]
test_suite.dtypes = ["at::kFloat"]
return test_suite


@register_test_suite("aten.eq.Scalar")
def get_eq_scalar_inputs():
# Scalars are chosen to fall within the make_seq_tensor range (1..numel),
# so each case exercises a genuine mix of equal / not-equal elements rather
# than a trivially all-false comparison.
test_suite = VkTestSuite(
[
((M1,), 5),
((M2, M1), 100),
((S1, M1, M2), 1000),
((S1, S2, S2, M2), 2000),
((S, S1, S2), 50),
((M1, M2), 700),
((S1, S2), 20),
]
)
test_suite.storage_types = [
"utils::kBuffer",
"utils::kTexture3D",
]
test_suite.layouts = [
"utils::kWidthPacked",
"utils::kChannelsPacked",
]
test_suite.dtypes = ["at::kInt"]
test_suite.data_gen = "make_seq_tensor"
return test_suite
Loading