Skip to content

[model] feat: [transformers-v5] Introduce new registration based kernel replacement.#569

Open
piyifan123 wants to merge 17 commits intomainfrom
piyifan/kernel-patch-exp
Open

[model] feat: [transformers-v5] Introduce new registration based kernel replacement.#569
piyifan123 wants to merge 17 commits intomainfrom
piyifan/kernel-patch-exp

Conversation

@piyifan123
Copy link
Copy Markdown
Collaborator

@piyifan123 piyifan123 commented Mar 16, 2026

Modeling code

Taking qwen3 as an example, the key part in generated modeling code is

  1. Registers the OpsSlot
# Creates OpSlot objects at the modeling level
veomni_rms_norm = OpSlot("rms_norm", "standard")
veomni_apply_rotary_pos_emb = OpSlot("apply_rotary_pos_emb", "full")
veomni_swiglu_mlp = OpSlot("swiglu_mlp", "standard")
veomni_cross_entropy_loss = OpSlot("cross_entropy_loss", "standard")
  1. Patch forward functions to call OpsSlot to get the kernel or fall back to eager impl (still use HF impl)
class Qwen3RMSNorm(nn.Module):
    ...

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # Modification: OpSlot guard — use fused RMSNorm kernel when bound.
        if veomni_rms_norm.has_kernel:
            return veomni_rms_norm(hidden_states, self.weight, self.variance_epsilon)
        # Original HF code below, unchanged.
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

  1. In auto.py, we set proper kernel for OpsSlot instances in the corresponding model objects.
def _bind_veomni_ops(modeling_module, ops_config: OpsImplementationConfig) -> bool:
    """Bind all OpSlot instances found in *modeling_module*.

    Returns ``True`` if at least one OpSlot was found (and bound).
    """
    found = False
    for name in dir(modeling_module):
        obj = getattr(modeling_module, name, None)
        if isinstance(obj, OpSlot):
            impl_name = _resolve_impl_name(obj.op_name, ops_config)
            obj.bind(impl_name)
            logger.info_rank0(f"OpSlot '{name}' bound to '{impl_name}' -> {obj}")
            found = True
    return found

Kernel registry

Created a kernel registry to allow people to register new kernels and specify like following

    KernelSpec(
        name="liger", # kernel name to specify in models.{op_name}_implementation flag
        op_name="rms_norm", 
        # For ops like RMSNorm, we have many variants so creates this one 
        # to allow model's OpSlot to select the right one.
        variant="standard", 
        factory=_liger_rms_norm_factory,
        # Add the hardware requirements so that user get early error when setting impl to a kernel not supported 
        # in the specified hardware.
        hardware=HardwareRequirement(device_type="cuda"),
        description="LigerKernel fused RMSNorm",
    )

Check kernel_defaults.py for details.

@piyifan123 piyifan123 changed the title [models] feat: [transformers-v5] Introduce new registration based kernel replacement. [model] feat: [transformers-v5] Introduce new registration based kernel replacement. Mar 16, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-designed refactoring to a unified, registration-based kernel selection system, moving away from ad-hoc patching. The new KernelRegistry and OpSlot mechanism provides a much cleaner, more extensible, and safer way to manage custom kernels, with a clear migration path for existing models. The design documentation is thorough and the implementation largely follows it. However, I've identified two critical issues: one is a discrepancy between the design and implementation of the configuration for MoE experts, which impacts the user-facing API, and the other is a missing check for duplicate registrations in the kernel registry, which could lead to silent errors. Addressing these will make the new system more robust and consistent.

Comment thread veomni/arguments/arguments_types.py Outdated
Comment on lines +623 to +632
@property
def moe_experts_implementation(self) -> str:
"""Resolve moe_implementation to a kernel registry name for the ``moe_experts`` OpSlot."""
raw = self.moe_implementation
if raw is None:
return "eager"
mapped = self._MOE_IMPL_TO_KERNEL.get(raw)
if mapped is None:
raise ValueError(f"Unknown moe_implementation='{raw}'. Valid: {list(self._MOE_IMPL_TO_KERNEL.keys())}")
return mapped
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of moe_experts_implementation as a read-only property contradicts the design goal of a unified configuration surface. The design document docs/design/unified_kernel_registry.md specifies moe_experts_implementation as a new, user-configurable field. However, with the current implementation, users cannot set this field in their configuration (e.g., YAML files) and must continue to use the old moe_implementation field with its legacy values (fused, fused_quack). This undermines the goal of having a consistent and direct way to select kernel implementations by their new names.

To align with the design, moe_experts_implementation should be a regular field. Backward compatibility for the old moe_implementation field should be handled, for instance, in the __post_init__ method by mapping the old value to the new field and issuing a deprecation warning.

Comment thread veomni/ops/kernel_registry.py Outdated
Comment thread veomni/arguments/arguments_types.py Outdated
_MOE_IMPL_TO_KERNEL = {
"eager": "eager",
"fused": "triton_group_gemm",
"fused_quack": "quack_cutlass",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just quack is OK

@@ -0,0 +1,223 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we split the default and registry file into the ops dirs? I think it's easier to manage.

"'eager' for standard PyTorch, 'liger' for LigerKernel fused RMSNorm."
},
)
swiglu_mlp_implementation: str = field(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to control some liger ops valid in GPU bu not in NPU.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can register them separately as

KERNEL_REGISTRY.register(
KernelSpec(
name="liger",
op_name="rms_norm",
variant="qwen3_5",
factory=_liger_rms_norm_qwen3_5_factory,
hardware=HardwareRequirement(device_type="cuda"),
description="LigerKernel fused RMSNorm for Qwen3.5 (1+weight, zeros init, gemma casting)",
)
)

KERNEL_REGISTRY.register(
KernelSpec(
name="liger_npu",
op_name="rms_norm",
variant="qwen3_5",
factory=_liger_rms_norm_qwen3_5_npu_factory,
hardware=HardwareRequirement(device_type="npu"),
description="LigerKernel fused RMSNorm for Qwen3.5 (1+weight, zeros init, gemma casting) for NPU",
)
)

WDYT? Users will not have a footgun as well since if they select rms_norm_impl: liger in NPU env, they will get an error and force them to select rms_norm_impl: liger_npu (and vice-versa). WDYT?

Comment thread veomni/models/auto.py
return get_model_config(config_path, trust_remote_code=trust_remote_code, **config_kwargs)


def _apply_legacy_moe_patch(config, moe_implementation):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why still hold legacy patch

piyifan123 and others added 9 commits March 17, 2026 09:19
Resolve conflicts in qwen3_5_moe patch gen config by combining OpSlot-based
dispatch (branch) with new imports, model init, and dummy vars from main (#602).
Regenerated generated files from merged config.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yanghw116
Copy link
Copy Markdown

Excuse me.

I've seen

veomni_load_balancing_loss = OpSlot("moe_load_balancing_loss", "standard")

in veomni/models/transformers/qwen3_5_moe/qwen3_5_moe_gpu_patch_gen_config.py, but I haven't found any registration like

KERNEL_REGISTRY.register(
    KernelSpec(
        name="liger",
        op_name="moe_load_balancing_loss",
        variant="standard",
        factory=xxxxxxxx,
        hardware=HardwareRequirement(device_type="gpu"),
        description="xxxxxxxx",
    )
)

could you help me find out what is going on?


@dataclass(frozen=True)
class KernelSpec:
"""Describes a single kernel implementation."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need more explanation on each field

@Luosuu
Copy link
Copy Markdown
Collaborator

Luosuu commented Apr 3, 2026

Issues and Suggestions

Bug: MoE OpSlot binding path order issue

In auto.py:192-200, _bind_veomni_ops is called after the model is constructed. But the legacy path _apply_legacy_moe_patch sets config._moe_implementation before model init (it patches the class-level forward). The new OpSlot path binds after init, which is correct for OpSlot's guard pattern, but
the problem is:

if (            
      ops_implementation is not None
      and modeling_module is not None
      and _bind_veomni_ops(modeling_module, ops_implementation)                                                                                                                                                                                                                                               
  ):
      ...                                                                                                                                                                                                                                                                                                     
  elif moe_implementation is not None:
      _apply_legacy_moe_patch(config, moe_implementation)                                                                                                                                                                                                                                                     

If ops_implementation is provided but the module has zero OpSlots (e.g., a model not yet migrated), _bind_veomni_ops returns False, and the legacy path runs — good. But if a module has some OpSlots but not MoE (unlikely currently but architecturally possible), the legacy MoE path would be skipped.
Consider decoupling these two conditions.

Bug: moe_experts_implementation property masking moe_implementation

In arguments_types.py, moe_experts_implementation is a @Property that delegates to moe_implementation. However, _resolve_impl_name does getattr(ops_config, f"{op_name}_implementation", "eager") — for op_name="moe_experts", this will correctly resolve to the property. But the property returns
self.moe_implementation or "eager", and moe_implementation defaults to None. This means setting moe_implementation="fused" would make the OpSlot try to bind "fused" — which IS registered. This works but is confusing because the same string ("fused") means different things in the legacy path
(triggers apply_veomni_fused_moe_patch) vs the OpSlot path (resolves via registry). Make sure downstream callers and docs are clear about this.

Potential issue: Module-global OpSlot shared across model instances

As noted in the design doc's "Open Questions §4" — OpSlot objects are module-level globals. Two instances of the same model class with different ops_implementation configs would share the same slots. The second call to _bind_veomni_ops would overwrite the first's bindings. This is a real risk in
evaluation pipelines that compare eager vs fused. Consider at minimum adding a warning in OpSlot.bind() if already bound with a different impl.

ForCausalLMLoss changes have subtle semantics

In fused_cross_entropy/init.py, the removal of else: loss_func = _cross_entropy on line 86 means that when hidden_states is provided but cross_entropy_fn is None, loss_func defaults to _cross_entropy (set at module level, possibly eager if apply_veomni_loss_patch was never called). Previously,
the else-branch explicitly set it. The new code is logically equivalent only because loss_func is initialized from cross_entropy_fn or _cross_entropy on the line above. This is correct but fragile — if someone reorders the code, the fallback changes. Add a brief comment.

Ruff F821 suppression is broad

Adding F821 (undefined name) to all *_patch_gen_config.py files suppresses a useful lint for real bugs. Consider whether # noqa: F821 on specific lines in the generated code would be more targeted.

Test coverage

test_models_patch.py adds _verify_opslot_state which validates that has_kernel matches use_liger_kernel. This is good but only tests the boolean state, not that the bound kernel actually produces correct outputs. Consider adding a numerical test (e.g., small random input through RMSNorm with
OpSlot-bound Liger vs eager, check allclose).

Nit: _bound field on OpSlot is redundant

OpSlot tracks both _kernel and _bound. The _bound flag distinguishes "never called bind()" from "called bind('eager')" — both have _kernel = None. This is only used in repr. Acceptable but slightly over-engineered for a repr; a simpler approach would be a sentinel _UNBOUND = object().

Summary

This is a well-architected PR with thorough design documentation. The core OpSlot + KernelRegistry pattern is clean, extensible, and minimal in its diff from upstream HF. The main concerns are:

  1. The module-global OpSlot sharing across model instances (documented but unmitigated)
  2. The coupling between OpSlot and legacy MoE paths in auto.py


---

## Open Questions
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR! The shift from monkey-patching to the OpSlot-based registry and dependency injection is a massive improvement for readability, maintainability, and upstream compatibility.

Just a few thoughts, bringing up one potential edge case and adding my two cents to your Open Questions:

  1. Potential torch.compile Graph Breaks
    Since OpSlot instances are module-level globals and .has_kernel is a dynamic property evaluated during forward, I'm slightly concerned this might introduce graph breaks in torch.compile (PyTorch 2.x Dynamo) . We might want to verify its compilation behavior or see if we need torch.compiler.assume_constant_result for those checks.

  2. Regarding Q1: Preset Silent Fallback
    I highly recommend warning or explicitly failing rather than silently skipping. In large-scale training, if a user thinks they enabled liger to save memory but it silently falls back to eager due to a variant mismatch, the resulting OOM will be incredibly frustrating to debug.

  3. Regarding Q4: Multi-model processes sharing OpSlots
    You are right that module-level slots mean two instances of the same model cannot have different kernel configs in the same process. While this might be a rare edge case now, it could block future workflows like Teacher-Student distillation or reference-model setups (where one uses fused and the other uses eager). Not a blocker for this PR, but definitely a tech debt to keep an eye on.

Overall, the architecture looks really solid!

@TimYangst TimYangst force-pushed the piyifan/kernel-patch-exp branch from 245f68b to 7a6ba14 Compare April 21, 2026 01:12
TimYangst added a commit that referenced this pull request Apr 21, 2026
… PR #569)

Squash of commits from piyifan/kernel-patch-exp bringing in the
kernel selection and patching framework that this branch builds on.

Original piyifan commits preserved on origin/piyifan/kernel-patch-exp.
TimYangst added a commit that referenced this pull request Apr 22, 2026
… PR #569)

Squash of commits from piyifan/kernel-patch-exp bringing in the
kernel selection and patching framework that this branch builds on.

Original piyifan commits preserved on origin/piyifan/kernel-patch-exp.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants