Skip to content

[ops] refactor: replace kernel env vars with args#651

Merged
FoolPlayer merged 10 commits intomainfrom
refactor/ops-implementation-config
Apr 20, 2026
Merged

[ops] refactor: replace kernel env vars with args#651
FoolPlayer merged 10 commits intomainfrom
refactor/ops-implementation-config

Conversation

@FoolPlayer
Copy link
Copy Markdown
Collaborator

@FoolPlayer FoolPlayer commented Apr 14, 2026

What does this PR do?

Replace the all-or-nothing VEOMNI_USE_LIGER_KERNEL / USE_GROUP_GEMM env vars with per-op fields on OpsImplementationConfig, and reorganize veomni/ops/ around a unified kernel registry so adding a new kernel/backend is a one-file change.

This started as the env-var → args refactor and grew into a full veomni/ops/ reorganization once the patch surface was visible across every model.

Motivation

  • Env-var control was binary and global — you couldn't, e.g., turn on Liger RMSNorm but keep eager cross-entropy.
  • Each model's gpu_patch.py / npu_patch.py had duplicated if ligerkernel: setattr(...) blocks that were drifting out of sync.
  • New ops/backends (DeepSeek V3 deterministic RoPE, Wan Triton RMSNorm, NPU chunked loss) had no clean home.

Key changes

1. Per-op config fields (veomni/arguments/arguments_types.py)

Five new str fields on OpsImplementationConfig, all defaulting to "eager" (no implicit "auto" — users opt in):

Field Backends
cross_entropy_loss_implementation eager, liger_kernel, npu (chunked loss)
rms_norm_implementation eager, liger_kernel, npu, triton*
swiglu_mlp_implementation eager, liger_kernel
rotary_pos_emb_implementation eager, liger_kernel, npu, triton*
load_balancing_loss_implementation eager, triton

* triton is registered per-model via extra_backends (DeepSeek V3, Wan).

__post_init__ validates backend availability (liger_kernel / torch_npu packages) up-front instead of failing with a cryptic error at first batch.

2. Removed env vars

  • VEOMNI_USE_LIGER_KERNEL → split into 4 per-op fields
  • USE_GROUP_GEMMmoe_implementation
  • VEOMNI_ENABLE_CHUNK_LOSScross_entropy_loss_implementation="npu"
  • MODELING_BACKEND is kept (still controls the import-time attention patch).

3. New unified registry (veomni/ops/config/)

  • registry.py: OpSpec / BackendSpec / OpScope + register_op / apply_global_ops / apply_per_model_patches.
  • singleton.py: bridges the resolved config from BaseTrainer to each model's device_patch.py.

Three dispatch scopes drive every kernel binding:

  • import-time — attention + HF LOSS_MAPPING install, in apply_ops_patch().
  • GLOBAL — module-level function pointer (cross-entropy, load-balancing loss).
  • PER_MODELsetattr on the HF modeling module, in each model's device_patch.py.
  • build-time — fused MoE binding, in build_foundation_model().

4. veomni/ops/ reorg (5 phases)

veomni/ops/
├── config/                 dispatch infra (no kernels)
├── kernels/                one subpackage per op
│   ├── attention/  cross_entropy/  load_balancing_loss/
│   ├── rms_norm/   rotary/         swiglu/   moe/
├── platform/npu/           HCCL pre-mul-sum patch
└── batch_invariant_ops/    deterministic-mode toggle

Old paths (flash_attn/, fused_cross_entropy/, fused_moe/, npu_patch/, dit/rope_wan/, dcp_consolidation.py, …) are gone — no shims.

5. Per-model patches unified

All 9 device_patch.py files (llama, qwen2/3, qwen3_moe, seed_oss, qwen2_vl, qwen3_vl, deepseek_v3, wan) now share one pattern:

apply_per_model_patches(
    hf_module=hf_qwen3,
    model_name="Qwen3",
    targets={
        "rms_norm": "Qwen3RMSNorm",
        "rotary_pos_emb": "apply_rotary_pos_emb",
        "swiglu_mlp": "Qwen3MLP",
    },
)

Per-model overrides (DeepSeek V3 deterministic Triton RoPE, Wan Triton RMSNorm, Qwen-VL vision RoPE) go in extra_backends / custom_patches.

6. Trainer wiring

BaseTrainer, VLMTrainer, DitTrainer call apply_ops_config(model_args.ops_implementation) before building the model. apply_ops_patch() (import-time) also installs VeOmni's LOSS_MAPPING so direct build_foundation_model() calls (unit tests, scripts) get the right loss function.

7. Tests + CI fixes

  • Tests now build OpsImplementationConfig instead of toggling env vars.
  • NPU CI: gate liger_kernel and triton modes on package availability (NPU image ships triton-ascend, not mainline triton).
  • Updated tests/special_sanity/check_device_api_usage.py whitelist for new MoE kernel paths.

8. Docs

  • New veomni/ops/README.md: layout, dispatch model, all ops/backends/per-model coverage, recipes for adding a new backend / new op.
  • Updated docs/design/kernel_selection.md, docs/usage/support_new_models/*.md, .agents/knowledge/{architecture,constraints}.md.

API and Usage Example

YAML (replaces all env-var setting):

model:
  ops_implementation:
    attn_implementation: flash_attention_2
    moe_implementation: fused
    cross_entropy_loss_implementation: liger_kernel
    load_balancing_loss_implementation: triton
    rms_norm_implementation: liger_kernel
    rotary_pos_emb_implementation: liger_kernel
    swiglu_mlp_implementation: eager   # mix-and-match per op

NPU users get chunked cross-entropy via cross_entropy_loss_implementation: npu.

Breaking changes

  • VEOMNI_USE_LIGER_KERNEL, USE_GROUP_GEMM, VEOMNI_ENABLE_CHUNK_LOSS no longer recognized — users must set the corresponding model.ops_implementation.* fields.
  • Default behavior changes from "use Liger if installed" to explicit eager; users who relied on the env var must opt in.

Related: #569.

Test

  • Local: full pytest tests/ops/, tests/models/test_models_patch.py (excluding qwen3_5 due to GPU memory in dev env).
  • CI: GPU + Ascend NPU jobs on the PR (gate_liger_kernel / gate_triton commits make the Ascend job green).

Checklist Before Submitting

  • Read the Contribute Guide
  • Applied pre-commit checks
  • Added/updated documentation
  • If tasks/ training scripts were moved or renamed: updated docs/ examples and verified python3 scripts/ci/check_doc_task_paths.py passes
  • Added tests to CI workflow (or explained why not feasible)

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request transitions kernel selection from environment variables to a structured configuration-driven approach using OpsImplementationConfig. It introduces an 'auto' resolution mechanism that selects the most appropriate kernel (Liger, Triton, or Eager) based on hardware and package availability. A global singleton is implemented to allow model-specific patches to access the resolved configuration. Feedback was provided to improve the robustness of the 'auto' resolution for load-balancing loss by explicitly checking for Triton availability to avoid potential crashes in CPU-only or non-Triton environments.

Comment thread veomni/arguments/arguments_types.py Outdated
Comment on lines +732 to +733
if self.load_balancing_loss_implementation == "auto":
self.load_balancing_loss_implementation = "eager" if npu else "triton"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The auto resolution for load_balancing_loss_implementation assumes that if the hardware is not an NPU, it must be a GPU where Triton is available. This will lead to a crash (import error) on CPU-only environments or systems where the triton package is not installed. It is safer to check for Triton availability explicitly, similar to how Liger is checked.

Suggested change
if self.load_balancing_loss_implementation == "auto":
self.load_balancing_loss_implementation = "eager" if npu else "triton"
if self.load_balancing_loss_implementation == "auto":
from ..utils.import_utils import is_fused_moe_available
self.load_balancing_loss_implementation = "triton" if (not npu and is_fused_moe_available()) else "eager"

@Luosuu
Copy link
Copy Markdown
Collaborator

Luosuu commented Apr 14, 2026

LGTM. please fix CI then we can merge.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are missing same change for test_padded_packed_loss.py?

Comment thread veomni/ops/__init__.py Outdated
)
# NOTE: fused MoE patch is applied in build_foundation_model() based on
# the moe_implementation parameter.
logger.info_rank0("✅ VeOmni ops config applied.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also call format_kernel_functions here?

Comment thread docs/usage/arguments.md Outdated
| --- | --- | --- | --- |
| attn_implementation | `Optional[Literal["eager", "sdpa", "flash_attention_2", "flash_attention_3", "flash_attention_4", "native-sparse"]]` | `"flash_attention_2"` | Attention implementation to use. |
| moe_implementation | `Optional[Literal["eager", "fused", "fused_quack"]]` | `None` | MoE implementation: `eager` (reference loop), `fused` (Triton), `fused_quack` (Quack CUTLASS, SM90+). |
| cross_entropy_loss_implementation | `Literal["auto", "eager", "liger_kernel"]` | `"auto"` | Cross-entropy loss: `liger_kernel` for fused linear CE, `eager` for PyTorch. |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit 1: what does auto mean?

nit 2: should we make it a literal? if not literal but string, it would be easier for people to register their own kernel? and they will still get a "not found error" when have a kernel name typo.

nit 3: how do we deal with NPU? ask to add liger_kernel_npu option? (i'd prefer this to be more explicit) and we can make auto to adapt to liger_kernel or liger_kernel_npu depending on the env?

The Ascend CI runner doesn't ship with liger-kernel, so the hard-coded
[True, False] parametrization produced a ValueError in
OpsImplementationConfig.__post_init__ for every 'veomni + use_liger=True'
case. Derive _USE_LIGER_KERNEL from is_liger_kernel_available() so those
modes are skipped when the package is missing; GPU coverage is unchanged.

Made-with: Cursor
…ility

The NPU CI image ships triton-ascend, not mainline triton, so
veomni/ops/kernels/load_balancing_loss/triton.py fails with
ModuleNotFoundError as soon as the registry resolves the triton backend.
Fall back to 'eager' when the mainline triton package is absent, mirroring
the liger-kernel gating.

Made-with: Cursor
@FoolPlayer FoolPlayer force-pushed the refactor/ops-implementation-config branch from e8966cc to 9b0989c Compare April 18, 2026 06:57
FoolPlayer added a commit that referenced this pull request Apr 18, 2026
Follow-up to #639. PR #639 fixed the NPU SIGABRT in
tests/data/test_multisource_dataset.py by forcing pin_memory=False, but
tests/data/test_datasets.py still passes the default pin_memory=True and
suffers the same flaky crash (seen on PR #651 CI run 24599839122 and on
main run 24379970711):

    terminate called without an active exception
    failed (exitcode: -6) local_rank: N -- Signal 6 (SIGABRT)

Root cause is identical to #639: the DataLoader pin_memory background
thread races with HCCL ProcessGroup teardown inside destroy_distributed,
invalidating torch_npu global state and crashing the pinning thread; the
still-joinable std::thread then triggers std::terminate() at interpreter
shutdown.

Unlike test_multisource_dataset.py (which overrides _build_dataloader),
test_datasets.py inherits the base dataloader from BaseTrainer, so the
simplest fix is to pass --data.dataloader.pin_memory=False via the
torchrun CLI in build_command. DummyDataset means pin_memory has no
performance benefit on GPU either, so this is behaviorally neutral.

Made-with: Cursor
@FoolPlayer FoolPlayer merged commit 3282afe into main Apr 20, 2026
26 of 28 checks passed
@FoolPlayer FoolPlayer deleted the refactor/ops-implementation-config branch April 20, 2026 04:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants