Skip to content
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
)
def register_softmax_cpp_ops():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_T,
supports_resize=True,
)
Expand Down
152 changes: 33 additions & 119 deletions backends/vulkan/runtime/graph/ops/glsl/softmax.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -20,94 +20,49 @@ ${define_active_storage_type(STORAGE)}

layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}

layout(push_constant) uniform restrict Block {
ivec4 tin_sizes;
ivec3 tout_limits;
};
${layout_declare_ubo(B, "TextureMetadata", "in_meta")}
${layout_declare_ubo(B, "TextureMetadata", "out_meta")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 0;
layout(constant_id = 4) const int reduce_dim = 0;
layout(constant_id = 5) const int group_dim = 1;
${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")}
const int packed_dim = get_packed_dim(out_layout);

// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
// threads that will co-operate to compute one reduction output. There may be
// multiple groups computing distinct reduction outputs within one work group.
#define NWORKERS 4
${layout_declare_spec_const(C, "int", "reduce_dim", "0")}
${layout_declare_spec_const(C, "int", "group_dim", "1")}

// Sets an upper limit on the total size of a work group based on how many
// elements are allocated in the shared memory array below. Each thread in the
// work group will write into its assigned element in the shared array.
#define NWORKERS 4
#define MAX_NTHREADS 16

shared vec4 shared_max[MAX_NTHREADS];
shared vec4 shared_sum[MAX_NTHREADS];

#include "indexing_utils.h"

int tid_to_smi(const ivec2 tid) {
return tid.x + tid.y * NWORKERS;
}

/*
* The shaders below compute softmax for a tensor. Softmax is an interesting mix
* between a reduction operator and a unary elementwise operator, defined as
* exp(x) / (sum of exp(x)). The general flow of the computation is:
*
* First, find the maximum element along the reduction dim. The maximum element
* is used to preserve numerical stability, since division of exponents is
* translation invariant.
*
* Next, compute the sum of exp(x - max_element) along the reduction dim.
*
* Finally, for each element along the reduction dim, we compute the output as
* exp(x - max_element) / sum_of_exponents.
*
* The shaders below also utilize shared memory to have multiple threads help
* compute the max and sum reduction operations. A total of NGROUPS x NWORKERS
* threads are launched. Each group works on a unique reduction "row", and
* within a group NWORKERS threads co-operate to compute the max and sum of one
* "row". Each worker in the group is responsible for computing a partial output
* of the "row" and uploading it to shared memory; the overall reduction output
* can then be determined by aggregating the partial outputs stored in shared
* memory.
*
* As a caveat, this shader does not currently support cases where `batch` > 1
* and the reduce dim happens to also be the batch concatenation dim. To support
* this, there will need to be additional logic to set the starting value of
* `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case,
* supporting this case is left as an exercise for when it is required.
*
* As a final note, log softmax is supported with this shader as well since via
* the op1 and op2 macro definitions. See the corresponding YAML file for more
* details.
*/

/*
* Computes softmax where the reduction dim is orthogonal to the packed dim.
* This case is simpler because each element of a texel belongs to a separate
* reduction dim, meaning we don't have to perform reduction along a texel.
*/
void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
// shared memory index of this thread
const int smi = tid_to_smi(tid);
// used to iterate over all shared memory in the group
int group_i;

scan_pos[reduce_dim] = tid.x;
vec4 max_elements = load_texel(tin, scan_pos);
// This thread computes a partial maximum
for (int i = tid.x; i < tin_sizes[reduce_dim];
vec4 max_elements = texelFetch(tin, scan_pos, 0);
for (int i = tid.x; i < in_meta.sizes[reduce_dim];
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
max_elements = max(max_elements, load_texel(tin, scan_pos));
max_elements = max(max_elements, texelFetch(tin, scan_pos, 0));
}
shared_max[smi] = max_elements;
barrier();
// Iterate over the partial maximums to obtain the overall maximum
group_i = tid.y * NWORKERS;
max_elements = shared_max[group_i++];
for (int i = 1; i < NWORKERS; ++i, group_i++) {
Expand All @@ -116,63 +71,44 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {

scan_pos[reduce_dim] = tid.x;
vec4 denominators = vec4(0);
// Compute partial sum
for (int i = tid.x; i < tin_sizes[reduce_dim];
for (int i = tid.x; i < in_meta.sizes[reduce_dim];
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
denominators += exp(load_texel(tin, scan_pos) - max_elements);
denominators += exp(texelFetch(tin, scan_pos, 0) - max_elements);
}
shared_sum[smi] = denominators;
barrier();
// Iterate over the partial sums to obtain the overall sum
group_i = tid.y * NWORKERS;
denominators = shared_sum[group_i++];
for (int i = 1; i < NWORKERS; ++i, group_i++) {
denominators += shared_sum[group_i];
}

// Determine if there are any padding elements in the final texel of the
// packed dimension
const int nspill = mod4(tin_sizes[packed_dim]);
// Detect if this thread is working on the final texels of the packed
// dimension, which may have padding elements
const int nspill = mod_4(in_meta.sizes[packed_dim]);
const bool is_last_texel =
scan_pos[packed_dim] == (tout_limits[packed_dim] - 1);
scan_pos[packed_dim] == (out_meta.limits[packed_dim] - 1);

scan_pos[reduce_dim] = tid.x;
for (int i = tid.x; i < tin_sizes[reduce_dim];
for (int i = tid.x; i < in_meta.sizes[reduce_dim];
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements);
// Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
const vec4 numerators = op1(texelFetch(tin, scan_pos, 0) - max_elements);
const vec4 safe_denom = max(denominators, vec4(1e-37));
vec4 outtex = op2(numerators, safe_denom);
// Replace any NaN/Inf with 0 using IEEE 754 bit-level manipulation.
// This avoids isnan()/x!=x which may not work reliably on all GPU drivers:
// - OpIsNan may have driver bugs for certain NaN bit patterns
// - OpFOrdNotEqual(NaN,NaN) = false (ordered comparison semantics)
// NaN/Inf pattern: all exponent bits set = (bits & 0x7F800000) == 0x7F800000
{
uvec4 bits = floatBitsToUint(outtex);
// Build a mask: 0xFFFFFFFF where NaN/Inf (exponent all-ones), else 0
uvec4 nan_inf_mask = uvec4(
((bits.x & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
((bits.y & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
((bits.z & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u,
((bits.w & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u);
// Zero out bits where NaN/Inf: normal values are unchanged
outtex = uintBitsToFloat(bits & ~nan_inf_mask);
}
// For the last texel in the packed dim, make sure that the padding elements
// are explicitly set to 0. Otherwise, they may influence computations later
// down the line.
if (is_last_texel && nspill > 0) {
[[unroll]] for (int i = nspill; i < 4; ++i) {
outtex[i] = 0;
}
}
write_texel(tout, scan_pos, outtex);
imageStore(tout, scan_pos, outtex);
}
// Flush outstanding imageStore writes so they're committed to memory and
// visible to subsequent GPU operations on this image.
memoryBarrierImage();
}

Expand All @@ -185,44 +121,31 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
* multiple of 4) so that they do not influence the output of reduction.
*/
void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
// shared memory index of this thread
const int smi = tid_to_smi(tid);
// used to iterate over all shared memory in the group
int group_i;

const int nspill = mod4(tin_sizes[packed_dim]);
const int reduce_len = tin_sizes[packed_dim] - nspill;
const int nspill = mod_4(in_meta.sizes[packed_dim]);
const int reduce_len = in_meta.sizes[packed_dim] - nspill;

scan_pos[reduce_dim] = tid.x;
// Initialize with -FLT_MAX to avoid contaminating the maximum with out-of-
// bounds texture reads. When NWORKERS > number of texels (e.g. reduce_len=12
// has 3 texels but NWORKERS=4), worker threads with no valid texels would
// otherwise load from an OOB index and get 0, which corrupts the max for
// rows where all values are negative and causes denominator underflow -> NaN.
vec4 max_elements = vec4(-3.402823e+38);
for (int i = tid.x * 4; i < reduce_len;
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
max_elements = max(max_elements, load_texel(tin, scan_pos));
max_elements = max(max_elements, texelFetch(tin, scan_pos, 0));
}
// For the last texel in the dim, if there are padding elements then each
// element of the texel needs to be processed individually such that the
// padding elements are ignored
if (scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1 && nspill > 0) {
const vec4 intex = load_texel(tin, scan_pos);
if (scan_pos[reduce_dim] == out_meta.limits[reduce_dim] - 1 && nspill > 0) {
const vec4 intex = texelFetch(tin, scan_pos, 0);
for (int i = 0; i < nspill; ++i) {
max_elements.x = max(intex[i], max_elements.x);
}
}
shared_max[smi] = max_elements;
barrier();
// Iterate over the partial maximums to obtain the overall maximum
group_i = tid.y * NWORKERS;
max_elements = shared_max[group_i++];
for (int i = 1; i < NWORKERS; ++i, group_i++) {
max_elements = max(max_elements, shared_max[group_i]);
}
// Each element of the texel is itself a partial maximum; iterate over the
// texel to find the actual maximum
float max_element = max_elements.x;
[[unroll]] for (int i = 1; i < 4; ++i) {
max_element = max(max_elements[i], max_element);
Expand All @@ -232,49 +155,40 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
vec4 denominators = vec4(0);
for (int i = tid.x * 4; i < reduce_len;
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
denominators += exp(load_texel(tin, scan_pos) - max_element);
denominators += exp(texelFetch(tin, scan_pos, 0) - max_element);
}
// For the last texel in the dim, if there are padding elements then each
// element of the texel needs to be processed individually such that the
// padding elements are ignored
if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) {
const vec4 intex = load_texel(tin, scan_pos);
if (nspill > 0 && scan_pos[reduce_dim] == out_meta.limits[reduce_dim] - 1) {
const vec4 intex = texelFetch(tin, scan_pos, 0);
for (int i = 0; i < nspill; ++i) {
denominators.x += exp(intex[i] - max_element);
}
}
shared_sum[smi] = denominators;
barrier();
// Iterate over the partial sums to obtain the overall sum
group_i = tid.y * NWORKERS;
denominators = shared_sum[group_i++];
for (int i = 1; i < NWORKERS; ++i, group_i++) {
denominators += shared_sum[group_i];
}
// Reduce over the accumulated texel to find the overall sum
float denominator = 0;
[[unroll]] for (int i = 0; i < 4; ++i) {
denominator += denominators[i];
}
// Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
const float safe_denominator = max(denominator, 1e-37);

scan_pos[reduce_dim] = tid.x;
for (int i = tid.x * 4; i < reduce_len;
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
const vec4 numerators = op1(load_texel(tin, scan_pos) - max_element);
write_texel(tout, scan_pos, op2(numerators, safe_denominator));
const vec4 numerators = op1(texelFetch(tin, scan_pos, 0) - max_element);
imageStore(tout, scan_pos, op2(numerators, safe_denominator));
}
// For the last texel in the dim, if there are padding elements then the
// padding elements need to be set to 0 explicitly, otherwise they may
// influence subsequent operations.
if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) {
const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element);
if (nspill > 0 && scan_pos[reduce_dim] == out_meta.limits[reduce_dim] - 1) {
const vec4 numerator = op1(texelFetch(tin, scan_pos, 0) - max_element);
vec4 outtex = op2(numerator, safe_denominator);
[[unroll]] for (int i = nspill; i < 4; ++i) {
outtex[i] = 0;
}
write_texel(tout, scan_pos, outtex);
imageStore(tout, scan_pos, outtex);
}
}

Expand All @@ -286,7 +200,7 @@ void main() {
gl_LocalInvocationID[reduce_dim],
gl_LocalInvocationID[group_dim]);

if (any(greaterThanEqual(scan_pos, tout_limits))) {
if (any(greaterThanEqual(scan_pos, out_meta.limits))) {
return;
}

Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/softmax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ softmax:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: softmax
- NAME: log_softmax
- NAME: softmax_texture3d
- NAME: log_softmax_texture3d
OPERATOR1: X
OPERATOR2: X - log(Y)
Loading
Loading