Skip to content

Normalize scalar tensors in CuTe JIT cache keys#2560

Open
xyzhang626 wants to merge 1 commit into
Dao-AILab:mainfrom
xyzhang626:xiaoyi/fa4-normalize-scalar-cache-keys
Open

Normalize scalar tensors in CuTe JIT cache keys#2560
xyzhang626 wants to merge 1 commit into
Dao-AILab:mainfrom
xyzhang626:xiaoyi/fa4-normalize-scalar-cache-keys

Conversation

@xyzhang626
Copy link
Copy Markdown

CuTeDSL compile keys are compared and hashed on the Python side. If a scalar tensor gets into the key, two keys with the same scalar value but produced by different layers compare by object identity and miss both the in-memory and persistent JIT caches.

This can happen in FA4 varlen backward when max_seqlen_q/k are scalar tensors and the backward compile key includes derived single-block predicates such as seqlen_q_rounded // m_block_size == 1. Each layer can produce a fresh scalar tensor, causing repeated backward JIT compilation for the same kernel shape.

This PR normalizes scalar tensors in CuTe JIT cache keys before in-memory lookup/storage and before computing the persistent cache hash.

Tests added:

  • in-memory cache lookup with distinct scalar tensor objects but identical values
  • persistent cache hash stability for scalar tensor keys

Local validation:

  • manual CPU scalar cache normalization check passed
  • manual CUDA scalar cache normalization check passed on GPU7

Could not run pytest locally because pytest is not installed in the available environment.

@xyzhang626
Copy link
Copy Markdown
Author

Just realized this some overlaps with #2507. #2507 fixes the specific bwd max_seqlen source, while this PR normalizes scalar tensors at the JIT cache boundary so future scalar tensor compile-key entries do not miss the in-memory or persistent cache.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant