[ops] feat: add triton invariant attention backend#35
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a Triton-based batch-invariant attention implementation (triton-invariant) to support A100 (SM80) GPUs, including integration with the configuration, inferencer, and verification scripts, along with comprehensive unit tests. The code review feedback highlights several critical improvement opportunities in vexact/batch_invariant_ops/oai_fused_attn.py. These include adding an early exit condition in _paged_attn_fwd_block_kernel to optimize decode attention, dynamically computing max_seqlen_k and max_seqlen_q when they are not provided to avoid massive redundant overhead, and clamping q_local, logical_page, and page_offset to prevent negative or out-of-bounds pointer arithmetic that could lead to undefined behavior or hardware exceptions.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
1db8d24 to
d9acb1c
Compare
d9acb1c to
46dd656
Compare
Summary
triton-invariantattention backend based on packed Triton varlen attention kernels.flash_attn_varlen_funcwrapper with packed non-paged autograd support and paged forward support.q_len > kv_len, paged forward aliases, GQA, and prefill/decode bitwise invariance.Tests
python -m py_compile vexact/batch_invariant_ops/triton_invariant_attention.py tests/batch_invariant_ops/test_triton_attention_varlen.py tests/scripts/verify_logits_vs_native_hf.pyruff check vexact/batch_invariant_ops/triton_invariant_attention.py tests/batch_invariant_ops/test_triton_attention_varlen.py tests/scripts/verify_logits_vs_native_hf.pyruff format --check vexact/batch_invariant_ops/triton_invariant_attention.py tests/batch_invariant_ops/test_triton_attention_varlen.py tests/scripts/verify_logits_vs_native_hf.pygit diff --checkpytest -q submodules/open-vexact/tests/batch_invariant_ops/test_triton_attention_varlen.py-> 18 passed--model_backend hf --attn_impl triton-invariant --use_remove_padding-> logits matched 64/64 tokens, max abs diff 0--model_backend veomni --attn_impl triton-invariant --use_remove_padding --use_fused_lce-> logprobs matched 64/64 tokens, max abs diff 0, backward produced nonzero grad normexamples/getting_started/run_qwen3_1b7.shwithINFER_FA_IMPL=triton-invariantandVEOMNI_ATTN_IMPLEMENTATION=triton-invariant->training/rollout_probs_diff_max:0.0,training/rollout_probs_diff_mean:0.0