Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ runtime.python_library(
],
)

runtime.python_library(
name = "insert_dtype_promotion",
srcs = ["insert_dtype_promotion.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "fuse_patterns",
srcs = ["fuse_patterns.py"],
Expand Down Expand Up @@ -133,6 +146,7 @@ runtime.python_library(
":fold_qdq",
":fuse_patterns",
":fuse_quantized_ops",
":insert_dtype_promotion",
":insert_prepack_nodes",
":remove_asserts",
":remove_redundant_ops",
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
FuseQuantizedOpsTransform,
)
from executorch.backends.vulkan._passes.insert_dtype_promotion import (
InsertDtypePromotionPass,
)
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
from executorch.backends.vulkan._passes.remove_asserts import (
remove_asserts,
Expand All @@ -28,6 +31,7 @@
"FoldQDQPass",
"FusePatternsPass",
"FuseQuantizedOpsTransform",
"InsertDtypePromotionPass",
"insert_prepack_nodes",
"remove_asserts",
"RemoveAssertsTransform",
Expand Down
102 changes: 102 additions & 0 deletions backends/vulkan/_passes/insert_dtype_promotion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Set, Union

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass

OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]

# Binary ops whose first two args are tensor inputs that may need promotion
BINARY_OPS: Set[OpType] = {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.ge.Tensor,
}


def _promote_dtype(a: torch.dtype, b: torch.dtype) -> torch.dtype:
"""Promote to common dtype following PyTorch type promotion rules."""
if a == b:
return a
# Any mix of different dtypes promotes to float32
return torch.float32


class InsertDtypePromotionPass(ExportPass):
"""
Insert _to_copy nodes before binary ops when the two tensor inputs have
different dtypes, promoting both to a common dtype.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
dirty = False
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in BINARY_OPS:
continue

lhs = node.args[0]
rhs = node.args[1]

if not isinstance(lhs, torch.fx.Node) or not isinstance(rhs, torch.fx.Node):
continue

if "val" not in lhs.meta or "val" not in rhs.meta:
continue

lhs_dtype = lhs.meta["val"].dtype
rhs_dtype = rhs.meta["val"].dtype

if lhs_dtype == rhs_dtype:
continue

promoted = _promote_dtype(lhs_dtype, rhs_dtype)

if lhs_dtype != promoted:
with graph_module.graph.inserting_before(node):
cast_lhs = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(lhs,),
{"dtype": promoted},
)
cast_lhs.meta["val"] = lhs.meta["val"].to(promoted)
node.replace_input_with(lhs, cast_lhs)
dirty = True

if rhs_dtype != promoted:
with graph_module.graph.inserting_before(node):
cast_rhs = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(rhs,),
{"dtype": promoted},
)
cast_rhs.meta["val"] = rhs.meta["val"].to(promoted)
node.replace_input_with(rhs, cast_rhs)
dirty = True

if dirty:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
dead_code_elimination_pass(graph_module)

return PassResult(graph_module, dirty)
4 changes: 2 additions & 2 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_T,
outputs_dtypes=utils.FP_T,
inputs_dtypes=utils.FP_INT_T,
outputs_dtypes=utils.FP_INT_T,
supports_resize=True,
are_node_inputs_supported_fn=check_to_copy_node,
)
Expand Down
10 changes: 5 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ binary_op:
- VALUE: half
- VALUE: float
- NAME: binary_eq_texture3d
OPERATOR: all(lessThanEqual(abs(X - Y), VEC4_T(1e-5)))
OPERATOR: lessThanEqual(abs(X - Y), VEC4_T(1e-5))
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -61,7 +61,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_lt_texture3d
OPERATOR: all(lessThan(X, Y))
OPERATOR: lessThan(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -77,7 +77,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_le_texture3d
OPERATOR: all(lessThanEqual(X, Y))
OPERATOR: lessThanEqual(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -93,7 +93,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_gt_texture3d
OPERATOR: all(greaterThan(X, Y))
OPERATOR: greaterThan(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -109,7 +109,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_ge_texture3d
OPERATOR: all(greaterThanEqual(X, Y))
OPERATOR: greaterThanEqual(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define IN_VEC4_T ${texel_type(IN_DTYPE)}
#define OUT_VEC4_T ${texel_type(OUT_DTYPE)}

${define_required_extensions("texture3d", IN_DTYPE)}
${define_required_extensions("texture3d", OUT_DTYPE)}

#include "indexing_utils.h"

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = C_DIM;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim);

if (any(greaterThanEqual(idx, out_sizes))) {
return;
}

IN_VEC4_T in_texel = IN_VEC4_T(texelFetch(t_in, pos, 0));
imageStore(t_out, pos, OUT_VEC4_T(in_texel));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

view_convert_texture:
parameter_names_with_default_values:
IN_DTYPE: float
OUT_DTYPE: float
STORAGE: texture3d
generate_variant_forall:
combination:
parameter_names: [IN_DTYPE, OUT_DTYPE]
combos:
- parameter_values: [int32, float]
- parameter_values: [float, int32]
shader_variants:
- NAME: view_convert_texture
36 changes: 23 additions & 13 deletions backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include <executorch/backends/vulkan/runtime/graph/ops/BlitNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
#include <set>

namespace vkcompute {

Expand All @@ -25,19 +25,29 @@ void resize_to_copy_op_node(
graph->virtual_resize(out, graph->sizes_of(self));
}

bool is_float_type(vkapi::ScalarType dtype) {
return dtype == vkapi::ScalarType::Float || dtype == vkapi::ScalarType::Half;
}

void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
static std::set<vkapi::ScalarType> supported_types = {
vkapi::ScalarType::Float, vkapi::ScalarType::Half};

VK_CHECK_COND(
supported_types.find(graph.dtype_of(in)) != supported_types.end() &&
supported_types.find(graph.dtype_of(out)) != supported_types.end(),
"Unsupported dtype for to_copy, only Float and Half are currently supported, recieved ",
vkapi::to_string(graph.dtype_of(in)),
" <-> ",
vkapi::to_string(graph.dtype_of(out)));

graph.execute_nodes().emplace_back(new BlitNode(graph, in, out));
vkapi::ScalarType in_dtype = graph.dtype_of(in);
vkapi::ScalarType out_dtype = graph.dtype_of(out);

// Same-dtype or float<->half conversions can use BlitNode
if (in_dtype == out_dtype ||
(is_float_type(in_dtype) && is_float_type(out_dtype))) {
graph.execute_nodes().emplace_back(new BlitNode(graph, in, out));
return;
}

// Other conversions (e.g. int<->float) use view_convert shaders
if (graph.is_buffer_storage(in)) {
add_view_copy_convert_buffer_node(
graph, in, out, {}, resize_to_copy_op_node);
} else {
add_view_copy_convert_texture_node(
graph, in, out, {}, resize_to_copy_op_node);
}
}

void to_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
Expand Down
29 changes: 29 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/View.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,35 @@ void add_view_copy_buffer_node(
resize_fn));
}

void add_view_copy_convert_texture_node(
ComputeGraph& graph,
ValueRef in,
ValueRef out,
const std::vector<ValueRef>& resize_args,
const ExecuteNode::ResizeFunction& resize_fn) {
std::string kernel_name = "view_convert_texture";
add_dtype_suffix(kernel_name, graph.dtype_of(in));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Parameter Buffers
{},
// Push Constants
{{graph.sizes_pc_of(out)}},
// Specialization Constants
{graph.packed_dim_of(out)},
// Resize Args
resize_args,
// Resizing Logic
resize_fn));
}

void add_view_copy_convert_buffer_node(
ComputeGraph& graph,
ValueRef in,
Expand Down
12 changes: 12 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/View.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ void add_view_copy_convert_buffer_node(
const std::vector<ValueRef>& resize_args,
const ExecuteNode::ResizeFunction& resize_fn);

/*
* Dispatches the view_convert_texture compute shader. This can be used to
* convert between different data types for 3D texture tensors while
* preserving the texel positions.
*/
void add_view_copy_convert_texture_node(
ComputeGraph& graph,
ValueRef in,
ValueRef out,
const std::vector<ValueRef>& resize_args,
const ExecuteNode::ResizeFunction& resize_fn);

void add_view_node(
ComputeGraph& graph,
ValueRef in,
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
FoldQDQPass,
FuseQuantizedOpsTransform,
insert_prepack_nodes,
InsertDtypePromotionPass,
RemoveRedundantOpsTransform,
SqueezeUnsqueezeInputs,
TagMemoryMetaPass,
Expand Down Expand Up @@ -165,6 +166,7 @@ def preprocess( # noqa: C901
AddmmToLinearTransform(),
FuseBatchNormPass(program),
AddmmToLinearTransform(),
InsertDtypePromotionPass(),
FusePatternsPass(),
FuseClampPass(),
RemoveRedundantOpsTransform(),
Expand Down
Loading