[ops] refactor: replace kernel env vars with args#651
Conversation
There was a problem hiding this comment.
Code Review
This pull request transitions kernel selection from environment variables to a structured configuration-driven approach using OpsImplementationConfig. It introduces an 'auto' resolution mechanism that selects the most appropriate kernel (Liger, Triton, or Eager) based on hardware and package availability. A global singleton is implemented to allow model-specific patches to access the resolved configuration. Feedback was provided to improve the robustness of the 'auto' resolution for load-balancing loss by explicitly checking for Triton availability to avoid potential crashes in CPU-only or non-Triton environments.
| if self.load_balancing_loss_implementation == "auto": | ||
| self.load_balancing_loss_implementation = "eager" if npu else "triton" |
There was a problem hiding this comment.
The auto resolution for load_balancing_loss_implementation assumes that if the hardware is not an NPU, it must be a GPU where Triton is available. This will lead to a crash (import error) on CPU-only environments or systems where the triton package is not installed. It is safer to check for Triton availability explicitly, similar to how Liger is checked.
| if self.load_balancing_loss_implementation == "auto": | |
| self.load_balancing_loss_implementation = "eager" if npu else "triton" | |
| if self.load_balancing_loss_implementation == "auto": | |
| from ..utils.import_utils import is_fused_moe_available | |
| self.load_balancing_loss_implementation = "triton" if (not npu and is_fused_moe_available()) else "eager" |
|
LGTM. please fix CI then we can merge. |
There was a problem hiding this comment.
I think we are missing same change for test_padded_packed_loss.py?
| ) | ||
| # NOTE: fused MoE patch is applied in build_foundation_model() based on | ||
| # the moe_implementation parameter. | ||
| logger.info_rank0("✅ VeOmni ops config applied.") |
There was a problem hiding this comment.
should we also call format_kernel_functions here?
| | --- | --- | --- | --- | | ||
| | attn_implementation | `Optional[Literal["eager", "sdpa", "flash_attention_2", "flash_attention_3", "flash_attention_4", "native-sparse"]]` | `"flash_attention_2"` | Attention implementation to use. | | ||
| | moe_implementation | `Optional[Literal["eager", "fused", "fused_quack"]]` | `None` | MoE implementation: `eager` (reference loop), `fused` (Triton), `fused_quack` (Quack CUTLASS, SM90+). | | ||
| | cross_entropy_loss_implementation | `Literal["auto", "eager", "liger_kernel"]` | `"auto"` | Cross-entropy loss: `liger_kernel` for fused linear CE, `eager` for PyTorch. | |
There was a problem hiding this comment.
nit 1: what does auto mean?
nit 2: should we make it a literal? if not literal but string, it would be easier for people to register their own kernel? and they will still get a "not found error" when have a kernel name typo.
nit 3: how do we deal with NPU? ask to add liger_kernel_npu option? (i'd prefer this to be more explicit) and we can make auto to adapt to liger_kernel or liger_kernel_npu depending on the env?
The Ascend CI runner doesn't ship with liger-kernel, so the hard-coded [True, False] parametrization produced a ValueError in OpsImplementationConfig.__post_init__ for every 'veomni + use_liger=True' case. Derive _USE_LIGER_KERNEL from is_liger_kernel_available() so those modes are skipped when the package is missing; GPU coverage is unchanged. Made-with: Cursor
…ility The NPU CI image ships triton-ascend, not mainline triton, so veomni/ops/kernels/load_balancing_loss/triton.py fails with ModuleNotFoundError as soon as the registry resolves the triton backend. Fall back to 'eager' when the mainline triton package is absent, mirroring the liger-kernel gating. Made-with: Cursor
e8966cc to
9b0989c
Compare
Follow-up to #639. PR #639 fixed the NPU SIGABRT in tests/data/test_multisource_dataset.py by forcing pin_memory=False, but tests/data/test_datasets.py still passes the default pin_memory=True and suffers the same flaky crash (seen on PR #651 CI run 24599839122 and on main run 24379970711): terminate called without an active exception failed (exitcode: -6) local_rank: N -- Signal 6 (SIGABRT) Root cause is identical to #639: the DataLoader pin_memory background thread races with HCCL ProcessGroup teardown inside destroy_distributed, invalidating torch_npu global state and crashing the pinning thread; the still-joinable std::thread then triggers std::terminate() at interpreter shutdown. Unlike test_multisource_dataset.py (which overrides _build_dataloader), test_datasets.py inherits the base dataloader from BaseTrainer, so the simplest fix is to pass --data.dataloader.pin_memory=False via the torchrun CLI in build_command. DummyDataset means pin_memory has no performance benefit on GPU either, so this is behaviorally neutral. Made-with: Cursor
What does this PR do?
Replace the all-or-nothing
VEOMNI_USE_LIGER_KERNEL/USE_GROUP_GEMMenv vars with per-op fields onOpsImplementationConfig, and reorganizeveomni/ops/around a unified kernel registry so adding a new kernel/backend is a one-file change.This started as the env-var → args refactor and grew into a full
veomni/ops/reorganization once the patch surface was visible across every model.Motivation
gpu_patch.py/npu_patch.pyhad duplicatedif ligerkernel: setattr(...)blocks that were drifting out of sync.Key changes
1. Per-op config fields (
veomni/arguments/arguments_types.py)Five new
strfields onOpsImplementationConfig, all defaulting to"eager"(no implicit "auto" — users opt in):cross_entropy_loss_implementationeager,liger_kernel,npu(chunked loss)rms_norm_implementationeager,liger_kernel,npu,triton*swiglu_mlp_implementationeager,liger_kernelrotary_pos_emb_implementationeager,liger_kernel,npu,triton*load_balancing_loss_implementationeager,triton*tritonis registered per-model viaextra_backends(DeepSeek V3, Wan).__post_init__validates backend availability (liger_kernel/torch_npupackages) up-front instead of failing with a cryptic error at first batch.2. Removed env vars
VEOMNI_USE_LIGER_KERNEL→ split into 4 per-op fieldsUSE_GROUP_GEMM→moe_implementationVEOMNI_ENABLE_CHUNK_LOSS→cross_entropy_loss_implementation="npu"MODELING_BACKENDis kept (still controls the import-time attention patch).3. New unified registry (
veomni/ops/config/)registry.py:OpSpec/BackendSpec/OpScope+register_op/apply_global_ops/apply_per_model_patches.singleton.py: bridges the resolved config fromBaseTrainerto each model'sdevice_patch.py.Three dispatch scopes drive every kernel binding:
LOSS_MAPPINGinstall, inapply_ops_patch().setattron the HF modeling module, in each model'sdevice_patch.py.build_foundation_model().4.
veomni/ops/reorg (5 phases)Old paths (
flash_attn/,fused_cross_entropy/,fused_moe/,npu_patch/,dit/rope_wan/,dcp_consolidation.py, …) are gone — no shims.5. Per-model patches unified
All 9
device_patch.pyfiles (llama,qwen2/3,qwen3_moe,seed_oss,qwen2_vl,qwen3_vl,deepseek_v3,wan) now share one pattern:Per-model overrides (DeepSeek V3 deterministic Triton RoPE, Wan Triton RMSNorm, Qwen-VL vision RoPE) go in
extra_backends/custom_patches.6. Trainer wiring
BaseTrainer,VLMTrainer,DitTrainercallapply_ops_config(model_args.ops_implementation)before building the model.apply_ops_patch()(import-time) also installs VeOmni'sLOSS_MAPPINGso directbuild_foundation_model()calls (unit tests, scripts) get the right loss function.7. Tests + CI fixes
OpsImplementationConfiginstead of toggling env vars.liger_kernelandtritonmodes on package availability (NPU image shipstriton-ascend, not mainlinetriton).tests/special_sanity/check_device_api_usage.pywhitelist for new MoE kernel paths.8. Docs
veomni/ops/README.md: layout, dispatch model, all ops/backends/per-model coverage, recipes for adding a new backend / new op.docs/design/kernel_selection.md,docs/usage/support_new_models/*.md,.agents/knowledge/{architecture,constraints}.md.API and Usage Example
YAML (replaces all env-var setting):
NPU users get chunked cross-entropy via
cross_entropy_loss_implementation: npu.Breaking changes
VEOMNI_USE_LIGER_KERNEL,USE_GROUP_GEMM,VEOMNI_ENABLE_CHUNK_LOSSno longer recognized — users must set the correspondingmodel.ops_implementation.*fields.Related: #569.
Test
pytest tests/ops/,tests/models/test_models_patch.py(excludingqwen3_5due to GPU memory in dev env).gate_liger_kernel/gate_tritoncommits make the Ascend job green).Checklist Before Submitting
tasks/training scripts were moved or renamed: updateddocs/examples and verifiedpython3 scripts/ci/check_doc_task_paths.pypasses