Skip to content

[CuTe,Sm100] Varlen Dynamic Persistent scheduler and metadata#2559

Open
reubenconducts wants to merge 9 commits into
Dao-AILab:mainfrom
reubenconducts:dynamic_metadata
Open

[CuTe,Sm100] Varlen Dynamic Persistent scheduler and metadata#2559
reubenconducts wants to merge 9 commits into
Dao-AILab:mainfrom
reubenconducts:dynamic_metadata

Conversation

@reubenconducts
Copy link
Copy Markdown
Contributor

@reubenconducts reubenconducts commented May 13, 2026

Draft for now until I get comprehensive benchmark numbers.

This PR adds the VarlenDynamicPersistentScheduler, seen in FA3, to FA4.

  • Accepts scheduler metadata tensors, which can be prepared in the prepare_scheduler_metadata.py FlashPrepareScheduler kernel:
    • num_m_blocks_ptr to precompute number of m blocks for the lpt sort (unused on SM100)
    • num_splits_dynamic_ptr: holds num_splits for each sequence in a batch, which can vary in mixed workloads
    • virtual_batch_idx_ptr: unused currently; used to permute sequences according to "virtual" batch indices, e.g. when sorting a batch for load balancing (to be utilized in a subsequent PR)
    • num_nheads_in_l2_ptr: used for head swizzle computation in the tile scheduler
    • tile_count_semaphore: zeroed out here, used in main kernel for dynamic persistence
  • Threads metadata through interface with options to disable, call FlashPrepareKernel, or pass in pre-computed cached metadata to amortize cost across layers.
  • Refactors tile_scheduler.py so that dynamic persistent schedulers (CLC and "traditional", like that in this PR) share methods and so that varlen schedulers (SingleTile and DynamicPersistent) reuse common methods.

Comprehensive performance numbers are attached. The regressions seen (e.g. with large batch, small seqlen) are attributable mainly to the fact that the combine kernel does not early-exit for 1-split sequences. We see that "traditional" dynamic persistent outperforms CLC (to which I've also wired up the appropriate metadata tensors) almost always. It's worth noting that in the few tests where prepare kernel latency appears extreme, a large proportion of that latency comes from torch.empty (known issue on Grace CPUs; ostensibly fixed with cuda 13.x, but I ran on 12.x).

varlen_dynamic_scheduler_perf.txt

To-do in follow-up PRs:

  • Currently the combine kernel does not early-exit for 1-split sequences; this explodes in many cases from over-launching zero-work CTAs. Solution is to use a persistent scheduler.
  • Implement lpt batch sort in the FlashPrepareKernel. This is proven to have an enormous impact on mixed prefill/decode workloads (including with CLC, from my preliminary testing). (In addition to this varlen case, batch sort will be helpful for load balancing with blocksparsity.)

@reubenconducts reubenconducts marked this pull request as ready for review May 13, 2026 23:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant