Summary
For learning purposes, it makes sense to introduce a safe_softmax kernel before moving on to online_softmax. It serves as a conceptual bridge between Week 1 and Week 2.
Motivation / Use Case
Specifically, it connects:
Target kernels aligned with KernelHeim Weeks 0–2:
Week 0: Tiled copy / transpose
Week 1: Reductions (sum) with multiple implementations (e.g., naive → optimized → warp-shuffle)
[safe_softmax ...] <- add here
Week 2: Single-pass online softmax
For example, keep the reduce_sum simple and apply warp-shuffle in safe_softmax:
Target kernels aligned with KernelHeim Weeks 0–2:
Week 0: Tiled copy / transpose
Week 1: Reductions with levels of optimization: 1. reduce_sum (naive), 2. safe_softmax(warp-shuffle)
Week 2: Single-pass online_softmax
In this progression, safe_softmax naturally builds on reduction techniques from Week 1 while preparing the groundwork for the single-pass formulation introduced in Week 2.
Proposed Solution
Add a safe softmax kernel (safe_softmax) implemented in CuTe DSL under /kernels.
Add a higher-level abstraction using PyTorch custom operators under /ops.
Use the existing reference implementation (softmax_online) in /ref to validate correctness.
The implementation should leverage reduction techniques (e.g., thread reduce, warp reduce, block reduce).
It's fine not applying cluster reduce, etc.
Scope Alignment
v0.1 scope (Weeks 0-2)
Alternatives Considered
No response
Additional Context
No response
Summary
For learning purposes, it makes sense to introduce a
safe_softmaxkernel before moving on toonline_softmax. It serves as a conceptual bridge between Week 1 and Week 2.Motivation / Use Case
Specifically, it connects:
For example, keep the
reduce_sumsimple and apply warp-shuffle insafe_softmax:In this progression,
safe_softmaxnaturally builds on reduction techniques from Week 1 while preparing the groundwork for the single-pass formulation introduced in Week 2.Proposed Solution
Add a safe softmax kernel (
safe_softmax) implemented in CuTe DSL under/kernels.Add a higher-level abstraction using PyTorch custom operators under
/ops.Use the existing reference implementation (
softmax_online) in/refto validate correctness.The implementation should leverage reduction techniques (e.g., thread reduce, warp reduce, block reduce).
It's fine not applying cluster reduce, etc.
Scope Alignment
v0.1 scope (Weeks 0-2)
Alternatives Considered
No response
Additional Context
No response