Skip to content

[runtime] feat: support variant mode#17

Draft
Neawhen wants to merge 4 commits into
mainfrom
feat/switch_mode
Draft

[runtime] feat: support variant mode#17
Neawhen wants to merge 4 commits into
mainfrom
feat/switch_mode

Conversation

@Neawhen

@Neawhen Neawhen commented Apr 22, 2026

Copy link
Copy Markdown
Collaborator

No description provided.

@Neawhen Neawhen requested review from Luosuu and pengwu22 and removed request for Luosuu and pengwu22 April 22, 2026 21:20

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a "variant" mode for Flash Attention, allowing for dynamic num_splits selection to improve throughput when batch-invariance is not required. The changes include new CLI arguments, configuration validation, and logic within the inferencer to toggle batch-invariant patches. Review feedback suggests extending this support to Blackwell hardware by adding a fa-variant-cute implementation and recommends explicitly disabling batch-invariant mode when the variant path is selected to ensure consistent global state.


# Non-invariant variant: same kernel but without the num_splits=1 lock,
# so the kernel is free to pick Split-KV counts for best throughput.
flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fa-variant implementation currently defaults to use_cute=False, which will cause a crash on Blackwell (SM100) hardware due to the assertion at line 167 (which requires SM90 for the non-cute path). Adding a fa-variant-cute implementation allows variant mode to be used on newer hardware, following the existing pattern for invariant mode.

Suggested change
flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False)
flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False)
flash_attention_forward_variant_cute = partial(flash_attention_forward, lock_num_splits=False, use_cute=True)

Comment thread benchmarks/throughput.py
choices=["fa-invariant", "fa-invariant-cute", "flex"],
default="fa-invariant",
help="Attention implementation (default: fa-invariant, i.e. flash attn).",
choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"],

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add fa-variant-cute to the choices to support variant mode on Blackwell (SM100) hardware, consistent with the fa-invariant-cute option.

Suggested change
choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"],
choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant", "fa-variant-cute"],

from .flash_attention import (
flash_attention_forward,
flash_attention_forward_cute,
flash_attention_forward_variant,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Export flash_attention_forward_variant_cute to support variant mode on Blackwell hardware.

    flash_attention_forward_variant,
    flash_attention_forward_variant_cute,

"AttentionBlockSize",
"flash_attention_forward",
"flash_attention_forward_cute",
"flash_attention_forward_variant",

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add flash_attention_forward_variant_cute to __all__.

Suggested change
"flash_attention_forward_variant",
"flash_attention_forward_variant",
"flash_attention_forward_variant_cute",

Comment thread vexact/config.py

# Validate attention implementation
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex"]
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Include fa-variant-cute in the list of valid attention implementations.

Suggested change
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"]
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant", "fa-variant-cute"]

)
from vexact.batch_invariant_ops import flash_attention_forward as flash_attention_forward_impl
from vexact.batch_invariant_ops import flash_attention_forward_cute as flash_attention_forward_cute_impl
from vexact.batch_invariant_ops import flash_attention_forward_variant as flash_attention_forward_variant_impl

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Import disable_batch_invariant_mode to allow explicitly disabling the global ATen patches when variant mode is requested. Also import the new cute variant for Blackwell support.

Suggested change
from vexact.batch_invariant_ops import flash_attention_forward_variant as flash_attention_forward_variant_impl
from vexact.batch_invariant_ops import (
disable_batch_invariant_mode,
flash_attention_forward_variant as flash_attention_forward_variant_impl,
flash_attention_forward_variant_cute as flash_attention_forward_variant_cute_impl,
)

ALL_ATTENTION_FUNCTIONS["fa-invariant"] = flash_attention_forward_impl
ALL_ATTENTION_FUNCTIONS["fa-invariant-cute"] = flash_attention_forward_cute_impl
# Non-invariant variant: same kernel path as fa-invariant but without num_splits=1.
ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Register the fa-variant-cute implementation.

Suggested change
ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl
ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl
ALL_ATTENTION_FUNCTIONS["fa-variant-cute"] = flash_attention_forward_variant_cute_impl

Comment on lines +95 to +98
if enable_batch_invariant and not is_batch_invariant_mode_enabled():
enable_batch_invariant_mode()
if not enable_batch_invariant:
logger.info("[VEXACT] Inferencer: batch invariant mode DISABLED (variant)")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation only enables batch-invariant mode if requested, but it doesn't disable it if it was previously enabled (e.g., in a shared process or test environment). To ensure "variant" mode truly disables the patches as described in the CLI help, disable_batch_invariant_mode() should be called when enable_batch_invariant is False.

        if enable_batch_invariant:
            if not is_batch_invariant_mode_enabled():
                enable_batch_invariant_mode()
        else:
            if is_batch_invariant_mode_enabled():
                disable_batch_invariant_mode()
            logger.info("[VEXACT] Inferencer: batch invariant mode DISABLED (variant)")

@Neawhen Neawhen force-pushed the feat/switch_mode branch from 4835504 to ea02ac7 Compare May 28, 2026 23:44
@Neawhen Neawhen marked this pull request as draft May 28, 2026 23:54
@Neawhen Neawhen force-pushed the feat/switch_mode branch from ea02ac7 to e77ac83 Compare June 1, 2026 23:14
Neawhen and others added 4 commits June 3, 2026 06:08
Document VeXact and vllm throughput cost of batch-invariance (256/512
requests, Qwen3-1.7B and Qwen3-30B-A3B, on H100). Ignore the local
benchmark runner scripts and result logs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ults

- vexact/core/scheduler.py: optional VEXACT_BATCH_PROBE JSONL probe of per-step
  batch composition (prefill/decode seq + token counts), enabled by env var.
- docs/tasks/vexact_vs_vllm_batch_probe.md: probe analysis write-up.
- benchmarks/throughput.py: pass top_p=1.0 so post-rebase sampler assertion
  ("top_p must be set") no longer fires.
- benchmarks/batch_invariance_results.md: add vexact-prefill-cudagraph rows
  (post mixed-prefill-cudagraph) to both summary tables.
- .gitignore: ignore batch_probe outputs.
Re-ran the throughput benchmark on the rebased feat/switch_mode
branch (now stacked on main which includes #32 prefill-cudagraph
cleanup) and added VeXact-latest rows to both the 256- and
512-request summary tables.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Neawhen Neawhen force-pushed the feat/switch_mode branch from e77ac83 to adcace8 Compare June 3, 2026 02:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant