Skip to content

[CuTe,Fwd,Sm100] refactor mla sm100 forward and add page table#2558

Open
jayhshah wants to merge 2 commits into
mainfrom
jshah/mla-paged-kv-refactor
Open

[CuTe,Fwd,Sm100] refactor mla sm100 forward and add page table#2558
jayhshah wants to merge 2 commits into
mainfrom
jshah/mla-paged-kv-refactor

Conversation

@jayhshah
Copy link
Copy Markdown
Collaborator

@jayhshah jayhshah commented May 13, 2026

Superseding #2468.

We refactor the kernel file to remove duplication around splitting V and add page table support.

We also support not providing Q and K (which semantically mean q_pe and k_pe). This computes attention according to the formula

O = softmax(scale * (Qv @ V.T)) @ V

as in Deepseek v4 core attention.

@jayhshah jayhshah requested a review from Johnsonms May 13, 2026 00:46
),
mCuBlockIdxOffsets=(
blocksparse_tensors.cu_block_idx_offsets if blocksparse_tensors is not None else None
blocksparse_tensors.cu_block_idx_offsets
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is just to fix the linter error.

@jayhshah jayhshah force-pushed the jshah/mla-paged-kv-refactor branch from e4e954f to 0705919 Compare May 13, 2026 01:29
@jayhshah
Copy link
Copy Markdown
Collaborator Author

Note that 64k test in benchmark script throws IMA on cutlass 4.5.0 in the refactor, but this is due to issue identified here: NVIDIA/cutlass#3208

@jayhshah jayhshah force-pushed the jshah/mla-paged-kv-refactor branch from 0705919 to 9eaec0a Compare May 13, 2026 02:25
@jayhshah
Copy link
Copy Markdown
Collaborator Author

Benchmark shows ~15% dropoff for paged using cp async loads.

benchmark_mla_paged.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant