From b35c5e08f10df586eb79e1cbb4a9bc8867f21899 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sat, 9 May 2026 20:03:28 +0800 Subject: [PATCH 01/36] zaya1 support --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/zaya.md | 55 + src/transformers/conversion_mapping.py | 12 + src/transformers/models/__init__.py | 1 + src/transformers/models/auto/auto_mappings.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/zaya/__init__.py | 28 + .../models/zaya/configuration_zaya.py | 183 +++ src/transformers/models/zaya/modeling_zaya.py | 1126 ++++++++++++++++ src/transformers/models/zaya/modular_zaya.py | 1133 +++++++++++++++++ tests/models/zaya/__init__.py | 1 + 11 files changed, 2544 insertions(+) create mode 100644 docs/source/en/model_doc/zaya.md create mode 100644 src/transformers/models/zaya/__init__.py create mode 100644 src/transformers/models/zaya/configuration_zaya.py create mode 100755 src/transformers/models/zaya/modeling_zaya.py create mode 100644 src/transformers/models/zaya/modular_zaya.py create mode 100644 tests/models/zaya/__init__.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 11e3c9008d56..3fc1d2ef50fd 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -883,6 +883,8 @@ title: Zamba - local: model_doc/zamba2 title: Zamba2 + - local: model_doc/zaya + title: ZAYA title: Text models - sections: - local: model_doc/aimv2 diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md new file mode 100644 index 000000000000..7f881a47efb9 --- /dev/null +++ b/docs/source/en/model_doc/zaya.md @@ -0,0 +1,55 @@ + +*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-09.* + +# ZAYA + +## Overview + +ZAYA1 is a 760M active / 8.4B total parameter MoE language model trained by Zyphra. It combines Compressed +Convolutional Attention (CCA), a nonlinear ZAYA1 router, and residual scaling. + +ZAYA1 uses the Gemma 3 tokenizer. For more details, see the [ZAYA1 model card](https://huggingface.co/Zyphra/ZAYA1-8B) +and Zyphra's technical reports. + +## Usage examples + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + + +model_id = "Zyphra/ZAYA1-8B" +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + +inputs = tokenizer("What factors contributed to the fall of the Roman Empire?", return_tensors="pt").to(model.device) +outputs = model.generate(**inputs, max_new_tokens=100) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +## ZayaConfig + +[[autodoc]] ZayaConfig + +## ZayaModel + +[[autodoc]] ZayaModel + - forward + +## ZayaForCausalLM + +[[autodoc]] ZayaForCausalLM + - forward diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index e7937fed254f..dff0f65f5b53 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -561,6 +561,18 @@ def _build_checkpoint_conversion_mapping(): operations=[Transpose(1, 2, check_dims=True)], ), ], + "zaya": [ + WeightConverter( + source_patterns="zaya_block.experts.local_experts.*.linear_fc1.weight", + target_patterns="zaya_block.experts.gate_up_proj", + operations=[MergeModulelist(dim=0)], + ), + WeightConverter( + source_patterns="zaya_block.experts.local_experts.*.linear_fc2.weight", + target_patterns="zaya_block.experts.down_proj", + operations=[MergeModulelist(dim=0)], + ), + ], "phimoe": [ WeightRenaming(".block_sparse_moe.", ".mlp."), WeightRenaming(".gate.weight", ".router.weight"), diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 406c5f7be0fc..b1c0412758b1 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -479,6 +479,7 @@ from .youtu import * from .zamba import * from .zamba2 import * + from .zaya import * from .zoedepth import * else: import sys diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 048dd5275537..ed3866d0ed42 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -643,6 +643,7 @@ ("youtu", "YoutuConfig"), ("zamba", "ZambaConfig"), ("zamba2", "Zamba2Config"), + ("zaya", "ZayaConfig"), ("zoedepth", "ZoeDepthConfig"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2202cc773db0..4d90c73183e7 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -510,6 +510,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("yolos", "YolosModel"), ("yoso", "YosoModel"), ("youtu", "YoutuModel"), + ("zaya", "ZayaModel"), ("zamba", "ZambaModel"), ("zamba2", "Zamba2Model"), ] @@ -772,6 +773,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("xlstm", "xLSTMForCausalLM"), ("xmod", "XmodForCausalLM"), ("youtu", "YoutuForCausalLM"), + ("zaya", "ZayaForCausalLM"), ("zamba", "ZambaForCausalLM"), ("zamba2", "Zamba2ForCausalLM"), ] diff --git a/src/transformers/models/zaya/__init__.py b/src/transformers/models/zaya/__init__.py new file mode 100644 index 000000000000..54cc0c89f303 --- /dev/null +++ b/src/transformers/models/zaya/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 Zyphra and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_zaya import * + from .modeling_zaya import * + +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py new file mode 100644 index 000000000000..506df6eee3f0 --- /dev/null +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -0,0 +1,183 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zaya/modular_zaya.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zaya.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PreTrainedConfig +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +class ZayaConfig(PreTrainedConfig): + r""" + num_query_groups (`int`, *optional*, defaults to 2): + Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`. + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the language modeling head. + ffn_hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the feed-forward and expert hidden states. + rope_theta (`float`, *optional*, defaults to 5000000): + The base period of the RoPE embeddings. + moe_router_topk (`int`, *optional*, defaults to 1): + Number of selected experts per token. ZAYA checkpoints use top-1 routing. + zaya_mlp_expansion (`int`, *optional*, defaults to 256): + Expansion size used by the dense ZAYA blocks. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Fraction of each attention head dimension using rotary embeddings. + cca_time0 (`int`, *optional*, defaults to 2): + First temporal parameter of the CCA projection. + cca_time1 (`int`, *optional*, defaults to 2): + Second temporal parameter of the CCA projection. + swa_layers (`list[int]`, *optional*): + Per-layer selector for standard RoPE versus SWA RoPE embeddings. + swa_rotary_base (`float`, *optional*): + RoPE base used by SWA layers. + + ```python + >>> from transformers import ZayaConfig, ZayaModel + + >>> configuration = ZayaConfig() + >>> model = ZayaModel(configuration) + + >>> configuration = model.config + ``` + """ + + model_type = "zaya" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + num_query_groups=2, + use_cache=True, + attention_bias=False, + lm_head_bias=False, + vocab_size=262272, + hidden_size=2048, + ffn_hidden_size=4096, + num_hidden_layers=80, + num_experts=16, + num_attention_heads=8, + hidden_act="silu", + head_dim=128, + initializer_range=0.02, + max_position_embeddings=131072, + norm_epsilon=1e-05, + pad_token_id=0, + bos_token_id=2, + eos_token_id=106, + tie_word_embeddings=True, + rope_theta=5000000, + attention_dropout=0.0, + moe_router_topk=1, + zaya_mlp_expansion=256, + rope_parameters=None, + partial_rotary_factor=0.5, + num_key_value_heads=2, + cca_time0=2, + cca_time1=2, + swa_layers=None, + swa_rotary_base=None, + output_router_logits=False, + _attn_implementation="eager", + **kwargs, + ): + for unused_checkpoint_kwarg in ( + "cca", + "activation_func", + "normalization", + "add_bias_linear", + "gated_linear_unit", + "fused_add_norm", + "apply_rope_fusion", + "bias_activation_fusion", + "activation_func_fp8_input_store", + "clamp_temp", + "residual_in_fp32", + "rope_scaling", + "scale_residual_merge", + "sliding_window", + "zaya_high_prec", + "zaya_use_mod", + "zaya_use_eda", + ): + kwargs.pop(unused_checkpoint_kwarg, None) + + num_query_groups = num_key_value_heads if num_query_groups is None else num_query_groups + if head_dim is None: + raise ValueError("`head_dim` must be set for ZAYA.") + if num_query_groups != num_key_value_heads: + raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.") + if moe_router_topk != 1: + raise ValueError("ZAYA currently supports `moe_router_topk=1` only.") + + self.num_query_groups = num_query_groups + self.use_cache = use_cache + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_experts = num_experts + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.head_dim = head_dim + self.initializer_range = initializer_range + self.num_key_value_heads = num_key_value_heads + self.max_position_embeddings = max_position_embeddings + self.norm_epsilon = norm_epsilon + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.attention_dropout = attention_dropout + self.moe_router_topk = moe_router_topk + self.zaya_mlp_expansion = zaya_mlp_expansion + self.partial_rotary_factor = partial_rotary_factor + self.rope_theta = rope_theta + rope_parameters = dict(rope_parameters) if rope_parameters is not None else {"rope_type": "default"} + rope_parameters.setdefault("rope_theta", rope_theta) + rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor) + self.rope_parameters = rope_parameters + cca_time0 = 2 if cca_time0 is None else cca_time0 + cca_time1 = 2 if cca_time1 is None else cca_time1 + if (cca_time0, cca_time1) != (2, 2): + raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.") + if swa_layers is not None and len(swa_layers) != num_hidden_layers: + raise ValueError("`swa_layers` must have one entry per hidden layer.") + if swa_layers is not None and swa_rotary_base is None: + raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.") + + self.cca_time0 = cca_time0 + self.cca_time1 = cca_time1 + self.swa_layers = swa_layers + self.swa_rotary_base = swa_rotary_base + self.output_router_logits = output_router_logits + self._attn_implementation = _attn_implementation + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=self.tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["ZayaConfig"] diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py new file mode 100755 index 000000000000..bbbecaeb1907 --- /dev/null +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -0,0 +1,1126 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zaya/modular_zaya.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zaya.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import init + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from .configuration_zaya import ZayaConfig + + +class ZayaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: ZayaConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: ZayaConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernel_forward_from_hub("RMSNorm") +class ZayaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + ZayaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class ZayaDynamicCache(DynamicCache): + """ + Cache that includes both the KV cache and the CCA cache. + """ + + def __init__( + self, + config: ZayaConfig, + batch_size: int, + dtype: torch.dtype = torch.float16, + device: str | None = None, + ): + super().__init__() + self.config = config + self.batch_size = batch_size + self.dtype = dtype + self.device = device + self.conv_kernel_size = (config.cca_time0 - 1) + (config.cca_time1 - 1) + self.num_layers = config.num_hidden_layers + self.key_value_hidden_size = config.num_query_groups * config.head_dim + self.query_hidden_size = config.num_attention_heads * config.head_dim + self.conv_state_size = self.key_value_hidden_size + self.query_hidden_size + self.has_previous_state = False + + self.conv_states = [None for _ in range(self.num_layers)] + self.prev_v2 = [None for _ in range(self.num_layers)] + + def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor: + if new_conv_state.shape[1] < self.conv_kernel_size: + new_conv_state = F.pad( + new_conv_state.transpose(1, 2), (self.conv_kernel_size - new_conv_state.shape[1], 0) + ) + else: + new_conv_state = new_conv_state[:, -self.conv_kernel_size :, :].transpose(1, 2) + + if self.conv_states[layer_idx] is None: + self.conv_states[layer_idx] = torch.zeros_like(new_conv_state) + + if not self.has_previous_state: + self.conv_states[layer_idx].copy_(new_conv_state) + else: + conv_state = torch.cat([self.conv_states[layer_idx], new_conv_state], dim=-1)[ + :, :, -self.conv_kernel_size : + ] + self.conv_states[layer_idx].copy_(conv_state) + return self.conv_states[layer_idx] + + def update_prev_v2(self, layer_idx: int, new_prev_v2: torch.Tensor) -> torch.Tensor: + if self.prev_v2[layer_idx] is None: + self.prev_v2[layer_idx] = torch.zeros_like(new_prev_v2) + self.prev_v2[layer_idx].copy_(new_prev_v2) + return self.prev_v2[layer_idx] + + def reset(self): + super().reset() + for conv_state in self.conv_states: + if conv_state is not None: + conv_state.zero_() + for prev_v2 in self.prev_v2: + if prev_v2 is not None: + prev_v2.zero_() + self.has_previous_state = False + + def _reorder_auxiliary_states(self, indices: torch.LongTensor): + for layer_idx, conv_state in enumerate(self.conv_states): + if conv_state is not None: + self.conv_states[layer_idx] = conv_state.index_select(0, indices.to(conv_state.device)) + for layer_idx, prev_v2 in enumerate(self.prev_v2): + if prev_v2 is not None: + self.prev_v2[layer_idx] = prev_v2.index_select(0, indices.to(prev_v2.device)) + self.batch_size = indices.shape[0] + + def reorder_cache(self, beam_idx: torch.LongTensor): + super().reorder_cache(beam_idx) + self._reorder_auxiliary_states(beam_idx) + + def batch_repeat_interleave(self, repeats: int): + super().batch_repeat_interleave(repeats) + for layer_idx, conv_state in enumerate(self.conv_states): + if conv_state is not None: + self.conv_states[layer_idx] = conv_state.repeat_interleave(repeats, dim=0) + for layer_idx, prev_v2 in enumerate(self.prev_v2): + if prev_v2 is not None: + self.prev_v2[layer_idx] = prev_v2.repeat_interleave(repeats, dim=0) + self.batch_size *= repeats + + def batch_select_indices(self, indices: torch.Tensor): + super().batch_select_indices(indices) + self._reorder_auxiliary_states(indices) + + +class CCA(nn.Module): + def __init__( + self, + config: ZayaConfig, + num_key_value_heads: int = 2, + num_attention_heads: int = 8, + hidden_size: int | None = None, + head_dim: int = 128, + cca_time0: int = 2, + cca_time1: int = 2, + layer_number: int = 0, + ): + super().__init__() + self.config = config + self.layer_number = layer_number + + self.hidden_size = int(hidden_size or config.hidden_size) + + self.depthwise_kernel_size = cca_time0 + self.grouped_kernel_size = cca_time1 + self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) + + self.num_key_value_heads = int(num_key_value_heads) + self.num_attention_heads = int(num_attention_heads) + + self.head_dim = int(head_dim) + self.key_value_hidden_size = self.num_key_value_heads * self.head_dim + self.query_hidden_size = self.num_attention_heads * self.head_dim + self.sqrt_head_dim = self.head_dim**0.5 + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") + + self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias) + self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias) + self.val_proj1 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias) + self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias) + + conv_channels = self.key_value_hidden_size + self.query_hidden_size + self.conv_qk = nn.Sequential( + nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.depthwise_kernel_size, + groups=conv_channels, + padding=0, + stride=1, + ), + nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.grouped_kernel_size, + groups=(self.num_key_value_heads + self.num_attention_heads), + padding=0, + stride=1, + ), + ) + + self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads)) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: ZayaDynamicCache | None, + attention_mask: torch.Tensor | None = None, + ): + if attention_mask is not None: + hidden_states = hidden_states * attention_mask[:, :, None].to(hidden_states.dtype) + + batch_size, seq_length, _ = hidden_states.shape + + projected_queries = self.linear_q(hidden_states) + projected_keys = self.linear_k(hidden_states) + qk_states = torch.cat([projected_queries, projected_keys], dim=-1) + + query_residual = projected_queries.view(batch_size, seq_length, self.num_attention_heads, self.head_dim) + key_residual = projected_keys.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim) + + key_residual = key_residual.unsqueeze(-2).expand(-1, -1, -1, self.num_key_value_groups, -1) + key_residual = key_residual.reshape(batch_size, seq_length, self.num_attention_heads, self.head_dim) + query_residual = (query_residual + key_residual) * 0.5 + key_residual = query_residual.view( + batch_size, seq_length, self.num_key_value_heads, self.num_key_value_groups, self.head_dim + ).mean(dim=-2) + + qk_states = qk_states.transpose(1, 2) + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state + if use_precomputed_states: + cached_qk_states = past_key_values.conv_states[self.layer_number] + conv_input = torch.cat([cached_qk_states, qk_states], dim=-1) + else: + conv_input = F.pad(qk_states, (self.total_padding, 0)) + + if past_key_values is not None: + past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=qk_states.transpose(1, 2)) + + convolved_qk_states = self.conv_qk(conv_input).transpose(1, 2) + + query = ( + convolved_qk_states[..., : self.query_hidden_size].view( + batch_size, seq_length, self.num_attention_heads, self.head_dim + ) + + query_residual + ) + + key = ( + convolved_qk_states[..., self.query_hidden_size :].view( + batch_size, seq_length, self.num_key_value_heads, self.head_dim + ) + + key_residual + ) + + value_current = self.val_proj1(hidden_states) + projected_v2 = self.val_proj2(hidden_states) + if use_precomputed_states: + first_v2 = past_key_values.prev_v2[self.layer_number].unsqueeze(1) + else: + first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size)) + value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1) + + if past_key_values is not None: + past_key_values.update_prev_v2(self.layer_number, projected_v2[:, -1, :]) + + value = torch.cat([value_current, value_delayed], dim=-1).view( + batch_size, seq_length, self.num_key_value_heads, self.head_dim + ) + + norm_eps = torch.finfo(query.dtype).eps + query_norm = query.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + key_norm = key.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + + key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1) + query = query * (self.sqrt_head_dim / query_norm) + + query = query.reshape(batch_size, seq_length, self.query_hidden_size) + key = key.reshape(batch_size, seq_length, self.key_value_hidden_size) + value = value.reshape(batch_size, seq_length, self.key_value_hidden_size) + return query, key, value + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class ZayaAttention(nn.Module): + def __init__(self, config: ZayaConfig, layer_n): + super().__init__() + self.config = config + self.layer_n = layer_n + self.layer_idx = layer_n + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.head_dim = config.head_dim + self.scaling = self.head_dim**-0.5 + + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, + self.hidden_size, + bias=self.config.attention_bias, + ) + self.qkv = CCA( + config=self.config, + num_attention_heads=self.config.num_attention_heads, + num_key_value_heads=self.config.num_query_groups, + hidden_size=self.hidden_size, + head_dim=self.config.head_dim, + cca_time0=self.config.cca_time0, + cca_time1=self.config.cca_time1, + layer_number=layer_n, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + attention_mask_2d: torch.Tensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: + batch_size, seq_length, _ = hidden_states.shape + query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, attention_mask_2d) + query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads, self.head_dim) + key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n) + + causal_mask = attention_mask + if causal_mask is not None: + causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]] + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + causal_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + output_attentions=output_attentions, + ) + + attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_values + + +def _apply_residual_scaling( + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + residual_scaling, + rms_norm: ZayaRMSNorm, +) -> tuple[torch.Tensor, torch.Tensor]: + residual, hidden_states = residual_scaling(residual, hidden_states) + residual = hidden_states.to(torch.float32) if residual is None else hidden_states + residual + hidden_states = rms_norm(residual.to(dtype=rms_norm.weight.dtype)) + return hidden_states, residual + + +class ZayaDecoderATTLayer(GradientCheckpointingLayer): + def __init__(self, config: ZayaConfig, layer_n: int): + + super().__init__() + self.config = config + self.self_attn = ZayaAttention(config, layer_n) + + self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + self.res_scale = ResidualScaling(config, layer_n) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + attention_mask: torch.Tensor | None = None, + attention_mask_2d: torch.Tensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + prev_router_hidden_states: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) + + hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + attention_mask_2d=attention_mask_2d, + past_key_values=past_key_values, + output_attentions=output_attentions, + position_embeddings=position_embeddings, + ) + + return hidden_states, self_attn_weights if output_attentions else None, residual, prev_router_hidden_states + + +class ResidualScaling(nn.Module): + def __init__(self, config, layer_n): + super().__init__() + self.not_first_layer = layer_n != 0 + self.hidden_states_scale = torch.nn.Parameter(torch.ones(config.hidden_size)) + self.hidden_states_bias = torch.nn.Parameter(torch.zeros(config.hidden_size)) + + if self.not_first_layer: + self.residual_scale = torch.nn.Parameter(torch.ones(config.hidden_size)) + self.residual_bias = torch.nn.Parameter(torch.zeros(config.hidden_size)) + + def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor): + hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale + if self.not_first_layer: + residual = (residual + self.residual_bias) * self.residual_scale + return residual, hidden_states + + +class ZayaRouter(nn.Module): + def __init__( + self, + config, + layer_idx: int, + num_moe_experts: int, + moe_router_topk: int, + mlp_expansion: int, + hidden_size: int | None = None, + ) -> None: + super().__init__() + + self.config = config + self.hidden_size = int(hidden_size or getattr(config, "hidden_size")) + self.layer_idx = layer_idx + + self.num_experts = num_moe_experts + 1 + self.topk = int(moe_router_topk) + self.mlp_expansion = int(mlp_expansion) + + self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True) + + zaya_first_layer = 1 + self.use_eda = self.layer_idx != zaya_first_layer + + ln_eps = float(getattr(config, "norm_epsilon", 1e-5)) + self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=ln_eps) + if self.use_eda: + self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion)) + + self.non_linearity = nn.GELU() + self.router_mlp = nn.Sequential( + nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True), + self.non_linearity, + nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True), + self.non_linearity, + nn.Linear(self.mlp_expansion, self.num_experts, bias=False), + ) + + self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + self.balancing_biases[-1] = -1.0 + + def forward( + self, + hidden_states: torch.Tensor, + router_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + seq_length = hidden_states.shape[1] + + router_hidden_states = self.down_proj(hidden_states) + + if self.use_eda and (router_states is not None): + router_hidden_states = router_hidden_states + router_states * self.router_states_scale + + router_hidden_states_next = router_hidden_states[:, -seq_length:].clone() + router_hidden_states = self.rmsnorm_eda(router_hidden_states) + logits = self.router_mlp(router_hidden_states) + expert_prob = torch.softmax(logits, dim=-1) + + expert_choice = expert_prob.detach().to(torch.float32) + self.balancing_biases + _, expert_choice = torch.topk(expert_choice, self.topk, dim=-1) + route_prob = torch.gather(expert_prob, dim=2, index=expert_choice) + + return ( + route_prob.reshape(-1, self.topk), + expert_choice.reshape(-1, self.topk), + router_hidden_states_next, + logits.reshape(-1, self.num_experts), + ) + + +@use_experts_implementation +class ZayaExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config, num_experts: int, ffn_hidden_size: int): + super().__init__() + self.num_experts = num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = ffn_hidden_size // 2 + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class ZayaBlock(nn.Module): + def __init__( + self, + config, + num_moe_experts: int, + mlp_expansion: int, + ffn_hidden_size: int, + layer_n: int, + ): + + super().__init__() + self.config = config + self.hidden_dim = config.hidden_size + self.num_moe_experts = num_moe_experts + self.router = ZayaRouter( + config=self.config, + layer_idx=layer_n, + num_moe_experts=self.num_moe_experts, + moe_router_topk=getattr(self.config, "moe_router_topk", 1), + mlp_expansion=mlp_expansion, + hidden_size=self.hidden_dim, + ) + self.experts = ZayaExperts(self.config, self.num_moe_experts, ffn_hidden_size=ffn_hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: + route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router( + hidden_states, router_states=prev_router_hidden_states + ) + batch_size, seq_length, emb_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) + expert_output = self.experts(hidden_states_flat, expert_choice, route_prob) + expert_output = expert_output.view(batch_size, seq_length, emb_dim) + + return expert_output, prev_router_hidden_states, router_logits + + +class ZayaDecoderMLPLayer(GradientCheckpointingLayer): + def __init__( + self, + config: ZayaConfig, + num_moe_experts: int, + mlp_expansion: int, + ffn_hidden_size: int, + layer_n: int, + ): + + super().__init__() + self.config = config + self.zaya_block = ZayaBlock( + config, + num_moe_experts, + mlp_expansion, + ffn_hidden_size, + layer_n, + ) + self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + self.res_scale = ResidualScaling(config, layer_n) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + prev_router_hidden_states: torch.Tensor | None = None, + output_router_logits: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]: + hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) + + hidden_states, prev_router_hidden_states, router_logits = self.zaya_block( + hidden_states, + prev_router_hidden_states, + ) + + return ( + hidden_states, + router_logits if output_router_logits else None, + residual, + prev_router_hidden_states, + ) + + +class ZayaPreTrainedModel(PreTrainedModel): + config: ZayaConfig + config_class = ZayaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(ZayaRouter, index=3), + } + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, ResidualScaling): + init.ones_(module.hidden_states_scale) + init.zeros_(module.hidden_states_bias) + if module.not_first_layer: + init.ones_(module.residual_scale) + init.zeros_(module.residual_bias) + elif isinstance(module, ZayaRouter): + if module.use_eda: + init.ones_(module.router_states_scale) + init.zeros_(module.balancing_biases) + module.balancing_biases[-1] = -1.0 + elif isinstance(module, ZayaExperts): + std = self.config.initializer_range + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + + +@auto_docstring +class ZayaModel(ZayaPreTrainedModel): + def __init__(self, config: ZayaConfig): + + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = [] + + for layer_n in range(config.num_hidden_layers): + if layer_n % 2 == 1: + self.layers.append( + ZayaDecoderMLPLayer( + config, + config.num_experts, + config.zaya_mlp_expansion, + config.ffn_hidden_size, + layer_n, + ) + ) + else: + self.layers.append(ZayaDecoderATTLayer(config, layer_n)) + self.layers = nn.ModuleList(self.layers) + + self.gradient_checkpointing = False + self.res_scale = ResidualScaling(config, config.num_hidden_layers) + + self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + + self.rotary_emb = ZayaRotaryEmbedding(config=config) + if self.config.swa_layers is not None: + swa_config = copy.copy(config) + swa_config.rope_parameters = { + **config.rope_parameters, + "rope_theta": swa_config.swa_rotary_base, + } + self.swa_rotary_emb = ZayaRotaryEmbedding(config=swa_config) + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + output_router_logits: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = ZayaDynamicCache( + self.config, inputs_embeds.shape[0], dtype=self.dtype, device=self.device + ) + + residual = None + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ).unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + position_ids, + past_key_values, + ) + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError("ZAYA CCA requires a 2D `attention_mask` to mask padding tokens before convolution.") + # ZayaDynamicCache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. + # CCA only needs it during multi-token prefill; single-token decoding uses the cached convolution state. + attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None + if inputs_embeds.shape[1] == 1: + attention_mask_2d = None + + hidden_states = inputs_embeds + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + if self.config.swa_layers is not None: + swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + prev_router_hidden_states = None + + for layer_n, decoder_layer in enumerate(self.layers): + if self.config.swa_layers is not None: + emb_to_use = position_embeddings if self.config.swa_layers[layer_n] == 0 else swa_position_embeddings + else: + emb_to_use = position_embeddings + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + residual, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + position_embeddings=emb_to_use, + prev_router_hidden_states=prev_router_hidden_states, + attention_mask_2d=attention_mask_2d, + **kwargs, + ) + + hidden_states = layer_outputs[0] + residual = layer_outputs[2] + prev_router_hidden_states = layer_outputs[3] + + if isinstance(decoder_layer, ZayaDecoderATTLayer): + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache, + ): + return create_causal_mask( + config=self.config, + inputs_embeds=input_tensor, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + +@auto_docstring +class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _is_stateful = True + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.model = ZayaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.post_init() + + def set_decoder(self, decoder): + self.model = decoder + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_router_logits: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=None, + logits=logits, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + if past_key_values is not None and not isinstance(past_key_values, ZayaDynamicCache): + raise ValueError( + f"Zaya uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}." + ) + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + def _prepare_cache_for_generation( + self, + generation_config, + model_kwargs: dict, + generation_mode, + batch_size: int, + max_cache_length: int, + ): + if generation_config.use_cache is False: + return + + if "past_key_values" not in model_kwargs: + cache_batch_size = batch_size * max(generation_config.num_beams, generation_config.num_return_sequences) + model_kwargs["past_key_values"] = ZayaDynamicCache( + self.config, cache_batch_size, dtype=self.dtype, device=self.device + ) + generation_config.cache_implementation = None + return super()._prepare_cache_for_generation( + generation_config=generation_config, + model_kwargs=model_kwargs, + generation_mode=generation_mode, + batch_size=batch_size, + max_cache_length=max_cache_length, + ) + + +__all__ = ["ZayaPreTrainedModel", "ZayaModel", "ZayaForCausalLM"] diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py new file mode 100644 index 000000000000..60bb870c73a5 --- /dev/null +++ b/src/transformers/models/zaya/modular_zaya.py @@ -0,0 +1,1133 @@ +# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Zaya model.""" + +import copy +from collections.abc import Callable + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import init + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PreTrainedConfig +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, +) +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..glm4.modeling_glm4 import Glm4RotaryEmbedding +from ..qwen3_5_moe.modeling_qwen3_5_moe import ( + apply_rotary_pos_emb, + eager_attention_forward, +) +from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +class ZayaConfig(PreTrainedConfig): + r""" + num_query_groups (`int`, *optional*, defaults to 2): + Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`. + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the language modeling head. + ffn_hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the feed-forward and expert hidden states. + rope_theta (`float`, *optional*, defaults to 5000000): + The base period of the RoPE embeddings. + moe_router_topk (`int`, *optional*, defaults to 1): + Number of selected experts per token. ZAYA checkpoints use top-1 routing. + zaya_mlp_expansion (`int`, *optional*, defaults to 256): + Expansion size used by the dense ZAYA blocks. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Fraction of each attention head dimension using rotary embeddings. + cca_time0 (`int`, *optional*, defaults to 2): + First temporal parameter of the CCA projection. + cca_time1 (`int`, *optional*, defaults to 2): + Second temporal parameter of the CCA projection. + swa_layers (`list[int]`, *optional*): + Per-layer selector for standard RoPE versus SWA RoPE embeddings. + swa_rotary_base (`float`, *optional*): + RoPE base used by SWA layers. + + ```python + >>> from transformers import ZayaConfig, ZayaModel + + >>> configuration = ZayaConfig() + >>> model = ZayaModel(configuration) + + >>> configuration = model.config + ``` + """ + + model_type = "zaya" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + num_query_groups=2, + use_cache=True, + attention_bias=False, + lm_head_bias=False, + vocab_size=262272, + hidden_size=2048, + ffn_hidden_size=4096, + num_hidden_layers=80, + num_experts=16, + num_attention_heads=8, + hidden_act="silu", + head_dim=128, + initializer_range=0.02, + max_position_embeddings=131072, + norm_epsilon=1e-05, + pad_token_id=0, + bos_token_id=2, + eos_token_id=106, + tie_word_embeddings=True, + rope_theta=5000000, + attention_dropout=0.0, + moe_router_topk=1, + zaya_mlp_expansion=256, + rope_parameters=None, + partial_rotary_factor=0.5, + num_key_value_heads=2, + cca_time0=2, + cca_time1=2, + swa_layers=None, + swa_rotary_base=None, + output_router_logits=False, + _attn_implementation="eager", + **kwargs, + ): + for unused_checkpoint_kwarg in ( + "cca", + "activation_func", + "normalization", + "add_bias_linear", + "gated_linear_unit", + "fused_add_norm", + "apply_rope_fusion", + "bias_activation_fusion", + "activation_func_fp8_input_store", + "clamp_temp", + "residual_in_fp32", + "rope_scaling", + "scale_residual_merge", + "sliding_window", + "zaya_high_prec", + "zaya_use_mod", + "zaya_use_eda", + ): + kwargs.pop(unused_checkpoint_kwarg, None) + + num_query_groups = num_key_value_heads if num_query_groups is None else num_query_groups + if head_dim is None: + raise ValueError("`head_dim` must be set for ZAYA.") + if num_query_groups != num_key_value_heads: + raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.") + if moe_router_topk != 1: + raise ValueError("ZAYA currently supports `moe_router_topk=1` only.") + + self.num_query_groups = num_query_groups + self.use_cache = use_cache + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_experts = num_experts + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.head_dim = head_dim + self.initializer_range = initializer_range + self.num_key_value_heads = num_key_value_heads + self.max_position_embeddings = max_position_embeddings + self.norm_epsilon = norm_epsilon + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.attention_dropout = attention_dropout + self.moe_router_topk = moe_router_topk + self.zaya_mlp_expansion = zaya_mlp_expansion + self.partial_rotary_factor = partial_rotary_factor + self.rope_theta = rope_theta + rope_parameters = dict(rope_parameters) if rope_parameters is not None else {"rope_type": "default"} + rope_parameters.setdefault("rope_theta", rope_theta) + rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor) + self.rope_parameters = rope_parameters + cca_time0 = 2 if cca_time0 is None else cca_time0 + cca_time1 = 2 if cca_time1 is None else cca_time1 + if (cca_time0, cca_time1) != (2, 2): + raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.") + if swa_layers is not None and len(swa_layers) != num_hidden_layers: + raise ValueError("`swa_layers` must have one entry per hidden layer.") + if swa_layers is not None and swa_rotary_base is None: + raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.") + + self.cca_time0 = cca_time0 + self.cca_time1 = cca_time1 + self.swa_layers = swa_layers + self.swa_rotary_base = swa_rotary_base + self.output_router_logits = output_router_logits + self._attn_implementation = _attn_implementation + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=self.tie_word_embeddings, + **kwargs, + ) + + +class ZayaRotaryEmbedding(Glm4RotaryEmbedding): + pass + + +class ZayaRMSNorm(Qwen3MoeRMSNorm): + pass + + +class ZayaDynamicCache(DynamicCache): + """ + Cache that includes both the KV cache and the CCA cache. + """ + + def __init__( + self, + config: ZayaConfig, + batch_size: int, + dtype: torch.dtype = torch.float16, + device: str | None = None, + ): + super().__init__() + self.config = config + self.batch_size = batch_size + self.dtype = dtype + self.device = device + self.conv_kernel_size = (config.cca_time0 - 1) + (config.cca_time1 - 1) + self.num_layers = config.num_hidden_layers + self.key_value_hidden_size = config.num_query_groups * config.head_dim + self.query_hidden_size = config.num_attention_heads * config.head_dim + self.conv_state_size = self.key_value_hidden_size + self.query_hidden_size + self.has_previous_state = False + + self.conv_states = [None for _ in range(self.num_layers)] + self.prev_v2 = [None for _ in range(self.num_layers)] + + def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor: + if new_conv_state.shape[1] < self.conv_kernel_size: + new_conv_state = F.pad( + new_conv_state.transpose(1, 2), (self.conv_kernel_size - new_conv_state.shape[1], 0) + ) + else: + new_conv_state = new_conv_state[:, -self.conv_kernel_size :, :].transpose(1, 2) + + if self.conv_states[layer_idx] is None: + self.conv_states[layer_idx] = torch.zeros_like(new_conv_state) + + if not self.has_previous_state: + self.conv_states[layer_idx].copy_(new_conv_state) + else: + conv_state = torch.cat([self.conv_states[layer_idx], new_conv_state], dim=-1)[ + :, :, -self.conv_kernel_size : + ] + self.conv_states[layer_idx].copy_(conv_state) + return self.conv_states[layer_idx] + + def update_prev_v2(self, layer_idx: int, new_prev_v2: torch.Tensor) -> torch.Tensor: + if self.prev_v2[layer_idx] is None: + self.prev_v2[layer_idx] = torch.zeros_like(new_prev_v2) + self.prev_v2[layer_idx].copy_(new_prev_v2) + return self.prev_v2[layer_idx] + + def reset(self): + super().reset() + for conv_state in self.conv_states: + if conv_state is not None: + conv_state.zero_() + for prev_v2 in self.prev_v2: + if prev_v2 is not None: + prev_v2.zero_() + self.has_previous_state = False + + def _reorder_auxiliary_states(self, indices: torch.LongTensor): + for layer_idx, conv_state in enumerate(self.conv_states): + if conv_state is not None: + self.conv_states[layer_idx] = conv_state.index_select(0, indices.to(conv_state.device)) + for layer_idx, prev_v2 in enumerate(self.prev_v2): + if prev_v2 is not None: + self.prev_v2[layer_idx] = prev_v2.index_select(0, indices.to(prev_v2.device)) + self.batch_size = indices.shape[0] + + def reorder_cache(self, beam_idx: torch.LongTensor): + super().reorder_cache(beam_idx) + self._reorder_auxiliary_states(beam_idx) + + def batch_repeat_interleave(self, repeats: int): + super().batch_repeat_interleave(repeats) + for layer_idx, conv_state in enumerate(self.conv_states): + if conv_state is not None: + self.conv_states[layer_idx] = conv_state.repeat_interleave(repeats, dim=0) + for layer_idx, prev_v2 in enumerate(self.prev_v2): + if prev_v2 is not None: + self.prev_v2[layer_idx] = prev_v2.repeat_interleave(repeats, dim=0) + self.batch_size *= repeats + + def batch_select_indices(self, indices: torch.Tensor): + super().batch_select_indices(indices) + self._reorder_auxiliary_states(indices) + + +class CCA(nn.Module): + def __init__( + self, + config: ZayaConfig, + num_key_value_heads: int = 2, + num_attention_heads: int = 8, + hidden_size: int | None = None, + head_dim: int = 128, + cca_time0: int = 2, + cca_time1: int = 2, + layer_number: int = 0, + ): + super().__init__() + self.config = config + self.layer_number = layer_number + + self.hidden_size = int(hidden_size or config.hidden_size) + + self.depthwise_kernel_size = cca_time0 + self.grouped_kernel_size = cca_time1 + self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) + + self.num_key_value_heads = int(num_key_value_heads) + self.num_attention_heads = int(num_attention_heads) + + self.head_dim = int(head_dim) + self.key_value_hidden_size = self.num_key_value_heads * self.head_dim + self.query_hidden_size = self.num_attention_heads * self.head_dim + self.sqrt_head_dim = self.head_dim**0.5 + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") + + self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias) + self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias) + self.val_proj1 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias) + self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias) + + conv_channels = self.key_value_hidden_size + self.query_hidden_size + self.conv_qk = nn.Sequential( + nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.depthwise_kernel_size, + groups=conv_channels, + padding=0, + stride=1, + ), + nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.grouped_kernel_size, + groups=(self.num_key_value_heads + self.num_attention_heads), + padding=0, + stride=1, + ), + ) + + self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads)) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: ZayaDynamicCache | None, + attention_mask: torch.Tensor | None = None, + ): + if attention_mask is not None: + hidden_states = hidden_states * attention_mask[:, :, None].to(hidden_states.dtype) + + batch_size, seq_length, _ = hidden_states.shape + + projected_queries = self.linear_q(hidden_states) + projected_keys = self.linear_k(hidden_states) + qk_states = torch.cat([projected_queries, projected_keys], dim=-1) + + query_residual = projected_queries.view(batch_size, seq_length, self.num_attention_heads, self.head_dim) + key_residual = projected_keys.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim) + + key_residual = key_residual.unsqueeze(-2).expand(-1, -1, -1, self.num_key_value_groups, -1) + key_residual = key_residual.reshape(batch_size, seq_length, self.num_attention_heads, self.head_dim) + query_residual = (query_residual + key_residual) * 0.5 + key_residual = query_residual.view( + batch_size, seq_length, self.num_key_value_heads, self.num_key_value_groups, self.head_dim + ).mean(dim=-2) + + qk_states = qk_states.transpose(1, 2) + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state + if use_precomputed_states: + cached_qk_states = past_key_values.conv_states[self.layer_number] + conv_input = torch.cat([cached_qk_states, qk_states], dim=-1) + else: + conv_input = F.pad(qk_states, (self.total_padding, 0)) + + if past_key_values is not None: + past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=qk_states.transpose(1, 2)) + + convolved_qk_states = self.conv_qk(conv_input).transpose(1, 2) + + query = ( + convolved_qk_states[..., : self.query_hidden_size].view( + batch_size, seq_length, self.num_attention_heads, self.head_dim + ) + + query_residual + ) + + key = ( + convolved_qk_states[..., self.query_hidden_size :].view( + batch_size, seq_length, self.num_key_value_heads, self.head_dim + ) + + key_residual + ) + + value_current = self.val_proj1(hidden_states) + projected_v2 = self.val_proj2(hidden_states) + if use_precomputed_states: + first_v2 = past_key_values.prev_v2[self.layer_number].unsqueeze(1) + else: + first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size)) + value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1) + + if past_key_values is not None: + past_key_values.update_prev_v2(self.layer_number, projected_v2[:, -1, :]) + + value = torch.cat([value_current, value_delayed], dim=-1).view( + batch_size, seq_length, self.num_key_value_heads, self.head_dim + ) + + norm_eps = torch.finfo(query.dtype).eps + query_norm = query.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + key_norm = key.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + + key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1) + query = query * (self.sqrt_head_dim / query_norm) + + query = query.reshape(batch_size, seq_length, self.query_hidden_size) + key = key.reshape(batch_size, seq_length, self.key_value_hidden_size) + value = value.reshape(batch_size, seq_length, self.key_value_hidden_size) + return query, key, value + + +class ZayaAttention(nn.Module): + def __init__(self, config: ZayaConfig, layer_n): + super().__init__() + self.config = config + self.layer_n = layer_n + self.layer_idx = layer_n + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.head_dim = config.head_dim + self.scaling = self.head_dim**-0.5 + + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, + self.hidden_size, + bias=self.config.attention_bias, + ) + self.qkv = CCA( + config=self.config, + num_attention_heads=self.config.num_attention_heads, + num_key_value_heads=self.config.num_query_groups, + hidden_size=self.hidden_size, + head_dim=self.config.head_dim, + cca_time0=self.config.cca_time0, + cca_time1=self.config.cca_time1, + layer_number=layer_n, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + attention_mask_2d: torch.Tensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: + batch_size, seq_length, _ = hidden_states.shape + query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, attention_mask_2d) + query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads, self.head_dim) + key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n) + + causal_mask = attention_mask + if causal_mask is not None: + causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]] + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + causal_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + output_attentions=output_attentions, + ) + + attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_values + + +class ZayaDecoderATTLayer(GradientCheckpointingLayer): + def __init__(self, config: ZayaConfig, layer_n: int): + + super().__init__() + self.config = config + self.self_attn = ZayaAttention(config, layer_n) + + self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + self.res_scale = ResidualScaling(config, layer_n) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + attention_mask: torch.Tensor | None = None, + attention_mask_2d: torch.Tensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + prev_router_hidden_states: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) + + hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + attention_mask_2d=attention_mask_2d, + past_key_values=past_key_values, + output_attentions=output_attentions, + position_embeddings=position_embeddings, + ) + + return hidden_states, self_attn_weights if output_attentions else None, residual, prev_router_hidden_states + + +class ResidualScaling(nn.Module): + def __init__(self, config, layer_n): + super().__init__() + self.not_first_layer = layer_n != 0 + self.hidden_states_scale = torch.nn.Parameter(torch.ones(config.hidden_size)) + self.hidden_states_bias = torch.nn.Parameter(torch.zeros(config.hidden_size)) + + if self.not_first_layer: + self.residual_scale = torch.nn.Parameter(torch.ones(config.hidden_size)) + self.residual_bias = torch.nn.Parameter(torch.zeros(config.hidden_size)) + + def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor): + hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale + if self.not_first_layer: + residual = (residual + self.residual_bias) * self.residual_scale + return residual, hidden_states + + +def _apply_residual_scaling( + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + residual_scaling, + rms_norm: ZayaRMSNorm, +) -> tuple[torch.Tensor, torch.Tensor]: + residual, hidden_states = residual_scaling(residual, hidden_states) + residual = hidden_states.to(torch.float32) if residual is None else hidden_states + residual + hidden_states = rms_norm(residual.to(dtype=rms_norm.weight.dtype)) + return hidden_states, residual + + +class ZayaRouter(nn.Module): + def __init__( + self, + config, + layer_idx: int, + num_moe_experts: int, + moe_router_topk: int, + mlp_expansion: int, + hidden_size: int | None = None, + ) -> None: + super().__init__() + + self.config = config + self.hidden_size = int(hidden_size or getattr(config, "hidden_size")) + self.layer_idx = layer_idx + + self.num_experts = num_moe_experts + 1 + self.topk = int(moe_router_topk) + self.mlp_expansion = int(mlp_expansion) + + self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True) + + zaya_first_layer = 1 + self.use_eda = self.layer_idx != zaya_first_layer + + ln_eps = float(getattr(config, "norm_epsilon", 1e-5)) + self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=ln_eps) + if self.use_eda: + self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion)) + + self.non_linearity = nn.GELU() + self.router_mlp = nn.Sequential( + nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True), + self.non_linearity, + nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True), + self.non_linearity, + nn.Linear(self.mlp_expansion, self.num_experts, bias=False), + ) + + self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + self.balancing_biases[-1] = -1.0 + + def forward( + self, + hidden_states: torch.Tensor, + router_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + seq_length = hidden_states.shape[1] + + router_hidden_states = self.down_proj(hidden_states) + + if self.use_eda and (router_states is not None): + router_hidden_states = router_hidden_states + router_states * self.router_states_scale + + router_hidden_states_next = router_hidden_states[:, -seq_length:].clone() + router_hidden_states = self.rmsnorm_eda(router_hidden_states) + logits = self.router_mlp(router_hidden_states) + expert_prob = torch.softmax(logits, dim=-1) + + expert_choice = expert_prob.detach().to(torch.float32) + self.balancing_biases + _, expert_choice = torch.topk(expert_choice, self.topk, dim=-1) + route_prob = torch.gather(expert_prob, dim=2, index=expert_choice) + + return ( + route_prob.reshape(-1, self.topk), + expert_choice.reshape(-1, self.topk), + router_hidden_states_next, + logits.reshape(-1, self.num_experts), + ) + + +@use_experts_implementation +class ZayaExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config, num_experts: int, ffn_hidden_size: int): + super().__init__() + self.num_experts = num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = ffn_hidden_size // 2 + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class ZayaBlock(nn.Module): + def __init__( + self, + config, + num_moe_experts: int, + mlp_expansion: int, + ffn_hidden_size: int, + layer_n: int, + ): + + super().__init__() + self.config = config + self.hidden_dim = config.hidden_size + self.num_moe_experts = num_moe_experts + self.router = ZayaRouter( + config=self.config, + layer_idx=layer_n, + num_moe_experts=self.num_moe_experts, + moe_router_topk=getattr(self.config, "moe_router_topk", 1), + mlp_expansion=mlp_expansion, + hidden_size=self.hidden_dim, + ) + self.experts = ZayaExperts(self.config, self.num_moe_experts, ffn_hidden_size=ffn_hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: + route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router( + hidden_states, router_states=prev_router_hidden_states + ) + batch_size, seq_length, emb_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) + expert_output = self.experts(hidden_states_flat, expert_choice, route_prob) + expert_output = expert_output.view(batch_size, seq_length, emb_dim) + + return expert_output, prev_router_hidden_states, router_logits + + +class ZayaDecoderMLPLayer(GradientCheckpointingLayer): + def __init__( + self, + config: ZayaConfig, + num_moe_experts: int, + mlp_expansion: int, + ffn_hidden_size: int, + layer_n: int, + ): + + super().__init__() + self.config = config + self.zaya_block = ZayaBlock( + config, + num_moe_experts, + mlp_expansion, + ffn_hidden_size, + layer_n, + ) + self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + self.res_scale = ResidualScaling(config, layer_n) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + prev_router_hidden_states: torch.Tensor | None = None, + output_router_logits: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]: + hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) + + hidden_states, prev_router_hidden_states, router_logits = self.zaya_block( + hidden_states, + prev_router_hidden_states, + ) + + return ( + hidden_states, + router_logits if output_router_logits else None, + residual, + prev_router_hidden_states, + ) + + +class ZayaPreTrainedModel(PreTrainedModel): + config: ZayaConfig + config_class = ZayaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(ZayaRouter, index=3), + } + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, ResidualScaling): + init.ones_(module.hidden_states_scale) + init.zeros_(module.hidden_states_bias) + if module.not_first_layer: + init.ones_(module.residual_scale) + init.zeros_(module.residual_bias) + elif isinstance(module, ZayaRouter): + if module.use_eda: + init.ones_(module.router_states_scale) + init.zeros_(module.balancing_biases) + module.balancing_biases[-1] = -1.0 + elif isinstance(module, ZayaExperts): + std = self.config.initializer_range + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + + +@auto_docstring +class ZayaModel(ZayaPreTrainedModel): + def __init__(self, config: ZayaConfig): + + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = [] + + for layer_n in range(config.num_hidden_layers): + if layer_n % 2 == 1: + self.layers.append( + ZayaDecoderMLPLayer( + config, + config.num_experts, + config.zaya_mlp_expansion, + config.ffn_hidden_size, + layer_n, + ) + ) + else: + self.layers.append(ZayaDecoderATTLayer(config, layer_n)) + self.layers = nn.ModuleList(self.layers) + + self.gradient_checkpointing = False + self.res_scale = ResidualScaling(config, config.num_hidden_layers) + + self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + + self.rotary_emb = ZayaRotaryEmbedding(config=config) + if self.config.swa_layers is not None: + swa_config = copy.copy(config) + swa_config.rope_parameters = { + **config.rope_parameters, + "rope_theta": swa_config.swa_rotary_base, + } + self.swa_rotary_emb = ZayaRotaryEmbedding(config=swa_config) + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + output_router_logits: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = ZayaDynamicCache( + self.config, inputs_embeds.shape[0], dtype=self.dtype, device=self.device + ) + + residual = None + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ).unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + position_ids, + past_key_values, + ) + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError("ZAYA CCA requires a 2D `attention_mask` to mask padding tokens before convolution.") + # ZayaDynamicCache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. + # CCA only needs it during multi-token prefill; single-token decoding uses the cached convolution state. + attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None + if inputs_embeds.shape[1] == 1: + attention_mask_2d = None + + hidden_states = inputs_embeds + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + if self.config.swa_layers is not None: + swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + prev_router_hidden_states = None + + for layer_n, decoder_layer in enumerate(self.layers): + if self.config.swa_layers is not None: + emb_to_use = position_embeddings if self.config.swa_layers[layer_n] == 0 else swa_position_embeddings + else: + emb_to_use = position_embeddings + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + residual, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + position_embeddings=emb_to_use, + prev_router_hidden_states=prev_router_hidden_states, + attention_mask_2d=attention_mask_2d, + **kwargs, + ) + + hidden_states = layer_outputs[0] + residual = layer_outputs[2] + prev_router_hidden_states = layer_outputs[3] + + if isinstance(decoder_layer, ZayaDecoderATTLayer): + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache, + ): + return create_causal_mask( + config=self.config, + inputs_embeds=input_tensor, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + +@auto_docstring +class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _is_stateful = True + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.model = ZayaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.post_init() + + def set_decoder(self, decoder): + self.model = decoder + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_router_logits: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=None, + logits=logits, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + if past_key_values is not None and not isinstance(past_key_values, ZayaDynamicCache): + raise ValueError( + f"Zaya uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}." + ) + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + def _prepare_cache_for_generation( + self, + generation_config, + model_kwargs: dict, + generation_mode, + batch_size: int, + max_cache_length: int, + ): + if generation_config.use_cache is False: + return + + if "past_key_values" not in model_kwargs: + cache_batch_size = batch_size * max(generation_config.num_beams, generation_config.num_return_sequences) + model_kwargs["past_key_values"] = ZayaDynamicCache( + self.config, cache_batch_size, dtype=self.dtype, device=self.device + ) + generation_config.cache_implementation = None + return super()._prepare_cache_for_generation( + generation_config=generation_config, + model_kwargs=model_kwargs, + generation_mode=generation_mode, + batch_size=batch_size, + max_cache_length=max_cache_length, + ) + + +__all__ = [ + "ZayaConfig", + "ZayaPreTrainedModel", + "ZayaModel", + "ZayaForCausalLM", +] diff --git a/tests/models/zaya/__init__.py b/tests/models/zaya/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/models/zaya/__init__.py @@ -0,0 +1 @@ + From d26fffc9a241c89886b19b083f1223e1462b3c5b Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sat, 9 May 2026 20:09:44 +0800 Subject: [PATCH 02/36] add test --- tests/models/zaya/test_modeling_zaya.py | 349 ++++++++++++++++++++++++ 1 file changed, 349 insertions(+) create mode 100644 tests/models/zaya/test_modeling_zaya.py diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py new file mode 100644 index 000000000000..2338d07675af --- /dev/null +++ b/tests/models/zaya/test_modeling_zaya.py @@ -0,0 +1,349 @@ +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch ZAYA model.""" + +import unittest + +from parameterized import parameterized + +from transformers import is_torch_available +from transformers.testing_utils import cleanup, require_torch, slow, torch_device + + +if is_torch_available(): + import torch + + from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel + from transformers.models.zaya.modeling_zaya import CCA, ZayaDynamicCache + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class ZayaModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = ZayaModel + + def __init__(self, parent): + super().__init__( + parent=parent, + batch_size=2, + seq_length=7, + vocab_size=128, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=64, + ) + self.head_dim = 8 + self.ffn_hidden_size = 64 + self.num_query_groups = 2 + self.num_experts = 4 + self.moe_router_topk = 1 + self.zaya_mlp_expansion = 4 + self.tie_word_embeddings = False + self.rope_parameters = { + "rope_theta": 10000, + "rope_type": "default", + } + + +@require_torch +class ZayaModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = ZayaModelTester + test_all_params_have_gradient = False + + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + @unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("ZAYA uses MoE routing; equivalent-output comparisons are not stable for this architecture.") + def test_model_outputs_equivalence(self, **kwargs): + pass + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + config._attn_implementation = "eager" + + for model_class in self.all_model_classes: + model = model_class._from_config(config, attn_implementation="eager") + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class({**inputs_dict, "output_attentions": True}, model_class)) + + expected_attn_layers = (config.num_hidden_layers + 1) // 2 + self.assertEqual(len(outputs.attentions), expected_attn_layers) + self.assertEqual( + outputs.attentions[0].shape, + ( + self.model_tester.batch_size, + config.num_attention_heads, + self.model_tester.seq_length, + self.model_tester.seq_length, + ), + ) + + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) + @unittest.skip( + "ZAYA uses partial rotary embeddings with CCA, which is not compatible with this generic RoPE test." + ) + def test_model_rope_scaling_from_config(self, scaling_type): + pass + + @unittest.skip("ZAYA needs alternating attention and MoE layers in the tiny test configuration.") + def test_num_layers_is_small(self): + pass + + @unittest.skip("ZAYA uses a custom cache carrying CCA convolution state in addition to KV tensors.") + def test_past_key_values_format(self): + pass + + @unittest.skip("ZAYA's custom CCA cache is not a standard per-layer KV cache.") + def test_greedy_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("ZAYA's custom CCA cache is not a standard per-layer KV cache.") + def test_beam_search_generate_dict_outputs_use_cache(self): + pass + + def test_moe_router_logits(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = self.model_tester.causal_lm_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**inputs_dict, output_router_logits=True) + + expected_moe_layers = config.num_hidden_layers // 2 + self.assertEqual(len(outputs.router_logits), expected_moe_layers) + self.assertEqual( + outputs.router_logits[0].shape, + (self.model_tester.batch_size * self.model_tester.seq_length, config.num_experts + 1), + ) + + def test_moe_router_topk_validation(self): + with self.assertRaisesRegex(ValueError, "moe_router_topk=1"): + ZayaConfig(moe_router_topk=2) + + def test_cca_cache_matches_full_forward(self): + config = ZayaConfig( + vocab_size=128, + hidden_size=32, + ffn_hidden_size=64, + num_hidden_layers=1, + num_experts=4, + num_attention_heads=4, + num_key_value_heads=2, + num_query_groups=2, + head_dim=8, + zaya_mlp_expansion=4, + tie_word_embeddings=False, + ) + torch.manual_seed(0) + cca = CCA( + config, + num_key_value_heads=config.num_key_value_heads, + num_attention_heads=config.num_attention_heads, + hidden_size=config.hidden_size, + head_dim=config.head_dim, + layer_number=0, + ).to(torch_device) + cca.eval() + hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device) + + with torch.no_grad(): + full = cca(hidden_states, None, None) + cache = ZayaDynamicCache(config, batch_size=1, dtype=hidden_states.dtype, device=torch_device) + cca(hidden_states[:, :4], cache, None) + cache.has_previous_state = True + cached = cca(hidden_states[:, 4:], cache, None) + + for full_states, cached_states in zip(full, cached): + torch.testing.assert_close(full_states[:, -1:], cached_states, rtol=1e-5, atol=1e-5) + + def test_cca_cache_matches_full_forward_multi_token(self): + config = ZayaConfig( + vocab_size=128, + hidden_size=32, + ffn_hidden_size=64, + num_hidden_layers=1, + num_experts=4, + num_attention_heads=4, + num_key_value_heads=2, + num_query_groups=2, + head_dim=8, + zaya_mlp_expansion=4, + tie_word_embeddings=False, + ) + torch.manual_seed(0) + cca = CCA( + config, + num_key_value_heads=config.num_key_value_heads, + num_attention_heads=config.num_attention_heads, + hidden_size=config.hidden_size, + head_dim=config.head_dim, + layer_number=0, + ).to(torch_device) + cca.eval() + hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device) + + with torch.no_grad(): + full = cca(hidden_states, None, None) + cache = ZayaDynamicCache(config, batch_size=1, dtype=hidden_states.dtype, device=torch_device) + cca(hidden_states[:, :3], cache, None) + cache.has_previous_state = True + cached = cca(hidden_states[:, 3:], cache, None) + + for full_states, cached_states in zip(full, cached): + torch.testing.assert_close(full_states[:, 3:], cached_states, rtol=1e-5, atol=1e-5) + + def test_zaya_cache_batch_methods(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + cache = ZayaDynamicCache(config, batch_size=2, dtype=torch.float32, device=torch_device) + cache.update_conv_state( + 0, + torch.arange(2 * 2 * cache.conv_state_size, device=torch_device, dtype=torch.float32).view( + 2, 2, cache.conv_state_size + ), + ) + cache.update_prev_v2( + 0, + torch.arange( + 2 * config.num_key_value_heads * config.head_dim // 2, device=torch_device, dtype=torch.float32 + ).view(2, config.num_key_value_heads * config.head_dim // 2), + ) + self.assertEqual(cache.prev_v2[0].shape[-1], config.num_key_value_heads * config.head_dim // 2) + + cache.batch_repeat_interleave(2) + self.assertEqual(cache.conv_states[0].shape[0], 4) + self.assertEqual(cache.prev_v2[0].shape[0], 4) + + cache.batch_select_indices(torch.tensor([3, 1], device=torch_device)) + self.assertEqual(cache.conv_states[0].shape[0], 2) + self.assertEqual(cache.prev_v2[0].shape[0], 2) + + cache.reorder_cache(torch.tensor([1, 0], device=torch_device)) + self.assertEqual(cache.batch_size, 2) + + cache.has_previous_state = True + cache.reset() + self.assertFalse(cache.has_previous_state) + self.assertEqual(cache.conv_states[0].sum().item(), 0) + self.assertEqual(cache.prev_v2[0].sum().item(), 0) + + +@require_torch +class ZayaIntegrationTest(unittest.TestCase): + model = None + model_id = "Zyphra/ZAYA1-8B" + + @classmethod + def get_model(cls): + if cls.model is None: + cls.model = ZayaForCausalLM.from_pretrained(cls.model_id, device_map="auto", dtype=torch.bfloat16) + return cls.model + + @classmethod + def tearDownClass(cls): + if cls.model is not None: + del cls.model + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def get_inputs(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + inputs = tokenizer("Hello! How can I assist you today?", return_tensors="pt") + self.assertEqual( + inputs.input_ids.tolist(), + [[2, 9259, 236888, 2088, 740, 564, 6361, 611, 3124, 236881, 106]], + ) + return inputs + + @slow + def test_model_logits(self): + model = self.get_model() + inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) + + with torch.no_grad(): + outputs = model(**inputs, use_cache=False, output_hidden_states=True, return_dict=True) + + logits = outputs.logits.float().cpu() + hidden_states = outputs.hidden_states[-1].float().cpu() + + EXPECTED_HIDDEN_MEAN = torch.tensor( + [[0.0399, -0.0123, -0.0560, -0.0480, -0.0627, -0.0362, -0.0220, 0.0004, -0.0321, -0.0263, 0.0046]] + ) + torch.testing.assert_close(hidden_states.mean(-1), EXPECTED_HIDDEN_MEAN, rtol=1e-2, atol=1e-2) + + EXPECTED_HIDDEN_SLICE = torch.tensor([-2.7812, 0.3320, 4.1562, -0.4395, 1.6406, 1.3359, -1.4531, -2.6719, 5.5000, -4.7500, 2.0625, 0.2930, -2.2344, -2.6094, 2.0781, 2.5000, 0.7969, 0.6836, -0.5469, 1.3906]) # fmt: skip + torch.testing.assert_close(hidden_states[0, 0, :20], EXPECTED_HIDDEN_SLICE, rtol=1e-2, atol=1e-2) + + EXPECTED_LOGITS_SLICE = torch.tensor([-2.3438, 1.7344, 3.7656, -3.8750, 0.4707, -0.7422, -2.5938, -2.7188, -2.9375, -2.9844, -3.0000, -3.0000, -3.0000, -3.0000, -3.0156, -3.0000, -3.0000, -3.0000, -3.0000, -3.0000]) # fmt: skip + torch.testing.assert_close(logits[0, -1, :20], EXPECTED_LOGITS_SLICE, rtol=1e-2, atol=1e-2) + self.assertEqual(logits[0, -1].argmax().item(), 107) + + @slow + def test_model_cache_matches_full_forward(self): + model = self.get_model() + inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) + + with torch.no_grad(): + full_logits = model(**inputs, use_cache=False).logits[:, -1] + prefill_outputs = model( + input_ids=inputs.input_ids[:, :-1], + attention_mask=inputs.attention_mask[:, :-1], + use_cache=True, + return_dict=True, + ) + cached_logits = model( + input_ids=inputs.input_ids[:, -1:], + attention_mask=inputs.attention_mask, + past_key_values=prefill_outputs.past_key_values, + use_cache=True, + return_dict=True, + ).logits[:, -1] + + torch.testing.assert_close(cached_logits.float().cpu(), full_logits.float().cpu(), rtol=1e-4, atol=1e-4) + + @slow + def test_model_generation(self): + model = self.get_model() + inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) + + with torch.no_grad(): + generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=3, top_k=None, top_p=None) + + self.assertEqual(generated_ids[0, -3:].tolist(), [107, 262146, 108]) From 8191d39741e5cb55347ce78772ac14a5de1a335f Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sat, 9 May 2026 20:11:19 +0800 Subject: [PATCH 03/36] update example --- docs/source/en/model_doc/zaya.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md index 7f881a47efb9..01e7a8504e1d 100644 --- a/docs/source/en/model_doc/zaya.md +++ b/docs/source/en/model_doc/zaya.md @@ -35,8 +35,9 @@ model_id = "Zyphra/ZAYA1-8B" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") -inputs = tokenizer("What factors contributed to the fall of the Roman Empire?", return_tensors="pt").to(model.device) -outputs = model.generate(**inputs, max_new_tokens=100) +inputs = tokenizer.apply_chat_template([{"role": "user", "content": "Write a haiku about recursion in programming."}], tokenize=True, add_generation_prompt=True, enable_thinking=False, return_tensors="pt") +inputs = inputs.to(model.device) +outputs = model.generate(**inputs, max_new_tokens=2048) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` From c125ef3124143bede61c1928cdecf75fb1ed7ae3 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sat, 9 May 2026 20:43:13 +0800 Subject: [PATCH 04/36] new config --- .../models/zaya/configuration_zaya.py | 157 +++++++----------- src/transformers/models/zaya/modular_zaya.py | 156 +++++++---------- 2 files changed, 125 insertions(+), 188 deletions(-) diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py index 506df6eee3f0..12a7c2999abc 100644 --- a/src/transformers/models/zaya/configuration_zaya.py +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -18,27 +18,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +from huggingface_hub.dataclasses import strict + from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters from ...utils import auto_docstring @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +@strict class ZayaConfig(PreTrainedConfig): r""" - num_query_groups (`int`, *optional*, defaults to 2): - Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`. - lm_head_bias (`bool`, *optional*, defaults to `False`): - Whether to add a bias to the language modeling head. ffn_hidden_size (`int`, *optional*, defaults to 4096): Dimension of the feed-forward and expert hidden states. + num_query_groups (`int`, *optional*, defaults to 2): + Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`. rope_theta (`float`, *optional*, defaults to 5000000): The base period of the RoPE embeddings. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Fraction of each attention head dimension using rotary embeddings. + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the language modeling head. moe_router_topk (`int`, *optional*, defaults to 1): Number of selected experts per token. ZAYA checkpoints use top-1 routing. zaya_mlp_expansion (`int`, *optional*, defaults to 256): Expansion size used by the dense ZAYA blocks. - partial_rotary_factor (`float`, *optional*, defaults to 0.5): - Fraction of each attention head dimension using rotary embeddings. cca_time0 (`int`, *optional*, defaults to 2): First temporal parameter of the CCA projection. cca_time1 (`int`, *optional*, defaults to 2): @@ -61,42 +65,39 @@ class ZayaConfig(PreTrainedConfig): model_type = "zaya" keys_to_ignore_at_inference = ["past_key_values"] - def __init__( - self, - num_query_groups=2, - use_cache=True, - attention_bias=False, - lm_head_bias=False, - vocab_size=262272, - hidden_size=2048, - ffn_hidden_size=4096, - num_hidden_layers=80, - num_experts=16, - num_attention_heads=8, - hidden_act="silu", - head_dim=128, - initializer_range=0.02, - max_position_embeddings=131072, - norm_epsilon=1e-05, - pad_token_id=0, - bos_token_id=2, - eos_token_id=106, - tie_word_embeddings=True, - rope_theta=5000000, - attention_dropout=0.0, - moe_router_topk=1, - zaya_mlp_expansion=256, - rope_parameters=None, - partial_rotary_factor=0.5, - num_key_value_heads=2, - cca_time0=2, - cca_time1=2, - swa_layers=None, - swa_rotary_base=None, - output_router_logits=False, - _attn_implementation="eager", - **kwargs, - ): + vocab_size: int = 262272 + hidden_size: int = 2048 + ffn_hidden_size: int = 4096 + num_hidden_layers: int = 80 + num_experts: int = 16 + num_attention_heads: int = 8 + num_key_value_heads: int | None = 2 + num_query_groups: int | None = 2 + hidden_act: str = "silu" + head_dim: int = 128 + max_position_embeddings: int = 131072 + initializer_range: float = 0.02 + norm_epsilon: float = 1e-5 + use_cache: bool = True + tie_word_embeddings: bool = True + rope_parameters: RopeParameters | dict | None = None + rope_theta: float | int = 5000000 + partial_rotary_factor: float = 0.5 + attention_bias: bool = False + lm_head_bias: bool = False + attention_dropout: float | int = 0.0 + moe_router_topk: int = 1 + zaya_mlp_expansion: int = 256 + cca_time0: int | None = 2 + cca_time1: int | None = 2 + swa_layers: list[int] | None = None + swa_rotary_base: float | int | None = None + output_router_logits: bool = False + pad_token_id: int | None = 0 + bos_token_id: int | None = 2 + eos_token_id: int | list[int] | None = 106 + + def __post_init__(self, **kwargs): for unused_checkpoint_kwarg in ( "cca", "activation_func", @@ -108,6 +109,8 @@ def __init__( "bias_activation_fusion", "activation_func_fp8_input_store", "clamp_temp", + "kv_channels", + "mamba_cache_dtype", "residual_in_fp32", "rope_scaling", "scale_residual_merge", @@ -118,66 +121,32 @@ def __init__( ): kwargs.pop(unused_checkpoint_kwarg, None) - num_query_groups = num_key_value_heads if num_query_groups is None else num_query_groups - if head_dim is None: + self.num_key_value_heads = ( + self.num_attention_heads if self.num_key_value_heads is None else self.num_key_value_heads + ) + self.num_query_groups = self.num_key_value_heads if self.num_query_groups is None else self.num_query_groups + if self.head_dim is None: raise ValueError("`head_dim` must be set for ZAYA.") - if num_query_groups != num_key_value_heads: + if self.num_query_groups != self.num_key_value_heads: raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.") - if moe_router_topk != 1: + if self.moe_router_topk != 1: raise ValueError("ZAYA currently supports `moe_router_topk=1` only.") - self.num_query_groups = num_query_groups - self.use_cache = use_cache - self.attention_bias = attention_bias - self.lm_head_bias = lm_head_bias - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_experts = num_experts - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.head_dim = head_dim - self.initializer_range = initializer_range - self.num_key_value_heads = num_key_value_heads - self.max_position_embeddings = max_position_embeddings - self.norm_epsilon = norm_epsilon - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.tie_word_embeddings = tie_word_embeddings - self.attention_dropout = attention_dropout - self.moe_router_topk = moe_router_topk - self.zaya_mlp_expansion = zaya_mlp_expansion - self.partial_rotary_factor = partial_rotary_factor - self.rope_theta = rope_theta - rope_parameters = dict(rope_parameters) if rope_parameters is not None else {"rope_type": "default"} - rope_parameters.setdefault("rope_theta", rope_theta) - rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor) - self.rope_parameters = rope_parameters - cca_time0 = 2 if cca_time0 is None else cca_time0 - cca_time1 = 2 if cca_time1 is None else cca_time1 - if (cca_time0, cca_time1) != (2, 2): + self.rope_parameters = ( + dict(self.rope_parameters) if self.rope_parameters is not None else {"rope_type": "default"} + ) + self.rope_parameters.setdefault("rope_theta", self.rope_theta) + self.rope_parameters.setdefault("partial_rotary_factor", self.partial_rotary_factor) + self.cca_time0 = 2 if self.cca_time0 is None else self.cca_time0 + self.cca_time1 = 2 if self.cca_time1 is None else self.cca_time1 + if (self.cca_time0, self.cca_time1) != (2, 2): raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.") - if swa_layers is not None and len(swa_layers) != num_hidden_layers: + if self.swa_layers is not None and len(self.swa_layers) != self.num_hidden_layers: raise ValueError("`swa_layers` must have one entry per hidden layer.") - if swa_layers is not None and swa_rotary_base is None: + if self.swa_layers is not None and self.swa_rotary_base is None: raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.") - self.cca_time0 = cca_time0 - self.cca_time1 = cca_time1 - self.swa_layers = swa_layers - self.swa_rotary_base = swa_rotary_base - self.output_router_logits = output_router_logits - self._attn_implementation = _attn_implementation - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=self.tie_word_embeddings, - **kwargs, - ) + super().__post_init__(**kwargs) __all__ = ["ZayaConfig"] diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 60bb870c73a5..ee6f44c840f3 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -20,6 +20,7 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint +from huggingface_hub.dataclasses import strict from torch import nn from torch.nn import init @@ -34,6 +35,7 @@ MoeCausalLMOutputWithPast, MoeModelOutputWithPast, ) +from ...modeling_rope_utils import RopeParameters from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -52,22 +54,23 @@ @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +@strict class ZayaConfig(PreTrainedConfig): r""" - num_query_groups (`int`, *optional*, defaults to 2): - Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`. - lm_head_bias (`bool`, *optional*, defaults to `False`): - Whether to add a bias to the language modeling head. ffn_hidden_size (`int`, *optional*, defaults to 4096): Dimension of the feed-forward and expert hidden states. + num_query_groups (`int`, *optional*, defaults to 2): + Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`. rope_theta (`float`, *optional*, defaults to 5000000): The base period of the RoPE embeddings. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Fraction of each attention head dimension using rotary embeddings. + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the language modeling head. moe_router_topk (`int`, *optional*, defaults to 1): Number of selected experts per token. ZAYA checkpoints use top-1 routing. zaya_mlp_expansion (`int`, *optional*, defaults to 256): Expansion size used by the dense ZAYA blocks. - partial_rotary_factor (`float`, *optional*, defaults to 0.5): - Fraction of each attention head dimension using rotary embeddings. cca_time0 (`int`, *optional*, defaults to 2): First temporal parameter of the CCA projection. cca_time1 (`int`, *optional*, defaults to 2): @@ -90,42 +93,39 @@ class ZayaConfig(PreTrainedConfig): model_type = "zaya" keys_to_ignore_at_inference = ["past_key_values"] - def __init__( - self, - num_query_groups=2, - use_cache=True, - attention_bias=False, - lm_head_bias=False, - vocab_size=262272, - hidden_size=2048, - ffn_hidden_size=4096, - num_hidden_layers=80, - num_experts=16, - num_attention_heads=8, - hidden_act="silu", - head_dim=128, - initializer_range=0.02, - max_position_embeddings=131072, - norm_epsilon=1e-05, - pad_token_id=0, - bos_token_id=2, - eos_token_id=106, - tie_word_embeddings=True, - rope_theta=5000000, - attention_dropout=0.0, - moe_router_topk=1, - zaya_mlp_expansion=256, - rope_parameters=None, - partial_rotary_factor=0.5, - num_key_value_heads=2, - cca_time0=2, - cca_time1=2, - swa_layers=None, - swa_rotary_base=None, - output_router_logits=False, - _attn_implementation="eager", - **kwargs, - ): + vocab_size: int = 262272 + hidden_size: int = 2048 + ffn_hidden_size: int = 4096 + num_hidden_layers: int = 80 + num_experts: int = 16 + num_attention_heads: int = 8 + num_key_value_heads: int | None = 2 + num_query_groups: int | None = 2 + hidden_act: str = "silu" + head_dim: int = 128 + max_position_embeddings: int = 131072 + initializer_range: float = 0.02 + norm_epsilon: float = 1e-5 + use_cache: bool = True + tie_word_embeddings: bool = True + rope_parameters: RopeParameters | dict | None = None + rope_theta: float | int = 5000000 + partial_rotary_factor: float = 0.5 + attention_bias: bool = False + lm_head_bias: bool = False + attention_dropout: float | int = 0.0 + moe_router_topk: int = 1 + zaya_mlp_expansion: int = 256 + cca_time0: int | None = 2 + cca_time1: int | None = 2 + swa_layers: list[int] | None = None + swa_rotary_base: float | int | None = None + output_router_logits: bool = False + pad_token_id: int | None = 0 + bos_token_id: int | None = 2 + eos_token_id: int | list[int] | None = 106 + + def __post_init__(self, **kwargs): for unused_checkpoint_kwarg in ( "cca", "activation_func", @@ -137,6 +137,8 @@ def __init__( "bias_activation_fusion", "activation_func_fp8_input_store", "clamp_temp", + "kv_channels", + "mamba_cache_dtype", "residual_in_fp32", "rope_scaling", "scale_residual_merge", @@ -147,66 +149,32 @@ def __init__( ): kwargs.pop(unused_checkpoint_kwarg, None) - num_query_groups = num_key_value_heads if num_query_groups is None else num_query_groups - if head_dim is None: + self.num_key_value_heads = ( + self.num_attention_heads if self.num_key_value_heads is None else self.num_key_value_heads + ) + self.num_query_groups = self.num_key_value_heads if self.num_query_groups is None else self.num_query_groups + if self.head_dim is None: raise ValueError("`head_dim` must be set for ZAYA.") - if num_query_groups != num_key_value_heads: + if self.num_query_groups != self.num_key_value_heads: raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.") - if moe_router_topk != 1: + if self.moe_router_topk != 1: raise ValueError("ZAYA currently supports `moe_router_topk=1` only.") - self.num_query_groups = num_query_groups - self.use_cache = use_cache - self.attention_bias = attention_bias - self.lm_head_bias = lm_head_bias - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_experts = num_experts - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.head_dim = head_dim - self.initializer_range = initializer_range - self.num_key_value_heads = num_key_value_heads - self.max_position_embeddings = max_position_embeddings - self.norm_epsilon = norm_epsilon - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.tie_word_embeddings = tie_word_embeddings - self.attention_dropout = attention_dropout - self.moe_router_topk = moe_router_topk - self.zaya_mlp_expansion = zaya_mlp_expansion - self.partial_rotary_factor = partial_rotary_factor - self.rope_theta = rope_theta - rope_parameters = dict(rope_parameters) if rope_parameters is not None else {"rope_type": "default"} - rope_parameters.setdefault("rope_theta", rope_theta) - rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor) - self.rope_parameters = rope_parameters - cca_time0 = 2 if cca_time0 is None else cca_time0 - cca_time1 = 2 if cca_time1 is None else cca_time1 - if (cca_time0, cca_time1) != (2, 2): + self.rope_parameters = ( + dict(self.rope_parameters) if self.rope_parameters is not None else {"rope_type": "default"} + ) + self.rope_parameters.setdefault("rope_theta", self.rope_theta) + self.rope_parameters.setdefault("partial_rotary_factor", self.partial_rotary_factor) + self.cca_time0 = 2 if self.cca_time0 is None else self.cca_time0 + self.cca_time1 = 2 if self.cca_time1 is None else self.cca_time1 + if (self.cca_time0, self.cca_time1) != (2, 2): raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.") - if swa_layers is not None and len(swa_layers) != num_hidden_layers: + if self.swa_layers is not None and len(self.swa_layers) != self.num_hidden_layers: raise ValueError("`swa_layers` must have one entry per hidden layer.") - if swa_layers is not None and swa_rotary_base is None: + if self.swa_layers is not None and self.swa_rotary_base is None: raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.") - self.cca_time0 = cca_time0 - self.cca_time1 = cca_time1 - self.swa_layers = swa_layers - self.swa_rotary_base = swa_rotary_base - self.output_router_logits = output_router_logits - self._attn_implementation = _attn_implementation - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=self.tie_word_embeddings, - **kwargs, - ) + super().__post_init__(**kwargs) class ZayaRotaryEmbedding(Glm4RotaryEmbedding): From c90df6f33e910c666115264a9ad22fe0381cd7df Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sat, 9 May 2026 20:58:11 +0800 Subject: [PATCH 05/36] remove empty line --- src/transformers/models/zaya/modeling_zaya.py | 4 ---- src/transformers/models/zaya/modular_zaya.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index bbbecaeb1907..ab68cbc73d36 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -543,7 +543,6 @@ def _apply_residual_scaling( class ZayaDecoderATTLayer(GradientCheckpointingLayer): def __init__(self, config: ZayaConfig, layer_n: int): - super().__init__() self.config = config self.self_attn = ZayaAttention(config, layer_n) @@ -715,7 +714,6 @@ def __init__( ffn_hidden_size: int, layer_n: int, ): - super().__init__() self.config = config self.hidden_dim = config.hidden_size @@ -755,7 +753,6 @@ def __init__( ffn_hidden_size: int, layer_n: int, ): - super().__init__() self.config = config self.zaya_block = ZayaBlock( @@ -829,7 +826,6 @@ def _init_weights(self, module): @auto_docstring class ZayaModel(ZayaPreTrainedModel): def __init__(self, config: ZayaConfig): - super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index ee6f44c840f3..d5bb4efd767e 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -501,7 +501,6 @@ def forward( class ZayaDecoderATTLayer(GradientCheckpointingLayer): def __init__(self, config: ZayaConfig, layer_n: int): - super().__init__() self.config = config self.self_attn = ZayaAttention(config, layer_n) @@ -685,7 +684,6 @@ def __init__( ffn_hidden_size: int, layer_n: int, ): - super().__init__() self.config = config self.hidden_dim = config.hidden_size @@ -725,7 +723,6 @@ def __init__( ffn_hidden_size: int, layer_n: int, ): - super().__init__() self.config = config self.zaya_block = ZayaBlock( @@ -799,7 +796,6 @@ def _init_weights(self, module): @auto_docstring class ZayaModel(ZayaPreTrainedModel): def __init__(self, config: ZayaConfig): - super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size From b90759f1de9b000ab16abdc9c6cc7190ad60381f Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sat, 9 May 2026 21:04:32 +0800 Subject: [PATCH 06/36] pass ci --- src/transformers/models/auto/modeling_auto.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4d90c73183e7..6ced59e8556f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -510,9 +510,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("yolos", "YolosModel"), ("yoso", "YosoModel"), ("youtu", "YoutuModel"), - ("zaya", "ZayaModel"), ("zamba", "ZambaModel"), ("zamba2", "Zamba2Model"), + ("zaya", "ZayaModel"), ] ) @@ -773,9 +773,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("xlstm", "xLSTMForCausalLM"), ("xmod", "XmodForCausalLM"), ("youtu", "YoutuForCausalLM"), - ("zaya", "ZayaForCausalLM"), ("zamba", "ZambaForCausalLM"), ("zamba2", "Zamba2ForCausalLM"), + ("zaya", "ZayaForCausalLM"), ] ) From 7e2999929315219eb0fa2bd9ed5dbd00962deb7e Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 11:39:48 +0800 Subject: [PATCH 07/36] modify config, laguna-sytle rope --- docs/source/en/model_doc/zaya.md | 13 +- src/transformers/models/zaya/modular_zaya.py | 131 +++++++++++-------- tests/models/zaya/test_modeling_zaya.py | 99 +++++++++++++- 3 files changed, 180 insertions(+), 63 deletions(-) diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md index 01e7a8504e1d..468f7327dd86 100644 --- a/docs/source/en/model_doc/zaya.md +++ b/docs/source/en/model_doc/zaya.md @@ -25,19 +25,26 @@ Convolutional Attention (CCA), a nonlinear ZAYA1 router, and residual scaling. ZAYA1 uses the Gemma 3 tokenizer. For more details, see the [ZAYA1 model card](https://huggingface.co/Zyphra/ZAYA1-8B) and Zyphra's technical reports. +This model was contributed by [JJJYmmm](https://github.com/JJJYmmm). + ## Usage examples ```python from transformers import AutoModelForCausalLM, AutoTokenizer - model_id = "Zyphra/ZAYA1-8B" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") -inputs = tokenizer.apply_chat_template([{"role": "user", "content": "Write a haiku about recursion in programming."}], tokenize=True, add_generation_prompt=True, enable_thinking=False, return_tensors="pt") +inputs = tokenizer.apply_chat_template( + [{"role": "user", "content": "Write a haiku about recursion in programming."}], + tokenize=True, + add_generation_prompt=True, + enable_thinking=False, + return_tensors="pt", +) inputs = inputs.to(model.device) -outputs = model.generate(**inputs, max_new_tokens=2048) +outputs = model.generate(**inputs, max_new_tokens=256) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index d5bb4efd767e..6b7af760e37e 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -1,4 +1,4 @@ -# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved. +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ """PyTorch Zaya model.""" -import copy from collections.abc import Callable +from typing import Any, Literal import torch import torch.nn.functional as F @@ -45,7 +45,7 @@ ) from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs -from ..glm4.modeling_glm4 import Glm4RotaryEmbedding +from ..laguna.modeling_laguna import LagunaRotaryEmbedding from ..qwen3_5_moe.modeling_qwen3_5_moe import ( apply_rotary_pos_emb, eager_attention_forward, @@ -58,11 +58,9 @@ class ZayaConfig(PreTrainedConfig): r""" ffn_hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the feed-forward and expert hidden states. - num_query_groups (`int`, *optional*, defaults to 2): - Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`. - rope_theta (`float`, *optional*, defaults to 5000000): - The base period of the RoPE embeddings. + Dimension of the feed-forward and expert hidden states, translate it to `intermediate_size`. + num_key_value_heads (`int`, *optional*, defaults to 2): + Number of key/value groups. partial_rotary_factor (`float`, *optional*, defaults to 0.5): Fraction of each attention head dimension using rotary embeddings. lm_head_bias (`bool`, *optional*, defaults to `False`): @@ -75,7 +73,7 @@ class ZayaConfig(PreTrainedConfig): First temporal parameter of the CCA projection. cca_time1 (`int`, *optional*, defaults to 2): Second temporal parameter of the CCA projection. - swa_layers (`list[int]`, *optional*): + layer_types (`list[str]`, *optional*): Per-layer selector for standard RoPE versus SWA RoPE embeddings. swa_rotary_base (`float`, *optional*): RoPE base used by SWA layers. @@ -92,6 +90,7 @@ class ZayaConfig(PreTrainedConfig): model_type = "zaya" keys_to_ignore_at_inference = ["past_key_values"] + default_theta = 5000000.0 vocab_size: int = 262272 hidden_size: int = 2048 @@ -100,7 +99,6 @@ class ZayaConfig(PreTrainedConfig): num_experts: int = 16 num_attention_heads: int = 8 num_key_value_heads: int | None = 2 - num_query_groups: int | None = 2 hidden_act: str = "silu" head_dim: int = 128 max_position_embeddings: int = 131072 @@ -109,7 +107,6 @@ class ZayaConfig(PreTrainedConfig): use_cache: bool = True tie_word_embeddings: bool = True rope_parameters: RopeParameters | dict | None = None - rope_theta: float | int = 5000000 partial_rotary_factor: float = 0.5 attention_bias: bool = False lm_head_bias: bool = False @@ -118,8 +115,8 @@ class ZayaConfig(PreTrainedConfig): zaya_mlp_expansion: int = 256 cca_time0: int | None = 2 cca_time1: int | None = 2 - swa_layers: list[int] | None = None - swa_rotary_base: float | int | None = None + layer_types: list[str] | None = None + swa_rotary_base: float | int = 10000.0 output_router_logits: bool = False pad_token_id: int | None = 0 bos_token_id: int | None = 2 @@ -128,6 +125,7 @@ class ZayaConfig(PreTrainedConfig): def __post_init__(self, **kwargs): for unused_checkpoint_kwarg in ( "cca", + "num_query_groups", "activation_func", "normalization", "add_bias_linear", @@ -149,35 +147,68 @@ def __post_init__(self, **kwargs): ): kwargs.pop(unused_checkpoint_kwarg, None) + self.intermediate_size = self.ffn_hidden_size + self.num_experts_per_tok = self.moe_router_topk + self.num_key_value_heads = ( self.num_attention_heads if self.num_key_value_heads is None else self.num_key_value_heads ) - self.num_query_groups = self.num_key_value_heads if self.num_query_groups is None else self.num_query_groups - if self.head_dim is None: - raise ValueError("`head_dim` must be set for ZAYA.") - if self.num_query_groups != self.num_key_value_heads: - raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.") - if self.moe_router_topk != 1: - raise ValueError("ZAYA currently supports `moe_router_topk=1` only.") - self.rope_parameters = ( - dict(self.rope_parameters) if self.rope_parameters is not None else {"rope_type": "default"} - ) - self.rope_parameters.setdefault("rope_theta", self.rope_theta) - self.rope_parameters.setdefault("partial_rotary_factor", self.partial_rotary_factor) + legacy_swa_layers = kwargs.pop("swa_layers", None) + if self.layer_types is None: + if legacy_swa_layers is None: + self.layer_types = ["full_attention"] * self.num_hidden_layers + else: + self.layer_types = [ + "full_attention" if layer_type == 0 else "sliding_attention" for layer_type in legacy_swa_layers + ] + else: + self.layer_types = list(self.layer_types) + self.cca_time0 = 2 if self.cca_time0 is None else self.cca_time0 self.cca_time1 = 2 if self.cca_time1 is None else self.cca_time1 - if (self.cca_time0, self.cca_time1) != (2, 2): - raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.") - if self.swa_layers is not None and len(self.swa_layers) != self.num_hidden_layers: - raise ValueError("`swa_layers` must have one entry per hidden layer.") - if self.swa_layers is not None and self.swa_rotary_base is None: - raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.") super().__post_init__(**kwargs) + def convert_rope_params_to_dict(self, **kwargs): + default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = { + "full_attention": { + "rope_type": "default", + "rope_theta": kwargs.pop("rope_theta", self.default_theta), + "partial_rotary_factor": self.partial_rotary_factor, + }, + "sliding_attention": { + "rope_type": "default", + "rope_theta": self.swa_rotary_base, + "partial_rotary_factor": self.partial_rotary_factor, + }, + } + layer_types = set(self.layer_types) + + if self.rope_parameters is None: + self.rope_parameters = {layer_type: default_rope_params[layer_type] for layer_type in layer_types} + else: + self.rope_parameters = { + layer_type: {**default_rope_params[layer_type], **(self.rope_parameters.get(layer_type) or {})} + for layer_type in layer_types + } + + return kwargs -class ZayaRotaryEmbedding(Glm4RotaryEmbedding): + def validate_architecture(self): + if self.head_dim is None: + raise ValueError("`head_dim` must be set for ZAYA.") + if self.num_experts_per_tok != 1: + raise ValueError("ZAYA currently supports `moe_router_topk=1` only.") + if len(self.layer_types) != self.num_hidden_layers: + raise ValueError("`layer_types` must have one entry per hidden layer.") + if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}: + raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.") + if (self.cca_time0, self.cca_time1) != (2, 2): + raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.") + + +class ZayaRotaryEmbedding(LagunaRotaryEmbedding): pass @@ -204,7 +235,7 @@ def __init__( self.device = device self.conv_kernel_size = (config.cca_time0 - 1) + (config.cca_time1 - 1) self.num_layers = config.num_hidden_layers - self.key_value_hidden_size = config.num_query_groups * config.head_dim + self.key_value_hidden_size = config.num_key_value_heads * config.head_dim self.query_hidden_size = config.num_attention_heads * config.head_dim self.conv_state_size = self.key_value_hidden_size + self.query_hidden_size self.has_previous_state = False @@ -439,7 +470,7 @@ def __init__(self, config: ZayaConfig, layer_n): self.qkv = CCA( config=self.config, num_attention_heads=self.config.num_attention_heads, - num_key_value_heads=self.config.num_query_groups, + num_key_value_heads=self.config.num_key_value_heads, hidden_size=self.hidden_size, head_dim=self.config.head_dim, cca_time0=self.config.cca_time0, @@ -639,11 +670,11 @@ def forward( class ZayaExperts(nn.Module): """Collection of expert weights stored as 3D tensors.""" - def __init__(self, config, num_experts: int, ffn_hidden_size: int): + def __init__(self, config, num_experts: int, intermediate_size: int): super().__init__() self.num_experts = num_experts self.hidden_dim = config.hidden_size - self.intermediate_dim = ffn_hidden_size // 2 + self.intermediate_dim = intermediate_size // 2 self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) self.act_fn = ACT2FN[config.hidden_act] @@ -681,7 +712,7 @@ def __init__( config, num_moe_experts: int, mlp_expansion: int, - ffn_hidden_size: int, + intermediate_size: int, layer_n: int, ): super().__init__() @@ -696,7 +727,7 @@ def __init__( mlp_expansion=mlp_expansion, hidden_size=self.hidden_dim, ) - self.experts = ZayaExperts(self.config, self.num_moe_experts, ffn_hidden_size=ffn_hidden_size) + self.experts = ZayaExperts(self.config, self.num_moe_experts, intermediate_size=intermediate_size) def forward( self, @@ -720,7 +751,7 @@ def __init__( config: ZayaConfig, num_moe_experts: int, mlp_expansion: int, - ffn_hidden_size: int, + intermediate_size: int, layer_n: int, ): super().__init__() @@ -729,7 +760,7 @@ def __init__( config, num_moe_experts, mlp_expansion, - ffn_hidden_size, + intermediate_size, layer_n, ) self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) @@ -809,7 +840,7 @@ def __init__(self, config: ZayaConfig): config, config.num_experts, config.zaya_mlp_expansion, - config.ffn_hidden_size, + config.intermediate_size, layer_n, ) ) @@ -823,13 +854,6 @@ def __init__(self, config: ZayaConfig): self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) self.rotary_emb = ZayaRotaryEmbedding(config=config) - if self.config.swa_layers is not None: - swa_config = copy.copy(config) - swa_config.rope_parameters = { - **config.rope_parameters, - "rope_theta": swa_config.swa_rotary_base, - } - self.swa_rotary_emb = ZayaRotaryEmbedding(config=swa_config) self.post_init() @@ -896,19 +920,16 @@ def forward( hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) - if self.config.swa_layers is not None: - swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids) + position_embeddings = { + layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) for layer_type in set(self.config.layer_types) + } all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None prev_router_hidden_states = None for layer_n, decoder_layer in enumerate(self.layers): - if self.config.swa_layers is not None: - emb_to_use = position_embeddings if self.config.swa_layers[layer_n] == 0 else swa_position_embeddings - else: - emb_to_use = position_embeddings + emb_to_use = position_embeddings[self.config.layer_types[layer_n]] if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index 2338d07675af..6264c05421b3 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -15,6 +15,7 @@ import unittest +from huggingface_hub.errors import StrictDataclassClassValidationError from parameterized import parameterized from transformers import is_torch_available @@ -48,7 +49,6 @@ def __init__(self, parent): ) self.head_dim = 8 self.ffn_hidden_size = 64 - self.num_query_groups = 2 self.num_experts = 4 self.moe_router_topk = 1 self.zaya_mlp_expansion = 4 @@ -115,11 +115,95 @@ def test_attention_outputs(self): @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) @unittest.skip( - "ZAYA uses partial rotary embeddings with CCA, which is not compatible with this generic RoPE test." + "RoPE-scaling-from-config test doesn't match ZAYA's nested per-layer-type rope_parameters (same as e.g. Laguna, Gemma3)." ) def test_model_rope_scaling_from_config(self, scaling_type): pass + def test_model_rope_scaling_frequencies(self): + """ + Tests the frequency properties of the different RoPE scaling types on the model RoPE layer. + Copied from Laguna to adapt to per-layer-type rope configs. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config.layer_types = ["full_attention", "sliding_attention"] + partial_rotary_factor = config.partial_rotary_factor + + def set_rope_params(rope_params): + config.rope_parameters = { + "full_attention": {**rope_params, "partial_rotary_factor": partial_rotary_factor}, + "sliding_attention": {**rope_params, "partial_rotary_factor": partial_rotary_factor}, + } + + set_rope_params({"rope_type": "default", "rope_theta": 10_000.0}) + + base_model = self.model_tester.base_model_class(config) + possible_rope_attributes = [ + "pos_emb", + "rotary_emb", + "global_rotary_emb", + "local_rotary_emb", + ] + for name, module in base_model.named_modules(): + if any(potential_name in name for potential_name in possible_rope_attributes): + rope_class = type(module) + break + + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + x = torch.randn(1, dtype=torch.float32, device=torch_device) + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device).unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device).unsqueeze(0) + + set_rope_params({"rope_type": "default", "rope_theta": 10_000.0}) + original_rope = rope_class(config=config).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, position_ids_short, layer_type="sliding_attention") + original_cos_long, original_sin_long = original_rope(x, position_ids_long, layer_type="sliding_attention") + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + + set_rope_params({"rope_type": "linear", "factor": scaling_factor, "rope_theta": 10_000.0}) + linear_scaling_rope = rope_class(config=config).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short, layer_type="sliding_attention") + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long, layer_type="sliding_attention") + torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) + torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) + + set_rope_params({"rope_type": "dynamic", "factor": scaling_factor, "rope_theta": 10_000.0}) + ntk_scaling_rope = rope_class(config=config).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short, layer_type="sliding_attention") + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long, layer_type="sliding_attention") + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue( + (ntk_scaling_rope.sliding_attention_inv_freq <= original_rope.sliding_attention_inv_freq).all() + ) + + set_rope_params({"rope_type": "yarn", "factor": scaling_factor, "rope_theta": 10_000.0}) + yarn_scaling_rope = rope_class(config=config).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short, layer_type="sliding_attention") + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long, layer_type="sliding_attention") + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_short, original_cos_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + @unittest.skip("ZAYA needs alternating attention and MoE layers in the tiny test configuration.") def test_num_layers_is_small(self): pass @@ -153,9 +237,16 @@ def test_moe_router_logits(self): ) def test_moe_router_topk_validation(self): - with self.assertRaisesRegex(ValueError, "moe_router_topk=1"): + with self.assertRaisesRegex(StrictDataclassClassValidationError, "moe_router_topk=1"): ZayaConfig(moe_router_topk=2) + def test_legacy_swa_layers_translate_to_layer_types(self): + config = ZayaConfig(num_hidden_layers=4, swa_layers=[0, 1, 0, 1], swa_rotary_base=10000) + + self.assertEqual(config.layer_types, ["full_attention", "sliding_attention", "full_attention", "sliding_attention"]) + self.assertEqual(config.rope_parameters["full_attention"]["rope_theta"], config.default_theta) + self.assertEqual(config.rope_parameters["sliding_attention"]["rope_theta"], 10000) + def test_cca_cache_matches_full_forward(self): config = ZayaConfig( vocab_size=128, @@ -165,7 +256,6 @@ def test_cca_cache_matches_full_forward(self): num_experts=4, num_attention_heads=4, num_key_value_heads=2, - num_query_groups=2, head_dim=8, zaya_mlp_expansion=4, tie_word_embeddings=False, @@ -201,7 +291,6 @@ def test_cca_cache_matches_full_forward_multi_token(self): num_experts=4, num_attention_heads=4, num_key_value_heads=2, - num_query_groups=2, head_dim=8, zaya_mlp_expansion=4, tie_word_embeddings=False, From cf083aa17fbea23edab44cf62d114fa430eca9ed Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 12:01:07 +0800 Subject: [PATCH 08/36] use existing cache --- src/transformers/models/zaya/modular_zaya.py | 130 +++---------------- tests/models/zaya/test_modeling_zaya.py | 85 ++++++------ 2 files changed, 66 insertions(+), 149 deletions(-) diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 6b7af760e37e..a48be2edf7ee 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -14,6 +14,7 @@ """PyTorch Zaya model.""" +import copy from collections.abc import Callable from typing import Any, Literal @@ -216,95 +217,12 @@ class ZayaRMSNorm(Qwen3MoeRMSNorm): pass -class ZayaDynamicCache(DynamicCache): - """ - Cache that includes both the KV cache and the CCA cache. - """ - - def __init__( - self, - config: ZayaConfig, - batch_size: int, - dtype: torch.dtype = torch.float16, - device: str | None = None, - ): - super().__init__() - self.config = config - self.batch_size = batch_size - self.dtype = dtype - self.device = device - self.conv_kernel_size = (config.cca_time0 - 1) + (config.cca_time1 - 1) - self.num_layers = config.num_hidden_layers - self.key_value_hidden_size = config.num_key_value_heads * config.head_dim - self.query_hidden_size = config.num_attention_heads * config.head_dim - self.conv_state_size = self.key_value_hidden_size + self.query_hidden_size - self.has_previous_state = False - - self.conv_states = [None for _ in range(self.num_layers)] - self.prev_v2 = [None for _ in range(self.num_layers)] - - def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor: - if new_conv_state.shape[1] < self.conv_kernel_size: - new_conv_state = F.pad( - new_conv_state.transpose(1, 2), (self.conv_kernel_size - new_conv_state.shape[1], 0) - ) - else: - new_conv_state = new_conv_state[:, -self.conv_kernel_size :, :].transpose(1, 2) - - if self.conv_states[layer_idx] is None: - self.conv_states[layer_idx] = torch.zeros_like(new_conv_state) - - if not self.has_previous_state: - self.conv_states[layer_idx].copy_(new_conv_state) - else: - conv_state = torch.cat([self.conv_states[layer_idx], new_conv_state], dim=-1)[ - :, :, -self.conv_kernel_size : - ] - self.conv_states[layer_idx].copy_(conv_state) - return self.conv_states[layer_idx] - - def update_prev_v2(self, layer_idx: int, new_prev_v2: torch.Tensor) -> torch.Tensor: - if self.prev_v2[layer_idx] is None: - self.prev_v2[layer_idx] = torch.zeros_like(new_prev_v2) - self.prev_v2[layer_idx].copy_(new_prev_v2) - return self.prev_v2[layer_idx] - - def reset(self): - super().reset() - for conv_state in self.conv_states: - if conv_state is not None: - conv_state.zero_() - for prev_v2 in self.prev_v2: - if prev_v2 is not None: - prev_v2.zero_() - self.has_previous_state = False - - def _reorder_auxiliary_states(self, indices: torch.LongTensor): - for layer_idx, conv_state in enumerate(self.conv_states): - if conv_state is not None: - self.conv_states[layer_idx] = conv_state.index_select(0, indices.to(conv_state.device)) - for layer_idx, prev_v2 in enumerate(self.prev_v2): - if prev_v2 is not None: - self.prev_v2[layer_idx] = prev_v2.index_select(0, indices.to(prev_v2.device)) - self.batch_size = indices.shape[0] - - def reorder_cache(self, beam_idx: torch.LongTensor): - super().reorder_cache(beam_idx) - self._reorder_auxiliary_states(beam_idx) - - def batch_repeat_interleave(self, repeats: int): - super().batch_repeat_interleave(repeats) - for layer_idx, conv_state in enumerate(self.conv_states): - if conv_state is not None: - self.conv_states[layer_idx] = conv_state.repeat_interleave(repeats, dim=0) - for layer_idx, prev_v2 in enumerate(self.prev_v2): - if prev_v2 is not None: - self.prev_v2[layer_idx] = prev_v2.repeat_interleave(repeats, dim=0) - self.batch_size *= repeats - - def batch_select_indices(self, indices: torch.Tensor): - super().batch_select_indices(indices) - self._reorder_auxiliary_states(indices) +def _make_zaya_cache(config: ZayaConfig) -> DynamicCache: + cache_config = copy.copy(config) + # layer_types is used to distinct the rope_type (full or swa) + # so need to construct a new layer_types to construct cache + cache_config.layer_types = ["hybrid" if layer_idx % 2 == 0 else "moe" for layer_idx in range(config.num_hidden_layers)] + return DynamicCache(config=cache_config) class CCA(nn.Module): @@ -370,7 +288,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - past_key_values: ZayaDynamicCache | None, + past_key_values: Cache | None, attention_mask: torch.Tensor | None = None, ): if attention_mask is not None: @@ -393,15 +311,18 @@ def forward( ).mean(dim=-2) qk_states = qk_states.transpose(1, 2) - use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_number) if use_precomputed_states: - cached_qk_states = past_key_values.conv_states[self.layer_number] + cached_qk_states = past_key_values.layers[self.layer_number].conv_states conv_input = torch.cat([cached_qk_states, qk_states], dim=-1) else: conv_input = F.pad(qk_states, (self.total_padding, 0)) if past_key_values is not None: - past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=qk_states.transpose(1, 2)) + new_conv_state = qk_states[..., -self.total_padding :] + if new_conv_state.shape[-1] < self.total_padding: + new_conv_state = F.pad(new_conv_state, (self.total_padding - new_conv_state.shape[-1], 0)) + past_key_values.update_conv_state(new_conv_state, self.layer_number) convolved_qk_states = self.conv_qk(conv_input).transpose(1, 2) @@ -422,13 +343,13 @@ def forward( value_current = self.val_proj1(hidden_states) projected_v2 = self.val_proj2(hidden_states) if use_precomputed_states: - first_v2 = past_key_values.prev_v2[self.layer_number].unsqueeze(1) + first_v2 = past_key_values.layers[self.layer_number].recurrent_states.unsqueeze(1) else: first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size)) value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1) if past_key_values is not None: - past_key_values.update_prev_v2(self.layer_number, projected_v2[:, -1, :]) + past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_number) value = torch.cat([value_current, value_delayed], dim=-1).view( batch_size, seq_length, self.num_key_value_heads, self.head_dim @@ -890,9 +811,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = ZayaDynamicCache( - self.config, inputs_embeds.shape[0], dtype=self.dtype, device=self.device - ) + past_key_values = _make_zaya_cache(self.config) residual = None @@ -912,7 +831,7 @@ def forward( ) if attention_mask is not None and attention_mask.ndim != 2: raise ValueError("ZAYA CCA requires a 2D `attention_mask` to mask padding tokens before convolution.") - # ZayaDynamicCache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. + # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. # CCA only needs it during multi-token prefill; single-token decoding uses the cached convolution state. attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None if inputs_embeds.shape[1] == 1: @@ -959,9 +878,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -1067,11 +983,6 @@ def prepare_inputs_for_generation( logits_to_keep=None, **kwargs, ): - if past_key_values is not None and not isinstance(past_key_values, ZayaDynamicCache): - raise ValueError( - f"Zaya uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}." - ) - model_inputs = super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, @@ -1096,10 +1007,7 @@ def _prepare_cache_for_generation( return if "past_key_values" not in model_kwargs: - cache_batch_size = batch_size * max(generation_config.num_beams, generation_config.num_return_sequences) - model_kwargs["past_key_values"] = ZayaDynamicCache( - self.config, cache_batch_size, dtype=self.dtype, device=self.device - ) + model_kwargs["past_key_values"] = _make_zaya_cache(self.config) generation_config.cache_implementation = None return super()._prepare_cache_for_generation( generation_config=generation_config, diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index 6264c05421b3..b710d136a554 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -26,7 +26,8 @@ import torch from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel - from transformers.models.zaya.modeling_zaya import CCA, ZayaDynamicCache + from transformers.cache_utils import DynamicCache, LinearAttentionAndFullAttentionLayer + from transformers.models.zaya.modeling_zaya import CCA, _make_zaya_cache from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester @@ -64,6 +65,36 @@ class ZayaModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = ZayaModelTester test_all_params_have_gradient = False + def _get_conv_state_shape(self, batch_size: int, config): + conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim + return (batch_size, conv_state_size, config.cca_time0 + config.cca_time1 - 2) + + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.num_key_value_heads * config.head_dim // 2) + + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + if not isinstance(past_key_values, DynamicCache): + raise ValueError("The cache does not use the correct Cache") + + config = config.get_text_config(decoder=True) + self.assertEqual(config.num_hidden_layers, len(past_key_values)) + attention_shape = (batch_size, config.num_key_value_heads, seq_length, config.head_dim) + conv_shape = self._get_conv_state_shape(batch_size, config) + recurrent_shape = self._get_recurrent_state_shape(batch_size, config) + + for layer_idx, layer in enumerate(past_key_values.layers): + if layer_idx % 2 == 0: + self.assertIs(type(layer), LinearAttentionAndFullAttentionLayer) + self.assertEqual(layer.keys.shape, attention_shape) + self.assertEqual(layer.values.shape, attention_shape) + self.assertEqual(layer.conv_states.shape, conv_shape) + self.assertEqual(layer.recurrent_states.shape, recurrent_shape) + else: + self.assertIsNone(layer.keys) + self.assertIsNone(layer.values) + self.assertIsNone(layer.conv_states) + self.assertIsNone(layer.recurrent_states) + def is_pipeline_test_to_skip( self, pipeline_test_case_name, @@ -208,18 +239,6 @@ def set_rope_params(rope_params): def test_num_layers_is_small(self): pass - @unittest.skip("ZAYA uses a custom cache carrying CCA convolution state in addition to KV tensors.") - def test_past_key_values_format(self): - pass - - @unittest.skip("ZAYA's custom CCA cache is not a standard per-layer KV cache.") - def test_greedy_generate_dict_outputs_use_cache(self): - pass - - @unittest.skip("ZAYA's custom CCA cache is not a standard per-layer KV cache.") - def test_beam_search_generate_dict_outputs_use_cache(self): - pass - def test_moe_router_logits(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = self.model_tester.causal_lm_class(config) @@ -274,9 +293,8 @@ def test_cca_cache_matches_full_forward(self): with torch.no_grad(): full = cca(hidden_states, None, None) - cache = ZayaDynamicCache(config, batch_size=1, dtype=hidden_states.dtype, device=torch_device) + cache = _make_zaya_cache(config) cca(hidden_states[:, :4], cache, None) - cache.has_previous_state = True cached = cca(hidden_states[:, 4:], cache, None) for full_states, cached_states in zip(full, cached): @@ -309,47 +327,38 @@ def test_cca_cache_matches_full_forward_multi_token(self): with torch.no_grad(): full = cca(hidden_states, None, None) - cache = ZayaDynamicCache(config, batch_size=1, dtype=hidden_states.dtype, device=torch_device) + cache = _make_zaya_cache(config) cca(hidden_states[:, :3], cache, None) - cache.has_previous_state = True cached = cca(hidden_states[:, 3:], cache, None) for full_states, cached_states in zip(full, cached): torch.testing.assert_close(full_states[:, 3:], cached_states, rtol=1e-5, atol=1e-5) - def test_zaya_cache_batch_methods(self): + def test_zaya_cache_reorder_and_reset(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - cache = ZayaDynamicCache(config, batch_size=2, dtype=torch.float32, device=torch_device) + cache = _make_zaya_cache(config) + conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim cache.update_conv_state( - 0, - torch.arange(2 * 2 * cache.conv_state_size, device=torch_device, dtype=torch.float32).view( - 2, 2, cache.conv_state_size + torch.arange(2 * conv_state_size * 2, device=torch_device, dtype=torch.float32).view( + 2, conv_state_size, 2 ), - ) - cache.update_prev_v2( 0, + ) + cache.update_recurrent_state( torch.arange( 2 * config.num_key_value_heads * config.head_dim // 2, device=torch_device, dtype=torch.float32 ).view(2, config.num_key_value_heads * config.head_dim // 2), + 0, ) - self.assertEqual(cache.prev_v2[0].shape[-1], config.num_key_value_heads * config.head_dim // 2) - - cache.batch_repeat_interleave(2) - self.assertEqual(cache.conv_states[0].shape[0], 4) - self.assertEqual(cache.prev_v2[0].shape[0], 4) - - cache.batch_select_indices(torch.tensor([3, 1], device=torch_device)) - self.assertEqual(cache.conv_states[0].shape[0], 2) - self.assertEqual(cache.prev_v2[0].shape[0], 2) + self.assertEqual(cache.layers[0].recurrent_states.shape[-1], config.num_key_value_heads * config.head_dim // 2) cache.reorder_cache(torch.tensor([1, 0], device=torch_device)) - self.assertEqual(cache.batch_size, 2) + self.assertEqual(cache.layers[0].conv_states.shape[0], 2) - cache.has_previous_state = True cache.reset() - self.assertFalse(cache.has_previous_state) - self.assertEqual(cache.conv_states[0].sum().item(), 0) - self.assertEqual(cache.prev_v2[0].sum().item(), 0) + self.assertFalse(cache.has_previous_state(0)) + self.assertEqual(cache.layers[0].conv_states.sum().item(), 0) + self.assertEqual(cache.layers[0].recurrent_states.sum().item(), 0) @require_torch From 69d09f3f56692bcfd917856dea6ee98d311aa51b Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 14:05:24 +0800 Subject: [PATCH 09/36] cca refine + use llama attn --- src/transformers/conversion_mapping.py | 2 + src/transformers/models/zaya/modular_zaya.py | 140 +++++++++---------- 2 files changed, 67 insertions(+), 75 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index dff0f65f5b53..0bf2c311845b 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -562,6 +562,8 @@ def _build_checkpoint_conversion_mapping(): ), ], "zaya": [ + WeightRenaming(r"self_attn\.qkv\.conv_qk\.0\.", "self_attn.qkv.conv_qk_depthwise."), + WeightRenaming(r"self_attn\.qkv\.conv_qk\.1\.", "self_attn.qkv.conv_qk_grouped."), WeightConverter( source_patterns="zaya_block.experts.local_experts.*.linear_fc1.weight", target_patterns="zaya_block.experts.gate_up_proj", diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index a48be2edf7ee..14c97c4dbc56 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -47,6 +47,7 @@ from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ..laguna.modeling_laguna import LagunaRotaryEmbedding +from ..llama.modeling_llama import LlamaAttention from ..qwen3_5_moe.modeling_qwen3_5_moe import ( apply_rotary_pos_emb, eager_attention_forward, @@ -201,6 +202,8 @@ def validate_architecture(self): raise ValueError("`head_dim` must be set for ZAYA.") if self.num_experts_per_tok != 1: raise ValueError("ZAYA currently supports `moe_router_topk=1` only.") + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") if len(self.layer_types) != self.num_hidden_layers: raise ValueError("`layer_types` must have one entry per hidden layer.") if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}: @@ -221,42 +224,43 @@ def _make_zaya_cache(config: ZayaConfig) -> DynamicCache: cache_config = copy.copy(config) # layer_types is used to distinct the rope_type (full or swa) # so need to construct a new layer_types to construct cache - cache_config.layer_types = ["hybrid" if layer_idx % 2 == 0 else "moe" for layer_idx in range(config.num_hidden_layers)] + cache_config.layer_types = [ + "hybrid" if layer_idx % 2 == 0 else "moe" for layer_idx in range(config.num_hidden_layers) + ] return DynamicCache(config=cache_config) -class CCA(nn.Module): - def __init__( - self, - config: ZayaConfig, - num_key_value_heads: int = 2, - num_attention_heads: int = 8, - hidden_size: int | None = None, - head_dim: int = 128, - cca_time0: int = 2, - cca_time1: int = 2, - layer_number: int = 0, - ): +class ZayaCCAProjection(nn.Module): + """ + Projects hidden states into attention q/k/v states with ZAYA's CCA path. + + `linear_q` and `linear_k` produce the residual q/k states and are concatenated into `qk_states`. The causal + `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail; + for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. + Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses + `val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**. + + The final q/k states are L2-normalized. `temp` is the learned per-KV-head scale applied to keys. + """ + + def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__() self.config = config - self.layer_number = layer_number + self.layer_idx = layer_idx - self.hidden_size = int(hidden_size or config.hidden_size) + self.hidden_size = config.hidden_size - self.depthwise_kernel_size = cca_time0 - self.grouped_kernel_size = cca_time1 + self.depthwise_kernel_size = config.cca_time0 + self.grouped_kernel_size = config.cca_time1 self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) - self.num_key_value_heads = int(num_key_value_heads) - self.num_attention_heads = int(num_attention_heads) - - self.head_dim = int(head_dim) + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim self.key_value_hidden_size = self.num_key_value_heads * self.head_dim self.query_hidden_size = self.num_attention_heads * self.head_dim self.sqrt_head_dim = self.head_dim**0.5 self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - if self.num_attention_heads % self.num_key_value_heads != 0: - raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias) self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias) @@ -264,23 +268,21 @@ def __init__( self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias) conv_channels = self.key_value_hidden_size + self.query_hidden_size - self.conv_qk = nn.Sequential( - nn.Conv1d( - in_channels=conv_channels, - out_channels=conv_channels, - kernel_size=self.depthwise_kernel_size, - groups=conv_channels, - padding=0, - stride=1, - ), - nn.Conv1d( - in_channels=conv_channels, - out_channels=conv_channels, - kernel_size=self.grouped_kernel_size, - groups=(self.num_key_value_heads + self.num_attention_heads), - padding=0, - stride=1, - ), + self.conv_qk_depthwise = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.depthwise_kernel_size, + groups=conv_channels, + padding=0, + stride=1, + ) + self.conv_qk_grouped = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.grouped_kernel_size, + groups=(self.num_key_value_heads + self.num_attention_heads), + padding=0, + stride=1, ) self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads)) @@ -311,9 +313,9 @@ def forward( ).mean(dim=-2) qk_states = qk_states.transpose(1, 2) - use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_number) + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx) if use_precomputed_states: - cached_qk_states = past_key_values.layers[self.layer_number].conv_states + cached_qk_states = past_key_values.layers[self.layer_idx].conv_states conv_input = torch.cat([cached_qk_states, qk_states], dim=-1) else: conv_input = F.pad(qk_states, (self.total_padding, 0)) @@ -322,9 +324,10 @@ def forward( new_conv_state = qk_states[..., -self.total_padding :] if new_conv_state.shape[-1] < self.total_padding: new_conv_state = F.pad(new_conv_state, (self.total_padding - new_conv_state.shape[-1], 0)) - past_key_values.update_conv_state(new_conv_state, self.layer_number) + past_key_values.update_conv_state(new_conv_state, self.layer_idx) - convolved_qk_states = self.conv_qk(conv_input).transpose(1, 2) + convolved_qk_states = self.conv_qk_depthwise(conv_input) + convolved_qk_states = self.conv_qk_grouped(convolved_qk_states).transpose(1, 2) query = ( convolved_qk_states[..., : self.query_hidden_size].view( @@ -343,13 +346,13 @@ def forward( value_current = self.val_proj1(hidden_states) projected_v2 = self.val_proj2(hidden_states) if use_precomputed_states: - first_v2 = past_key_values.layers[self.layer_number].recurrent_states.unsqueeze(1) + first_v2 = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) else: first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size)) value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1) if past_key_values is not None: - past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_number) + past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_idx) value = torch.cat([value_current, value_delayed], dim=-1).view( batch_size, seq_length, self.num_key_value_heads, self.head_dim @@ -368,35 +371,20 @@ def forward( return query, key, value -class ZayaAttention(nn.Module): - def __init__(self, config: ZayaConfig, layer_n): - super().__init__() - self.config = config - self.layer_n = layer_n - self.layer_idx = layer_n +class ZayaAttention(LlamaAttention): + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.layer_n = layer_idx self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - self.is_causal = True - self.attention_dropout = config.attention_dropout - self.head_dim = config.head_dim - self.scaling = self.head_dim**-0.5 - self.o_proj = nn.Linear( - self.num_attention_heads * self.head_dim, - self.hidden_size, - bias=self.config.attention_bias, - ) - self.qkv = CCA( + del self.q_proj + del self.k_proj + del self.v_proj + self.qkv = ZayaCCAProjection( config=self.config, - num_attention_heads=self.config.num_attention_heads, - num_key_value_heads=self.config.num_key_value_heads, - hidden_size=self.hidden_size, - head_dim=self.config.head_dim, - cca_time0=self.config.cca_time0, - cca_time1=self.config.cca_time1, - layer_number=layer_n, + layer_idx=layer_idx, ) def forward( @@ -541,8 +529,7 @@ def __init__( zaya_first_layer = 1 self.use_eda = self.layer_idx != zaya_first_layer - ln_eps = float(getattr(config, "norm_epsilon", 1e-5)) - self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=ln_eps) + self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=config.norm_epsilon) if self.use_eda: self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion)) @@ -830,9 +817,11 @@ def forward( past_key_values, ) if attention_mask is not None and attention_mask.ndim != 2: - raise ValueError("ZAYA CCA requires a 2D `attention_mask` to mask padding tokens before convolution.") + raise ValueError( + "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." + ) # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. - # CCA only needs it during multi-token prefill; single-token decoding uses the cached convolution state. + # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state. attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None if inputs_embeds.shape[1] == 1: attention_mask_2d = None @@ -840,7 +829,8 @@ def forward( hidden_states = inputs_embeds position_embeddings = { - layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) for layer_type in set(self.config.layer_types) + layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) + for layer_type in set(self.config.layer_types) } all_hidden_states = () if output_hidden_states else None From d936d54a15f496cb4a85c0780db7885cf6cb9306 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 14:10:30 +0800 Subject: [PATCH 10/36] use dict for 2d/4d mask --- src/transformers/models/zaya/modular_zaya.py | 28 +++++++++++--------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 14c97c4dbc56..5863d125f619 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -390,14 +390,21 @@ def __init__(self, config: ZayaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - attention_mask_2d: torch.Tensor | None = None, + attention_mask: torch.Tensor | dict[str, torch.Tensor | None] | None = None, past_key_values: Cache | None = None, output_attentions: bool = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: batch_size, seq_length, _ = hidden_states.shape - query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, attention_mask_2d) + + if isinstance(attention_mask, dict): + causal_mask = attention_mask.get("causal") + padding_mask = attention_mask.get("padding") + else: + causal_mask = attention_mask + padding_mask = None + + query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask) query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads, self.head_dim) key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) @@ -412,8 +419,7 @@ def forward( if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n) - causal_mask = attention_mask - if causal_mask is not None: + if isinstance(causal_mask, torch.Tensor): causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]] attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( @@ -452,8 +458,7 @@ def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor, - attention_mask: torch.Tensor | None = None, - attention_mask_2d: torch.Tensor | None = None, + attention_mask: torch.Tensor | dict[str, torch.Tensor | None] | None = None, past_key_values: Cache | None = None, output_attentions: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, @@ -465,7 +470,6 @@ def forward( hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - attention_mask_2d=attention_mask_2d, past_key_values=past_key_values, output_attentions=output_attentions, position_embeddings=position_embeddings, @@ -822,9 +826,10 @@ def forward( ) # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state. - attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None + padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None if inputs_embeds.shape[1] == 1: - attention_mask_2d = None + padding_mask = None + attention_masks = {"causal": causal_mask, "padding": padding_mask} hidden_states = inputs_embeds @@ -845,13 +850,12 @@ def forward( layer_outputs = decoder_layer( hidden_states, residual, - attention_mask=causal_mask, + attention_mask=attention_masks, position_ids=position_ids, past_key_values=past_key_values, output_attentions=output_attentions, position_embeddings=emb_to_use, prev_router_hidden_states=prev_router_hidden_states, - attention_mask_2d=attention_mask_2d, **kwargs, ) From 733e687cf069ea12078dfd51f8f03b57c97b9595 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 15:39:12 +0800 Subject: [PATCH 11/36] optimize, reuse existing code --- src/transformers/models/zaya/modular_zaya.py | 230 ++++++++----------- tests/models/zaya/test_modeling_zaya.py | 57 +++-- 2 files changed, 135 insertions(+), 152 deletions(-) diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 5863d125f619..14f35f909634 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -29,8 +29,7 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin -from ...integrations import use_experts_implementation -from ...masking_utils import create_causal_mask +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -47,12 +46,12 @@ from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ..laguna.modeling_laguna import LagunaRotaryEmbedding -from ..llama.modeling_llama import LlamaAttention +from ..llama.modeling_llama import LlamaAttention, LlamaPreTrainedModel from ..qwen3_5_moe.modeling_qwen3_5_moe import ( apply_rotary_pos_emb, eager_attention_forward, ) -from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm +from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeExperts, Qwen3MoeRMSNorm @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") @@ -117,6 +116,7 @@ class ZayaConfig(PreTrainedConfig): zaya_mlp_expansion: int = 256 cca_time0: int | None = 2 cca_time1: int | None = 2 + sliding_window: int | None = None layer_types: list[str] | None = None swa_rotary_base: float | int = 10000.0 output_router_logits: bool = False @@ -142,7 +142,6 @@ def __post_init__(self, **kwargs): "residual_in_fp32", "rope_scaling", "scale_residual_merge", - "sliding_window", "zaya_high_prec", "zaya_use_mod", "zaya_use_eda", @@ -157,6 +156,9 @@ def __post_init__(self, **kwargs): ) legacy_swa_layers = kwargs.pop("swa_layers", None) + swa_window_sizes = {int(window_size) for window_size in (legacy_swa_layers or []) if int(window_size) > 0} + if self.sliding_window is None and swa_window_sizes: + self.sliding_window = max(swa_window_sizes) if self.layer_types is None: if legacy_swa_layers is None: self.layer_types = ["full_attention"] * self.num_hidden_layers @@ -201,13 +203,17 @@ def validate_architecture(self): if self.head_dim is None: raise ValueError("`head_dim` must be set for ZAYA.") if self.num_experts_per_tok != 1: - raise ValueError("ZAYA currently supports `moe_router_topk=1` only.") + raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.") if self.num_attention_heads % self.num_key_value_heads != 0: raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") if len(self.layer_types) != self.num_hidden_layers: raise ValueError("`layer_types` must have one entry per hidden layer.") if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}: raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.") + if "sliding_attention" in self.layer_types and self.sliding_window is None: + raise ValueError("`sliding_window` must be set when `layer_types` contains `sliding_attention`.") + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError("`sliding_window` must be a strictly positive integer.") if (self.cca_time0, self.cca_time1) != (2, 2): raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.") @@ -239,8 +245,6 @@ class ZayaCCAProjection(nn.Module): for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses `val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**. - - The final q/k states are L2-normalized. `temp` is the learned per-KV-head scale applied to keys. """ def __init__(self, config: ZayaConfig, layer_idx: int): @@ -259,7 +263,6 @@ def __init__(self, config: ZayaConfig, layer_idx: int): self.head_dim = config.head_dim self.key_value_hidden_size = self.num_key_value_heads * self.head_dim self.query_hidden_size = self.num_attention_heads * self.head_dim - self.sqrt_head_dim = self.head_dim**0.5 self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias) @@ -296,20 +299,20 @@ def forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask[:, :, None].to(hidden_states.dtype) - batch_size, seq_length, _ = hidden_states.shape + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) projected_queries = self.linear_q(hidden_states) projected_keys = self.linear_k(hidden_states) qk_states = torch.cat([projected_queries, projected_keys], dim=-1) - query_residual = projected_queries.view(batch_size, seq_length, self.num_attention_heads, self.head_dim) - key_residual = projected_keys.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim) + query_residual = projected_queries.view(*hidden_shape) + key_residual = projected_keys.view(*input_shape, self.num_key_value_heads, self.head_dim) - key_residual = key_residual.unsqueeze(-2).expand(-1, -1, -1, self.num_key_value_groups, -1) - key_residual = key_residual.reshape(batch_size, seq_length, self.num_attention_heads, self.head_dim) + key_residual = key_residual.repeat_interleave(self.num_key_value_groups, dim=-2) query_residual = (query_residual + key_residual) * 0.5 key_residual = query_residual.view( - batch_size, seq_length, self.num_key_value_heads, self.num_key_value_groups, self.head_dim + *input_shape, self.num_key_value_heads, self.num_key_value_groups, self.head_dim ).mean(dim=-2) qk_states = qk_states.transpose(1, 2) @@ -331,14 +334,14 @@ def forward( query = ( convolved_qk_states[..., : self.query_hidden_size].view( - batch_size, seq_length, self.num_attention_heads, self.head_dim + *input_shape, self.num_attention_heads, self.head_dim ) + query_residual ) key = ( convolved_qk_states[..., self.query_hidden_size :].view( - batch_size, seq_length, self.num_key_value_heads, self.head_dim + *input_shape, self.num_key_value_heads, self.head_dim ) + key_residual ) @@ -348,26 +351,16 @@ def forward( if use_precomputed_states: first_v2 = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) else: - first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size)) + first_v2 = self.val_proj2(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1) if past_key_values is not None: past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_idx) value = torch.cat([value_current, value_delayed], dim=-1).view( - batch_size, seq_length, self.num_key_value_heads, self.head_dim + *input_shape, self.num_key_value_heads, self.head_dim ) - norm_eps = torch.finfo(query.dtype).eps - query_norm = query.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) - key_norm = key.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) - - key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1) - query = query * (self.sqrt_head_dim / query_norm) - - query = query.reshape(batch_size, seq_length, self.query_hidden_size) - key = key.reshape(batch_size, seq_length, self.key_value_hidden_size) - value = value.reshape(batch_size, seq_length, self.key_value_hidden_size) return query, key, value @@ -375,6 +368,8 @@ class ZayaAttention(LlamaAttention): def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__(config, layer_idx) self.layer_n = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads @@ -405,9 +400,14 @@ def forward( padding_mask = None query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask) - query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads, self.head_dim) - key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) - value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + + norm_eps = torch.finfo(query_states.dtype).eps + head_dim_scale = self.scaling**-1 + query_states = query_states * ( + head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)) + key_states = key_states * self.qkv.temp[None, None, :, None] query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -433,15 +433,13 @@ def forward( causal_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, + sliding_window=self.sliding_window, output_attentions=output_attentions, ) attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_values @@ -457,14 +455,14 @@ def __init__(self, config: ZayaConfig, layer_n: int): def forward( self, hidden_states: torch.Tensor, - residual: torch.Tensor, + residual: torch.Tensor | None, attention_mask: torch.Tensor | dict[str, torch.Tensor | None] | None = None, past_key_values: Cache | None = None, output_attentions: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, prev_router_hidden_states: torch.Tensor | None = None, **kwargs, - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]: hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) hidden_states, self_attn_weights, _ = self.self_attn( @@ -508,13 +506,27 @@ def _apply_residual_scaling( return hidden_states, residual +class ZayaRouterMLP(nn.Module): + def __init__(self, hidden_size: int, num_experts: int): + super().__init__() + self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) + self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.act_fn(self.fc1(hidden_states)) + hidden_states = self.act_fn(self.fc2(hidden_states)) + return self.out_proj(hidden_states) + + class ZayaRouter(nn.Module): def __init__( self, config, layer_idx: int, num_moe_experts: int, - moe_router_topk: int, + num_experts_per_tok: int, mlp_expansion: int, hidden_size: int | None = None, ) -> None: @@ -525,7 +537,7 @@ def __init__( self.layer_idx = layer_idx self.num_experts = num_moe_experts + 1 - self.topk = int(moe_router_topk) + self.topk = int(num_experts_per_tok) self.mlp_expansion = int(mlp_expansion) self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True) @@ -537,14 +549,7 @@ def __init__( if self.use_eda: self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion)) - self.non_linearity = nn.GELU() - self.router_mlp = nn.Sequential( - nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True), - self.non_linearity, - nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True), - self.non_linearity, - nn.Linear(self.mlp_expansion, self.num_experts, bias=False), - ) + self.router_mlp = ZayaRouterMLP(self.mlp_expansion, self.num_experts) self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) self.balancing_biases[-1] = -1.0 @@ -578,12 +583,9 @@ def forward( ) -@use_experts_implementation -class ZayaExperts(nn.Module): - """Collection of expert weights stored as 3D tensors.""" - +class ZayaExperts(Qwen3MoeExperts): def __init__(self, config, num_experts: int, intermediate_size: int): - super().__init__() + nn.Module.__init__(self) self.num_experts = num_experts self.hidden_dim = config.hidden_size self.intermediate_dim = intermediate_size // 2 @@ -591,34 +593,8 @@ def __init__(self, config, num_experts: int, intermediate_size: int): self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) self.act_fn = ACT2FN[config.hidden_act] - def forward( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - ) -> torch.Tensor: - final_hidden_states = torch.zeros_like(hidden_states) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts + 1) - expert_mask = expert_mask.permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for expert_idx in expert_hit: - expert_idx = expert_idx[0] - if expert_idx == self.num_experts: - continue - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) - current_hidden_states = self.act_fn(gate) * up - current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] - final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) - - return final_hidden_states - - -class ZayaBlock(nn.Module): + +class ZayaSparseMoeBlock(nn.Module): def __init__( self, config, @@ -635,7 +611,7 @@ def __init__( config=self.config, layer_idx=layer_n, num_moe_experts=self.num_moe_experts, - moe_router_topk=getattr(self.config, "moe_router_topk", 1), + num_experts_per_tok=self.config.num_experts_per_tok, mlp_expansion=mlp_expansion, hidden_size=self.hidden_dim, ) @@ -649,6 +625,10 @@ def forward( route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router( hidden_states, router_states=prev_router_hidden_states ) + skip_expert = expert_choice == self.num_moe_experts + route_prob = route_prob.masked_fill(skip_expert, 0) + expert_choice = expert_choice.masked_fill(skip_expert, 0) + batch_size, seq_length, emb_dim = hidden_states.shape hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) expert_output = self.experts(hidden_states_flat, expert_choice, route_prob) @@ -668,7 +648,7 @@ def __init__( ): super().__init__() self.config = config - self.zaya_block = ZayaBlock( + self.zaya_block = ZayaSparseMoeBlock( config, num_moe_experts, mlp_expansion, @@ -701,24 +681,21 @@ def forward( ) -class ZayaPreTrainedModel(PreTrainedModel): +class ZayaPreTrainedModel(LlamaPreTrainedModel): config: ZayaConfig config_class = ZayaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True _no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_attention_backend = True + # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache. + _can_compile_fullgraph = False _can_record_outputs = { "router_logits": OutputRecorder(ZayaRouter, index=3), + "hidden_states": [ZayaDecoderATTLayer, ZayaDecoderMLPLayer], + "attentions": ZayaAttention, } @torch.no_grad() def _init_weights(self, module): - super()._init_weights(module) + PreTrainedModel._init_weights(self, module) if isinstance(module, ResidualScaling): init.ones_(module.hidden_states_scale) init.zeros_(module.hidden_states_bias) @@ -786,18 +763,11 @@ def forward( past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - output_router_logits: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -814,22 +784,25 @@ def forward( device=inputs_embeds.device, ).unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - position_ids, - past_key_values, - ) - if attention_mask is not None and attention_mask.ndim != 2: + if isinstance(attention_mask, dict): + causal_mask_mapping = attention_mask + padding_mask = None + else: + causal_mask_mapping = self._update_causal_mask( + attention_mask, + inputs_embeds, + position_ids, + past_key_values, + ) + padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None + if attention_mask is not None and not isinstance(attention_mask, dict) and attention_mask.ndim != 2: raise ValueError( "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." ) # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state. - padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None if inputs_embeds.shape[1] == 1: padding_mask = None - attention_masks = {"causal": causal_mask, "padding": padding_mask} hidden_states = inputs_embeds @@ -838,22 +811,18 @@ def forward( for layer_type in set(self.config.layer_types) } - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None prev_router_hidden_states = None for layer_n, decoder_layer in enumerate(self.layers): - emb_to_use = position_embeddings[self.config.layer_types[layer_n]] - if output_hidden_states: - all_hidden_states += (hidden_states,) - + layer_type = self.config.layer_types[layer_n] + emb_to_use = position_embeddings[layer_type] + attention_mask = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask} layer_outputs = decoder_layer( hidden_states, residual, - attention_mask=attention_masks, + attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - output_attentions=output_attentions, position_embeddings=emb_to_use, prev_router_hidden_states=prev_router_hidden_states, **kwargs, @@ -863,20 +832,11 @@ def forward( residual = layer_outputs[2] prev_router_hidden_states = layer_outputs[3] - if isinstance(decoder_layer, ZayaDecoderATTLayer): - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm) - if output_hidden_states: - all_hidden_states += (hidden_states,) - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, ) def _update_causal_mask( @@ -886,13 +846,21 @@ def _update_causal_mask( position_ids: torch.Tensor, past_key_values: Cache, ): - return create_causal_mask( - config=self.config, - inputs_embeds=input_tensor, - attention_mask=attention_mask, - past_key_values=past_key_values, - position_ids=position_ids, - ) + mask_kwargs = { + "config": self.config, + "inputs_embeds": input_tensor, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + mask_creation_functions = { + "full_attention": lambda: create_causal_mask(**mask_kwargs), + "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs), + } + causal_mask_mapping = {} + for layer_type in set(self.config.layer_types): + causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]() + return causal_mask_mapping @auto_docstring diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index b710d136a554..5e16b744c989 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -27,7 +27,7 @@ from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel from transformers.cache_utils import DynamicCache, LinearAttentionAndFullAttentionLayer - from transformers.models.zaya.modeling_zaya import CCA, _make_zaya_cache + from transformers.models.zaya.modeling_zaya import ZayaCCAProjection, _make_zaya_cache from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester @@ -90,8 +90,8 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(layer.conv_states.shape, conv_shape) self.assertEqual(layer.recurrent_states.shape, recurrent_shape) else: - self.assertIsNone(layer.keys) - self.assertIsNone(layer.values) + self.assertIsNone(getattr(layer, "keys", None)) + self.assertIsNone(getattr(layer, "values", None)) self.assertIsNone(layer.conv_states) self.assertIsNone(layer.recurrent_states) @@ -260,12 +260,41 @@ def test_moe_router_topk_validation(self): ZayaConfig(moe_router_topk=2) def test_legacy_swa_layers_translate_to_layer_types(self): - config = ZayaConfig(num_hidden_layers=4, swa_layers=[0, 1, 0, 1], swa_rotary_base=10000) + config = ZayaConfig(num_hidden_layers=4, swa_layers=[4096, 0, 4096, 0], swa_rotary_base=10000) - self.assertEqual(config.layer_types, ["full_attention", "sliding_attention", "full_attention", "sliding_attention"]) + self.assertEqual( + config.layer_types, ["sliding_attention", "full_attention", "sliding_attention", "full_attention"] + ) + self.assertEqual(config.sliding_window, 4096) self.assertEqual(config.rope_parameters["full_attention"]["rope_theta"], config.default_theta) self.assertEqual(config.rope_parameters["sliding_attention"]["rope_theta"], 10000) + def test_sliding_attention_mask_is_used(self): + config = ZayaConfig( + vocab_size=128, + hidden_size=32, + ffn_hidden_size=64, + num_hidden_layers=4, + num_experts=4, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + zaya_mlp_expansion=4, + layer_types=["sliding_attention", "full_attention", "full_attention", "full_attention"], + sliding_window=3, + tie_word_embeddings=False, + attn_implementation="eager", + ) + model = ZayaModel(config).to(torch_device) + model.eval() + input_ids = torch.arange(6, device=torch_device).unsqueeze(0) + + with torch.no_grad(): + outputs = model(input_ids=input_ids, output_attentions=True) + + sliding_attention = outputs.attentions[0] + self.assertTrue(torch.all(sliding_attention[:, :, -1, :3] == 0)) + def test_cca_cache_matches_full_forward(self): config = ZayaConfig( vocab_size=128, @@ -280,14 +309,7 @@ def test_cca_cache_matches_full_forward(self): tie_word_embeddings=False, ) torch.manual_seed(0) - cca = CCA( - config, - num_key_value_heads=config.num_key_value_heads, - num_attention_heads=config.num_attention_heads, - hidden_size=config.hidden_size, - head_dim=config.head_dim, - layer_number=0, - ).to(torch_device) + cca = ZayaCCAProjection(config, layer_idx=0).to(torch_device) cca.eval() hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device) @@ -314,14 +336,7 @@ def test_cca_cache_matches_full_forward_multi_token(self): tie_word_embeddings=False, ) torch.manual_seed(0) - cca = CCA( - config, - num_key_value_heads=config.num_key_value_heads, - num_attention_heads=config.num_attention_heads, - hidden_size=config.hidden_size, - head_dim=config.head_dim, - layer_number=0, - ).to(torch_device) + cca = ZayaCCAProjection(config, layer_idx=0).to(torch_device) cca.eval() hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device) From eb7c8cc7cc5224fe32b8e735402f5659ea708da3 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 16:20:13 +0800 Subject: [PATCH 12/36] inherit from AfmoeForCausalLM, but need to construct cache from _make_zaya_cache --- src/transformers/models/zaya/modular_zaya.py | 129 +++---------------- 1 file changed, 16 insertions(+), 113 deletions(-) diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 14f35f909634..50cad3bd10ea 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -26,25 +26,21 @@ from torch.nn import init from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache, DynamicCache, LinearAttentionAndFullAttentionLayer from ...configuration_utils import PreTrainedConfig -from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import ( - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, -) +from ...modeling_outputs import MoeModelOutputWithPast from ...modeling_rope_utils import RopeParameters from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, auto_docstring, - can_return_tuple, ) from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..afmoe.modeling_afmoe import AfmoeForCausalLM from ..laguna.modeling_laguna import LagunaRotaryEmbedding from ..llama.modeling_llama import LlamaAttention, LlamaPreTrainedModel from ..qwen3_5_moe.modeling_qwen3_5_moe import ( @@ -236,6 +232,14 @@ def _make_zaya_cache(config: ZayaConfig) -> DynamicCache: return DynamicCache(config=cache_config) +def _is_zaya_cache(past_key_values: Cache) -> bool: + return ( + isinstance(past_key_values, DynamicCache) + and len(past_key_values.layers) > 0 + and isinstance(past_key_values.layers[0], LinearAttentionAndFullAttentionLayer) + ) + + class ZayaCCAProjection(nn.Module): """ Projects hidden states into attention q/k/v states with ZAYA's CCA path. @@ -771,7 +775,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if use_cache and past_key_values is None: + if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)): + if past_key_values is not None and past_key_values.get_seq_length() > 0: + raise ValueError("ZAYA requires a native hybrid cache created from `_make_zaya_cache`.") past_key_values = _make_zaya_cache(self.config) residual = None @@ -863,8 +869,8 @@ def _update_causal_mask( return causal_mask_mapping -@auto_docstring -class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +class ZayaForCausalLM(ZayaPreTrainedModel, AfmoeForCausalLM): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _is_stateful = True @@ -873,112 +879,9 @@ def __init__(self, config, **kwargs): self.model = ZayaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight self.post_init() - def set_decoder(self, decoder): - self.model = decoder - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_router_logits: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> MoeCausalLMOutputWithPast: - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_router_logits=output_router_logits, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - loss = None - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=None, - logits=logits, - past_key_values=outputs.past_key_values if use_cache else None, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=outputs.router_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - position_ids=None, - use_cache=True, - logits_to_keep=None, - **kwargs, - ): - model_inputs = super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - **kwargs, - ) - return model_inputs - - def _prepare_cache_for_generation( - self, - generation_config, - model_kwargs: dict, - generation_mode, - batch_size: int, - max_cache_length: int, - ): - if generation_config.use_cache is False: - return - - if "past_key_values" not in model_kwargs: - model_kwargs["past_key_values"] = _make_zaya_cache(self.config) - generation_config.cache_implementation = None - return super()._prepare_cache_for_generation( - generation_config=generation_config, - model_kwargs=model_kwargs, - generation_mode=generation_mode, - batch_size=batch_size, - max_cache_length=max_cache_length, - ) - __all__ = [ "ZayaConfig", From 4d5bda4e29b7d5c6cd9ef062fb9e1f540b77300c Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 18:43:45 +0800 Subject: [PATCH 13/36] checkpoint conversion --- docs/source/en/model_doc/zaya.md | 8 + src/transformers/conversion_mapping.py | 3 + .../models/zaya/configuration_zaya.py | 120 ++- .../models/zaya/convert_zaya_weights_to_hf.py | 335 ++++++++ src/transformers/models/zaya/modeling_zaya.py | 761 +++++++----------- src/transformers/models/zaya/modular_zaya.py | 300 +++---- tests/models/zaya/test_modeling_zaya.py | 62 +- 7 files changed, 838 insertions(+), 751 deletions(-) create mode 100644 src/transformers/models/zaya/convert_zaya_weights_to_hf.py diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md index 468f7327dd86..24468b8df65f 100644 --- a/docs/source/en/model_doc/zaya.md +++ b/docs/source/en/model_doc/zaya.md @@ -27,6 +27,14 @@ and Zyphra's technical reports. This model was contributed by [JJJYmmm](https://github.com/JJJYmmm). + + +When building a manual generation loop with `past_key_values`, use [`~models.zaya.modeling_zaya.make_zaya_cache`] to +create ZAYA's cache. ZAYA uses `config.layer_types` for full/sliding attention masks and RoPE parameters, while its +cache uses the native hybrid layout needed by the attention, convolution, and recurrent states. + + + ## Usage examples ```python diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 0bf2c311845b..bb333bdcb4ce 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -564,6 +564,9 @@ def _build_checkpoint_conversion_mapping(): "zaya": [ WeightRenaming(r"self_attn\.qkv\.conv_qk\.0\.", "self_attn.qkv.conv_qk_depthwise."), WeightRenaming(r"self_attn\.qkv\.conv_qk\.1\.", "self_attn.qkv.conv_qk_grouped."), + WeightRenaming(r"zaya_block\.router\.router_mlp\.0\.", "zaya_block.router.router_mlp.fc1."), + WeightRenaming(r"zaya_block\.router\.router_mlp\.2\.", "zaya_block.router.router_mlp.fc2."), + WeightRenaming(r"zaya_block\.router\.router_mlp\.4\.", "zaya_block.router.router_mlp.out_proj."), WeightConverter( source_patterns="zaya_block.experts.local_experts.*.linear_fc1.weight", target_patterns="zaya_block.experts.gate_up_proj", diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py index 12a7c2999abc..479d07dea7d4 100644 --- a/src/transformers/models/zaya/configuration_zaya.py +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -4,7 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_zaya.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved. +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Literal + from huggingface_hub.dataclasses import strict from ...configuration_utils import PreTrainedConfig @@ -29,17 +31,15 @@ @strict class ZayaConfig(PreTrainedConfig): r""" - ffn_hidden_size (`int`, *optional*, defaults to 4096): + intermediate_size (`int`, *optional*, defaults to 4096): Dimension of the feed-forward and expert hidden states. - num_query_groups (`int`, *optional*, defaults to 2): - Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`. - rope_theta (`float`, *optional*, defaults to 5000000): - The base period of the RoPE embeddings. + num_key_value_heads (`int`, *optional*, defaults to 2): + Number of key/value groups. partial_rotary_factor (`float`, *optional*, defaults to 0.5): Fraction of each attention head dimension using rotary embeddings. lm_head_bias (`bool`, *optional*, defaults to `False`): Whether to add a bias to the language modeling head. - moe_router_topk (`int`, *optional*, defaults to 1): + num_experts_per_tok (`int`, *optional*, defaults to 1): Number of selected experts per token. ZAYA checkpoints use top-1 routing. zaya_mlp_expansion (`int`, *optional*, defaults to 256): Expansion size used by the dense ZAYA blocks. @@ -47,7 +47,7 @@ class ZayaConfig(PreTrainedConfig): First temporal parameter of the CCA projection. cca_time1 (`int`, *optional*, defaults to 2): Second temporal parameter of the CCA projection. - swa_layers (`list[int]`, *optional*): + layer_types (`list[str]`, *optional*): Per-layer selector for standard RoPE versus SWA RoPE embeddings. swa_rotary_base (`float`, *optional*): RoPE base used by SWA layers. @@ -64,15 +64,15 @@ class ZayaConfig(PreTrainedConfig): model_type = "zaya" keys_to_ignore_at_inference = ["past_key_values"] + default_theta = 5000000.0 vocab_size: int = 262272 hidden_size: int = 2048 - ffn_hidden_size: int = 4096 - num_hidden_layers: int = 80 + intermediate_size: int = 4096 + num_hidden_layers: int = 40 num_experts: int = 16 num_attention_heads: int = 8 - num_key_value_heads: int | None = 2 - num_query_groups: int | None = 2 + num_key_value_heads: int = 2 hidden_act: str = "silu" head_dim: int = 128 max_position_embeddings: int = 131072 @@ -81,72 +81,64 @@ class ZayaConfig(PreTrainedConfig): use_cache: bool = True tie_word_embeddings: bool = True rope_parameters: RopeParameters | dict | None = None - rope_theta: float | int = 5000000 partial_rotary_factor: float = 0.5 attention_bias: bool = False lm_head_bias: bool = False attention_dropout: float | int = 0.0 - moe_router_topk: int = 1 + num_experts_per_tok: int = 1 zaya_mlp_expansion: int = 256 - cca_time0: int | None = 2 - cca_time1: int | None = 2 - swa_layers: list[int] | None = None - swa_rotary_base: float | int | None = None + cca_time0: int = 2 + cca_time1: int = 2 + sliding_window: int | None = None + layer_types: list[str] | None = None + swa_rotary_base: float | int = 10000.0 output_router_logits: bool = False pad_token_id: int | None = 0 bos_token_id: int | None = 2 eos_token_id: int | list[int] | None = 106 def __post_init__(self, **kwargs): - for unused_checkpoint_kwarg in ( - "cca", - "activation_func", - "normalization", - "add_bias_linear", - "gated_linear_unit", - "fused_add_norm", - "apply_rope_fusion", - "bias_activation_fusion", - "activation_func_fp8_input_store", - "clamp_temp", - "kv_channels", - "mamba_cache_dtype", - "residual_in_fp32", - "rope_scaling", - "scale_residual_merge", - "sliding_window", - "zaya_high_prec", - "zaya_use_mod", - "zaya_use_eda", - ): - kwargs.pop(unused_checkpoint_kwarg, None) - - self.num_key_value_heads = ( - self.num_attention_heads if self.num_key_value_heads is None else self.num_key_value_heads - ) - self.num_query_groups = self.num_key_value_heads if self.num_query_groups is None else self.num_query_groups - if self.head_dim is None: - raise ValueError("`head_dim` must be set for ZAYA.") - if self.num_query_groups != self.num_key_value_heads: - raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.") - if self.moe_router_topk != 1: - raise ValueError("ZAYA currently supports `moe_router_topk=1` only.") - - self.rope_parameters = ( - dict(self.rope_parameters) if self.rope_parameters is not None else {"rope_type": "default"} + self.layer_types = ( + ["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) ) - self.rope_parameters.setdefault("rope_theta", self.rope_theta) - self.rope_parameters.setdefault("partial_rotary_factor", self.partial_rotary_factor) - self.cca_time0 = 2 if self.cca_time0 is None else self.cca_time0 - self.cca_time1 = 2 if self.cca_time1 is None else self.cca_time1 - if (self.cca_time0, self.cca_time1) != (2, 2): - raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.") - if self.swa_layers is not None and len(self.swa_layers) != self.num_hidden_layers: - raise ValueError("`swa_layers` must have one entry per hidden layer.") - if self.swa_layers is not None and self.swa_rotary_base is None: - raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.") + + default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = { + "full_attention": { + "rope_type": "default", + "rope_theta": self.default_theta, + "partial_rotary_factor": self.partial_rotary_factor, + }, + "sliding_attention": { + "rope_type": "default", + "rope_theta": self.swa_rotary_base, + "partial_rotary_factor": self.partial_rotary_factor, + }, + } + if self.rope_parameters is None: + self.rope_parameters = { + layer_type: default_rope_params[layer_type] for layer_type in set(self.layer_types) + } super().__post_init__(**kwargs) + def convert_rope_params_to_dict(self, **kwargs): + # ZAYA uses nested RoPE parameters keyed by layer type. Keep the base RoPE BC conversion from treating them + # like a single flat RoPE dict and injecting top-level keys such as `rope_theta`. + return kwargs + + def validate_architecture(self): + if self.num_experts_per_tok != 1: + raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.") + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") + if len(self.layer_types) != self.num_hidden_layers: + raise ValueError("`layer_types` must have one entry per hidden layer.") + if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}: + raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.") + if "sliding_attention" in self.layer_types and self.sliding_window is None: + raise ValueError("`sliding_window` must be set when `layer_types` contains `sliding_attention`.") + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError("`sliding_window` must be a strictly positive integer.") + __all__ = ["ZayaConfig"] diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py new file mode 100644 index 000000000000..ba9198b9c666 --- /dev/null +++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py @@ -0,0 +1,335 @@ +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert original alternating-layer ZAYA checkpoints to the Transformers-native decoder-layer layout.""" + +import argparse +import json +import re +import shutil +from collections import defaultdict +from pathlib import Path + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from transformers import ZayaConfig + + +_LAYER_PATTERN = re.compile(r"^model\.layers\.(\d+)\.(.+)$") +_LOCAL_EXPERT_PATTERN = re.compile( + r"^model\.layers\.(\d+)\.zaya_block\.experts\.local_experts\.(\d+)\.linear_fc([12])\.weight$" +) + +_UNUSED_CONFIG_KEYS = ( + "cca", + "num_query_groups", + "ffn_hidden_size", + "moe_router_topk", + "activation_func", + "normalization", + "add_bias_linear", + "gated_linear_unit", + "fused_add_norm", + "apply_rope_fusion", + "bias_activation_fusion", + "activation_func_fp8_input_store", + "clamp_temp", + "kv_channels", + "mamba_cache_dtype", + "residual_in_fp32", + "rope_scaling", + "scale_residual_merge", + "zaya_high_prec", + "zaya_use_mod", + "zaya_use_eda", +) + + +def _rename_common(rest: str) -> str: + replacements = ( + ("self_attn.qkv.conv_qk.0.", "self_attn.qkv.conv_qk_depthwise."), + ("self_attn.qkv.conv_qk.1.", "self_attn.qkv.conv_qk_grouped."), + ("zaya_block.router.router_mlp.0.", "zaya_block.router.router_mlp.fc1."), + ("zaya_block.router.router_mlp.2.", "zaya_block.router.router_mlp.fc2."), + ("zaya_block.router.router_mlp.4.", "zaya_block.router.router_mlp.out_proj."), + ) + for old, new in replacements: + if rest.startswith(old): + return new + rest.removeprefix(old) + return rest + + +def _expert_target(name: str) -> tuple[str, int] | None: + match = _LOCAL_EXPERT_PATTERN.match(name) + if match is None: + return None + + old_layer_idx = int(match.group(1)) + if old_layer_idx % 2 != 1: + raise ValueError(f"Expert weights are expected on odd ZAYA layers, got: {name}") + + new_layer_idx = old_layer_idx // 2 + expert_idx = int(match.group(2)) + projection = "gate_up_proj" if match.group(3) == "1" else "down_proj" + target = f"model.layers.{new_layer_idx}.zaya_block.experts.{projection}" + return target, expert_idx + + +def convert_weight_name(name: str) -> str | None: + if _expert_target(name) is not None: + return None + + match = _LAYER_PATTERN.match(name) + if match is None: + return name + + old_layer_idx = int(match.group(1)) + rest = match.group(2) + new_layer_idx = old_layer_idx // 2 + + if old_layer_idx % 2 == 0: + rest = _rename_common(rest) + if rest.startswith(("self_attn.", "input_norm.", "res_scale.")): + return f"model.layers.{new_layer_idx}.{rest}" + else: + rest = _rename_common(rest) + if rest.startswith("zaya_block."): + return f"model.layers.{new_layer_idx}.{rest}" + if rest.startswith("input_norm."): + return f"model.layers.{new_layer_idx}.post_attention_norm.{rest.removeprefix('input_norm.')}" + if rest.startswith("res_scale."): + return f"model.layers.{new_layer_idx}.post_attention_res_scale.{rest.removeprefix('res_scale.')}" + + raise ValueError(f"Unexpected ZAYA layer weight name: {name}") + + +def _convert_layer_types(config_dict: dict, old_num_hidden_layers: int, new_num_hidden_layers: int) -> list[str]: + layer_types = config_dict.get("layer_types") + if layer_types is not None: + if len(layer_types) == old_num_hidden_layers: + return layer_types[::2] + if len(layer_types) == new_num_hidden_layers: + return list(layer_types) + raise ValueError("`layer_types` must match either the original or converted number of hidden layers.") + + swa_layers = config_dict.get("swa_layers") + if swa_layers is None: + return ["full_attention"] * new_num_hidden_layers + if len(swa_layers) == old_num_hidden_layers: + swa_layers = swa_layers[::2] + elif len(swa_layers) != new_num_hidden_layers: + raise ValueError("`swa_layers` must match either the original or converted number of hidden layers.") + return ["full_attention" if int(window_size) == 0 else "sliding_attention" for window_size in swa_layers] + + +def convert_config(input_dir: Path, output_dir: Path) -> None: + config_dict = json.loads((input_dir / "config.json").read_text()) + old_num_hidden_layers = int(config_dict["num_hidden_layers"]) + if old_num_hidden_layers % 2 != 0: + raise ValueError("Original ZAYA checkpoints must have an even number of alternating attention/MoE layers.") + + new_num_hidden_layers = old_num_hidden_layers // 2 + layer_types = _convert_layer_types(config_dict, old_num_hidden_layers, new_num_hidden_layers) + partial_rotary_factor = config_dict.get("partial_rotary_factor", ZayaConfig.partial_rotary_factor) + rope_theta = config_dict.get("rope_theta", ZayaConfig.default_theta) + swa_rotary_base = config_dict.get("swa_rotary_base", ZayaConfig.swa_rotary_base) + intermediate_size = config_dict.get( + "intermediate_size", config_dict.get("ffn_hidden_size", ZayaConfig.intermediate_size) + ) + num_experts_per_tok = config_dict.get( + "num_experts_per_tok", config_dict.get("moe_router_topk", ZayaConfig.num_experts_per_tok) + ) + + swa_layers = config_dict.get("swa_layers") or [] + sliding_window = config_dict.get("sliding_window") + if sliding_window is None: + positive_windows = [int(window_size) for window_size in swa_layers if int(window_size) > 0] + sliding_window = max(positive_windows) if positive_windows else None + + rope_parameters = { + "full_attention": { + "rope_type": "default", + "rope_theta": rope_theta, + "partial_rotary_factor": partial_rotary_factor, + }, + "sliding_attention": { + "rope_type": "default", + "rope_theta": swa_rotary_base, + "partial_rotary_factor": partial_rotary_factor, + }, + } + + for key in (*_UNUSED_CONFIG_KEYS, "swa_layers", "rope_theta"): + config_dict.pop(key, None) + + config_dict.update( + { + "architectures": ["ZayaForCausalLM"], + "num_hidden_layers": new_num_hidden_layers, + "intermediate_size": intermediate_size, + "num_experts_per_tok": num_experts_per_tok, + "layer_types": layer_types, + "sliding_window": sliding_window, + "swa_rotary_base": swa_rotary_base, + "rope_parameters": {layer_type: rope_parameters[layer_type] for layer_type in set(layer_types)}, + } + ) + ZayaConfig(**config_dict).save_pretrained(output_dir) + + +def copy_non_weight_files(input_dir: Path, output_dir: Path) -> None: + for path in input_dir.iterdir(): + if path.name == "config.json": + continue + if path.name.endswith(".safetensors") or path.name.endswith(".bin"): + continue + if path.name in {"model.safetensors.index.json", "pytorch_model.bin.index.json"}: + continue + + output_path = output_dir / path.name + if path.is_dir(): + shutil.copytree(path, output_path, dirs_exist_ok=True) + else: + shutil.copy2(path, output_path) + + +def _build_weight_plan(input_dir: Path) -> tuple[dict[str, str], dict[str, list[str]], dict[str, str], dict]: + index = json.loads((input_dir / "model.safetensors.index.json").read_text()) + old_weight_map = index["weight_map"] + converted_weight_map = {} + normal_sources_by_output_file = defaultdict(list) + expert_sources_by_target = defaultdict(list) + output_file_by_target = {} + + for source_key, filename in old_weight_map.items(): + expert_info = _expert_target(source_key) + if expert_info is not None: + target_key, expert_idx = expert_info + expert_sources_by_target[target_key].append((expert_idx, source_key)) + output_file_by_target.setdefault(target_key, filename) + converted_weight_map[target_key] = output_file_by_target[target_key] + continue + + target_key = convert_weight_name(source_key) + if target_key in converted_weight_map: + raise ValueError(f"Duplicate converted weight name: {target_key}") + converted_weight_map[target_key] = filename + normal_sources_by_output_file[filename].append((source_key, target_key)) + + index["weight_map"] = converted_weight_map + return normal_sources_by_output_file, expert_sources_by_target, output_file_by_target, index + + +def _load_sources(input_dir: Path, source_keys: list[str], old_weight_map: dict[str, str]) -> dict[str, torch.Tensor]: + sources_by_file = defaultdict(list) + for source_key in source_keys: + sources_by_file[old_weight_map[source_key]].append(source_key) + + tensors = {} + for filename, keys in sources_by_file.items(): + with safe_open(input_dir / filename, framework="pt", device="cpu") as f: + for key in keys: + tensors[key] = f.get_tensor(key) + return tensors + + +def convert_safetensors(input_dir: Path, output_dir: Path) -> None: + index_path = input_dir / "model.safetensors.index.json" + if not index_path.exists(): + safetensors_path = input_dir / "model.safetensors" + if not safetensors_path.exists(): + raise FileNotFoundError("Only safetensors ZAYA checkpoints are supported by this converter.") + + with safe_open(safetensors_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + state_dict = {} + expert_groups = defaultdict(list) + for key in f.keys(): + expert_info = _expert_target(key) + if expert_info is not None: + target_key, expert_idx = expert_info + expert_groups[target_key].append((expert_idx, f.get_tensor(key))) + continue + state_dict[convert_weight_name(key)] = f.get_tensor(key) + for target_key, expert_tensors in expert_groups.items(): + state_dict[target_key] = torch.stack([tensor for _, tensor in sorted(expert_tensors)], dim=0) + save_file(state_dict, output_dir / "model.safetensors", metadata=metadata) + return + + old_index = json.loads(index_path.read_text()) + old_weight_map = old_index["weight_map"] + normal_sources_by_output_file, expert_sources_by_target, output_file_by_target, converted_index = ( + _build_weight_plan(input_dir) + ) + output_filenames = sorted(set(converted_index["weight_map"].values())) + + metadata_by_file = {} + for filename in sorted(set(old_weight_map.values())): + with safe_open(input_dir / filename, framework="pt", device="cpu") as f: + metadata_by_file[filename] = f.metadata() + + for output_filename in output_filenames: + shard = {} + normal_sources = normal_sources_by_output_file.get(output_filename, []) + source_keys = [source_key for source_key, _ in normal_sources] + + expert_groups_for_shard = { + target_key: sorted(sources) + for target_key, sources in expert_sources_by_target.items() + if output_file_by_target[target_key] == output_filename + } + for sources in expert_groups_for_shard.values(): + source_keys.extend(source_key for _, source_key in sources) + + loaded_tensors = _load_sources(input_dir, source_keys, old_weight_map) + for source_key, target_key in normal_sources: + shard[target_key] = loaded_tensors[source_key] + for target_key, sources in expert_groups_for_shard.items(): + shard[target_key] = torch.stack([loaded_tensors[source_key] for _, source_key in sources], dim=0) + + save_file(shard, output_dir / output_filename, metadata=metadata_by_file.get(output_filename)) + + (output_dir / "model.safetensors.index.json").write_text( + json.dumps(converted_index, indent=2, sort_keys=True) + "\n" + ) + + +def convert_checkpoint(input_dir: str, output_dir: str, overwrite: bool = False) -> None: + input_path = Path(input_dir).expanduser().resolve() + output_path = Path(output_dir).expanduser().resolve() + if input_path == output_path: + raise ValueError("Please write the converted checkpoint to a different output directory.") + if output_path.exists() and any(output_path.iterdir()): + if not overwrite: + raise FileExistsError(f"{output_path} already exists and is not empty. Pass --overwrite to replace it.") + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + copy_non_weight_files(input_path, output_path) + convert_config(input_path, output_path) + convert_safetensors(input_path, output_path) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--input_dir", required=True, help="Path to the original alternating-layer ZAYA checkpoint.") + parser.add_argument("--output_dir", required=True, help="Path where the converted checkpoint should be written.") + parser.add_argument("--overwrite", action="store_true", help="Overwrite a non-empty output directory.") + args = parser.parse_args() + convert_checkpoint(args.input_dir, args.output_dir, overwrite=args.overwrite) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index ab68cbc73d36..20662110b172 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -4,7 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_zaya.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved. +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import copy from collections.abc import Callable -from typing import Optional +from typing import Any, Optional import torch import torch.nn.functional as F @@ -29,10 +29,10 @@ from torch.nn import init from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache, DynamicCache, LinearAttentionAndFullAttentionLayer from ...generation import GenerationMixin -from ...integrations import use_experts_implementation, use_kernel_forward_from_hub -from ...masking_utils import create_causal_mask +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update @@ -47,27 +47,33 @@ class ZayaRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` - def __init__(self, config: ZayaConfig, device=None): + def __init__(self, config: ZayaConfig): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings - self.config = config + self.layer_types = list(set(config.layer_types)) + self.rope_type = {} + for layer_type in self.layer_types: + rope_params = self.config.rope_parameters[layer_type] + if rope_params is None: + continue - self.rope_type = self.config.rope_parameters["rope_type"] - rope_init_fn: Callable = self.compute_default_rope_parameters - if self.rope_type != "default": - rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = rope_init_fn(self.config, device) - - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + self.rope_type[layer_type] = rope_params["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] + curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, layer_type=layer_type) + self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False) + self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False) + setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling) @staticmethod def compute_default_rope_parameters( config: ZayaConfig | None = None, device: Optional["torch.device"] = None, seq_len: int | None = None, + layer_type: str | None = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation @@ -78,12 +84,16 @@ def compute_default_rope_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. + layer_type (`str`, *optional*): + The current layer type if the model has different RoPE parameters per type. + Should not be used unless `config.layer_types is not None` Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - base = config.rope_parameters["rope_theta"] - partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + base = config.rope_parameters[layer_type]["rope_theta"] + # key difference to gemma3: partial rope + partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0) head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads dim = int(head_dim * partial_rotary_factor) @@ -97,16 +107,19 @@ def compute_default_rope_parameters( @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + def forward(self, x, position_ids, layer_type=None): + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -132,129 +145,36 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class ZayaDynamicCache(DynamicCache): - """ - Cache that includes both the KV cache and the CCA cache. +class ZayaCCAProjection(nn.Module): """ + Projects hidden states into attention q/k/v states with ZAYA's CCA path. - def __init__( - self, - config: ZayaConfig, - batch_size: int, - dtype: torch.dtype = torch.float16, - device: str | None = None, - ): - super().__init__() - self.config = config - self.batch_size = batch_size - self.dtype = dtype - self.device = device - self.conv_kernel_size = (config.cca_time0 - 1) + (config.cca_time1 - 1) - self.num_layers = config.num_hidden_layers - self.key_value_hidden_size = config.num_query_groups * config.head_dim - self.query_hidden_size = config.num_attention_heads * config.head_dim - self.conv_state_size = self.key_value_hidden_size + self.query_hidden_size - self.has_previous_state = False - - self.conv_states = [None for _ in range(self.num_layers)] - self.prev_v2 = [None for _ in range(self.num_layers)] - - def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor: - if new_conv_state.shape[1] < self.conv_kernel_size: - new_conv_state = F.pad( - new_conv_state.transpose(1, 2), (self.conv_kernel_size - new_conv_state.shape[1], 0) - ) - else: - new_conv_state = new_conv_state[:, -self.conv_kernel_size :, :].transpose(1, 2) + `linear_q` and `linear_k` produce the residual q/k states and are concatenated into `qk_states`. The causal + `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail; + for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. + Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses + `val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**. - if self.conv_states[layer_idx] is None: - self.conv_states[layer_idx] = torch.zeros_like(new_conv_state) + Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys. + """ - if not self.has_previous_state: - self.conv_states[layer_idx].copy_(new_conv_state) - else: - conv_state = torch.cat([self.conv_states[layer_idx], new_conv_state], dim=-1)[ - :, :, -self.conv_kernel_size : - ] - self.conv_states[layer_idx].copy_(conv_state) - return self.conv_states[layer_idx] - - def update_prev_v2(self, layer_idx: int, new_prev_v2: torch.Tensor) -> torch.Tensor: - if self.prev_v2[layer_idx] is None: - self.prev_v2[layer_idx] = torch.zeros_like(new_prev_v2) - self.prev_v2[layer_idx].copy_(new_prev_v2) - return self.prev_v2[layer_idx] - - def reset(self): - super().reset() - for conv_state in self.conv_states: - if conv_state is not None: - conv_state.zero_() - for prev_v2 in self.prev_v2: - if prev_v2 is not None: - prev_v2.zero_() - self.has_previous_state = False - - def _reorder_auxiliary_states(self, indices: torch.LongTensor): - for layer_idx, conv_state in enumerate(self.conv_states): - if conv_state is not None: - self.conv_states[layer_idx] = conv_state.index_select(0, indices.to(conv_state.device)) - for layer_idx, prev_v2 in enumerate(self.prev_v2): - if prev_v2 is not None: - self.prev_v2[layer_idx] = prev_v2.index_select(0, indices.to(prev_v2.device)) - self.batch_size = indices.shape[0] - - def reorder_cache(self, beam_idx: torch.LongTensor): - super().reorder_cache(beam_idx) - self._reorder_auxiliary_states(beam_idx) - - def batch_repeat_interleave(self, repeats: int): - super().batch_repeat_interleave(repeats) - for layer_idx, conv_state in enumerate(self.conv_states): - if conv_state is not None: - self.conv_states[layer_idx] = conv_state.repeat_interleave(repeats, dim=0) - for layer_idx, prev_v2 in enumerate(self.prev_v2): - if prev_v2 is not None: - self.prev_v2[layer_idx] = prev_v2.repeat_interleave(repeats, dim=0) - self.batch_size *= repeats - - def batch_select_indices(self, indices: torch.Tensor): - super().batch_select_indices(indices) - self._reorder_auxiliary_states(indices) - - -class CCA(nn.Module): - def __init__( - self, - config: ZayaConfig, - num_key_value_heads: int = 2, - num_attention_heads: int = 8, - hidden_size: int | None = None, - head_dim: int = 128, - cca_time0: int = 2, - cca_time1: int = 2, - layer_number: int = 0, - ): + def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__() self.config = config - self.layer_number = layer_number + self.layer_idx = layer_idx - self.hidden_size = int(hidden_size or config.hidden_size) + self.hidden_size = config.hidden_size - self.depthwise_kernel_size = cca_time0 - self.grouped_kernel_size = cca_time1 + self.depthwise_kernel_size = config.cca_time0 + self.grouped_kernel_size = config.cca_time1 self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) - self.num_key_value_heads = int(num_key_value_heads) - self.num_attention_heads = int(num_attention_heads) - - self.head_dim = int(head_dim) + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim self.key_value_hidden_size = self.num_key_value_heads * self.head_dim self.query_hidden_size = self.num_attention_heads * self.head_dim - self.sqrt_head_dim = self.head_dim**0.5 self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - if self.num_attention_heads % self.num_key_value_heads != 0: - raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias) self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias) @@ -262,23 +182,21 @@ def __init__( self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias) conv_channels = self.key_value_hidden_size + self.query_hidden_size - self.conv_qk = nn.Sequential( - nn.Conv1d( - in_channels=conv_channels, - out_channels=conv_channels, - kernel_size=self.depthwise_kernel_size, - groups=conv_channels, - padding=0, - stride=1, - ), - nn.Conv1d( - in_channels=conv_channels, - out_channels=conv_channels, - kernel_size=self.grouped_kernel_size, - groups=(self.num_key_value_heads + self.num_attention_heads), - padding=0, - stride=1, - ), + self.conv_qk_depthwise = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.depthwise_kernel_size, + groups=conv_channels, + padding=0, + stride=1, + ) + self.conv_qk_grouped = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.grouped_kernel_size, + groups=(self.num_key_value_heads + self.num_attention_heads), + padding=0, + stride=1, ) self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads)) @@ -286,51 +204,55 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - past_key_values: ZayaDynamicCache | None, - attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None, + padding_mask: torch.Tensor | None = None, ): - if attention_mask is not None: - hidden_states = hidden_states * attention_mask[:, :, None].to(hidden_states.dtype) + if padding_mask is not None: + hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype) - batch_size, seq_length, _ = hidden_states.shape + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) projected_queries = self.linear_q(hidden_states) projected_keys = self.linear_k(hidden_states) qk_states = torch.cat([projected_queries, projected_keys], dim=-1) - query_residual = projected_queries.view(batch_size, seq_length, self.num_attention_heads, self.head_dim) - key_residual = projected_keys.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim) + query_residual = projected_queries.view(*hidden_shape) + key_residual = projected_keys.view(*input_shape, self.num_key_value_heads, self.head_dim) - key_residual = key_residual.unsqueeze(-2).expand(-1, -1, -1, self.num_key_value_groups, -1) - key_residual = key_residual.reshape(batch_size, seq_length, self.num_attention_heads, self.head_dim) + key_residual = key_residual.repeat_interleave(self.num_key_value_groups, dim=-2) query_residual = (query_residual + key_residual) * 0.5 key_residual = query_residual.view( - batch_size, seq_length, self.num_key_value_heads, self.num_key_value_groups, self.head_dim + *input_shape, self.num_key_value_heads, self.num_key_value_groups, self.head_dim ).mean(dim=-2) qk_states = qk_states.transpose(1, 2) - use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx) if use_precomputed_states: - cached_qk_states = past_key_values.conv_states[self.layer_number] + cached_qk_states = past_key_values.layers[self.layer_idx].conv_states conv_input = torch.cat([cached_qk_states, qk_states], dim=-1) else: conv_input = F.pad(qk_states, (self.total_padding, 0)) if past_key_values is not None: - past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=qk_states.transpose(1, 2)) + new_conv_state = qk_states[..., -self.total_padding :] + if new_conv_state.shape[-1] < self.total_padding: + new_conv_state = F.pad(new_conv_state, (self.total_padding - new_conv_state.shape[-1], 0)) + past_key_values.update_conv_state(new_conv_state, self.layer_idx) - convolved_qk_states = self.conv_qk(conv_input).transpose(1, 2) + convolved_qk_states = self.conv_qk_depthwise(conv_input) + convolved_qk_states = self.conv_qk_grouped(convolved_qk_states).transpose(1, 2) query = ( convolved_qk_states[..., : self.query_hidden_size].view( - batch_size, seq_length, self.num_attention_heads, self.head_dim + *input_shape, self.num_attention_heads, self.head_dim ) + query_residual ) key = ( convolved_qk_states[..., self.query_hidden_size :].view( - batch_size, seq_length, self.num_key_value_heads, self.head_dim + *input_shape, self.num_key_value_heads, self.head_dim ) + key_residual ) @@ -338,28 +260,18 @@ def forward( value_current = self.val_proj1(hidden_states) projected_v2 = self.val_proj2(hidden_states) if use_precomputed_states: - first_v2 = past_key_values.prev_v2[self.layer_number].unsqueeze(1) + first_v2 = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) else: - first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size)) + first_v2 = self.val_proj2(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1) if past_key_values is not None: - past_key_values.update_prev_v2(self.layer_number, projected_v2[:, -1, :]) + past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_idx) value = torch.cat([value_current, value_delayed], dim=-1).view( - batch_size, seq_length, self.num_key_value_heads, self.head_dim + *input_shape, self.num_key_value_heads, self.head_dim ) - norm_eps = torch.finfo(query.dtype).eps - query_norm = query.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) - key_norm = key.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) - - key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1) - query = query * (self.sqrt_head_dim / query_norm) - - query = query.reshape(batch_size, seq_length, self.query_hidden_size) - key = key.reshape(batch_size, seq_length, self.key_value_hidden_size) - value = value.reshape(batch_size, seq_length, self.key_value_hidden_size) return query, key, value @@ -446,51 +358,56 @@ def eager_attention_forward( return attn_output, attn_weights +@use_kernelized_func(apply_rotary_pos_emb) class ZayaAttention(nn.Module): - def __init__(self, config: ZayaConfig, layer_n): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__() self.config = config - self.layer_n = layer_n - self.layer_idx = layer_n - self.hidden_size = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - self.is_causal = True - self.attention_dropout = config.attention_dropout - self.head_dim = config.head_dim + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - + self.attention_dropout = config.attention_dropout + self.is_causal = True self.o_proj = nn.Linear( - self.num_attention_heads * self.head_dim, - self.hidden_size, - bias=self.config.attention_bias, + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.qkv = CCA( + self.layer_n = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.qkv = ZayaCCAProjection( config=self.config, - num_attention_heads=self.config.num_attention_heads, - num_key_value_heads=self.config.num_query_groups, - hidden_size=self.hidden_size, - head_dim=self.config.head_dim, - cca_time0=self.config.cca_time0, - cca_time1=self.config.cca_time1, - layer_number=layer_n, + layer_idx=layer_idx, ) def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - attention_mask_2d: torch.Tensor | None = None, + attention_mask: dict[str, Any] | None = None, past_key_values: Cache | None = None, - output_attentions: bool = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: batch_size, seq_length, _ = hidden_states.shape - query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, attention_mask_2d) - query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads, self.head_dim) - key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) - value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + + mask_mapping = attention_mask or {} + causal_mask = mask_mapping.get("causal") + padding_mask = mask_mapping.get("padding") + + query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask) + + norm_eps = torch.finfo(query_states.dtype).eps + head_dim_scale = self.scaling**-1 + query_states = query_states * ( + head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)) + key_states = key_states * self.qkv.temp[None, None, :, None] query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -502,8 +419,7 @@ def forward( if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n) - causal_mask = attention_mask - if causal_mask is not None: + if isinstance(causal_mask, torch.Tensor): causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]] attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( @@ -517,15 +433,13 @@ def forward( causal_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - output_attentions=output_attentions, + sliding_window=self.sliding_window, + **kwargs, ) attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_values @@ -541,66 +455,94 @@ def _apply_residual_scaling( return hidden_states, residual -class ZayaDecoderATTLayer(GradientCheckpointingLayer): - def __init__(self, config: ZayaConfig, layer_n: int): +class ZayaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__() self.config = config - self.self_attn = ZayaAttention(config, layer_n) - + self.self_attn = ZayaAttention(config, layer_idx) self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) - self.res_scale = ResidualScaling(config, layer_n) + self.res_scale = ResidualScaling(config.hidden_size, has_residual_scale=layer_idx != 0) + self.zaya_block = ZayaSparseMoeBlock( + config, + config.num_experts, + config.zaya_mlp_expansion, + config.intermediate_size, + layer_idx, + ) + self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + self.post_attention_res_scale = ResidualScaling(config.hidden_size) def forward( self, hidden_states: torch.Tensor, - residual: torch.Tensor, - attention_mask: torch.Tensor | None = None, - attention_mask_2d: torch.Tensor | None = None, + residual: torch.Tensor | None, + prev_router_hidden_states: torch.Tensor | None = None, + attention_mask: dict[str, Any] | None = None, past_key_values: Cache | None = None, - output_attentions: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - prev_router_hidden_states: torch.Tensor | None = None, - **kwargs, - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]: hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - attention_mask_2d=attention_mask_2d, past_key_values=past_key_values, - output_attentions=output_attentions, position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states, residual = _apply_residual_scaling( + hidden_states, residual, self.post_attention_res_scale, self.post_attention_norm + ) + + hidden_states, prev_router_hidden_states, _ = self.zaya_block( + hidden_states, + prev_router_hidden_states, ) - return hidden_states, self_attn_weights if output_attentions else None, residual, prev_router_hidden_states + return hidden_states, self_attn_weights, residual, prev_router_hidden_states class ResidualScaling(nn.Module): - def __init__(self, config, layer_n): + def __init__(self, hidden_size: int, has_residual_scale: bool = True): super().__init__() - self.not_first_layer = layer_n != 0 - self.hidden_states_scale = torch.nn.Parameter(torch.ones(config.hidden_size)) - self.hidden_states_bias = torch.nn.Parameter(torch.zeros(config.hidden_size)) + self.has_residual_scale = has_residual_scale + self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) + self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size)) - if self.not_first_layer: - self.residual_scale = torch.nn.Parameter(torch.ones(config.hidden_size)) - self.residual_bias = torch.nn.Parameter(torch.zeros(config.hidden_size)) + if self.has_residual_scale: + self.residual_scale = nn.Parameter(torch.ones(hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor): hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale - if self.not_first_layer: + if self.has_residual_scale: residual = (residual + self.residual_bias) * self.residual_scale return residual, hidden_states +class ZayaRouterMLP(nn.Module): + def __init__(self, hidden_size: int, num_experts: int): + super().__init__() + self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) + self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.act_fn(self.fc1(hidden_states)) + hidden_states = self.act_fn(self.fc2(hidden_states)) + return self.out_proj(hidden_states) + + class ZayaRouter(nn.Module): def __init__( self, config, layer_idx: int, num_moe_experts: int, - moe_router_topk: int, + num_experts_per_tok: int, mlp_expansion: int, hidden_size: int | None = None, ) -> None: @@ -611,27 +553,18 @@ def __init__( self.layer_idx = layer_idx self.num_experts = num_moe_experts + 1 - self.topk = int(moe_router_topk) + self.topk = int(num_experts_per_tok) self.mlp_expansion = int(mlp_expansion) self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True) - zaya_first_layer = 1 - self.use_eda = self.layer_idx != zaya_first_layer + self.use_eda = self.layer_idx != 0 - ln_eps = float(getattr(config, "norm_epsilon", 1e-5)) - self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=ln_eps) + self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=config.norm_epsilon) if self.use_eda: self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion)) - self.non_linearity = nn.GELU() - self.router_mlp = nn.Sequential( - nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True), - self.non_linearity, - nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True), - self.non_linearity, - nn.Linear(self.mlp_expansion, self.num_experts, bias=False), - ) + self.router_mlp = ZayaRouterMLP(self.mlp_expansion, self.num_experts) self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) self.balancing_biases[-1] = -1.0 @@ -669,11 +602,11 @@ def forward( class ZayaExperts(nn.Module): """Collection of expert weights stored as 3D tensors.""" - def __init__(self, config, num_experts: int, ffn_hidden_size: int): + def __init__(self, config, num_experts: int, intermediate_size: int): super().__init__() self.num_experts = num_experts self.hidden_dim = config.hidden_size - self.intermediate_dim = ffn_hidden_size // 2 + self.intermediate_dim = intermediate_size // 2 self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) self.act_fn = ACT2FN[config.hidden_act] @@ -686,7 +619,7 @@ def forward( ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() @@ -705,14 +638,14 @@ def forward( return final_hidden_states -class ZayaBlock(nn.Module): +class ZayaSparseMoeBlock(nn.Module): def __init__( self, config, num_moe_experts: int, mlp_expansion: int, - ffn_hidden_size: int, - layer_n: int, + intermediate_size: int, + layer_idx: int, ): super().__init__() self.config = config @@ -720,13 +653,13 @@ def __init__( self.num_moe_experts = num_moe_experts self.router = ZayaRouter( config=self.config, - layer_idx=layer_n, + layer_idx=layer_idx, num_moe_experts=self.num_moe_experts, - moe_router_topk=getattr(self.config, "moe_router_topk", 1), + num_experts_per_tok=self.config.num_experts_per_tok, mlp_expansion=mlp_expansion, hidden_size=self.hidden_dim, ) - self.experts = ZayaExperts(self.config, self.num_moe_experts, ffn_hidden_size=ffn_hidden_size) + self.experts = ZayaExperts(self.config, self.num_moe_experts, intermediate_size=intermediate_size) def forward( self, @@ -736,6 +669,13 @@ def forward( route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router( hidden_states, router_states=prev_router_hidden_states ) + + # if the router outputs num_moe_experts, just skip the tokens + # by masking them with id=0 and prob=0 to reuse the expert code + skip_expert = expert_choice == self.num_moe_experts + route_prob = route_prob.masked_fill(skip_expert, 0) + expert_choice = expert_choice.masked_fill(skip_expert, 0) + batch_size, seq_length, emb_dim = hidden_states.shape hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) expert_output = self.experts(hidden_states_flat, expert_choice, route_prob) @@ -744,64 +684,25 @@ def forward( return expert_output, prev_router_hidden_states, router_logits -class ZayaDecoderMLPLayer(GradientCheckpointingLayer): - def __init__( - self, - config: ZayaConfig, - num_moe_experts: int, - mlp_expansion: int, - ffn_hidden_size: int, - layer_n: int, - ): - super().__init__() - self.config = config - self.zaya_block = ZayaBlock( - config, - num_moe_experts, - mlp_expansion, - ffn_hidden_size, - layer_n, - ) - self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) - self.res_scale = ResidualScaling(config, layer_n) - - def forward( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor | None, - prev_router_hidden_states: torch.Tensor | None = None, - output_router_logits: bool = False, - **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]: - hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) - - hidden_states, prev_router_hidden_states, router_logits = self.zaya_block( - hidden_states, - prev_router_hidden_states, - ) - - return ( - hidden_states, - router_logits if output_router_logits else None, - residual, - prev_router_hidden_states, - ) - - +@auto_docstring class ZayaPreTrainedModel(PreTrainedModel): config: ZayaConfig - config_class = ZayaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"] + _no_split_modules = ["ZayaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True + # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache. + _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { "router_logits": OutputRecorder(ZayaRouter, index=3), + "hidden_states": ZayaDecoderLayer, + "attentions": ZayaAttention, } + config_class = ZayaConfig @torch.no_grad() def _init_weights(self, module): @@ -809,7 +710,7 @@ def _init_weights(self, module): if isinstance(module, ResidualScaling): init.ones_(module.hidden_states_scale) init.zeros_(module.hidden_states_bias) - if module.not_first_layer: + if module.has_residual_scale: init.ones_(module.residual_scale) init.zeros_(module.residual_bias) elif isinstance(module, ZayaRouter): @@ -821,6 +722,33 @@ def _init_weights(self, module): std = self.config.initializer_range init.normal_(module.gate_up_proj, mean=0.0, std=std) init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, ZayaRotaryEmbedding): + for layer_type in module.layer_types: + rope_init_fn = module.compute_default_rope_parameters + if module.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]] + curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type) + getattr(module, f"{layer_type}_inv_freq").copy_(curr_inv_freq) + getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq) + + +def make_zaya_cache(config: ZayaConfig) -> DynamicCache: + """ + Create ZAYA's native hybrid cache. + + `config.layer_types` is reserved for full/sliding attention masks and RoPE parameters. Cache layers use the native hybrid layout because every ZAYA decoder layer has attention, convolution, and recurrent states. + """ + cache_config = copy.copy(config) + cache_config.layer_types = ["hybrid"] * config.num_hidden_layers + return DynamicCache(config=cache_config) + + +def _is_zaya_cache(past_key_values: Cache) -> bool: + return ( + isinstance(past_key_values, DynamicCache) + and len(past_key_values.layers) > 0 + and isinstance(past_key_values.layers[0], LinearAttentionAndFullAttentionLayer) + ) @auto_docstring @@ -830,36 +758,16 @@ def __init__(self, config: ZayaConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = [] - - for layer_n in range(config.num_hidden_layers): - if layer_n % 2 == 1: - self.layers.append( - ZayaDecoderMLPLayer( - config, - config.num_experts, - config.zaya_mlp_expansion, - config.ffn_hidden_size, - layer_n, - ) - ) - else: - self.layers.append(ZayaDecoderATTLayer(config, layer_n)) - self.layers = nn.ModuleList(self.layers) + self.layers = nn.ModuleList( + [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.gradient_checkpointing = False - self.res_scale = ResidualScaling(config, config.num_hidden_layers) + self.res_scale = ResidualScaling(config.hidden_size) self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) self.rotary_emb = ZayaRotaryEmbedding(config=config) - if self.config.swa_layers is not None: - swa_config = copy.copy(config) - swa_config.rope_parameters = { - **config.rope_parameters, - "rope_theta": swa_config.swa_rotary_base, - } - self.swa_rotary_emb = ZayaRotaryEmbedding(config=swa_config) self.post_init() @@ -880,25 +788,18 @@ def forward( past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - output_router_logits: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if use_cache and past_key_values is None: - past_key_values = ZayaDynamicCache( - self.config, inputs_embeds.shape[0], dtype=self.dtype, device=self.device - ) + if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)): + if past_key_values is not None and past_key_values.get_seq_length() > 0: + raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.") + past_key_values = make_zaya_cache(self.config) residual = None @@ -910,48 +811,44 @@ def forward( device=inputs_embeds.device, ).unsqueeze(0) - causal_mask = self._update_causal_mask( + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError( + "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." + ) + + causal_mask_mapping = self._update_causal_mask( attention_mask, inputs_embeds, position_ids, past_key_values, ) - if attention_mask is not None and attention_mask.ndim != 2: - raise ValueError("ZAYA CCA requires a 2D `attention_mask` to mask padding tokens before convolution.") - # ZayaDynamicCache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. - # CCA only needs it during multi-token prefill; single-token decoding uses the cached convolution state. - attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None + padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None + + # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. + # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state. if inputs_embeds.shape[1] == 1: - attention_mask_2d = None + padding_mask = None hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) - if self.config.swa_layers is not None: - swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids) + position_embeddings = { + layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) + for layer_type in set(self.config.layer_types) + } - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None prev_router_hidden_states = None for layer_n, decoder_layer in enumerate(self.layers): - if self.config.swa_layers is not None: - emb_to_use = position_embeddings if self.config.swa_layers[layer_n] == 0 else swa_position_embeddings - else: - emb_to_use = position_embeddings - if output_hidden_states: - all_hidden_states += (hidden_states,) - + layer_type = self.config.layer_types[layer_n] + emb_to_use = position_embeddings[layer_type] + mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask} layer_outputs = decoder_layer( hidden_states, residual, - attention_mask=causal_mask, - position_ids=position_ids, + prev_router_hidden_states, + attention_mask=mask_mapping, past_key_values=past_key_values, - output_attentions=output_attentions, position_embeddings=emb_to_use, - prev_router_hidden_states=prev_router_hidden_states, - attention_mask_2d=attention_mask_2d, **kwargs, ) @@ -959,23 +856,11 @@ def forward( residual = layer_outputs[2] prev_router_hidden_states = layer_outputs[3] - if isinstance(decoder_layer, ZayaDecoderATTLayer): - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm) - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, ) def _update_causal_mask( @@ -985,33 +870,37 @@ def _update_causal_mask( position_ids: torch.Tensor, past_key_values: Cache, ): - return create_causal_mask( - config=self.config, - inputs_embeds=input_tensor, - attention_mask=attention_mask, - past_key_values=past_key_values, - position_ids=position_ids, - ) - - -@auto_docstring + mask_kwargs = { + "config": self.config, + "inputs_embeds": input_tensor, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + mask_creation_functions = { + "full_attention": lambda: create_causal_mask(**mask_kwargs), + "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs), + } + causal_mask_mapping = {} + for layer_type in set(self.config.layer_types): + causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]() + return causal_mask_mapping + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} _is_stateful = True def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(config) self.model = ZayaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.post_init() - def set_decoder(self, decoder): - self.model = decoder - @can_return_tuple @auto_docstring def forward( @@ -1027,11 +916,28 @@ def forward( logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> MoeCausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, ZayaForCausalLM + + >>> model = ZayaForCausalLM.from_pretrained("meta-zaya/Zaya-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-zaya/Zaya-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1043,80 +949,21 @@ def forward( ) hidden_states = outputs.last_hidden_state - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) return MoeCausalLMOutputWithPast( loss=loss, - aux_loss=None, logits=logits, - past_key_values=outputs.past_key_values if use_cache else None, + past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - position_ids=None, - use_cache=True, - logits_to_keep=None, - **kwargs, - ): - if past_key_values is not None and not isinstance(past_key_values, ZayaDynamicCache): - raise ValueError( - f"Zaya uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}." - ) - - model_inputs = super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - **kwargs, - ) - return model_inputs - - def _prepare_cache_for_generation( - self, - generation_config, - model_kwargs: dict, - generation_mode, - batch_size: int, - max_cache_length: int, - ): - if generation_config.use_cache is False: - return - - if "past_key_values" not in model_kwargs: - cache_batch_size = batch_size * max(generation_config.num_beams, generation_config.num_return_sequences) - model_kwargs["past_key_values"] = ZayaDynamicCache( - self.config, cache_batch_size, dtype=self.dtype, device=self.device - ) - generation_config.cache_implementation = None - return super()._prepare_cache_for_generation( - generation_config=generation_config, - model_kwargs=model_kwargs, - generation_mode=generation_mode, - batch_size=batch_size, - max_cache_length=max_cache_length, - ) - __all__ = ["ZayaPreTrainedModel", "ZayaModel", "ZayaForCausalLM"] diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 50cad3bd10ea..cbafc3200146 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -31,7 +31,7 @@ from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeModelOutputWithPast -from ...modeling_rope_utils import RopeParameters +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -54,15 +54,15 @@ @strict class ZayaConfig(PreTrainedConfig): r""" - ffn_hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the feed-forward and expert hidden states, translate it to `intermediate_size`. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimension of the feed-forward and expert hidden states. num_key_value_heads (`int`, *optional*, defaults to 2): Number of key/value groups. partial_rotary_factor (`float`, *optional*, defaults to 0.5): Fraction of each attention head dimension using rotary embeddings. lm_head_bias (`bool`, *optional*, defaults to `False`): Whether to add a bias to the language modeling head. - moe_router_topk (`int`, *optional*, defaults to 1): + num_experts_per_tok (`int`, *optional*, defaults to 1): Number of selected experts per token. ZAYA checkpoints use top-1 routing. zaya_mlp_expansion (`int`, *optional*, defaults to 256): Expansion size used by the dense ZAYA blocks. @@ -91,11 +91,11 @@ class ZayaConfig(PreTrainedConfig): vocab_size: int = 262272 hidden_size: int = 2048 - ffn_hidden_size: int = 4096 - num_hidden_layers: int = 80 + intermediate_size: int = 4096 + num_hidden_layers: int = 40 num_experts: int = 16 num_attention_heads: int = 8 - num_key_value_heads: int | None = 2 + num_key_value_heads: int = 2 hidden_act: str = "silu" head_dim: int = 128 max_position_embeddings: int = 131072 @@ -108,10 +108,10 @@ class ZayaConfig(PreTrainedConfig): attention_bias: bool = False lm_head_bias: bool = False attention_dropout: float | int = 0.0 - moe_router_topk: int = 1 + num_experts_per_tok: int = 1 zaya_mlp_expansion: int = 256 - cca_time0: int | None = 2 - cca_time1: int | None = 2 + cca_time0: int = 2 + cca_time1: int = 2 sliding_window: int | None = None layer_types: list[str] | None = None swa_rotary_base: float | int = 10000.0 @@ -121,60 +121,14 @@ class ZayaConfig(PreTrainedConfig): eos_token_id: int | list[int] | None = 106 def __post_init__(self, **kwargs): - for unused_checkpoint_kwarg in ( - "cca", - "num_query_groups", - "activation_func", - "normalization", - "add_bias_linear", - "gated_linear_unit", - "fused_add_norm", - "apply_rope_fusion", - "bias_activation_fusion", - "activation_func_fp8_input_store", - "clamp_temp", - "kv_channels", - "mamba_cache_dtype", - "residual_in_fp32", - "rope_scaling", - "scale_residual_merge", - "zaya_high_prec", - "zaya_use_mod", - "zaya_use_eda", - ): - kwargs.pop(unused_checkpoint_kwarg, None) - - self.intermediate_size = self.ffn_hidden_size - self.num_experts_per_tok = self.moe_router_topk - - self.num_key_value_heads = ( - self.num_attention_heads if self.num_key_value_heads is None else self.num_key_value_heads + self.layer_types = ( + ["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) ) - legacy_swa_layers = kwargs.pop("swa_layers", None) - swa_window_sizes = {int(window_size) for window_size in (legacy_swa_layers or []) if int(window_size) > 0} - if self.sliding_window is None and swa_window_sizes: - self.sliding_window = max(swa_window_sizes) - if self.layer_types is None: - if legacy_swa_layers is None: - self.layer_types = ["full_attention"] * self.num_hidden_layers - else: - self.layer_types = [ - "full_attention" if layer_type == 0 else "sliding_attention" for layer_type in legacy_swa_layers - ] - else: - self.layer_types = list(self.layer_types) - - self.cca_time0 = 2 if self.cca_time0 is None else self.cca_time0 - self.cca_time1 = 2 if self.cca_time1 is None else self.cca_time1 - - super().__post_init__(**kwargs) - - def convert_rope_params_to_dict(self, **kwargs): default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = { "full_attention": { "rope_type": "default", - "rope_theta": kwargs.pop("rope_theta", self.default_theta), + "rope_theta": self.default_theta, "partial_rotary_factor": self.partial_rotary_factor, }, "sliding_attention": { @@ -183,21 +137,19 @@ def convert_rope_params_to_dict(self, **kwargs): "partial_rotary_factor": self.partial_rotary_factor, }, } - layer_types = set(self.layer_types) - if self.rope_parameters is None: - self.rope_parameters = {layer_type: default_rope_params[layer_type] for layer_type in layer_types} - else: self.rope_parameters = { - layer_type: {**default_rope_params[layer_type], **(self.rope_parameters.get(layer_type) or {})} - for layer_type in layer_types + layer_type: default_rope_params[layer_type] for layer_type in set(self.layer_types) } + super().__post_init__(**kwargs) + + def convert_rope_params_to_dict(self, **kwargs): + # ZAYA uses nested RoPE parameters keyed by layer type. Keep the base RoPE BC conversion from treating them + # like a single flat RoPE dict and injecting top-level keys such as `rope_theta`. return kwargs def validate_architecture(self): - if self.head_dim is None: - raise ValueError("`head_dim` must be set for ZAYA.") if self.num_experts_per_tok != 1: raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.") if self.num_attention_heads % self.num_key_value_heads != 0: @@ -210,8 +162,6 @@ def validate_architecture(self): raise ValueError("`sliding_window` must be set when `layer_types` contains `sliding_attention`.") if self.sliding_window is not None and self.sliding_window <= 0: raise ValueError("`sliding_window` must be a strictly positive integer.") - if (self.cca_time0, self.cca_time1) != (2, 2): - raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.") class ZayaRotaryEmbedding(LagunaRotaryEmbedding): @@ -222,13 +172,14 @@ class ZayaRMSNorm(Qwen3MoeRMSNorm): pass -def _make_zaya_cache(config: ZayaConfig) -> DynamicCache: +def make_zaya_cache(config: ZayaConfig) -> DynamicCache: + """ + Create ZAYA's native hybrid cache. + + `config.layer_types` is reserved for full/sliding attention masks and RoPE parameters. Cache layers use the native hybrid layout because every ZAYA decoder layer has attention, convolution, and recurrent states. + """ cache_config = copy.copy(config) - # layer_types is used to distinct the rope_type (full or swa) - # so need to construct a new layer_types to construct cache - cache_config.layer_types = [ - "hybrid" if layer_idx % 2 == 0 else "moe" for layer_idx in range(config.num_hidden_layers) - ] + cache_config.layer_types = ["hybrid"] * config.num_hidden_layers return DynamicCache(config=cache_config) @@ -249,6 +200,8 @@ class ZayaCCAProjection(nn.Module): for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses `val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**. + + Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys. """ def __init__(self, config: ZayaConfig, layer_idx: int): @@ -298,10 +251,10 @@ def forward( self, hidden_states: torch.Tensor, past_key_values: Cache | None, - attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, ): - if attention_mask is not None: - hidden_states = hidden_states * attention_mask[:, :, None].to(hidden_states.dtype) + if padding_mask is not None: + hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -389,19 +342,16 @@ def __init__(self, config: ZayaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | dict[str, torch.Tensor | None] | None = None, + attention_mask: dict[str, Any] | None = None, past_key_values: Cache | None = None, - output_attentions: bool = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: batch_size, seq_length, _ = hidden_states.shape - if isinstance(attention_mask, dict): - causal_mask = attention_mask.get("causal") - padding_mask = attention_mask.get("padding") - else: - causal_mask = attention_mask - padding_mask = None + mask_mapping = attention_mask or {} + causal_mask = mask_mapping.get("causal") + padding_mask = mask_mapping.get("padding") query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask) @@ -438,7 +388,7 @@ def forward( dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, - output_attentions=output_attentions, + **kwargs, ) attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim) @@ -447,25 +397,32 @@ def forward( return attn_output, attn_weights, past_key_values -class ZayaDecoderATTLayer(GradientCheckpointingLayer): - def __init__(self, config: ZayaConfig, layer_n: int): +class ZayaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__() self.config = config - self.self_attn = ZayaAttention(config, layer_n) - + self.self_attn = ZayaAttention(config, layer_idx) self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) - self.res_scale = ResidualScaling(config, layer_n) + self.res_scale = ResidualScaling(config.hidden_size, has_residual_scale=layer_idx != 0) + self.zaya_block = ZayaSparseMoeBlock( + config, + config.num_experts, + config.zaya_mlp_expansion, + config.intermediate_size, + layer_idx, + ) + self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + self.post_attention_res_scale = ResidualScaling(config.hidden_size) def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor | None, - attention_mask: torch.Tensor | dict[str, torch.Tensor | None] | None = None, + prev_router_hidden_states: torch.Tensor | None = None, + attention_mask: dict[str, Any] | None = None, past_key_values: Cache | None = None, - output_attentions: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - prev_router_hidden_states: torch.Tensor | None = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]: hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) @@ -473,27 +430,36 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, - output_attentions=output_attentions, position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states, residual = _apply_residual_scaling( + hidden_states, residual, self.post_attention_res_scale, self.post_attention_norm + ) + + hidden_states, prev_router_hidden_states, _ = self.zaya_block( + hidden_states, + prev_router_hidden_states, ) - return hidden_states, self_attn_weights if output_attentions else None, residual, prev_router_hidden_states + return hidden_states, self_attn_weights, residual, prev_router_hidden_states class ResidualScaling(nn.Module): - def __init__(self, config, layer_n): + def __init__(self, hidden_size: int, has_residual_scale: bool = True): super().__init__() - self.not_first_layer = layer_n != 0 - self.hidden_states_scale = torch.nn.Parameter(torch.ones(config.hidden_size)) - self.hidden_states_bias = torch.nn.Parameter(torch.zeros(config.hidden_size)) + self.has_residual_scale = has_residual_scale + self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) + self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size)) - if self.not_first_layer: - self.residual_scale = torch.nn.Parameter(torch.ones(config.hidden_size)) - self.residual_bias = torch.nn.Parameter(torch.zeros(config.hidden_size)) + if self.has_residual_scale: + self.residual_scale = nn.Parameter(torch.ones(hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor): hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale - if self.not_first_layer: + if self.has_residual_scale: residual = (residual + self.residual_bias) * self.residual_scale return residual, hidden_states @@ -546,8 +512,7 @@ def __init__( self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True) - zaya_first_layer = 1 - self.use_eda = self.layer_idx != zaya_first_layer + self.use_eda = self.layer_idx != 0 self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=config.norm_epsilon) if self.use_eda: @@ -605,7 +570,7 @@ def __init__( num_moe_experts: int, mlp_expansion: int, intermediate_size: int, - layer_n: int, + layer_idx: int, ): super().__init__() self.config = config @@ -613,7 +578,7 @@ def __init__( self.num_moe_experts = num_moe_experts self.router = ZayaRouter( config=self.config, - layer_idx=layer_n, + layer_idx=layer_idx, num_moe_experts=self.num_moe_experts, num_experts_per_tok=self.config.num_experts_per_tok, mlp_expansion=mlp_expansion, @@ -629,6 +594,9 @@ def forward( route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router( hidden_states, router_states=prev_router_hidden_states ) + + # if the router outputs num_moe_experts, just skip the tokens + # by masking them with id=0 and prob=0 to reuse the expert code skip_expert = expert_choice == self.num_moe_experts route_prob = route_prob.masked_fill(skip_expert, 0) expert_choice = expert_choice.masked_fill(skip_expert, 0) @@ -641,59 +609,15 @@ def forward( return expert_output, prev_router_hidden_states, router_logits -class ZayaDecoderMLPLayer(GradientCheckpointingLayer): - def __init__( - self, - config: ZayaConfig, - num_moe_experts: int, - mlp_expansion: int, - intermediate_size: int, - layer_n: int, - ): - super().__init__() - self.config = config - self.zaya_block = ZayaSparseMoeBlock( - config, - num_moe_experts, - mlp_expansion, - intermediate_size, - layer_n, - ) - self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) - self.res_scale = ResidualScaling(config, layer_n) - - def forward( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor | None, - prev_router_hidden_states: torch.Tensor | None = None, - output_router_logits: bool = False, - **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]: - hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) - - hidden_states, prev_router_hidden_states, router_logits = self.zaya_block( - hidden_states, - prev_router_hidden_states, - ) - - return ( - hidden_states, - router_logits if output_router_logits else None, - residual, - prev_router_hidden_states, - ) - - class ZayaPreTrainedModel(LlamaPreTrainedModel): config: ZayaConfig config_class = ZayaConfig - _no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"] + _no_split_modules = ["ZayaDecoderLayer"] # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache. _can_compile_fullgraph = False _can_record_outputs = { "router_logits": OutputRecorder(ZayaRouter, index=3), - "hidden_states": [ZayaDecoderATTLayer, ZayaDecoderMLPLayer], + "hidden_states": ZayaDecoderLayer, "attentions": ZayaAttention, } @@ -703,7 +627,7 @@ def _init_weights(self, module): if isinstance(module, ResidualScaling): init.ones_(module.hidden_states_scale) init.zeros_(module.hidden_states_bias) - if module.not_first_layer: + if module.has_residual_scale: init.ones_(module.residual_scale) init.zeros_(module.residual_bias) elif isinstance(module, ZayaRouter): @@ -715,6 +639,14 @@ def _init_weights(self, module): std = self.config.initializer_range init.normal_(module.gate_up_proj, mean=0.0, std=std) init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, ZayaRotaryEmbedding): + for layer_type in module.layer_types: + rope_init_fn = module.compute_default_rope_parameters + if module.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]] + curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type) + getattr(module, f"{layer_type}_inv_freq").copy_(curr_inv_freq) + getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq) @auto_docstring @@ -724,25 +656,12 @@ def __init__(self, config: ZayaConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = [] - - for layer_n in range(config.num_hidden_layers): - if layer_n % 2 == 1: - self.layers.append( - ZayaDecoderMLPLayer( - config, - config.num_experts, - config.zaya_mlp_expansion, - config.intermediate_size, - layer_n, - ) - ) - else: - self.layers.append(ZayaDecoderATTLayer(config, layer_n)) - self.layers = nn.ModuleList(self.layers) + self.layers = nn.ModuleList( + [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.gradient_checkpointing = False - self.res_scale = ResidualScaling(config, config.num_hidden_layers) + self.res_scale = ResidualScaling(config.hidden_size) self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) @@ -777,8 +696,8 @@ def forward( if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)): if past_key_values is not None and past_key_values.get_seq_length() > 0: - raise ValueError("ZAYA requires a native hybrid cache created from `_make_zaya_cache`.") - past_key_values = _make_zaya_cache(self.config) + raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.") + past_key_values = make_zaya_cache(self.config) residual = None @@ -790,21 +709,19 @@ def forward( device=inputs_embeds.device, ).unsqueeze(0) - if isinstance(attention_mask, dict): - causal_mask_mapping = attention_mask - padding_mask = None - else: - causal_mask_mapping = self._update_causal_mask( - attention_mask, - inputs_embeds, - position_ids, - past_key_values, - ) - padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None - if attention_mask is not None and not isinstance(attention_mask, dict) and attention_mask.ndim != 2: + if attention_mask is not None and attention_mask.ndim != 2: raise ValueError( "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." ) + + causal_mask_mapping = self._update_causal_mask( + attention_mask, + inputs_embeds, + position_ids, + past_key_values, + ) + padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None + # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state. if inputs_embeds.shape[1] == 1: @@ -822,15 +739,14 @@ def forward( for layer_n, decoder_layer in enumerate(self.layers): layer_type = self.config.layer_types[layer_n] emb_to_use = position_embeddings[layer_type] - attention_mask = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask} + mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask} layer_outputs = decoder_layer( hidden_states, residual, - attention_mask=attention_mask, - position_ids=position_ids, + prev_router_hidden_states, + attention_mask=mask_mapping, past_key_values=past_key_values, position_embeddings=emb_to_use, - prev_router_hidden_states=prev_router_hidden_states, **kwargs, ) diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index 5e16b744c989..027901d2ffce 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -27,7 +27,7 @@ from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel from transformers.cache_utils import DynamicCache, LinearAttentionAndFullAttentionLayer - from transformers.models.zaya.modeling_zaya import ZayaCCAProjection, _make_zaya_cache + from transformers.models.zaya.modeling_zaya import ZayaCCAProjection, make_zaya_cache from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester @@ -49,14 +49,16 @@ def __init__(self, parent): intermediate_size=64, ) self.head_dim = 8 - self.ffn_hidden_size = 64 self.num_experts = 4 - self.moe_router_topk = 1 + self.num_experts_per_tok = 1 self.zaya_mlp_expansion = 4 self.tie_word_embeddings = False self.rope_parameters = { - "rope_theta": 10000, - "rope_type": "default", + "full_attention": { + "rope_theta": 10000, + "rope_type": "default", + "partial_rotary_factor": 0.5, + }, } @@ -82,18 +84,12 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l conv_shape = self._get_conv_state_shape(batch_size, config) recurrent_shape = self._get_recurrent_state_shape(batch_size, config) - for layer_idx, layer in enumerate(past_key_values.layers): - if layer_idx % 2 == 0: - self.assertIs(type(layer), LinearAttentionAndFullAttentionLayer) - self.assertEqual(layer.keys.shape, attention_shape) - self.assertEqual(layer.values.shape, attention_shape) - self.assertEqual(layer.conv_states.shape, conv_shape) - self.assertEqual(layer.recurrent_states.shape, recurrent_shape) - else: - self.assertIsNone(getattr(layer, "keys", None)) - self.assertIsNone(getattr(layer, "values", None)) - self.assertIsNone(layer.conv_states) - self.assertIsNone(layer.recurrent_states) + for layer in past_key_values.layers: + self.assertIs(type(layer), LinearAttentionAndFullAttentionLayer) + self.assertEqual(layer.keys.shape, attention_shape) + self.assertEqual(layer.values.shape, attention_shape) + self.assertEqual(layer.conv_states.shape, conv_shape) + self.assertEqual(layer.recurrent_states.shape, recurrent_shape) def is_pipeline_test_to_skip( self, @@ -132,7 +128,7 @@ def test_attention_outputs(self): with torch.no_grad(): outputs = model(**self._prepare_for_class({**inputs_dict, "output_attentions": True}, model_class)) - expected_attn_layers = (config.num_hidden_layers + 1) // 2 + expected_attn_layers = config.num_hidden_layers self.assertEqual(len(outputs.attentions), expected_attn_layers) self.assertEqual( outputs.attentions[0].shape, @@ -248,32 +244,22 @@ def test_moe_router_logits(self): with torch.no_grad(): outputs = model(**inputs_dict, output_router_logits=True) - expected_moe_layers = config.num_hidden_layers // 2 + expected_moe_layers = config.num_hidden_layers self.assertEqual(len(outputs.router_logits), expected_moe_layers) self.assertEqual( outputs.router_logits[0].shape, (self.model_tester.batch_size * self.model_tester.seq_length, config.num_experts + 1), ) - def test_moe_router_topk_validation(self): - with self.assertRaisesRegex(StrictDataclassClassValidationError, "moe_router_topk=1"): - ZayaConfig(moe_router_topk=2) - - def test_legacy_swa_layers_translate_to_layer_types(self): - config = ZayaConfig(num_hidden_layers=4, swa_layers=[4096, 0, 4096, 0], swa_rotary_base=10000) - - self.assertEqual( - config.layer_types, ["sliding_attention", "full_attention", "sliding_attention", "full_attention"] - ) - self.assertEqual(config.sliding_window, 4096) - self.assertEqual(config.rope_parameters["full_attention"]["rope_theta"], config.default_theta) - self.assertEqual(config.rope_parameters["sliding_attention"]["rope_theta"], 10000) + def test_num_experts_per_tok_validation(self): + with self.assertRaisesRegex(StrictDataclassClassValidationError, "num_experts_per_tok=1"): + ZayaConfig(num_experts_per_tok=2) def test_sliding_attention_mask_is_used(self): config = ZayaConfig( vocab_size=128, hidden_size=32, - ffn_hidden_size=64, + intermediate_size=64, num_hidden_layers=4, num_experts=4, num_attention_heads=4, @@ -299,7 +285,7 @@ def test_cca_cache_matches_full_forward(self): config = ZayaConfig( vocab_size=128, hidden_size=32, - ffn_hidden_size=64, + intermediate_size=64, num_hidden_layers=1, num_experts=4, num_attention_heads=4, @@ -315,7 +301,7 @@ def test_cca_cache_matches_full_forward(self): with torch.no_grad(): full = cca(hidden_states, None, None) - cache = _make_zaya_cache(config) + cache = make_zaya_cache(config) cca(hidden_states[:, :4], cache, None) cached = cca(hidden_states[:, 4:], cache, None) @@ -326,7 +312,7 @@ def test_cca_cache_matches_full_forward_multi_token(self): config = ZayaConfig( vocab_size=128, hidden_size=32, - ffn_hidden_size=64, + intermediate_size=64, num_hidden_layers=1, num_experts=4, num_attention_heads=4, @@ -342,7 +328,7 @@ def test_cca_cache_matches_full_forward_multi_token(self): with torch.no_grad(): full = cca(hidden_states, None, None) - cache = _make_zaya_cache(config) + cache = make_zaya_cache(config) cca(hidden_states[:, :3], cache, None) cached = cca(hidden_states[:, 3:], cache, None) @@ -351,7 +337,7 @@ def test_cca_cache_matches_full_forward_multi_token(self): def test_zaya_cache_reorder_and_reset(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - cache = _make_zaya_cache(config) + cache = make_zaya_cache(config) conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim cache.update_conv_state( torch.arange(2 * conv_state_size * 2, device=torch_device, dtype=torch.float32).view( From f3e8e02c7b87632dc40ded0e062d93cec888e33a Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 19:05:14 +0800 Subject: [PATCH 14/36] align with official implement, check 74b conversion --- src/transformers/models/zaya/convert_zaya_weights_to_hf.py | 4 +++- src/transformers/models/zaya/modeling_zaya.py | 4 +++- src/transformers/models/zaya/modular_zaya.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py index ba9198b9c666..a1b9b357dc52 100644 --- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py +++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py @@ -156,7 +156,9 @@ def convert_config(input_dir: Path, output_dir: Path) -> None: sliding_window = config_dict.get("sliding_window") if sliding_window is None: positive_windows = [int(window_size) for window_size in swa_layers if int(window_size) > 0] - sliding_window = max(positive_windows) if positive_windows else None + # Original ZAYA stores the number of previous tokens attended by SWA layers. Transformers' sliding window + # is the total local attention span, including the current token. + sliding_window = max(positive_windows) + 1 if positive_windows else None rope_parameters = { "full_attention": { diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index 20662110b172..3f59f4fbee2c 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -877,9 +877,11 @@ def _update_causal_mask( "past_key_values": past_key_values, "position_ids": position_ids, } + # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. + sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} mask_creation_functions = { "full_attention": lambda: create_causal_mask(**mask_kwargs), - "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs), + "sliding_attention": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), } causal_mask_mapping = {} for layer_type in set(self.config.layer_types): diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index cbafc3200146..7c1fd957cab6 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -775,9 +775,11 @@ def _update_causal_mask( "past_key_values": past_key_values, "position_ids": position_ids, } + # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. + sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} mask_creation_functions = { "full_attention": lambda: create_causal_mask(**mask_kwargs), - "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs), + "sliding_attention": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), } causal_mask_mapping = {} for layer_type in set(self.config.layer_types): From f4f206c576e9b873aef15692172406fd365e49a6 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 19:20:33 +0800 Subject: [PATCH 15/36] easier test --- tests/models/zaya/test_modeling_zaya.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index 027901d2ffce..80f9112a97fc 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -397,22 +397,13 @@ def test_model_logits(self): inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) with torch.no_grad(): - outputs = model(**inputs, use_cache=False, output_hidden_states=True, return_dict=True) + logits = model(**inputs, use_cache=False, return_dict=True).logits.float().cpu() - logits = outputs.logits.float().cpu() - hidden_states = outputs.hidden_states[-1].float().cpu() + self.assertEqual(logits.shape, (1, inputs.input_ids.shape[-1], model.config.vocab_size)) + self.assertTrue(torch.isfinite(logits).all().item()) - EXPECTED_HIDDEN_MEAN = torch.tensor( - [[0.0399, -0.0123, -0.0560, -0.0480, -0.0627, -0.0362, -0.0220, 0.0004, -0.0321, -0.0263, 0.0046]] - ) - torch.testing.assert_close(hidden_states.mean(-1), EXPECTED_HIDDEN_MEAN, rtol=1e-2, atol=1e-2) - - EXPECTED_HIDDEN_SLICE = torch.tensor([-2.7812, 0.3320, 4.1562, -0.4395, 1.6406, 1.3359, -1.4531, -2.6719, 5.5000, -4.7500, 2.0625, 0.2930, -2.2344, -2.6094, 2.0781, 2.5000, 0.7969, 0.6836, -0.5469, 1.3906]) # fmt: skip - torch.testing.assert_close(hidden_states[0, 0, :20], EXPECTED_HIDDEN_SLICE, rtol=1e-2, atol=1e-2) - - EXPECTED_LOGITS_SLICE = torch.tensor([-2.3438, 1.7344, 3.7656, -3.8750, 0.4707, -0.7422, -2.5938, -2.7188, -2.9375, -2.9844, -3.0000, -3.0000, -3.0000, -3.0000, -3.0156, -3.0000, -3.0000, -3.0000, -3.0000, -3.0000]) # fmt: skip - torch.testing.assert_close(logits[0, -1, :20], EXPECTED_LOGITS_SLICE, rtol=1e-2, atol=1e-2) - self.assertEqual(logits[0, -1].argmax().item(), 107) + expected_argmax = torch.tensor([[105, 9731, 107, 740, 564, 1601, 611, 236881, 236881, 107, 107]]) + torch.testing.assert_close(logits.argmax(-1), expected_argmax) @slow def test_model_cache_matches_full_forward(self): From 7c48ee10eadc7465bd08a51812c320458b5dd1cc Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 20:09:31 +0800 Subject: [PATCH 16/36] remove mapping since we convert the ckpt --- src/transformers/conversion_mapping.py | 17 ----------------- src/transformers/models/zaya/modeling_zaya.py | 15 ++++++++++++++- src/transformers/models/zaya/modular_zaya.py | 15 ++++++++++++++- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 4c0a7942a698..09bea78c96d6 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -561,23 +561,6 @@ def _build_checkpoint_conversion_mapping(): operations=[Transpose(1, 2, check_dims=True)], ), ], - "zaya": [ - WeightRenaming(r"self_attn\.qkv\.conv_qk\.0\.", "self_attn.qkv.conv_qk_depthwise."), - WeightRenaming(r"self_attn\.qkv\.conv_qk\.1\.", "self_attn.qkv.conv_qk_grouped."), - WeightRenaming(r"zaya_block\.router\.router_mlp\.0\.", "zaya_block.router.router_mlp.fc1."), - WeightRenaming(r"zaya_block\.router\.router_mlp\.2\.", "zaya_block.router.router_mlp.fc2."), - WeightRenaming(r"zaya_block\.router\.router_mlp\.4\.", "zaya_block.router.router_mlp.out_proj."), - WeightConverter( - source_patterns="zaya_block.experts.local_experts.*.linear_fc1.weight", - target_patterns="zaya_block.experts.gate_up_proj", - operations=[MergeModulelist(dim=0)], - ), - WeightConverter( - source_patterns="zaya_block.experts.local_experts.*.linear_fc2.weight", - target_patterns="zaya_block.experts.down_proj", - operations=[MergeModulelist(dim=0)], - ), - ], "phimoe": [ WeightRenaming(".block_sparse_moe.", ".mlp."), WeightRenaming(".gate.weight", ".router.weight"), diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index 3f59f4fbee2c..58d2201ce747 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -736,7 +736,20 @@ def make_zaya_cache(config: ZayaConfig) -> DynamicCache: """ Create ZAYA's native hybrid cache. - `config.layer_types` is reserved for full/sliding attention masks and RoPE parameters. Cache layers use the native hybrid layout because every ZAYA decoder layer has attention, convolution, and recurrent states. + ZAYA uses `config.layer_types` for the attention mask and RoPE variant of each layer (`"full_attention"` or + `"sliding_attention"`). That is separate from the cache layout: every ZAYA decoder layer needs the native + `"hybrid"` cache layer because it stores all three states used during decoding: + + - The regular dynamic attention KV cache, updated after the CCA projection and RoPE application. + - `conv_states`, the pre-convolution q/k tail used by `ZayaCCAProjection` on the next decoding step. Its channel + dimension is `num_attention_heads * head_dim + num_key_value_heads * head_dim`, and its time dimension is + `cca_time0 + cca_time1 - 2`. + - `recurrent_states`, ZAYA's delayed value state. It stores the previous token's `val_proj2` output (the legacy + `prev_h2`/second value projection state), so the next token can build its value from the current `val_proj1` + output plus the cached delayed `val_proj2`. + + The copied config only changes `layer_types` to `"hybrid"` so `DynamicCache` instantiates + `LinearAttentionAndFullAttentionLayer`; it does not alter the model's mask or RoPE layer types. """ cache_config = copy.copy(config) cache_config.layer_types = ["hybrid"] * config.num_hidden_layers diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 7c1fd957cab6..423848c7f01d 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -176,7 +176,20 @@ def make_zaya_cache(config: ZayaConfig) -> DynamicCache: """ Create ZAYA's native hybrid cache. - `config.layer_types` is reserved for full/sliding attention masks and RoPE parameters. Cache layers use the native hybrid layout because every ZAYA decoder layer has attention, convolution, and recurrent states. + ZAYA uses `config.layer_types` for the attention mask and RoPE variant of each layer (`"full_attention"` or + `"sliding_attention"`). That is separate from the cache layout: every ZAYA decoder layer needs the native + `"hybrid"` cache layer because it stores all three states used during decoding: + + - The regular dynamic attention KV cache, updated after the CCA projection and RoPE application. + - `conv_states`, the pre-convolution q/k tail used by `ZayaCCAProjection` on the next decoding step. Its channel + dimension is `num_attention_heads * head_dim + num_key_value_heads * head_dim`, and its time dimension is + `cca_time0 + cca_time1 - 2`. + - `recurrent_states`, ZAYA's delayed value state. It stores the previous token's `val_proj2` output (the legacy + `prev_h2`/second value projection state), so the next token can build its value from the current `val_proj1` + output plus the cached delayed `val_proj2`. + + The copied config only changes `layer_types` to `"hybrid"` so `DynamicCache` instantiates + `LinearAttentionAndFullAttentionLayer`; it does not alter the model's mask or RoPE layer types. """ cache_config = copy.copy(config) cache_config.layer_types = ["hybrid"] * config.num_hidden_layers From 498c2522c4cfae79b3f4fe0abeeb4946435eda74 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 12 May 2026 21:17:11 +0800 Subject: [PATCH 17/36] use default_swa_theta --- src/transformers/models/zaya/configuration_zaya.py | 6 ++---- src/transformers/models/zaya/convert_zaya_weights_to_hf.py | 5 ++--- src/transformers/models/zaya/modular_zaya.py | 6 ++---- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py index 479d07dea7d4..26bf600d413b 100644 --- a/src/transformers/models/zaya/configuration_zaya.py +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -49,8 +49,6 @@ class ZayaConfig(PreTrainedConfig): Second temporal parameter of the CCA projection. layer_types (`list[str]`, *optional*): Per-layer selector for standard RoPE versus SWA RoPE embeddings. - swa_rotary_base (`float`, *optional*): - RoPE base used by SWA layers. ```python >>> from transformers import ZayaConfig, ZayaModel @@ -65,6 +63,7 @@ class ZayaConfig(PreTrainedConfig): model_type = "zaya" keys_to_ignore_at_inference = ["past_key_values"] default_theta = 5000000.0 + default_swa_theta = 10000.0 vocab_size: int = 262272 hidden_size: int = 2048 @@ -91,7 +90,6 @@ class ZayaConfig(PreTrainedConfig): cca_time1: int = 2 sliding_window: int | None = None layer_types: list[str] | None = None - swa_rotary_base: float | int = 10000.0 output_router_logits: bool = False pad_token_id: int | None = 0 bos_token_id: int | None = 2 @@ -110,7 +108,7 @@ def __post_init__(self, **kwargs): }, "sliding_attention": { "rope_type": "default", - "rope_theta": self.swa_rotary_base, + "rope_theta": self.default_swa_theta, "partial_rotary_factor": self.partial_rotary_factor, }, } diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py index a1b9b357dc52..63a0bd94142f 100644 --- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py +++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py @@ -144,7 +144,7 @@ def convert_config(input_dir: Path, output_dir: Path) -> None: layer_types = _convert_layer_types(config_dict, old_num_hidden_layers, new_num_hidden_layers) partial_rotary_factor = config_dict.get("partial_rotary_factor", ZayaConfig.partial_rotary_factor) rope_theta = config_dict.get("rope_theta", ZayaConfig.default_theta) - swa_rotary_base = config_dict.get("swa_rotary_base", ZayaConfig.swa_rotary_base) + swa_rotary_base = config_dict.get("swa_rotary_base", ZayaConfig.default_swa_theta) intermediate_size = config_dict.get( "intermediate_size", config_dict.get("ffn_hidden_size", ZayaConfig.intermediate_size) ) @@ -173,7 +173,7 @@ def convert_config(input_dir: Path, output_dir: Path) -> None: }, } - for key in (*_UNUSED_CONFIG_KEYS, "swa_layers", "rope_theta"): + for key in (*_UNUSED_CONFIG_KEYS, "swa_layers", "rope_theta", "swa_rotary_base"): config_dict.pop(key, None) config_dict.update( @@ -184,7 +184,6 @@ def convert_config(input_dir: Path, output_dir: Path) -> None: "num_experts_per_tok": num_experts_per_tok, "layer_types": layer_types, "sliding_window": sliding_window, - "swa_rotary_base": swa_rotary_base, "rope_parameters": {layer_type: rope_parameters[layer_type] for layer_type in set(layer_types)}, } ) diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 423848c7f01d..55d2219fac3f 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -72,8 +72,6 @@ class ZayaConfig(PreTrainedConfig): Second temporal parameter of the CCA projection. layer_types (`list[str]`, *optional*): Per-layer selector for standard RoPE versus SWA RoPE embeddings. - swa_rotary_base (`float`, *optional*): - RoPE base used by SWA layers. ```python >>> from transformers import ZayaConfig, ZayaModel @@ -88,6 +86,7 @@ class ZayaConfig(PreTrainedConfig): model_type = "zaya" keys_to_ignore_at_inference = ["past_key_values"] default_theta = 5000000.0 + default_swa_theta = 10000.0 vocab_size: int = 262272 hidden_size: int = 2048 @@ -114,7 +113,6 @@ class ZayaConfig(PreTrainedConfig): cca_time1: int = 2 sliding_window: int | None = None layer_types: list[str] | None = None - swa_rotary_base: float | int = 10000.0 output_router_logits: bool = False pad_token_id: int | None = 0 bos_token_id: int | None = 2 @@ -133,7 +131,7 @@ def __post_init__(self, **kwargs): }, "sliding_attention": { "rope_type": "default", - "rope_theta": self.swa_rotary_base, + "rope_theta": self.default_swa_theta, "partial_rotary_factor": self.partial_rotary_factor, }, } From 3d6306129c568f8dd32b9dd3936dc9c8340db55a Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Wed, 13 May 2026 10:51:29 +0800 Subject: [PATCH 18/36] update date --- docs/source/en/model_doc/zaya.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md index 24468b8df65f..e6a220adbecf 100644 --- a/docs/source/en/model_doc/zaya.md +++ b/docs/source/en/model_doc/zaya.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-09.* +*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-13.* # ZAYA From 4d742969fb022a7485d767b31d907f121fdd396d Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Wed, 13 May 2026 10:54:01 +0800 Subject: [PATCH 19/36] temp init --- src/transformers/models/zaya/modular_zaya.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 55d2219fac3f..fb66f6b0522f 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -641,6 +641,8 @@ def _init_weights(self, module): if module.has_residual_scale: init.ones_(module.residual_scale) init.zeros_(module.residual_bias) + elif isinstance(module, ZayaCCAProjection): + init.ones_(module.temp) elif isinstance(module, ZayaRouter): if module.use_eda: init.ones_(module.router_states_scale) From d77d5d47e7a554ece9cc85a05545aa3034e164c8 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Wed, 13 May 2026 11:11:32 +0800 Subject: [PATCH 20/36] modular --- src/transformers/models/zaya/modeling_zaya.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index 58d2201ce747..008e166b416e 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -713,6 +713,8 @@ def _init_weights(self, module): if module.has_residual_scale: init.ones_(module.residual_scale) init.zeros_(module.residual_bias) + elif isinstance(module, ZayaCCAProjection): + init.ones_(module.temp) elif isinstance(module, ZayaRouter): if module.use_eda: init.ones_(module.router_states_scale) From 1c16fecb90f8a0ea54ba68db51e541ec697557e9 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Wed, 13 May 2026 18:05:14 +0800 Subject: [PATCH 21/36] better residual scaling --- .../models/zaya/convert_zaya_weights_to_hf.py | 23 ++- src/transformers/models/zaya/modeling_zaya.py | 168 ++++++++---------- src/transformers/models/zaya/modular_zaya.py | 95 +++++----- 3 files changed, 133 insertions(+), 153 deletions(-) diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py index 63a0bd94142f..228532e53fd4 100644 --- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py +++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py @@ -59,8 +59,10 @@ def _rename_common(rest: str) -> str: replacements = ( - ("self_attn.qkv.conv_qk.0.", "self_attn.qkv.conv_qk_depthwise."), - ("self_attn.qkv.conv_qk.1.", "self_attn.qkv.conv_qk_grouped."), + ("self_attn.qkv.conv_qk.0.", "self_attn.qkv_proj.conv_qk_depthwise."), + ("self_attn.qkv.conv_qk.1.", "self_attn.qkv_proj.conv_qk_grouped."), + ("self_attn.qkv.temp", "self_attn.temp"), + ("self_attn.qkv.", "self_attn.qkv_proj."), ("zaya_block.router.router_mlp.0.", "zaya_block.router.router_mlp.fc1."), ("zaya_block.router.router_mlp.2.", "zaya_block.router.router_mlp.fc2."), ("zaya_block.router.router_mlp.4.", "zaya_block.router.router_mlp.out_proj."), @@ -87,12 +89,15 @@ def _expert_target(name: str) -> tuple[str, int] | None: return target, expert_idx -def convert_weight_name(name: str) -> str | None: +def convert_weight_name(name: str, old_num_hidden_layers: int | None = None) -> str | None: if _expert_target(name) is not None: return None match = _LAYER_PATTERN.match(name) if match is None: + if old_num_hidden_layers is not None and name.startswith("model.res_scale."): + new_layer_idx = old_num_hidden_layers // 2 - 1 + return f"model.layers.{new_layer_idx}.post_mlp_res_scale.{name.removeprefix('model.res_scale.')}" return name old_layer_idx = int(match.group(1)) @@ -101,8 +106,12 @@ def convert_weight_name(name: str) -> str | None: if old_layer_idx % 2 == 0: rest = _rename_common(rest) - if rest.startswith(("self_attn.", "input_norm.", "res_scale.")): + if rest.startswith(("self_attn.", "input_norm.")): return f"model.layers.{new_layer_idx}.{rest}" + if rest.startswith("res_scale."): + if old_layer_idx == 0: + return f"model.input_{rest.removeprefix('res_scale.')}" + return f"model.layers.{new_layer_idx - 1}.post_mlp_res_scale.{rest.removeprefix('res_scale.')}" else: rest = _rename_common(rest) if rest.startswith("zaya_block."): @@ -209,6 +218,7 @@ def copy_non_weight_files(input_dir: Path, output_dir: Path) -> None: def _build_weight_plan(input_dir: Path) -> tuple[dict[str, str], dict[str, list[str]], dict[str, str], dict]: index = json.loads((input_dir / "model.safetensors.index.json").read_text()) old_weight_map = index["weight_map"] + old_num_hidden_layers = int(json.loads((input_dir / "config.json").read_text())["num_hidden_layers"]) converted_weight_map = {} normal_sources_by_output_file = defaultdict(list) expert_sources_by_target = defaultdict(list) @@ -223,7 +233,7 @@ def _build_weight_plan(input_dir: Path) -> tuple[dict[str, str], dict[str, list[ converted_weight_map[target_key] = output_file_by_target[target_key] continue - target_key = convert_weight_name(source_key) + target_key = convert_weight_name(source_key, old_num_hidden_layers) if target_key in converted_weight_map: raise ValueError(f"Duplicate converted weight name: {target_key}") converted_weight_map[target_key] = filename @@ -253,6 +263,7 @@ def convert_safetensors(input_dir: Path, output_dir: Path) -> None: if not safetensors_path.exists(): raise FileNotFoundError("Only safetensors ZAYA checkpoints are supported by this converter.") + old_num_hidden_layers = int(json.loads((input_dir / "config.json").read_text())["num_hidden_layers"]) with safe_open(safetensors_path, framework="pt", device="cpu") as f: metadata = f.metadata() state_dict = {} @@ -263,7 +274,7 @@ def convert_safetensors(input_dir: Path, output_dir: Path) -> None: target_key, expert_idx = expert_info expert_groups[target_key].append((expert_idx, f.get_tensor(key))) continue - state_dict[convert_weight_name(key)] = f.get_tensor(key) + state_dict[convert_weight_name(key, old_num_hidden_layers)] = f.get_tensor(key) for target_key, expert_tensors in expert_groups.items(): state_dict[target_key] = torch.stack([tensor for _, tensor in sorted(expert_tensors)], dim=0) save_file(state_dict, output_dir / "model.safetensors", metadata=metadata) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index 008e166b416e..f9e1537e4972 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, LinearAttentionAndFullAttentionLayer from ...generation import GenerationMixin -from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -199,8 +199,6 @@ def __init__(self, config: ZayaConfig, layer_idx: int): stride=1, ) - self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads)) - def forward( self, hidden_states: torch.Tensor, @@ -282,6 +280,43 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -321,44 +356,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -@use_kernelized_func(apply_rotary_pos_emb) class ZayaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -368,22 +365,23 @@ def __init__(self, config: ZayaConfig, layer_idx: int): self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True + self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.layer_n = layer_idx + self.qkv_proj = ZayaCCAProjection( + config=self.config, + layer_idx=layer_idx, + ) self.layer_type = config.layer_types[layer_idx] self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.qkv = ZayaCCAProjection( - config=self.config, - layer_idx=layer_idx, - ) + self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) def forward( self, @@ -399,7 +397,7 @@ def forward( causal_mask = mask_mapping.get("causal") padding_mask = mask_mapping.get("padding") - query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask) + query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask) norm_eps = torch.finfo(query_states.dtype).eps head_dim_scale = self.scaling**-1 @@ -407,7 +405,7 @@ def forward( head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) ) key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)) - key_states = key_states * self.qkv.temp[None, None, :, None] + key_states = key_states * self.temp[None, None, :, None] query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -417,7 +415,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) if isinstance(causal_mask, torch.Tensor): causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]] @@ -443,25 +441,12 @@ def forward( return attn_output, attn_weights, past_key_values -def _apply_residual_scaling( - hidden_states: torch.Tensor, - residual: torch.Tensor | None, - residual_scaling, - rms_norm: ZayaRMSNorm, -) -> tuple[torch.Tensor, torch.Tensor]: - residual, hidden_states = residual_scaling(residual, hidden_states) - residual = hidden_states.to(torch.float32) if residual is None else hidden_states + residual - hidden_states = rms_norm(residual.to(dtype=rms_norm.weight.dtype)) - return hidden_states, residual - - class ZayaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__() self.config = config self.self_attn = ZayaAttention(config, layer_idx) self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) - self.res_scale = ResidualScaling(config.hidden_size, has_residual_scale=layer_idx != 0) self.zaya_block = ZayaSparseMoeBlock( config, config.num_experts, @@ -471,18 +456,21 @@ def __init__(self, config: ZayaConfig, layer_idx: int): ) self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) self.post_attention_res_scale = ResidualScaling(config.hidden_size) + self.post_mlp_res_scale = ResidualScaling(config.hidden_size) def forward( self, hidden_states: torch.Tensor, - residual: torch.Tensor | None, prev_router_hidden_states: torch.Tensor | None = None, attention_mask: dict[str, Any] | None = None, past_key_values: Cache | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]: - hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + residual = hidden_states + # Matches the original ZAYA `residual_in_fp32` path; norm casts back to the parameter dtype below. + residual = residual.to(torch.float32) + hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype)) hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, @@ -492,34 +480,31 @@ def forward( **kwargs, ) - hidden_states, residual = _apply_residual_scaling( - hidden_states, residual, self.post_attention_res_scale, self.post_attention_norm - ) + residual = self.post_attention_res_scale(hidden_states, residual) + hidden_states = self.post_attention_norm(residual.to(dtype=self.post_attention_norm.weight.dtype)) hidden_states, prev_router_hidden_states, _ = self.zaya_block( hidden_states, prev_router_hidden_states, ) - return hidden_states, self_attn_weights, residual, prev_router_hidden_states + hidden_states = self.post_mlp_res_scale(hidden_states, residual) + + return hidden_states, self_attn_weights, prev_router_hidden_states class ResidualScaling(nn.Module): - def __init__(self, hidden_size: int, has_residual_scale: bool = True): + def __init__(self, hidden_size: int): super().__init__() - self.has_residual_scale = has_residual_scale self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size)) + self.residual_scale = nn.Parameter(torch.ones(hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) - if self.has_residual_scale: - self.residual_scale = nn.Parameter(torch.ones(hidden_size)) - self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) - - def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor): + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale - if self.has_residual_scale: - residual = (residual + self.residual_bias) * self.residual_scale - return residual, hidden_states + residual = (residual + self.residual_bias) * self.residual_scale + return hidden_states + residual class ZayaRouterMLP(nn.Module): @@ -710,11 +695,11 @@ def _init_weights(self, module): if isinstance(module, ResidualScaling): init.ones_(module.hidden_states_scale) init.zeros_(module.hidden_states_bias) - if module.has_residual_scale: - init.ones_(module.residual_scale) - init.zeros_(module.residual_bias) - elif isinstance(module, ZayaCCAProjection): - init.ones_(module.temp) + init.ones_(module.residual_scale) + init.zeros_(module.residual_bias) + elif isinstance(module, ZayaModel): + init.ones_(module.input_hidden_states_scale) + init.zeros_(module.input_hidden_states_bias) elif isinstance(module, ZayaRouter): if module.use_eda: init.ones_(module.router_states_scale) @@ -778,8 +763,9 @@ def __init__(self, config: ZayaConfig): ) self.gradient_checkpointing = False - self.res_scale = ResidualScaling(config.hidden_size) + self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) self.rotary_emb = ZayaRotaryEmbedding(config=config) @@ -816,8 +802,6 @@ def forward( raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.") past_key_values = make_zaya_cache(self.config) - residual = None - if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange( @@ -851,6 +835,8 @@ def forward( for layer_type in set(self.config.layer_types) } + hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale + prev_router_hidden_states = None for layer_n, decoder_layer in enumerate(self.layers): @@ -859,7 +845,6 @@ def forward( mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask} layer_outputs = decoder_layer( hidden_states, - residual, prev_router_hidden_states, attention_mask=mask_mapping, past_key_values=past_key_values, @@ -868,10 +853,9 @@ def forward( ) hidden_states = layer_outputs[0] - residual = layer_outputs[2] - prev_router_hidden_states = layer_outputs[3] + prev_router_hidden_states = layer_outputs[2] - hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm) + hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) return MoeModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index fb66f6b0522f..f0becacb968c 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -42,7 +42,8 @@ from ...utils.output_capturing import OutputRecorder, capture_outputs from ..afmoe.modeling_afmoe import AfmoeForCausalLM from ..laguna.modeling_laguna import LagunaRotaryEmbedding -from ..llama.modeling_llama import LlamaAttention, LlamaPreTrainedModel +from ..llama.modeling_llama import LlamaPreTrainedModel +from ..phi3.modeling_phi3 import Phi3Attention from ..qwen3_5_moe.modeling_qwen3_5_moe import ( apply_rotary_pos_emb, eager_attention_forward, @@ -256,8 +257,6 @@ def __init__(self, config: ZayaConfig, layer_idx: int): stride=1, ) - self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads)) - def forward( self, hidden_states: torch.Tensor, @@ -332,20 +331,20 @@ def forward( return query, key, value -class ZayaAttention(LlamaAttention): +class ZayaAttention(Phi3Attention): def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__(config, layer_idx) - self.layer_n = layer_idx + del op_size # noqa: F821 self.layer_type = config.layer_types[layer_idx] self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - del self.q_proj - del self.k_proj - del self.v_proj - self.qkv = ZayaCCAProjection( + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + self.qkv_proj = ZayaCCAProjection( config=self.config, layer_idx=layer_idx, ) @@ -364,7 +363,7 @@ def forward( causal_mask = mask_mapping.get("causal") padding_mask = mask_mapping.get("padding") - query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask) + query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask) norm_eps = torch.finfo(query_states.dtype).eps head_dim_scale = self.scaling**-1 @@ -372,7 +371,7 @@ def forward( head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) ) key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)) - key_states = key_states * self.qkv.temp[None, None, :, None] + key_states = key_states * self.temp[None, None, :, None] query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -382,7 +381,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) if isinstance(causal_mask, torch.Tensor): causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]] @@ -414,7 +413,6 @@ def __init__(self, config: ZayaConfig, layer_idx: int): self.config = config self.self_attn = ZayaAttention(config, layer_idx) self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) - self.res_scale = ResidualScaling(config.hidden_size, has_residual_scale=layer_idx != 0) self.zaya_block = ZayaSparseMoeBlock( config, config.num_experts, @@ -424,18 +422,21 @@ def __init__(self, config: ZayaConfig, layer_idx: int): ) self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) self.post_attention_res_scale = ResidualScaling(config.hidden_size) + self.post_mlp_res_scale = ResidualScaling(config.hidden_size) def forward( self, hidden_states: torch.Tensor, - residual: torch.Tensor | None, prev_router_hidden_states: torch.Tensor | None = None, attention_mask: dict[str, Any] | None = None, past_key_values: Cache | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]: - hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm) + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + residual = hidden_states + # Matches the original ZAYA `residual_in_fp32` path; norm casts back to the parameter dtype below. + residual = residual.to(torch.float32) + hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype)) hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, @@ -445,46 +446,31 @@ def forward( **kwargs, ) - hidden_states, residual = _apply_residual_scaling( - hidden_states, residual, self.post_attention_res_scale, self.post_attention_norm - ) + residual = self.post_attention_res_scale(hidden_states, residual) + hidden_states = self.post_attention_norm(residual.to(dtype=self.post_attention_norm.weight.dtype)) hidden_states, prev_router_hidden_states, _ = self.zaya_block( hidden_states, prev_router_hidden_states, ) - return hidden_states, self_attn_weights, residual, prev_router_hidden_states + hidden_states = self.post_mlp_res_scale(hidden_states, residual) + + return hidden_states, self_attn_weights, prev_router_hidden_states class ResidualScaling(nn.Module): - def __init__(self, hidden_size: int, has_residual_scale: bool = True): + def __init__(self, hidden_size: int): super().__init__() - self.has_residual_scale = has_residual_scale self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size)) + self.residual_scale = nn.Parameter(torch.ones(hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) - if self.has_residual_scale: - self.residual_scale = nn.Parameter(torch.ones(hidden_size)) - self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) - - def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor): + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale - if self.has_residual_scale: - residual = (residual + self.residual_bias) * self.residual_scale - return residual, hidden_states - - -def _apply_residual_scaling( - hidden_states: torch.Tensor, - residual: torch.Tensor | None, - residual_scaling, - rms_norm: ZayaRMSNorm, -) -> tuple[torch.Tensor, torch.Tensor]: - residual, hidden_states = residual_scaling(residual, hidden_states) - residual = hidden_states.to(torch.float32) if residual is None else hidden_states + residual - hidden_states = rms_norm(residual.to(dtype=rms_norm.weight.dtype)) - return hidden_states, residual + residual = (residual + self.residual_bias) * self.residual_scale + return hidden_states + residual class ZayaRouterMLP(nn.Module): @@ -638,11 +624,11 @@ def _init_weights(self, module): if isinstance(module, ResidualScaling): init.ones_(module.hidden_states_scale) init.zeros_(module.hidden_states_bias) - if module.has_residual_scale: - init.ones_(module.residual_scale) - init.zeros_(module.residual_bias) - elif isinstance(module, ZayaCCAProjection): - init.ones_(module.temp) + init.ones_(module.residual_scale) + init.zeros_(module.residual_bias) + elif isinstance(module, ZayaModel): + init.ones_(module.input_hidden_states_scale) + init.zeros_(module.input_hidden_states_bias) elif isinstance(module, ZayaRouter): if module.use_eda: init.ones_(module.router_states_scale) @@ -674,8 +660,9 @@ def __init__(self, config: ZayaConfig): ) self.gradient_checkpointing = False - self.res_scale = ResidualScaling(config.hidden_size) + self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) self.rotary_emb = ZayaRotaryEmbedding(config=config) @@ -712,8 +699,6 @@ def forward( raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.") past_key_values = make_zaya_cache(self.config) - residual = None - if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange( @@ -747,6 +732,8 @@ def forward( for layer_type in set(self.config.layer_types) } + hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale + prev_router_hidden_states = None for layer_n, decoder_layer in enumerate(self.layers): @@ -755,7 +742,6 @@ def forward( mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask} layer_outputs = decoder_layer( hidden_states, - residual, prev_router_hidden_states, attention_mask=mask_mapping, past_key_values=past_key_values, @@ -764,10 +750,9 @@ def forward( ) hidden_states = layer_outputs[0] - residual = layer_outputs[2] - prev_router_hidden_states = layer_outputs[3] + prev_router_hidden_states = layer_outputs[2] - hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm) + hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) return MoeModelOutputWithPast( last_hidden_state=hidden_states, From 3f53fbca6674db5d58bfb5cdb23def9f61624168 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Wed, 13 May 2026 18:13:36 +0800 Subject: [PATCH 22/36] better cache --- docs/source/en/model_doc/zaya.md | 8 --- src/transformers/cache_utils.py | 54 +++++++++---------- .../models/zaya/configuration_zaya.py | 10 ++++ src/transformers/models/zaya/modeling_zaya.py | 42 ++------------- src/transformers/models/zaya/modular_zaya.py | 52 +++++------------- tests/models/zaya/test_modeling_zaya.py | 8 +-- 6 files changed, 57 insertions(+), 117 deletions(-) diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md index e6a220adbecf..06beb12e2e6f 100644 --- a/docs/source/en/model_doc/zaya.md +++ b/docs/source/en/model_doc/zaya.md @@ -27,14 +27,6 @@ and Zyphra's technical reports. This model was contributed by [JJJYmmm](https://github.com/JJJYmmm). - - -When building a manual generation loop with `past_key_values`, use [`~models.zaya.modeling_zaya.make_zaya_cache`] to -create ZAYA's cache. ZAYA uses `config.layer_types` for full/sliding attention masks and RoPE parameters, while its -cache uses the native hybrid layout needed by the attention, convolution, and recurrent states. - - - ## Usage examples ```python diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index dfef404a42f1..a9eee165b68f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -23,9 +23,9 @@ logger = logging.get_logger(__name__) -# Registry mapping ``config.layer_types[i]`` -> the dynamic cache layer class to build for -# that layer. ``DynamicCache.__init__`` consults this mapping when a ``config`` is provided -# so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own +# Registry mapping ``config.cache_layer_types[i]`` (or ``config.layer_types[i]`` when the cache-specific field is not +# set) -> the dynamic cache layer class to build for that layer. ``DynamicCache.__init__`` consults this mapping when a +# ``config`` is provided so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own # cache-layer subclass and stop needing a model-specific ``Cache`` subclass. # # A cache layer subclass with a class attribute ``layer_type = "..."`` auto-registers via @@ -34,6 +34,24 @@ LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {} +def _get_layer_types_for_cache(decoder_config: PreTrainedConfig) -> list[str]: + sliding_window = getattr(decoder_config, "sliding_window", None) or getattr( + decoder_config, "attention_chunk_size", None + ) + layer_types = getattr(decoder_config, "cache_layer_types", None) or getattr(decoder_config, "layer_types", None) + if layer_types is None: + layer_types = [] + for _ in range(decoder_config.num_hidden_layers): + if sliding_window is not None: + layer_types.append("sliding_attention") + else: + layer_types.append("full_attention") + # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) + if hasattr(decoder_config, "num_kv_shared_layers"): + layer_types = layer_types[: -decoder_config.num_kv_shared_layers] + return layer_types + + class CacheLayerMixin(ABC): """Base, abstract class for a single layer's cache.""" @@ -1280,20 +1298,7 @@ def __init__( # If a config is passed, use it to infer the layer types and initialize accordingly if config is not None: decoder_config = config.get_text_config(decoder=True) - sliding_window = getattr(decoder_config, "sliding_window", None) or getattr( - decoder_config, "attention_chunk_size", None - ) - layer_types = getattr(decoder_config, "layer_types", None) - if layer_types is None: - layer_types = [] - for _ in range(decoder_config.num_hidden_layers): - if sliding_window is not None: - layer_types.append("sliding_attention") - else: - layer_types.append("full_attention") - # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) - if hasattr(decoder_config, "num_kv_shared_layers"): - layer_types = layer_types[: -decoder_config.num_kv_shared_layers] + layer_types = _get_layer_types_for_cache(decoder_config) for layer_type in layer_types: cache_cls = LAYER_TYPE_CACHE_MAPPING.get(layer_type, DynamicLayer) @@ -1382,18 +1387,7 @@ def __init__( **kwargs, ): config = config.get_text_config(decoder=True) - layer_types = getattr(config, "layer_types", None) - # If `layer_types` is not explicitly provided, infer if the model is fully sliding - if layer_types is None: - if getattr(config, "sliding_window", None) is not None: - layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)] - elif getattr(config, "attention_chunk_size", None) is not None: - layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)] - else: - layer_types = ["full_attention" for _ in range(config.num_hidden_layers)] - # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) - if hasattr(config, "num_kv_shared_layers"): - layer_types = layer_types[: -config.num_kv_shared_layers] + layer_types = _get_layer_types_for_cache(config) sliding_layer_types = { name @@ -1413,6 +1407,8 @@ def __init__( # LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache elif layer_type in ("mamba", "conv", "linear_attention", "moe"): layer = LinearAttentionLayer() + elif layer_type == "hybrid": + layer = LinearAttentionAndFullAttentionLayer(config) else: layer = StaticLayer(max_cache_len=max_cache_len) layers.append(layer) diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py index 26bf600d413b..9f373e9a6d21 100644 --- a/src/transformers/models/zaya/configuration_zaya.py +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -49,6 +49,8 @@ class ZayaConfig(PreTrainedConfig): Second temporal parameter of the CCA projection. layer_types (`list[str]`, *optional*): Per-layer selector for standard RoPE versus SWA RoPE embeddings. + cache_layer_types (`list[str]`, *optional*): + Per-layer selector for cache layout. ZAYA uses the native `"hybrid"` cache layer for every decoder layer. ```python >>> from transformers import ZayaConfig, ZayaModel @@ -90,6 +92,7 @@ class ZayaConfig(PreTrainedConfig): cca_time1: int = 2 sliding_window: int | None = None layer_types: list[str] | None = None + cache_layer_types: list[str] | None = None output_router_logits: bool = False pad_token_id: int | None = 0 bos_token_id: int | None = 2 @@ -99,6 +102,9 @@ def __post_init__(self, **kwargs): self.layer_types = ( ["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) ) + self.cache_layer_types = ( + ["hybrid"] * self.num_hidden_layers if self.cache_layer_types is None else list(self.cache_layer_types) + ) default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = { "full_attention": { @@ -131,6 +137,10 @@ def validate_architecture(self): raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") if len(self.layer_types) != self.num_hidden_layers: raise ValueError("`layer_types` must have one entry per hidden layer.") + if len(self.cache_layer_types) != self.num_hidden_layers: + raise ValueError("`cache_layer_types` must have one entry per hidden layer.") + if invalid_cache_layer_types := set(self.cache_layer_types) - {"hybrid"}: + raise ValueError(f"`cache_layer_types` contains unsupported values: {sorted(invalid_cache_layer_types)}.") if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}: raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.") if "sliding_attention" in self.layer_types and self.sliding_window is None: diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index f9e1537e4972..c958f34054bf 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -19,7 +19,6 @@ # limitations under the License. -import copy from collections.abc import Callable from typing import Any, Optional @@ -29,7 +28,7 @@ from torch.nn import init from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, LinearAttentionAndFullAttentionLayer +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_experts_implementation, use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask @@ -719,44 +718,13 @@ def _init_weights(self, module): getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq) -def make_zaya_cache(config: ZayaConfig) -> DynamicCache: - """ - Create ZAYA's native hybrid cache. - - ZAYA uses `config.layer_types` for the attention mask and RoPE variant of each layer (`"full_attention"` or - `"sliding_attention"`). That is separate from the cache layout: every ZAYA decoder layer needs the native - `"hybrid"` cache layer because it stores all three states used during decoding: - - - The regular dynamic attention KV cache, updated after the CCA projection and RoPE application. - - `conv_states`, the pre-convolution q/k tail used by `ZayaCCAProjection` on the next decoding step. Its channel - dimension is `num_attention_heads * head_dim + num_key_value_heads * head_dim`, and its time dimension is - `cca_time0 + cca_time1 - 2`. - - `recurrent_states`, ZAYA's delayed value state. It stores the previous token's `val_proj2` output (the legacy - `prev_h2`/second value projection state), so the next token can build its value from the current `val_proj1` - output plus the cached delayed `val_proj2`. - - The copied config only changes `layer_types` to `"hybrid"` so `DynamicCache` instantiates - `LinearAttentionAndFullAttentionLayer`; it does not alter the model's mask or RoPE layer types. - """ - cache_config = copy.copy(config) - cache_config.layer_types = ["hybrid"] * config.num_hidden_layers - return DynamicCache(config=cache_config) - - -def _is_zaya_cache(past_key_values: Cache) -> bool: - return ( - isinstance(past_key_values, DynamicCache) - and len(past_key_values.layers) > 0 - and isinstance(past_key_values.layers[0], LinearAttentionAndFullAttentionLayer) - ) - - @auto_docstring class ZayaModel(ZayaPreTrainedModel): def __init__(self, config: ZayaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.cache_layer_types = config.cache_layer_types self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -797,10 +765,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)): - if past_key_values is not None and past_key_values.get_seq_length() > 0: - raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.") - past_key_values = make_zaya_cache(self.config) + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index f0becacb968c..7878e9ee8d16 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -14,7 +14,6 @@ """PyTorch Zaya model.""" -import copy from collections.abc import Callable from typing import Any, Literal @@ -26,7 +25,7 @@ from torch.nn import init from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, LinearAttentionAndFullAttentionLayer +from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer @@ -73,6 +72,8 @@ class ZayaConfig(PreTrainedConfig): Second temporal parameter of the CCA projection. layer_types (`list[str]`, *optional*): Per-layer selector for standard RoPE versus SWA RoPE embeddings. + cache_layer_types (`list[str]`, *optional*): + Per-layer selector for cache layout. ZAYA uses the native `"hybrid"` cache layer for every decoder layer. ```python >>> from transformers import ZayaConfig, ZayaModel @@ -114,6 +115,7 @@ class ZayaConfig(PreTrainedConfig): cca_time1: int = 2 sliding_window: int | None = None layer_types: list[str] | None = None + cache_layer_types: list[str] | None = None output_router_logits: bool = False pad_token_id: int | None = 0 bos_token_id: int | None = 2 @@ -123,6 +125,9 @@ def __post_init__(self, **kwargs): self.layer_types = ( ["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) ) + self.cache_layer_types = ( + ["hybrid"] * self.num_hidden_layers if self.cache_layer_types is None else list(self.cache_layer_types) + ) default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = { "full_attention": { @@ -155,6 +160,10 @@ def validate_architecture(self): raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") if len(self.layer_types) != self.num_hidden_layers: raise ValueError("`layer_types` must have one entry per hidden layer.") + if len(self.cache_layer_types) != self.num_hidden_layers: + raise ValueError("`cache_layer_types` must have one entry per hidden layer.") + if invalid_cache_layer_types := set(self.cache_layer_types) - {"hybrid"}: + raise ValueError(f"`cache_layer_types` contains unsupported values: {sorted(invalid_cache_layer_types)}.") if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}: raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.") if "sliding_attention" in self.layer_types and self.sliding_window is None: @@ -171,38 +180,6 @@ class ZayaRMSNorm(Qwen3MoeRMSNorm): pass -def make_zaya_cache(config: ZayaConfig) -> DynamicCache: - """ - Create ZAYA's native hybrid cache. - - ZAYA uses `config.layer_types` for the attention mask and RoPE variant of each layer (`"full_attention"` or - `"sliding_attention"`). That is separate from the cache layout: every ZAYA decoder layer needs the native - `"hybrid"` cache layer because it stores all three states used during decoding: - - - The regular dynamic attention KV cache, updated after the CCA projection and RoPE application. - - `conv_states`, the pre-convolution q/k tail used by `ZayaCCAProjection` on the next decoding step. Its channel - dimension is `num_attention_heads * head_dim + num_key_value_heads * head_dim`, and its time dimension is - `cca_time0 + cca_time1 - 2`. - - `recurrent_states`, ZAYA's delayed value state. It stores the previous token's `val_proj2` output (the legacy - `prev_h2`/second value projection state), so the next token can build its value from the current `val_proj1` - output plus the cached delayed `val_proj2`. - - The copied config only changes `layer_types` to `"hybrid"` so `DynamicCache` instantiates - `LinearAttentionAndFullAttentionLayer`; it does not alter the model's mask or RoPE layer types. - """ - cache_config = copy.copy(config) - cache_config.layer_types = ["hybrid"] * config.num_hidden_layers - return DynamicCache(config=cache_config) - - -def _is_zaya_cache(past_key_values: Cache) -> bool: - return ( - isinstance(past_key_values, DynamicCache) - and len(past_key_values.layers) > 0 - and isinstance(past_key_values.layers[0], LinearAttentionAndFullAttentionLayer) - ) - - class ZayaCCAProjection(nn.Module): """ Projects hidden states into attention q/k/v states with ZAYA's CCA path. @@ -654,6 +631,7 @@ def __init__(self, config: ZayaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.cache_layer_types = config.cache_layer_types self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -694,10 +672,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)): - if past_key_values is not None and past_key_values.get_seq_length() > 0: - raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.") - past_key_values = make_zaya_cache(self.config) + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index 80f9112a97fc..94bd74093e15 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -27,7 +27,7 @@ from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel from transformers.cache_utils import DynamicCache, LinearAttentionAndFullAttentionLayer - from transformers.models.zaya.modeling_zaya import ZayaCCAProjection, make_zaya_cache + from transformers.models.zaya.modeling_zaya import ZayaCCAProjection from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester @@ -301,7 +301,7 @@ def test_cca_cache_matches_full_forward(self): with torch.no_grad(): full = cca(hidden_states, None, None) - cache = make_zaya_cache(config) + cache = DynamicCache(config=config) cca(hidden_states[:, :4], cache, None) cached = cca(hidden_states[:, 4:], cache, None) @@ -328,7 +328,7 @@ def test_cca_cache_matches_full_forward_multi_token(self): with torch.no_grad(): full = cca(hidden_states, None, None) - cache = make_zaya_cache(config) + cache = DynamicCache(config=config) cca(hidden_states[:, :3], cache, None) cached = cca(hidden_states[:, 3:], cache, None) @@ -337,7 +337,7 @@ def test_cca_cache_matches_full_forward_multi_token(self): def test_zaya_cache_reorder_and_reset(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - cache = make_zaya_cache(config) + cache = DynamicCache(config=config) conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim cache.update_conv_state( torch.arange(2 * conv_state_size * 2, device=torch_device, dtype=torch.float32).view( From dc7ac50dd15c4e95d0b246b534f7c19c8fcc385c Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Wed, 13 May 2026 18:30:48 +0800 Subject: [PATCH 23/36] ops forget init again --- src/transformers/models/zaya/modeling_zaya.py | 2 ++ src/transformers/models/zaya/modular_zaya.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index c958f34054bf..b162dde52e02 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -699,6 +699,8 @@ def _init_weights(self, module): elif isinstance(module, ZayaModel): init.ones_(module.input_hidden_states_scale) init.zeros_(module.input_hidden_states_bias) + elif isinstance(module, ZayaAttention): + init.zeros_(module.temp) elif isinstance(module, ZayaRouter): if module.use_eda: init.ones_(module.router_states_scale) diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 7878e9ee8d16..04a6625b1313 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -606,6 +606,8 @@ def _init_weights(self, module): elif isinstance(module, ZayaModel): init.ones_(module.input_hidden_states_scale) init.zeros_(module.input_hidden_states_bias) + elif isinstance(module, ZayaAttention): + init.zeros_(module.temp) elif isinstance(module, ZayaRouter): if module.use_eda: init.ones_(module.router_states_scale) From 8be4b1ee9c4871398093690664bea5b97b887a25 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Thu, 14 May 2026 14:18:18 +0800 Subject: [PATCH 24/36] better naming --- src/transformers/cache_utils.py | 82 ++-- src/transformers/configuration_utils.py | 3 +- src/transformers/models/zaya/__init__.py | 2 +- .../models/zaya/configuration_zaya.py | 89 ++-- .../models/zaya/convert_zaya_weights_to_hf.py | 75 ++-- src/transformers/models/zaya/modeling_zaya.py | 298 ++++++------- src/transformers/models/zaya/modular_zaya.py | 395 +++++++----------- tests/models/zaya/test_modeling_zaya.py | 66 +-- 8 files changed, 473 insertions(+), 537 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a9eee165b68f..993643dfe390 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -23,9 +23,9 @@ logger = logging.get_logger(__name__) -# Registry mapping ``config.cache_layer_types[i]`` (or ``config.layer_types[i]`` when the cache-specific field is not -# set) -> the dynamic cache layer class to build for that layer. ``DynamicCache.__init__`` consults this mapping when a -# ``config`` is provided so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own +# Registry mapping ``config.layer_types[i]`` -> the dynamic cache layer class to build for +# that layer. ``DynamicCache.__init__`` consults this mapping when a ``config`` is provided +# so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own # cache-layer subclass and stop needing a model-specific ``Cache`` subclass. # # A cache layer subclass with a class attribute ``layer_type = "..."`` auto-registers via @@ -34,24 +34,6 @@ LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {} -def _get_layer_types_for_cache(decoder_config: PreTrainedConfig) -> list[str]: - sliding_window = getattr(decoder_config, "sliding_window", None) or getattr( - decoder_config, "attention_chunk_size", None - ) - layer_types = getattr(decoder_config, "cache_layer_types", None) or getattr(decoder_config, "layer_types", None) - if layer_types is None: - layer_types = [] - for _ in range(decoder_config.num_hidden_layers): - if sliding_window is not None: - layer_types.append("sliding_attention") - else: - layer_types.append("full_attention") - # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) - if hasattr(decoder_config, "num_kv_shared_layers"): - layer_types = layer_types[: -decoder_config.num_kv_shared_layers] - return layer_types - - class CacheLayerMixin(ABC): """Base, abstract class for a single layer's cache.""" @@ -882,6 +864,33 @@ def reorder_cache(self, beam_idx: torch.LongTensor): DynamicLayer.reorder_cache(self, beam_idx) +class LinearAttentionAndSlidingWindowAttentionLayer(LinearAttentionLayer, DynamicSlidingWindowLayer): + # The dynamic sliding attention part makes it non-compileable + is_compileable = False + + def __init__(self, config: PreTrainedConfig | None = None): + DynamicSlidingWindowLayer.__init__(self, config) + LinearAttentionLayer.__init__(self) + + def lazy_initialization(self, *args, **kwargs) -> None: + # When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args + if len(args) == 2 and len(kwargs) == 0: + DynamicSlidingWindowLayer.lazy_initialization(self, *args) + # Otherwise, for the LinearAttention cache, when it's called in `update_conv_state` or `update_recurrent_state`, + # it's always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states) + if len(args) == 0 and len(kwargs) == 1: + LinearAttentionLayer.lazy_initialization(self, **kwargs) + + def reset(self) -> None: + LinearAttentionLayer.reset(self) + DynamicSlidingWindowLayer.reset(self) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + LinearAttentionLayer.reorder_cache(self, beam_idx) + DynamicSlidingWindowLayer.reorder_cache(self, beam_idx) + + # Pre-register the standard layer types (some classes are shared between multiple types, # e.g. ``DynamicSlidingWindowLayer`` covers both ``"sliding_attention"`` and # ``"chunked_attention"`` — those need an explicit map entry rather than the @@ -901,6 +910,7 @@ def reorder_cache(self, beam_idx: torch.LongTensor): "moe": LinearAttentionLayer, # Hybrid layers (e.g. zamba / zamba2) carry both a linear-attention state and a dynamic-attention state. "hybrid": LinearAttentionAndFullAttentionLayer, + "hybrid_sliding": LinearAttentionAndSlidingWindowAttentionLayer, } ) @@ -1298,7 +1308,20 @@ def __init__( # If a config is passed, use it to infer the layer types and initialize accordingly if config is not None: decoder_config = config.get_text_config(decoder=True) - layer_types = _get_layer_types_for_cache(decoder_config) + sliding_window = getattr(decoder_config, "sliding_window", None) or getattr( + decoder_config, "attention_chunk_size", None + ) + layer_types = getattr(decoder_config, "layer_types", None) + if layer_types is None: + layer_types = [] + for _ in range(decoder_config.num_hidden_layers): + if sliding_window is not None: + layer_types.append("sliding_attention") + else: + layer_types.append("full_attention") + # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) + if hasattr(decoder_config, "num_kv_shared_layers"): + layer_types = layer_types[: -decoder_config.num_kv_shared_layers] for layer_type in layer_types: cache_cls = LAYER_TYPE_CACHE_MAPPING.get(layer_type, DynamicLayer) @@ -1387,7 +1410,18 @@ def __init__( **kwargs, ): config = config.get_text_config(decoder=True) - layer_types = _get_layer_types_for_cache(config) + layer_types = getattr(config, "layer_types", None) + # If `layer_types` is not explicitly provided, infer if the model is fully sliding + if layer_types is None: + if getattr(config, "sliding_window", None) is not None: + layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)] + elif getattr(config, "attention_chunk_size", None) is not None: + layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)] + else: + layer_types = ["full_attention" for _ in range(config.num_hidden_layers)] + # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) + if hasattr(config, "num_kv_shared_layers"): + layer_types = layer_types[: -config.num_kv_shared_layers] sliding_layer_types = { name @@ -1407,8 +1441,6 @@ def __init__( # LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache elif layer_type in ("mamba", "conv", "linear_attention", "moe"): layer = LinearAttentionLayer() - elif layer_type == "hybrid": - layer = LinearAttentionAndFullAttentionLayer(config) else: layer = StaticLayer(max_cache_len=max_cache_len) layers.append(layer) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 7ba033c538d8..e495bbdc69c0 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -71,7 +71,8 @@ "attention", "sparse", "dense", - "hybrid", # for layers that have both mamba and attention in zamba and zamba2 + "hybrid", # for zamba/zamba2/zaya1, which use full attention + conv states + "hybrid_sliding", # for zaya1, which uses swa + conv states "moe", # for nemotron_h, which uses either attention, mamba or moe ) diff --git a/src/transformers/models/zaya/__init__.py b/src/transformers/models/zaya/__init__.py index 54cc0c89f303..c28f97af94ea 100644 --- a/src/transformers/models/zaya/__init__.py +++ b/src/transformers/models/zaya/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2025 Zyphra and The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 Zyphra and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py index 9f373e9a6d21..6fb47ecb76a4 100644 --- a/src/transformers/models/zaya/configuration_zaya.py +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -31,26 +31,14 @@ @strict class ZayaConfig(PreTrainedConfig): r""" - intermediate_size (`int`, *optional*, defaults to 4096): - Dimension of the feed-forward and expert hidden states. - num_key_value_heads (`int`, *optional*, defaults to 2): - Number of key/value groups. - partial_rotary_factor (`float`, *optional*, defaults to 0.5): - Fraction of each attention head dimension using rotary embeddings. lm_head_bias (`bool`, *optional*, defaults to `False`): Whether to add a bias to the language modeling head. - num_experts_per_tok (`int`, *optional*, defaults to 1): - Number of selected experts per token. ZAYA checkpoints use top-1 routing. - zaya_mlp_expansion (`int`, *optional*, defaults to 256): - Expansion size used by the dense ZAYA blocks. + router_hidden_size (`int`, *optional*, defaults to 256): + Hidden size used by the ZAYA router. cca_time0 (`int`, *optional*, defaults to 2): First temporal parameter of the CCA projection. cca_time1 (`int`, *optional*, defaults to 2): Second temporal parameter of the CCA projection. - layer_types (`list[str]`, *optional*): - Per-layer selector for standard RoPE versus SWA RoPE embeddings. - cache_layer_types (`list[str]`, *optional*): - Per-layer selector for cache layout. ZAYA uses the native `"hybrid"` cache layer for every decoder layer. ```python >>> from transformers import ZayaConfig, ZayaModel @@ -64,87 +52,76 @@ class ZayaConfig(PreTrainedConfig): model_type = "zaya" keys_to_ignore_at_inference = ["past_key_values"] - default_theta = 5000000.0 - default_swa_theta = 10000.0 vocab_size: int = 262272 hidden_size: int = 2048 - intermediate_size: int = 4096 num_hidden_layers: int = 40 - num_experts: int = 16 num_attention_heads: int = 8 num_key_value_heads: int = 2 hidden_act: str = "silu" - head_dim: int = 128 max_position_embeddings: int = 131072 initializer_range: float = 0.02 - norm_epsilon: float = 1e-5 + rms_norm_eps: float = 1e-5 use_cache: bool = True tie_word_embeddings: bool = True rope_parameters: RopeParameters | dict | None = None - partial_rotary_factor: float = 0.5 - attention_bias: bool = False - lm_head_bias: bool = False + sliding_window: int | None = None attention_dropout: float | int = 0.0 + moe_intermediate_size: int = 2048 + num_experts_per_tok: int = 1 - zaya_mlp_expansion: int = 256 - cca_time0: int = 2 - cca_time1: int = 2 - sliding_window: int | None = None - layer_types: list[str] | None = None - cache_layer_types: list[str] | None = None + num_experts: int = 16 output_router_logits: bool = False + layer_types: list[str] | None = None pad_token_id: int | None = 0 bos_token_id: int | None = 2 eos_token_id: int | list[int] | None = 106 + # Zaya-specific attention + head_dim: int = 128 + attention_bias: bool = False + + lm_head_bias: bool = False + router_hidden_size: int = 256 + cca_time0: int = 2 + cca_time1: int = 2 + def __post_init__(self, **kwargs): - self.layer_types = ( - ["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) - ) - self.cache_layer_types = ( - ["hybrid"] * self.num_hidden_layers if self.cache_layer_types is None else list(self.cache_layer_types) - ) - - default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = { - "full_attention": { + self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) + + default_rope_params: dict[Literal["hybrid", "hybrid_sliding"], dict[str, Any]] = { + "hybrid": { "rope_type": "default", - "rope_theta": self.default_theta, - "partial_rotary_factor": self.partial_rotary_factor, + "rope_theta": 5_000_000.0, + "partial_rotary_factor": 0.5, }, - "sliding_attention": { + "hybrid_sliding": { "rope_type": "default", - "rope_theta": self.default_swa_theta, - "partial_rotary_factor": self.partial_rotary_factor, + "rope_theta": 10_000.0, + "partial_rotary_factor": 0.5, }, } if self.rope_parameters is None: - self.rope_parameters = { - layer_type: default_rope_params[layer_type] for layer_type in set(self.layer_types) - } + self.rope_parameters = default_rope_params - super().__post_init__(**kwargs) + super().__post_init__(**kwargs, ignore_keys_at_rope_validation={"hybrid", "hybrid_sliding"}) def convert_rope_params_to_dict(self, **kwargs): - # ZAYA uses nested RoPE parameters keyed by layer type. Keep the base RoPE BC conversion from treating them - # like a single flat RoPE dict and injecting top-level keys such as `rope_theta`. + # No legacy flat RoPE format is supported here; conversion writes the nested ZAYA layer-type format directly. return kwargs def validate_architecture(self): + """Part of ``@strict``-powered validation.""" if self.num_experts_per_tok != 1: raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.") if self.num_attention_heads % self.num_key_value_heads != 0: raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") if len(self.layer_types) != self.num_hidden_layers: raise ValueError("`layer_types` must have one entry per hidden layer.") - if len(self.cache_layer_types) != self.num_hidden_layers: - raise ValueError("`cache_layer_types` must have one entry per hidden layer.") - if invalid_cache_layer_types := set(self.cache_layer_types) - {"hybrid"}: - raise ValueError(f"`cache_layer_types` contains unsupported values: {sorted(invalid_cache_layer_types)}.") - if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}: + if invalid_layer_types := set(self.layer_types) - {"hybrid", "hybrid_sliding"}: raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.") - if "sliding_attention" in self.layer_types and self.sliding_window is None: - raise ValueError("`sliding_window` must be set when `layer_types` contains `sliding_attention`.") + if "hybrid_sliding" in self.layer_types and self.sliding_window is None: + raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.") if self.sliding_window is not None and self.sliding_window <= 0: raise ValueError("`sliding_window` must be a strictly positive integer.") diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py index 228532e53fd4..2ac6cb7df869 100644 --- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py +++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py @@ -27,6 +27,8 @@ from transformers import ZayaConfig +_DEFAULT_ROPE_THETA = 5_000_000.0 +_DEFAULT_SWA_ROPE_THETA = 10_000.0 _LAYER_PATTERN = re.compile(r"^model\.layers\.(\d+)\.(.+)$") _LOCAL_EXPERT_PATTERN = re.compile( r"^model\.layers\.(\d+)\.zaya_block\.experts\.local_experts\.(\d+)\.linear_fc([12])\.weight$" @@ -35,8 +37,11 @@ _UNUSED_CONFIG_KEYS = ( "cca", "num_query_groups", + "intermediate_size", "ffn_hidden_size", "moe_router_topk", + "norm_epsilon", + "zaya_mlp_expansion", "activation_func", "normalization", "add_bias_linear", @@ -61,11 +66,18 @@ def _rename_common(rest: str) -> str: replacements = ( ("self_attn.qkv.conv_qk.0.", "self_attn.qkv_proj.conv_qk_depthwise."), ("self_attn.qkv.conv_qk.1.", "self_attn.qkv_proj.conv_qk_grouped."), - ("self_attn.qkv.temp", "self_attn.temp"), + ("self_attn.qkv.temp", "self_attn.qk_norm.temp"), + ("self_attn.qkv.linear_q.", "self_attn.qkv_proj.q_proj."), + ("self_attn.qkv.linear_k.", "self_attn.qkv_proj.k_proj."), + ("self_attn.qkv.val_proj1.", "self_attn.qkv_proj.v_proj_current."), + ("self_attn.qkv.val_proj2.", "self_attn.qkv_proj.v_proj_delayed."), ("self_attn.qkv.", "self_attn.qkv_proj."), - ("zaya_block.router.router_mlp.0.", "zaya_block.router.router_mlp.fc1."), - ("zaya_block.router.router_mlp.2.", "zaya_block.router.router_mlp.fc2."), - ("zaya_block.router.router_mlp.4.", "zaya_block.router.router_mlp.out_proj."), + ("zaya_block.router.rmsnorm_eda.", "mlp.gate.router_mlp.rmsnorm_eda."), + ("zaya_block.router.router_mlp.0.", "mlp.gate.router_mlp.fc1."), + ("zaya_block.router.router_mlp.2.", "mlp.gate.router_mlp.fc2."), + ("zaya_block.router.router_mlp.4.", "mlp.gate.router_mlp.out_proj."), + ("zaya_block.router.", "mlp.gate."), + ("zaya_block.", "mlp."), ) for old, new in replacements: if rest.startswith(old): @@ -85,7 +97,7 @@ def _expert_target(name: str) -> tuple[str, int] | None: new_layer_idx = old_layer_idx // 2 expert_idx = int(match.group(2)) projection = "gate_up_proj" if match.group(3) == "1" else "down_proj" - target = f"model.layers.{new_layer_idx}.zaya_block.experts.{projection}" + target = f"model.layers.{new_layer_idx}.mlp.experts.{projection}" return target, expert_idx @@ -97,7 +109,7 @@ def convert_weight_name(name: str, old_num_hidden_layers: int | None = None) -> if match is None: if old_num_hidden_layers is not None and name.startswith("model.res_scale."): new_layer_idx = old_num_hidden_layers // 2 - 1 - return f"model.layers.{new_layer_idx}.post_mlp_res_scale.{name.removeprefix('model.res_scale.')}" + return f"model.layers.{new_layer_idx}.post_mlp_residual_scale.{name.removeprefix('model.res_scale.')}" return name old_layer_idx = int(match.group(1)) @@ -106,41 +118,51 @@ def convert_weight_name(name: str, old_num_hidden_layers: int | None = None) -> if old_layer_idx % 2 == 0: rest = _rename_common(rest) - if rest.startswith(("self_attn.", "input_norm.")): + if rest.startswith("self_attn."): return f"model.layers.{new_layer_idx}.{rest}" + if rest.startswith("input_norm."): + return f"model.layers.{new_layer_idx}.input_layernorm.{rest.removeprefix('input_norm.')}" if rest.startswith("res_scale."): if old_layer_idx == 0: return f"model.input_{rest.removeprefix('res_scale.')}" - return f"model.layers.{new_layer_idx - 1}.post_mlp_res_scale.{rest.removeprefix('res_scale.')}" + return f"model.layers.{new_layer_idx - 1}.post_mlp_residual_scale.{rest.removeprefix('res_scale.')}" else: rest = _rename_common(rest) - if rest.startswith("zaya_block."): + if rest.startswith("mlp."): return f"model.layers.{new_layer_idx}.{rest}" if rest.startswith("input_norm."): - return f"model.layers.{new_layer_idx}.post_attention_norm.{rest.removeprefix('input_norm.')}" + return f"model.layers.{new_layer_idx}.post_attention_layernorm.{rest.removeprefix('input_norm.')}" if rest.startswith("res_scale."): - return f"model.layers.{new_layer_idx}.post_attention_res_scale.{rest.removeprefix('res_scale.')}" + return f"model.layers.{new_layer_idx}.post_attention_residual_scale.{rest.removeprefix('res_scale.')}" raise ValueError(f"Unexpected ZAYA layer weight name: {name}") +def _to_hybrid_layer_type(layer_type: str) -> str: + if layer_type == "full_attention": + return "hybrid" + if layer_type == "sliding_attention": + return "hybrid_sliding" + raise ValueError(f"Unsupported ZAYA layer type: {layer_type}") + + def _convert_layer_types(config_dict: dict, old_num_hidden_layers: int, new_num_hidden_layers: int) -> list[str]: layer_types = config_dict.get("layer_types") if layer_types is not None: if len(layer_types) == old_num_hidden_layers: - return layer_types[::2] + return [_to_hybrid_layer_type(layer_type) for layer_type in layer_types[::2]] if len(layer_types) == new_num_hidden_layers: - return list(layer_types) + return [_to_hybrid_layer_type(layer_type) for layer_type in layer_types] raise ValueError("`layer_types` must match either the original or converted number of hidden layers.") swa_layers = config_dict.get("swa_layers") if swa_layers is None: - return ["full_attention"] * new_num_hidden_layers + return ["hybrid"] * new_num_hidden_layers if len(swa_layers) == old_num_hidden_layers: swa_layers = swa_layers[::2] elif len(swa_layers) != new_num_hidden_layers: raise ValueError("`swa_layers` must match either the original or converted number of hidden layers.") - return ["full_attention" if int(window_size) == 0 else "sliding_attention" for window_size in swa_layers] + return ["hybrid" if int(window_size) == 0 else "hybrid_sliding" for window_size in swa_layers] def convert_config(input_dir: Path, output_dir: Path) -> None: @@ -151,12 +173,15 @@ def convert_config(input_dir: Path, output_dir: Path) -> None: new_num_hidden_layers = old_num_hidden_layers // 2 layer_types = _convert_layer_types(config_dict, old_num_hidden_layers, new_num_hidden_layers) - partial_rotary_factor = config_dict.get("partial_rotary_factor", ZayaConfig.partial_rotary_factor) - rope_theta = config_dict.get("rope_theta", ZayaConfig.default_theta) - swa_rotary_base = config_dict.get("swa_rotary_base", ZayaConfig.default_swa_theta) - intermediate_size = config_dict.get( - "intermediate_size", config_dict.get("ffn_hidden_size", ZayaConfig.intermediate_size) + partial_rotary_factor = 0.5 + rope_theta = config_dict.get("rope_theta", _DEFAULT_ROPE_THETA) + swa_rotary_base = config_dict.get("swa_rotary_base", _DEFAULT_SWA_ROPE_THETA) + rms_norm_eps = config_dict.get("rms_norm_eps", config_dict.get("norm_epsilon", ZayaConfig.rms_norm_eps)) + router_hidden_size = config_dict.get( + "router_hidden_size", config_dict.get("zaya_mlp_expansion", ZayaConfig.router_hidden_size) ) + expert_ffn_size = config_dict.get("intermediate_size", config_dict.get("ffn_hidden_size")) + moe_intermediate_size = expert_ffn_size // 2 if expert_ffn_size is not None else ZayaConfig.moe_intermediate_size num_experts_per_tok = config_dict.get( "num_experts_per_tok", config_dict.get("moe_router_topk", ZayaConfig.num_experts_per_tok) ) @@ -170,12 +195,12 @@ def convert_config(input_dir: Path, output_dir: Path) -> None: sliding_window = max(positive_windows) + 1 if positive_windows else None rope_parameters = { - "full_attention": { + "hybrid": { "rope_type": "default", "rope_theta": rope_theta, "partial_rotary_factor": partial_rotary_factor, }, - "sliding_attention": { + "hybrid_sliding": { "rope_type": "default", "rope_theta": swa_rotary_base, "partial_rotary_factor": partial_rotary_factor, @@ -189,11 +214,13 @@ def convert_config(input_dir: Path, output_dir: Path) -> None: { "architectures": ["ZayaForCausalLM"], "num_hidden_layers": new_num_hidden_layers, - "intermediate_size": intermediate_size, + "moe_intermediate_size": moe_intermediate_size, "num_experts_per_tok": num_experts_per_tok, + "rms_norm_eps": rms_norm_eps, + "router_hidden_size": router_hidden_size, "layer_types": layer_types, "sliding_window": sliding_window, - "rope_parameters": {layer_type: rope_parameters[layer_type] for layer_type in set(layer_types)}, + "rope_parameters": rope_parameters, } ) ZayaConfig(**config_dict).save_pretrained(output_dir) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index b162dde52e02..a9a11daf14bb 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -144,15 +144,28 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + class ZayaCCAProjection(nn.Module): """ Projects hidden states into attention q/k/v states with ZAYA's CCA path. - `linear_q` and `linear_k` produce the residual q/k states and are concatenated into `qk_states`. The causal + `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail; for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. - Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses - `val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**. + Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token + `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection + from **the recurrent cache**. Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys. """ @@ -166,21 +179,22 @@ def __init__(self, config: ZayaConfig, layer_idx: int): self.depthwise_kernel_size = config.cca_time0 self.grouped_kernel_size = config.cca_time1 - self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) + self.conv_kernel_size = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) self.num_key_value_heads = config.num_key_value_heads self.num_attention_heads = config.num_attention_heads self.head_dim = config.head_dim - self.key_value_hidden_size = self.num_key_value_heads * self.head_dim - self.query_hidden_size = self.num_attention_heads * self.head_dim self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias) - self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias) - self.val_proj1 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias) - self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias) + query_hidden_size = self.num_attention_heads * self.head_dim + key_value_hidden_size = self.num_key_value_heads * self.head_dim - conv_channels = self.key_value_hidden_size + self.query_hidden_size + self.q_proj = nn.Linear(self.hidden_size, query_hidden_size, bias=self.config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, key_value_hidden_size, bias=self.config.attention_bias) + self.v_proj_current = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + self.v_proj_delayed = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + + conv_channels = key_value_hidden_size + query_hidden_size self.conv_qk_depthwise = nn.Conv1d( in_channels=conv_channels, out_channels=conv_channels, @@ -210,68 +224,71 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - projected_queries = self.linear_q(hidden_states) - projected_keys = self.linear_k(hidden_states) + projected_queries = self.q_proj(hidden_states) + projected_keys = self.k_proj(hidden_states) qk_states = torch.cat([projected_queries, projected_keys], dim=-1) query_residual = projected_queries.view(*hidden_shape) - key_residual = projected_keys.view(*input_shape, self.num_key_value_heads, self.head_dim) - - key_residual = key_residual.repeat_interleave(self.num_key_value_groups, dim=-2) + key_residual = projected_keys.view(*input_shape, -1, self.head_dim).transpose(1, 2) + key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) query_residual = (query_residual + key_residual) * 0.5 - key_residual = query_residual.view( - *input_shape, self.num_key_value_heads, self.num_key_value_groups, self.head_dim - ).mean(dim=-2) + key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) qk_states = qk_states.transpose(1, 2) use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx) if use_precomputed_states: cached_qk_states = past_key_values.layers[self.layer_idx].conv_states - conv_input = torch.cat([cached_qk_states, qk_states], dim=-1) + qk_states = torch.cat([cached_qk_states, qk_states], dim=-1) else: - conv_input = F.pad(qk_states, (self.total_padding, 0)) + qk_states = F.pad(qk_states, (self.conv_kernel_size, 0)) if past_key_values is not None: - new_conv_state = qk_states[..., -self.total_padding :] - if new_conv_state.shape[-1] < self.total_padding: - new_conv_state = F.pad(new_conv_state, (self.total_padding - new_conv_state.shape[-1], 0)) + new_conv_state = qk_states[..., -self.conv_kernel_size :] + if new_conv_state.shape[-1] < self.conv_kernel_size: + new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) past_key_values.update_conv_state(new_conv_state, self.layer_idx) - convolved_qk_states = self.conv_qk_depthwise(conv_input) - convolved_qk_states = self.conv_qk_grouped(convolved_qk_states).transpose(1, 2) - - query = ( - convolved_qk_states[..., : self.query_hidden_size].view( - *input_shape, self.num_attention_heads, self.head_dim - ) - + query_residual - ) + qk_states = self.conv_qk_depthwise(qk_states) + qk_states = self.conv_qk_grouped(qk_states).transpose(1, 2) - key = ( - convolved_qk_states[..., self.query_hidden_size :].view( - *input_shape, self.num_key_value_heads, self.head_dim - ) - + key_residual - ) + query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1] + query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual + key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual - value_current = self.val_proj1(hidden_states) - projected_v2 = self.val_proj2(hidden_states) + value_current = self.v_proj_current(hidden_states) + delayed_v_state = self.v_proj_delayed(hidden_states) if use_precomputed_states: - first_v2 = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) + recurrent_v_state = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) else: - first_v2 = self.val_proj2(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) - value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1) + recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) + value_delayed = torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1) if past_key_values is not None: - past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_idx) + past_key_values.update_recurrent_state(delayed_v_state[:, -1, :], self.layer_idx) - value = torch.cat([value_current, value_delayed], dim=-1).view( - *input_shape, self.num_key_value_heads, self.head_dim - ) + value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape) return query, key, value +class ZayaQKNorm(nn.Module): + def __init__(self, config: ZayaConfig, scaling: float): + super().__init__() + self.head_dim_scale = scaling**-1 + self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + + def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + norm_eps = torch.finfo(query_states.dtype).eps + query_states = query_states * ( + self.head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * ( + self.head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * self.temp[None, None, :, None] + return query_states, key_states + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -279,18 +296,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -377,10 +382,10 @@ def __init__(self, config: ZayaConfig, layer_idx: int): layer_idx=layer_idx, ) self.layer_type = config.layer_types[layer_idx] - self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads - self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + self.qk_norm = ZayaQKNorm(config, self.scaling) def forward( self, @@ -389,7 +394,7 @@ def forward( past_key_values: Cache | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: batch_size, seq_length, _ = hidden_states.shape mask_mapping = attention_mask or {} @@ -398,13 +403,7 @@ def forward( query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask) - norm_eps = torch.finfo(query_states.dtype).eps - head_dim_scale = self.scaling**-1 - query_states = query_states * ( - head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) - ) - key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)) - key_states = key_states * self.temp[None, None, :, None] + query_states, key_states = self.qk_norm(query_states, key_states) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -416,9 +415,6 @@ def forward( if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) - if isinstance(causal_mask, torch.Tensor): - causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]] - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) @@ -434,10 +430,10 @@ def forward( **kwargs, ) - attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim) + attn_output = attn_output.view(batch_size, seq_length, -1) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_values + return attn_output, attn_weights class ZayaDecoderLayer(GradientCheckpointingLayer): @@ -445,17 +441,11 @@ def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__() self.config = config self.self_attn = ZayaAttention(config, layer_idx) - self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) - self.zaya_block = ZayaSparseMoeBlock( - config, - config.num_experts, - config.zaya_mlp_expansion, - config.intermediate_size, - layer_idx, - ) - self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) - self.post_attention_res_scale = ResidualScaling(config.hidden_size) - self.post_mlp_res_scale = ResidualScaling(config.hidden_size) + self.input_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.mlp = ZayaSparseMoeBlock(config, layer_idx) + self.post_attention_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size) + self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size) def forward( self, @@ -465,13 +455,11 @@ def forward( past_key_values: Cache | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: residual = hidden_states - # Matches the original ZAYA `residual_in_fp32` path; norm casts back to the parameter dtype below. - residual = residual.to(torch.float32) - hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype)) + hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, @@ -479,20 +467,20 @@ def forward( **kwargs, ) - residual = self.post_attention_res_scale(hidden_states, residual) - hidden_states = self.post_attention_norm(residual.to(dtype=self.post_attention_norm.weight.dtype)) + residual = self.post_attention_residual_scale(hidden_states, residual) + hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype)) - hidden_states, prev_router_hidden_states, _ = self.zaya_block( + hidden_states, prev_router_hidden_states, _ = self.mlp( hidden_states, prev_router_hidden_states, ) - hidden_states = self.post_mlp_res_scale(hidden_states, residual) + hidden_states = self.post_mlp_residual_scale(hidden_states, residual) - return hidden_states, self_attn_weights, prev_router_hidden_states + return hidden_states, prev_router_hidden_states -class ResidualScaling(nn.Module): +class ZayaResidualScaling(nn.Module): def __init__(self, hidden_size: int): super().__init__() self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) @@ -501,20 +489,25 @@ def __init__(self, hidden_size: int): self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): + output_dtype = hidden_states.dtype hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale + # Matches the original ZAYA `residual_in_fp32` path. + residual = residual.to(torch.float32) residual = (residual + self.residual_bias) * self.residual_scale - return hidden_states + residual + return (hidden_states + residual).to(output_dtype) class ZayaRouterMLP(nn.Module): - def __init__(self, hidden_size: int, num_experts: int): + def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float): super().__init__() + self.rmsnorm_eda = ZayaRMSNorm(hidden_size, eps=rms_norm_eps) self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) self.act_fn = nn.GELU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.rmsnorm_eda(hidden_states) hidden_states = self.act_fn(self.fc1(hidden_states)) hidden_states = self.act_fn(self.fc2(hidden_states)) return self.out_proj(hidden_states) @@ -525,30 +518,24 @@ def __init__( self, config, layer_idx: int, - num_moe_experts: int, - num_experts_per_tok: int, - mlp_expansion: int, - hidden_size: int | None = None, ) -> None: super().__init__() self.config = config - self.hidden_size = int(hidden_size or getattr(config, "hidden_size")) + self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.num_experts = num_moe_experts + 1 - self.topk = int(num_experts_per_tok) - self.mlp_expansion = int(mlp_expansion) + self.num_experts = config.num_experts + 1 + self.top_k = config.num_experts_per_tok + self.router_hidden_size = config.router_hidden_size - self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True) + self.down_proj = nn.Linear(self.hidden_size, self.router_hidden_size, bias=True) self.use_eda = self.layer_idx != 0 - - self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=config.norm_epsilon) if self.use_eda: - self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion)) + self.router_states_scale = nn.Parameter(torch.ones(self.router_hidden_size)) - self.router_mlp = ZayaRouterMLP(self.mlp_expansion, self.num_experts) + self.router_mlp = ZayaRouterMLP(self.router_hidden_size, self.num_experts, config.rms_norm_eps) self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) self.balancing_biases[-1] = -1.0 @@ -558,27 +545,32 @@ def forward( hidden_states: torch.Tensor, router_states: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + final_shape = (-1, self.top_k) seq_length = hidden_states.shape[1] router_hidden_states = self.down_proj(hidden_states) - if self.use_eda and (router_states is not None): + if self.use_eda and router_states is not None: router_hidden_states = router_hidden_states + router_states * self.router_states_scale router_hidden_states_next = router_hidden_states[:, -seq_length:].clone() - router_hidden_states = self.rmsnorm_eda(router_hidden_states) - logits = self.router_mlp(router_hidden_states) - expert_prob = torch.softmax(logits, dim=-1) + router_logits = self.router_mlp(router_hidden_states) + router_probs = torch.softmax(router_logits, dim=-1) + + biased_router_probs = router_probs.detach().to(torch.float32) + self.balancing_biases + _, router_indices = torch.topk(biased_router_probs, self.top_k, dim=-1) + router_probs = torch.gather(router_probs, dim=2, index=router_indices) - expert_choice = expert_prob.detach().to(torch.float32) + self.balancing_biases - _, expert_choice = torch.topk(expert_choice, self.topk, dim=-1) - route_prob = torch.gather(expert_prob, dim=2, index=expert_choice) + # If the router selects the extra skip expert, mask it before `ZayaExperts` builds its one-hot expert mask. + skip_expert = router_indices == self.config.num_experts + router_probs = router_probs.masked_fill(skip_expert, 0) + router_indices = router_indices.masked_fill(skip_expert, 0) return ( - route_prob.reshape(-1, self.topk), - expert_choice.reshape(-1, self.topk), + router_logits.reshape(-1, self.num_experts), + router_probs.reshape(final_shape), + router_indices.reshape(final_shape), router_hidden_states_next, - logits.reshape(-1, self.num_experts), ) @@ -586,11 +578,11 @@ def forward( class ZayaExperts(nn.Module): """Collection of expert weights stored as 3D tensors.""" - def __init__(self, config, num_experts: int, intermediate_size: int): + def __init__(self, config): super().__init__() - self.num_experts = num_experts + self.num_experts = config.num_experts self.hidden_dim = config.hidden_size - self.intermediate_dim = intermediate_size // 2 + self.intermediate_dim = config.moe_intermediate_size self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) self.act_fn = ACT2FN[config.hidden_act] @@ -623,46 +615,25 @@ def forward( class ZayaSparseMoeBlock(nn.Module): - def __init__( - self, - config, - num_moe_experts: int, - mlp_expansion: int, - intermediate_size: int, - layer_idx: int, - ): + def __init__(self, config, layer_idx: int): super().__init__() self.config = config self.hidden_dim = config.hidden_size - self.num_moe_experts = num_moe_experts - self.router = ZayaRouter( - config=self.config, - layer_idx=layer_idx, - num_moe_experts=self.num_moe_experts, - num_experts_per_tok=self.config.num_experts_per_tok, - mlp_expansion=mlp_expansion, - hidden_size=self.hidden_dim, - ) - self.experts = ZayaExperts(self.config, self.num_moe_experts, intermediate_size=intermediate_size) + self.gate = ZayaRouter(self.config, layer_idx) + self.experts = ZayaExperts(self.config) def forward( self, hidden_states: torch.Tensor, prev_router_hidden_states: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: - route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router( + router_logits, router_probs, router_indices, prev_router_hidden_states = self.gate( hidden_states, router_states=prev_router_hidden_states ) - # if the router outputs num_moe_experts, just skip the tokens - # by masking them with id=0 and prob=0 to reuse the expert code - skip_expert = expert_choice == self.num_moe_experts - route_prob = route_prob.masked_fill(skip_expert, 0) - expert_choice = expert_choice.masked_fill(skip_expert, 0) - batch_size, seq_length, emb_dim = hidden_states.shape hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) - expert_output = self.experts(hidden_states_flat, expert_choice, route_prob) + expert_output = self.experts(hidden_states_flat, router_indices, router_probs) expert_output = expert_output.view(batch_size, seq_length, emb_dim) return expert_output, prev_router_hidden_states, router_logits @@ -682,7 +653,7 @@ class ZayaPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(ZayaRouter, index=3), + "router_logits": OutputRecorder(ZayaRouter, index=0), "hidden_states": ZayaDecoderLayer, "attentions": ZayaAttention, } @@ -691,7 +662,7 @@ class ZayaPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - if isinstance(module, ResidualScaling): + if isinstance(module, ZayaResidualScaling): init.ones_(module.hidden_states_scale) init.zeros_(module.hidden_states_bias) init.ones_(module.residual_scale) @@ -699,7 +670,7 @@ def _init_weights(self, module): elif isinstance(module, ZayaModel): init.ones_(module.input_hidden_states_scale) init.zeros_(module.input_hidden_states_bias) - elif isinstance(module, ZayaAttention): + elif isinstance(module, ZayaQKNorm): init.zeros_(module.temp) elif isinstance(module, ZayaRouter): if module.use_eda: @@ -726,7 +697,6 @@ def __init__(self, config: ZayaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.cache_layer_types = config.cache_layer_types self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -736,7 +706,7 @@ def __init__(self, config: ZayaConfig): self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) - self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.rotary_emb = ZayaRotaryEmbedding(config=config) @@ -772,11 +742,8 @@ def forward( if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - position_ids = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ).unsqueeze(0) + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) if attention_mask is not None and attention_mask.ndim != 2: raise ValueError( @@ -811,7 +778,7 @@ def forward( layer_type = self.config.layer_types[layer_n] emb_to_use = position_embeddings[layer_type] mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask} - layer_outputs = decoder_layer( + hidden_states, prev_router_hidden_states = decoder_layer( hidden_states, prev_router_hidden_states, attention_mask=mask_mapping, @@ -820,9 +787,6 @@ def forward( **kwargs, ) - hidden_states = layer_outputs[0] - prev_router_hidden_states = layer_outputs[2] - hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) return MoeModelOutputWithPast( @@ -847,8 +811,8 @@ def _update_causal_mask( # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} mask_creation_functions = { - "full_attention": lambda: create_causal_mask(**mask_kwargs), - "sliding_attention": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), + "hybrid": lambda: create_causal_mask(**mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), } causal_mask_mapping = {} for layer_type in set(self.config.layer_types): diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 04a6625b1313..d8655390ba61 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -24,13 +24,12 @@ from torch import nn from torch.nn import init -from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeModelOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -40,8 +39,9 @@ from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ..afmoe.modeling_afmoe import AfmoeForCausalLM +from ..laguna.configuration_laguna import LagunaConfig from ..laguna.modeling_laguna import LagunaRotaryEmbedding -from ..llama.modeling_llama import LlamaPreTrainedModel +from ..llama.modeling_llama import LlamaPreTrainedModel, repeat_kv from ..phi3.modeling_phi3 import Phi3Attention from ..qwen3_5_moe.modeling_qwen3_5_moe import ( apply_rotary_pos_emb, @@ -52,28 +52,16 @@ @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") @strict -class ZayaConfig(PreTrainedConfig): +class ZayaConfig(LagunaConfig): r""" - intermediate_size (`int`, *optional*, defaults to 4096): - Dimension of the feed-forward and expert hidden states. - num_key_value_heads (`int`, *optional*, defaults to 2): - Number of key/value groups. - partial_rotary_factor (`float`, *optional*, defaults to 0.5): - Fraction of each attention head dimension using rotary embeddings. lm_head_bias (`bool`, *optional*, defaults to `False`): Whether to add a bias to the language modeling head. - num_experts_per_tok (`int`, *optional*, defaults to 1): - Number of selected experts per token. ZAYA checkpoints use top-1 routing. - zaya_mlp_expansion (`int`, *optional*, defaults to 256): - Expansion size used by the dense ZAYA blocks. + router_hidden_size (`int`, *optional*, defaults to 256): + Hidden size used by the ZAYA router. cca_time0 (`int`, *optional*, defaults to 2): First temporal parameter of the CCA projection. cca_time1 (`int`, *optional*, defaults to 2): Second temporal parameter of the CCA projection. - layer_types (`list[str]`, *optional*): - Per-layer selector for standard RoPE versus SWA RoPE embeddings. - cache_layer_types (`list[str]`, *optional*): - Per-layer selector for cache layout. ZAYA uses the native `"hybrid"` cache layer for every decoder layer. ```python >>> from transformers import ZayaConfig, ZayaModel @@ -86,71 +74,61 @@ class ZayaConfig(PreTrainedConfig): """ model_type = "zaya" - keys_to_ignore_at_inference = ["past_key_values"] - default_theta = 5000000.0 - default_swa_theta = 10000.0 vocab_size: int = 262272 - hidden_size: int = 2048 - intermediate_size: int = 4096 - num_hidden_layers: int = 40 - num_experts: int = 16 + moe_intermediate_size: int = 2048 num_attention_heads: int = 8 num_key_value_heads: int = 2 - hidden_act: str = "silu" - head_dim: int = 128 - max_position_embeddings: int = 131072 - initializer_range: float = 0.02 - norm_epsilon: float = 1e-5 - use_cache: bool = True tie_word_embeddings: bool = True - rope_parameters: RopeParameters | dict | None = None - partial_rotary_factor: float = 0.5 - attention_bias: bool = False - lm_head_bias: bool = False - attention_dropout: float | int = 0.0 - num_experts_per_tok: int = 1 - zaya_mlp_expansion: int = 256 - cca_time0: int = 2 - cca_time1: int = 2 + rms_norm_eps: float = 1e-5 sliding_window: int | None = None - layer_types: list[str] | None = None - cache_layer_types: list[str] | None = None - output_router_logits: bool = False pad_token_id: int | None = 0 bos_token_id: int | None = 2 eos_token_id: int | list[int] | None = 106 + num_experts_per_tok: int = 1 + num_experts: int = 16 + + lm_head_bias: bool = False + router_hidden_size: int = 256 + cca_time0: int = 2 + cca_time1: int = 2 + + # Fields declared by LagunaConfig but not used by ZAYA. + # TP and PP are not tested yet, so remove for now + base_model_tp_plan = AttributeError() + base_model_pp_plan = AttributeError() + intermediate_size = AttributeError() + shared_expert_intermediate_size = AttributeError() + router_aux_loss_coef = AttributeError() + num_attention_heads_per_layer = AttributeError() + mlp_layer_types = AttributeError() + moe_routed_scaling_factor = AttributeError() + moe_apply_router_weight_on_input = AttributeError() + moe_router_logit_softcapping = AttributeError() + def __post_init__(self, **kwargs): - self.layer_types = ( - ["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) - ) - self.cache_layer_types = ( - ["hybrid"] * self.num_hidden_layers if self.cache_layer_types is None else list(self.cache_layer_types) - ) + self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) - default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = { - "full_attention": { + default_rope_params: dict[Literal["hybrid", "hybrid_sliding"], dict[str, Any]] = { + "hybrid": { "rope_type": "default", - "rope_theta": self.default_theta, - "partial_rotary_factor": self.partial_rotary_factor, + "rope_theta": 5_000_000.0, + "partial_rotary_factor": 0.5, }, - "sliding_attention": { + "hybrid_sliding": { "rope_type": "default", - "rope_theta": self.default_swa_theta, - "partial_rotary_factor": self.partial_rotary_factor, + "rope_theta": 10_000.0, + "partial_rotary_factor": 0.5, }, } if self.rope_parameters is None: - self.rope_parameters = { - layer_type: default_rope_params[layer_type] for layer_type in set(self.layer_types) - } + self.rope_parameters = default_rope_params - super().__post_init__(**kwargs) + PreTrainedConfig.__post_init__(self, **kwargs, ignore_keys_at_rope_validation={"hybrid", "hybrid_sliding"}) def convert_rope_params_to_dict(self, **kwargs): - # ZAYA uses nested RoPE parameters keyed by layer type. Keep the base RoPE BC conversion from treating them - # like a single flat RoPE dict and injecting top-level keys such as `rope_theta`. + # No legacy flat RoPE format is supported here; conversion writes the nested ZAYA layer-type format directly. return kwargs def validate_architecture(self): @@ -160,14 +138,10 @@ def validate_architecture(self): raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") if len(self.layer_types) != self.num_hidden_layers: raise ValueError("`layer_types` must have one entry per hidden layer.") - if len(self.cache_layer_types) != self.num_hidden_layers: - raise ValueError("`cache_layer_types` must have one entry per hidden layer.") - if invalid_cache_layer_types := set(self.cache_layer_types) - {"hybrid"}: - raise ValueError(f"`cache_layer_types` contains unsupported values: {sorted(invalid_cache_layer_types)}.") - if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}: + if invalid_layer_types := set(self.layer_types) - {"hybrid", "hybrid_sliding"}: raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.") - if "sliding_attention" in self.layer_types and self.sliding_window is None: - raise ValueError("`sliding_window` must be set when `layer_types` contains `sliding_attention`.") + if "hybrid_sliding" in self.layer_types and self.sliding_window is None: + raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.") if self.sliding_window is not None and self.sliding_window <= 0: raise ValueError("`sliding_window` must be a strictly positive integer.") @@ -184,11 +158,12 @@ class ZayaCCAProjection(nn.Module): """ Projects hidden states into attention q/k/v states with ZAYA's CCA path. - `linear_q` and `linear_k` produce the residual q/k states and are concatenated into `qk_states`. The causal + `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail; for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. - Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses - `val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**. + Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token + `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection + from **the recurrent cache**. Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys. """ @@ -202,21 +177,22 @@ def __init__(self, config: ZayaConfig, layer_idx: int): self.depthwise_kernel_size = config.cca_time0 self.grouped_kernel_size = config.cca_time1 - self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) + self.conv_kernel_size = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) self.num_key_value_heads = config.num_key_value_heads self.num_attention_heads = config.num_attention_heads self.head_dim = config.head_dim - self.key_value_hidden_size = self.num_key_value_heads * self.head_dim - self.query_hidden_size = self.num_attention_heads * self.head_dim self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias) - self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias) - self.val_proj1 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias) - self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias) + query_hidden_size = self.num_attention_heads * self.head_dim + key_value_hidden_size = self.num_key_value_heads * self.head_dim - conv_channels = self.key_value_hidden_size + self.query_hidden_size + self.q_proj = nn.Linear(self.hidden_size, query_hidden_size, bias=self.config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, key_value_hidden_size, bias=self.config.attention_bias) + self.v_proj_current = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + self.v_proj_delayed = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + + conv_channels = key_value_hidden_size + query_hidden_size self.conv_qk_depthwise = nn.Conv1d( in_channels=conv_channels, out_channels=conv_channels, @@ -246,81 +222,84 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - projected_queries = self.linear_q(hidden_states) - projected_keys = self.linear_k(hidden_states) + projected_queries = self.q_proj(hidden_states) + projected_keys = self.k_proj(hidden_states) qk_states = torch.cat([projected_queries, projected_keys], dim=-1) query_residual = projected_queries.view(*hidden_shape) - key_residual = projected_keys.view(*input_shape, self.num_key_value_heads, self.head_dim) - - key_residual = key_residual.repeat_interleave(self.num_key_value_groups, dim=-2) + key_residual = projected_keys.view(*input_shape, -1, self.head_dim).transpose(1, 2) + key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) query_residual = (query_residual + key_residual) * 0.5 - key_residual = query_residual.view( - *input_shape, self.num_key_value_heads, self.num_key_value_groups, self.head_dim - ).mean(dim=-2) + key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) qk_states = qk_states.transpose(1, 2) use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx) if use_precomputed_states: cached_qk_states = past_key_values.layers[self.layer_idx].conv_states - conv_input = torch.cat([cached_qk_states, qk_states], dim=-1) + qk_states = torch.cat([cached_qk_states, qk_states], dim=-1) else: - conv_input = F.pad(qk_states, (self.total_padding, 0)) + qk_states = F.pad(qk_states, (self.conv_kernel_size, 0)) if past_key_values is not None: - new_conv_state = qk_states[..., -self.total_padding :] - if new_conv_state.shape[-1] < self.total_padding: - new_conv_state = F.pad(new_conv_state, (self.total_padding - new_conv_state.shape[-1], 0)) + new_conv_state = qk_states[..., -self.conv_kernel_size :] + if new_conv_state.shape[-1] < self.conv_kernel_size: + new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) past_key_values.update_conv_state(new_conv_state, self.layer_idx) - convolved_qk_states = self.conv_qk_depthwise(conv_input) - convolved_qk_states = self.conv_qk_grouped(convolved_qk_states).transpose(1, 2) - - query = ( - convolved_qk_states[..., : self.query_hidden_size].view( - *input_shape, self.num_attention_heads, self.head_dim - ) - + query_residual - ) + qk_states = self.conv_qk_depthwise(qk_states) + qk_states = self.conv_qk_grouped(qk_states).transpose(1, 2) - key = ( - convolved_qk_states[..., self.query_hidden_size :].view( - *input_shape, self.num_key_value_heads, self.head_dim - ) - + key_residual - ) + query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1] + query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual + key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual - value_current = self.val_proj1(hidden_states) - projected_v2 = self.val_proj2(hidden_states) + value_current = self.v_proj_current(hidden_states) + delayed_v_state = self.v_proj_delayed(hidden_states) if use_precomputed_states: - first_v2 = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) + recurrent_v_state = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) else: - first_v2 = self.val_proj2(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) - value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1) + recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) + value_delayed = torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1) if past_key_values is not None: - past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_idx) + past_key_values.update_recurrent_state(delayed_v_state[:, -1, :], self.layer_idx) - value = torch.cat([value_current, value_delayed], dim=-1).view( - *input_shape, self.num_key_value_heads, self.head_dim - ) + value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape) return query, key, value +class ZayaQKNorm(nn.Module): + def __init__(self, config: ZayaConfig, scaling: float): + super().__init__() + self.head_dim_scale = scaling**-1 + self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + + def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + norm_eps = torch.finfo(query_states.dtype).eps + query_states = query_states * ( + self.head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * ( + self.head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * self.temp[None, None, :, None] + return query_states, key_states + + class ZayaAttention(Phi3Attention): def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__(config, layer_idx) del op_size # noqa: F821 self.layer_type = config.layer_types[layer_idx] - self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + self.qk_norm = ZayaQKNorm(config, self.scaling) self.qkv_proj = ZayaCCAProjection( config=self.config, layer_idx=layer_idx, @@ -333,7 +312,7 @@ def forward( past_key_values: Cache | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: batch_size, seq_length, _ = hidden_states.shape mask_mapping = attention_mask or {} @@ -342,13 +321,7 @@ def forward( query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask) - norm_eps = torch.finfo(query_states.dtype).eps - head_dim_scale = self.scaling**-1 - query_states = query_states * ( - head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) - ) - key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)) - key_states = key_states * self.temp[None, None, :, None] + query_states, key_states = self.qk_norm(query_states, key_states) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -360,9 +333,6 @@ def forward( if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) - if isinstance(causal_mask, torch.Tensor): - causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]] - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) @@ -378,10 +348,10 @@ def forward( **kwargs, ) - attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim) + attn_output = attn_output.view(batch_size, seq_length, -1) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_values + return attn_output, attn_weights class ZayaDecoderLayer(GradientCheckpointingLayer): @@ -389,17 +359,11 @@ def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__() self.config = config self.self_attn = ZayaAttention(config, layer_idx) - self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) - self.zaya_block = ZayaSparseMoeBlock( - config, - config.num_experts, - config.zaya_mlp_expansion, - config.intermediate_size, - layer_idx, - ) - self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) - self.post_attention_res_scale = ResidualScaling(config.hidden_size) - self.post_mlp_res_scale = ResidualScaling(config.hidden_size) + self.input_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.mlp = ZayaSparseMoeBlock(config, layer_idx) + self.post_attention_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size) + self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size) def forward( self, @@ -409,13 +373,11 @@ def forward( past_key_values: Cache | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: residual = hidden_states - # Matches the original ZAYA `residual_in_fp32` path; norm casts back to the parameter dtype below. - residual = residual.to(torch.float32) - hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype)) + hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, @@ -423,20 +385,20 @@ def forward( **kwargs, ) - residual = self.post_attention_res_scale(hidden_states, residual) - hidden_states = self.post_attention_norm(residual.to(dtype=self.post_attention_norm.weight.dtype)) + residual = self.post_attention_residual_scale(hidden_states, residual) + hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype)) - hidden_states, prev_router_hidden_states, _ = self.zaya_block( + hidden_states, prev_router_hidden_states, _ = self.mlp( hidden_states, prev_router_hidden_states, ) - hidden_states = self.post_mlp_res_scale(hidden_states, residual) + hidden_states = self.post_mlp_residual_scale(hidden_states, residual) - return hidden_states, self_attn_weights, prev_router_hidden_states + return hidden_states, prev_router_hidden_states -class ResidualScaling(nn.Module): +class ZayaResidualScaling(nn.Module): def __init__(self, hidden_size: int): super().__init__() self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) @@ -445,20 +407,25 @@ def __init__(self, hidden_size: int): self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): + output_dtype = hidden_states.dtype hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale + # Matches the original ZAYA `residual_in_fp32` path. + residual = residual.to(torch.float32) residual = (residual + self.residual_bias) * self.residual_scale - return hidden_states + residual + return (hidden_states + residual).to(output_dtype) class ZayaRouterMLP(nn.Module): - def __init__(self, hidden_size: int, num_experts: int): + def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float): super().__init__() + self.rmsnorm_eda = ZayaRMSNorm(hidden_size, eps=rms_norm_eps) self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) self.act_fn = nn.GELU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.rmsnorm_eda(hidden_states) hidden_states = self.act_fn(self.fc1(hidden_states)) hidden_states = self.act_fn(self.fc2(hidden_states)) return self.out_proj(hidden_states) @@ -469,30 +436,24 @@ def __init__( self, config, layer_idx: int, - num_moe_experts: int, - num_experts_per_tok: int, - mlp_expansion: int, - hidden_size: int | None = None, ) -> None: super().__init__() self.config = config - self.hidden_size = int(hidden_size or getattr(config, "hidden_size")) + self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.num_experts = num_moe_experts + 1 - self.topk = int(num_experts_per_tok) - self.mlp_expansion = int(mlp_expansion) + self.num_experts = config.num_experts + 1 + self.top_k = config.num_experts_per_tok + self.router_hidden_size = config.router_hidden_size - self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True) + self.down_proj = nn.Linear(self.hidden_size, self.router_hidden_size, bias=True) self.use_eda = self.layer_idx != 0 - - self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=config.norm_epsilon) if self.use_eda: - self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion)) + self.router_states_scale = nn.Parameter(torch.ones(self.router_hidden_size)) - self.router_mlp = ZayaRouterMLP(self.mlp_expansion, self.num_experts) + self.router_mlp = ZayaRouterMLP(self.router_hidden_size, self.num_experts, config.rms_norm_eps) self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) self.balancing_biases[-1] = -1.0 @@ -502,82 +463,59 @@ def forward( hidden_states: torch.Tensor, router_states: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + final_shape = (-1, self.top_k) seq_length = hidden_states.shape[1] router_hidden_states = self.down_proj(hidden_states) - if self.use_eda and (router_states is not None): + if self.use_eda and router_states is not None: router_hidden_states = router_hidden_states + router_states * self.router_states_scale router_hidden_states_next = router_hidden_states[:, -seq_length:].clone() - router_hidden_states = self.rmsnorm_eda(router_hidden_states) - logits = self.router_mlp(router_hidden_states) - expert_prob = torch.softmax(logits, dim=-1) + router_logits = self.router_mlp(router_hidden_states) + router_probs = torch.softmax(router_logits, dim=-1) + + biased_router_probs = router_probs.detach().to(torch.float32) + self.balancing_biases + _, router_indices = torch.topk(biased_router_probs, self.top_k, dim=-1) + router_probs = torch.gather(router_probs, dim=2, index=router_indices) - expert_choice = expert_prob.detach().to(torch.float32) + self.balancing_biases - _, expert_choice = torch.topk(expert_choice, self.topk, dim=-1) - route_prob = torch.gather(expert_prob, dim=2, index=expert_choice) + # If the router selects the extra skip expert, mask it before `ZayaExperts` builds its one-hot expert mask. + skip_expert = router_indices == self.config.num_experts + router_probs = router_probs.masked_fill(skip_expert, 0) + router_indices = router_indices.masked_fill(skip_expert, 0) return ( - route_prob.reshape(-1, self.topk), - expert_choice.reshape(-1, self.topk), + router_logits.reshape(-1, self.num_experts), + router_probs.reshape(final_shape), + router_indices.reshape(final_shape), router_hidden_states_next, - logits.reshape(-1, self.num_experts), ) class ZayaExperts(Qwen3MoeExperts): - def __init__(self, config, num_experts: int, intermediate_size: int): - nn.Module.__init__(self) - self.num_experts = num_experts - self.hidden_dim = config.hidden_size - self.intermediate_dim = intermediate_size // 2 - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) - self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) - self.act_fn = ACT2FN[config.hidden_act] + pass class ZayaSparseMoeBlock(nn.Module): - def __init__( - self, - config, - num_moe_experts: int, - mlp_expansion: int, - intermediate_size: int, - layer_idx: int, - ): + def __init__(self, config, layer_idx: int): super().__init__() self.config = config self.hidden_dim = config.hidden_size - self.num_moe_experts = num_moe_experts - self.router = ZayaRouter( - config=self.config, - layer_idx=layer_idx, - num_moe_experts=self.num_moe_experts, - num_experts_per_tok=self.config.num_experts_per_tok, - mlp_expansion=mlp_expansion, - hidden_size=self.hidden_dim, - ) - self.experts = ZayaExperts(self.config, self.num_moe_experts, intermediate_size=intermediate_size) + self.gate = ZayaRouter(self.config, layer_idx) + self.experts = ZayaExperts(self.config) def forward( self, hidden_states: torch.Tensor, prev_router_hidden_states: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: - route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router( + router_logits, router_probs, router_indices, prev_router_hidden_states = self.gate( hidden_states, router_states=prev_router_hidden_states ) - # if the router outputs num_moe_experts, just skip the tokens - # by masking them with id=0 and prob=0 to reuse the expert code - skip_expert = expert_choice == self.num_moe_experts - route_prob = route_prob.masked_fill(skip_expert, 0) - expert_choice = expert_choice.masked_fill(skip_expert, 0) - batch_size, seq_length, emb_dim = hidden_states.shape hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) - expert_output = self.experts(hidden_states_flat, expert_choice, route_prob) + expert_output = self.experts(hidden_states_flat, router_indices, router_probs) expert_output = expert_output.view(batch_size, seq_length, emb_dim) return expert_output, prev_router_hidden_states, router_logits @@ -590,7 +528,7 @@ class ZayaPreTrainedModel(LlamaPreTrainedModel): # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache. _can_compile_fullgraph = False _can_record_outputs = { - "router_logits": OutputRecorder(ZayaRouter, index=3), + "router_logits": OutputRecorder(ZayaRouter, index=0), "hidden_states": ZayaDecoderLayer, "attentions": ZayaAttention, } @@ -598,7 +536,7 @@ class ZayaPreTrainedModel(LlamaPreTrainedModel): @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) - if isinstance(module, ResidualScaling): + if isinstance(module, ZayaResidualScaling): init.ones_(module.hidden_states_scale) init.zeros_(module.hidden_states_bias) init.ones_(module.residual_scale) @@ -606,7 +544,7 @@ def _init_weights(self, module): elif isinstance(module, ZayaModel): init.ones_(module.input_hidden_states_scale) init.zeros_(module.input_hidden_states_bias) - elif isinstance(module, ZayaAttention): + elif isinstance(module, ZayaQKNorm): init.zeros_(module.temp) elif isinstance(module, ZayaRouter): if module.use_eda: @@ -633,7 +571,6 @@ def __init__(self, config: ZayaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.cache_layer_types = config.cache_layer_types self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -643,7 +580,7 @@ def __init__(self, config: ZayaConfig): self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) - self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.rotary_emb = ZayaRotaryEmbedding(config=config) @@ -679,11 +616,8 @@ def forward( if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - position_ids = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ).unsqueeze(0) + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) if attention_mask is not None and attention_mask.ndim != 2: raise ValueError( @@ -718,7 +652,7 @@ def forward( layer_type = self.config.layer_types[layer_n] emb_to_use = position_embeddings[layer_type] mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask} - layer_outputs = decoder_layer( + hidden_states, prev_router_hidden_states = decoder_layer( hidden_states, prev_router_hidden_states, attention_mask=mask_mapping, @@ -727,9 +661,6 @@ def forward( **kwargs, ) - hidden_states = layer_outputs[0] - prev_router_hidden_states = layer_outputs[2] - hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) return MoeModelOutputWithPast( @@ -754,8 +685,8 @@ def _update_causal_mask( # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} mask_creation_functions = { - "full_attention": lambda: create_causal_mask(**mask_kwargs), - "sliding_attention": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), + "hybrid": lambda: create_causal_mask(**mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), } causal_mask_mapping = {} for layer_type in set(self.config.layer_types): @@ -764,18 +695,14 @@ def _update_causal_mask( @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") -class ZayaForCausalLM(ZayaPreTrainedModel, AfmoeForCausalLM): +class ZayaForCausalLM(AfmoeForCausalLM, ZayaPreTrainedModel): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _is_stateful = True def __init__(self, config, **kwargs): super().__init__(config, **kwargs) - self.model = ZayaModel(config) - self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) - self.post_init() - __all__ = [ "ZayaConfig", diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index 94bd74093e15..37a081273dcb 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -26,7 +26,11 @@ import torch from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel - from transformers.cache_utils import DynamicCache, LinearAttentionAndFullAttentionLayer + from transformers.cache_utils import ( + DynamicCache, + LinearAttentionAndFullAttentionLayer, + LinearAttentionAndSlidingWindowAttentionLayer, + ) from transformers.models.zaya.modeling_zaya import ZayaCCAProjection from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester @@ -46,15 +50,15 @@ def __init__(self, parent): num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, - intermediate_size=64, + moe_intermediate_size=32, ) self.head_dim = 8 self.num_experts = 4 self.num_experts_per_tok = 1 - self.zaya_mlp_expansion = 4 + self.router_hidden_size = 4 self.tie_word_embeddings = False self.rope_parameters = { - "full_attention": { + "hybrid": { "rope_theta": 10000, "rope_type": "default", "partial_rotary_factor": 0.5, @@ -69,7 +73,8 @@ class ZayaModelTest(CausalLMModelTest, unittest.TestCase): def _get_conv_state_shape(self, batch_size: int, config): conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim - return (batch_size, conv_state_size, config.cca_time0 + config.cca_time1 - 2) + conv_kernel_size = config.cca_time0 + config.cca_time1 - 2 + return (batch_size, conv_state_size, conv_kernel_size) def _get_recurrent_state_shape(self, batch_size: int, config): return (batch_size, config.num_key_value_heads * config.head_dim // 2) @@ -84,8 +89,13 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l conv_shape = self._get_conv_state_shape(batch_size, config) recurrent_shape = self._get_recurrent_state_shape(batch_size, config) - for layer in past_key_values.layers: - self.assertIs(type(layer), LinearAttentionAndFullAttentionLayer) + for layer_type, layer in zip(config.layer_types, past_key_values.layers): + expected_layer_class = ( + LinearAttentionAndSlidingWindowAttentionLayer + if layer_type == "hybrid_sliding" + else LinearAttentionAndFullAttentionLayer + ) + self.assertIs(type(layer), expected_layer_class) self.assertEqual(layer.keys.shape, attention_shape) self.assertEqual(layer.values.shape, attention_shape) self.assertEqual(layer.conv_states.shape, conv_shape) @@ -153,13 +163,13 @@ def test_model_rope_scaling_frequencies(self): Copied from Laguna to adapt to per-layer-type rope configs. """ config, _ = self.model_tester.prepare_config_and_inputs_for_common() - config.layer_types = ["full_attention", "sliding_attention"] - partial_rotary_factor = config.partial_rotary_factor + config.layer_types = ["hybrid", "hybrid_sliding"] + partial_rotary_factor = config.rope_parameters["hybrid"]["partial_rotary_factor"] def set_rope_params(rope_params): config.rope_parameters = { - "full_attention": {**rope_params, "partial_rotary_factor": partial_rotary_factor}, - "sliding_attention": {**rope_params, "partial_rotary_factor": partial_rotary_factor}, + "hybrid": {**rope_params, "partial_rotary_factor": partial_rotary_factor}, + "hybrid_sliding": {**rope_params, "partial_rotary_factor": partial_rotary_factor}, } set_rope_params({"rope_type": "default", "rope_theta": 10_000.0}) @@ -186,15 +196,15 @@ def set_rope_params(rope_params): set_rope_params({"rope_type": "default", "rope_theta": 10_000.0}) original_rope = rope_class(config=config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short, layer_type="sliding_attention") - original_cos_long, original_sin_long = original_rope(x, position_ids_long, layer_type="sliding_attention") + original_cos_short, original_sin_short = original_rope(x, position_ids_short, layer_type="hybrid_sliding") + original_cos_long, original_sin_long = original_rope(x, position_ids_long, layer_type="hybrid_sliding") torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) set_rope_params({"rope_type": "linear", "factor": scaling_factor, "rope_theta": 10_000.0}) linear_scaling_rope = rope_class(config=config).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short, layer_type="sliding_attention") - linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long, layer_type="sliding_attention") + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding") + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding") torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) for new_position in range(0, long_input_length, scaling_factor): @@ -204,22 +214,20 @@ def set_rope_params(rope_params): set_rope_params({"rope_type": "dynamic", "factor": scaling_factor, "rope_theta": 10_000.0}) ntk_scaling_rope = rope_class(config=config).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short, layer_type="sliding_attention") - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long, layer_type="sliding_attention") + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding") + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding") torch.testing.assert_close(ntk_cos_short, original_cos_short) torch.testing.assert_close(ntk_sin_short, original_sin_short) with self.assertRaises(AssertionError): torch.testing.assert_close(ntk_cos_long, original_cos_long) with self.assertRaises(AssertionError): torch.testing.assert_close(ntk_sin_long, original_sin_long) - self.assertTrue( - (ntk_scaling_rope.sliding_attention_inv_freq <= original_rope.sliding_attention_inv_freq).all() - ) + self.assertTrue((ntk_scaling_rope.hybrid_sliding_inv_freq <= original_rope.hybrid_sliding_inv_freq).all()) set_rope_params({"rope_type": "yarn", "factor": scaling_factor, "rope_theta": 10_000.0}) yarn_scaling_rope = rope_class(config=config).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short, layer_type="sliding_attention") - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long, layer_type="sliding_attention") + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding") + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding") torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) with self.assertRaises(AssertionError): @@ -259,14 +267,14 @@ def test_sliding_attention_mask_is_used(self): config = ZayaConfig( vocab_size=128, hidden_size=32, - intermediate_size=64, + moe_intermediate_size=32, num_hidden_layers=4, num_experts=4, num_attention_heads=4, num_key_value_heads=2, head_dim=8, - zaya_mlp_expansion=4, - layer_types=["sliding_attention", "full_attention", "full_attention", "full_attention"], + router_hidden_size=4, + layer_types=["hybrid_sliding", "hybrid", "hybrid_sliding", "hybrid"], sliding_window=3, tie_word_embeddings=False, attn_implementation="eager", @@ -285,13 +293,13 @@ def test_cca_cache_matches_full_forward(self): config = ZayaConfig( vocab_size=128, hidden_size=32, - intermediate_size=64, + moe_intermediate_size=32, num_hidden_layers=1, num_experts=4, num_attention_heads=4, num_key_value_heads=2, head_dim=8, - zaya_mlp_expansion=4, + router_hidden_size=4, tie_word_embeddings=False, ) torch.manual_seed(0) @@ -312,13 +320,13 @@ def test_cca_cache_matches_full_forward_multi_token(self): config = ZayaConfig( vocab_size=128, hidden_size=32, - intermediate_size=64, + moe_intermediate_size=32, num_hidden_layers=1, num_experts=4, num_attention_heads=4, num_key_value_heads=2, head_dim=8, - zaya_mlp_expansion=4, + router_hidden_size=4, tie_word_embeddings=False, ) torch.manual_seed(0) From 7bb5122a923b786e892b7d49a1c021b933cc3b0c Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Thu, 14 May 2026 15:41:33 +0800 Subject: [PATCH 25/36] llama decoderlayer --- src/transformers/models/zaya/modeling_zaya.py | 9 +++++---- src/transformers/models/zaya/modular_zaya.py | 11 +++-------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index a9a11daf14bb..d2e0712f603e 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -439,11 +439,12 @@ def forward( class ZayaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ZayaConfig, layer_idx: int): super().__init__() - self.config = config - self.self_attn = ZayaAttention(config, layer_idx) - self.input_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.hidden_size = config.hidden_size + + self.self_attn = ZayaAttention(config=config, layer_idx=layer_idx) self.mlp = ZayaSparseMoeBlock(config, layer_idx) - self.post_attention_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.input_layernorm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size) self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size) diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index d8655390ba61..5e325fb78d00 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -27,7 +27,6 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask -from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -41,7 +40,7 @@ from ..afmoe.modeling_afmoe import AfmoeForCausalLM from ..laguna.configuration_laguna import LagunaConfig from ..laguna.modeling_laguna import LagunaRotaryEmbedding -from ..llama.modeling_llama import LlamaPreTrainedModel, repeat_kv +from ..llama.modeling_llama import LlamaDecoderLayer, LlamaPreTrainedModel, repeat_kv from ..phi3.modeling_phi3 import Phi3Attention from ..qwen3_5_moe.modeling_qwen3_5_moe import ( apply_rotary_pos_emb, @@ -354,14 +353,10 @@ def forward( return attn_output, attn_weights -class ZayaDecoderLayer(GradientCheckpointingLayer): +class ZayaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: ZayaConfig, layer_idx: int): - super().__init__() - self.config = config - self.self_attn = ZayaAttention(config, layer_idx) - self.input_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + super().__init__(config, layer_idx) self.mlp = ZayaSparseMoeBlock(config, layer_idx) - self.post_attention_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size) self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size) From b315ae07e07ae1c50f5d841842ad7306555973d3 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sat, 16 May 2026 11:51:23 +0800 Subject: [PATCH 26/36] improve --- .../models/auto/tokenization_auto.py | 1 + .../models/zaya/configuration_zaya.py | 6 - src/transformers/models/zaya/modeling_zaya.py | 129 +++++++-------- src/transformers/models/zaya/modular_zaya.py | 149 +++++++----------- tests/models/zaya/test_modeling_zaya.py | 49 ++---- tests/test_modeling_common.py | 1 + 6 files changed, 128 insertions(+), 207 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index db34543e63a1..75677f8ea505 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -340,6 +340,7 @@ ("xlstm", "GPTNeoXTokenizer" if is_tokenizers_available() else None), ("xmod", "XLMRobertaTokenizer" if is_tokenizers_available() else None), ("yoso", "AlbertTokenizer" if is_tokenizers_available() else None), + ("zaya", "GemmaTokenizer" if is_tokenizers_available() else None), ] ) diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py index 6fb47ecb76a4..4a18bd4716f2 100644 --- a/src/transformers/models/zaya/configuration_zaya.py +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -116,14 +116,8 @@ def validate_architecture(self): raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.") if self.num_attention_heads % self.num_key_value_heads != 0: raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") - if len(self.layer_types) != self.num_hidden_layers: - raise ValueError("`layer_types` must have one entry per hidden layer.") - if invalid_layer_types := set(self.layer_types) - {"hybrid", "hybrid_sliding"}: - raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.") if "hybrid_sliding" in self.layer_types and self.sliding_window is None: raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.") - if self.sliding_window is not None and self.sliding_window <= 0: - raise ValueError("`sliding_window` must be a strictly positive integer.") __all__ = ["ZayaConfig"] diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index d2e0712f603e..0815020f0e2e 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -166,8 +166,6 @@ class ZayaCCAProjection(nn.Module): Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection from **the recurrent cache**. - - Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys. """ def __init__(self, config: ZayaConfig, layer_idx: int): @@ -229,7 +227,7 @@ def forward( qk_states = torch.cat([projected_queries, projected_keys], dim=-1) query_residual = projected_queries.view(*hidden_shape) - key_residual = projected_keys.view(*input_shape, -1, self.head_dim).transpose(1, 2) + key_residual = projected_keys.view(*hidden_shape).transpose(1, 2) key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) query_residual = (query_residual + key_residual) * 0.5 key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) @@ -255,6 +253,8 @@ def forward( query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual + # The value path carries half of each value head from the current token and half from the previous token. + # During cached decoding, `recurrent_v_state` is the previous token's delayed projection. value_current = self.v_proj_current(hidden_states) delayed_v_state = self.v_proj_delayed(hidden_states) if use_precomputed_states: @@ -272,9 +272,13 @@ def forward( class ZayaQKNorm(nn.Module): - def __init__(self, config: ZayaConfig, scaling: float): + """ + L2-normalizes q/k states to sqrt(head_dim) and applies ZAYA's learned per-KV-head key scale. + """ + + def __init__(self, config: ZayaConfig): super().__init__() - self.head_dim_scale = scaling**-1 + self.head_dim_scale = config.head_dim**0.5 self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -385,7 +389,7 @@ def __init__(self, config: ZayaConfig, layer_idx: int): self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads - self.qk_norm = ZayaQKNorm(config, self.scaling) + self.qk_norm = ZayaQKNorm(config) def forward( self, @@ -395,7 +399,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: - batch_size, seq_length, _ = hidden_states.shape + input_shape = hidden_states.shape[:-1] mask_mapping = attention_mask or {} causal_mask = mask_mapping.get("causal") @@ -430,7 +434,7 @@ def forward( **kwargs, ) - attn_output = attn_output.view(batch_size, seq_length, -1) + attn_output = attn_output.view(*input_shape, -1) attn_output = self.o_proj(attn_output) return attn_output, attn_weights @@ -469,9 +473,9 @@ def forward( ) residual = self.post_attention_residual_scale(hidden_states, residual) - hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype)) + hidden_states = self.post_attention_layernorm(residual) - hidden_states, prev_router_hidden_states, _ = self.mlp( + hidden_states, prev_router_hidden_states = self.mlp( hidden_states, prev_router_hidden_states, ) @@ -618,17 +622,15 @@ def forward( class ZayaSparseMoeBlock(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() - self.config = config - self.hidden_dim = config.hidden_size - self.gate = ZayaRouter(self.config, layer_idx) - self.experts = ZayaExperts(self.config) + self.gate = ZayaRouter(config, layer_idx) + self.experts = ZayaExperts(config) def forward( self, hidden_states: torch.Tensor, prev_router_hidden_states: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: - router_logits, router_probs, router_indices, prev_router_hidden_states = self.gate( + ) -> tuple[torch.Tensor, torch.Tensor | None]: + _, router_probs, router_indices, prev_router_hidden_states = self.gate( hidden_states, router_states=prev_router_hidden_states ) @@ -637,7 +639,7 @@ def forward( expert_output = self.experts(hidden_states_flat, router_indices, router_probs) expert_output = expert_output.view(batch_size, seq_length, emb_dim) - return expert_output, prev_router_hidden_states, router_logits + return expert_output, prev_router_hidden_states @auto_docstring @@ -658,7 +660,6 @@ class ZayaPreTrainedModel(PreTrainedModel): "hidden_states": ZayaDecoderLayer, "attentions": ZayaAttention, } - config_class = ZayaConfig @torch.no_grad() def _init_weights(self, module): @@ -677,7 +678,7 @@ def _init_weights(self, module): if module.use_eda: init.ones_(module.router_states_scale) init.zeros_(module.balancing_biases) - module.balancing_biases[-1] = -1.0 + module.balancing_biases[-1] = -1.0 # ignore: trf012 elif isinstance(module, ZayaExperts): std = self.config.initializer_range init.normal_(module.gate_up_proj, mean=0.0, std=std) @@ -698,27 +699,20 @@ def __init__(self, config: ZayaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - + self.rotary_emb = ZayaRotaryEmbedding(config=config) self.gradient_checkpointing = False - self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - self.rotary_emb = ZayaRotaryEmbedding(config=config) - + # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - @merge_with_config_defaults @capture_outputs @auto_docstring @@ -751,18 +745,23 @@ def forward( "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." ) - causal_mask_mapping = self._update_causal_mask( - attention_mask, - inputs_embeds, - position_ids, - past_key_values, - ) - padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None - - # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. - # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state. - if inputs_embeds.shape[1] == 1: - padding_mask = None + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. + sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} + mask_creation_functions = { + "hybrid": lambda: create_causal_mask(**mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + causal_mask_mapping = { + layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types) + } + cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds) hidden_states = inputs_embeds @@ -775,50 +774,36 @@ def forward( prev_router_hidden_states = None - for layer_n, decoder_layer in enumerate(self.layers): - layer_type = self.config.layer_types[layer_n] - emb_to_use = position_embeddings[layer_type] - mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask} + for idx, decoder_layer in enumerate(self.layers): + layer_type = self.config.layer_types[idx] hidden_states, prev_router_hidden_states = decoder_layer( hidden_states, prev_router_hidden_states, - attention_mask=mask_mapping, + attention_mask={"causal": causal_mask_mapping[layer_type], "padding": cca_mask}, past_key_values=past_key_values, - position_embeddings=emb_to_use, + position_embeddings=position_embeddings[layer_type], **kwargs, ) - hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) + hidden_states = self.final_norm(hidden_states) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - position_ids: torch.Tensor, - past_key_values: Cache, - ): - mask_kwargs = { - "config": self.config, - "inputs_embeds": input_tensor, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "position_ids": position_ids, - } - # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. - sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} - mask_creation_functions = { - "hybrid": lambda: create_causal_mask(**mask_kwargs), - "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), - } - causal_mask_mapping = {} - for layer_type in set(self.config.layer_types): - causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]() - return causal_mask_mapping + def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds): + """ + No need to zero padding states when cached convolution states are already available or all inputs are valid. + """ + cca_mask = attention_mask + if (past_key_values is not None and past_key_values.has_previous_state()) or ( + attention_mask is not None and torch.all(attention_mask == 1) + ): + cca_mask = None + elif attention_mask is not None: + cca_mask = attention_mask[:, -inputs_embeds.shape[1] :] + return cca_mask @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 5e325fb78d00..1967bf6fc64a 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -39,7 +39,7 @@ from ...utils.output_capturing import OutputRecorder, capture_outputs from ..afmoe.modeling_afmoe import AfmoeForCausalLM from ..laguna.configuration_laguna import LagunaConfig -from ..laguna.modeling_laguna import LagunaRotaryEmbedding +from ..laguna.modeling_laguna import LagunaModel, LagunaRotaryEmbedding from ..llama.modeling_llama import LlamaDecoderLayer, LlamaPreTrainedModel, repeat_kv from ..phi3.modeling_phi3 import Phi3Attention from ..qwen3_5_moe.modeling_qwen3_5_moe import ( @@ -94,7 +94,7 @@ class ZayaConfig(LagunaConfig): cca_time1: int = 2 # Fields declared by LagunaConfig but not used by ZAYA. - # TP and PP are not tested yet, so remove for now + # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state. base_model_tp_plan = AttributeError() base_model_pp_plan = AttributeError() intermediate_size = AttributeError() @@ -135,14 +135,8 @@ def validate_architecture(self): raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.") if self.num_attention_heads % self.num_key_value_heads != 0: raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") - if len(self.layer_types) != self.num_hidden_layers: - raise ValueError("`layer_types` must have one entry per hidden layer.") - if invalid_layer_types := set(self.layer_types) - {"hybrid", "hybrid_sliding"}: - raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.") if "hybrid_sliding" in self.layer_types and self.sliding_window is None: raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.") - if self.sliding_window is not None and self.sliding_window <= 0: - raise ValueError("`sliding_window` must be a strictly positive integer.") class ZayaRotaryEmbedding(LagunaRotaryEmbedding): @@ -163,8 +157,6 @@ class ZayaCCAProjection(nn.Module): Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection from **the recurrent cache**. - - Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys. """ def __init__(self, config: ZayaConfig, layer_idx: int): @@ -226,7 +218,7 @@ def forward( qk_states = torch.cat([projected_queries, projected_keys], dim=-1) query_residual = projected_queries.view(*hidden_shape) - key_residual = projected_keys.view(*input_shape, -1, self.head_dim).transpose(1, 2) + key_residual = projected_keys.view(*hidden_shape).transpose(1, 2) key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) query_residual = (query_residual + key_residual) * 0.5 key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) @@ -252,6 +244,8 @@ def forward( query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual + # The value path carries half of each value head from the current token and half from the previous token. + # During cached decoding, `recurrent_v_state` is the previous token's delayed projection. value_current = self.v_proj_current(hidden_states) delayed_v_state = self.v_proj_delayed(hidden_states) if use_precomputed_states: @@ -269,9 +263,13 @@ def forward( class ZayaQKNorm(nn.Module): - def __init__(self, config: ZayaConfig, scaling: float): + """ + L2-normalizes q/k states to sqrt(head_dim) and applies ZAYA's learned per-KV-head key scale. + """ + + def __init__(self, config: ZayaConfig): super().__init__() - self.head_dim_scale = scaling**-1 + self.head_dim_scale = config.head_dim**0.5 self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -298,7 +296,7 @@ def __init__(self, config: ZayaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.qk_norm = ZayaQKNorm(config, self.scaling) + self.qk_norm = ZayaQKNorm(config) self.qkv_proj = ZayaCCAProjection( config=self.config, layer_idx=layer_idx, @@ -312,7 +310,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: - batch_size, seq_length, _ = hidden_states.shape + input_shape = hidden_states.shape[:-1] mask_mapping = attention_mask or {} causal_mask = mask_mapping.get("causal") @@ -347,7 +345,7 @@ def forward( **kwargs, ) - attn_output = attn_output.view(batch_size, seq_length, -1) + attn_output = attn_output.view(*input_shape, -1) attn_output = self.o_proj(attn_output) return attn_output, attn_weights @@ -381,9 +379,9 @@ def forward( ) residual = self.post_attention_residual_scale(hidden_states, residual) - hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype)) + hidden_states = self.post_attention_layernorm(residual) - hidden_states, prev_router_hidden_states, _ = self.mlp( + hidden_states, prev_router_hidden_states = self.mlp( hidden_states, prev_router_hidden_states, ) @@ -494,17 +492,15 @@ class ZayaExperts(Qwen3MoeExperts): class ZayaSparseMoeBlock(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() - self.config = config - self.hidden_dim = config.hidden_size - self.gate = ZayaRouter(self.config, layer_idx) - self.experts = ZayaExperts(self.config) + self.gate = ZayaRouter(config, layer_idx) + self.experts = ZayaExperts(config) def forward( self, hidden_states: torch.Tensor, prev_router_hidden_states: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: - router_logits, router_probs, router_indices, prev_router_hidden_states = self.gate( + ) -> tuple[torch.Tensor, torch.Tensor | None]: + _, router_probs, router_indices, prev_router_hidden_states = self.gate( hidden_states, router_states=prev_router_hidden_states ) @@ -513,13 +509,11 @@ def forward( expert_output = self.experts(hidden_states_flat, router_indices, router_probs) expert_output = expert_output.view(batch_size, seq_length, emb_dim) - return expert_output, prev_router_hidden_states, router_logits + return expert_output, prev_router_hidden_states class ZayaPreTrainedModel(LlamaPreTrainedModel): config: ZayaConfig - config_class = ZayaConfig - _no_split_modules = ["ZayaDecoderLayer"] # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache. _can_compile_fullgraph = False _can_record_outputs = { @@ -545,7 +539,7 @@ def _init_weights(self, module): if module.use_eda: init.ones_(module.router_states_scale) init.zeros_(module.balancing_biases) - module.balancing_biases[-1] = -1.0 + module.balancing_biases[-1] = -1.0 # ignore: trf012 elif isinstance(module, ZayaExperts): std = self.config.initializer_range init.normal_(module.gate_up_proj, mean=0.0, std=std) @@ -561,32 +555,14 @@ def _init_weights(self, module): @auto_docstring -class ZayaModel(ZayaPreTrainedModel): +class ZayaModel(LagunaModel): def __init__(self, config: ZayaConfig): super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - - self.gradient_checkpointing = False - + del self.norm self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - self.rotary_emb = ZayaRotaryEmbedding(config=config) - - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - @merge_with_config_defaults @capture_outputs @auto_docstring @@ -619,18 +595,23 @@ def forward( "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." ) - causal_mask_mapping = self._update_causal_mask( - attention_mask, - inputs_embeds, - position_ids, - past_key_values, - ) - padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None - - # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask. - # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state. - if inputs_embeds.shape[1] == 1: - padding_mask = None + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. + sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} + mask_creation_functions = { + "hybrid": lambda: create_causal_mask(**mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + causal_mask_mapping = { + layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types) + } + cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds) hidden_states = inputs_embeds @@ -643,50 +624,36 @@ def forward( prev_router_hidden_states = None - for layer_n, decoder_layer in enumerate(self.layers): - layer_type = self.config.layer_types[layer_n] - emb_to_use = position_embeddings[layer_type] - mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask} + for idx, decoder_layer in enumerate(self.layers): + layer_type = self.config.layer_types[idx] hidden_states, prev_router_hidden_states = decoder_layer( hidden_states, prev_router_hidden_states, - attention_mask=mask_mapping, + attention_mask={"causal": causal_mask_mapping[layer_type], "padding": cca_mask}, past_key_values=past_key_values, - position_embeddings=emb_to_use, + position_embeddings=position_embeddings[layer_type], **kwargs, ) - hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) + hidden_states = self.final_norm(hidden_states) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - position_ids: torch.Tensor, - past_key_values: Cache, - ): - mask_kwargs = { - "config": self.config, - "inputs_embeds": input_tensor, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "position_ids": position_ids, - } - # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. - sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} - mask_creation_functions = { - "hybrid": lambda: create_causal_mask(**mask_kwargs), - "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), - } - causal_mask_mapping = {} - for layer_type in set(self.config.layer_types): - causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]() - return causal_mask_mapping + def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds): + """ + No need to zero padding states when cached convolution states are already available or all inputs are valid. + """ + cca_mask = attention_mask + if (past_key_values is not None and past_key_values.has_previous_state()) or ( + attention_mask is not None and torch.all(attention_mask == 1) + ): + cca_mask = None + elif attention_mask is not None: + cca_mask = attention_mask[:, -inputs_embeds.shape[1] :] + return cca_mask @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index 37a081273dcb..ceeb40fd6a06 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -40,30 +40,16 @@ class ZayaModelTester(CausalLMModelTester): if is_torch_available(): base_model_class = ZayaModel - def __init__(self, parent): + def __init__(self, parent, **kwargs): super().__init__( parent=parent, - batch_size=2, - seq_length=7, - vocab_size=128, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=2, + num_hidden_layers=4, moe_intermediate_size=32, + num_experts_per_tok=1, + layer_types=["hybrid", "hybrid_sliding", "hybrid", "hybrid_sliding"], + sliding_window=64, + **kwargs, ) - self.head_dim = 8 - self.num_experts = 4 - self.num_experts_per_tok = 1 - self.router_hidden_size = 4 - self.tie_word_embeddings = False - self.rope_parameters = { - "hybrid": { - "rope_theta": 10000, - "rope_type": "default", - "partial_rotary_factor": 0.5, - }, - } @require_torch @@ -101,17 +87,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(layer.conv_states.shape, conv_shape) self.assertEqual(layer.recurrent_states.shape, recurrent_shape) - def is_pipeline_test_to_skip( - self, - pipeline_test_case_name, - config_class, - model_architecture, - tokenizer_name, - image_processor_name, - feature_extractor_name, - processor_name, - ): - return True @unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.") def test_eager_padding_matches_padding_free_with_position_ids(self): @@ -121,8 +96,11 @@ def test_eager_padding_matches_padding_free_with_position_ids(self): def test_sdpa_padding_matches_padding_free_with_position_ids(self): pass - @unittest.skip("ZAYA uses MoE routing; equivalent-output comparisons are not stable for this architecture.") - def test_model_outputs_equivalence(self, **kwargs): + @unittest.skip( + "ZAYA follows the original SWA behavior where sliding attention only applies the local causal pattern;" + "See https://github.com/huggingface/transformers/pull/45862#discussion_r3249556316" + ) + def test_left_padding_compatibility(self): pass def test_attention_outputs(self): @@ -163,7 +141,6 @@ def test_model_rope_scaling_frequencies(self): Copied from Laguna to adapt to per-layer-type rope configs. """ config, _ = self.model_tester.prepare_config_and_inputs_for_common() - config.layer_types = ["hybrid", "hybrid_sliding"] partial_rotary_factor = config.rope_parameters["hybrid"]["partial_rotary_factor"] def set_rope_params(rope_params): @@ -239,10 +216,6 @@ def set_rope_params(rope_params): with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_sin_long, original_sin_long) - @unittest.skip("ZAYA needs alternating attention and MoE layers in the tiny test configuration.") - def test_num_layers_is_small(self): - pass - def test_moe_router_logits(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = self.model_tester.causal_lm_class(config) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index fcd3547a06c7..41a8f5cbbbfb 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -814,6 +814,7 @@ def test_num_layers_is_small(self): "Gemma3nVision2TextModelTest": 4, # need to test KV shared layer for both types: `full_attention` and `sliding_attention` "BeitModelTest": 4, # BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers "ZambaModelTest": 5, # The minimum number to test beyond the initial ["mamba", "mamba", "hybrid"] in `ZambaConfig._layers_block_type` + "ZayaModelTest": 4, # needs two passes over `hybrid` and `hybrid_sliding` layer types } target_num_hidden_layers = exceptional_num_hidden_layers.get(type(self).__name__, 2) From 0df3204dbd5e00058e8c8b74283dfdb136f6e96f Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sat, 16 May 2026 12:17:32 +0800 Subject: [PATCH 27/36] update date --- docs/source/en/model_doc/zaya.md | 2 +- tests/models/zaya/test_modeling_zaya.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md index 06beb12e2e6f..199cd5d2935b 100644 --- a/docs/source/en/model_doc/zaya.md +++ b/docs/source/en/model_doc/zaya.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-13.* +*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-16.* # ZAYA diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index ceeb40fd6a06..316d206004d0 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -87,7 +87,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(layer.conv_states.shape, conv_shape) self.assertEqual(layer.recurrent_states.shape, recurrent_shape) - @unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.") def test_eager_padding_matches_padding_free_with_position_ids(self): pass From d362c90c378b4b32b54513f1627b6d9d59ccc6a1 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sun, 17 May 2026 20:04:39 +0800 Subject: [PATCH 28/36] Fix ZAYA residual stream precision regression --- src/transformers/models/zaya/modeling_zaya.py | 17 +++++++++-------- src/transformers/models/zaya/modular_zaya.py | 17 +++++++++-------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index 0815020f0e2e..6224c0bcc6bd 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -462,7 +462,7 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.input_layernorm(residual.to(dtype=self.input_layernorm.weight.dtype)) hidden_states, _ = self.self_attn( hidden_states=hidden_states, @@ -473,7 +473,7 @@ def forward( ) residual = self.post_attention_residual_scale(hidden_states, residual) - hidden_states = self.post_attention_layernorm(residual) + hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype)) hidden_states, prev_router_hidden_states = self.mlp( hidden_states, @@ -494,12 +494,10 @@ def __init__(self, hidden_size: int): self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): - output_dtype = hidden_states.dtype + # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path. hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale - # Matches the original ZAYA `residual_in_fp32` path. - residual = residual.to(torch.float32) residual = (residual + self.residual_bias) * self.residual_scale - return (hidden_states + residual).to(output_dtype) + return hidden_states + residual class ZayaRouterMLP(nn.Module): @@ -770,7 +768,10 @@ def forward( for layer_type in set(self.config.layer_types) } - hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale + # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path. + hidden_states = ((hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale).to( + torch.float32 + ) prev_router_hidden_states = None @@ -785,7 +786,7 @@ def forward( **kwargs, ) - hidden_states = self.final_norm(hidden_states) + hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) return MoeModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 1967bf6fc64a..c7b1c237a066 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -368,7 +368,7 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.input_layernorm(residual.to(dtype=self.input_layernorm.weight.dtype)) hidden_states, _ = self.self_attn( hidden_states=hidden_states, @@ -379,7 +379,7 @@ def forward( ) residual = self.post_attention_residual_scale(hidden_states, residual) - hidden_states = self.post_attention_layernorm(residual) + hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype)) hidden_states, prev_router_hidden_states = self.mlp( hidden_states, @@ -400,12 +400,10 @@ def __init__(self, hidden_size: int): self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): - output_dtype = hidden_states.dtype + # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path. hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale - # Matches the original ZAYA `residual_in_fp32` path. - residual = residual.to(torch.float32) residual = (residual + self.residual_bias) * self.residual_scale - return (hidden_states + residual).to(output_dtype) + return hidden_states + residual class ZayaRouterMLP(nn.Module): @@ -620,7 +618,10 @@ def forward( for layer_type in set(self.config.layer_types) } - hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale + # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path. + hidden_states = ((hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale).to( + torch.float32 + ) prev_router_hidden_states = None @@ -635,7 +636,7 @@ def forward( **kwargs, ) - hidden_states = self.final_norm(hidden_states) + hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) return MoeModelOutputWithPast( last_hidden_state=hidden_states, From f6178966056a9269e6303ec515aeadfd78773634 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 19 May 2026 20:04:07 +0800 Subject: [PATCH 29/36] code clean --- .../models/zaya/convert_zaya_weights_to_hf.py | 2 +- src/transformers/models/zaya/modeling_zaya.py | 29 ++++----- src/transformers/models/zaya/modular_zaya.py | 29 ++++----- tests/models/zaya/test_modeling_zaya.py | 60 ++----------------- tests/test_modeling_common.py | 1 - 5 files changed, 27 insertions(+), 94 deletions(-) diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py index 2ac6cb7df869..ad279541ab83 100644 --- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py +++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py @@ -72,7 +72,7 @@ def _rename_common(rest: str) -> str: ("self_attn.qkv.val_proj1.", "self_attn.qkv_proj.v_proj_current."), ("self_attn.qkv.val_proj2.", "self_attn.qkv_proj.v_proj_delayed."), ("self_attn.qkv.", "self_attn.qkv_proj."), - ("zaya_block.router.rmsnorm_eda.", "mlp.gate.router_mlp.rmsnorm_eda."), + ("zaya_block.router.rmsnorm_eda.", "mlp.gate.router_mlp.norm."), ("zaya_block.router.router_mlp.0.", "mlp.gate.router_mlp.fc1."), ("zaya_block.router.router_mlp.2.", "mlp.gate.router_mlp.fc2."), ("zaya_block.router.router_mlp.4.", "mlp.gate.router_mlp.out_proj."), diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index 6224c0bcc6bd..e10cf4388ef5 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -160,12 +160,8 @@ class ZayaCCAProjection(nn.Module): """ Projects hidden states into attention q/k/v states with ZAYA's CCA path. - `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal - `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail; - for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. - Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token - `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection - from **the recurrent cache**. + This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D + convolution, q/k keep residual projection paths, and v uses a delayed recurrent state. """ def __init__(self, config: ZayaConfig, layer_idx: int): @@ -242,8 +238,7 @@ def forward( if past_key_values is not None: new_conv_state = qk_states[..., -self.conv_kernel_size :] - if new_conv_state.shape[-1] < self.conv_kernel_size: - new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) + new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) past_key_values.update_conv_state(new_conv_state, self.layer_idx) qk_states = self.conv_qk_depthwise(qk_states) @@ -405,8 +400,8 @@ def forward( causal_mask = mask_mapping.get("causal") padding_mask = mask_mapping.get("padding") + # ZAYA replaces the usual independent q/k/v projections with CCA projection followed by special q/k normalization. query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask) - query_states, key_states = self.qk_norm(query_states, key_states) query_states = query_states.transpose(1, 2) @@ -503,14 +498,14 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): class ZayaRouterMLP(nn.Module): def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float): super().__init__() - self.rmsnorm_eda = ZayaRMSNorm(hidden_size, eps=rms_norm_eps) + self.norm = ZayaRMSNorm(hidden_size, eps=rms_norm_eps) self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) self.act_fn = nn.GELU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.rmsnorm_eda(hidden_states) + hidden_states = self.norm(hidden_states) hidden_states = self.act_fn(self.fc1(hidden_states)) hidden_states = self.act_fn(self.fc2(hidden_states)) return self.out_proj(hidden_states) @@ -676,7 +671,7 @@ def _init_weights(self, module): if module.use_eda: init.ones_(module.router_states_scale) init.zeros_(module.balancing_biases) - module.balancing_biases[-1] = -1.0 # ignore: trf012 + module.balancing_biases[-1] = -1.0 # trf-ignore: TRF012 elif isinstance(module, ZayaExperts): std = self.config.initializer_range init.normal_(module.gate_up_proj, mean=0.0, std=std) @@ -750,16 +745,14 @@ def forward( "past_key_values": past_key_values, "position_ids": position_ids, } - # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. - sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} mask_creation_functions = { "hybrid": lambda: create_causal_mask(**mask_kwargs), - "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**mask_kwargs), } causal_mask_mapping = { layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types) } - cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds) + cca_mask = self._update_cca_mask(attention_mask, past_key_values) hidden_states = inputs_embeds @@ -793,7 +786,7 @@ def forward( past_key_values=past_key_values if use_cache else None, ) - def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds): + def _update_cca_mask(self, attention_mask, past_key_values): """ No need to zero padding states when cached convolution states are already available or all inputs are valid. """ @@ -802,8 +795,6 @@ def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds): attention_mask is not None and torch.all(attention_mask == 1) ): cca_mask = None - elif attention_mask is not None: - cca_mask = attention_mask[:, -inputs_embeds.shape[1] :] return cca_mask diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index c7b1c237a066..e85d6dd06591 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -151,12 +151,8 @@ class ZayaCCAProjection(nn.Module): """ Projects hidden states into attention q/k/v states with ZAYA's CCA path. - `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal - `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail; - for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. - Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token - `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection - from **the recurrent cache**. + This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D + convolution, q/k keep residual projection paths, and v uses a delayed recurrent state. """ def __init__(self, config: ZayaConfig, layer_idx: int): @@ -233,8 +229,7 @@ def forward( if past_key_values is not None: new_conv_state = qk_states[..., -self.conv_kernel_size :] - if new_conv_state.shape[-1] < self.conv_kernel_size: - new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) + new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) past_key_values.update_conv_state(new_conv_state, self.layer_idx) qk_states = self.conv_qk_depthwise(qk_states) @@ -316,8 +311,8 @@ def forward( causal_mask = mask_mapping.get("causal") padding_mask = mask_mapping.get("padding") + # ZAYA replaces the usual independent q/k/v projections with CCA projection followed by special q/k normalization. query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask) - query_states, key_states = self.qk_norm(query_states, key_states) query_states = query_states.transpose(1, 2) @@ -409,14 +404,14 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): class ZayaRouterMLP(nn.Module): def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float): super().__init__() - self.rmsnorm_eda = ZayaRMSNorm(hidden_size, eps=rms_norm_eps) + self.norm = ZayaRMSNorm(hidden_size, eps=rms_norm_eps) self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) self.act_fn = nn.GELU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.rmsnorm_eda(hidden_states) + hidden_states = self.norm(hidden_states) hidden_states = self.act_fn(self.fc1(hidden_states)) hidden_states = self.act_fn(self.fc2(hidden_states)) return self.out_proj(hidden_states) @@ -537,7 +532,7 @@ def _init_weights(self, module): if module.use_eda: init.ones_(module.router_states_scale) init.zeros_(module.balancing_biases) - module.balancing_biases[-1] = -1.0 # ignore: trf012 + module.balancing_biases[-1] = -1.0 # trf-ignore: TRF012 elif isinstance(module, ZayaExperts): std = self.config.initializer_range init.normal_(module.gate_up_proj, mean=0.0, std=std) @@ -600,16 +595,14 @@ def forward( "past_key_values": past_key_values, "position_ids": position_ids, } - # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. - sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} mask_creation_functions = { "hybrid": lambda: create_causal_mask(**mask_kwargs), - "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**mask_kwargs), } causal_mask_mapping = { layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types) } - cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds) + cca_mask = self._update_cca_mask(attention_mask, past_key_values) hidden_states = inputs_embeds @@ -643,7 +636,7 @@ def forward( past_key_values=past_key_values if use_cache else None, ) - def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds): + def _update_cca_mask(self, attention_mask, past_key_values): """ No need to zero padding states when cached convolution states are already available or all inputs are valid. """ @@ -652,8 +645,6 @@ def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds): attention_mask is not None and torch.all(attention_mask == 1) ): cca_mask = None - elif attention_mask is not None: - cca_mask = attention_mask[:, -inputs_embeds.shape[1] :] return cca_mask diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index 316d206004d0..a222a957969c 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -43,10 +43,10 @@ class ZayaModelTester(CausalLMModelTester): def __init__(self, parent, **kwargs): super().__init__( parent=parent, - num_hidden_layers=4, + num_hidden_layers=2, moe_intermediate_size=32, num_experts_per_tok=1, - layer_types=["hybrid", "hybrid_sliding", "hybrid", "hybrid_sliding"], + layer_types=["hybrid", "hybrid_sliding"], sliding_window=64, **kwargs, ) @@ -95,13 +95,6 @@ def test_eager_padding_matches_padding_free_with_position_ids(self): def test_sdpa_padding_matches_padding_free_with_position_ids(self): pass - @unittest.skip( - "ZAYA follows the original SWA behavior where sliding attention only applies the local causal pattern;" - "See https://github.com/huggingface/transformers/pull/45862#discussion_r3249556316" - ) - def test_left_padding_compatibility(self): - pass - def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True @@ -215,22 +208,6 @@ def set_rope_params(rope_params): with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_sin_long, original_sin_long) - def test_moe_router_logits(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = self.model_tester.causal_lm_class(config) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - outputs = model(**inputs_dict, output_router_logits=True) - - expected_moe_layers = config.num_hidden_layers - self.assertEqual(len(outputs.router_logits), expected_moe_layers) - self.assertEqual( - outputs.router_logits[0].shape, - (self.model_tester.batch_size * self.model_tester.seq_length, config.num_experts + 1), - ) - def test_num_experts_per_tok_validation(self): with self.assertRaisesRegex(StrictDataclassClassValidationError, "num_experts_per_tok=1"): ZayaConfig(num_experts_per_tok=2) @@ -240,13 +217,13 @@ def test_sliding_attention_mask_is_used(self): vocab_size=128, hidden_size=32, moe_intermediate_size=32, - num_hidden_layers=4, + num_hidden_layers=2, num_experts=4, num_attention_heads=4, num_key_value_heads=2, head_dim=8, router_hidden_size=4, - layer_types=["hybrid_sliding", "hybrid", "hybrid_sliding", "hybrid"], + layer_types=["hybrid_sliding", "hybrid"], sliding_window=3, tie_word_embeddings=False, attn_implementation="eager", @@ -261,33 +238,6 @@ def test_sliding_attention_mask_is_used(self): sliding_attention = outputs.attentions[0] self.assertTrue(torch.all(sliding_attention[:, :, -1, :3] == 0)) - def test_cca_cache_matches_full_forward(self): - config = ZayaConfig( - vocab_size=128, - hidden_size=32, - moe_intermediate_size=32, - num_hidden_layers=1, - num_experts=4, - num_attention_heads=4, - num_key_value_heads=2, - head_dim=8, - router_hidden_size=4, - tie_word_embeddings=False, - ) - torch.manual_seed(0) - cca = ZayaCCAProjection(config, layer_idx=0).to(torch_device) - cca.eval() - hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device) - - with torch.no_grad(): - full = cca(hidden_states, None, None) - cache = DynamicCache(config=config) - cca(hidden_states[:, :4], cache, None) - cached = cca(hidden_states[:, 4:], cache, None) - - for full_states, cached_states in zip(full, cached): - torch.testing.assert_close(full_states[:, -1:], cached_states, rtol=1e-5, atol=1e-5) - def test_cca_cache_matches_full_forward_multi_token(self): config = ZayaConfig( vocab_size=128, @@ -307,6 +257,8 @@ def test_cca_cache_matches_full_forward_multi_token(self): hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device) with torch.no_grad(): + # Compare full CCA projection against a cached continuation. The second chunk must recover the same + # q/k/v states from the cached convolution tail and delayed recurrent value state. full = cca(hidden_states, None, None) cache = DynamicCache(config=config) cca(hidden_states[:, :3], cache, None) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 41a8f5cbbbfb..fcd3547a06c7 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -814,7 +814,6 @@ def test_num_layers_is_small(self): "Gemma3nVision2TextModelTest": 4, # need to test KV shared layer for both types: `full_attention` and `sliding_attention` "BeitModelTest": 4, # BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers "ZambaModelTest": 5, # The minimum number to test beyond the initial ["mamba", "mamba", "hybrid"] in `ZambaConfig._layers_block_type` - "ZayaModelTest": 4, # needs two passes over `hybrid` and `hybrid_sliding` layer types } target_num_hidden_layers = exceptional_num_hidden_layers.get(type(self).__name__, 2) From bc55e03537aef2a783283f8508bb6bea194207bb Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 19 May 2026 20:19:24 +0800 Subject: [PATCH 30/36] date --- docs/source/en/model_doc/zaya.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md index 199cd5d2935b..de7916038254 100644 --- a/docs/source/en/model_doc/zaya.md +++ b/docs/source/en/model_doc/zaya.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-16.* +*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-19.* # ZAYA From 6ad8e9f077473d7ed5eaa01265ec3be0dc6690ed Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Wed, 20 May 2026 20:01:02 +0800 Subject: [PATCH 31/36] update fsdp --- src/transformers/models/zaya/configuration_zaya.py | 6 ++++++ .../models/zaya/convert_zaya_weights_to_hf.py | 2 ++ src/transformers/models/zaya/modeling_zaya.py | 5 +++-- src/transformers/models/zaya/modular_zaya.py | 13 +++++++++---- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py index 4a18bd4716f2..95de5cdf2919 100644 --- a/src/transformers/models/zaya/configuration_zaya.py +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -86,6 +86,12 @@ class ZayaConfig(PreTrainedConfig): cca_time0: int = 2 cca_time1: int = 2 + base_model_fsdp_plan = { + "embed_tokens": "free_full_weight", + "layers.*": "free_full_weight", + "norm": "keep_full_weight", + } + def __post_init__(self, **kwargs): self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py index ad279541ab83..a359ed5bb8aa 100644 --- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py +++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py @@ -107,6 +107,8 @@ def convert_weight_name(name: str, old_num_hidden_layers: int | None = None) -> match = _LAYER_PATTERN.match(name) if match is None: + if name.startswith("model.final_norm."): + return f"model.norm.{name.removeprefix('model.final_norm.')}" if old_num_hidden_layers is not None and name.startswith("model.res_scale."): new_layer_idx = old_num_hidden_layers // 2 - 1 return f"model.layers.{new_layer_idx}.post_mlp_residual_scale.{name.removeprefix('model.res_scale.')}" diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index e10cf4388ef5..fd5fab1630aa 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -697,11 +697,11 @@ def __init__(self, config: ZayaConfig): self.layers = nn.ModuleList( [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) + self.norm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = ZayaRotaryEmbedding(config=config) self.gradient_checkpointing = False self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) - self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) # Initialize weights and apply final processing self.post_init() @@ -779,7 +779,7 @@ def forward( **kwargs, ) - hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) return MoeModelOutputWithPast( last_hidden_state=hidden_states, @@ -803,6 +803,7 @@ class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_gather_output"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + _fsdp_plan = {"lm_head": "keep_full_weight"} _is_stateful = True def __init__(self, config, **kwargs): diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index e85d6dd06591..83ad9199a267 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -93,8 +93,14 @@ class ZayaConfig(LagunaConfig): cca_time0: int = 2 cca_time1: int = 2 + base_model_fsdp_plan = { + "embed_tokens": "free_full_weight", + "layers.*": "free_full_weight", + "norm": "keep_full_weight", + } + # Fields declared by LagunaConfig but not used by ZAYA. - # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state. + # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state. For TP, see discussion https://github.com/huggingface/transformers/pull/45862#discussion_r3266709862 base_model_tp_plan = AttributeError() base_model_pp_plan = AttributeError() intermediate_size = AttributeError() @@ -551,10 +557,8 @@ def _init_weights(self, module): class ZayaModel(LagunaModel): def __init__(self, config: ZayaConfig): super().__init__(config) - del self.norm self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) - self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) @merge_with_config_defaults @capture_outputs @@ -629,7 +633,7 @@ def forward( **kwargs, ) - hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) return MoeModelOutputWithPast( last_hidden_state=hidden_states, @@ -651,6 +655,7 @@ def _update_cca_mask(self, attention_mask, past_key_values): @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") class ZayaForCausalLM(AfmoeForCausalLM, ZayaPreTrainedModel): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _fsdp_plan = {"lm_head": "keep_full_weight"} _is_stateful = True def __init__(self, config, **kwargs): From ecb80ed181814aef33f65a39150f57a6f1c9edf8 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Wed, 20 May 2026 20:03:51 +0800 Subject: [PATCH 32/36] clean --- src/transformers/models/zaya/modeling_zaya.py | 2 -- src/transformers/models/zaya/modular_zaya.py | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index fd5fab1630aa..befe762155e9 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -801,8 +801,6 @@ def _update_cca_mask(self, attention_mask, past_key_values): @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_gather_output"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} _fsdp_plan = {"lm_head": "keep_full_weight"} _is_stateful = True diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 83ad9199a267..e83a97518ba4 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -658,6 +658,9 @@ class ZayaForCausalLM(AfmoeForCausalLM, ZayaPreTrainedModel): _fsdp_plan = {"lm_head": "keep_full_weight"} _is_stateful = True + _tp_plan = AttributeError() + _pp_plan = AttributeError() + def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) From db1db76278affd6643da34f63c5f01c0ba509cf9 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Wed, 20 May 2026 23:04:25 +0800 Subject: [PATCH 33/36] upstream --- src/transformers/models/zaya/configuration_zaya.py | 12 ++++++------ src/transformers/models/zaya/modeling_zaya.py | 1 + src/transformers/models/zaya/modular_zaya.py | 7 ------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py index 95de5cdf2919..690a521ee357 100644 --- a/src/transformers/models/zaya/configuration_zaya.py +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -53,6 +53,12 @@ class ZayaConfig(PreTrainedConfig): model_type = "zaya" keys_to_ignore_at_inference = ["past_key_values"] + base_model_fsdp_plan = { + "embed_tokens": "free_full_weight", + "layers.*": "free_full_weight", + "norm": "keep_full_weight", + } + vocab_size: int = 262272 hidden_size: int = 2048 num_hidden_layers: int = 40 @@ -86,12 +92,6 @@ class ZayaConfig(PreTrainedConfig): cca_time0: int = 2 cca_time1: int = 2 - base_model_fsdp_plan = { - "embed_tokens": "free_full_weight", - "layers.*": "free_full_weight", - "norm": "keep_full_weight", - } - def __post_init__(self, **kwargs): self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index befe762155e9..b401f22d190d 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -801,6 +801,7 @@ def _update_cca_mask(self, attention_mask, past_key_values): @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _sp_plan = {"lm_head": "colwise_loss_parallel"} _fsdp_plan = {"lm_head": "keep_full_weight"} _is_stateful = True diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index e83a97518ba4..7569d5e697a9 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -93,12 +93,6 @@ class ZayaConfig(LagunaConfig): cca_time0: int = 2 cca_time1: int = 2 - base_model_fsdp_plan = { - "embed_tokens": "free_full_weight", - "layers.*": "free_full_weight", - "norm": "keep_full_weight", - } - # Fields declared by LagunaConfig but not used by ZAYA. # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state. For TP, see discussion https://github.com/huggingface/transformers/pull/45862#discussion_r3266709862 base_model_tp_plan = AttributeError() @@ -655,7 +649,6 @@ def _update_cca_mask(self, attention_mask, past_key_values): @auto_docstring(checkpoint="Zyphra/ZAYA1-8B") class ZayaForCausalLM(AfmoeForCausalLM, ZayaPreTrainedModel): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _fsdp_plan = {"lm_head": "keep_full_weight"} _is_stateful = True _tp_plan = AttributeError() From 86215d58d5eae2ef1021365beed3dc07eb36c450 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 26 May 2026 17:42:26 +0800 Subject: [PATCH 34/36] improve comments --- src/transformers/models/zaya/modeling_zaya.py | 8 ++- src/transformers/models/zaya/modular_zaya.py | 17 +++--- tests/models/zaya/test_modeling_zaya.py | 57 +++++++++++-------- 3 files changed, 51 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index b401f22d190d..f6e2a16c82af 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -158,7 +158,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class ZayaCCAProjection(nn.Module): """ - Projects hidden states into attention q/k/v states with ZAYA's CCA path. + Projects hidden states into attention q/k/v states with ZAYA's Compressed Convolutional Attention (CCA) path. + See https://arxiv.org/abs/2510.04476. This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D convolution, q/k keep residual projection paths, and v uses a delayed recurrent state. @@ -457,6 +458,8 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: residual = hidden_states + # Match upstream's residual_in_fp32 path by keeping the residual stream in fp32 and avoiding extra + # fp32->bf16 round trips in the residual module. hidden_states = self.input_layernorm(residual.to(dtype=self.input_layernorm.weight.dtype)) hidden_states, _ = self.self_attn( @@ -623,6 +626,7 @@ def forward( hidden_states: torch.Tensor, prev_router_hidden_states: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + # ZAYA carries router hidden states across decoder layers; the next layer consumes this state in its router. _, router_probs, router_indices, prev_router_hidden_states = self.gate( hidden_states, router_states=prev_router_hidden_states ) @@ -770,6 +774,8 @@ def forward( for idx, decoder_layer in enumerate(self.layers): layer_type = self.config.layer_types[idx] + # Attention uses the prepared causal mask, while CCA projection still needs the raw 2D padding mask to + # zero padding tokens before convolution. hidden_states, prev_router_hidden_states = decoder_layer( hidden_states, prev_router_hidden_states, diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index 7569d5e697a9..c3b5008fe988 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -35,8 +35,7 @@ TransformersKwargs, auto_docstring, ) -from ...utils.generic import merge_with_config_defaults -from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...utils.output_capturing import OutputRecorder from ..afmoe.modeling_afmoe import AfmoeForCausalLM from ..laguna.configuration_laguna import LagunaConfig from ..laguna.modeling_laguna import LagunaModel, LagunaRotaryEmbedding @@ -94,7 +93,8 @@ class ZayaConfig(LagunaConfig): cca_time1: int = 2 # Fields declared by LagunaConfig but not used by ZAYA. - # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state. For TP, see discussion https://github.com/huggingface/transformers/pull/45862#discussion_r3266709862 + # NOTE: TP is intentionally disabled for now because the useful degree is limited by ZAYA's 2 KV heads; see + # https://github.com/huggingface/transformers/pull/45862#discussion_r3266709862. PP needs coverage for the cross-layer router state. base_model_tp_plan = AttributeError() base_model_pp_plan = AttributeError() intermediate_size = AttributeError() @@ -149,7 +149,8 @@ class ZayaRMSNorm(Qwen3MoeRMSNorm): class ZayaCCAProjection(nn.Module): """ - Projects hidden states into attention q/k/v states with ZAYA's CCA path. + Projects hidden states into attention q/k/v states with ZAYA's Compressed Convolutional Attention (CCA) path. + See https://arxiv.org/abs/2510.04476. This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D convolution, q/k keep residual projection paths, and v uses a delayed recurrent state. @@ -363,6 +364,8 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: residual = hidden_states + # Match upstream's residual_in_fp32 path by keeping the residual stream in fp32 and avoiding extra + # fp32->bf16 round trips in the residual module. hidden_states = self.input_layernorm(residual.to(dtype=self.input_layernorm.weight.dtype)) hidden_states, _ = self.self_attn( @@ -493,6 +496,7 @@ def forward( hidden_states: torch.Tensor, prev_router_hidden_states: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + # ZAYA carries router hidden states across decoder layers; the next layer consumes this state in its router. _, router_probs, router_indices, prev_router_hidden_states = self.gate( hidden_states, router_states=prev_router_hidden_states ) @@ -554,9 +558,6 @@ def __init__(self, config: ZayaConfig): self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) - @merge_with_config_defaults - @capture_outputs - @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -618,6 +619,8 @@ def forward( for idx, decoder_layer in enumerate(self.layers): layer_type = self.config.layer_types[idx] + # Attention uses the prepared causal mask, while CCA projection still needs the raw 2D padding mask to + # zero padding tokens before convolution. hidden_states, prev_router_hidden_states = decoder_layer( hidden_states, prev_router_hidden_states, diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py index a222a957969c..86448222a20e 100644 --- a/tests/models/zaya/test_modeling_zaya.py +++ b/tests/models/zaya/test_modeling_zaya.py @@ -19,7 +19,7 @@ from parameterized import parameterized from transformers import is_torch_available -from transformers.testing_utils import cleanup, require_torch, slow, torch_device +from transformers.testing_utils import Expectations, cleanup, require_torch, slow, torch_device if is_torch_available(): @@ -213,30 +213,19 @@ def test_num_experts_per_tok_validation(self): ZayaConfig(num_experts_per_tok=2) def test_sliding_attention_mask_is_used(self): - config = ZayaConfig( - vocab_size=128, - hidden_size=32, - moe_intermediate_size=32, - num_hidden_layers=2, - num_experts=4, - num_attention_heads=4, - num_key_value_heads=2, - head_dim=8, - router_hidden_size=4, - layer_types=["hybrid_sliding", "hybrid"], - sliding_window=3, - tie_word_embeddings=False, - attn_implementation="eager", - ) - model = ZayaModel(config).to(torch_device) + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.layer_types = ["hybrid_sliding"] + ["hybrid"] * (config.num_hidden_layers - 1) + config.sliding_window = 3 + config._attn_implementation = "eager" + + model = ZayaModel._from_config(config, attn_implementation="eager").to(torch_device) model.eval() - input_ids = torch.arange(6, device=torch_device).unsqueeze(0) with torch.no_grad(): - outputs = model(input_ids=input_ids, output_attentions=True) + outputs = model(input_ids=inputs_dict["input_ids"].to(torch_device), output_attentions=True) sliding_attention = outputs.attentions[0] - self.assertTrue(torch.all(sliding_attention[:, :, -1, :3] == 0)) + self.assertTrue(torch.all(sliding_attention[:, :, -1, : -config.sliding_window] == 0)) def test_cca_cache_matches_full_forward_multi_token(self): config = ZayaConfig( @@ -334,6 +323,18 @@ def test_model_logits(self): self.assertEqual(logits.shape, (1, inputs.input_ids.shape[-1], model.config.vocab_size)) self.assertTrue(torch.isfinite(logits).all().item()) + EXPECTED_LOGITS = Expectations( + { + (None, None): [ + [0.3613, 0.3633, 0.3633], + [-1.3672, -1.3672, -1.3672], + [-2.8750, -2.8750, -2.8750], + ] + } + ) # fmt: skip + expected_slice = torch.tensor(EXPECTED_LOGITS.get_expectation(), dtype=logits.dtype) + torch.testing.assert_close(logits[0, -3:, -3:], expected_slice, rtol=1e-3, atol=1e-3) + expected_argmax = torch.tensor([[105, 9731, 107, 740, 564, 1601, 611, 236881, 236881, 107, 107]]) torch.testing.assert_close(logits.argmax(-1), expected_argmax) @@ -366,6 +367,16 @@ def test_model_generation(self): inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) with torch.no_grad(): - generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=3, top_k=None, top_p=None) - - self.assertEqual(generated_ids[0, -3:].tolist(), [107, 262146, 108]) + generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=16, top_k=None, top_p=None) + + expected_generated_ids = Expectations( + { + (None, None): [ + 107, 262146, 108, 9259, 236888, 2088, 740, 564, + 6361, 611, 3124, 236881, 108, 236769, 10282, 236787, + ] + } + ) # fmt: skip + self.assertEqual( + generated_ids[0, inputs.input_ids.shape[-1] :].tolist(), expected_generated_ids.get_expectation() + ) From ebeb8c3fbf3e48c8a9d85c545cbf69f7ea013f64 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 26 May 2026 17:54:24 +0800 Subject: [PATCH 35/36] date --- docs/source/en/model_doc/zaya.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md index de7916038254..1ae3323416c8 100644 --- a/docs/source/en/model_doc/zaya.md +++ b/docs/source/en/model_doc/zaya.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-19.* +*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-26.* # ZAYA From d71306f07a44096dccde73c8ae5f783581d9b857 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 26 May 2026 23:19:49 +0800 Subject: [PATCH 36/36] update link! --- src/transformers/models/zaya/modeling_zaya.py | 2 +- src/transformers/models/zaya/modular_zaya.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py index f6e2a16c82af..714910c8831a 100755 --- a/src/transformers/models/zaya/modeling_zaya.py +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -159,7 +159,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class ZayaCCAProjection(nn.Module): """ Projects hidden states into attention q/k/v states with ZAYA's Compressed Convolutional Attention (CCA) path. - See https://arxiv.org/abs/2510.04476. + See https://huggingface.co/papers/2510.04476. This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D convolution, q/k keep residual projection paths, and v uses a delayed recurrent state. diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py index c3b5008fe988..c205de489e8c 100644 --- a/src/transformers/models/zaya/modular_zaya.py +++ b/src/transformers/models/zaya/modular_zaya.py @@ -150,7 +150,7 @@ class ZayaRMSNorm(Qwen3MoeRMSNorm): class ZayaCCAProjection(nn.Module): """ Projects hidden states into attention q/k/v states with ZAYA's Compressed Convolutional Attention (CCA) path. - See https://arxiv.org/abs/2510.04476. + See https://huggingface.co/papers/2510.04476. This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D convolution, q/k keep residual projection paths, and v uses a delayed recurrent state.