Skip to content

Add safe softmax kernel #18

@austin362667

Description

@austin362667

Summary

For learning purposes, it makes sense to introduce a safe_softmax kernel before moving on to online_softmax. It serves as a conceptual bridge between Week 1 and Week 2.

Motivation / Use Case

Specifically, it connects:

Target kernels aligned with KernelHeim Weeks 0–2:
Week 0: Tiled copy / transpose
Week 1: Reductions (sum) with multiple implementations (e.g., naive → optimized → warp-shuffle)
[safe_softmax ...] <- add here
Week 2: Single-pass online softmax

For example, keep the reduce_sum simple and apply warp-shuffle in safe_softmax:

Target kernels aligned with KernelHeim Weeks 0–2:
Week 0: Tiled copy / transpose
Week 1: Reductions with levels of optimization: 1. reduce_sum (naive), 2. safe_softmax(warp-shuffle)
Week 2: Single-pass online_softmax

In this progression, safe_softmax naturally builds on reduction techniques from Week 1 while preparing the groundwork for the single-pass formulation introduced in Week 2.

Proposed Solution

Add a safe softmax kernel (safe_softmax) implemented in CuTe DSL under /kernels.
Add a higher-level abstraction using PyTorch custom operators under /ops.
Use the existing reference implementation (softmax_online) in /ref to validate correctness.

The implementation should leverage reduction techniques (e.g., thread reduce, warp reduce, block reduce).
It's fine not applying cluster reduce, etc.

Scope Alignment

v0.1 scope (Weeks 0-2)

Alternatives Considered

No response

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions