Hi, I noticed that the softmax kernel always routes data through shared memory (global -> shared -> registers), even for small row sizes where data fits entirely in registers. This adds a bit of overhead
I've experimented with a simple change for $N \le 8192$, where we instead load directly from global memory into registers, using the synchronous CopyUniversalOp() instead of the async shared memory path that uses cpasync.CopyG2Op(). To the best of my knowledge, CopyUniversalOp() derives from CopyOp() which compiles down to ld.global PTX instructions (and allows one to load from global memory directly to registers).
I understand the rationale from the blog post -- for larger $N$, the asynchronous cp.async instructions are beneficial because they overlap memory transfers with computation. But for smaller $N$ where data can fit in registers, there's no computation to overlap with, and the cp.async -> wait -> load may add some overhead.
Here are my benchmark results (H200, $M = 8192$, bf16, mean of $n=10$ runs)
| $N$ |
Before |
After |
Delta |
| 1024 |
~2350 GB/s |
~2434 GB/s |
+ 50-100 GB/s |
| 2048 |
~3003 GB/s |
~3042 GB/s |
Minimal |
| 4096 |
~3419 GB/s |
~3578 GB/s |
+50-100 GB/s |
| 8192 |
~3685 GB/s |
~3800 GB/s |
+100-150 GB/s |
Would you be open to a PR for this?
Happy to provide more benchmark data / adjust the approach based on feedback.
Hi, I noticed that the softmax kernel always routes data through shared memory (global -> shared -> registers), even for small row sizes where data fits entirely in registers. This adds a bit of overhead
I've experimented with a simple change for$N \le 8192$ , where we instead load directly from global memory into registers, using the synchronous
CopyUniversalOp()instead of the async shared memory path that usescpasync.CopyG2Op(). To the best of my knowledge,CopyUniversalOp()derives fromCopyOp()which compiles down told.globalPTX instructions (and allows one to load from global memory directly to registers).I understand the rationale from the blog post -- for larger$N$ , the asynchronous $N$ where data can fit in registers, there's no computation to overlap with, and the
cp.asyncinstructions are beneficial because they overlap memory transfers with computation. But for smallercp.async-> wait -> load may add some overhead.Here are my benchmark results (H200,$M = 8192$ , bf16, mean of $n=10$ runs)
Would you be open to a PR for this?
Happy to provide more benchmark data / adjust the approach based on feedback.