Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions backends/vulkan/runtime/graph/ops/glsl/softmax.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,25 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
for (int i = tid.x; i < tin_sizes[reduce_dim];
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements);
vec4 outtex = op2(numerators, denominators);
// Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
const vec4 safe_denom = max(denominators, vec4(1e-37));
vec4 outtex = op2(numerators, safe_denom);
// Replace any NaN/Inf with 0 using IEEE 754 bit-level manipulation.
// This avoids isnan()/x!=x which may not work reliably on all GPU drivers:
// - OpIsNan may have driver bugs for certain NaN bit patterns
// - OpFOrdNotEqual(NaN,NaN) = false (ordered comparison semantics)
// NaN/Inf pattern: all exponent bits set = (bits & 0x7F800000) == 0x7F800000
{
uvec4 bits = floatBitsToUint(outtex);
// Build a mask: 0xFFFFFFFF where NaN/Inf (exponent all-ones), else 0
uvec4 nan_inf_mask = uvec4(
((bits.x & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
((bits.y & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
((bits.z & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
((bits.w & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u);
// Zero out bits where NaN/Inf: normal values are unchanged
outtex = uintBitsToFloat(bits & ~nan_inf_mask);
}
// For the last texel in the packed dim, make sure that the padding elements
// are explicitly set to 0. Otherwise, they may influence computations later
// down the line.
Expand All @@ -153,6 +171,9 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
}
write_texel(tout, scan_pos, outtex);
}
// Flush outstanding imageStore writes so they're committed to memory and
// visible to subsequent GPU operations on this image.
memoryBarrierImage();
}

/*
Expand All @@ -173,7 +194,12 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
const int reduce_len = tin_sizes[packed_dim] - nspill;

scan_pos[reduce_dim] = tid.x;
vec4 max_elements = vec4(load_texel(tin, scan_pos).x);
// Initialize with -FLT_MAX to avoid contaminating the maximum with out-of-
// bounds texture reads. When NWORKERS > number of texels (e.g. reduce_len=12
// has 3 texels but NWORKERS=4), worker threads with no valid texels would
// otherwise load from an OOB index and get 0, which corrupts the max for
// rows where all values are negative and causes denominator underflow -> NaN.
vec4 max_elements = vec4(-3.402823e+38);
for (int i = tid.x * 4; i < reduce_len;
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
max_elements = max(max_elements, load_texel(tin, scan_pos));
Expand Down Expand Up @@ -230,19 +256,21 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
[[unroll]] for (int i = 0; i < 4; ++i) {
denominator += denominators[i];
}
// Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
const float safe_denominator = max(denominator, 1e-37);

scan_pos[reduce_dim] = tid.x;
for (int i = tid.x * 4; i < reduce_len;
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
const vec4 numerators = op1(load_texel(tin, scan_pos) - max_element);
write_texel(tout, scan_pos, op2(numerators, denominator));
write_texel(tout, scan_pos, op2(numerators, safe_denominator));
}
// For the last texel in the dim, if there are padding elements then the
// padding elements need to be set to 0 explicitly, otherwise they may
// influence subsequent operations.
if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) {
const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element);
vec4 outtex = op2(numerator, denominator);
vec4 outtex = op2(numerator, safe_denominator);
[[unroll]] for (int i = nspill; i < 4; ++i) {
outtex[i] = 0;
}
Expand Down
78 changes: 54 additions & 24 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,21 @@ utils::uvec3 conv2d_global_wg_size(
utils::uvec3 wg_size = create_conv2d_global_wg_size(
*graph, method, out, weight_data, stride_equals_dilation);

if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) {
if (method == Conv2dMethod::Depthwise) {
// The output_tile shaders (conv2d_dw_output_tile,
// conv2d_dw_sned_output_tile) use a 2D dispatch: (x_tile, y_tile) packed
// into glb_x, channel in glb_y. The base conv2d_dw shader uses a 1D
// dispatch: all (x, y, channel) packed into glb_x. For the base shader, we
// must use {W*H*C_packed, 1, 1}.
const bool uses_output_tile =
shader.kernel_name.find("_output_tile") != std::string::npos;
if (uses_output_tile) {
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
} else {
const utils::uvec3 base_extents = graph->create_global_wg_size(out);
wg_size = {base_extents[0] * base_extents[1] * base_extents[2], 1, 1};
}
} else if (method == Conv2dMethod::Pointwise) {
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};

if (shader.kernel_name.find("s1p0") != std::string::npos) {
Expand Down Expand Up @@ -562,29 +576,45 @@ void add_conv2d_node(
PushConstantDataInfo(&param, sizeof(param)),
};
} else if (method == Conv2dMethod::Depthwise) {
const utils::ivec4 kernel_param_size_stride = {
kernel_params.kernel_size[0],
kernel_params.kernel_size[1],
kernel_params.stride[0],
kernel_params.stride[1]};

const utils::ivec4 kernel_param_pad_dial = {
kernel_params.padding[0],
kernel_params.padding[1],
kernel_params.dilation[0],
kernel_params.dilation[1]};

push_constants = {
graph.logical_limits_pc_of(out),
graph.sizes_pc_of(in),
PushConstantDataInfo(
&kernel_param_size_stride, sizeof(kernel_param_size_stride)),
PushConstantDataInfo(
&kernel_param_pad_dial, sizeof(kernel_param_pad_dial)),
PushConstantDataInfo(
&extra_params, sizeof(extra_params), sizeof(utils::ivec4)),
PushConstantDataInfo(&out_params, sizeof(out_params)),
};
// output_tile variants use push constants; the base conv2d_dw shader uses
// UBOs. Distinguish by checking if "_output_tile" is in the shader name.
const bool uses_output_tile =
shader.kernel_name.find("_output_tile") != std::string::npos;

if (uses_output_tile) {
const utils::ivec4 kernel_param_size_stride = {
kernel_params.kernel_size[0],
kernel_params.kernel_size[1],
kernel_params.stride[0],
kernel_params.stride[1]};

const utils::ivec4 kernel_param_pad_dial = {
kernel_params.padding[0],
kernel_params.padding[1],
kernel_params.dilation[0],
kernel_params.dilation[1]};

push_constants = {
graph.logical_limits_pc_of(out),
graph.sizes_pc_of(in),
PushConstantDataInfo(
&kernel_param_size_stride, sizeof(kernel_param_size_stride)),
PushConstantDataInfo(
&kernel_param_pad_dial, sizeof(kernel_param_pad_dial)),
PushConstantDataInfo(
&extra_params, sizeof(extra_params), sizeof(utils::ivec4)),
PushConstantDataInfo(&out_params, sizeof(out_params)),
};
} else {
// Base conv2d_dw shader uses UBOs, same as SlidingWindow case
param_buffers = {
graph.logical_limits_ubo(out),
graph.sizes_ubo(in),
graph.create_params_buffer(kernel_params),
graph.create_params_buffer(extra_params),
graph.create_params_buffer(out_params),
};
}
} else {
param_buffers = {
graph.logical_limits_ubo(out),
Expand Down
18 changes: 17 additions & 1 deletion backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,7 +1588,23 @@ def get_softmax_inputs():
"utils::kWidthPacked",
"utils::kChannelsPacked",
]
return test_suite

# Large negative values regression test (edgeTAM attention scores that
# produced NaN due to missing max-shift in softmax numerics)
large_neg_test_suite = VkTestSuite(
[
((1, 8, 512, 12), -1, False),
]
)
large_neg_test_suite.layouts = [
"utils::kWidthPacked",
"utils::kChannelsPacked",
]
large_neg_test_suite.data_range = (-1.8e10, -6.5e9)
large_neg_test_suite.test_name_suffix = "large_negative"
large_neg_test_suite.dtypes = ["at::kFloat"]

return [test_suite, large_neg_test_suite]


@register_test_suite(
Expand Down
Loading