Skip to content

Skip SMEM for small $N$ in softmax kernel #62

@GiftedNovaHD

Description

@GiftedNovaHD

Hi, I noticed that the softmax kernel always routes data through shared memory (global -> shared -> registers), even for small row sizes where data fits entirely in registers. This adds a bit of overhead

I've experimented with a simple change for $N \le 8192$, where we instead load directly from global memory into registers, using the synchronous CopyUniversalOp() instead of the async shared memory path that uses cpasync.CopyG2Op(). To the best of my knowledge, CopyUniversalOp() derives from CopyOp() which compiles down to ld.global PTX instructions (and allows one to load from global memory directly to registers).

I understand the rationale from the blog post -- for larger $N$, the asynchronous cp.async instructions are beneficial because they overlap memory transfers with computation. But for smaller $N$ where data can fit in registers, there's no computation to overlap with, and the cp.async -> wait -> load may add some overhead.

Here are my benchmark results (H200, $M = 8192$, bf16, mean of $n=10$ runs)

$N$ Before After Delta
1024 ~2350 GB/s ~2434 GB/s + 50-100 GB/s
2048 ~3003 GB/s ~3042 GB/s Minimal
4096 ~3419 GB/s ~3578 GB/s +50-100 GB/s
8192 ~3685 GB/s ~3800 GB/s +100-150 GB/s

Would you be open to a PR for this?

Happy to provide more benchmark data / adjust the approach based on feedback.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions