Skip to content

[ops] feat: add triton invariant attention backend#35

Open
Luosuu wants to merge 4 commits into
verl-project:mainfrom
Luosuu:feat/triton-invariant-attn
Open

[ops] feat: add triton invariant attention backend#35
Luosuu wants to merge 4 commits into
verl-project:mainfrom
Luosuu:feat/triton-invariant-attn

Conversation

@Luosuu

@Luosuu Luosuu commented Jun 6, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • Add a triton-invariant attention backend based on packed Triton varlen attention kernels.
  • Expose a FlashAttention-compatible flash_attn_varlen_func wrapper with packed non-paged autograd support and paged forward support.
  • Support GQA, packed-only non-paged model inputs, and register the backend for inference plus VeOmni/VeRL actor-side use.
  • Extend packed logits/logprob verification with an optional VeOmni model backend for fused LCE validation.
  • Add CUDA tests for non-paged backward, arbitrary head dims, q_len > kv_len, paged forward aliases, GQA, and prefill/decode bitwise invariance.

Tests

  • python -m py_compile vexact/batch_invariant_ops/triton_invariant_attention.py tests/batch_invariant_ops/test_triton_attention_varlen.py tests/scripts/verify_logits_vs_native_hf.py
  • ruff check vexact/batch_invariant_ops/triton_invariant_attention.py tests/batch_invariant_ops/test_triton_attention_varlen.py tests/scripts/verify_logits_vs_native_hf.py
  • ruff format --check vexact/batch_invariant_ops/triton_invariant_attention.py tests/batch_invariant_ops/test_triton_attention_varlen.py tests/scripts/verify_logits_vs_native_hf.py
  • git diff --check
  • mlx worker 950994: pytest -q submodules/open-vexact/tests/batch_invariant_ops/test_triton_attention_varlen.py -> 18 passed
  • mlx worker 950994 Qwen3-1.7B verifier, --model_backend hf --attn_impl triton-invariant --use_remove_padding -> logits matched 64/64 tokens, max abs diff 0
  • mlx worker 950994 Qwen3-1.7B verifier, --model_backend veomni --attn_impl triton-invariant --use_remove_padding --use_fused_lce -> logprobs matched 64/64 tokens, max abs diff 0, backward produced nonzero grad norm
  • mlx worker 950994 Qwen3-1.7B vexact rollout smoke using examples/getting_started/run_qwen3_1b7.sh with INFER_FA_IMPL=triton-invariant and VEOMNI_ATTN_IMPLEMENTATION=triton-invariant -> training/rollout_probs_diff_max:0.0, training/rollout_probs_diff_mean:0.0

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Triton-based batch-invariant attention implementation (triton-invariant) to support A100 (SM80) GPUs, including integration with the configuration, inferencer, and verification scripts, along with comprehensive unit tests. The code review feedback highlights several critical improvement opportunities in vexact/batch_invariant_ops/oai_fused_attn.py. These include adding an early exit condition in _paged_attn_fwd_block_kernel to optimize decode attention, dynamically computing max_seqlen_k and max_seqlen_q when they are not provided to avoid massive redundant overhead, and clamping q_local, logical_page, and page_offset to prevent negative or out-of-bounds pointer arithmetic that could lead to undefined behavior or hardware exceptions.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread vexact/batch_invariant_ops/triton_invariant_attention.py
Comment thread vexact/batch_invariant_ops/oai_fused_attn.py Outdated
Comment thread vexact/batch_invariant_ops/oai_fused_attn.py Outdated
Comment thread vexact/batch_invariant_ops/oai_fused_attn.py Outdated
@Luosuu Luosuu changed the title feat: add triton invariant attention backend [ops] feat: add triton invariant attention backend Jun 6, 2026
@Luosuu Luosuu force-pushed the feat/triton-invariant-attn branch from 1db8d24 to d9acb1c Compare June 6, 2026 18:32
@Luosuu Luosuu force-pushed the feat/triton-invariant-attn branch from d9acb1c to 46dd656 Compare June 6, 2026 18:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant