Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ runtime.python_library(
],
)

runtime.python_library(
name = "insert_dtype_promotion",
srcs = ["insert_dtype_promotion.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "fuse_patterns",
srcs = ["fuse_patterns.py"],
Expand Down Expand Up @@ -133,6 +146,7 @@ runtime.python_library(
":fold_qdq",
":fuse_patterns",
":fuse_quantized_ops",
":insert_dtype_promotion",
":insert_prepack_nodes",
":remove_asserts",
":remove_redundant_ops",
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
FuseQuantizedOpsTransform,
)
from executorch.backends.vulkan._passes.insert_dtype_promotion import (
InsertDtypePromotionPass,
)
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
from executorch.backends.vulkan._passes.remove_asserts import (
remove_asserts,
Expand All @@ -28,6 +31,7 @@
"FoldQDQPass",
"FusePatternsPass",
"FuseQuantizedOpsTransform",
"InsertDtypePromotionPass",
"insert_prepack_nodes",
"remove_asserts",
"RemoveAssertsTransform",
Expand Down
102 changes: 102 additions & 0 deletions backends/vulkan/_passes/insert_dtype_promotion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Set, Union

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass

OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]

# Binary ops whose first two args are tensor inputs that may need promotion
BINARY_OPS: Set[OpType] = {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.ge.Tensor,
}


def _promote_dtype(a: torch.dtype, b: torch.dtype) -> torch.dtype:
"""Promote to common dtype following PyTorch type promotion rules."""
if a == b:
return a
# Any mix of different dtypes promotes to float32
return torch.float32


class InsertDtypePromotionPass(ExportPass):
"""
Insert _to_copy nodes before binary ops when the two tensor inputs have
different dtypes, promoting both to a common dtype.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
dirty = False
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in BINARY_OPS:
continue

lhs = node.args[0]
rhs = node.args[1]

if not isinstance(lhs, torch.fx.Node) or not isinstance(rhs, torch.fx.Node):
continue

if "val" not in lhs.meta or "val" not in rhs.meta:
continue

lhs_dtype = lhs.meta["val"].dtype
rhs_dtype = rhs.meta["val"].dtype

if lhs_dtype == rhs_dtype:
continue

promoted = _promote_dtype(lhs_dtype, rhs_dtype)

if lhs_dtype != promoted:
with graph_module.graph.inserting_before(node):
cast_lhs = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(lhs,),
{"dtype": promoted},
)
cast_lhs.meta["val"] = lhs.meta["val"].to(promoted)
node.replace_input_with(lhs, cast_lhs)
dirty = True

if rhs_dtype != promoted:
with graph_module.graph.inserting_before(node):
cast_rhs = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(rhs,),
{"dtype": promoted},
)
cast_rhs.meta["val"] = rhs.meta["val"].to(promoted)
node.replace_input_with(rhs, cast_rhs)
dirty = True

if dirty:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
dead_code_elimination_pass(graph_module)

return PassResult(graph_module, dirty)
22 changes: 20 additions & 2 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,28 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
# Vulkan compute graph. This annotation is used in later graph passes.
node.meta["etvk_tensorref"] = True

# Get the list of node users that do not handle their own prepacking
# Get the list of node users that need a prepack node inserted. This
# includes ops that don't handle their own prepacking, as well as ops
# that do handle their own prepacking but use this constant tensor as
# the primary input (since the op expects the primary input to be a GPU
# tensor, not a TensorRef).
nodes_to_replace_input = []
for user in node.users:
if user.op == "call_function" and not handles_own_prepacking(user.target):
if user.op != "call_function":
continue

if not handles_own_prepacking(user.target):
nodes_to_replace_input.append(user)
continue

# Most prepacking ops have the primary input at arg 0, but
# embedding is embedding(weight, indices, ...) where the
# primary input (indices) is at arg 1.
primary_arg_idx = 0
if user.target == exir_ops.edge.aten.embedding.default:
primary_arg_idx = 1

if node in user.args and user.args.index(node) == primary_arg_idx:
nodes_to_replace_input.append(user)

if len(nodes_to_replace_input) == 0:
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_T,
outputs_dtypes=utils.FP_T,
inputs_dtypes=utils.FP_INT_T,
outputs_dtypes=utils.FP_INT_T,
supports_resize=True,
are_node_inputs_supported_fn=check_to_copy_node,
)
Expand Down
10 changes: 5 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ binary_op:
- VALUE: half
- VALUE: float
- NAME: binary_eq_texture3d
OPERATOR: all(lessThanEqual(abs(X - Y), VEC4_T(1e-5)))
OPERATOR: lessThanEqual(abs(X - Y), VEC4_T(1e-5))
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -61,7 +61,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_lt_texture3d
OPERATOR: all(lessThan(X, Y))
OPERATOR: lessThan(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -77,7 +77,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_le_texture3d
OPERATOR: all(lessThanEqual(X, Y))
OPERATOR: lessThanEqual(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -93,7 +93,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_gt_texture3d
OPERATOR: all(greaterThan(X, Y))
OPERATOR: greaterThan(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -109,7 +109,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_ge_texture3d
OPERATOR: all(greaterThanEqual(X, Y))
OPERATOR: greaterThanEqual(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand Down
36 changes: 32 additions & 4 deletions backends/vulkan/runtime/graph/ops/glsl/softmax.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,25 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
for (int i = tid.x; i < tin_sizes[reduce_dim];
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements);
vec4 outtex = op2(numerators, denominators);
// Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
const vec4 safe_denom = max(denominators, vec4(1e-37));
vec4 outtex = op2(numerators, safe_denom);
// Replace any NaN/Inf with 0 using IEEE 754 bit-level manipulation.
// This avoids isnan()/x!=x which may not work reliably on all GPU drivers:
// - OpIsNan may have driver bugs for certain NaN bit patterns
// - OpFOrdNotEqual(NaN,NaN) = false (ordered comparison semantics)
// NaN/Inf pattern: all exponent bits set = (bits & 0x7F800000) == 0x7F800000
{
uvec4 bits = floatBitsToUint(outtex);
// Build a mask: 0xFFFFFFFF where NaN/Inf (exponent all-ones), else 0
uvec4 nan_inf_mask = uvec4(
((bits.x & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
((bits.y & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
((bits.z & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
((bits.w & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u);
// Zero out bits where NaN/Inf: normal values are unchanged
outtex = uintBitsToFloat(bits & ~nan_inf_mask);
}
// For the last texel in the packed dim, make sure that the padding elements
// are explicitly set to 0. Otherwise, they may influence computations later
// down the line.
Expand All @@ -153,6 +171,9 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
}
write_texel(tout, scan_pos, outtex);
}
// Flush outstanding imageStore writes so they're committed to memory and
// visible to subsequent GPU operations on this image.
memoryBarrierImage();
}

/*
Expand All @@ -173,7 +194,12 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
const int reduce_len = tin_sizes[packed_dim] - nspill;

scan_pos[reduce_dim] = tid.x;
vec4 max_elements = vec4(load_texel(tin, scan_pos).x);
// Initialize with -FLT_MAX to avoid contaminating the maximum with out-of-
// bounds texture reads. When NWORKERS > number of texels (e.g. reduce_len=12
// has 3 texels but NWORKERS=4), worker threads with no valid texels would
// otherwise load from an OOB index and get 0, which corrupts the max for
// rows where all values are negative and causes denominator underflow -> NaN.
vec4 max_elements = vec4(-3.402823e+38);
for (int i = tid.x * 4; i < reduce_len;
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
max_elements = max(max_elements, load_texel(tin, scan_pos));
Expand Down Expand Up @@ -230,19 +256,21 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
[[unroll]] for (int i = 0; i < 4; ++i) {
denominator += denominators[i];
}
// Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
const float safe_denominator = max(denominator, 1e-37);

scan_pos[reduce_dim] = tid.x;
for (int i = tid.x * 4; i < reduce_len;
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
const vec4 numerators = op1(load_texel(tin, scan_pos) - max_element);
write_texel(tout, scan_pos, op2(numerators, denominator));
write_texel(tout, scan_pos, op2(numerators, safe_denominator));
}
// For the last texel in the dim, if there are padding elements then the
// padding elements need to be set to 0 explicitly, otherwise they may
// influence subsequent operations.
if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) {
const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element);
vec4 outtex = op2(numerator, denominator);
vec4 outtex = op2(numerator, safe_denominator);
[[unroll]] for (int i = nspill; i < 4; ++i) {
outtex[i] = 0;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define IN_VEC4_T ${texel_type(IN_DTYPE)}
#define OUT_VEC4_T ${texel_type(OUT_DTYPE)}

${define_required_extensions("texture3d", IN_DTYPE)}
${define_required_extensions("texture3d", OUT_DTYPE)}

#include "indexing_utils.h"

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = C_DIM;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim);

if (any(greaterThanEqual(idx, out_sizes))) {
return;
}

IN_VEC4_T in_texel = IN_VEC4_T(texelFetch(t_in, pos, 0));
imageStore(t_out, pos, OUT_VEC4_T(in_texel));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

view_convert_texture:
parameter_names_with_default_values:
IN_DTYPE: float
OUT_DTYPE: float
STORAGE: texture3d
generate_variant_forall:
combination:
parameter_names: [IN_DTYPE, OUT_DTYPE]
combos:
- parameter_values: [int32, float]
- parameter_values: [float, int32]
shader_variants:
- NAME: view_convert_texture
Loading
Loading