Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions backends/vulkan/_passes/remove_redundant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class RemoveRedundantOpsTransform(ExportPass):
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.aten.expand_copy.default,
# copy.default(self, src): no-op when src dtype/shape matches self.
exir_ops.edge.aten.copy.default,
}

# For these ops the meaningful input is args[1] (src), not args[0] (self).
_src_arg1_ops: Set[OpType] = {
exir_ops.edge.aten.copy.default,
}

def __init__(self) -> None:
Expand All @@ -41,7 +48,8 @@ def _should_remove(self, node: torch.fx.Node) -> bool:
if node.target not in self.redundant_ops:
return False

orig_node = node.args[0]
src_arg_idx = 1 if node.target in self._src_arg1_ops else 0
orig_node = node.args[src_arg_idx]
assert isinstance(orig_node, torch.fx.Node)

src_dtype = orig_node.meta["val"].dtype
Expand All @@ -61,7 +69,8 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
if not self._should_remove(node):
continue

node.replace_all_uses_with(node.args[0])
src_arg_idx = 1 if node.target in self._src_arg1_ops else 0
node.replace_all_uses_with(node.args[src_arg_idx])

graph_module.graph.eliminate_dead_code()

Expand Down
87 changes: 86 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def update_features_impl(op: OpKey):
# Guard and assert ops
torch.ops.aten._assert_scalar.default,
torch.ops.aten.sym_constrain_range_for_size.default,
# copy.default is a no-op when src dtype matches dst dtype; removed by
# RemoveRedundantOpsTransform before execution.
exir_ops.edge.aten.copy.default,
]
)
def register_ephemeral_ops():
Expand Down Expand Up @@ -231,17 +234,46 @@ def register_clamp():
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.pow.Tensor_Tensor,
]
)
def register_binaryop_cpp_ops():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
supports_resize=True,
supports_highdim=True,
)


@update_features(
[
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.ge.Tensor,
]
)
def register_binaryop_cpp_ops():
def register_comparison_ops():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
outputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)


# =============================================================================
# BinaryOp.cpp (bitwise)
# =============================================================================


@update_features(exir_ops.edge.aten.bitwise_and.Tensor)
def register_bitwise_and():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)
Expand Down Expand Up @@ -673,6 +705,7 @@ def register_argreduce_cpp_ops():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
inputs_dtypes=utils.FP_T,
outputs_dtypes=utils.INT_T,
supports_resize=True,
supports_highdim=True,
are_node_inputs_supported_fn=is_reduce_node_supported,
Expand Down Expand Up @@ -1157,6 +1190,58 @@ def register_index_select():
)


# =============================================================================
# Where.cpp
# =============================================================================


@update_features(exir_ops.edge.aten.where.self)
def register_where():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=[utils.BOOL_T, utils.FP_T, utils.FP_T],
outputs_dtypes=utils.FP_T,
supports_resize=True,
)


# =============================================================================
# IndexTensor.cpp
# =============================================================================


@update_features(exir_ops.edge.aten.index.Tensor)
def register_index_tensor():
def check_index_tensor_node(node: torch.fx.Node) -> bool:
self_arg = node.args[0]
indices = node.args[1]

# Only support 1D self tensor
if not isinstance(self_arg, torch.fx.Node):
return False
self_val = self_arg.meta.get("val", None)
if self_val is None:
return False
if len(self_val.size()) != 1:
return False

# Only support exactly one non-None index tensor
if not isinstance(indices, (list, tuple)):
return False
non_none = [idx for idx in indices if idx is not None]
if len(non_none) != 1:
return False

return True

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
supports_resize=True,
are_node_inputs_supported_fn=check_index_tensor_node,
)


# =============================================================================
# Arange.cpp
# =============================================================================
Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,11 @@ binary_op:
- VALUE: half
- VALUE: float
- VALUE: int32
- NAME: binary_bitwise_and
OPERATOR: X & Y
generate_variant_forall:
STORAGE:
- VALUE: buffer
- VALUE: texture3d
DTYPE:
- VALUE: uint8
58 changes: 58 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/index_tensor_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

${define_required_extensions("buffer", DTYPE)}

#define PRECISION ${PRECISION}

#define T ${buffer_scalar_type(DTYPE)}

${define_active_storage_type("buffer")}

layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_self", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_index", "int", "buffer")}

${layout_declare_ubo(B, "BufferMetadata", "outp")}
${layout_declare_ubo(B, "BufferMetadata", "inp")}
${layout_declare_ubo(B, "BufferMetadata", "index")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// Implements aten.index.Tensor for the case where self is 1D and there is
// exactly one index tensor. Each output element is:
// output[...] = self[index[...]]

void main() {
const uint out_bufi = gl_GlobalInvocationID.x;
if (out_of_bounds(out_bufi, outp)) {
return;
}

// Convert output buffer index to tensor index
TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi);

// Read the index value at the same tensor position
const uint index_bufi = tensor_idx_to_linear_idx(index, out_tidx);
const int idx = t_index[index_bufi];

// Construct a tensor index for the 1D self tensor.
// In WHCN ordering, a 1D tensor has its elements along dim 0 (width).
TensorIndex self_tidx;
self_tidx.data[0] = uvec4(uint(idx), 0, 0, 0);
self_tidx.data[1] = uvec4(0);
const uint self_bufi = tensor_idx_to_linear_idx(inp, self_tidx);

t_out[out_bufi] = t_self[self_bufi];
}
16 changes: 16 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/index_tensor_buffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

index_tensor_buffer:
parameter_names_with_default_values:
DTYPE: float
STORAGE: buffer
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: index_tensor_buffer
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

${define_required_extensions("texture3d", DTYPE)}

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}

${define_active_storage_type("texture3d")}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

#include "common.glslh"
#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_self", DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_index", "int", "texture3d")}

${layout_declare_ubo(B, "TextureMetadata", "outp")}
${layout_declare_ubo(B, "TextureMetadata", "inp")}
${layout_declare_ubo(B, "TextureMetadata", "index")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// Implements aten.index.Tensor for the case where self is 1D and there is
// exactly one index tensor. Each output element is:
// output[...] = self[index[...]]

void main() {
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);

if (out_of_bounds(out_pos, outp)) {
return;
}

TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);
ivec4 idx_texel = texelFetch(t_index, out_pos, 0);

VEC4_T out_texel = VEC4_T(0);

int limit = min(
4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]);
for (int comp = 0; comp < limit; comp++) {
int idx = idx_texel[comp];

// Construct a tensor index for the 1D self tensor.
// In WHCN ordering, a 1D tensor has its elements along dim 0 (width).
TensorIndex4D self_tidx;
self_tidx.data = ivec4(idx, 0, 0, 0);

TextureElementIndex self_elem =
tensor4d_idx_to_texture_element_idx_simple(inp, self_tidx);

VEC4_T self_texel = texelFetch(t_self, self_elem.pos, 0);
out_texel[comp] = self_texel[self_elem.comp];

out_tidx.data[outp.packed_dim]++;
}

imageStore(t_out, out_pos, out_texel);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

index_tensor_texture:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: index_tensor_texture3d
Loading
Loading