Skip to content

Fused Add + RMSNorm pattern #56

@AndreSlavescu

Description

@AndreSlavescu

Motivation:

Qwen3 Decoder Layer has the following pattern in the forward pass:

hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

https://github.com/huggingface/transformers/blob/0dc2df5ddafe3cb5824ad24e85beba13e0aa6726/src/transformers/models/qwen3/modeling_qwen3.py#L271

Goal:

Design a fused cutlass kernel that performs the above pattern.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions