From b35c5e08f10df586eb79e1cbb4a9bc8867f21899 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Sat, 9 May 2026 20:03:28 +0800
Subject: [PATCH 01/36] zaya1 support
---
docs/source/en/_toctree.yml | 2 +
docs/source/en/model_doc/zaya.md | 55 +
src/transformers/conversion_mapping.py | 12 +
src/transformers/models/__init__.py | 1 +
src/transformers/models/auto/auto_mappings.py | 1 +
src/transformers/models/auto/modeling_auto.py | 2 +
src/transformers/models/zaya/__init__.py | 28 +
.../models/zaya/configuration_zaya.py | 183 +++
src/transformers/models/zaya/modeling_zaya.py | 1126 ++++++++++++++++
src/transformers/models/zaya/modular_zaya.py | 1133 +++++++++++++++++
tests/models/zaya/__init__.py | 1 +
11 files changed, 2544 insertions(+)
create mode 100644 docs/source/en/model_doc/zaya.md
create mode 100644 src/transformers/models/zaya/__init__.py
create mode 100644 src/transformers/models/zaya/configuration_zaya.py
create mode 100755 src/transformers/models/zaya/modeling_zaya.py
create mode 100644 src/transformers/models/zaya/modular_zaya.py
create mode 100644 tests/models/zaya/__init__.py
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 11e3c9008d56..3fc1d2ef50fd 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -883,6 +883,8 @@
title: Zamba
- local: model_doc/zamba2
title: Zamba2
+ - local: model_doc/zaya
+ title: ZAYA
title: Text models
- sections:
- local: model_doc/aimv2
diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md
new file mode 100644
index 000000000000..7f881a47efb9
--- /dev/null
+++ b/docs/source/en/model_doc/zaya.md
@@ -0,0 +1,55 @@
+
+*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-09.*
+
+# ZAYA
+
+## Overview
+
+ZAYA1 is a 760M active / 8.4B total parameter MoE language model trained by Zyphra. It combines Compressed
+Convolutional Attention (CCA), a nonlinear ZAYA1 router, and residual scaling.
+
+ZAYA1 uses the Gemma 3 tokenizer. For more details, see the [ZAYA1 model card](https://huggingface.co/Zyphra/ZAYA1-8B)
+and Zyphra's technical reports.
+
+## Usage examples
+
+```python
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+model_id = "Zyphra/ZAYA1-8B"
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
+
+inputs = tokenizer("What factors contributed to the fall of the Roman Empire?", return_tensors="pt").to(model.device)
+outputs = model.generate(**inputs, max_new_tokens=100)
+print(tokenizer.decode(outputs[0], skip_special_tokens=True))
+```
+
+## ZayaConfig
+
+[[autodoc]] ZayaConfig
+
+## ZayaModel
+
+[[autodoc]] ZayaModel
+ - forward
+
+## ZayaForCausalLM
+
+[[autodoc]] ZayaForCausalLM
+ - forward
diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py
index e7937fed254f..dff0f65f5b53 100755
--- a/src/transformers/conversion_mapping.py
+++ b/src/transformers/conversion_mapping.py
@@ -561,6 +561,18 @@ def _build_checkpoint_conversion_mapping():
operations=[Transpose(1, 2, check_dims=True)],
),
],
+ "zaya": [
+ WeightConverter(
+ source_patterns="zaya_block.experts.local_experts.*.linear_fc1.weight",
+ target_patterns="zaya_block.experts.gate_up_proj",
+ operations=[MergeModulelist(dim=0)],
+ ),
+ WeightConverter(
+ source_patterns="zaya_block.experts.local_experts.*.linear_fc2.weight",
+ target_patterns="zaya_block.experts.down_proj",
+ operations=[MergeModulelist(dim=0)],
+ ),
+ ],
"phimoe": [
WeightRenaming(".block_sparse_moe.", ".mlp."),
WeightRenaming(".gate.weight", ".router.weight"),
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 406c5f7be0fc..b1c0412758b1 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -479,6 +479,7 @@
from .youtu import *
from .zamba import *
from .zamba2 import *
+ from .zaya import *
from .zoedepth import *
else:
import sys
diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py
index 048dd5275537..ed3866d0ed42 100644
--- a/src/transformers/models/auto/auto_mappings.py
+++ b/src/transformers/models/auto/auto_mappings.py
@@ -643,6 +643,7 @@
("youtu", "YoutuConfig"),
("zamba", "ZambaConfig"),
("zamba2", "Zamba2Config"),
+ ("zaya", "ZayaConfig"),
("zoedepth", "ZoeDepthConfig"),
]
)
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 2202cc773db0..4d90c73183e7 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -510,6 +510,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("yolos", "YolosModel"),
("yoso", "YosoModel"),
("youtu", "YoutuModel"),
+ ("zaya", "ZayaModel"),
("zamba", "ZambaModel"),
("zamba2", "Zamba2Model"),
]
@@ -772,6 +773,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("xlstm", "xLSTMForCausalLM"),
("xmod", "XmodForCausalLM"),
("youtu", "YoutuForCausalLM"),
+ ("zaya", "ZayaForCausalLM"),
("zamba", "ZambaForCausalLM"),
("zamba2", "Zamba2ForCausalLM"),
]
diff --git a/src/transformers/models/zaya/__init__.py b/src/transformers/models/zaya/__init__.py
new file mode 100644
index 000000000000..54cc0c89f303
--- /dev/null
+++ b/src/transformers/models/zaya/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 Zyphra and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_zaya import *
+ from .modeling_zaya import *
+
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py
new file mode 100644
index 000000000000..506df6eee3f0
--- /dev/null
+++ b/src/transformers/models/zaya/configuration_zaya.py
@@ -0,0 +1,183 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/zaya/modular_zaya.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_zaya.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PreTrainedConfig
+from ...utils import auto_docstring
+
+
+@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
+class ZayaConfig(PreTrainedConfig):
+ r"""
+ num_query_groups (`int`, *optional*, defaults to 2):
+ Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`.
+ lm_head_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the language modeling head.
+ ffn_hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the feed-forward and expert hidden states.
+ rope_theta (`float`, *optional*, defaults to 5000000):
+ The base period of the RoPE embeddings.
+ moe_router_topk (`int`, *optional*, defaults to 1):
+ Number of selected experts per token. ZAYA checkpoints use top-1 routing.
+ zaya_mlp_expansion (`int`, *optional*, defaults to 256):
+ Expansion size used by the dense ZAYA blocks.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
+ Fraction of each attention head dimension using rotary embeddings.
+ cca_time0 (`int`, *optional*, defaults to 2):
+ First temporal parameter of the CCA projection.
+ cca_time1 (`int`, *optional*, defaults to 2):
+ Second temporal parameter of the CCA projection.
+ swa_layers (`list[int]`, *optional*):
+ Per-layer selector for standard RoPE versus SWA RoPE embeddings.
+ swa_rotary_base (`float`, *optional*):
+ RoPE base used by SWA layers.
+
+ ```python
+ >>> from transformers import ZayaConfig, ZayaModel
+
+ >>> configuration = ZayaConfig()
+ >>> model = ZayaModel(configuration)
+
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "zaya"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ num_query_groups=2,
+ use_cache=True,
+ attention_bias=False,
+ lm_head_bias=False,
+ vocab_size=262272,
+ hidden_size=2048,
+ ffn_hidden_size=4096,
+ num_hidden_layers=80,
+ num_experts=16,
+ num_attention_heads=8,
+ hidden_act="silu",
+ head_dim=128,
+ initializer_range=0.02,
+ max_position_embeddings=131072,
+ norm_epsilon=1e-05,
+ pad_token_id=0,
+ bos_token_id=2,
+ eos_token_id=106,
+ tie_word_embeddings=True,
+ rope_theta=5000000,
+ attention_dropout=0.0,
+ moe_router_topk=1,
+ zaya_mlp_expansion=256,
+ rope_parameters=None,
+ partial_rotary_factor=0.5,
+ num_key_value_heads=2,
+ cca_time0=2,
+ cca_time1=2,
+ swa_layers=None,
+ swa_rotary_base=None,
+ output_router_logits=False,
+ _attn_implementation="eager",
+ **kwargs,
+ ):
+ for unused_checkpoint_kwarg in (
+ "cca",
+ "activation_func",
+ "normalization",
+ "add_bias_linear",
+ "gated_linear_unit",
+ "fused_add_norm",
+ "apply_rope_fusion",
+ "bias_activation_fusion",
+ "activation_func_fp8_input_store",
+ "clamp_temp",
+ "residual_in_fp32",
+ "rope_scaling",
+ "scale_residual_merge",
+ "sliding_window",
+ "zaya_high_prec",
+ "zaya_use_mod",
+ "zaya_use_eda",
+ ):
+ kwargs.pop(unused_checkpoint_kwarg, None)
+
+ num_query_groups = num_key_value_heads if num_query_groups is None else num_query_groups
+ if head_dim is None:
+ raise ValueError("`head_dim` must be set for ZAYA.")
+ if num_query_groups != num_key_value_heads:
+ raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.")
+ if moe_router_topk != 1:
+ raise ValueError("ZAYA currently supports `moe_router_topk=1` only.")
+
+ self.num_query_groups = num_query_groups
+ self.use_cache = use_cache
+ self.attention_bias = attention_bias
+ self.lm_head_bias = lm_head_bias
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.ffn_hidden_size = ffn_hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_experts = num_experts
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.head_dim = head_dim
+ self.initializer_range = initializer_range
+ self.num_key_value_heads = num_key_value_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.norm_epsilon = norm_epsilon
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+ self.attention_dropout = attention_dropout
+ self.moe_router_topk = moe_router_topk
+ self.zaya_mlp_expansion = zaya_mlp_expansion
+ self.partial_rotary_factor = partial_rotary_factor
+ self.rope_theta = rope_theta
+ rope_parameters = dict(rope_parameters) if rope_parameters is not None else {"rope_type": "default"}
+ rope_parameters.setdefault("rope_theta", rope_theta)
+ rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor)
+ self.rope_parameters = rope_parameters
+ cca_time0 = 2 if cca_time0 is None else cca_time0
+ cca_time1 = 2 if cca_time1 is None else cca_time1
+ if (cca_time0, cca_time1) != (2, 2):
+ raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.")
+ if swa_layers is not None and len(swa_layers) != num_hidden_layers:
+ raise ValueError("`swa_layers` must have one entry per hidden layer.")
+ if swa_layers is not None and swa_rotary_base is None:
+ raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.")
+
+ self.cca_time0 = cca_time0
+ self.cca_time1 = cca_time1
+ self.swa_layers = swa_layers
+ self.swa_rotary_base = swa_rotary_base
+ self.output_router_logits = output_router_logits
+ self._attn_implementation = _attn_implementation
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=self.tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["ZayaConfig"]
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
new file mode 100755
index 000000000000..bbbecaeb1907
--- /dev/null
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -0,0 +1,1126 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/zaya/modular_zaya.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_zaya.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import copy
+from collections.abc import Callable
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import init
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_experts_implementation, use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.generic import maybe_autocast, merge_with_config_defaults
+from ...utils.output_capturing import OutputRecorder, capture_outputs
+from .configuration_zaya import ZayaConfig
+
+
+class ZayaRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: ZayaConfig, device=None):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+
+ self.rope_type = self.config.rope_parameters["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: ZayaConfig | None = None,
+ device: Optional["torch.device"] = None,
+ seq_len: int | None = None,
+ ) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PreTrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_parameters["rope_theta"]
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class ZayaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
+ """
+ ZayaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class ZayaDynamicCache(DynamicCache):
+ """
+ Cache that includes both the KV cache and the CCA cache.
+ """
+
+ def __init__(
+ self,
+ config: ZayaConfig,
+ batch_size: int,
+ dtype: torch.dtype = torch.float16,
+ device: str | None = None,
+ ):
+ super().__init__()
+ self.config = config
+ self.batch_size = batch_size
+ self.dtype = dtype
+ self.device = device
+ self.conv_kernel_size = (config.cca_time0 - 1) + (config.cca_time1 - 1)
+ self.num_layers = config.num_hidden_layers
+ self.key_value_hidden_size = config.num_query_groups * config.head_dim
+ self.query_hidden_size = config.num_attention_heads * config.head_dim
+ self.conv_state_size = self.key_value_hidden_size + self.query_hidden_size
+ self.has_previous_state = False
+
+ self.conv_states = [None for _ in range(self.num_layers)]
+ self.prev_v2 = [None for _ in range(self.num_layers)]
+
+ def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor:
+ if new_conv_state.shape[1] < self.conv_kernel_size:
+ new_conv_state = F.pad(
+ new_conv_state.transpose(1, 2), (self.conv_kernel_size - new_conv_state.shape[1], 0)
+ )
+ else:
+ new_conv_state = new_conv_state[:, -self.conv_kernel_size :, :].transpose(1, 2)
+
+ if self.conv_states[layer_idx] is None:
+ self.conv_states[layer_idx] = torch.zeros_like(new_conv_state)
+
+ if not self.has_previous_state:
+ self.conv_states[layer_idx].copy_(new_conv_state)
+ else:
+ conv_state = torch.cat([self.conv_states[layer_idx], new_conv_state], dim=-1)[
+ :, :, -self.conv_kernel_size :
+ ]
+ self.conv_states[layer_idx].copy_(conv_state)
+ return self.conv_states[layer_idx]
+
+ def update_prev_v2(self, layer_idx: int, new_prev_v2: torch.Tensor) -> torch.Tensor:
+ if self.prev_v2[layer_idx] is None:
+ self.prev_v2[layer_idx] = torch.zeros_like(new_prev_v2)
+ self.prev_v2[layer_idx].copy_(new_prev_v2)
+ return self.prev_v2[layer_idx]
+
+ def reset(self):
+ super().reset()
+ for conv_state in self.conv_states:
+ if conv_state is not None:
+ conv_state.zero_()
+ for prev_v2 in self.prev_v2:
+ if prev_v2 is not None:
+ prev_v2.zero_()
+ self.has_previous_state = False
+
+ def _reorder_auxiliary_states(self, indices: torch.LongTensor):
+ for layer_idx, conv_state in enumerate(self.conv_states):
+ if conv_state is not None:
+ self.conv_states[layer_idx] = conv_state.index_select(0, indices.to(conv_state.device))
+ for layer_idx, prev_v2 in enumerate(self.prev_v2):
+ if prev_v2 is not None:
+ self.prev_v2[layer_idx] = prev_v2.index_select(0, indices.to(prev_v2.device))
+ self.batch_size = indices.shape[0]
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ super().reorder_cache(beam_idx)
+ self._reorder_auxiliary_states(beam_idx)
+
+ def batch_repeat_interleave(self, repeats: int):
+ super().batch_repeat_interleave(repeats)
+ for layer_idx, conv_state in enumerate(self.conv_states):
+ if conv_state is not None:
+ self.conv_states[layer_idx] = conv_state.repeat_interleave(repeats, dim=0)
+ for layer_idx, prev_v2 in enumerate(self.prev_v2):
+ if prev_v2 is not None:
+ self.prev_v2[layer_idx] = prev_v2.repeat_interleave(repeats, dim=0)
+ self.batch_size *= repeats
+
+ def batch_select_indices(self, indices: torch.Tensor):
+ super().batch_select_indices(indices)
+ self._reorder_auxiliary_states(indices)
+
+
+class CCA(nn.Module):
+ def __init__(
+ self,
+ config: ZayaConfig,
+ num_key_value_heads: int = 2,
+ num_attention_heads: int = 8,
+ hidden_size: int | None = None,
+ head_dim: int = 128,
+ cca_time0: int = 2,
+ cca_time1: int = 2,
+ layer_number: int = 0,
+ ):
+ super().__init__()
+ self.config = config
+ self.layer_number = layer_number
+
+ self.hidden_size = int(hidden_size or config.hidden_size)
+
+ self.depthwise_kernel_size = cca_time0
+ self.grouped_kernel_size = cca_time1
+ self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1)
+
+ self.num_key_value_heads = int(num_key_value_heads)
+ self.num_attention_heads = int(num_attention_heads)
+
+ self.head_dim = int(head_dim)
+ self.key_value_hidden_size = self.num_key_value_heads * self.head_dim
+ self.query_hidden_size = self.num_attention_heads * self.head_dim
+ self.sqrt_head_dim = self.head_dim**0.5
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
+ if self.num_attention_heads % self.num_key_value_heads != 0:
+ raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
+
+ self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias)
+ self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias)
+ self.val_proj1 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias)
+ self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias)
+
+ conv_channels = self.key_value_hidden_size + self.query_hidden_size
+ self.conv_qk = nn.Sequential(
+ nn.Conv1d(
+ in_channels=conv_channels,
+ out_channels=conv_channels,
+ kernel_size=self.depthwise_kernel_size,
+ groups=conv_channels,
+ padding=0,
+ stride=1,
+ ),
+ nn.Conv1d(
+ in_channels=conv_channels,
+ out_channels=conv_channels,
+ kernel_size=self.grouped_kernel_size,
+ groups=(self.num_key_value_heads + self.num_attention_heads),
+ padding=0,
+ stride=1,
+ ),
+ )
+
+ self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ past_key_values: ZayaDynamicCache | None,
+ attention_mask: torch.Tensor | None = None,
+ ):
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask[:, :, None].to(hidden_states.dtype)
+
+ batch_size, seq_length, _ = hidden_states.shape
+
+ projected_queries = self.linear_q(hidden_states)
+ projected_keys = self.linear_k(hidden_states)
+ qk_states = torch.cat([projected_queries, projected_keys], dim=-1)
+
+ query_residual = projected_queries.view(batch_size, seq_length, self.num_attention_heads, self.head_dim)
+ key_residual = projected_keys.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
+
+ key_residual = key_residual.unsqueeze(-2).expand(-1, -1, -1, self.num_key_value_groups, -1)
+ key_residual = key_residual.reshape(batch_size, seq_length, self.num_attention_heads, self.head_dim)
+ query_residual = (query_residual + key_residual) * 0.5
+ key_residual = query_residual.view(
+ batch_size, seq_length, self.num_key_value_heads, self.num_key_value_groups, self.head_dim
+ ).mean(dim=-2)
+
+ qk_states = qk_states.transpose(1, 2)
+ use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state
+ if use_precomputed_states:
+ cached_qk_states = past_key_values.conv_states[self.layer_number]
+ conv_input = torch.cat([cached_qk_states, qk_states], dim=-1)
+ else:
+ conv_input = F.pad(qk_states, (self.total_padding, 0))
+
+ if past_key_values is not None:
+ past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=qk_states.transpose(1, 2))
+
+ convolved_qk_states = self.conv_qk(conv_input).transpose(1, 2)
+
+ query = (
+ convolved_qk_states[..., : self.query_hidden_size].view(
+ batch_size, seq_length, self.num_attention_heads, self.head_dim
+ )
+ + query_residual
+ )
+
+ key = (
+ convolved_qk_states[..., self.query_hidden_size :].view(
+ batch_size, seq_length, self.num_key_value_heads, self.head_dim
+ )
+ + key_residual
+ )
+
+ value_current = self.val_proj1(hidden_states)
+ projected_v2 = self.val_proj2(hidden_states)
+ if use_precomputed_states:
+ first_v2 = past_key_values.prev_v2[self.layer_number].unsqueeze(1)
+ else:
+ first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size))
+ value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1)
+
+ if past_key_values is not None:
+ past_key_values.update_prev_v2(self.layer_number, projected_v2[:, -1, :])
+
+ value = torch.cat([value_current, value_delayed], dim=-1).view(
+ batch_size, seq_length, self.num_key_value_heads, self.head_dim
+ )
+
+ norm_eps = torch.finfo(query.dtype).eps
+ query_norm = query.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
+ key_norm = key.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
+
+ key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1)
+ query = query * (self.sqrt_head_dim / query_norm)
+
+ query = query.reshape(batch_size, seq_length, self.query_hidden_size)
+ key = key.reshape(batch_size, seq_length, self.key_value_hidden_size)
+ value = value.reshape(batch_size, seq_length, self.key_value_hidden_size)
+ return query, key, value
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Removes the interleaving of cos and sin from GLM
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class ZayaAttention(nn.Module):
+ def __init__(self, config: ZayaConfig, layer_n):
+ super().__init__()
+ self.config = config
+ self.layer_n = layer_n
+ self.layer_idx = layer_n
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+ self.head_dim = config.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ self.o_proj = nn.Linear(
+ self.num_attention_heads * self.head_dim,
+ self.hidden_size,
+ bias=self.config.attention_bias,
+ )
+ self.qkv = CCA(
+ config=self.config,
+ num_attention_heads=self.config.num_attention_heads,
+ num_key_value_heads=self.config.num_query_groups,
+ hidden_size=self.hidden_size,
+ head_dim=self.config.head_dim,
+ cca_time0=self.config.cca_time0,
+ cca_time1=self.config.cca_time1,
+ layer_number=layer_n,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ attention_mask_2d: torch.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ output_attentions: bool = False,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
+ batch_size, seq_length, _ = hidden_states.shape
+ query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, attention_mask_2d)
+ query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads, self.head_dim)
+ key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim)
+ value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n)
+
+ causal_mask = attention_mask
+ if causal_mask is not None:
+ causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]]
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ causal_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ )
+
+ attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_values
+
+
+def _apply_residual_scaling(
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ residual_scaling,
+ rms_norm: ZayaRMSNorm,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ residual, hidden_states = residual_scaling(residual, hidden_states)
+ residual = hidden_states.to(torch.float32) if residual is None else hidden_states + residual
+ hidden_states = rms_norm(residual.to(dtype=rms_norm.weight.dtype))
+ return hidden_states, residual
+
+
+class ZayaDecoderATTLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ZayaConfig, layer_n: int):
+
+ super().__init__()
+ self.config = config
+ self.self_attn = ZayaAttention(config, layer_n)
+
+ self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
+ self.res_scale = ResidualScaling(config, layer_n)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ attention_mask_2d: torch.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ output_attentions: bool | None = False,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ prev_router_hidden_states: torch.Tensor | None = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
+ hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
+
+ hidden_states, self_attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ attention_mask_2d=attention_mask_2d,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ position_embeddings=position_embeddings,
+ )
+
+ return hidden_states, self_attn_weights if output_attentions else None, residual, prev_router_hidden_states
+
+
+class ResidualScaling(nn.Module):
+ def __init__(self, config, layer_n):
+ super().__init__()
+ self.not_first_layer = layer_n != 0
+ self.hidden_states_scale = torch.nn.Parameter(torch.ones(config.hidden_size))
+ self.hidden_states_bias = torch.nn.Parameter(torch.zeros(config.hidden_size))
+
+ if self.not_first_layer:
+ self.residual_scale = torch.nn.Parameter(torch.ones(config.hidden_size))
+ self.residual_bias = torch.nn.Parameter(torch.zeros(config.hidden_size))
+
+ def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor):
+ hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale
+ if self.not_first_layer:
+ residual = (residual + self.residual_bias) * self.residual_scale
+ return residual, hidden_states
+
+
+class ZayaRouter(nn.Module):
+ def __init__(
+ self,
+ config,
+ layer_idx: int,
+ num_moe_experts: int,
+ moe_router_topk: int,
+ mlp_expansion: int,
+ hidden_size: int | None = None,
+ ) -> None:
+ super().__init__()
+
+ self.config = config
+ self.hidden_size = int(hidden_size or getattr(config, "hidden_size"))
+ self.layer_idx = layer_idx
+
+ self.num_experts = num_moe_experts + 1
+ self.topk = int(moe_router_topk)
+ self.mlp_expansion = int(mlp_expansion)
+
+ self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True)
+
+ zaya_first_layer = 1
+ self.use_eda = self.layer_idx != zaya_first_layer
+
+ ln_eps = float(getattr(config, "norm_epsilon", 1e-5))
+ self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=ln_eps)
+ if self.use_eda:
+ self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion))
+
+ self.non_linearity = nn.GELU()
+ self.router_mlp = nn.Sequential(
+ nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True),
+ self.non_linearity,
+ nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True),
+ self.non_linearity,
+ nn.Linear(self.mlp_expansion, self.num_experts, bias=False),
+ )
+
+ self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32))
+ self.balancing_biases[-1] = -1.0
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ router_states: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ seq_length = hidden_states.shape[1]
+
+ router_hidden_states = self.down_proj(hidden_states)
+
+ if self.use_eda and (router_states is not None):
+ router_hidden_states = router_hidden_states + router_states * self.router_states_scale
+
+ router_hidden_states_next = router_hidden_states[:, -seq_length:].clone()
+ router_hidden_states = self.rmsnorm_eda(router_hidden_states)
+ logits = self.router_mlp(router_hidden_states)
+ expert_prob = torch.softmax(logits, dim=-1)
+
+ expert_choice = expert_prob.detach().to(torch.float32) + self.balancing_biases
+ _, expert_choice = torch.topk(expert_choice, self.topk, dim=-1)
+ route_prob = torch.gather(expert_prob, dim=2, index=expert_choice)
+
+ return (
+ route_prob.reshape(-1, self.topk),
+ expert_choice.reshape(-1, self.topk),
+ router_hidden_states_next,
+ logits.reshape(-1, self.num_experts),
+ )
+
+
+@use_experts_implementation
+class ZayaExperts(nn.Module):
+ """Collection of expert weights stored as 3D tensors."""
+
+ def __init__(self, config, num_experts: int, ffn_hidden_size: int):
+ super().__init__()
+ self.num_experts = num_experts
+ self.hidden_dim = config.hidden_size
+ self.intermediate_dim = ffn_hidden_size // 2
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
+ self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ top_k_index: torch.Tensor,
+ top_k_weights: torch.Tensor,
+ ) -> torch.Tensor:
+ final_hidden_states = torch.zeros_like(hidden_states)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts + 1)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+ for expert_idx in expert_hit:
+ expert_idx = expert_idx[0]
+ if expert_idx == self.num_experts:
+ continue
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
+ current_state = hidden_states[token_idx]
+ gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
+ current_hidden_states = self.act_fn(gate) * up
+ current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
+ current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
+
+ return final_hidden_states
+
+
+class ZayaBlock(nn.Module):
+ def __init__(
+ self,
+ config,
+ num_moe_experts: int,
+ mlp_expansion: int,
+ ffn_hidden_size: int,
+ layer_n: int,
+ ):
+
+ super().__init__()
+ self.config = config
+ self.hidden_dim = config.hidden_size
+ self.num_moe_experts = num_moe_experts
+ self.router = ZayaRouter(
+ config=self.config,
+ layer_idx=layer_n,
+ num_moe_experts=self.num_moe_experts,
+ moe_router_topk=getattr(self.config, "moe_router_topk", 1),
+ mlp_expansion=mlp_expansion,
+ hidden_size=self.hidden_dim,
+ )
+ self.experts = ZayaExperts(self.config, self.num_moe_experts, ffn_hidden_size=ffn_hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ prev_router_hidden_states: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
+ route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router(
+ hidden_states, router_states=prev_router_hidden_states
+ )
+ batch_size, seq_length, emb_dim = hidden_states.shape
+ hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim)
+ expert_output = self.experts(hidden_states_flat, expert_choice, route_prob)
+ expert_output = expert_output.view(batch_size, seq_length, emb_dim)
+
+ return expert_output, prev_router_hidden_states, router_logits
+
+
+class ZayaDecoderMLPLayer(GradientCheckpointingLayer):
+ def __init__(
+ self,
+ config: ZayaConfig,
+ num_moe_experts: int,
+ mlp_expansion: int,
+ ffn_hidden_size: int,
+ layer_n: int,
+ ):
+
+ super().__init__()
+ self.config = config
+ self.zaya_block = ZayaBlock(
+ config,
+ num_moe_experts,
+ mlp_expansion,
+ ffn_hidden_size,
+ layer_n,
+ )
+ self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
+ self.res_scale = ResidualScaling(config, layer_n)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ prev_router_hidden_states: torch.Tensor | None = None,
+ output_router_logits: bool = False,
+ **kwargs,
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]:
+ hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
+
+ hidden_states, prev_router_hidden_states, router_logits = self.zaya_block(
+ hidden_states,
+ prev_router_hidden_states,
+ )
+
+ return (
+ hidden_states,
+ router_logits if output_router_logits else None,
+ residual,
+ prev_router_hidden_states,
+ )
+
+
+class ZayaPreTrainedModel(PreTrainedModel):
+ config: ZayaConfig
+ config_class = ZayaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(ZayaRouter, index=3),
+ }
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, ResidualScaling):
+ init.ones_(module.hidden_states_scale)
+ init.zeros_(module.hidden_states_bias)
+ if module.not_first_layer:
+ init.ones_(module.residual_scale)
+ init.zeros_(module.residual_bias)
+ elif isinstance(module, ZayaRouter):
+ if module.use_eda:
+ init.ones_(module.router_states_scale)
+ init.zeros_(module.balancing_biases)
+ module.balancing_biases[-1] = -1.0
+ elif isinstance(module, ZayaExperts):
+ std = self.config.initializer_range
+ init.normal_(module.gate_up_proj, mean=0.0, std=std)
+ init.normal_(module.down_proj, mean=0.0, std=std)
+
+
+@auto_docstring
+class ZayaModel(ZayaPreTrainedModel):
+ def __init__(self, config: ZayaConfig):
+
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = []
+
+ for layer_n in range(config.num_hidden_layers):
+ if layer_n % 2 == 1:
+ self.layers.append(
+ ZayaDecoderMLPLayer(
+ config,
+ config.num_experts,
+ config.zaya_mlp_expansion,
+ config.ffn_hidden_size,
+ layer_n,
+ )
+ )
+ else:
+ self.layers.append(ZayaDecoderATTLayer(config, layer_n))
+ self.layers = nn.ModuleList(self.layers)
+
+ self.gradient_checkpointing = False
+ self.res_scale = ResidualScaling(config, config.num_hidden_layers)
+
+ self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
+
+ self.rotary_emb = ZayaRotaryEmbedding(config=config)
+ if self.config.swa_layers is not None:
+ swa_config = copy.copy(config)
+ swa_config.rope_parameters = {
+ **config.rope_parameters,
+ "rope_theta": swa_config.swa_rotary_base,
+ }
+ self.swa_rotary_emb = ZayaRotaryEmbedding(config=swa_config)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @merge_with_config_defaults
+ @capture_outputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ output_router_logits: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = ZayaDynamicCache(
+ self.config, inputs_embeds.shape[0], dtype=self.dtype, device=self.device
+ )
+
+ residual = None
+
+ if position_ids is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ position_ids = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ ).unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds,
+ position_ids,
+ past_key_values,
+ )
+ if attention_mask is not None and attention_mask.ndim != 2:
+ raise ValueError("ZAYA CCA requires a 2D `attention_mask` to mask padding tokens before convolution.")
+ # ZayaDynamicCache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
+ # CCA only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
+ attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
+ if inputs_embeds.shape[1] == 1:
+ attention_mask_2d = None
+
+ hidden_states = inputs_embeds
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ if self.config.swa_layers is not None:
+ swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids)
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ prev_router_hidden_states = None
+
+ for layer_n, decoder_layer in enumerate(self.layers):
+ if self.config.swa_layers is not None:
+ emb_to_use = position_embeddings if self.config.swa_layers[layer_n] == 0 else swa_position_embeddings
+ else:
+ emb_to_use = position_embeddings
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ residual,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ position_embeddings=emb_to_use,
+ prev_router_hidden_states=prev_router_hidden_states,
+ attention_mask_2d=attention_mask_2d,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+ residual = layer_outputs[2]
+ prev_router_hidden_states = layer_outputs[3]
+
+ if isinstance(decoder_layer, ZayaDecoderATTLayer):
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if past_key_values and not past_key_values.has_previous_state:
+ past_key_values.has_previous_state = True
+
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: Cache,
+ ):
+ return create_causal_mask(
+ config=self.config,
+ inputs_embeds=input_tensor,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+
+@auto_docstring
+class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+ _is_stateful = True
+
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+ self.model = ZayaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
+
+ self.post_init()
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_router_logits: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeCausalLMOutputWithPast:
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_router_logits=output_router_logits,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits,
+ labels=labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=None,
+ logits=logits,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ position_ids=None,
+ use_cache=True,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ if past_key_values is not None and not isinstance(past_key_values, ZayaDynamicCache):
+ raise ValueError(
+ f"Zaya uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
+ )
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ use_cache=use_cache,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+ return model_inputs
+
+ def _prepare_cache_for_generation(
+ self,
+ generation_config,
+ model_kwargs: dict,
+ generation_mode,
+ batch_size: int,
+ max_cache_length: int,
+ ):
+ if generation_config.use_cache is False:
+ return
+
+ if "past_key_values" not in model_kwargs:
+ cache_batch_size = batch_size * max(generation_config.num_beams, generation_config.num_return_sequences)
+ model_kwargs["past_key_values"] = ZayaDynamicCache(
+ self.config, cache_batch_size, dtype=self.dtype, device=self.device
+ )
+ generation_config.cache_implementation = None
+ return super()._prepare_cache_for_generation(
+ generation_config=generation_config,
+ model_kwargs=model_kwargs,
+ generation_mode=generation_mode,
+ batch_size=batch_size,
+ max_cache_length=max_cache_length,
+ )
+
+
+__all__ = ["ZayaPreTrainedModel", "ZayaModel", "ZayaForCausalLM"]
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
new file mode 100644
index 000000000000..60bb870c73a5
--- /dev/null
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -0,0 +1,1133 @@
+# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""PyTorch Zaya model."""
+
+import copy
+from collections.abc import Callable
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import init
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...configuration_utils import PreTrainedConfig
+from ...generation import GenerationMixin
+from ...integrations import use_experts_implementation
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ MoeCausalLMOutputWithPast,
+ MoeModelOutputWithPast,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ TransformersKwargs,
+ auto_docstring,
+ can_return_tuple,
+)
+from ...utils.generic import merge_with_config_defaults
+from ...utils.output_capturing import OutputRecorder, capture_outputs
+from ..glm4.modeling_glm4 import Glm4RotaryEmbedding
+from ..qwen3_5_moe.modeling_qwen3_5_moe import (
+ apply_rotary_pos_emb,
+ eager_attention_forward,
+)
+from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm
+
+
+@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
+class ZayaConfig(PreTrainedConfig):
+ r"""
+ num_query_groups (`int`, *optional*, defaults to 2):
+ Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`.
+ lm_head_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the language modeling head.
+ ffn_hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the feed-forward and expert hidden states.
+ rope_theta (`float`, *optional*, defaults to 5000000):
+ The base period of the RoPE embeddings.
+ moe_router_topk (`int`, *optional*, defaults to 1):
+ Number of selected experts per token. ZAYA checkpoints use top-1 routing.
+ zaya_mlp_expansion (`int`, *optional*, defaults to 256):
+ Expansion size used by the dense ZAYA blocks.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
+ Fraction of each attention head dimension using rotary embeddings.
+ cca_time0 (`int`, *optional*, defaults to 2):
+ First temporal parameter of the CCA projection.
+ cca_time1 (`int`, *optional*, defaults to 2):
+ Second temporal parameter of the CCA projection.
+ swa_layers (`list[int]`, *optional*):
+ Per-layer selector for standard RoPE versus SWA RoPE embeddings.
+ swa_rotary_base (`float`, *optional*):
+ RoPE base used by SWA layers.
+
+ ```python
+ >>> from transformers import ZayaConfig, ZayaModel
+
+ >>> configuration = ZayaConfig()
+ >>> model = ZayaModel(configuration)
+
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "zaya"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ num_query_groups=2,
+ use_cache=True,
+ attention_bias=False,
+ lm_head_bias=False,
+ vocab_size=262272,
+ hidden_size=2048,
+ ffn_hidden_size=4096,
+ num_hidden_layers=80,
+ num_experts=16,
+ num_attention_heads=8,
+ hidden_act="silu",
+ head_dim=128,
+ initializer_range=0.02,
+ max_position_embeddings=131072,
+ norm_epsilon=1e-05,
+ pad_token_id=0,
+ bos_token_id=2,
+ eos_token_id=106,
+ tie_word_embeddings=True,
+ rope_theta=5000000,
+ attention_dropout=0.0,
+ moe_router_topk=1,
+ zaya_mlp_expansion=256,
+ rope_parameters=None,
+ partial_rotary_factor=0.5,
+ num_key_value_heads=2,
+ cca_time0=2,
+ cca_time1=2,
+ swa_layers=None,
+ swa_rotary_base=None,
+ output_router_logits=False,
+ _attn_implementation="eager",
+ **kwargs,
+ ):
+ for unused_checkpoint_kwarg in (
+ "cca",
+ "activation_func",
+ "normalization",
+ "add_bias_linear",
+ "gated_linear_unit",
+ "fused_add_norm",
+ "apply_rope_fusion",
+ "bias_activation_fusion",
+ "activation_func_fp8_input_store",
+ "clamp_temp",
+ "residual_in_fp32",
+ "rope_scaling",
+ "scale_residual_merge",
+ "sliding_window",
+ "zaya_high_prec",
+ "zaya_use_mod",
+ "zaya_use_eda",
+ ):
+ kwargs.pop(unused_checkpoint_kwarg, None)
+
+ num_query_groups = num_key_value_heads if num_query_groups is None else num_query_groups
+ if head_dim is None:
+ raise ValueError("`head_dim` must be set for ZAYA.")
+ if num_query_groups != num_key_value_heads:
+ raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.")
+ if moe_router_topk != 1:
+ raise ValueError("ZAYA currently supports `moe_router_topk=1` only.")
+
+ self.num_query_groups = num_query_groups
+ self.use_cache = use_cache
+ self.attention_bias = attention_bias
+ self.lm_head_bias = lm_head_bias
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.ffn_hidden_size = ffn_hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_experts = num_experts
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.head_dim = head_dim
+ self.initializer_range = initializer_range
+ self.num_key_value_heads = num_key_value_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.norm_epsilon = norm_epsilon
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+ self.attention_dropout = attention_dropout
+ self.moe_router_topk = moe_router_topk
+ self.zaya_mlp_expansion = zaya_mlp_expansion
+ self.partial_rotary_factor = partial_rotary_factor
+ self.rope_theta = rope_theta
+ rope_parameters = dict(rope_parameters) if rope_parameters is not None else {"rope_type": "default"}
+ rope_parameters.setdefault("rope_theta", rope_theta)
+ rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor)
+ self.rope_parameters = rope_parameters
+ cca_time0 = 2 if cca_time0 is None else cca_time0
+ cca_time1 = 2 if cca_time1 is None else cca_time1
+ if (cca_time0, cca_time1) != (2, 2):
+ raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.")
+ if swa_layers is not None and len(swa_layers) != num_hidden_layers:
+ raise ValueError("`swa_layers` must have one entry per hidden layer.")
+ if swa_layers is not None and swa_rotary_base is None:
+ raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.")
+
+ self.cca_time0 = cca_time0
+ self.cca_time1 = cca_time1
+ self.swa_layers = swa_layers
+ self.swa_rotary_base = swa_rotary_base
+ self.output_router_logits = output_router_logits
+ self._attn_implementation = _attn_implementation
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=self.tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class ZayaRotaryEmbedding(Glm4RotaryEmbedding):
+ pass
+
+
+class ZayaRMSNorm(Qwen3MoeRMSNorm):
+ pass
+
+
+class ZayaDynamicCache(DynamicCache):
+ """
+ Cache that includes both the KV cache and the CCA cache.
+ """
+
+ def __init__(
+ self,
+ config: ZayaConfig,
+ batch_size: int,
+ dtype: torch.dtype = torch.float16,
+ device: str | None = None,
+ ):
+ super().__init__()
+ self.config = config
+ self.batch_size = batch_size
+ self.dtype = dtype
+ self.device = device
+ self.conv_kernel_size = (config.cca_time0 - 1) + (config.cca_time1 - 1)
+ self.num_layers = config.num_hidden_layers
+ self.key_value_hidden_size = config.num_query_groups * config.head_dim
+ self.query_hidden_size = config.num_attention_heads * config.head_dim
+ self.conv_state_size = self.key_value_hidden_size + self.query_hidden_size
+ self.has_previous_state = False
+
+ self.conv_states = [None for _ in range(self.num_layers)]
+ self.prev_v2 = [None for _ in range(self.num_layers)]
+
+ def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor:
+ if new_conv_state.shape[1] < self.conv_kernel_size:
+ new_conv_state = F.pad(
+ new_conv_state.transpose(1, 2), (self.conv_kernel_size - new_conv_state.shape[1], 0)
+ )
+ else:
+ new_conv_state = new_conv_state[:, -self.conv_kernel_size :, :].transpose(1, 2)
+
+ if self.conv_states[layer_idx] is None:
+ self.conv_states[layer_idx] = torch.zeros_like(new_conv_state)
+
+ if not self.has_previous_state:
+ self.conv_states[layer_idx].copy_(new_conv_state)
+ else:
+ conv_state = torch.cat([self.conv_states[layer_idx], new_conv_state], dim=-1)[
+ :, :, -self.conv_kernel_size :
+ ]
+ self.conv_states[layer_idx].copy_(conv_state)
+ return self.conv_states[layer_idx]
+
+ def update_prev_v2(self, layer_idx: int, new_prev_v2: torch.Tensor) -> torch.Tensor:
+ if self.prev_v2[layer_idx] is None:
+ self.prev_v2[layer_idx] = torch.zeros_like(new_prev_v2)
+ self.prev_v2[layer_idx].copy_(new_prev_v2)
+ return self.prev_v2[layer_idx]
+
+ def reset(self):
+ super().reset()
+ for conv_state in self.conv_states:
+ if conv_state is not None:
+ conv_state.zero_()
+ for prev_v2 in self.prev_v2:
+ if prev_v2 is not None:
+ prev_v2.zero_()
+ self.has_previous_state = False
+
+ def _reorder_auxiliary_states(self, indices: torch.LongTensor):
+ for layer_idx, conv_state in enumerate(self.conv_states):
+ if conv_state is not None:
+ self.conv_states[layer_idx] = conv_state.index_select(0, indices.to(conv_state.device))
+ for layer_idx, prev_v2 in enumerate(self.prev_v2):
+ if prev_v2 is not None:
+ self.prev_v2[layer_idx] = prev_v2.index_select(0, indices.to(prev_v2.device))
+ self.batch_size = indices.shape[0]
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ super().reorder_cache(beam_idx)
+ self._reorder_auxiliary_states(beam_idx)
+
+ def batch_repeat_interleave(self, repeats: int):
+ super().batch_repeat_interleave(repeats)
+ for layer_idx, conv_state in enumerate(self.conv_states):
+ if conv_state is not None:
+ self.conv_states[layer_idx] = conv_state.repeat_interleave(repeats, dim=0)
+ for layer_idx, prev_v2 in enumerate(self.prev_v2):
+ if prev_v2 is not None:
+ self.prev_v2[layer_idx] = prev_v2.repeat_interleave(repeats, dim=0)
+ self.batch_size *= repeats
+
+ def batch_select_indices(self, indices: torch.Tensor):
+ super().batch_select_indices(indices)
+ self._reorder_auxiliary_states(indices)
+
+
+class CCA(nn.Module):
+ def __init__(
+ self,
+ config: ZayaConfig,
+ num_key_value_heads: int = 2,
+ num_attention_heads: int = 8,
+ hidden_size: int | None = None,
+ head_dim: int = 128,
+ cca_time0: int = 2,
+ cca_time1: int = 2,
+ layer_number: int = 0,
+ ):
+ super().__init__()
+ self.config = config
+ self.layer_number = layer_number
+
+ self.hidden_size = int(hidden_size or config.hidden_size)
+
+ self.depthwise_kernel_size = cca_time0
+ self.grouped_kernel_size = cca_time1
+ self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1)
+
+ self.num_key_value_heads = int(num_key_value_heads)
+ self.num_attention_heads = int(num_attention_heads)
+
+ self.head_dim = int(head_dim)
+ self.key_value_hidden_size = self.num_key_value_heads * self.head_dim
+ self.query_hidden_size = self.num_attention_heads * self.head_dim
+ self.sqrt_head_dim = self.head_dim**0.5
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
+ if self.num_attention_heads % self.num_key_value_heads != 0:
+ raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
+
+ self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias)
+ self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias)
+ self.val_proj1 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias)
+ self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias)
+
+ conv_channels = self.key_value_hidden_size + self.query_hidden_size
+ self.conv_qk = nn.Sequential(
+ nn.Conv1d(
+ in_channels=conv_channels,
+ out_channels=conv_channels,
+ kernel_size=self.depthwise_kernel_size,
+ groups=conv_channels,
+ padding=0,
+ stride=1,
+ ),
+ nn.Conv1d(
+ in_channels=conv_channels,
+ out_channels=conv_channels,
+ kernel_size=self.grouped_kernel_size,
+ groups=(self.num_key_value_heads + self.num_attention_heads),
+ padding=0,
+ stride=1,
+ ),
+ )
+
+ self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ past_key_values: ZayaDynamicCache | None,
+ attention_mask: torch.Tensor | None = None,
+ ):
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask[:, :, None].to(hidden_states.dtype)
+
+ batch_size, seq_length, _ = hidden_states.shape
+
+ projected_queries = self.linear_q(hidden_states)
+ projected_keys = self.linear_k(hidden_states)
+ qk_states = torch.cat([projected_queries, projected_keys], dim=-1)
+
+ query_residual = projected_queries.view(batch_size, seq_length, self.num_attention_heads, self.head_dim)
+ key_residual = projected_keys.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
+
+ key_residual = key_residual.unsqueeze(-2).expand(-1, -1, -1, self.num_key_value_groups, -1)
+ key_residual = key_residual.reshape(batch_size, seq_length, self.num_attention_heads, self.head_dim)
+ query_residual = (query_residual + key_residual) * 0.5
+ key_residual = query_residual.view(
+ batch_size, seq_length, self.num_key_value_heads, self.num_key_value_groups, self.head_dim
+ ).mean(dim=-2)
+
+ qk_states = qk_states.transpose(1, 2)
+ use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state
+ if use_precomputed_states:
+ cached_qk_states = past_key_values.conv_states[self.layer_number]
+ conv_input = torch.cat([cached_qk_states, qk_states], dim=-1)
+ else:
+ conv_input = F.pad(qk_states, (self.total_padding, 0))
+
+ if past_key_values is not None:
+ past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=qk_states.transpose(1, 2))
+
+ convolved_qk_states = self.conv_qk(conv_input).transpose(1, 2)
+
+ query = (
+ convolved_qk_states[..., : self.query_hidden_size].view(
+ batch_size, seq_length, self.num_attention_heads, self.head_dim
+ )
+ + query_residual
+ )
+
+ key = (
+ convolved_qk_states[..., self.query_hidden_size :].view(
+ batch_size, seq_length, self.num_key_value_heads, self.head_dim
+ )
+ + key_residual
+ )
+
+ value_current = self.val_proj1(hidden_states)
+ projected_v2 = self.val_proj2(hidden_states)
+ if use_precomputed_states:
+ first_v2 = past_key_values.prev_v2[self.layer_number].unsqueeze(1)
+ else:
+ first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size))
+ value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1)
+
+ if past_key_values is not None:
+ past_key_values.update_prev_v2(self.layer_number, projected_v2[:, -1, :])
+
+ value = torch.cat([value_current, value_delayed], dim=-1).view(
+ batch_size, seq_length, self.num_key_value_heads, self.head_dim
+ )
+
+ norm_eps = torch.finfo(query.dtype).eps
+ query_norm = query.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
+ key_norm = key.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
+
+ key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1)
+ query = query * (self.sqrt_head_dim / query_norm)
+
+ query = query.reshape(batch_size, seq_length, self.query_hidden_size)
+ key = key.reshape(batch_size, seq_length, self.key_value_hidden_size)
+ value = value.reshape(batch_size, seq_length, self.key_value_hidden_size)
+ return query, key, value
+
+
+class ZayaAttention(nn.Module):
+ def __init__(self, config: ZayaConfig, layer_n):
+ super().__init__()
+ self.config = config
+ self.layer_n = layer_n
+ self.layer_idx = layer_n
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+ self.head_dim = config.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ self.o_proj = nn.Linear(
+ self.num_attention_heads * self.head_dim,
+ self.hidden_size,
+ bias=self.config.attention_bias,
+ )
+ self.qkv = CCA(
+ config=self.config,
+ num_attention_heads=self.config.num_attention_heads,
+ num_key_value_heads=self.config.num_query_groups,
+ hidden_size=self.hidden_size,
+ head_dim=self.config.head_dim,
+ cca_time0=self.config.cca_time0,
+ cca_time1=self.config.cca_time1,
+ layer_number=layer_n,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ attention_mask_2d: torch.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ output_attentions: bool = False,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
+ batch_size, seq_length, _ = hidden_states.shape
+ query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, attention_mask_2d)
+ query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads, self.head_dim)
+ key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim)
+ value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n)
+
+ causal_mask = attention_mask
+ if causal_mask is not None:
+ causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]]
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ causal_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ )
+
+ attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_values
+
+
+class ZayaDecoderATTLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ZayaConfig, layer_n: int):
+
+ super().__init__()
+ self.config = config
+ self.self_attn = ZayaAttention(config, layer_n)
+
+ self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
+ self.res_scale = ResidualScaling(config, layer_n)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ attention_mask_2d: torch.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ output_attentions: bool | None = False,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ prev_router_hidden_states: torch.Tensor | None = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
+ hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
+
+ hidden_states, self_attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ attention_mask_2d=attention_mask_2d,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ position_embeddings=position_embeddings,
+ )
+
+ return hidden_states, self_attn_weights if output_attentions else None, residual, prev_router_hidden_states
+
+
+class ResidualScaling(nn.Module):
+ def __init__(self, config, layer_n):
+ super().__init__()
+ self.not_first_layer = layer_n != 0
+ self.hidden_states_scale = torch.nn.Parameter(torch.ones(config.hidden_size))
+ self.hidden_states_bias = torch.nn.Parameter(torch.zeros(config.hidden_size))
+
+ if self.not_first_layer:
+ self.residual_scale = torch.nn.Parameter(torch.ones(config.hidden_size))
+ self.residual_bias = torch.nn.Parameter(torch.zeros(config.hidden_size))
+
+ def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor):
+ hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale
+ if self.not_first_layer:
+ residual = (residual + self.residual_bias) * self.residual_scale
+ return residual, hidden_states
+
+
+def _apply_residual_scaling(
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ residual_scaling,
+ rms_norm: ZayaRMSNorm,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ residual, hidden_states = residual_scaling(residual, hidden_states)
+ residual = hidden_states.to(torch.float32) if residual is None else hidden_states + residual
+ hidden_states = rms_norm(residual.to(dtype=rms_norm.weight.dtype))
+ return hidden_states, residual
+
+
+class ZayaRouter(nn.Module):
+ def __init__(
+ self,
+ config,
+ layer_idx: int,
+ num_moe_experts: int,
+ moe_router_topk: int,
+ mlp_expansion: int,
+ hidden_size: int | None = None,
+ ) -> None:
+ super().__init__()
+
+ self.config = config
+ self.hidden_size = int(hidden_size or getattr(config, "hidden_size"))
+ self.layer_idx = layer_idx
+
+ self.num_experts = num_moe_experts + 1
+ self.topk = int(moe_router_topk)
+ self.mlp_expansion = int(mlp_expansion)
+
+ self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True)
+
+ zaya_first_layer = 1
+ self.use_eda = self.layer_idx != zaya_first_layer
+
+ ln_eps = float(getattr(config, "norm_epsilon", 1e-5))
+ self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=ln_eps)
+ if self.use_eda:
+ self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion))
+
+ self.non_linearity = nn.GELU()
+ self.router_mlp = nn.Sequential(
+ nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True),
+ self.non_linearity,
+ nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True),
+ self.non_linearity,
+ nn.Linear(self.mlp_expansion, self.num_experts, bias=False),
+ )
+
+ self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32))
+ self.balancing_biases[-1] = -1.0
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ router_states: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ seq_length = hidden_states.shape[1]
+
+ router_hidden_states = self.down_proj(hidden_states)
+
+ if self.use_eda and (router_states is not None):
+ router_hidden_states = router_hidden_states + router_states * self.router_states_scale
+
+ router_hidden_states_next = router_hidden_states[:, -seq_length:].clone()
+ router_hidden_states = self.rmsnorm_eda(router_hidden_states)
+ logits = self.router_mlp(router_hidden_states)
+ expert_prob = torch.softmax(logits, dim=-1)
+
+ expert_choice = expert_prob.detach().to(torch.float32) + self.balancing_biases
+ _, expert_choice = torch.topk(expert_choice, self.topk, dim=-1)
+ route_prob = torch.gather(expert_prob, dim=2, index=expert_choice)
+
+ return (
+ route_prob.reshape(-1, self.topk),
+ expert_choice.reshape(-1, self.topk),
+ router_hidden_states_next,
+ logits.reshape(-1, self.num_experts),
+ )
+
+
+@use_experts_implementation
+class ZayaExperts(nn.Module):
+ """Collection of expert weights stored as 3D tensors."""
+
+ def __init__(self, config, num_experts: int, ffn_hidden_size: int):
+ super().__init__()
+ self.num_experts = num_experts
+ self.hidden_dim = config.hidden_size
+ self.intermediate_dim = ffn_hidden_size // 2
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
+ self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ top_k_index: torch.Tensor,
+ top_k_weights: torch.Tensor,
+ ) -> torch.Tensor:
+ final_hidden_states = torch.zeros_like(hidden_states)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts + 1)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+ for expert_idx in expert_hit:
+ expert_idx = expert_idx[0]
+ if expert_idx == self.num_experts:
+ continue
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
+ current_state = hidden_states[token_idx]
+ gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
+ current_hidden_states = self.act_fn(gate) * up
+ current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
+ current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
+
+ return final_hidden_states
+
+
+class ZayaBlock(nn.Module):
+ def __init__(
+ self,
+ config,
+ num_moe_experts: int,
+ mlp_expansion: int,
+ ffn_hidden_size: int,
+ layer_n: int,
+ ):
+
+ super().__init__()
+ self.config = config
+ self.hidden_dim = config.hidden_size
+ self.num_moe_experts = num_moe_experts
+ self.router = ZayaRouter(
+ config=self.config,
+ layer_idx=layer_n,
+ num_moe_experts=self.num_moe_experts,
+ moe_router_topk=getattr(self.config, "moe_router_topk", 1),
+ mlp_expansion=mlp_expansion,
+ hidden_size=self.hidden_dim,
+ )
+ self.experts = ZayaExperts(self.config, self.num_moe_experts, ffn_hidden_size=ffn_hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ prev_router_hidden_states: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
+ route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router(
+ hidden_states, router_states=prev_router_hidden_states
+ )
+ batch_size, seq_length, emb_dim = hidden_states.shape
+ hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim)
+ expert_output = self.experts(hidden_states_flat, expert_choice, route_prob)
+ expert_output = expert_output.view(batch_size, seq_length, emb_dim)
+
+ return expert_output, prev_router_hidden_states, router_logits
+
+
+class ZayaDecoderMLPLayer(GradientCheckpointingLayer):
+ def __init__(
+ self,
+ config: ZayaConfig,
+ num_moe_experts: int,
+ mlp_expansion: int,
+ ffn_hidden_size: int,
+ layer_n: int,
+ ):
+
+ super().__init__()
+ self.config = config
+ self.zaya_block = ZayaBlock(
+ config,
+ num_moe_experts,
+ mlp_expansion,
+ ffn_hidden_size,
+ layer_n,
+ )
+ self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
+ self.res_scale = ResidualScaling(config, layer_n)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ prev_router_hidden_states: torch.Tensor | None = None,
+ output_router_logits: bool = False,
+ **kwargs,
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]:
+ hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
+
+ hidden_states, prev_router_hidden_states, router_logits = self.zaya_block(
+ hidden_states,
+ prev_router_hidden_states,
+ )
+
+ return (
+ hidden_states,
+ router_logits if output_router_logits else None,
+ residual,
+ prev_router_hidden_states,
+ )
+
+
+class ZayaPreTrainedModel(PreTrainedModel):
+ config: ZayaConfig
+ config_class = ZayaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(ZayaRouter, index=3),
+ }
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, ResidualScaling):
+ init.ones_(module.hidden_states_scale)
+ init.zeros_(module.hidden_states_bias)
+ if module.not_first_layer:
+ init.ones_(module.residual_scale)
+ init.zeros_(module.residual_bias)
+ elif isinstance(module, ZayaRouter):
+ if module.use_eda:
+ init.ones_(module.router_states_scale)
+ init.zeros_(module.balancing_biases)
+ module.balancing_biases[-1] = -1.0
+ elif isinstance(module, ZayaExperts):
+ std = self.config.initializer_range
+ init.normal_(module.gate_up_proj, mean=0.0, std=std)
+ init.normal_(module.down_proj, mean=0.0, std=std)
+
+
+@auto_docstring
+class ZayaModel(ZayaPreTrainedModel):
+ def __init__(self, config: ZayaConfig):
+
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = []
+
+ for layer_n in range(config.num_hidden_layers):
+ if layer_n % 2 == 1:
+ self.layers.append(
+ ZayaDecoderMLPLayer(
+ config,
+ config.num_experts,
+ config.zaya_mlp_expansion,
+ config.ffn_hidden_size,
+ layer_n,
+ )
+ )
+ else:
+ self.layers.append(ZayaDecoderATTLayer(config, layer_n))
+ self.layers = nn.ModuleList(self.layers)
+
+ self.gradient_checkpointing = False
+ self.res_scale = ResidualScaling(config, config.num_hidden_layers)
+
+ self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
+
+ self.rotary_emb = ZayaRotaryEmbedding(config=config)
+ if self.config.swa_layers is not None:
+ swa_config = copy.copy(config)
+ swa_config.rope_parameters = {
+ **config.rope_parameters,
+ "rope_theta": swa_config.swa_rotary_base,
+ }
+ self.swa_rotary_emb = ZayaRotaryEmbedding(config=swa_config)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @merge_with_config_defaults
+ @capture_outputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ output_router_logits: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = ZayaDynamicCache(
+ self.config, inputs_embeds.shape[0], dtype=self.dtype, device=self.device
+ )
+
+ residual = None
+
+ if position_ids is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ position_ids = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ ).unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds,
+ position_ids,
+ past_key_values,
+ )
+ if attention_mask is not None and attention_mask.ndim != 2:
+ raise ValueError("ZAYA CCA requires a 2D `attention_mask` to mask padding tokens before convolution.")
+ # ZayaDynamicCache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
+ # CCA only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
+ attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
+ if inputs_embeds.shape[1] == 1:
+ attention_mask_2d = None
+
+ hidden_states = inputs_embeds
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ if self.config.swa_layers is not None:
+ swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids)
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ prev_router_hidden_states = None
+
+ for layer_n, decoder_layer in enumerate(self.layers):
+ if self.config.swa_layers is not None:
+ emb_to_use = position_embeddings if self.config.swa_layers[layer_n] == 0 else swa_position_embeddings
+ else:
+ emb_to_use = position_embeddings
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ residual,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ position_embeddings=emb_to_use,
+ prev_router_hidden_states=prev_router_hidden_states,
+ attention_mask_2d=attention_mask_2d,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+ residual = layer_outputs[2]
+ prev_router_hidden_states = layer_outputs[3]
+
+ if isinstance(decoder_layer, ZayaDecoderATTLayer):
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if past_key_values and not past_key_values.has_previous_state:
+ past_key_values.has_previous_state = True
+
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: Cache,
+ ):
+ return create_causal_mask(
+ config=self.config,
+ inputs_embeds=input_tensor,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+
+@auto_docstring
+class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+ _is_stateful = True
+
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+ self.model = ZayaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
+
+ self.post_init()
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_router_logits: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeCausalLMOutputWithPast:
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_router_logits=output_router_logits,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits,
+ labels=labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=None,
+ logits=logits,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ position_ids=None,
+ use_cache=True,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ if past_key_values is not None and not isinstance(past_key_values, ZayaDynamicCache):
+ raise ValueError(
+ f"Zaya uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
+ )
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ use_cache=use_cache,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+ return model_inputs
+
+ def _prepare_cache_for_generation(
+ self,
+ generation_config,
+ model_kwargs: dict,
+ generation_mode,
+ batch_size: int,
+ max_cache_length: int,
+ ):
+ if generation_config.use_cache is False:
+ return
+
+ if "past_key_values" not in model_kwargs:
+ cache_batch_size = batch_size * max(generation_config.num_beams, generation_config.num_return_sequences)
+ model_kwargs["past_key_values"] = ZayaDynamicCache(
+ self.config, cache_batch_size, dtype=self.dtype, device=self.device
+ )
+ generation_config.cache_implementation = None
+ return super()._prepare_cache_for_generation(
+ generation_config=generation_config,
+ model_kwargs=model_kwargs,
+ generation_mode=generation_mode,
+ batch_size=batch_size,
+ max_cache_length=max_cache_length,
+ )
+
+
+__all__ = [
+ "ZayaConfig",
+ "ZayaPreTrainedModel",
+ "ZayaModel",
+ "ZayaForCausalLM",
+]
diff --git a/tests/models/zaya/__init__.py b/tests/models/zaya/__init__.py
new file mode 100644
index 000000000000..8b137891791f
--- /dev/null
+++ b/tests/models/zaya/__init__.py
@@ -0,0 +1 @@
+
From d26fffc9a241c89886b19b083f1223e1462b3c5b Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Sat, 9 May 2026 20:09:44 +0800
Subject: [PATCH 02/36] add test
---
tests/models/zaya/test_modeling_zaya.py | 349 ++++++++++++++++++++++++
1 file changed, 349 insertions(+)
create mode 100644 tests/models/zaya/test_modeling_zaya.py
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
new file mode 100644
index 000000000000..2338d07675af
--- /dev/null
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -0,0 +1,349 @@
+# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch ZAYA model."""
+
+import unittest
+
+from parameterized import parameterized
+
+from transformers import is_torch_available
+from transformers.testing_utils import cleanup, require_torch, slow, torch_device
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel
+ from transformers.models.zaya.modeling_zaya import CCA, ZayaDynamicCache
+
+from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
+
+
+class ZayaModelTester(CausalLMModelTester):
+ if is_torch_available():
+ base_model_class = ZayaModel
+
+ def __init__(self, parent):
+ super().__init__(
+ parent=parent,
+ batch_size=2,
+ seq_length=7,
+ vocab_size=128,
+ hidden_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ intermediate_size=64,
+ )
+ self.head_dim = 8
+ self.ffn_hidden_size = 64
+ self.num_query_groups = 2
+ self.num_experts = 4
+ self.moe_router_topk = 1
+ self.zaya_mlp_expansion = 4
+ self.tie_word_embeddings = False
+ self.rope_parameters = {
+ "rope_theta": 10000,
+ "rope_type": "default",
+ }
+
+
+@require_torch
+class ZayaModelTest(CausalLMModelTest, unittest.TestCase):
+ model_tester_class = ZayaModelTester
+ test_all_params_have_gradient = False
+
+ def is_pipeline_test_to_skip(
+ self,
+ pipeline_test_case_name,
+ config_class,
+ model_architecture,
+ tokenizer_name,
+ image_processor_name,
+ feature_extractor_name,
+ processor_name,
+ ):
+ return True
+
+ @unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.")
+ def test_eager_padding_matches_padding_free_with_position_ids(self):
+ pass
+
+ @unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.")
+ def test_sdpa_padding_matches_padding_free_with_position_ids(self):
+ pass
+
+ @unittest.skip("ZAYA uses MoE routing; equivalent-output comparisons are not stable for this architecture.")
+ def test_model_outputs_equivalence(self, **kwargs):
+ pass
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+ config._attn_implementation = "eager"
+
+ for model_class in self.all_model_classes:
+ model = model_class._from_config(config, attn_implementation="eager")
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class({**inputs_dict, "output_attentions": True}, model_class))
+
+ expected_attn_layers = (config.num_hidden_layers + 1) // 2
+ self.assertEqual(len(outputs.attentions), expected_attn_layers)
+ self.assertEqual(
+ outputs.attentions[0].shape,
+ (
+ self.model_tester.batch_size,
+ config.num_attention_heads,
+ self.model_tester.seq_length,
+ self.model_tester.seq_length,
+ ),
+ )
+
+ @parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
+ @unittest.skip(
+ "ZAYA uses partial rotary embeddings with CCA, which is not compatible with this generic RoPE test."
+ )
+ def test_model_rope_scaling_from_config(self, scaling_type):
+ pass
+
+ @unittest.skip("ZAYA needs alternating attention and MoE layers in the tiny test configuration.")
+ def test_num_layers_is_small(self):
+ pass
+
+ @unittest.skip("ZAYA uses a custom cache carrying CCA convolution state in addition to KV tensors.")
+ def test_past_key_values_format(self):
+ pass
+
+ @unittest.skip("ZAYA's custom CCA cache is not a standard per-layer KV cache.")
+ def test_greedy_generate_dict_outputs_use_cache(self):
+ pass
+
+ @unittest.skip("ZAYA's custom CCA cache is not a standard per-layer KV cache.")
+ def test_beam_search_generate_dict_outputs_use_cache(self):
+ pass
+
+ def test_moe_router_logits(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = self.model_tester.causal_lm_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**inputs_dict, output_router_logits=True)
+
+ expected_moe_layers = config.num_hidden_layers // 2
+ self.assertEqual(len(outputs.router_logits), expected_moe_layers)
+ self.assertEqual(
+ outputs.router_logits[0].shape,
+ (self.model_tester.batch_size * self.model_tester.seq_length, config.num_experts + 1),
+ )
+
+ def test_moe_router_topk_validation(self):
+ with self.assertRaisesRegex(ValueError, "moe_router_topk=1"):
+ ZayaConfig(moe_router_topk=2)
+
+ def test_cca_cache_matches_full_forward(self):
+ config = ZayaConfig(
+ vocab_size=128,
+ hidden_size=32,
+ ffn_hidden_size=64,
+ num_hidden_layers=1,
+ num_experts=4,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ num_query_groups=2,
+ head_dim=8,
+ zaya_mlp_expansion=4,
+ tie_word_embeddings=False,
+ )
+ torch.manual_seed(0)
+ cca = CCA(
+ config,
+ num_key_value_heads=config.num_key_value_heads,
+ num_attention_heads=config.num_attention_heads,
+ hidden_size=config.hidden_size,
+ head_dim=config.head_dim,
+ layer_number=0,
+ ).to(torch_device)
+ cca.eval()
+ hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device)
+
+ with torch.no_grad():
+ full = cca(hidden_states, None, None)
+ cache = ZayaDynamicCache(config, batch_size=1, dtype=hidden_states.dtype, device=torch_device)
+ cca(hidden_states[:, :4], cache, None)
+ cache.has_previous_state = True
+ cached = cca(hidden_states[:, 4:], cache, None)
+
+ for full_states, cached_states in zip(full, cached):
+ torch.testing.assert_close(full_states[:, -1:], cached_states, rtol=1e-5, atol=1e-5)
+
+ def test_cca_cache_matches_full_forward_multi_token(self):
+ config = ZayaConfig(
+ vocab_size=128,
+ hidden_size=32,
+ ffn_hidden_size=64,
+ num_hidden_layers=1,
+ num_experts=4,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ num_query_groups=2,
+ head_dim=8,
+ zaya_mlp_expansion=4,
+ tie_word_embeddings=False,
+ )
+ torch.manual_seed(0)
+ cca = CCA(
+ config,
+ num_key_value_heads=config.num_key_value_heads,
+ num_attention_heads=config.num_attention_heads,
+ hidden_size=config.hidden_size,
+ head_dim=config.head_dim,
+ layer_number=0,
+ ).to(torch_device)
+ cca.eval()
+ hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device)
+
+ with torch.no_grad():
+ full = cca(hidden_states, None, None)
+ cache = ZayaDynamicCache(config, batch_size=1, dtype=hidden_states.dtype, device=torch_device)
+ cca(hidden_states[:, :3], cache, None)
+ cache.has_previous_state = True
+ cached = cca(hidden_states[:, 3:], cache, None)
+
+ for full_states, cached_states in zip(full, cached):
+ torch.testing.assert_close(full_states[:, 3:], cached_states, rtol=1e-5, atol=1e-5)
+
+ def test_zaya_cache_batch_methods(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ cache = ZayaDynamicCache(config, batch_size=2, dtype=torch.float32, device=torch_device)
+ cache.update_conv_state(
+ 0,
+ torch.arange(2 * 2 * cache.conv_state_size, device=torch_device, dtype=torch.float32).view(
+ 2, 2, cache.conv_state_size
+ ),
+ )
+ cache.update_prev_v2(
+ 0,
+ torch.arange(
+ 2 * config.num_key_value_heads * config.head_dim // 2, device=torch_device, dtype=torch.float32
+ ).view(2, config.num_key_value_heads * config.head_dim // 2),
+ )
+ self.assertEqual(cache.prev_v2[0].shape[-1], config.num_key_value_heads * config.head_dim // 2)
+
+ cache.batch_repeat_interleave(2)
+ self.assertEqual(cache.conv_states[0].shape[0], 4)
+ self.assertEqual(cache.prev_v2[0].shape[0], 4)
+
+ cache.batch_select_indices(torch.tensor([3, 1], device=torch_device))
+ self.assertEqual(cache.conv_states[0].shape[0], 2)
+ self.assertEqual(cache.prev_v2[0].shape[0], 2)
+
+ cache.reorder_cache(torch.tensor([1, 0], device=torch_device))
+ self.assertEqual(cache.batch_size, 2)
+
+ cache.has_previous_state = True
+ cache.reset()
+ self.assertFalse(cache.has_previous_state)
+ self.assertEqual(cache.conv_states[0].sum().item(), 0)
+ self.assertEqual(cache.prev_v2[0].sum().item(), 0)
+
+
+@require_torch
+class ZayaIntegrationTest(unittest.TestCase):
+ model = None
+ model_id = "Zyphra/ZAYA1-8B"
+
+ @classmethod
+ def get_model(cls):
+ if cls.model is None:
+ cls.model = ZayaForCausalLM.from_pretrained(cls.model_id, device_map="auto", dtype=torch.bfloat16)
+ return cls.model
+
+ @classmethod
+ def tearDownClass(cls):
+ if cls.model is not None:
+ del cls.model
+ cleanup(torch_device, gc_collect=True)
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ def get_inputs(self):
+ tokenizer = AutoTokenizer.from_pretrained(self.model_id)
+ inputs = tokenizer("Hello! How can I assist you today?", return_tensors="pt")
+ self.assertEqual(
+ inputs.input_ids.tolist(),
+ [[2, 9259, 236888, 2088, 740, 564, 6361, 611, 3124, 236881, 106]],
+ )
+ return inputs
+
+ @slow
+ def test_model_logits(self):
+ model = self.get_model()
+ inputs = self.get_inputs().to(model.model.embed_tokens.weight.device)
+
+ with torch.no_grad():
+ outputs = model(**inputs, use_cache=False, output_hidden_states=True, return_dict=True)
+
+ logits = outputs.logits.float().cpu()
+ hidden_states = outputs.hidden_states[-1].float().cpu()
+
+ EXPECTED_HIDDEN_MEAN = torch.tensor(
+ [[0.0399, -0.0123, -0.0560, -0.0480, -0.0627, -0.0362, -0.0220, 0.0004, -0.0321, -0.0263, 0.0046]]
+ )
+ torch.testing.assert_close(hidden_states.mean(-1), EXPECTED_HIDDEN_MEAN, rtol=1e-2, atol=1e-2)
+
+ EXPECTED_HIDDEN_SLICE = torch.tensor([-2.7812, 0.3320, 4.1562, -0.4395, 1.6406, 1.3359, -1.4531, -2.6719, 5.5000, -4.7500, 2.0625, 0.2930, -2.2344, -2.6094, 2.0781, 2.5000, 0.7969, 0.6836, -0.5469, 1.3906]) # fmt: skip
+ torch.testing.assert_close(hidden_states[0, 0, :20], EXPECTED_HIDDEN_SLICE, rtol=1e-2, atol=1e-2)
+
+ EXPECTED_LOGITS_SLICE = torch.tensor([-2.3438, 1.7344, 3.7656, -3.8750, 0.4707, -0.7422, -2.5938, -2.7188, -2.9375, -2.9844, -3.0000, -3.0000, -3.0000, -3.0000, -3.0156, -3.0000, -3.0000, -3.0000, -3.0000, -3.0000]) # fmt: skip
+ torch.testing.assert_close(logits[0, -1, :20], EXPECTED_LOGITS_SLICE, rtol=1e-2, atol=1e-2)
+ self.assertEqual(logits[0, -1].argmax().item(), 107)
+
+ @slow
+ def test_model_cache_matches_full_forward(self):
+ model = self.get_model()
+ inputs = self.get_inputs().to(model.model.embed_tokens.weight.device)
+
+ with torch.no_grad():
+ full_logits = model(**inputs, use_cache=False).logits[:, -1]
+ prefill_outputs = model(
+ input_ids=inputs.input_ids[:, :-1],
+ attention_mask=inputs.attention_mask[:, :-1],
+ use_cache=True,
+ return_dict=True,
+ )
+ cached_logits = model(
+ input_ids=inputs.input_ids[:, -1:],
+ attention_mask=inputs.attention_mask,
+ past_key_values=prefill_outputs.past_key_values,
+ use_cache=True,
+ return_dict=True,
+ ).logits[:, -1]
+
+ torch.testing.assert_close(cached_logits.float().cpu(), full_logits.float().cpu(), rtol=1e-4, atol=1e-4)
+
+ @slow
+ def test_model_generation(self):
+ model = self.get_model()
+ inputs = self.get_inputs().to(model.model.embed_tokens.weight.device)
+
+ with torch.no_grad():
+ generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=3, top_k=None, top_p=None)
+
+ self.assertEqual(generated_ids[0, -3:].tolist(), [107, 262146, 108])
From 8191d39741e5cb55347ce78772ac14a5de1a335f Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Sat, 9 May 2026 20:11:19 +0800
Subject: [PATCH 03/36] update example
---
docs/source/en/model_doc/zaya.md | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md
index 7f881a47efb9..01e7a8504e1d 100644
--- a/docs/source/en/model_doc/zaya.md
+++ b/docs/source/en/model_doc/zaya.md
@@ -35,8 +35,9 @@ model_id = "Zyphra/ZAYA1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
-inputs = tokenizer("What factors contributed to the fall of the Roman Empire?", return_tensors="pt").to(model.device)
-outputs = model.generate(**inputs, max_new_tokens=100)
+inputs = tokenizer.apply_chat_template([{"role": "user", "content": "Write a haiku about recursion in programming."}], tokenize=True, add_generation_prompt=True, enable_thinking=False, return_tensors="pt")
+inputs = inputs.to(model.device)
+outputs = model.generate(**inputs, max_new_tokens=2048)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
From c125ef3124143bede61c1928cdecf75fb1ed7ae3 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Sat, 9 May 2026 20:43:13 +0800
Subject: [PATCH 04/36] new config
---
.../models/zaya/configuration_zaya.py | 157 +++++++-----------
src/transformers/models/zaya/modular_zaya.py | 156 +++++++----------
2 files changed, 125 insertions(+), 188 deletions(-)
diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py
index 506df6eee3f0..12a7c2999abc 100644
--- a/src/transformers/models/zaya/configuration_zaya.py
+++ b/src/transformers/models/zaya/configuration_zaya.py
@@ -18,27 +18,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from huggingface_hub.dataclasses import strict
+
from ...configuration_utils import PreTrainedConfig
+from ...modeling_rope_utils import RopeParameters
from ...utils import auto_docstring
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
+@strict
class ZayaConfig(PreTrainedConfig):
r"""
- num_query_groups (`int`, *optional*, defaults to 2):
- Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`.
- lm_head_bias (`bool`, *optional*, defaults to `False`):
- Whether to add a bias to the language modeling head.
ffn_hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the feed-forward and expert hidden states.
+ num_query_groups (`int`, *optional*, defaults to 2):
+ Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`.
rope_theta (`float`, *optional*, defaults to 5000000):
The base period of the RoPE embeddings.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
+ Fraction of each attention head dimension using rotary embeddings.
+ lm_head_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the language modeling head.
moe_router_topk (`int`, *optional*, defaults to 1):
Number of selected experts per token. ZAYA checkpoints use top-1 routing.
zaya_mlp_expansion (`int`, *optional*, defaults to 256):
Expansion size used by the dense ZAYA blocks.
- partial_rotary_factor (`float`, *optional*, defaults to 0.5):
- Fraction of each attention head dimension using rotary embeddings.
cca_time0 (`int`, *optional*, defaults to 2):
First temporal parameter of the CCA projection.
cca_time1 (`int`, *optional*, defaults to 2):
@@ -61,42 +65,39 @@ class ZayaConfig(PreTrainedConfig):
model_type = "zaya"
keys_to_ignore_at_inference = ["past_key_values"]
- def __init__(
- self,
- num_query_groups=2,
- use_cache=True,
- attention_bias=False,
- lm_head_bias=False,
- vocab_size=262272,
- hidden_size=2048,
- ffn_hidden_size=4096,
- num_hidden_layers=80,
- num_experts=16,
- num_attention_heads=8,
- hidden_act="silu",
- head_dim=128,
- initializer_range=0.02,
- max_position_embeddings=131072,
- norm_epsilon=1e-05,
- pad_token_id=0,
- bos_token_id=2,
- eos_token_id=106,
- tie_word_embeddings=True,
- rope_theta=5000000,
- attention_dropout=0.0,
- moe_router_topk=1,
- zaya_mlp_expansion=256,
- rope_parameters=None,
- partial_rotary_factor=0.5,
- num_key_value_heads=2,
- cca_time0=2,
- cca_time1=2,
- swa_layers=None,
- swa_rotary_base=None,
- output_router_logits=False,
- _attn_implementation="eager",
- **kwargs,
- ):
+ vocab_size: int = 262272
+ hidden_size: int = 2048
+ ffn_hidden_size: int = 4096
+ num_hidden_layers: int = 80
+ num_experts: int = 16
+ num_attention_heads: int = 8
+ num_key_value_heads: int | None = 2
+ num_query_groups: int | None = 2
+ hidden_act: str = "silu"
+ head_dim: int = 128
+ max_position_embeddings: int = 131072
+ initializer_range: float = 0.02
+ norm_epsilon: float = 1e-5
+ use_cache: bool = True
+ tie_word_embeddings: bool = True
+ rope_parameters: RopeParameters | dict | None = None
+ rope_theta: float | int = 5000000
+ partial_rotary_factor: float = 0.5
+ attention_bias: bool = False
+ lm_head_bias: bool = False
+ attention_dropout: float | int = 0.0
+ moe_router_topk: int = 1
+ zaya_mlp_expansion: int = 256
+ cca_time0: int | None = 2
+ cca_time1: int | None = 2
+ swa_layers: list[int] | None = None
+ swa_rotary_base: float | int | None = None
+ output_router_logits: bool = False
+ pad_token_id: int | None = 0
+ bos_token_id: int | None = 2
+ eos_token_id: int | list[int] | None = 106
+
+ def __post_init__(self, **kwargs):
for unused_checkpoint_kwarg in (
"cca",
"activation_func",
@@ -108,6 +109,8 @@ def __init__(
"bias_activation_fusion",
"activation_func_fp8_input_store",
"clamp_temp",
+ "kv_channels",
+ "mamba_cache_dtype",
"residual_in_fp32",
"rope_scaling",
"scale_residual_merge",
@@ -118,66 +121,32 @@ def __init__(
):
kwargs.pop(unused_checkpoint_kwarg, None)
- num_query_groups = num_key_value_heads if num_query_groups is None else num_query_groups
- if head_dim is None:
+ self.num_key_value_heads = (
+ self.num_attention_heads if self.num_key_value_heads is None else self.num_key_value_heads
+ )
+ self.num_query_groups = self.num_key_value_heads if self.num_query_groups is None else self.num_query_groups
+ if self.head_dim is None:
raise ValueError("`head_dim` must be set for ZAYA.")
- if num_query_groups != num_key_value_heads:
+ if self.num_query_groups != self.num_key_value_heads:
raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.")
- if moe_router_topk != 1:
+ if self.moe_router_topk != 1:
raise ValueError("ZAYA currently supports `moe_router_topk=1` only.")
- self.num_query_groups = num_query_groups
- self.use_cache = use_cache
- self.attention_bias = attention_bias
- self.lm_head_bias = lm_head_bias
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.ffn_hidden_size = ffn_hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_experts = num_experts
- self.num_attention_heads = num_attention_heads
- self.hidden_act = hidden_act
- self.head_dim = head_dim
- self.initializer_range = initializer_range
- self.num_key_value_heads = num_key_value_heads
- self.max_position_embeddings = max_position_embeddings
- self.norm_epsilon = norm_epsilon
- self.pad_token_id = pad_token_id
- self.bos_token_id = bos_token_id
- self.eos_token_id = eos_token_id
- self.tie_word_embeddings = tie_word_embeddings
- self.attention_dropout = attention_dropout
- self.moe_router_topk = moe_router_topk
- self.zaya_mlp_expansion = zaya_mlp_expansion
- self.partial_rotary_factor = partial_rotary_factor
- self.rope_theta = rope_theta
- rope_parameters = dict(rope_parameters) if rope_parameters is not None else {"rope_type": "default"}
- rope_parameters.setdefault("rope_theta", rope_theta)
- rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor)
- self.rope_parameters = rope_parameters
- cca_time0 = 2 if cca_time0 is None else cca_time0
- cca_time1 = 2 if cca_time1 is None else cca_time1
- if (cca_time0, cca_time1) != (2, 2):
+ self.rope_parameters = (
+ dict(self.rope_parameters) if self.rope_parameters is not None else {"rope_type": "default"}
+ )
+ self.rope_parameters.setdefault("rope_theta", self.rope_theta)
+ self.rope_parameters.setdefault("partial_rotary_factor", self.partial_rotary_factor)
+ self.cca_time0 = 2 if self.cca_time0 is None else self.cca_time0
+ self.cca_time1 = 2 if self.cca_time1 is None else self.cca_time1
+ if (self.cca_time0, self.cca_time1) != (2, 2):
raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.")
- if swa_layers is not None and len(swa_layers) != num_hidden_layers:
+ if self.swa_layers is not None and len(self.swa_layers) != self.num_hidden_layers:
raise ValueError("`swa_layers` must have one entry per hidden layer.")
- if swa_layers is not None and swa_rotary_base is None:
+ if self.swa_layers is not None and self.swa_rotary_base is None:
raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.")
- self.cca_time0 = cca_time0
- self.cca_time1 = cca_time1
- self.swa_layers = swa_layers
- self.swa_rotary_base = swa_rotary_base
- self.output_router_logits = output_router_logits
- self._attn_implementation = _attn_implementation
-
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=self.tie_word_embeddings,
- **kwargs,
- )
+ super().__post_init__(**kwargs)
__all__ = ["ZayaConfig"]
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 60bb870c73a5..ee6f44c840f3 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -20,6 +20,7 @@
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
+from huggingface_hub.dataclasses import strict
from torch import nn
from torch.nn import init
@@ -34,6 +35,7 @@
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
)
+from ...modeling_rope_utils import RopeParameters
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
@@ -52,22 +54,23 @@
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
+@strict
class ZayaConfig(PreTrainedConfig):
r"""
- num_query_groups (`int`, *optional*, defaults to 2):
- Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`.
- lm_head_bias (`bool`, *optional*, defaults to `False`):
- Whether to add a bias to the language modeling head.
ffn_hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the feed-forward and expert hidden states.
+ num_query_groups (`int`, *optional*, defaults to 2):
+ Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`.
rope_theta (`float`, *optional*, defaults to 5000000):
The base period of the RoPE embeddings.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
+ Fraction of each attention head dimension using rotary embeddings.
+ lm_head_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the language modeling head.
moe_router_topk (`int`, *optional*, defaults to 1):
Number of selected experts per token. ZAYA checkpoints use top-1 routing.
zaya_mlp_expansion (`int`, *optional*, defaults to 256):
Expansion size used by the dense ZAYA blocks.
- partial_rotary_factor (`float`, *optional*, defaults to 0.5):
- Fraction of each attention head dimension using rotary embeddings.
cca_time0 (`int`, *optional*, defaults to 2):
First temporal parameter of the CCA projection.
cca_time1 (`int`, *optional*, defaults to 2):
@@ -90,42 +93,39 @@ class ZayaConfig(PreTrainedConfig):
model_type = "zaya"
keys_to_ignore_at_inference = ["past_key_values"]
- def __init__(
- self,
- num_query_groups=2,
- use_cache=True,
- attention_bias=False,
- lm_head_bias=False,
- vocab_size=262272,
- hidden_size=2048,
- ffn_hidden_size=4096,
- num_hidden_layers=80,
- num_experts=16,
- num_attention_heads=8,
- hidden_act="silu",
- head_dim=128,
- initializer_range=0.02,
- max_position_embeddings=131072,
- norm_epsilon=1e-05,
- pad_token_id=0,
- bos_token_id=2,
- eos_token_id=106,
- tie_word_embeddings=True,
- rope_theta=5000000,
- attention_dropout=0.0,
- moe_router_topk=1,
- zaya_mlp_expansion=256,
- rope_parameters=None,
- partial_rotary_factor=0.5,
- num_key_value_heads=2,
- cca_time0=2,
- cca_time1=2,
- swa_layers=None,
- swa_rotary_base=None,
- output_router_logits=False,
- _attn_implementation="eager",
- **kwargs,
- ):
+ vocab_size: int = 262272
+ hidden_size: int = 2048
+ ffn_hidden_size: int = 4096
+ num_hidden_layers: int = 80
+ num_experts: int = 16
+ num_attention_heads: int = 8
+ num_key_value_heads: int | None = 2
+ num_query_groups: int | None = 2
+ hidden_act: str = "silu"
+ head_dim: int = 128
+ max_position_embeddings: int = 131072
+ initializer_range: float = 0.02
+ norm_epsilon: float = 1e-5
+ use_cache: bool = True
+ tie_word_embeddings: bool = True
+ rope_parameters: RopeParameters | dict | None = None
+ rope_theta: float | int = 5000000
+ partial_rotary_factor: float = 0.5
+ attention_bias: bool = False
+ lm_head_bias: bool = False
+ attention_dropout: float | int = 0.0
+ moe_router_topk: int = 1
+ zaya_mlp_expansion: int = 256
+ cca_time0: int | None = 2
+ cca_time1: int | None = 2
+ swa_layers: list[int] | None = None
+ swa_rotary_base: float | int | None = None
+ output_router_logits: bool = False
+ pad_token_id: int | None = 0
+ bos_token_id: int | None = 2
+ eos_token_id: int | list[int] | None = 106
+
+ def __post_init__(self, **kwargs):
for unused_checkpoint_kwarg in (
"cca",
"activation_func",
@@ -137,6 +137,8 @@ def __init__(
"bias_activation_fusion",
"activation_func_fp8_input_store",
"clamp_temp",
+ "kv_channels",
+ "mamba_cache_dtype",
"residual_in_fp32",
"rope_scaling",
"scale_residual_merge",
@@ -147,66 +149,32 @@ def __init__(
):
kwargs.pop(unused_checkpoint_kwarg, None)
- num_query_groups = num_key_value_heads if num_query_groups is None else num_query_groups
- if head_dim is None:
+ self.num_key_value_heads = (
+ self.num_attention_heads if self.num_key_value_heads is None else self.num_key_value_heads
+ )
+ self.num_query_groups = self.num_key_value_heads if self.num_query_groups is None else self.num_query_groups
+ if self.head_dim is None:
raise ValueError("`head_dim` must be set for ZAYA.")
- if num_query_groups != num_key_value_heads:
+ if self.num_query_groups != self.num_key_value_heads:
raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.")
- if moe_router_topk != 1:
+ if self.moe_router_topk != 1:
raise ValueError("ZAYA currently supports `moe_router_topk=1` only.")
- self.num_query_groups = num_query_groups
- self.use_cache = use_cache
- self.attention_bias = attention_bias
- self.lm_head_bias = lm_head_bias
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.ffn_hidden_size = ffn_hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_experts = num_experts
- self.num_attention_heads = num_attention_heads
- self.hidden_act = hidden_act
- self.head_dim = head_dim
- self.initializer_range = initializer_range
- self.num_key_value_heads = num_key_value_heads
- self.max_position_embeddings = max_position_embeddings
- self.norm_epsilon = norm_epsilon
- self.pad_token_id = pad_token_id
- self.bos_token_id = bos_token_id
- self.eos_token_id = eos_token_id
- self.tie_word_embeddings = tie_word_embeddings
- self.attention_dropout = attention_dropout
- self.moe_router_topk = moe_router_topk
- self.zaya_mlp_expansion = zaya_mlp_expansion
- self.partial_rotary_factor = partial_rotary_factor
- self.rope_theta = rope_theta
- rope_parameters = dict(rope_parameters) if rope_parameters is not None else {"rope_type": "default"}
- rope_parameters.setdefault("rope_theta", rope_theta)
- rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor)
- self.rope_parameters = rope_parameters
- cca_time0 = 2 if cca_time0 is None else cca_time0
- cca_time1 = 2 if cca_time1 is None else cca_time1
- if (cca_time0, cca_time1) != (2, 2):
+ self.rope_parameters = (
+ dict(self.rope_parameters) if self.rope_parameters is not None else {"rope_type": "default"}
+ )
+ self.rope_parameters.setdefault("rope_theta", self.rope_theta)
+ self.rope_parameters.setdefault("partial_rotary_factor", self.partial_rotary_factor)
+ self.cca_time0 = 2 if self.cca_time0 is None else self.cca_time0
+ self.cca_time1 = 2 if self.cca_time1 is None else self.cca_time1
+ if (self.cca_time0, self.cca_time1) != (2, 2):
raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.")
- if swa_layers is not None and len(swa_layers) != num_hidden_layers:
+ if self.swa_layers is not None and len(self.swa_layers) != self.num_hidden_layers:
raise ValueError("`swa_layers` must have one entry per hidden layer.")
- if swa_layers is not None and swa_rotary_base is None:
+ if self.swa_layers is not None and self.swa_rotary_base is None:
raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.")
- self.cca_time0 = cca_time0
- self.cca_time1 = cca_time1
- self.swa_layers = swa_layers
- self.swa_rotary_base = swa_rotary_base
- self.output_router_logits = output_router_logits
- self._attn_implementation = _attn_implementation
-
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=self.tie_word_embeddings,
- **kwargs,
- )
+ super().__post_init__(**kwargs)
class ZayaRotaryEmbedding(Glm4RotaryEmbedding):
From c90df6f33e910c666115264a9ad22fe0381cd7df Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Sat, 9 May 2026 20:58:11 +0800
Subject: [PATCH 05/36] remove empty line
---
src/transformers/models/zaya/modeling_zaya.py | 4 ----
src/transformers/models/zaya/modular_zaya.py | 4 ----
2 files changed, 8 deletions(-)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index bbbecaeb1907..ab68cbc73d36 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -543,7 +543,6 @@ def _apply_residual_scaling(
class ZayaDecoderATTLayer(GradientCheckpointingLayer):
def __init__(self, config: ZayaConfig, layer_n: int):
-
super().__init__()
self.config = config
self.self_attn = ZayaAttention(config, layer_n)
@@ -715,7 +714,6 @@ def __init__(
ffn_hidden_size: int,
layer_n: int,
):
-
super().__init__()
self.config = config
self.hidden_dim = config.hidden_size
@@ -755,7 +753,6 @@ def __init__(
ffn_hidden_size: int,
layer_n: int,
):
-
super().__init__()
self.config = config
self.zaya_block = ZayaBlock(
@@ -829,7 +826,6 @@ def _init_weights(self, module):
@auto_docstring
class ZayaModel(ZayaPreTrainedModel):
def __init__(self, config: ZayaConfig):
-
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index ee6f44c840f3..d5bb4efd767e 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -501,7 +501,6 @@ def forward(
class ZayaDecoderATTLayer(GradientCheckpointingLayer):
def __init__(self, config: ZayaConfig, layer_n: int):
-
super().__init__()
self.config = config
self.self_attn = ZayaAttention(config, layer_n)
@@ -685,7 +684,6 @@ def __init__(
ffn_hidden_size: int,
layer_n: int,
):
-
super().__init__()
self.config = config
self.hidden_dim = config.hidden_size
@@ -725,7 +723,6 @@ def __init__(
ffn_hidden_size: int,
layer_n: int,
):
-
super().__init__()
self.config = config
self.zaya_block = ZayaBlock(
@@ -799,7 +796,6 @@ def _init_weights(self, module):
@auto_docstring
class ZayaModel(ZayaPreTrainedModel):
def __init__(self, config: ZayaConfig):
-
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
From b90759f1de9b000ab16abdc9c6cc7190ad60381f Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Sat, 9 May 2026 21:04:32 +0800
Subject: [PATCH 06/36] pass ci
---
src/transformers/models/auto/modeling_auto.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 4d90c73183e7..6ced59e8556f 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -510,9 +510,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("yolos", "YolosModel"),
("yoso", "YosoModel"),
("youtu", "YoutuModel"),
- ("zaya", "ZayaModel"),
("zamba", "ZambaModel"),
("zamba2", "Zamba2Model"),
+ ("zaya", "ZayaModel"),
]
)
@@ -773,9 +773,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("xlstm", "xLSTMForCausalLM"),
("xmod", "XmodForCausalLM"),
("youtu", "YoutuForCausalLM"),
- ("zaya", "ZayaForCausalLM"),
("zamba", "ZambaForCausalLM"),
("zamba2", "Zamba2ForCausalLM"),
+ ("zaya", "ZayaForCausalLM"),
]
)
From 7e2999929315219eb0fa2bd9ed5dbd00962deb7e Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 11:39:48 +0800
Subject: [PATCH 07/36] modify config, laguna-sytle rope
---
docs/source/en/model_doc/zaya.md | 13 +-
src/transformers/models/zaya/modular_zaya.py | 131 +++++++++++--------
tests/models/zaya/test_modeling_zaya.py | 99 +++++++++++++-
3 files changed, 180 insertions(+), 63 deletions(-)
diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md
index 01e7a8504e1d..468f7327dd86 100644
--- a/docs/source/en/model_doc/zaya.md
+++ b/docs/source/en/model_doc/zaya.md
@@ -25,19 +25,26 @@ Convolutional Attention (CCA), a nonlinear ZAYA1 router, and residual scaling.
ZAYA1 uses the Gemma 3 tokenizer. For more details, see the [ZAYA1 model card](https://huggingface.co/Zyphra/ZAYA1-8B)
and Zyphra's technical reports.
+This model was contributed by [JJJYmmm](https://github.com/JJJYmmm).
+
## Usage examples
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
-
model_id = "Zyphra/ZAYA1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
-inputs = tokenizer.apply_chat_template([{"role": "user", "content": "Write a haiku about recursion in programming."}], tokenize=True, add_generation_prompt=True, enable_thinking=False, return_tensors="pt")
+inputs = tokenizer.apply_chat_template(
+ [{"role": "user", "content": "Write a haiku about recursion in programming."}],
+ tokenize=True,
+ add_generation_prompt=True,
+ enable_thinking=False,
+ return_tensors="pt",
+)
inputs = inputs.to(model.device)
-outputs = model.generate(**inputs, max_new_tokens=2048)
+outputs = model.generate(**inputs, max_new_tokens=256)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index d5bb4efd767e..6b7af760e37e 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -1,4 +1,4 @@
-# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved.
+# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,8 +14,8 @@
"""PyTorch Zaya model."""
-import copy
from collections.abc import Callable
+from typing import Any, Literal
import torch
import torch.nn.functional as F
@@ -45,7 +45,7 @@
)
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
-from ..glm4.modeling_glm4 import Glm4RotaryEmbedding
+from ..laguna.modeling_laguna import LagunaRotaryEmbedding
from ..qwen3_5_moe.modeling_qwen3_5_moe import (
apply_rotary_pos_emb,
eager_attention_forward,
@@ -58,11 +58,9 @@
class ZayaConfig(PreTrainedConfig):
r"""
ffn_hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the feed-forward and expert hidden states.
- num_query_groups (`int`, *optional*, defaults to 2):
- Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`.
- rope_theta (`float`, *optional*, defaults to 5000000):
- The base period of the RoPE embeddings.
+ Dimension of the feed-forward and expert hidden states, translate it to `intermediate_size`.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ Number of key/value groups.
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Fraction of each attention head dimension using rotary embeddings.
lm_head_bias (`bool`, *optional*, defaults to `False`):
@@ -75,7 +73,7 @@ class ZayaConfig(PreTrainedConfig):
First temporal parameter of the CCA projection.
cca_time1 (`int`, *optional*, defaults to 2):
Second temporal parameter of the CCA projection.
- swa_layers (`list[int]`, *optional*):
+ layer_types (`list[str]`, *optional*):
Per-layer selector for standard RoPE versus SWA RoPE embeddings.
swa_rotary_base (`float`, *optional*):
RoPE base used by SWA layers.
@@ -92,6 +90,7 @@ class ZayaConfig(PreTrainedConfig):
model_type = "zaya"
keys_to_ignore_at_inference = ["past_key_values"]
+ default_theta = 5000000.0
vocab_size: int = 262272
hidden_size: int = 2048
@@ -100,7 +99,6 @@ class ZayaConfig(PreTrainedConfig):
num_experts: int = 16
num_attention_heads: int = 8
num_key_value_heads: int | None = 2
- num_query_groups: int | None = 2
hidden_act: str = "silu"
head_dim: int = 128
max_position_embeddings: int = 131072
@@ -109,7 +107,6 @@ class ZayaConfig(PreTrainedConfig):
use_cache: bool = True
tie_word_embeddings: bool = True
rope_parameters: RopeParameters | dict | None = None
- rope_theta: float | int = 5000000
partial_rotary_factor: float = 0.5
attention_bias: bool = False
lm_head_bias: bool = False
@@ -118,8 +115,8 @@ class ZayaConfig(PreTrainedConfig):
zaya_mlp_expansion: int = 256
cca_time0: int | None = 2
cca_time1: int | None = 2
- swa_layers: list[int] | None = None
- swa_rotary_base: float | int | None = None
+ layer_types: list[str] | None = None
+ swa_rotary_base: float | int = 10000.0
output_router_logits: bool = False
pad_token_id: int | None = 0
bos_token_id: int | None = 2
@@ -128,6 +125,7 @@ class ZayaConfig(PreTrainedConfig):
def __post_init__(self, **kwargs):
for unused_checkpoint_kwarg in (
"cca",
+ "num_query_groups",
"activation_func",
"normalization",
"add_bias_linear",
@@ -149,35 +147,68 @@ def __post_init__(self, **kwargs):
):
kwargs.pop(unused_checkpoint_kwarg, None)
+ self.intermediate_size = self.ffn_hidden_size
+ self.num_experts_per_tok = self.moe_router_topk
+
self.num_key_value_heads = (
self.num_attention_heads if self.num_key_value_heads is None else self.num_key_value_heads
)
- self.num_query_groups = self.num_key_value_heads if self.num_query_groups is None else self.num_query_groups
- if self.head_dim is None:
- raise ValueError("`head_dim` must be set for ZAYA.")
- if self.num_query_groups != self.num_key_value_heads:
- raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.")
- if self.moe_router_topk != 1:
- raise ValueError("ZAYA currently supports `moe_router_topk=1` only.")
- self.rope_parameters = (
- dict(self.rope_parameters) if self.rope_parameters is not None else {"rope_type": "default"}
- )
- self.rope_parameters.setdefault("rope_theta", self.rope_theta)
- self.rope_parameters.setdefault("partial_rotary_factor", self.partial_rotary_factor)
+ legacy_swa_layers = kwargs.pop("swa_layers", None)
+ if self.layer_types is None:
+ if legacy_swa_layers is None:
+ self.layer_types = ["full_attention"] * self.num_hidden_layers
+ else:
+ self.layer_types = [
+ "full_attention" if layer_type == 0 else "sliding_attention" for layer_type in legacy_swa_layers
+ ]
+ else:
+ self.layer_types = list(self.layer_types)
+
self.cca_time0 = 2 if self.cca_time0 is None else self.cca_time0
self.cca_time1 = 2 if self.cca_time1 is None else self.cca_time1
- if (self.cca_time0, self.cca_time1) != (2, 2):
- raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.")
- if self.swa_layers is not None and len(self.swa_layers) != self.num_hidden_layers:
- raise ValueError("`swa_layers` must have one entry per hidden layer.")
- if self.swa_layers is not None and self.swa_rotary_base is None:
- raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.")
super().__post_init__(**kwargs)
+ def convert_rope_params_to_dict(self, **kwargs):
+ default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = {
+ "full_attention": {
+ "rope_type": "default",
+ "rope_theta": kwargs.pop("rope_theta", self.default_theta),
+ "partial_rotary_factor": self.partial_rotary_factor,
+ },
+ "sliding_attention": {
+ "rope_type": "default",
+ "rope_theta": self.swa_rotary_base,
+ "partial_rotary_factor": self.partial_rotary_factor,
+ },
+ }
+ layer_types = set(self.layer_types)
+
+ if self.rope_parameters is None:
+ self.rope_parameters = {layer_type: default_rope_params[layer_type] for layer_type in layer_types}
+ else:
+ self.rope_parameters = {
+ layer_type: {**default_rope_params[layer_type], **(self.rope_parameters.get(layer_type) or {})}
+ for layer_type in layer_types
+ }
+
+ return kwargs
-class ZayaRotaryEmbedding(Glm4RotaryEmbedding):
+ def validate_architecture(self):
+ if self.head_dim is None:
+ raise ValueError("`head_dim` must be set for ZAYA.")
+ if self.num_experts_per_tok != 1:
+ raise ValueError("ZAYA currently supports `moe_router_topk=1` only.")
+ if len(self.layer_types) != self.num_hidden_layers:
+ raise ValueError("`layer_types` must have one entry per hidden layer.")
+ if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}:
+ raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.")
+ if (self.cca_time0, self.cca_time1) != (2, 2):
+ raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.")
+
+
+class ZayaRotaryEmbedding(LagunaRotaryEmbedding):
pass
@@ -204,7 +235,7 @@ def __init__(
self.device = device
self.conv_kernel_size = (config.cca_time0 - 1) + (config.cca_time1 - 1)
self.num_layers = config.num_hidden_layers
- self.key_value_hidden_size = config.num_query_groups * config.head_dim
+ self.key_value_hidden_size = config.num_key_value_heads * config.head_dim
self.query_hidden_size = config.num_attention_heads * config.head_dim
self.conv_state_size = self.key_value_hidden_size + self.query_hidden_size
self.has_previous_state = False
@@ -439,7 +470,7 @@ def __init__(self, config: ZayaConfig, layer_n):
self.qkv = CCA(
config=self.config,
num_attention_heads=self.config.num_attention_heads,
- num_key_value_heads=self.config.num_query_groups,
+ num_key_value_heads=self.config.num_key_value_heads,
hidden_size=self.hidden_size,
head_dim=self.config.head_dim,
cca_time0=self.config.cca_time0,
@@ -639,11 +670,11 @@ def forward(
class ZayaExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
- def __init__(self, config, num_experts: int, ffn_hidden_size: int):
+ def __init__(self, config, num_experts: int, intermediate_size: int):
super().__init__()
self.num_experts = num_experts
self.hidden_dim = config.hidden_size
- self.intermediate_dim = ffn_hidden_size // 2
+ self.intermediate_dim = intermediate_size // 2
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
@@ -681,7 +712,7 @@ def __init__(
config,
num_moe_experts: int,
mlp_expansion: int,
- ffn_hidden_size: int,
+ intermediate_size: int,
layer_n: int,
):
super().__init__()
@@ -696,7 +727,7 @@ def __init__(
mlp_expansion=mlp_expansion,
hidden_size=self.hidden_dim,
)
- self.experts = ZayaExperts(self.config, self.num_moe_experts, ffn_hidden_size=ffn_hidden_size)
+ self.experts = ZayaExperts(self.config, self.num_moe_experts, intermediate_size=intermediate_size)
def forward(
self,
@@ -720,7 +751,7 @@ def __init__(
config: ZayaConfig,
num_moe_experts: int,
mlp_expansion: int,
- ffn_hidden_size: int,
+ intermediate_size: int,
layer_n: int,
):
super().__init__()
@@ -729,7 +760,7 @@ def __init__(
config,
num_moe_experts,
mlp_expansion,
- ffn_hidden_size,
+ intermediate_size,
layer_n,
)
self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
@@ -809,7 +840,7 @@ def __init__(self, config: ZayaConfig):
config,
config.num_experts,
config.zaya_mlp_expansion,
- config.ffn_hidden_size,
+ config.intermediate_size,
layer_n,
)
)
@@ -823,13 +854,6 @@ def __init__(self, config: ZayaConfig):
self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
self.rotary_emb = ZayaRotaryEmbedding(config=config)
- if self.config.swa_layers is not None:
- swa_config = copy.copy(config)
- swa_config.rope_parameters = {
- **config.rope_parameters,
- "rope_theta": swa_config.swa_rotary_base,
- }
- self.swa_rotary_emb = ZayaRotaryEmbedding(config=swa_config)
self.post_init()
@@ -896,19 +920,16 @@ def forward(
hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- if self.config.swa_layers is not None:
- swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids)
+ position_embeddings = {
+ layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) for layer_type in set(self.config.layer_types)
+ }
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
prev_router_hidden_states = None
for layer_n, decoder_layer in enumerate(self.layers):
- if self.config.swa_layers is not None:
- emb_to_use = position_embeddings if self.config.swa_layers[layer_n] == 0 else swa_position_embeddings
- else:
- emb_to_use = position_embeddings
+ emb_to_use = position_embeddings[self.config.layer_types[layer_n]]
if output_hidden_states:
all_hidden_states += (hidden_states,)
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index 2338d07675af..6264c05421b3 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -15,6 +15,7 @@
import unittest
+from huggingface_hub.errors import StrictDataclassClassValidationError
from parameterized import parameterized
from transformers import is_torch_available
@@ -48,7 +49,6 @@ def __init__(self, parent):
)
self.head_dim = 8
self.ffn_hidden_size = 64
- self.num_query_groups = 2
self.num_experts = 4
self.moe_router_topk = 1
self.zaya_mlp_expansion = 4
@@ -115,11 +115,95 @@ def test_attention_outputs(self):
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
@unittest.skip(
- "ZAYA uses partial rotary embeddings with CCA, which is not compatible with this generic RoPE test."
+ "RoPE-scaling-from-config test doesn't match ZAYA's nested per-layer-type rope_parameters (same as e.g. Laguna, Gemma3)."
)
def test_model_rope_scaling_from_config(self, scaling_type):
pass
+ def test_model_rope_scaling_frequencies(self):
+ """
+ Tests the frequency properties of the different RoPE scaling types on the model RoPE layer.
+ Copied from Laguna to adapt to per-layer-type rope configs.
+ """
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ config.layer_types = ["full_attention", "sliding_attention"]
+ partial_rotary_factor = config.partial_rotary_factor
+
+ def set_rope_params(rope_params):
+ config.rope_parameters = {
+ "full_attention": {**rope_params, "partial_rotary_factor": partial_rotary_factor},
+ "sliding_attention": {**rope_params, "partial_rotary_factor": partial_rotary_factor},
+ }
+
+ set_rope_params({"rope_type": "default", "rope_theta": 10_000.0})
+
+ base_model = self.model_tester.base_model_class(config)
+ possible_rope_attributes = [
+ "pos_emb",
+ "rotary_emb",
+ "global_rotary_emb",
+ "local_rotary_emb",
+ ]
+ for name, module in base_model.named_modules():
+ if any(potential_name in name for potential_name in possible_rope_attributes):
+ rope_class = type(module)
+ break
+
+ scaling_factor = 10
+ short_input_length = 10
+ long_input_length = int(config.max_position_embeddings * 1.5)
+
+ x = torch.randn(1, dtype=torch.float32, device=torch_device)
+ position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device).unsqueeze(0)
+ position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device).unsqueeze(0)
+
+ set_rope_params({"rope_type": "default", "rope_theta": 10_000.0})
+ original_rope = rope_class(config=config).to(torch_device)
+ original_cos_short, original_sin_short = original_rope(x, position_ids_short, layer_type="sliding_attention")
+ original_cos_long, original_sin_long = original_rope(x, position_ids_long, layer_type="sliding_attention")
+ torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
+ torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
+
+ set_rope_params({"rope_type": "linear", "factor": scaling_factor, "rope_theta": 10_000.0})
+ linear_scaling_rope = rope_class(config=config).to(torch_device)
+ linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short, layer_type="sliding_attention")
+ linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long, layer_type="sliding_attention")
+ torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
+ torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
+ for new_position in range(0, long_input_length, scaling_factor):
+ original_position = int(new_position // scaling_factor)
+ torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
+ torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
+
+ set_rope_params({"rope_type": "dynamic", "factor": scaling_factor, "rope_theta": 10_000.0})
+ ntk_scaling_rope = rope_class(config=config).to(torch_device)
+ ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short, layer_type="sliding_attention")
+ ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long, layer_type="sliding_attention")
+ torch.testing.assert_close(ntk_cos_short, original_cos_short)
+ torch.testing.assert_close(ntk_sin_short, original_sin_short)
+ with self.assertRaises(AssertionError):
+ torch.testing.assert_close(ntk_cos_long, original_cos_long)
+ with self.assertRaises(AssertionError):
+ torch.testing.assert_close(ntk_sin_long, original_sin_long)
+ self.assertTrue(
+ (ntk_scaling_rope.sliding_attention_inv_freq <= original_rope.sliding_attention_inv_freq).all()
+ )
+
+ set_rope_params({"rope_type": "yarn", "factor": scaling_factor, "rope_theta": 10_000.0})
+ yarn_scaling_rope = rope_class(config=config).to(torch_device)
+ yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short, layer_type="sliding_attention")
+ yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long, layer_type="sliding_attention")
+ torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
+ torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
+ with self.assertRaises(AssertionError):
+ torch.testing.assert_close(yarn_cos_short, original_cos_short)
+ with self.assertRaises(AssertionError):
+ torch.testing.assert_close(yarn_sin_short, original_sin_short)
+ with self.assertRaises(AssertionError):
+ torch.testing.assert_close(yarn_cos_long, original_cos_long)
+ with self.assertRaises(AssertionError):
+ torch.testing.assert_close(yarn_sin_long, original_sin_long)
+
@unittest.skip("ZAYA needs alternating attention and MoE layers in the tiny test configuration.")
def test_num_layers_is_small(self):
pass
@@ -153,9 +237,16 @@ def test_moe_router_logits(self):
)
def test_moe_router_topk_validation(self):
- with self.assertRaisesRegex(ValueError, "moe_router_topk=1"):
+ with self.assertRaisesRegex(StrictDataclassClassValidationError, "moe_router_topk=1"):
ZayaConfig(moe_router_topk=2)
+ def test_legacy_swa_layers_translate_to_layer_types(self):
+ config = ZayaConfig(num_hidden_layers=4, swa_layers=[0, 1, 0, 1], swa_rotary_base=10000)
+
+ self.assertEqual(config.layer_types, ["full_attention", "sliding_attention", "full_attention", "sliding_attention"])
+ self.assertEqual(config.rope_parameters["full_attention"]["rope_theta"], config.default_theta)
+ self.assertEqual(config.rope_parameters["sliding_attention"]["rope_theta"], 10000)
+
def test_cca_cache_matches_full_forward(self):
config = ZayaConfig(
vocab_size=128,
@@ -165,7 +256,6 @@ def test_cca_cache_matches_full_forward(self):
num_experts=4,
num_attention_heads=4,
num_key_value_heads=2,
- num_query_groups=2,
head_dim=8,
zaya_mlp_expansion=4,
tie_word_embeddings=False,
@@ -201,7 +291,6 @@ def test_cca_cache_matches_full_forward_multi_token(self):
num_experts=4,
num_attention_heads=4,
num_key_value_heads=2,
- num_query_groups=2,
head_dim=8,
zaya_mlp_expansion=4,
tie_word_embeddings=False,
From cf083aa17fbea23edab44cf62d114fa430eca9ed Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 12:01:07 +0800
Subject: [PATCH 08/36] use existing cache
---
src/transformers/models/zaya/modular_zaya.py | 130 +++----------------
tests/models/zaya/test_modeling_zaya.py | 85 ++++++------
2 files changed, 66 insertions(+), 149 deletions(-)
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 6b7af760e37e..a48be2edf7ee 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -14,6 +14,7 @@
"""PyTorch Zaya model."""
+import copy
from collections.abc import Callable
from typing import Any, Literal
@@ -216,95 +217,12 @@ class ZayaRMSNorm(Qwen3MoeRMSNorm):
pass
-class ZayaDynamicCache(DynamicCache):
- """
- Cache that includes both the KV cache and the CCA cache.
- """
-
- def __init__(
- self,
- config: ZayaConfig,
- batch_size: int,
- dtype: torch.dtype = torch.float16,
- device: str | None = None,
- ):
- super().__init__()
- self.config = config
- self.batch_size = batch_size
- self.dtype = dtype
- self.device = device
- self.conv_kernel_size = (config.cca_time0 - 1) + (config.cca_time1 - 1)
- self.num_layers = config.num_hidden_layers
- self.key_value_hidden_size = config.num_key_value_heads * config.head_dim
- self.query_hidden_size = config.num_attention_heads * config.head_dim
- self.conv_state_size = self.key_value_hidden_size + self.query_hidden_size
- self.has_previous_state = False
-
- self.conv_states = [None for _ in range(self.num_layers)]
- self.prev_v2 = [None for _ in range(self.num_layers)]
-
- def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor:
- if new_conv_state.shape[1] < self.conv_kernel_size:
- new_conv_state = F.pad(
- new_conv_state.transpose(1, 2), (self.conv_kernel_size - new_conv_state.shape[1], 0)
- )
- else:
- new_conv_state = new_conv_state[:, -self.conv_kernel_size :, :].transpose(1, 2)
-
- if self.conv_states[layer_idx] is None:
- self.conv_states[layer_idx] = torch.zeros_like(new_conv_state)
-
- if not self.has_previous_state:
- self.conv_states[layer_idx].copy_(new_conv_state)
- else:
- conv_state = torch.cat([self.conv_states[layer_idx], new_conv_state], dim=-1)[
- :, :, -self.conv_kernel_size :
- ]
- self.conv_states[layer_idx].copy_(conv_state)
- return self.conv_states[layer_idx]
-
- def update_prev_v2(self, layer_idx: int, new_prev_v2: torch.Tensor) -> torch.Tensor:
- if self.prev_v2[layer_idx] is None:
- self.prev_v2[layer_idx] = torch.zeros_like(new_prev_v2)
- self.prev_v2[layer_idx].copy_(new_prev_v2)
- return self.prev_v2[layer_idx]
-
- def reset(self):
- super().reset()
- for conv_state in self.conv_states:
- if conv_state is not None:
- conv_state.zero_()
- for prev_v2 in self.prev_v2:
- if prev_v2 is not None:
- prev_v2.zero_()
- self.has_previous_state = False
-
- def _reorder_auxiliary_states(self, indices: torch.LongTensor):
- for layer_idx, conv_state in enumerate(self.conv_states):
- if conv_state is not None:
- self.conv_states[layer_idx] = conv_state.index_select(0, indices.to(conv_state.device))
- for layer_idx, prev_v2 in enumerate(self.prev_v2):
- if prev_v2 is not None:
- self.prev_v2[layer_idx] = prev_v2.index_select(0, indices.to(prev_v2.device))
- self.batch_size = indices.shape[0]
-
- def reorder_cache(self, beam_idx: torch.LongTensor):
- super().reorder_cache(beam_idx)
- self._reorder_auxiliary_states(beam_idx)
-
- def batch_repeat_interleave(self, repeats: int):
- super().batch_repeat_interleave(repeats)
- for layer_idx, conv_state in enumerate(self.conv_states):
- if conv_state is not None:
- self.conv_states[layer_idx] = conv_state.repeat_interleave(repeats, dim=0)
- for layer_idx, prev_v2 in enumerate(self.prev_v2):
- if prev_v2 is not None:
- self.prev_v2[layer_idx] = prev_v2.repeat_interleave(repeats, dim=0)
- self.batch_size *= repeats
-
- def batch_select_indices(self, indices: torch.Tensor):
- super().batch_select_indices(indices)
- self._reorder_auxiliary_states(indices)
+def _make_zaya_cache(config: ZayaConfig) -> DynamicCache:
+ cache_config = copy.copy(config)
+ # layer_types is used to distinct the rope_type (full or swa)
+ # so need to construct a new layer_types to construct cache
+ cache_config.layer_types = ["hybrid" if layer_idx % 2 == 0 else "moe" for layer_idx in range(config.num_hidden_layers)]
+ return DynamicCache(config=cache_config)
class CCA(nn.Module):
@@ -370,7 +288,7 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
- past_key_values: ZayaDynamicCache | None,
+ past_key_values: Cache | None,
attention_mask: torch.Tensor | None = None,
):
if attention_mask is not None:
@@ -393,15 +311,18 @@ def forward(
).mean(dim=-2)
qk_states = qk_states.transpose(1, 2)
- use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state
+ use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_number)
if use_precomputed_states:
- cached_qk_states = past_key_values.conv_states[self.layer_number]
+ cached_qk_states = past_key_values.layers[self.layer_number].conv_states
conv_input = torch.cat([cached_qk_states, qk_states], dim=-1)
else:
conv_input = F.pad(qk_states, (self.total_padding, 0))
if past_key_values is not None:
- past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=qk_states.transpose(1, 2))
+ new_conv_state = qk_states[..., -self.total_padding :]
+ if new_conv_state.shape[-1] < self.total_padding:
+ new_conv_state = F.pad(new_conv_state, (self.total_padding - new_conv_state.shape[-1], 0))
+ past_key_values.update_conv_state(new_conv_state, self.layer_number)
convolved_qk_states = self.conv_qk(conv_input).transpose(1, 2)
@@ -422,13 +343,13 @@ def forward(
value_current = self.val_proj1(hidden_states)
projected_v2 = self.val_proj2(hidden_states)
if use_precomputed_states:
- first_v2 = past_key_values.prev_v2[self.layer_number].unsqueeze(1)
+ first_v2 = past_key_values.layers[self.layer_number].recurrent_states.unsqueeze(1)
else:
first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size))
value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1)
if past_key_values is not None:
- past_key_values.update_prev_v2(self.layer_number, projected_v2[:, -1, :])
+ past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_number)
value = torch.cat([value_current, value_delayed], dim=-1).view(
batch_size, seq_length, self.num_key_value_heads, self.head_dim
@@ -890,9 +811,7 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
- past_key_values = ZayaDynamicCache(
- self.config, inputs_embeds.shape[0], dtype=self.dtype, device=self.device
- )
+ past_key_values = _make_zaya_cache(self.config)
residual = None
@@ -912,7 +831,7 @@ def forward(
)
if attention_mask is not None and attention_mask.ndim != 2:
raise ValueError("ZAYA CCA requires a 2D `attention_mask` to mask padding tokens before convolution.")
- # ZayaDynamicCache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
+ # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
# CCA only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
if inputs_embeds.shape[1] == 1:
@@ -959,9 +878,6 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- if past_key_values and not past_key_values.has_previous_state:
- past_key_values.has_previous_state = True
-
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
@@ -1067,11 +983,6 @@ def prepare_inputs_for_generation(
logits_to_keep=None,
**kwargs,
):
- if past_key_values is not None and not isinstance(past_key_values, ZayaDynamicCache):
- raise ValueError(
- f"Zaya uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
- )
-
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
@@ -1096,10 +1007,7 @@ def _prepare_cache_for_generation(
return
if "past_key_values" not in model_kwargs:
- cache_batch_size = batch_size * max(generation_config.num_beams, generation_config.num_return_sequences)
- model_kwargs["past_key_values"] = ZayaDynamicCache(
- self.config, cache_batch_size, dtype=self.dtype, device=self.device
- )
+ model_kwargs["past_key_values"] = _make_zaya_cache(self.config)
generation_config.cache_implementation = None
return super()._prepare_cache_for_generation(
generation_config=generation_config,
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index 6264c05421b3..b710d136a554 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -26,7 +26,8 @@
import torch
from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel
- from transformers.models.zaya.modeling_zaya import CCA, ZayaDynamicCache
+ from transformers.cache_utils import DynamicCache, LinearAttentionAndFullAttentionLayer
+ from transformers.models.zaya.modeling_zaya import CCA, _make_zaya_cache
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
@@ -64,6 +65,36 @@ class ZayaModelTest(CausalLMModelTest, unittest.TestCase):
model_tester_class = ZayaModelTester
test_all_params_have_gradient = False
+ def _get_conv_state_shape(self, batch_size: int, config):
+ conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim
+ return (batch_size, conv_state_size, config.cca_time0 + config.cca_time1 - 2)
+
+ def _get_recurrent_state_shape(self, batch_size: int, config):
+ return (batch_size, config.num_key_value_heads * config.head_dim // 2)
+
+ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config):
+ if not isinstance(past_key_values, DynamicCache):
+ raise ValueError("The cache does not use the correct Cache")
+
+ config = config.get_text_config(decoder=True)
+ self.assertEqual(config.num_hidden_layers, len(past_key_values))
+ attention_shape = (batch_size, config.num_key_value_heads, seq_length, config.head_dim)
+ conv_shape = self._get_conv_state_shape(batch_size, config)
+ recurrent_shape = self._get_recurrent_state_shape(batch_size, config)
+
+ for layer_idx, layer in enumerate(past_key_values.layers):
+ if layer_idx % 2 == 0:
+ self.assertIs(type(layer), LinearAttentionAndFullAttentionLayer)
+ self.assertEqual(layer.keys.shape, attention_shape)
+ self.assertEqual(layer.values.shape, attention_shape)
+ self.assertEqual(layer.conv_states.shape, conv_shape)
+ self.assertEqual(layer.recurrent_states.shape, recurrent_shape)
+ else:
+ self.assertIsNone(layer.keys)
+ self.assertIsNone(layer.values)
+ self.assertIsNone(layer.conv_states)
+ self.assertIsNone(layer.recurrent_states)
+
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
@@ -208,18 +239,6 @@ def set_rope_params(rope_params):
def test_num_layers_is_small(self):
pass
- @unittest.skip("ZAYA uses a custom cache carrying CCA convolution state in addition to KV tensors.")
- def test_past_key_values_format(self):
- pass
-
- @unittest.skip("ZAYA's custom CCA cache is not a standard per-layer KV cache.")
- def test_greedy_generate_dict_outputs_use_cache(self):
- pass
-
- @unittest.skip("ZAYA's custom CCA cache is not a standard per-layer KV cache.")
- def test_beam_search_generate_dict_outputs_use_cache(self):
- pass
-
def test_moe_router_logits(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = self.model_tester.causal_lm_class(config)
@@ -274,9 +293,8 @@ def test_cca_cache_matches_full_forward(self):
with torch.no_grad():
full = cca(hidden_states, None, None)
- cache = ZayaDynamicCache(config, batch_size=1, dtype=hidden_states.dtype, device=torch_device)
+ cache = _make_zaya_cache(config)
cca(hidden_states[:, :4], cache, None)
- cache.has_previous_state = True
cached = cca(hidden_states[:, 4:], cache, None)
for full_states, cached_states in zip(full, cached):
@@ -309,47 +327,38 @@ def test_cca_cache_matches_full_forward_multi_token(self):
with torch.no_grad():
full = cca(hidden_states, None, None)
- cache = ZayaDynamicCache(config, batch_size=1, dtype=hidden_states.dtype, device=torch_device)
+ cache = _make_zaya_cache(config)
cca(hidden_states[:, :3], cache, None)
- cache.has_previous_state = True
cached = cca(hidden_states[:, 3:], cache, None)
for full_states, cached_states in zip(full, cached):
torch.testing.assert_close(full_states[:, 3:], cached_states, rtol=1e-5, atol=1e-5)
- def test_zaya_cache_batch_methods(self):
+ def test_zaya_cache_reorder_and_reset(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- cache = ZayaDynamicCache(config, batch_size=2, dtype=torch.float32, device=torch_device)
+ cache = _make_zaya_cache(config)
+ conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim
cache.update_conv_state(
- 0,
- torch.arange(2 * 2 * cache.conv_state_size, device=torch_device, dtype=torch.float32).view(
- 2, 2, cache.conv_state_size
+ torch.arange(2 * conv_state_size * 2, device=torch_device, dtype=torch.float32).view(
+ 2, conv_state_size, 2
),
- )
- cache.update_prev_v2(
0,
+ )
+ cache.update_recurrent_state(
torch.arange(
2 * config.num_key_value_heads * config.head_dim // 2, device=torch_device, dtype=torch.float32
).view(2, config.num_key_value_heads * config.head_dim // 2),
+ 0,
)
- self.assertEqual(cache.prev_v2[0].shape[-1], config.num_key_value_heads * config.head_dim // 2)
-
- cache.batch_repeat_interleave(2)
- self.assertEqual(cache.conv_states[0].shape[0], 4)
- self.assertEqual(cache.prev_v2[0].shape[0], 4)
-
- cache.batch_select_indices(torch.tensor([3, 1], device=torch_device))
- self.assertEqual(cache.conv_states[0].shape[0], 2)
- self.assertEqual(cache.prev_v2[0].shape[0], 2)
+ self.assertEqual(cache.layers[0].recurrent_states.shape[-1], config.num_key_value_heads * config.head_dim // 2)
cache.reorder_cache(torch.tensor([1, 0], device=torch_device))
- self.assertEqual(cache.batch_size, 2)
+ self.assertEqual(cache.layers[0].conv_states.shape[0], 2)
- cache.has_previous_state = True
cache.reset()
- self.assertFalse(cache.has_previous_state)
- self.assertEqual(cache.conv_states[0].sum().item(), 0)
- self.assertEqual(cache.prev_v2[0].sum().item(), 0)
+ self.assertFalse(cache.has_previous_state(0))
+ self.assertEqual(cache.layers[0].conv_states.sum().item(), 0)
+ self.assertEqual(cache.layers[0].recurrent_states.sum().item(), 0)
@require_torch
From 69d09f3f56692bcfd917856dea6ee98d311aa51b Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 14:05:24 +0800
Subject: [PATCH 09/36] cca refine + use llama attn
---
src/transformers/conversion_mapping.py | 2 +
src/transformers/models/zaya/modular_zaya.py | 140 +++++++++----------
2 files changed, 67 insertions(+), 75 deletions(-)
diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py
index dff0f65f5b53..0bf2c311845b 100755
--- a/src/transformers/conversion_mapping.py
+++ b/src/transformers/conversion_mapping.py
@@ -562,6 +562,8 @@ def _build_checkpoint_conversion_mapping():
),
],
"zaya": [
+ WeightRenaming(r"self_attn\.qkv\.conv_qk\.0\.", "self_attn.qkv.conv_qk_depthwise."),
+ WeightRenaming(r"self_attn\.qkv\.conv_qk\.1\.", "self_attn.qkv.conv_qk_grouped."),
WeightConverter(
source_patterns="zaya_block.experts.local_experts.*.linear_fc1.weight",
target_patterns="zaya_block.experts.gate_up_proj",
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index a48be2edf7ee..14c97c4dbc56 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -47,6 +47,7 @@
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..laguna.modeling_laguna import LagunaRotaryEmbedding
+from ..llama.modeling_llama import LlamaAttention
from ..qwen3_5_moe.modeling_qwen3_5_moe import (
apply_rotary_pos_emb,
eager_attention_forward,
@@ -201,6 +202,8 @@ def validate_architecture(self):
raise ValueError("`head_dim` must be set for ZAYA.")
if self.num_experts_per_tok != 1:
raise ValueError("ZAYA currently supports `moe_router_topk=1` only.")
+ if self.num_attention_heads % self.num_key_value_heads != 0:
+ raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
if len(self.layer_types) != self.num_hidden_layers:
raise ValueError("`layer_types` must have one entry per hidden layer.")
if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}:
@@ -221,42 +224,43 @@ def _make_zaya_cache(config: ZayaConfig) -> DynamicCache:
cache_config = copy.copy(config)
# layer_types is used to distinct the rope_type (full or swa)
# so need to construct a new layer_types to construct cache
- cache_config.layer_types = ["hybrid" if layer_idx % 2 == 0 else "moe" for layer_idx in range(config.num_hidden_layers)]
+ cache_config.layer_types = [
+ "hybrid" if layer_idx % 2 == 0 else "moe" for layer_idx in range(config.num_hidden_layers)
+ ]
return DynamicCache(config=cache_config)
-class CCA(nn.Module):
- def __init__(
- self,
- config: ZayaConfig,
- num_key_value_heads: int = 2,
- num_attention_heads: int = 8,
- hidden_size: int | None = None,
- head_dim: int = 128,
- cca_time0: int = 2,
- cca_time1: int = 2,
- layer_number: int = 0,
- ):
+class ZayaCCAProjection(nn.Module):
+ """
+ Projects hidden states into attention q/k/v states with ZAYA's CCA path.
+
+ `linear_q` and `linear_k` produce the residual q/k states and are concatenated into `qk_states`. The causal
+ `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail;
+ for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`.
+ Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses
+ `val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**.
+
+ The final q/k states are L2-normalized. `temp` is the learned per-KV-head scale applied to keys.
+ """
+
+ def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__()
self.config = config
- self.layer_number = layer_number
+ self.layer_idx = layer_idx
- self.hidden_size = int(hidden_size or config.hidden_size)
+ self.hidden_size = config.hidden_size
- self.depthwise_kernel_size = cca_time0
- self.grouped_kernel_size = cca_time1
+ self.depthwise_kernel_size = config.cca_time0
+ self.grouped_kernel_size = config.cca_time1
self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1)
- self.num_key_value_heads = int(num_key_value_heads)
- self.num_attention_heads = int(num_attention_heads)
-
- self.head_dim = int(head_dim)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
self.key_value_hidden_size = self.num_key_value_heads * self.head_dim
self.query_hidden_size = self.num_attention_heads * self.head_dim
self.sqrt_head_dim = self.head_dim**0.5
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
- if self.num_attention_heads % self.num_key_value_heads != 0:
- raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias)
self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias)
@@ -264,23 +268,21 @@ def __init__(
self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias)
conv_channels = self.key_value_hidden_size + self.query_hidden_size
- self.conv_qk = nn.Sequential(
- nn.Conv1d(
- in_channels=conv_channels,
- out_channels=conv_channels,
- kernel_size=self.depthwise_kernel_size,
- groups=conv_channels,
- padding=0,
- stride=1,
- ),
- nn.Conv1d(
- in_channels=conv_channels,
- out_channels=conv_channels,
- kernel_size=self.grouped_kernel_size,
- groups=(self.num_key_value_heads + self.num_attention_heads),
- padding=0,
- stride=1,
- ),
+ self.conv_qk_depthwise = nn.Conv1d(
+ in_channels=conv_channels,
+ out_channels=conv_channels,
+ kernel_size=self.depthwise_kernel_size,
+ groups=conv_channels,
+ padding=0,
+ stride=1,
+ )
+ self.conv_qk_grouped = nn.Conv1d(
+ in_channels=conv_channels,
+ out_channels=conv_channels,
+ kernel_size=self.grouped_kernel_size,
+ groups=(self.num_key_value_heads + self.num_attention_heads),
+ padding=0,
+ stride=1,
)
self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads))
@@ -311,9 +313,9 @@ def forward(
).mean(dim=-2)
qk_states = qk_states.transpose(1, 2)
- use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_number)
+ use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx)
if use_precomputed_states:
- cached_qk_states = past_key_values.layers[self.layer_number].conv_states
+ cached_qk_states = past_key_values.layers[self.layer_idx].conv_states
conv_input = torch.cat([cached_qk_states, qk_states], dim=-1)
else:
conv_input = F.pad(qk_states, (self.total_padding, 0))
@@ -322,9 +324,10 @@ def forward(
new_conv_state = qk_states[..., -self.total_padding :]
if new_conv_state.shape[-1] < self.total_padding:
new_conv_state = F.pad(new_conv_state, (self.total_padding - new_conv_state.shape[-1], 0))
- past_key_values.update_conv_state(new_conv_state, self.layer_number)
+ past_key_values.update_conv_state(new_conv_state, self.layer_idx)
- convolved_qk_states = self.conv_qk(conv_input).transpose(1, 2)
+ convolved_qk_states = self.conv_qk_depthwise(conv_input)
+ convolved_qk_states = self.conv_qk_grouped(convolved_qk_states).transpose(1, 2)
query = (
convolved_qk_states[..., : self.query_hidden_size].view(
@@ -343,13 +346,13 @@ def forward(
value_current = self.val_proj1(hidden_states)
projected_v2 = self.val_proj2(hidden_states)
if use_precomputed_states:
- first_v2 = past_key_values.layers[self.layer_number].recurrent_states.unsqueeze(1)
+ first_v2 = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1)
else:
first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size))
value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1)
if past_key_values is not None:
- past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_number)
+ past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_idx)
value = torch.cat([value_current, value_delayed], dim=-1).view(
batch_size, seq_length, self.num_key_value_heads, self.head_dim
@@ -368,35 +371,20 @@ def forward(
return query, key, value
-class ZayaAttention(nn.Module):
- def __init__(self, config: ZayaConfig, layer_n):
- super().__init__()
- self.config = config
- self.layer_n = layer_n
- self.layer_idx = layer_n
+class ZayaAttention(LlamaAttention):
+ def __init__(self, config: ZayaConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.layer_n = layer_idx
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
- self.is_causal = True
- self.attention_dropout = config.attention_dropout
- self.head_dim = config.head_dim
- self.scaling = self.head_dim**-0.5
- self.o_proj = nn.Linear(
- self.num_attention_heads * self.head_dim,
- self.hidden_size,
- bias=self.config.attention_bias,
- )
- self.qkv = CCA(
+ del self.q_proj
+ del self.k_proj
+ del self.v_proj
+ self.qkv = ZayaCCAProjection(
config=self.config,
- num_attention_heads=self.config.num_attention_heads,
- num_key_value_heads=self.config.num_key_value_heads,
- hidden_size=self.hidden_size,
- head_dim=self.config.head_dim,
- cca_time0=self.config.cca_time0,
- cca_time1=self.config.cca_time1,
- layer_number=layer_n,
+ layer_idx=layer_idx,
)
def forward(
@@ -541,8 +529,7 @@ def __init__(
zaya_first_layer = 1
self.use_eda = self.layer_idx != zaya_first_layer
- ln_eps = float(getattr(config, "norm_epsilon", 1e-5))
- self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=ln_eps)
+ self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=config.norm_epsilon)
if self.use_eda:
self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion))
@@ -830,9 +817,11 @@ def forward(
past_key_values,
)
if attention_mask is not None and attention_mask.ndim != 2:
- raise ValueError("ZAYA CCA requires a 2D `attention_mask` to mask padding tokens before convolution.")
+ raise ValueError(
+ "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution."
+ )
# ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
- # CCA only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
+ # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
if inputs_embeds.shape[1] == 1:
attention_mask_2d = None
@@ -840,7 +829,8 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = {
- layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) for layer_type in set(self.config.layer_types)
+ layer_type: self.rotary_emb(hidden_states, position_ids, layer_type)
+ for layer_type in set(self.config.layer_types)
}
all_hidden_states = () if output_hidden_states else None
From d936d54a15f496cb4a85c0780db7885cf6cb9306 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 14:10:30 +0800
Subject: [PATCH 10/36] use dict for 2d/4d mask
---
src/transformers/models/zaya/modular_zaya.py | 28 +++++++++++---------
1 file changed, 16 insertions(+), 12 deletions(-)
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 14c97c4dbc56..5863d125f619 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -390,14 +390,21 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- attention_mask_2d: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | dict[str, torch.Tensor | None] | None = None,
past_key_values: Cache | None = None,
output_attentions: bool = False,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
batch_size, seq_length, _ = hidden_states.shape
- query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, attention_mask_2d)
+
+ if isinstance(attention_mask, dict):
+ causal_mask = attention_mask.get("causal")
+ padding_mask = attention_mask.get("padding")
+ else:
+ causal_mask = attention_mask
+ padding_mask = None
+
+ query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask)
query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads, self.head_dim)
key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim)
value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim)
@@ -412,8 +419,7 @@ def forward(
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n)
- causal_mask = attention_mask
- if causal_mask is not None:
+ if isinstance(causal_mask, torch.Tensor):
causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
@@ -452,8 +458,7 @@ def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- attention_mask_2d: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | dict[str, torch.Tensor | None] | None = None,
past_key_values: Cache | None = None,
output_attentions: bool | None = False,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
@@ -465,7 +470,6 @@ def forward(
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
- attention_mask_2d=attention_mask_2d,
past_key_values=past_key_values,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
@@ -822,9 +826,10 @@ def forward(
)
# ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
# CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
- attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
+ padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
if inputs_embeds.shape[1] == 1:
- attention_mask_2d = None
+ padding_mask = None
+ attention_masks = {"causal": causal_mask, "padding": padding_mask}
hidden_states = inputs_embeds
@@ -845,13 +850,12 @@ def forward(
layer_outputs = decoder_layer(
hidden_states,
residual,
- attention_mask=causal_mask,
+ attention_mask=attention_masks,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
position_embeddings=emb_to_use,
prev_router_hidden_states=prev_router_hidden_states,
- attention_mask_2d=attention_mask_2d,
**kwargs,
)
From 733e687cf069ea12078dfd51f8f03b57c97b9595 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 15:39:12 +0800
Subject: [PATCH 11/36] optimize, reuse existing code
---
src/transformers/models/zaya/modular_zaya.py | 230 ++++++++-----------
tests/models/zaya/test_modeling_zaya.py | 57 +++--
2 files changed, 135 insertions(+), 152 deletions(-)
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 5863d125f619..14f35f909634 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -29,8 +29,7 @@
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...generation import GenerationMixin
-from ...integrations import use_experts_implementation
-from ...masking_utils import create_causal_mask
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
MoeCausalLMOutputWithPast,
@@ -47,12 +46,12 @@
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..laguna.modeling_laguna import LagunaRotaryEmbedding
-from ..llama.modeling_llama import LlamaAttention
+from ..llama.modeling_llama import LlamaAttention, LlamaPreTrainedModel
from ..qwen3_5_moe.modeling_qwen3_5_moe import (
apply_rotary_pos_emb,
eager_attention_forward,
)
-from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm
+from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeExperts, Qwen3MoeRMSNorm
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
@@ -117,6 +116,7 @@ class ZayaConfig(PreTrainedConfig):
zaya_mlp_expansion: int = 256
cca_time0: int | None = 2
cca_time1: int | None = 2
+ sliding_window: int | None = None
layer_types: list[str] | None = None
swa_rotary_base: float | int = 10000.0
output_router_logits: bool = False
@@ -142,7 +142,6 @@ def __post_init__(self, **kwargs):
"residual_in_fp32",
"rope_scaling",
"scale_residual_merge",
- "sliding_window",
"zaya_high_prec",
"zaya_use_mod",
"zaya_use_eda",
@@ -157,6 +156,9 @@ def __post_init__(self, **kwargs):
)
legacy_swa_layers = kwargs.pop("swa_layers", None)
+ swa_window_sizes = {int(window_size) for window_size in (legacy_swa_layers or []) if int(window_size) > 0}
+ if self.sliding_window is None and swa_window_sizes:
+ self.sliding_window = max(swa_window_sizes)
if self.layer_types is None:
if legacy_swa_layers is None:
self.layer_types = ["full_attention"] * self.num_hidden_layers
@@ -201,13 +203,17 @@ def validate_architecture(self):
if self.head_dim is None:
raise ValueError("`head_dim` must be set for ZAYA.")
if self.num_experts_per_tok != 1:
- raise ValueError("ZAYA currently supports `moe_router_topk=1` only.")
+ raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.")
if self.num_attention_heads % self.num_key_value_heads != 0:
raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
if len(self.layer_types) != self.num_hidden_layers:
raise ValueError("`layer_types` must have one entry per hidden layer.")
if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}:
raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.")
+ if "sliding_attention" in self.layer_types and self.sliding_window is None:
+ raise ValueError("`sliding_window` must be set when `layer_types` contains `sliding_attention`.")
+ if self.sliding_window is not None and self.sliding_window <= 0:
+ raise ValueError("`sliding_window` must be a strictly positive integer.")
if (self.cca_time0, self.cca_time1) != (2, 2):
raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.")
@@ -239,8 +245,6 @@ class ZayaCCAProjection(nn.Module):
for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`.
Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses
`val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**.
-
- The final q/k states are L2-normalized. `temp` is the learned per-KV-head scale applied to keys.
"""
def __init__(self, config: ZayaConfig, layer_idx: int):
@@ -259,7 +263,6 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
self.head_dim = config.head_dim
self.key_value_hidden_size = self.num_key_value_heads * self.head_dim
self.query_hidden_size = self.num_attention_heads * self.head_dim
- self.sqrt_head_dim = self.head_dim**0.5
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias)
@@ -296,20 +299,20 @@ def forward(
if attention_mask is not None:
hidden_states = hidden_states * attention_mask[:, :, None].to(hidden_states.dtype)
- batch_size, seq_length, _ = hidden_states.shape
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
projected_queries = self.linear_q(hidden_states)
projected_keys = self.linear_k(hidden_states)
qk_states = torch.cat([projected_queries, projected_keys], dim=-1)
- query_residual = projected_queries.view(batch_size, seq_length, self.num_attention_heads, self.head_dim)
- key_residual = projected_keys.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
+ query_residual = projected_queries.view(*hidden_shape)
+ key_residual = projected_keys.view(*input_shape, self.num_key_value_heads, self.head_dim)
- key_residual = key_residual.unsqueeze(-2).expand(-1, -1, -1, self.num_key_value_groups, -1)
- key_residual = key_residual.reshape(batch_size, seq_length, self.num_attention_heads, self.head_dim)
+ key_residual = key_residual.repeat_interleave(self.num_key_value_groups, dim=-2)
query_residual = (query_residual + key_residual) * 0.5
key_residual = query_residual.view(
- batch_size, seq_length, self.num_key_value_heads, self.num_key_value_groups, self.head_dim
+ *input_shape, self.num_key_value_heads, self.num_key_value_groups, self.head_dim
).mean(dim=-2)
qk_states = qk_states.transpose(1, 2)
@@ -331,14 +334,14 @@ def forward(
query = (
convolved_qk_states[..., : self.query_hidden_size].view(
- batch_size, seq_length, self.num_attention_heads, self.head_dim
+ *input_shape, self.num_attention_heads, self.head_dim
)
+ query_residual
)
key = (
convolved_qk_states[..., self.query_hidden_size :].view(
- batch_size, seq_length, self.num_key_value_heads, self.head_dim
+ *input_shape, self.num_key_value_heads, self.head_dim
)
+ key_residual
)
@@ -348,26 +351,16 @@ def forward(
if use_precomputed_states:
first_v2 = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1)
else:
- first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size))
+ first_v2 = self.val_proj2(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size))
value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1)
if past_key_values is not None:
past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_idx)
value = torch.cat([value_current, value_delayed], dim=-1).view(
- batch_size, seq_length, self.num_key_value_heads, self.head_dim
+ *input_shape, self.num_key_value_heads, self.head_dim
)
- norm_eps = torch.finfo(query.dtype).eps
- query_norm = query.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
- key_norm = key.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
-
- key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1)
- query = query * (self.sqrt_head_dim / query_norm)
-
- query = query.reshape(batch_size, seq_length, self.query_hidden_size)
- key = key.reshape(batch_size, seq_length, self.key_value_hidden_size)
- value = value.reshape(batch_size, seq_length, self.key_value_hidden_size)
return query, key, value
@@ -375,6 +368,8 @@ class ZayaAttention(LlamaAttention):
def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.layer_n = layer_idx
+ self.layer_type = config.layer_types[layer_idx]
+ self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
@@ -405,9 +400,14 @@ def forward(
padding_mask = None
query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask)
- query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads, self.head_dim)
- key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim)
- value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim)
+
+ norm_eps = torch.finfo(query_states.dtype).eps
+ head_dim_scale = self.scaling**-1
+ query_states = query_states * (
+ head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
+ )
+ key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps))
+ key_states = key_states * self.qkv.temp[None, None, :, None]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -433,15 +433,13 @@ def forward(
causal_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
+ sliding_window=self.sliding_window,
output_attentions=output_attentions,
)
attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
- if not output_attentions:
- attn_weights = None
-
return attn_output, attn_weights, past_key_values
@@ -457,14 +455,14 @@ def __init__(self, config: ZayaConfig, layer_n: int):
def forward(
self,
hidden_states: torch.Tensor,
- residual: torch.Tensor,
+ residual: torch.Tensor | None,
attention_mask: torch.Tensor | dict[str, torch.Tensor | None] | None = None,
past_key_values: Cache | None = None,
output_attentions: bool | None = False,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
prev_router_hidden_states: torch.Tensor | None = None,
**kwargs,
- ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]:
hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
hidden_states, self_attn_weights, _ = self.self_attn(
@@ -508,13 +506,27 @@ def _apply_residual_scaling(
return hidden_states, residual
+class ZayaRouterMLP(nn.Module):
+ def __init__(self, hidden_size: int, num_experts: int):
+ super().__init__()
+ self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True)
+ self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True)
+ self.out_proj = nn.Linear(hidden_size, num_experts, bias=False)
+ self.act_fn = nn.GELU()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.act_fn(self.fc1(hidden_states))
+ hidden_states = self.act_fn(self.fc2(hidden_states))
+ return self.out_proj(hidden_states)
+
+
class ZayaRouter(nn.Module):
def __init__(
self,
config,
layer_idx: int,
num_moe_experts: int,
- moe_router_topk: int,
+ num_experts_per_tok: int,
mlp_expansion: int,
hidden_size: int | None = None,
) -> None:
@@ -525,7 +537,7 @@ def __init__(
self.layer_idx = layer_idx
self.num_experts = num_moe_experts + 1
- self.topk = int(moe_router_topk)
+ self.topk = int(num_experts_per_tok)
self.mlp_expansion = int(mlp_expansion)
self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True)
@@ -537,14 +549,7 @@ def __init__(
if self.use_eda:
self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion))
- self.non_linearity = nn.GELU()
- self.router_mlp = nn.Sequential(
- nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True),
- self.non_linearity,
- nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True),
- self.non_linearity,
- nn.Linear(self.mlp_expansion, self.num_experts, bias=False),
- )
+ self.router_mlp = ZayaRouterMLP(self.mlp_expansion, self.num_experts)
self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32))
self.balancing_biases[-1] = -1.0
@@ -578,12 +583,9 @@ def forward(
)
-@use_experts_implementation
-class ZayaExperts(nn.Module):
- """Collection of expert weights stored as 3D tensors."""
-
+class ZayaExperts(Qwen3MoeExperts):
def __init__(self, config, num_experts: int, intermediate_size: int):
- super().__init__()
+ nn.Module.__init__(self)
self.num_experts = num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = intermediate_size // 2
@@ -591,34 +593,8 @@ def __init__(self, config, num_experts: int, intermediate_size: int):
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
- def forward(
- self,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
- ) -> torch.Tensor:
- final_hidden_states = torch.zeros_like(hidden_states)
- with torch.no_grad():
- expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts + 1)
- expert_mask = expert_mask.permute(2, 1, 0)
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
-
- for expert_idx in expert_hit:
- expert_idx = expert_idx[0]
- if expert_idx == self.num_experts:
- continue
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
- current_state = hidden_states[token_idx]
- gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
- current_hidden_states = self.act_fn(gate) * up
- current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
- current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
- final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
-
- return final_hidden_states
-
-
-class ZayaBlock(nn.Module):
+
+class ZayaSparseMoeBlock(nn.Module):
def __init__(
self,
config,
@@ -635,7 +611,7 @@ def __init__(
config=self.config,
layer_idx=layer_n,
num_moe_experts=self.num_moe_experts,
- moe_router_topk=getattr(self.config, "moe_router_topk", 1),
+ num_experts_per_tok=self.config.num_experts_per_tok,
mlp_expansion=mlp_expansion,
hidden_size=self.hidden_dim,
)
@@ -649,6 +625,10 @@ def forward(
route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router(
hidden_states, router_states=prev_router_hidden_states
)
+ skip_expert = expert_choice == self.num_moe_experts
+ route_prob = route_prob.masked_fill(skip_expert, 0)
+ expert_choice = expert_choice.masked_fill(skip_expert, 0)
+
batch_size, seq_length, emb_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim)
expert_output = self.experts(hidden_states_flat, expert_choice, route_prob)
@@ -668,7 +648,7 @@ def __init__(
):
super().__init__()
self.config = config
- self.zaya_block = ZayaBlock(
+ self.zaya_block = ZayaSparseMoeBlock(
config,
num_moe_experts,
mlp_expansion,
@@ -701,24 +681,21 @@ def forward(
)
-class ZayaPreTrainedModel(PreTrainedModel):
+class ZayaPreTrainedModel(LlamaPreTrainedModel):
config: ZayaConfig
config_class = ZayaConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
_no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"]
- _skip_keys_device_placement = ["past_key_values"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _supports_attention_backend = True
+ # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache.
+ _can_compile_fullgraph = False
_can_record_outputs = {
"router_logits": OutputRecorder(ZayaRouter, index=3),
+ "hidden_states": [ZayaDecoderATTLayer, ZayaDecoderMLPLayer],
+ "attentions": ZayaAttention,
}
@torch.no_grad()
def _init_weights(self, module):
- super()._init_weights(module)
+ PreTrainedModel._init_weights(self, module)
if isinstance(module, ResidualScaling):
init.ones_(module.hidden_states_scale)
init.zeros_(module.hidden_states_bias)
@@ -786,18 +763,11 @@ def forward(
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- output_router_logits: bool | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> MoeModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@@ -814,22 +784,25 @@ def forward(
device=inputs_embeds.device,
).unsqueeze(0)
- causal_mask = self._update_causal_mask(
- attention_mask,
- inputs_embeds,
- position_ids,
- past_key_values,
- )
- if attention_mask is not None and attention_mask.ndim != 2:
+ if isinstance(attention_mask, dict):
+ causal_mask_mapping = attention_mask
+ padding_mask = None
+ else:
+ causal_mask_mapping = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds,
+ position_ids,
+ past_key_values,
+ )
+ padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
+ if attention_mask is not None and not isinstance(attention_mask, dict) and attention_mask.ndim != 2:
raise ValueError(
"ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution."
)
# ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
# CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
- padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
if inputs_embeds.shape[1] == 1:
padding_mask = None
- attention_masks = {"causal": causal_mask, "padding": padding_mask}
hidden_states = inputs_embeds
@@ -838,22 +811,18 @@ def forward(
for layer_type in set(self.config.layer_types)
}
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
prev_router_hidden_states = None
for layer_n, decoder_layer in enumerate(self.layers):
- emb_to_use = position_embeddings[self.config.layer_types[layer_n]]
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
+ layer_type = self.config.layer_types[layer_n]
+ emb_to_use = position_embeddings[layer_type]
+ attention_mask = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask}
layer_outputs = decoder_layer(
hidden_states,
residual,
- attention_mask=attention_masks,
+ attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
- output_attentions=output_attentions,
position_embeddings=emb_to_use,
prev_router_hidden_states=prev_router_hidden_states,
**kwargs,
@@ -863,20 +832,11 @@ def forward(
residual = layer_outputs[2]
prev_router_hidden_states = layer_outputs[3]
- if isinstance(decoder_layer, ZayaDecoderATTLayer):
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm)
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
)
def _update_causal_mask(
@@ -886,13 +846,21 @@ def _update_causal_mask(
position_ids: torch.Tensor,
past_key_values: Cache,
):
- return create_causal_mask(
- config=self.config,
- inputs_embeds=input_tensor,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- position_ids=position_ids,
- )
+ mask_kwargs = {
+ "config": self.config,
+ "inputs_embeds": input_tensor,
+ "attention_mask": attention_mask,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ mask_creation_functions = {
+ "full_attention": lambda: create_causal_mask(**mask_kwargs),
+ "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs),
+ }
+ causal_mask_mapping = {}
+ for layer_type in set(self.config.layer_types):
+ causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]()
+ return causal_mask_mapping
@auto_docstring
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index b710d136a554..5e16b744c989 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -27,7 +27,7 @@
from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel
from transformers.cache_utils import DynamicCache, LinearAttentionAndFullAttentionLayer
- from transformers.models.zaya.modeling_zaya import CCA, _make_zaya_cache
+ from transformers.models.zaya.modeling_zaya import ZayaCCAProjection, _make_zaya_cache
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
@@ -90,8 +90,8 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l
self.assertEqual(layer.conv_states.shape, conv_shape)
self.assertEqual(layer.recurrent_states.shape, recurrent_shape)
else:
- self.assertIsNone(layer.keys)
- self.assertIsNone(layer.values)
+ self.assertIsNone(getattr(layer, "keys", None))
+ self.assertIsNone(getattr(layer, "values", None))
self.assertIsNone(layer.conv_states)
self.assertIsNone(layer.recurrent_states)
@@ -260,12 +260,41 @@ def test_moe_router_topk_validation(self):
ZayaConfig(moe_router_topk=2)
def test_legacy_swa_layers_translate_to_layer_types(self):
- config = ZayaConfig(num_hidden_layers=4, swa_layers=[0, 1, 0, 1], swa_rotary_base=10000)
+ config = ZayaConfig(num_hidden_layers=4, swa_layers=[4096, 0, 4096, 0], swa_rotary_base=10000)
- self.assertEqual(config.layer_types, ["full_attention", "sliding_attention", "full_attention", "sliding_attention"])
+ self.assertEqual(
+ config.layer_types, ["sliding_attention", "full_attention", "sliding_attention", "full_attention"]
+ )
+ self.assertEqual(config.sliding_window, 4096)
self.assertEqual(config.rope_parameters["full_attention"]["rope_theta"], config.default_theta)
self.assertEqual(config.rope_parameters["sliding_attention"]["rope_theta"], 10000)
+ def test_sliding_attention_mask_is_used(self):
+ config = ZayaConfig(
+ vocab_size=128,
+ hidden_size=32,
+ ffn_hidden_size=64,
+ num_hidden_layers=4,
+ num_experts=4,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ head_dim=8,
+ zaya_mlp_expansion=4,
+ layer_types=["sliding_attention", "full_attention", "full_attention", "full_attention"],
+ sliding_window=3,
+ tie_word_embeddings=False,
+ attn_implementation="eager",
+ )
+ model = ZayaModel(config).to(torch_device)
+ model.eval()
+ input_ids = torch.arange(6, device=torch_device).unsqueeze(0)
+
+ with torch.no_grad():
+ outputs = model(input_ids=input_ids, output_attentions=True)
+
+ sliding_attention = outputs.attentions[0]
+ self.assertTrue(torch.all(sliding_attention[:, :, -1, :3] == 0))
+
def test_cca_cache_matches_full_forward(self):
config = ZayaConfig(
vocab_size=128,
@@ -280,14 +309,7 @@ def test_cca_cache_matches_full_forward(self):
tie_word_embeddings=False,
)
torch.manual_seed(0)
- cca = CCA(
- config,
- num_key_value_heads=config.num_key_value_heads,
- num_attention_heads=config.num_attention_heads,
- hidden_size=config.hidden_size,
- head_dim=config.head_dim,
- layer_number=0,
- ).to(torch_device)
+ cca = ZayaCCAProjection(config, layer_idx=0).to(torch_device)
cca.eval()
hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device)
@@ -314,14 +336,7 @@ def test_cca_cache_matches_full_forward_multi_token(self):
tie_word_embeddings=False,
)
torch.manual_seed(0)
- cca = CCA(
- config,
- num_key_value_heads=config.num_key_value_heads,
- num_attention_heads=config.num_attention_heads,
- hidden_size=config.hidden_size,
- head_dim=config.head_dim,
- layer_number=0,
- ).to(torch_device)
+ cca = ZayaCCAProjection(config, layer_idx=0).to(torch_device)
cca.eval()
hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device)
From eb7c8cc7cc5224fe32b8e735402f5659ea708da3 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 16:20:13 +0800
Subject: [PATCH 12/36] inherit from AfmoeForCausalLM, but need to construct
cache from _make_zaya_cache
---
src/transformers/models/zaya/modular_zaya.py | 129 +++----------------
1 file changed, 16 insertions(+), 113 deletions(-)
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 14f35f909634..50cad3bd10ea 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -26,25 +26,21 @@
from torch.nn import init
from ...activations import ACT2FN
-from ...cache_utils import Cache, DynamicCache
+from ...cache_utils import Cache, DynamicCache, LinearAttentionAndFullAttentionLayer
from ...configuration_utils import PreTrainedConfig
-from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
-from ...modeling_outputs import (
- MoeCausalLMOutputWithPast,
- MoeModelOutputWithPast,
-)
+from ...modeling_outputs import MoeModelOutputWithPast
from ...modeling_rope_utils import RopeParameters
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
TransformersKwargs,
auto_docstring,
- can_return_tuple,
)
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
+from ..afmoe.modeling_afmoe import AfmoeForCausalLM
from ..laguna.modeling_laguna import LagunaRotaryEmbedding
from ..llama.modeling_llama import LlamaAttention, LlamaPreTrainedModel
from ..qwen3_5_moe.modeling_qwen3_5_moe import (
@@ -236,6 +232,14 @@ def _make_zaya_cache(config: ZayaConfig) -> DynamicCache:
return DynamicCache(config=cache_config)
+def _is_zaya_cache(past_key_values: Cache) -> bool:
+ return (
+ isinstance(past_key_values, DynamicCache)
+ and len(past_key_values.layers) > 0
+ and isinstance(past_key_values.layers[0], LinearAttentionAndFullAttentionLayer)
+ )
+
+
class ZayaCCAProjection(nn.Module):
"""
Projects hidden states into attention q/k/v states with ZAYA's CCA path.
@@ -771,7 +775,9 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
+ if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)):
+ if past_key_values is not None and past_key_values.get_seq_length() > 0:
+ raise ValueError("ZAYA requires a native hybrid cache created from `_make_zaya_cache`.")
past_key_values = _make_zaya_cache(self.config)
residual = None
@@ -863,8 +869,8 @@ def _update_causal_mask(
return causal_mask_mapping
-@auto_docstring
-class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin):
+@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
+class ZayaForCausalLM(ZayaPreTrainedModel, AfmoeForCausalLM):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_is_stateful = True
@@ -873,112 +879,9 @@ def __init__(self, config, **kwargs):
self.model = ZayaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
- if self.config.tie_word_embeddings:
- self.lm_head.weight = self.model.embed_tokens.weight
self.post_init()
- def set_decoder(self, decoder):
- self.model = decoder
-
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- output_router_logits: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> MoeCausalLMOutputWithPast:
- output_router_logits = (
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
- )
-
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_router_logits=output_router_logits,
- **kwargs,
- )
-
- hidden_states = outputs.last_hidden_state
-
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(
- logits=logits,
- labels=labels,
- vocab_size=self.config.vocab_size,
- **kwargs,
- )
-
- return MoeCausalLMOutputWithPast(
- loss=loss,
- aux_loss=None,
- logits=logits,
- past_key_values=outputs.past_key_values if use_cache else None,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- router_logits=outputs.router_logits,
- )
-
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- position_ids=None,
- use_cache=True,
- logits_to_keep=None,
- **kwargs,
- ):
- model_inputs = super().prepare_inputs_for_generation(
- input_ids=input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- position_ids=position_ids,
- use_cache=use_cache,
- logits_to_keep=logits_to_keep,
- **kwargs,
- )
- return model_inputs
-
- def _prepare_cache_for_generation(
- self,
- generation_config,
- model_kwargs: dict,
- generation_mode,
- batch_size: int,
- max_cache_length: int,
- ):
- if generation_config.use_cache is False:
- return
-
- if "past_key_values" not in model_kwargs:
- model_kwargs["past_key_values"] = _make_zaya_cache(self.config)
- generation_config.cache_implementation = None
- return super()._prepare_cache_for_generation(
- generation_config=generation_config,
- model_kwargs=model_kwargs,
- generation_mode=generation_mode,
- batch_size=batch_size,
- max_cache_length=max_cache_length,
- )
-
__all__ = [
"ZayaConfig",
From 4d5bda4e29b7d5c6cd9ef062fb9e1f540b77300c Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 18:43:45 +0800
Subject: [PATCH 13/36] checkpoint conversion
---
docs/source/en/model_doc/zaya.md | 8 +
src/transformers/conversion_mapping.py | 3 +
.../models/zaya/configuration_zaya.py | 120 ++-
.../models/zaya/convert_zaya_weights_to_hf.py | 335 ++++++++
src/transformers/models/zaya/modeling_zaya.py | 761 +++++++-----------
src/transformers/models/zaya/modular_zaya.py | 300 +++----
tests/models/zaya/test_modeling_zaya.py | 62 +-
7 files changed, 838 insertions(+), 751 deletions(-)
create mode 100644 src/transformers/models/zaya/convert_zaya_weights_to_hf.py
diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md
index 468f7327dd86..24468b8df65f 100644
--- a/docs/source/en/model_doc/zaya.md
+++ b/docs/source/en/model_doc/zaya.md
@@ -27,6 +27,14 @@ and Zyphra's technical reports.
This model was contributed by [JJJYmmm](https://github.com/JJJYmmm).
+
+
+When building a manual generation loop with `past_key_values`, use [`~models.zaya.modeling_zaya.make_zaya_cache`] to
+create ZAYA's cache. ZAYA uses `config.layer_types` for full/sliding attention masks and RoPE parameters, while its
+cache uses the native hybrid layout needed by the attention, convolution, and recurrent states.
+
+
+
## Usage examples
```python
diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py
index 0bf2c311845b..bb333bdcb4ce 100755
--- a/src/transformers/conversion_mapping.py
+++ b/src/transformers/conversion_mapping.py
@@ -564,6 +564,9 @@ def _build_checkpoint_conversion_mapping():
"zaya": [
WeightRenaming(r"self_attn\.qkv\.conv_qk\.0\.", "self_attn.qkv.conv_qk_depthwise."),
WeightRenaming(r"self_attn\.qkv\.conv_qk\.1\.", "self_attn.qkv.conv_qk_grouped."),
+ WeightRenaming(r"zaya_block\.router\.router_mlp\.0\.", "zaya_block.router.router_mlp.fc1."),
+ WeightRenaming(r"zaya_block\.router\.router_mlp\.2\.", "zaya_block.router.router_mlp.fc2."),
+ WeightRenaming(r"zaya_block\.router\.router_mlp\.4\.", "zaya_block.router.router_mlp.out_proj."),
WeightConverter(
source_patterns="zaya_block.experts.local_experts.*.linear_fc1.weight",
target_patterns="zaya_block.experts.gate_up_proj",
diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py
index 12a7c2999abc..479d07dea7d4 100644
--- a/src/transformers/models/zaya/configuration_zaya.py
+++ b/src/transformers/models/zaya/configuration_zaya.py
@@ -4,7 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_zaya.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved.
+# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Literal
+
from huggingface_hub.dataclasses import strict
from ...configuration_utils import PreTrainedConfig
@@ -29,17 +31,15 @@
@strict
class ZayaConfig(PreTrainedConfig):
r"""
- ffn_hidden_size (`int`, *optional*, defaults to 4096):
+ intermediate_size (`int`, *optional*, defaults to 4096):
Dimension of the feed-forward and expert hidden states.
- num_query_groups (`int`, *optional*, defaults to 2):
- Number of query groups. For ZAYA checkpoints this matches `num_key_value_heads`.
- rope_theta (`float`, *optional*, defaults to 5000000):
- The base period of the RoPE embeddings.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ Number of key/value groups.
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Fraction of each attention head dimension using rotary embeddings.
lm_head_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the language modeling head.
- moe_router_topk (`int`, *optional*, defaults to 1):
+ num_experts_per_tok (`int`, *optional*, defaults to 1):
Number of selected experts per token. ZAYA checkpoints use top-1 routing.
zaya_mlp_expansion (`int`, *optional*, defaults to 256):
Expansion size used by the dense ZAYA blocks.
@@ -47,7 +47,7 @@ class ZayaConfig(PreTrainedConfig):
First temporal parameter of the CCA projection.
cca_time1 (`int`, *optional*, defaults to 2):
Second temporal parameter of the CCA projection.
- swa_layers (`list[int]`, *optional*):
+ layer_types (`list[str]`, *optional*):
Per-layer selector for standard RoPE versus SWA RoPE embeddings.
swa_rotary_base (`float`, *optional*):
RoPE base used by SWA layers.
@@ -64,15 +64,15 @@ class ZayaConfig(PreTrainedConfig):
model_type = "zaya"
keys_to_ignore_at_inference = ["past_key_values"]
+ default_theta = 5000000.0
vocab_size: int = 262272
hidden_size: int = 2048
- ffn_hidden_size: int = 4096
- num_hidden_layers: int = 80
+ intermediate_size: int = 4096
+ num_hidden_layers: int = 40
num_experts: int = 16
num_attention_heads: int = 8
- num_key_value_heads: int | None = 2
- num_query_groups: int | None = 2
+ num_key_value_heads: int = 2
hidden_act: str = "silu"
head_dim: int = 128
max_position_embeddings: int = 131072
@@ -81,72 +81,64 @@ class ZayaConfig(PreTrainedConfig):
use_cache: bool = True
tie_word_embeddings: bool = True
rope_parameters: RopeParameters | dict | None = None
- rope_theta: float | int = 5000000
partial_rotary_factor: float = 0.5
attention_bias: bool = False
lm_head_bias: bool = False
attention_dropout: float | int = 0.0
- moe_router_topk: int = 1
+ num_experts_per_tok: int = 1
zaya_mlp_expansion: int = 256
- cca_time0: int | None = 2
- cca_time1: int | None = 2
- swa_layers: list[int] | None = None
- swa_rotary_base: float | int | None = None
+ cca_time0: int = 2
+ cca_time1: int = 2
+ sliding_window: int | None = None
+ layer_types: list[str] | None = None
+ swa_rotary_base: float | int = 10000.0
output_router_logits: bool = False
pad_token_id: int | None = 0
bos_token_id: int | None = 2
eos_token_id: int | list[int] | None = 106
def __post_init__(self, **kwargs):
- for unused_checkpoint_kwarg in (
- "cca",
- "activation_func",
- "normalization",
- "add_bias_linear",
- "gated_linear_unit",
- "fused_add_norm",
- "apply_rope_fusion",
- "bias_activation_fusion",
- "activation_func_fp8_input_store",
- "clamp_temp",
- "kv_channels",
- "mamba_cache_dtype",
- "residual_in_fp32",
- "rope_scaling",
- "scale_residual_merge",
- "sliding_window",
- "zaya_high_prec",
- "zaya_use_mod",
- "zaya_use_eda",
- ):
- kwargs.pop(unused_checkpoint_kwarg, None)
-
- self.num_key_value_heads = (
- self.num_attention_heads if self.num_key_value_heads is None else self.num_key_value_heads
- )
- self.num_query_groups = self.num_key_value_heads if self.num_query_groups is None else self.num_query_groups
- if self.head_dim is None:
- raise ValueError("`head_dim` must be set for ZAYA.")
- if self.num_query_groups != self.num_key_value_heads:
- raise ValueError("`num_query_groups` must be equal to `num_key_value_heads` for ZAYA.")
- if self.moe_router_topk != 1:
- raise ValueError("ZAYA currently supports `moe_router_topk=1` only.")
-
- self.rope_parameters = (
- dict(self.rope_parameters) if self.rope_parameters is not None else {"rope_type": "default"}
+ self.layer_types = (
+ ["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types)
)
- self.rope_parameters.setdefault("rope_theta", self.rope_theta)
- self.rope_parameters.setdefault("partial_rotary_factor", self.partial_rotary_factor)
- self.cca_time0 = 2 if self.cca_time0 is None else self.cca_time0
- self.cca_time1 = 2 if self.cca_time1 is None else self.cca_time1
- if (self.cca_time0, self.cca_time1) != (2, 2):
- raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.")
- if self.swa_layers is not None and len(self.swa_layers) != self.num_hidden_layers:
- raise ValueError("`swa_layers` must have one entry per hidden layer.")
- if self.swa_layers is not None and self.swa_rotary_base is None:
- raise ValueError("`swa_rotary_base` must be set when `swa_layers` is provided.")
+
+ default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = {
+ "full_attention": {
+ "rope_type": "default",
+ "rope_theta": self.default_theta,
+ "partial_rotary_factor": self.partial_rotary_factor,
+ },
+ "sliding_attention": {
+ "rope_type": "default",
+ "rope_theta": self.swa_rotary_base,
+ "partial_rotary_factor": self.partial_rotary_factor,
+ },
+ }
+ if self.rope_parameters is None:
+ self.rope_parameters = {
+ layer_type: default_rope_params[layer_type] for layer_type in set(self.layer_types)
+ }
super().__post_init__(**kwargs)
+ def convert_rope_params_to_dict(self, **kwargs):
+ # ZAYA uses nested RoPE parameters keyed by layer type. Keep the base RoPE BC conversion from treating them
+ # like a single flat RoPE dict and injecting top-level keys such as `rope_theta`.
+ return kwargs
+
+ def validate_architecture(self):
+ if self.num_experts_per_tok != 1:
+ raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.")
+ if self.num_attention_heads % self.num_key_value_heads != 0:
+ raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
+ if len(self.layer_types) != self.num_hidden_layers:
+ raise ValueError("`layer_types` must have one entry per hidden layer.")
+ if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}:
+ raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.")
+ if "sliding_attention" in self.layer_types and self.sliding_window is None:
+ raise ValueError("`sliding_window` must be set when `layer_types` contains `sliding_attention`.")
+ if self.sliding_window is not None and self.sliding_window <= 0:
+ raise ValueError("`sliding_window` must be a strictly positive integer.")
+
__all__ = ["ZayaConfig"]
diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
new file mode 100644
index 000000000000..ba9198b9c666
--- /dev/null
+++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
@@ -0,0 +1,335 @@
+# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert original alternating-layer ZAYA checkpoints to the Transformers-native decoder-layer layout."""
+
+import argparse
+import json
+import re
+import shutil
+from collections import defaultdict
+from pathlib import Path
+
+import torch
+from safetensors import safe_open
+from safetensors.torch import save_file
+
+from transformers import ZayaConfig
+
+
+_LAYER_PATTERN = re.compile(r"^model\.layers\.(\d+)\.(.+)$")
+_LOCAL_EXPERT_PATTERN = re.compile(
+ r"^model\.layers\.(\d+)\.zaya_block\.experts\.local_experts\.(\d+)\.linear_fc([12])\.weight$"
+)
+
+_UNUSED_CONFIG_KEYS = (
+ "cca",
+ "num_query_groups",
+ "ffn_hidden_size",
+ "moe_router_topk",
+ "activation_func",
+ "normalization",
+ "add_bias_linear",
+ "gated_linear_unit",
+ "fused_add_norm",
+ "apply_rope_fusion",
+ "bias_activation_fusion",
+ "activation_func_fp8_input_store",
+ "clamp_temp",
+ "kv_channels",
+ "mamba_cache_dtype",
+ "residual_in_fp32",
+ "rope_scaling",
+ "scale_residual_merge",
+ "zaya_high_prec",
+ "zaya_use_mod",
+ "zaya_use_eda",
+)
+
+
+def _rename_common(rest: str) -> str:
+ replacements = (
+ ("self_attn.qkv.conv_qk.0.", "self_attn.qkv.conv_qk_depthwise."),
+ ("self_attn.qkv.conv_qk.1.", "self_attn.qkv.conv_qk_grouped."),
+ ("zaya_block.router.router_mlp.0.", "zaya_block.router.router_mlp.fc1."),
+ ("zaya_block.router.router_mlp.2.", "zaya_block.router.router_mlp.fc2."),
+ ("zaya_block.router.router_mlp.4.", "zaya_block.router.router_mlp.out_proj."),
+ )
+ for old, new in replacements:
+ if rest.startswith(old):
+ return new + rest.removeprefix(old)
+ return rest
+
+
+def _expert_target(name: str) -> tuple[str, int] | None:
+ match = _LOCAL_EXPERT_PATTERN.match(name)
+ if match is None:
+ return None
+
+ old_layer_idx = int(match.group(1))
+ if old_layer_idx % 2 != 1:
+ raise ValueError(f"Expert weights are expected on odd ZAYA layers, got: {name}")
+
+ new_layer_idx = old_layer_idx // 2
+ expert_idx = int(match.group(2))
+ projection = "gate_up_proj" if match.group(3) == "1" else "down_proj"
+ target = f"model.layers.{new_layer_idx}.zaya_block.experts.{projection}"
+ return target, expert_idx
+
+
+def convert_weight_name(name: str) -> str | None:
+ if _expert_target(name) is not None:
+ return None
+
+ match = _LAYER_PATTERN.match(name)
+ if match is None:
+ return name
+
+ old_layer_idx = int(match.group(1))
+ rest = match.group(2)
+ new_layer_idx = old_layer_idx // 2
+
+ if old_layer_idx % 2 == 0:
+ rest = _rename_common(rest)
+ if rest.startswith(("self_attn.", "input_norm.", "res_scale.")):
+ return f"model.layers.{new_layer_idx}.{rest}"
+ else:
+ rest = _rename_common(rest)
+ if rest.startswith("zaya_block."):
+ return f"model.layers.{new_layer_idx}.{rest}"
+ if rest.startswith("input_norm."):
+ return f"model.layers.{new_layer_idx}.post_attention_norm.{rest.removeprefix('input_norm.')}"
+ if rest.startswith("res_scale."):
+ return f"model.layers.{new_layer_idx}.post_attention_res_scale.{rest.removeprefix('res_scale.')}"
+
+ raise ValueError(f"Unexpected ZAYA layer weight name: {name}")
+
+
+def _convert_layer_types(config_dict: dict, old_num_hidden_layers: int, new_num_hidden_layers: int) -> list[str]:
+ layer_types = config_dict.get("layer_types")
+ if layer_types is not None:
+ if len(layer_types) == old_num_hidden_layers:
+ return layer_types[::2]
+ if len(layer_types) == new_num_hidden_layers:
+ return list(layer_types)
+ raise ValueError("`layer_types` must match either the original or converted number of hidden layers.")
+
+ swa_layers = config_dict.get("swa_layers")
+ if swa_layers is None:
+ return ["full_attention"] * new_num_hidden_layers
+ if len(swa_layers) == old_num_hidden_layers:
+ swa_layers = swa_layers[::2]
+ elif len(swa_layers) != new_num_hidden_layers:
+ raise ValueError("`swa_layers` must match either the original or converted number of hidden layers.")
+ return ["full_attention" if int(window_size) == 0 else "sliding_attention" for window_size in swa_layers]
+
+
+def convert_config(input_dir: Path, output_dir: Path) -> None:
+ config_dict = json.loads((input_dir / "config.json").read_text())
+ old_num_hidden_layers = int(config_dict["num_hidden_layers"])
+ if old_num_hidden_layers % 2 != 0:
+ raise ValueError("Original ZAYA checkpoints must have an even number of alternating attention/MoE layers.")
+
+ new_num_hidden_layers = old_num_hidden_layers // 2
+ layer_types = _convert_layer_types(config_dict, old_num_hidden_layers, new_num_hidden_layers)
+ partial_rotary_factor = config_dict.get("partial_rotary_factor", ZayaConfig.partial_rotary_factor)
+ rope_theta = config_dict.get("rope_theta", ZayaConfig.default_theta)
+ swa_rotary_base = config_dict.get("swa_rotary_base", ZayaConfig.swa_rotary_base)
+ intermediate_size = config_dict.get(
+ "intermediate_size", config_dict.get("ffn_hidden_size", ZayaConfig.intermediate_size)
+ )
+ num_experts_per_tok = config_dict.get(
+ "num_experts_per_tok", config_dict.get("moe_router_topk", ZayaConfig.num_experts_per_tok)
+ )
+
+ swa_layers = config_dict.get("swa_layers") or []
+ sliding_window = config_dict.get("sliding_window")
+ if sliding_window is None:
+ positive_windows = [int(window_size) for window_size in swa_layers if int(window_size) > 0]
+ sliding_window = max(positive_windows) if positive_windows else None
+
+ rope_parameters = {
+ "full_attention": {
+ "rope_type": "default",
+ "rope_theta": rope_theta,
+ "partial_rotary_factor": partial_rotary_factor,
+ },
+ "sliding_attention": {
+ "rope_type": "default",
+ "rope_theta": swa_rotary_base,
+ "partial_rotary_factor": partial_rotary_factor,
+ },
+ }
+
+ for key in (*_UNUSED_CONFIG_KEYS, "swa_layers", "rope_theta"):
+ config_dict.pop(key, None)
+
+ config_dict.update(
+ {
+ "architectures": ["ZayaForCausalLM"],
+ "num_hidden_layers": new_num_hidden_layers,
+ "intermediate_size": intermediate_size,
+ "num_experts_per_tok": num_experts_per_tok,
+ "layer_types": layer_types,
+ "sliding_window": sliding_window,
+ "swa_rotary_base": swa_rotary_base,
+ "rope_parameters": {layer_type: rope_parameters[layer_type] for layer_type in set(layer_types)},
+ }
+ )
+ ZayaConfig(**config_dict).save_pretrained(output_dir)
+
+
+def copy_non_weight_files(input_dir: Path, output_dir: Path) -> None:
+ for path in input_dir.iterdir():
+ if path.name == "config.json":
+ continue
+ if path.name.endswith(".safetensors") or path.name.endswith(".bin"):
+ continue
+ if path.name in {"model.safetensors.index.json", "pytorch_model.bin.index.json"}:
+ continue
+
+ output_path = output_dir / path.name
+ if path.is_dir():
+ shutil.copytree(path, output_path, dirs_exist_ok=True)
+ else:
+ shutil.copy2(path, output_path)
+
+
+def _build_weight_plan(input_dir: Path) -> tuple[dict[str, str], dict[str, list[str]], dict[str, str], dict]:
+ index = json.loads((input_dir / "model.safetensors.index.json").read_text())
+ old_weight_map = index["weight_map"]
+ converted_weight_map = {}
+ normal_sources_by_output_file = defaultdict(list)
+ expert_sources_by_target = defaultdict(list)
+ output_file_by_target = {}
+
+ for source_key, filename in old_weight_map.items():
+ expert_info = _expert_target(source_key)
+ if expert_info is not None:
+ target_key, expert_idx = expert_info
+ expert_sources_by_target[target_key].append((expert_idx, source_key))
+ output_file_by_target.setdefault(target_key, filename)
+ converted_weight_map[target_key] = output_file_by_target[target_key]
+ continue
+
+ target_key = convert_weight_name(source_key)
+ if target_key in converted_weight_map:
+ raise ValueError(f"Duplicate converted weight name: {target_key}")
+ converted_weight_map[target_key] = filename
+ normal_sources_by_output_file[filename].append((source_key, target_key))
+
+ index["weight_map"] = converted_weight_map
+ return normal_sources_by_output_file, expert_sources_by_target, output_file_by_target, index
+
+
+def _load_sources(input_dir: Path, source_keys: list[str], old_weight_map: dict[str, str]) -> dict[str, torch.Tensor]:
+ sources_by_file = defaultdict(list)
+ for source_key in source_keys:
+ sources_by_file[old_weight_map[source_key]].append(source_key)
+
+ tensors = {}
+ for filename, keys in sources_by_file.items():
+ with safe_open(input_dir / filename, framework="pt", device="cpu") as f:
+ for key in keys:
+ tensors[key] = f.get_tensor(key)
+ return tensors
+
+
+def convert_safetensors(input_dir: Path, output_dir: Path) -> None:
+ index_path = input_dir / "model.safetensors.index.json"
+ if not index_path.exists():
+ safetensors_path = input_dir / "model.safetensors"
+ if not safetensors_path.exists():
+ raise FileNotFoundError("Only safetensors ZAYA checkpoints are supported by this converter.")
+
+ with safe_open(safetensors_path, framework="pt", device="cpu") as f:
+ metadata = f.metadata()
+ state_dict = {}
+ expert_groups = defaultdict(list)
+ for key in f.keys():
+ expert_info = _expert_target(key)
+ if expert_info is not None:
+ target_key, expert_idx = expert_info
+ expert_groups[target_key].append((expert_idx, f.get_tensor(key)))
+ continue
+ state_dict[convert_weight_name(key)] = f.get_tensor(key)
+ for target_key, expert_tensors in expert_groups.items():
+ state_dict[target_key] = torch.stack([tensor for _, tensor in sorted(expert_tensors)], dim=0)
+ save_file(state_dict, output_dir / "model.safetensors", metadata=metadata)
+ return
+
+ old_index = json.loads(index_path.read_text())
+ old_weight_map = old_index["weight_map"]
+ normal_sources_by_output_file, expert_sources_by_target, output_file_by_target, converted_index = (
+ _build_weight_plan(input_dir)
+ )
+ output_filenames = sorted(set(converted_index["weight_map"].values()))
+
+ metadata_by_file = {}
+ for filename in sorted(set(old_weight_map.values())):
+ with safe_open(input_dir / filename, framework="pt", device="cpu") as f:
+ metadata_by_file[filename] = f.metadata()
+
+ for output_filename in output_filenames:
+ shard = {}
+ normal_sources = normal_sources_by_output_file.get(output_filename, [])
+ source_keys = [source_key for source_key, _ in normal_sources]
+
+ expert_groups_for_shard = {
+ target_key: sorted(sources)
+ for target_key, sources in expert_sources_by_target.items()
+ if output_file_by_target[target_key] == output_filename
+ }
+ for sources in expert_groups_for_shard.values():
+ source_keys.extend(source_key for _, source_key in sources)
+
+ loaded_tensors = _load_sources(input_dir, source_keys, old_weight_map)
+ for source_key, target_key in normal_sources:
+ shard[target_key] = loaded_tensors[source_key]
+ for target_key, sources in expert_groups_for_shard.items():
+ shard[target_key] = torch.stack([loaded_tensors[source_key] for _, source_key in sources], dim=0)
+
+ save_file(shard, output_dir / output_filename, metadata=metadata_by_file.get(output_filename))
+
+ (output_dir / "model.safetensors.index.json").write_text(
+ json.dumps(converted_index, indent=2, sort_keys=True) + "\n"
+ )
+
+
+def convert_checkpoint(input_dir: str, output_dir: str, overwrite: bool = False) -> None:
+ input_path = Path(input_dir).expanduser().resolve()
+ output_path = Path(output_dir).expanduser().resolve()
+ if input_path == output_path:
+ raise ValueError("Please write the converted checkpoint to a different output directory.")
+ if output_path.exists() and any(output_path.iterdir()):
+ if not overwrite:
+ raise FileExistsError(f"{output_path} already exists and is not empty. Pass --overwrite to replace it.")
+ shutil.rmtree(output_path)
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ copy_non_weight_files(input_path, output_path)
+ convert_config(input_path, output_path)
+ convert_safetensors(input_path, output_path)
+
+
+def main():
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("--input_dir", required=True, help="Path to the original alternating-layer ZAYA checkpoint.")
+ parser.add_argument("--output_dir", required=True, help="Path where the converted checkpoint should be written.")
+ parser.add_argument("--overwrite", action="store_true", help="Overwrite a non-empty output directory.")
+ args = parser.parse_args()
+ convert_checkpoint(args.input_dir, args.output_dir, overwrite=args.overwrite)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index ab68cbc73d36..20662110b172 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -4,7 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_zaya.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved.
+# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@
import copy
from collections.abc import Callable
-from typing import Optional
+from typing import Any, Optional
import torch
import torch.nn.functional as F
@@ -29,10 +29,10 @@
from torch.nn import init
from ...activations import ACT2FN
-from ...cache_utils import Cache, DynamicCache
+from ...cache_utils import Cache, DynamicCache, LinearAttentionAndFullAttentionLayer
from ...generation import GenerationMixin
-from ...integrations import use_experts_implementation, use_kernel_forward_from_hub
-from ...masking_utils import create_causal_mask
+from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@@ -47,27 +47,33 @@
class ZayaRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: ZayaConfig, device=None):
+ def __init__(self, config: ZayaConfig):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
-
self.config = config
+ self.layer_types = list(set(config.layer_types))
+ self.rope_type = {}
+ for layer_type in self.layer_types:
+ rope_params = self.config.rope_parameters[layer_type]
+ if rope_params is None:
+ continue
- self.rope_type = self.config.rope_parameters["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
-
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
+ self.rope_type[layer_type] = rope_params["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type[layer_type] != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
+ curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, layer_type=layer_type)
+ self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
+ setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
@staticmethod
def compute_default_rope_parameters(
config: ZayaConfig | None = None,
device: Optional["torch.device"] = None,
seq_len: int | None = None,
+ layer_type: str | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
@@ -78,12 +84,16 @@ def compute_default_rope_parameters(
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
+ layer_type (`str`, *optional*):
+ The current layer type if the model has different RoPE parameters per type.
+ Should not be used unless `config.layer_types is not None`
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
- base = config.rope_parameters["rope_theta"]
- partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
+ base = config.rope_parameters[layer_type]["rope_theta"]
+ # key difference to gemma3: partial rope
+ partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0)
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
dim = int(head_dim * partial_rotary_factor)
@@ -97,16 +107,19 @@ def compute_default_rope_parameters(
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ def forward(self, x, position_ids, layer_type=None):
+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
+
+ inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
+ cos = emb.cos() * attention_scaling
+ sin = emb.sin() * attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@@ -132,129 +145,36 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
-class ZayaDynamicCache(DynamicCache):
- """
- Cache that includes both the KV cache and the CCA cache.
+class ZayaCCAProjection(nn.Module):
"""
+ Projects hidden states into attention q/k/v states with ZAYA's CCA path.
- def __init__(
- self,
- config: ZayaConfig,
- batch_size: int,
- dtype: torch.dtype = torch.float16,
- device: str | None = None,
- ):
- super().__init__()
- self.config = config
- self.batch_size = batch_size
- self.dtype = dtype
- self.device = device
- self.conv_kernel_size = (config.cca_time0 - 1) + (config.cca_time1 - 1)
- self.num_layers = config.num_hidden_layers
- self.key_value_hidden_size = config.num_query_groups * config.head_dim
- self.query_hidden_size = config.num_attention_heads * config.head_dim
- self.conv_state_size = self.key_value_hidden_size + self.query_hidden_size
- self.has_previous_state = False
-
- self.conv_states = [None for _ in range(self.num_layers)]
- self.prev_v2 = [None for _ in range(self.num_layers)]
-
- def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor:
- if new_conv_state.shape[1] < self.conv_kernel_size:
- new_conv_state = F.pad(
- new_conv_state.transpose(1, 2), (self.conv_kernel_size - new_conv_state.shape[1], 0)
- )
- else:
- new_conv_state = new_conv_state[:, -self.conv_kernel_size :, :].transpose(1, 2)
+ `linear_q` and `linear_k` produce the residual q/k states and are concatenated into `qk_states`. The causal
+ `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail;
+ for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`.
+ Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses
+ `val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**.
- if self.conv_states[layer_idx] is None:
- self.conv_states[layer_idx] = torch.zeros_like(new_conv_state)
+ Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys.
+ """
- if not self.has_previous_state:
- self.conv_states[layer_idx].copy_(new_conv_state)
- else:
- conv_state = torch.cat([self.conv_states[layer_idx], new_conv_state], dim=-1)[
- :, :, -self.conv_kernel_size :
- ]
- self.conv_states[layer_idx].copy_(conv_state)
- return self.conv_states[layer_idx]
-
- def update_prev_v2(self, layer_idx: int, new_prev_v2: torch.Tensor) -> torch.Tensor:
- if self.prev_v2[layer_idx] is None:
- self.prev_v2[layer_idx] = torch.zeros_like(new_prev_v2)
- self.prev_v2[layer_idx].copy_(new_prev_v2)
- return self.prev_v2[layer_idx]
-
- def reset(self):
- super().reset()
- for conv_state in self.conv_states:
- if conv_state is not None:
- conv_state.zero_()
- for prev_v2 in self.prev_v2:
- if prev_v2 is not None:
- prev_v2.zero_()
- self.has_previous_state = False
-
- def _reorder_auxiliary_states(self, indices: torch.LongTensor):
- for layer_idx, conv_state in enumerate(self.conv_states):
- if conv_state is not None:
- self.conv_states[layer_idx] = conv_state.index_select(0, indices.to(conv_state.device))
- for layer_idx, prev_v2 in enumerate(self.prev_v2):
- if prev_v2 is not None:
- self.prev_v2[layer_idx] = prev_v2.index_select(0, indices.to(prev_v2.device))
- self.batch_size = indices.shape[0]
-
- def reorder_cache(self, beam_idx: torch.LongTensor):
- super().reorder_cache(beam_idx)
- self._reorder_auxiliary_states(beam_idx)
-
- def batch_repeat_interleave(self, repeats: int):
- super().batch_repeat_interleave(repeats)
- for layer_idx, conv_state in enumerate(self.conv_states):
- if conv_state is not None:
- self.conv_states[layer_idx] = conv_state.repeat_interleave(repeats, dim=0)
- for layer_idx, prev_v2 in enumerate(self.prev_v2):
- if prev_v2 is not None:
- self.prev_v2[layer_idx] = prev_v2.repeat_interleave(repeats, dim=0)
- self.batch_size *= repeats
-
- def batch_select_indices(self, indices: torch.Tensor):
- super().batch_select_indices(indices)
- self._reorder_auxiliary_states(indices)
-
-
-class CCA(nn.Module):
- def __init__(
- self,
- config: ZayaConfig,
- num_key_value_heads: int = 2,
- num_attention_heads: int = 8,
- hidden_size: int | None = None,
- head_dim: int = 128,
- cca_time0: int = 2,
- cca_time1: int = 2,
- layer_number: int = 0,
- ):
+ def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__()
self.config = config
- self.layer_number = layer_number
+ self.layer_idx = layer_idx
- self.hidden_size = int(hidden_size or config.hidden_size)
+ self.hidden_size = config.hidden_size
- self.depthwise_kernel_size = cca_time0
- self.grouped_kernel_size = cca_time1
+ self.depthwise_kernel_size = config.cca_time0
+ self.grouped_kernel_size = config.cca_time1
self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1)
- self.num_key_value_heads = int(num_key_value_heads)
- self.num_attention_heads = int(num_attention_heads)
-
- self.head_dim = int(head_dim)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
self.key_value_hidden_size = self.num_key_value_heads * self.head_dim
self.query_hidden_size = self.num_attention_heads * self.head_dim
- self.sqrt_head_dim = self.head_dim**0.5
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
- if self.num_attention_heads % self.num_key_value_heads != 0:
- raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias)
self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias)
@@ -262,23 +182,21 @@ def __init__(
self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias)
conv_channels = self.key_value_hidden_size + self.query_hidden_size
- self.conv_qk = nn.Sequential(
- nn.Conv1d(
- in_channels=conv_channels,
- out_channels=conv_channels,
- kernel_size=self.depthwise_kernel_size,
- groups=conv_channels,
- padding=0,
- stride=1,
- ),
- nn.Conv1d(
- in_channels=conv_channels,
- out_channels=conv_channels,
- kernel_size=self.grouped_kernel_size,
- groups=(self.num_key_value_heads + self.num_attention_heads),
- padding=0,
- stride=1,
- ),
+ self.conv_qk_depthwise = nn.Conv1d(
+ in_channels=conv_channels,
+ out_channels=conv_channels,
+ kernel_size=self.depthwise_kernel_size,
+ groups=conv_channels,
+ padding=0,
+ stride=1,
+ )
+ self.conv_qk_grouped = nn.Conv1d(
+ in_channels=conv_channels,
+ out_channels=conv_channels,
+ kernel_size=self.grouped_kernel_size,
+ groups=(self.num_key_value_heads + self.num_attention_heads),
+ padding=0,
+ stride=1,
)
self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads))
@@ -286,51 +204,55 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
- past_key_values: ZayaDynamicCache | None,
- attention_mask: torch.Tensor | None = None,
+ past_key_values: Cache | None,
+ padding_mask: torch.Tensor | None = None,
):
- if attention_mask is not None:
- hidden_states = hidden_states * attention_mask[:, :, None].to(hidden_states.dtype)
+ if padding_mask is not None:
+ hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype)
- batch_size, seq_length, _ = hidden_states.shape
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
projected_queries = self.linear_q(hidden_states)
projected_keys = self.linear_k(hidden_states)
qk_states = torch.cat([projected_queries, projected_keys], dim=-1)
- query_residual = projected_queries.view(batch_size, seq_length, self.num_attention_heads, self.head_dim)
- key_residual = projected_keys.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
+ query_residual = projected_queries.view(*hidden_shape)
+ key_residual = projected_keys.view(*input_shape, self.num_key_value_heads, self.head_dim)
- key_residual = key_residual.unsqueeze(-2).expand(-1, -1, -1, self.num_key_value_groups, -1)
- key_residual = key_residual.reshape(batch_size, seq_length, self.num_attention_heads, self.head_dim)
+ key_residual = key_residual.repeat_interleave(self.num_key_value_groups, dim=-2)
query_residual = (query_residual + key_residual) * 0.5
key_residual = query_residual.view(
- batch_size, seq_length, self.num_key_value_heads, self.num_key_value_groups, self.head_dim
+ *input_shape, self.num_key_value_heads, self.num_key_value_groups, self.head_dim
).mean(dim=-2)
qk_states = qk_states.transpose(1, 2)
- use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state
+ use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx)
if use_precomputed_states:
- cached_qk_states = past_key_values.conv_states[self.layer_number]
+ cached_qk_states = past_key_values.layers[self.layer_idx].conv_states
conv_input = torch.cat([cached_qk_states, qk_states], dim=-1)
else:
conv_input = F.pad(qk_states, (self.total_padding, 0))
if past_key_values is not None:
- past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=qk_states.transpose(1, 2))
+ new_conv_state = qk_states[..., -self.total_padding :]
+ if new_conv_state.shape[-1] < self.total_padding:
+ new_conv_state = F.pad(new_conv_state, (self.total_padding - new_conv_state.shape[-1], 0))
+ past_key_values.update_conv_state(new_conv_state, self.layer_idx)
- convolved_qk_states = self.conv_qk(conv_input).transpose(1, 2)
+ convolved_qk_states = self.conv_qk_depthwise(conv_input)
+ convolved_qk_states = self.conv_qk_grouped(convolved_qk_states).transpose(1, 2)
query = (
convolved_qk_states[..., : self.query_hidden_size].view(
- batch_size, seq_length, self.num_attention_heads, self.head_dim
+ *input_shape, self.num_attention_heads, self.head_dim
)
+ query_residual
)
key = (
convolved_qk_states[..., self.query_hidden_size :].view(
- batch_size, seq_length, self.num_key_value_heads, self.head_dim
+ *input_shape, self.num_key_value_heads, self.head_dim
)
+ key_residual
)
@@ -338,28 +260,18 @@ def forward(
value_current = self.val_proj1(hidden_states)
projected_v2 = self.val_proj2(hidden_states)
if use_precomputed_states:
- first_v2 = past_key_values.prev_v2[self.layer_number].unsqueeze(1)
+ first_v2 = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1)
else:
- first_v2 = self.val_proj2(hidden_states.new_zeros(batch_size, 1, self.hidden_size))
+ first_v2 = self.val_proj2(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size))
value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1)
if past_key_values is not None:
- past_key_values.update_prev_v2(self.layer_number, projected_v2[:, -1, :])
+ past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_idx)
value = torch.cat([value_current, value_delayed], dim=-1).view(
- batch_size, seq_length, self.num_key_value_heads, self.head_dim
+ *input_shape, self.num_key_value_heads, self.head_dim
)
- norm_eps = torch.finfo(query.dtype).eps
- query_norm = query.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
- key_norm = key.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
-
- key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1)
- query = query * (self.sqrt_head_dim / query_norm)
-
- query = query.reshape(batch_size, seq_length, self.query_hidden_size)
- key = key.reshape(batch_size, seq_length, self.key_value_hidden_size)
- value = value.reshape(batch_size, seq_length, self.key_value_hidden_size)
return query, key, value
@@ -446,51 +358,56 @@ def eager_attention_forward(
return attn_output, attn_weights
+@use_kernelized_func(apply_rotary_pos_emb)
class ZayaAttention(nn.Module):
- def __init__(self, config: ZayaConfig, layer_n):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__()
self.config = config
- self.layer_n = layer_n
- self.layer_idx = layer_n
- self.hidden_size = config.hidden_size
- self.num_attention_heads = config.num_attention_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
- self.is_causal = True
- self.attention_dropout = config.attention_dropout
- self.head_dim = config.head_dim
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
-
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
self.o_proj = nn.Linear(
- self.num_attention_heads * self.head_dim,
- self.hidden_size,
- bias=self.config.attention_bias,
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
- self.qkv = CCA(
+ self.layer_n = layer_idx
+ self.layer_type = config.layer_types[layer_idx]
+ self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.qkv = ZayaCCAProjection(
config=self.config,
- num_attention_heads=self.config.num_attention_heads,
- num_key_value_heads=self.config.num_query_groups,
- hidden_size=self.hidden_size,
- head_dim=self.config.head_dim,
- cca_time0=self.config.cca_time0,
- cca_time1=self.config.cca_time1,
- layer_number=layer_n,
+ layer_idx=layer_idx,
)
def forward(
self,
hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- attention_mask_2d: torch.Tensor | None = None,
+ attention_mask: dict[str, Any] | None = None,
past_key_values: Cache | None = None,
- output_attentions: bool = False,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ **kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
batch_size, seq_length, _ = hidden_states.shape
- query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, attention_mask_2d)
- query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads, self.head_dim)
- key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim)
- value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim)
+
+ mask_mapping = attention_mask or {}
+ causal_mask = mask_mapping.get("causal")
+ padding_mask = mask_mapping.get("padding")
+
+ query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask)
+
+ norm_eps = torch.finfo(query_states.dtype).eps
+ head_dim_scale = self.scaling**-1
+ query_states = query_states * (
+ head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
+ )
+ key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps))
+ key_states = key_states * self.qkv.temp[None, None, :, None]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -502,8 +419,7 @@ def forward(
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n)
- causal_mask = attention_mask
- if causal_mask is not None:
+ if isinstance(causal_mask, torch.Tensor):
causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
@@ -517,15 +433,13 @@ def forward(
causal_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
- output_attentions=output_attentions,
+ sliding_window=self.sliding_window,
+ **kwargs,
)
attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
- if not output_attentions:
- attn_weights = None
-
return attn_output, attn_weights, past_key_values
@@ -541,66 +455,94 @@ def _apply_residual_scaling(
return hidden_states, residual
-class ZayaDecoderATTLayer(GradientCheckpointingLayer):
- def __init__(self, config: ZayaConfig, layer_n: int):
+class ZayaDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__()
self.config = config
- self.self_attn = ZayaAttention(config, layer_n)
-
+ self.self_attn = ZayaAttention(config, layer_idx)
self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
- self.res_scale = ResidualScaling(config, layer_n)
+ self.res_scale = ResidualScaling(config.hidden_size, has_residual_scale=layer_idx != 0)
+ self.zaya_block = ZayaSparseMoeBlock(
+ config,
+ config.num_experts,
+ config.zaya_mlp_expansion,
+ config.intermediate_size,
+ layer_idx,
+ )
+ self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
+ self.post_attention_res_scale = ResidualScaling(config.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
- residual: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- attention_mask_2d: torch.Tensor | None = None,
+ residual: torch.Tensor | None,
+ prev_router_hidden_states: torch.Tensor | None = None,
+ attention_mask: dict[str, Any] | None = None,
past_key_values: Cache | None = None,
- output_attentions: bool | None = False,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- prev_router_hidden_states: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]:
hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
- attention_mask_2d=attention_mask_2d,
past_key_values=past_key_values,
- output_attentions=output_attentions,
position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states, residual = _apply_residual_scaling(
+ hidden_states, residual, self.post_attention_res_scale, self.post_attention_norm
+ )
+
+ hidden_states, prev_router_hidden_states, _ = self.zaya_block(
+ hidden_states,
+ prev_router_hidden_states,
)
- return hidden_states, self_attn_weights if output_attentions else None, residual, prev_router_hidden_states
+ return hidden_states, self_attn_weights, residual, prev_router_hidden_states
class ResidualScaling(nn.Module):
- def __init__(self, config, layer_n):
+ def __init__(self, hidden_size: int, has_residual_scale: bool = True):
super().__init__()
- self.not_first_layer = layer_n != 0
- self.hidden_states_scale = torch.nn.Parameter(torch.ones(config.hidden_size))
- self.hidden_states_bias = torch.nn.Parameter(torch.zeros(config.hidden_size))
+ self.has_residual_scale = has_residual_scale
+ self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size))
+ self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size))
- if self.not_first_layer:
- self.residual_scale = torch.nn.Parameter(torch.ones(config.hidden_size))
- self.residual_bias = torch.nn.Parameter(torch.zeros(config.hidden_size))
+ if self.has_residual_scale:
+ self.residual_scale = nn.Parameter(torch.ones(hidden_size))
+ self.residual_bias = nn.Parameter(torch.zeros(hidden_size))
def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor):
hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale
- if self.not_first_layer:
+ if self.has_residual_scale:
residual = (residual + self.residual_bias) * self.residual_scale
return residual, hidden_states
+class ZayaRouterMLP(nn.Module):
+ def __init__(self, hidden_size: int, num_experts: int):
+ super().__init__()
+ self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True)
+ self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True)
+ self.out_proj = nn.Linear(hidden_size, num_experts, bias=False)
+ self.act_fn = nn.GELU()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.act_fn(self.fc1(hidden_states))
+ hidden_states = self.act_fn(self.fc2(hidden_states))
+ return self.out_proj(hidden_states)
+
+
class ZayaRouter(nn.Module):
def __init__(
self,
config,
layer_idx: int,
num_moe_experts: int,
- moe_router_topk: int,
+ num_experts_per_tok: int,
mlp_expansion: int,
hidden_size: int | None = None,
) -> None:
@@ -611,27 +553,18 @@ def __init__(
self.layer_idx = layer_idx
self.num_experts = num_moe_experts + 1
- self.topk = int(moe_router_topk)
+ self.topk = int(num_experts_per_tok)
self.mlp_expansion = int(mlp_expansion)
self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True)
- zaya_first_layer = 1
- self.use_eda = self.layer_idx != zaya_first_layer
+ self.use_eda = self.layer_idx != 0
- ln_eps = float(getattr(config, "norm_epsilon", 1e-5))
- self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=ln_eps)
+ self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=config.norm_epsilon)
if self.use_eda:
self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion))
- self.non_linearity = nn.GELU()
- self.router_mlp = nn.Sequential(
- nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True),
- self.non_linearity,
- nn.Linear(self.mlp_expansion, self.mlp_expansion, bias=True),
- self.non_linearity,
- nn.Linear(self.mlp_expansion, self.num_experts, bias=False),
- )
+ self.router_mlp = ZayaRouterMLP(self.mlp_expansion, self.num_experts)
self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32))
self.balancing_biases[-1] = -1.0
@@ -669,11 +602,11 @@ def forward(
class ZayaExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
- def __init__(self, config, num_experts: int, ffn_hidden_size: int):
+ def __init__(self, config, num_experts: int, intermediate_size: int):
super().__init__()
self.num_experts = num_experts
self.hidden_dim = config.hidden_size
- self.intermediate_dim = ffn_hidden_size // 2
+ self.intermediate_dim = intermediate_size // 2
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
@@ -686,7 +619,7 @@ def forward(
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
- expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts + 1)
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
@@ -705,14 +638,14 @@ def forward(
return final_hidden_states
-class ZayaBlock(nn.Module):
+class ZayaSparseMoeBlock(nn.Module):
def __init__(
self,
config,
num_moe_experts: int,
mlp_expansion: int,
- ffn_hidden_size: int,
- layer_n: int,
+ intermediate_size: int,
+ layer_idx: int,
):
super().__init__()
self.config = config
@@ -720,13 +653,13 @@ def __init__(
self.num_moe_experts = num_moe_experts
self.router = ZayaRouter(
config=self.config,
- layer_idx=layer_n,
+ layer_idx=layer_idx,
num_moe_experts=self.num_moe_experts,
- moe_router_topk=getattr(self.config, "moe_router_topk", 1),
+ num_experts_per_tok=self.config.num_experts_per_tok,
mlp_expansion=mlp_expansion,
hidden_size=self.hidden_dim,
)
- self.experts = ZayaExperts(self.config, self.num_moe_experts, ffn_hidden_size=ffn_hidden_size)
+ self.experts = ZayaExperts(self.config, self.num_moe_experts, intermediate_size=intermediate_size)
def forward(
self,
@@ -736,6 +669,13 @@ def forward(
route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router(
hidden_states, router_states=prev_router_hidden_states
)
+
+ # if the router outputs num_moe_experts, just skip the tokens
+ # by masking them with id=0 and prob=0 to reuse the expert code
+ skip_expert = expert_choice == self.num_moe_experts
+ route_prob = route_prob.masked_fill(skip_expert, 0)
+ expert_choice = expert_choice.masked_fill(skip_expert, 0)
+
batch_size, seq_length, emb_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim)
expert_output = self.experts(hidden_states_flat, expert_choice, route_prob)
@@ -744,64 +684,25 @@ def forward(
return expert_output, prev_router_hidden_states, router_logits
-class ZayaDecoderMLPLayer(GradientCheckpointingLayer):
- def __init__(
- self,
- config: ZayaConfig,
- num_moe_experts: int,
- mlp_expansion: int,
- ffn_hidden_size: int,
- layer_n: int,
- ):
- super().__init__()
- self.config = config
- self.zaya_block = ZayaBlock(
- config,
- num_moe_experts,
- mlp_expansion,
- ffn_hidden_size,
- layer_n,
- )
- self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
- self.res_scale = ResidualScaling(config, layer_n)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- residual: torch.Tensor | None,
- prev_router_hidden_states: torch.Tensor | None = None,
- output_router_logits: bool = False,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]:
- hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
-
- hidden_states, prev_router_hidden_states, router_logits = self.zaya_block(
- hidden_states,
- prev_router_hidden_states,
- )
-
- return (
- hidden_states,
- router_logits if output_router_logits else None,
- residual,
- prev_router_hidden_states,
- )
-
-
+@auto_docstring
class ZayaPreTrainedModel(PreTrainedModel):
config: ZayaConfig
- config_class = ZayaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
- _no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"]
+ _no_split_modules = ["ZayaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
+ # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache.
+ _can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(ZayaRouter, index=3),
+ "hidden_states": ZayaDecoderLayer,
+ "attentions": ZayaAttention,
}
+ config_class = ZayaConfig
@torch.no_grad()
def _init_weights(self, module):
@@ -809,7 +710,7 @@ def _init_weights(self, module):
if isinstance(module, ResidualScaling):
init.ones_(module.hidden_states_scale)
init.zeros_(module.hidden_states_bias)
- if module.not_first_layer:
+ if module.has_residual_scale:
init.ones_(module.residual_scale)
init.zeros_(module.residual_bias)
elif isinstance(module, ZayaRouter):
@@ -821,6 +722,33 @@ def _init_weights(self, module):
std = self.config.initializer_range
init.normal_(module.gate_up_proj, mean=0.0, std=std)
init.normal_(module.down_proj, mean=0.0, std=std)
+ elif isinstance(module, ZayaRotaryEmbedding):
+ for layer_type in module.layer_types:
+ rope_init_fn = module.compute_default_rope_parameters
+ if module.rope_type[layer_type] != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
+ getattr(module, f"{layer_type}_inv_freq").copy_(curr_inv_freq)
+ getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq)
+
+
+def make_zaya_cache(config: ZayaConfig) -> DynamicCache:
+ """
+ Create ZAYA's native hybrid cache.
+
+ `config.layer_types` is reserved for full/sliding attention masks and RoPE parameters. Cache layers use the native hybrid layout because every ZAYA decoder layer has attention, convolution, and recurrent states.
+ """
+ cache_config = copy.copy(config)
+ cache_config.layer_types = ["hybrid"] * config.num_hidden_layers
+ return DynamicCache(config=cache_config)
+
+
+def _is_zaya_cache(past_key_values: Cache) -> bool:
+ return (
+ isinstance(past_key_values, DynamicCache)
+ and len(past_key_values.layers) > 0
+ and isinstance(past_key_values.layers[0], LinearAttentionAndFullAttentionLayer)
+ )
@auto_docstring
@@ -830,36 +758,16 @@ def __init__(self, config: ZayaConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = []
-
- for layer_n in range(config.num_hidden_layers):
- if layer_n % 2 == 1:
- self.layers.append(
- ZayaDecoderMLPLayer(
- config,
- config.num_experts,
- config.zaya_mlp_expansion,
- config.ffn_hidden_size,
- layer_n,
- )
- )
- else:
- self.layers.append(ZayaDecoderATTLayer(config, layer_n))
- self.layers = nn.ModuleList(self.layers)
+ self.layers = nn.ModuleList(
+ [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
self.gradient_checkpointing = False
- self.res_scale = ResidualScaling(config, config.num_hidden_layers)
+ self.res_scale = ResidualScaling(config.hidden_size)
self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
self.rotary_emb = ZayaRotaryEmbedding(config=config)
- if self.config.swa_layers is not None:
- swa_config = copy.copy(config)
- swa_config.rope_parameters = {
- **config.rope_parameters,
- "rope_theta": swa_config.swa_rotary_base,
- }
- self.swa_rotary_emb = ZayaRotaryEmbedding(config=swa_config)
self.post_init()
@@ -880,25 +788,18 @@ def forward(
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- output_router_logits: bool | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> MoeModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = ZayaDynamicCache(
- self.config, inputs_embeds.shape[0], dtype=self.dtype, device=self.device
- )
+ if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)):
+ if past_key_values is not None and past_key_values.get_seq_length() > 0:
+ raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.")
+ past_key_values = make_zaya_cache(self.config)
residual = None
@@ -910,48 +811,44 @@ def forward(
device=inputs_embeds.device,
).unsqueeze(0)
- causal_mask = self._update_causal_mask(
+ if attention_mask is not None and attention_mask.ndim != 2:
+ raise ValueError(
+ "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution."
+ )
+
+ causal_mask_mapping = self._update_causal_mask(
attention_mask,
inputs_embeds,
position_ids,
past_key_values,
)
- if attention_mask is not None and attention_mask.ndim != 2:
- raise ValueError("ZAYA CCA requires a 2D `attention_mask` to mask padding tokens before convolution.")
- # ZayaDynamicCache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
- # CCA only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
- attention_mask_2d = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
+ padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
+
+ # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
+ # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
if inputs_embeds.shape[1] == 1:
- attention_mask_2d = None
+ padding_mask = None
hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- if self.config.swa_layers is not None:
- swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids)
+ position_embeddings = {
+ layer_type: self.rotary_emb(hidden_states, position_ids, layer_type)
+ for layer_type in set(self.config.layer_types)
+ }
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
prev_router_hidden_states = None
for layer_n, decoder_layer in enumerate(self.layers):
- if self.config.swa_layers is not None:
- emb_to_use = position_embeddings if self.config.swa_layers[layer_n] == 0 else swa_position_embeddings
- else:
- emb_to_use = position_embeddings
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
+ layer_type = self.config.layer_types[layer_n]
+ emb_to_use = position_embeddings[layer_type]
+ mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask}
layer_outputs = decoder_layer(
hidden_states,
residual,
- attention_mask=causal_mask,
- position_ids=position_ids,
+ prev_router_hidden_states,
+ attention_mask=mask_mapping,
past_key_values=past_key_values,
- output_attentions=output_attentions,
position_embeddings=emb_to_use,
- prev_router_hidden_states=prev_router_hidden_states,
- attention_mask_2d=attention_mask_2d,
**kwargs,
)
@@ -959,23 +856,11 @@ def forward(
residual = layer_outputs[2]
prev_router_hidden_states = layer_outputs[3]
- if isinstance(decoder_layer, ZayaDecoderATTLayer):
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm)
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if past_key_values and not past_key_values.has_previous_state:
- past_key_values.has_previous_state = True
-
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
)
def _update_causal_mask(
@@ -985,33 +870,37 @@ def _update_causal_mask(
position_ids: torch.Tensor,
past_key_values: Cache,
):
- return create_causal_mask(
- config=self.config,
- inputs_embeds=input_tensor,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- position_ids=position_ids,
- )
-
-
-@auto_docstring
+ mask_kwargs = {
+ "config": self.config,
+ "inputs_embeds": input_tensor,
+ "attention_mask": attention_mask,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ mask_creation_functions = {
+ "full_attention": lambda: create_causal_mask(**mask_kwargs),
+ "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs),
+ }
+ causal_mask_mapping = {}
+ for layer_type in set(self.config.layer_types):
+ causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]()
+ return causal_mask_mapping
+
+
+@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+ _tp_plan = {"lm_head": "colwise_gather_output"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
_is_stateful = True
def __init__(self, config, **kwargs):
- super().__init__(config, **kwargs)
+ super().__init__(config)
self.model = ZayaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
- if self.config.tie_word_embeddings:
- self.lm_head.weight = self.model.embed_tokens.weight
-
self.post_init()
- def set_decoder(self, decoder):
- self.model = decoder
-
@can_return_tuple
@auto_docstring
def forward(
@@ -1027,11 +916,28 @@ def forward(
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> MoeCausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, ZayaForCausalLM
+
+ >>> model = ZayaForCausalLM.from_pretrained("meta-zaya/Zaya-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-zaya/Zaya-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
- outputs = self.model(
+ outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -1043,80 +949,21 @@ def forward(
)
hidden_states = outputs.last_hidden_state
-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
+
loss = None
if labels is not None:
- loss = self.loss_function(
- logits=logits,
- labels=labels,
- vocab_size=self.config.vocab_size,
- **kwargs,
- )
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
return MoeCausalLMOutputWithPast(
loss=loss,
- aux_loss=None,
logits=logits,
- past_key_values=outputs.past_key_values if use_cache else None,
+ past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- position_ids=None,
- use_cache=True,
- logits_to_keep=None,
- **kwargs,
- ):
- if past_key_values is not None and not isinstance(past_key_values, ZayaDynamicCache):
- raise ValueError(
- f"Zaya uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
- )
-
- model_inputs = super().prepare_inputs_for_generation(
- input_ids=input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- position_ids=position_ids,
- use_cache=use_cache,
- logits_to_keep=logits_to_keep,
- **kwargs,
- )
- return model_inputs
-
- def _prepare_cache_for_generation(
- self,
- generation_config,
- model_kwargs: dict,
- generation_mode,
- batch_size: int,
- max_cache_length: int,
- ):
- if generation_config.use_cache is False:
- return
-
- if "past_key_values" not in model_kwargs:
- cache_batch_size = batch_size * max(generation_config.num_beams, generation_config.num_return_sequences)
- model_kwargs["past_key_values"] = ZayaDynamicCache(
- self.config, cache_batch_size, dtype=self.dtype, device=self.device
- )
- generation_config.cache_implementation = None
- return super()._prepare_cache_for_generation(
- generation_config=generation_config,
- model_kwargs=model_kwargs,
- generation_mode=generation_mode,
- batch_size=batch_size,
- max_cache_length=max_cache_length,
- )
-
__all__ = ["ZayaPreTrainedModel", "ZayaModel", "ZayaForCausalLM"]
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 50cad3bd10ea..cbafc3200146 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -31,7 +31,7 @@
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeModelOutputWithPast
-from ...modeling_rope_utils import RopeParameters
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
@@ -54,15 +54,15 @@
@strict
class ZayaConfig(PreTrainedConfig):
r"""
- ffn_hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the feed-forward and expert hidden states, translate it to `intermediate_size`.
+ intermediate_size (`int`, *optional*, defaults to 4096):
+ Dimension of the feed-forward and expert hidden states.
num_key_value_heads (`int`, *optional*, defaults to 2):
Number of key/value groups.
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Fraction of each attention head dimension using rotary embeddings.
lm_head_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the language modeling head.
- moe_router_topk (`int`, *optional*, defaults to 1):
+ num_experts_per_tok (`int`, *optional*, defaults to 1):
Number of selected experts per token. ZAYA checkpoints use top-1 routing.
zaya_mlp_expansion (`int`, *optional*, defaults to 256):
Expansion size used by the dense ZAYA blocks.
@@ -91,11 +91,11 @@ class ZayaConfig(PreTrainedConfig):
vocab_size: int = 262272
hidden_size: int = 2048
- ffn_hidden_size: int = 4096
- num_hidden_layers: int = 80
+ intermediate_size: int = 4096
+ num_hidden_layers: int = 40
num_experts: int = 16
num_attention_heads: int = 8
- num_key_value_heads: int | None = 2
+ num_key_value_heads: int = 2
hidden_act: str = "silu"
head_dim: int = 128
max_position_embeddings: int = 131072
@@ -108,10 +108,10 @@ class ZayaConfig(PreTrainedConfig):
attention_bias: bool = False
lm_head_bias: bool = False
attention_dropout: float | int = 0.0
- moe_router_topk: int = 1
+ num_experts_per_tok: int = 1
zaya_mlp_expansion: int = 256
- cca_time0: int | None = 2
- cca_time1: int | None = 2
+ cca_time0: int = 2
+ cca_time1: int = 2
sliding_window: int | None = None
layer_types: list[str] | None = None
swa_rotary_base: float | int = 10000.0
@@ -121,60 +121,14 @@ class ZayaConfig(PreTrainedConfig):
eos_token_id: int | list[int] | None = 106
def __post_init__(self, **kwargs):
- for unused_checkpoint_kwarg in (
- "cca",
- "num_query_groups",
- "activation_func",
- "normalization",
- "add_bias_linear",
- "gated_linear_unit",
- "fused_add_norm",
- "apply_rope_fusion",
- "bias_activation_fusion",
- "activation_func_fp8_input_store",
- "clamp_temp",
- "kv_channels",
- "mamba_cache_dtype",
- "residual_in_fp32",
- "rope_scaling",
- "scale_residual_merge",
- "zaya_high_prec",
- "zaya_use_mod",
- "zaya_use_eda",
- ):
- kwargs.pop(unused_checkpoint_kwarg, None)
-
- self.intermediate_size = self.ffn_hidden_size
- self.num_experts_per_tok = self.moe_router_topk
-
- self.num_key_value_heads = (
- self.num_attention_heads if self.num_key_value_heads is None else self.num_key_value_heads
+ self.layer_types = (
+ ["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types)
)
- legacy_swa_layers = kwargs.pop("swa_layers", None)
- swa_window_sizes = {int(window_size) for window_size in (legacy_swa_layers or []) if int(window_size) > 0}
- if self.sliding_window is None and swa_window_sizes:
- self.sliding_window = max(swa_window_sizes)
- if self.layer_types is None:
- if legacy_swa_layers is None:
- self.layer_types = ["full_attention"] * self.num_hidden_layers
- else:
- self.layer_types = [
- "full_attention" if layer_type == 0 else "sliding_attention" for layer_type in legacy_swa_layers
- ]
- else:
- self.layer_types = list(self.layer_types)
-
- self.cca_time0 = 2 if self.cca_time0 is None else self.cca_time0
- self.cca_time1 = 2 if self.cca_time1 is None else self.cca_time1
-
- super().__post_init__(**kwargs)
-
- def convert_rope_params_to_dict(self, **kwargs):
default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = {
"full_attention": {
"rope_type": "default",
- "rope_theta": kwargs.pop("rope_theta", self.default_theta),
+ "rope_theta": self.default_theta,
"partial_rotary_factor": self.partial_rotary_factor,
},
"sliding_attention": {
@@ -183,21 +137,19 @@ def convert_rope_params_to_dict(self, **kwargs):
"partial_rotary_factor": self.partial_rotary_factor,
},
}
- layer_types = set(self.layer_types)
-
if self.rope_parameters is None:
- self.rope_parameters = {layer_type: default_rope_params[layer_type] for layer_type in layer_types}
- else:
self.rope_parameters = {
- layer_type: {**default_rope_params[layer_type], **(self.rope_parameters.get(layer_type) or {})}
- for layer_type in layer_types
+ layer_type: default_rope_params[layer_type] for layer_type in set(self.layer_types)
}
+ super().__post_init__(**kwargs)
+
+ def convert_rope_params_to_dict(self, **kwargs):
+ # ZAYA uses nested RoPE parameters keyed by layer type. Keep the base RoPE BC conversion from treating them
+ # like a single flat RoPE dict and injecting top-level keys such as `rope_theta`.
return kwargs
def validate_architecture(self):
- if self.head_dim is None:
- raise ValueError("`head_dim` must be set for ZAYA.")
if self.num_experts_per_tok != 1:
raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.")
if self.num_attention_heads % self.num_key_value_heads != 0:
@@ -210,8 +162,6 @@ def validate_architecture(self):
raise ValueError("`sliding_window` must be set when `layer_types` contains `sliding_attention`.")
if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError("`sliding_window` must be a strictly positive integer.")
- if (self.cca_time0, self.cca_time1) != (2, 2):
- raise ValueError("ZAYA currently supports `cca_time0=2` and `cca_time1=2` only.")
class ZayaRotaryEmbedding(LagunaRotaryEmbedding):
@@ -222,13 +172,14 @@ class ZayaRMSNorm(Qwen3MoeRMSNorm):
pass
-def _make_zaya_cache(config: ZayaConfig) -> DynamicCache:
+def make_zaya_cache(config: ZayaConfig) -> DynamicCache:
+ """
+ Create ZAYA's native hybrid cache.
+
+ `config.layer_types` is reserved for full/sliding attention masks and RoPE parameters. Cache layers use the native hybrid layout because every ZAYA decoder layer has attention, convolution, and recurrent states.
+ """
cache_config = copy.copy(config)
- # layer_types is used to distinct the rope_type (full or swa)
- # so need to construct a new layer_types to construct cache
- cache_config.layer_types = [
- "hybrid" if layer_idx % 2 == 0 else "moe" for layer_idx in range(config.num_hidden_layers)
- ]
+ cache_config.layer_types = ["hybrid"] * config.num_hidden_layers
return DynamicCache(config=cache_config)
@@ -249,6 +200,8 @@ class ZayaCCAProjection(nn.Module):
for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`.
Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses
`val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**.
+
+ Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys.
"""
def __init__(self, config: ZayaConfig, layer_idx: int):
@@ -298,10 +251,10 @@ def forward(
self,
hidden_states: torch.Tensor,
past_key_values: Cache | None,
- attention_mask: torch.Tensor | None = None,
+ padding_mask: torch.Tensor | None = None,
):
- if attention_mask is not None:
- hidden_states = hidden_states * attention_mask[:, :, None].to(hidden_states.dtype)
+ if padding_mask is not None:
+ hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
@@ -389,19 +342,16 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | dict[str, torch.Tensor | None] | None = None,
+ attention_mask: dict[str, Any] | None = None,
past_key_values: Cache | None = None,
- output_attentions: bool = False,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ **kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
batch_size, seq_length, _ = hidden_states.shape
- if isinstance(attention_mask, dict):
- causal_mask = attention_mask.get("causal")
- padding_mask = attention_mask.get("padding")
- else:
- causal_mask = attention_mask
- padding_mask = None
+ mask_mapping = attention_mask or {}
+ causal_mask = mask_mapping.get("causal")
+ padding_mask = mask_mapping.get("padding")
query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask)
@@ -438,7 +388,7 @@ def forward(
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
- output_attentions=output_attentions,
+ **kwargs,
)
attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim)
@@ -447,25 +397,32 @@ def forward(
return attn_output, attn_weights, past_key_values
-class ZayaDecoderATTLayer(GradientCheckpointingLayer):
- def __init__(self, config: ZayaConfig, layer_n: int):
+class ZayaDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__()
self.config = config
- self.self_attn = ZayaAttention(config, layer_n)
-
+ self.self_attn = ZayaAttention(config, layer_idx)
self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
- self.res_scale = ResidualScaling(config, layer_n)
+ self.res_scale = ResidualScaling(config.hidden_size, has_residual_scale=layer_idx != 0)
+ self.zaya_block = ZayaSparseMoeBlock(
+ config,
+ config.num_experts,
+ config.zaya_mlp_expansion,
+ config.intermediate_size,
+ layer_idx,
+ )
+ self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
+ self.post_attention_res_scale = ResidualScaling(config.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
- attention_mask: torch.Tensor | dict[str, torch.Tensor | None] | None = None,
+ prev_router_hidden_states: torch.Tensor | None = None,
+ attention_mask: dict[str, Any] | None = None,
past_key_values: Cache | None = None,
- output_attentions: bool | None = False,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- prev_router_hidden_states: torch.Tensor | None = None,
- **kwargs,
+ **kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]:
hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
@@ -473,27 +430,36 @@ def forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
- output_attentions=output_attentions,
position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states, residual = _apply_residual_scaling(
+ hidden_states, residual, self.post_attention_res_scale, self.post_attention_norm
+ )
+
+ hidden_states, prev_router_hidden_states, _ = self.zaya_block(
+ hidden_states,
+ prev_router_hidden_states,
)
- return hidden_states, self_attn_weights if output_attentions else None, residual, prev_router_hidden_states
+ return hidden_states, self_attn_weights, residual, prev_router_hidden_states
class ResidualScaling(nn.Module):
- def __init__(self, config, layer_n):
+ def __init__(self, hidden_size: int, has_residual_scale: bool = True):
super().__init__()
- self.not_first_layer = layer_n != 0
- self.hidden_states_scale = torch.nn.Parameter(torch.ones(config.hidden_size))
- self.hidden_states_bias = torch.nn.Parameter(torch.zeros(config.hidden_size))
+ self.has_residual_scale = has_residual_scale
+ self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size))
+ self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size))
- if self.not_first_layer:
- self.residual_scale = torch.nn.Parameter(torch.ones(config.hidden_size))
- self.residual_bias = torch.nn.Parameter(torch.zeros(config.hidden_size))
+ if self.has_residual_scale:
+ self.residual_scale = nn.Parameter(torch.ones(hidden_size))
+ self.residual_bias = nn.Parameter(torch.zeros(hidden_size))
def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor):
hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale
- if self.not_first_layer:
+ if self.has_residual_scale:
residual = (residual + self.residual_bias) * self.residual_scale
return residual, hidden_states
@@ -546,8 +512,7 @@ def __init__(
self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True)
- zaya_first_layer = 1
- self.use_eda = self.layer_idx != zaya_first_layer
+ self.use_eda = self.layer_idx != 0
self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=config.norm_epsilon)
if self.use_eda:
@@ -605,7 +570,7 @@ def __init__(
num_moe_experts: int,
mlp_expansion: int,
intermediate_size: int,
- layer_n: int,
+ layer_idx: int,
):
super().__init__()
self.config = config
@@ -613,7 +578,7 @@ def __init__(
self.num_moe_experts = num_moe_experts
self.router = ZayaRouter(
config=self.config,
- layer_idx=layer_n,
+ layer_idx=layer_idx,
num_moe_experts=self.num_moe_experts,
num_experts_per_tok=self.config.num_experts_per_tok,
mlp_expansion=mlp_expansion,
@@ -629,6 +594,9 @@ def forward(
route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router(
hidden_states, router_states=prev_router_hidden_states
)
+
+ # if the router outputs num_moe_experts, just skip the tokens
+ # by masking them with id=0 and prob=0 to reuse the expert code
skip_expert = expert_choice == self.num_moe_experts
route_prob = route_prob.masked_fill(skip_expert, 0)
expert_choice = expert_choice.masked_fill(skip_expert, 0)
@@ -641,59 +609,15 @@ def forward(
return expert_output, prev_router_hidden_states, router_logits
-class ZayaDecoderMLPLayer(GradientCheckpointingLayer):
- def __init__(
- self,
- config: ZayaConfig,
- num_moe_experts: int,
- mlp_expansion: int,
- intermediate_size: int,
- layer_n: int,
- ):
- super().__init__()
- self.config = config
- self.zaya_block = ZayaSparseMoeBlock(
- config,
- num_moe_experts,
- mlp_expansion,
- intermediate_size,
- layer_n,
- )
- self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
- self.res_scale = ResidualScaling(config, layer_n)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- residual: torch.Tensor | None,
- prev_router_hidden_states: torch.Tensor | None = None,
- output_router_logits: bool = False,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]:
- hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
-
- hidden_states, prev_router_hidden_states, router_logits = self.zaya_block(
- hidden_states,
- prev_router_hidden_states,
- )
-
- return (
- hidden_states,
- router_logits if output_router_logits else None,
- residual,
- prev_router_hidden_states,
- )
-
-
class ZayaPreTrainedModel(LlamaPreTrainedModel):
config: ZayaConfig
config_class = ZayaConfig
- _no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"]
+ _no_split_modules = ["ZayaDecoderLayer"]
# ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache.
_can_compile_fullgraph = False
_can_record_outputs = {
"router_logits": OutputRecorder(ZayaRouter, index=3),
- "hidden_states": [ZayaDecoderATTLayer, ZayaDecoderMLPLayer],
+ "hidden_states": ZayaDecoderLayer,
"attentions": ZayaAttention,
}
@@ -703,7 +627,7 @@ def _init_weights(self, module):
if isinstance(module, ResidualScaling):
init.ones_(module.hidden_states_scale)
init.zeros_(module.hidden_states_bias)
- if module.not_first_layer:
+ if module.has_residual_scale:
init.ones_(module.residual_scale)
init.zeros_(module.residual_bias)
elif isinstance(module, ZayaRouter):
@@ -715,6 +639,14 @@ def _init_weights(self, module):
std = self.config.initializer_range
init.normal_(module.gate_up_proj, mean=0.0, std=std)
init.normal_(module.down_proj, mean=0.0, std=std)
+ elif isinstance(module, ZayaRotaryEmbedding):
+ for layer_type in module.layer_types:
+ rope_init_fn = module.compute_default_rope_parameters
+ if module.rope_type[layer_type] != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
+ getattr(module, f"{layer_type}_inv_freq").copy_(curr_inv_freq)
+ getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq)
@auto_docstring
@@ -724,25 +656,12 @@ def __init__(self, config: ZayaConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = []
-
- for layer_n in range(config.num_hidden_layers):
- if layer_n % 2 == 1:
- self.layers.append(
- ZayaDecoderMLPLayer(
- config,
- config.num_experts,
- config.zaya_mlp_expansion,
- config.intermediate_size,
- layer_n,
- )
- )
- else:
- self.layers.append(ZayaDecoderATTLayer(config, layer_n))
- self.layers = nn.ModuleList(self.layers)
+ self.layers = nn.ModuleList(
+ [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
self.gradient_checkpointing = False
- self.res_scale = ResidualScaling(config, config.num_hidden_layers)
+ self.res_scale = ResidualScaling(config.hidden_size)
self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
@@ -777,8 +696,8 @@ def forward(
if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)):
if past_key_values is not None and past_key_values.get_seq_length() > 0:
- raise ValueError("ZAYA requires a native hybrid cache created from `_make_zaya_cache`.")
- past_key_values = _make_zaya_cache(self.config)
+ raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.")
+ past_key_values = make_zaya_cache(self.config)
residual = None
@@ -790,21 +709,19 @@ def forward(
device=inputs_embeds.device,
).unsqueeze(0)
- if isinstance(attention_mask, dict):
- causal_mask_mapping = attention_mask
- padding_mask = None
- else:
- causal_mask_mapping = self._update_causal_mask(
- attention_mask,
- inputs_embeds,
- position_ids,
- past_key_values,
- )
- padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
- if attention_mask is not None and not isinstance(attention_mask, dict) and attention_mask.ndim != 2:
+ if attention_mask is not None and attention_mask.ndim != 2:
raise ValueError(
"ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution."
)
+
+ causal_mask_mapping = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds,
+ position_ids,
+ past_key_values,
+ )
+ padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
+
# ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
# CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
if inputs_embeds.shape[1] == 1:
@@ -822,15 +739,14 @@ def forward(
for layer_n, decoder_layer in enumerate(self.layers):
layer_type = self.config.layer_types[layer_n]
emb_to_use = position_embeddings[layer_type]
- attention_mask = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask}
+ mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask}
layer_outputs = decoder_layer(
hidden_states,
residual,
- attention_mask=attention_mask,
- position_ids=position_ids,
+ prev_router_hidden_states,
+ attention_mask=mask_mapping,
past_key_values=past_key_values,
position_embeddings=emb_to_use,
- prev_router_hidden_states=prev_router_hidden_states,
**kwargs,
)
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index 5e16b744c989..027901d2ffce 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -27,7 +27,7 @@
from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel
from transformers.cache_utils import DynamicCache, LinearAttentionAndFullAttentionLayer
- from transformers.models.zaya.modeling_zaya import ZayaCCAProjection, _make_zaya_cache
+ from transformers.models.zaya.modeling_zaya import ZayaCCAProjection, make_zaya_cache
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
@@ -49,14 +49,16 @@ def __init__(self, parent):
intermediate_size=64,
)
self.head_dim = 8
- self.ffn_hidden_size = 64
self.num_experts = 4
- self.moe_router_topk = 1
+ self.num_experts_per_tok = 1
self.zaya_mlp_expansion = 4
self.tie_word_embeddings = False
self.rope_parameters = {
- "rope_theta": 10000,
- "rope_type": "default",
+ "full_attention": {
+ "rope_theta": 10000,
+ "rope_type": "default",
+ "partial_rotary_factor": 0.5,
+ },
}
@@ -82,18 +84,12 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l
conv_shape = self._get_conv_state_shape(batch_size, config)
recurrent_shape = self._get_recurrent_state_shape(batch_size, config)
- for layer_idx, layer in enumerate(past_key_values.layers):
- if layer_idx % 2 == 0:
- self.assertIs(type(layer), LinearAttentionAndFullAttentionLayer)
- self.assertEqual(layer.keys.shape, attention_shape)
- self.assertEqual(layer.values.shape, attention_shape)
- self.assertEqual(layer.conv_states.shape, conv_shape)
- self.assertEqual(layer.recurrent_states.shape, recurrent_shape)
- else:
- self.assertIsNone(getattr(layer, "keys", None))
- self.assertIsNone(getattr(layer, "values", None))
- self.assertIsNone(layer.conv_states)
- self.assertIsNone(layer.recurrent_states)
+ for layer in past_key_values.layers:
+ self.assertIs(type(layer), LinearAttentionAndFullAttentionLayer)
+ self.assertEqual(layer.keys.shape, attention_shape)
+ self.assertEqual(layer.values.shape, attention_shape)
+ self.assertEqual(layer.conv_states.shape, conv_shape)
+ self.assertEqual(layer.recurrent_states.shape, recurrent_shape)
def is_pipeline_test_to_skip(
self,
@@ -132,7 +128,7 @@ def test_attention_outputs(self):
with torch.no_grad():
outputs = model(**self._prepare_for_class({**inputs_dict, "output_attentions": True}, model_class))
- expected_attn_layers = (config.num_hidden_layers + 1) // 2
+ expected_attn_layers = config.num_hidden_layers
self.assertEqual(len(outputs.attentions), expected_attn_layers)
self.assertEqual(
outputs.attentions[0].shape,
@@ -248,32 +244,22 @@ def test_moe_router_logits(self):
with torch.no_grad():
outputs = model(**inputs_dict, output_router_logits=True)
- expected_moe_layers = config.num_hidden_layers // 2
+ expected_moe_layers = config.num_hidden_layers
self.assertEqual(len(outputs.router_logits), expected_moe_layers)
self.assertEqual(
outputs.router_logits[0].shape,
(self.model_tester.batch_size * self.model_tester.seq_length, config.num_experts + 1),
)
- def test_moe_router_topk_validation(self):
- with self.assertRaisesRegex(StrictDataclassClassValidationError, "moe_router_topk=1"):
- ZayaConfig(moe_router_topk=2)
-
- def test_legacy_swa_layers_translate_to_layer_types(self):
- config = ZayaConfig(num_hidden_layers=4, swa_layers=[4096, 0, 4096, 0], swa_rotary_base=10000)
-
- self.assertEqual(
- config.layer_types, ["sliding_attention", "full_attention", "sliding_attention", "full_attention"]
- )
- self.assertEqual(config.sliding_window, 4096)
- self.assertEqual(config.rope_parameters["full_attention"]["rope_theta"], config.default_theta)
- self.assertEqual(config.rope_parameters["sliding_attention"]["rope_theta"], 10000)
+ def test_num_experts_per_tok_validation(self):
+ with self.assertRaisesRegex(StrictDataclassClassValidationError, "num_experts_per_tok=1"):
+ ZayaConfig(num_experts_per_tok=2)
def test_sliding_attention_mask_is_used(self):
config = ZayaConfig(
vocab_size=128,
hidden_size=32,
- ffn_hidden_size=64,
+ intermediate_size=64,
num_hidden_layers=4,
num_experts=4,
num_attention_heads=4,
@@ -299,7 +285,7 @@ def test_cca_cache_matches_full_forward(self):
config = ZayaConfig(
vocab_size=128,
hidden_size=32,
- ffn_hidden_size=64,
+ intermediate_size=64,
num_hidden_layers=1,
num_experts=4,
num_attention_heads=4,
@@ -315,7 +301,7 @@ def test_cca_cache_matches_full_forward(self):
with torch.no_grad():
full = cca(hidden_states, None, None)
- cache = _make_zaya_cache(config)
+ cache = make_zaya_cache(config)
cca(hidden_states[:, :4], cache, None)
cached = cca(hidden_states[:, 4:], cache, None)
@@ -326,7 +312,7 @@ def test_cca_cache_matches_full_forward_multi_token(self):
config = ZayaConfig(
vocab_size=128,
hidden_size=32,
- ffn_hidden_size=64,
+ intermediate_size=64,
num_hidden_layers=1,
num_experts=4,
num_attention_heads=4,
@@ -342,7 +328,7 @@ def test_cca_cache_matches_full_forward_multi_token(self):
with torch.no_grad():
full = cca(hidden_states, None, None)
- cache = _make_zaya_cache(config)
+ cache = make_zaya_cache(config)
cca(hidden_states[:, :3], cache, None)
cached = cca(hidden_states[:, 3:], cache, None)
@@ -351,7 +337,7 @@ def test_cca_cache_matches_full_forward_multi_token(self):
def test_zaya_cache_reorder_and_reset(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- cache = _make_zaya_cache(config)
+ cache = make_zaya_cache(config)
conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim
cache.update_conv_state(
torch.arange(2 * conv_state_size * 2, device=torch_device, dtype=torch.float32).view(
From f3e8e02c7b87632dc40ded0e062d93cec888e33a Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 19:05:14 +0800
Subject: [PATCH 14/36] align with official implement, check 74b conversion
---
src/transformers/models/zaya/convert_zaya_weights_to_hf.py | 4 +++-
src/transformers/models/zaya/modeling_zaya.py | 4 +++-
src/transformers/models/zaya/modular_zaya.py | 4 +++-
3 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
index ba9198b9c666..a1b9b357dc52 100644
--- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
+++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
@@ -156,7 +156,9 @@ def convert_config(input_dir: Path, output_dir: Path) -> None:
sliding_window = config_dict.get("sliding_window")
if sliding_window is None:
positive_windows = [int(window_size) for window_size in swa_layers if int(window_size) > 0]
- sliding_window = max(positive_windows) if positive_windows else None
+ # Original ZAYA stores the number of previous tokens attended by SWA layers. Transformers' sliding window
+ # is the total local attention span, including the current token.
+ sliding_window = max(positive_windows) + 1 if positive_windows else None
rope_parameters = {
"full_attention": {
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index 20662110b172..3f59f4fbee2c 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -877,9 +877,11 @@ def _update_causal_mask(
"past_key_values": past_key_values,
"position_ids": position_ids,
}
+ # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection.
+ sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None}
mask_creation_functions = {
"full_attention": lambda: create_causal_mask(**mask_kwargs),
- "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs),
+ "sliding_attention": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
}
causal_mask_mapping = {}
for layer_type in set(self.config.layer_types):
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index cbafc3200146..7c1fd957cab6 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -775,9 +775,11 @@ def _update_causal_mask(
"past_key_values": past_key_values,
"position_ids": position_ids,
}
+ # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection.
+ sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None}
mask_creation_functions = {
"full_attention": lambda: create_causal_mask(**mask_kwargs),
- "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs),
+ "sliding_attention": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
}
causal_mask_mapping = {}
for layer_type in set(self.config.layer_types):
From f4f206c576e9b873aef15692172406fd365e49a6 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 19:20:33 +0800
Subject: [PATCH 15/36] easier test
---
tests/models/zaya/test_modeling_zaya.py | 19 +++++--------------
1 file changed, 5 insertions(+), 14 deletions(-)
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index 027901d2ffce..80f9112a97fc 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -397,22 +397,13 @@ def test_model_logits(self):
inputs = self.get_inputs().to(model.model.embed_tokens.weight.device)
with torch.no_grad():
- outputs = model(**inputs, use_cache=False, output_hidden_states=True, return_dict=True)
+ logits = model(**inputs, use_cache=False, return_dict=True).logits.float().cpu()
- logits = outputs.logits.float().cpu()
- hidden_states = outputs.hidden_states[-1].float().cpu()
+ self.assertEqual(logits.shape, (1, inputs.input_ids.shape[-1], model.config.vocab_size))
+ self.assertTrue(torch.isfinite(logits).all().item())
- EXPECTED_HIDDEN_MEAN = torch.tensor(
- [[0.0399, -0.0123, -0.0560, -0.0480, -0.0627, -0.0362, -0.0220, 0.0004, -0.0321, -0.0263, 0.0046]]
- )
- torch.testing.assert_close(hidden_states.mean(-1), EXPECTED_HIDDEN_MEAN, rtol=1e-2, atol=1e-2)
-
- EXPECTED_HIDDEN_SLICE = torch.tensor([-2.7812, 0.3320, 4.1562, -0.4395, 1.6406, 1.3359, -1.4531, -2.6719, 5.5000, -4.7500, 2.0625, 0.2930, -2.2344, -2.6094, 2.0781, 2.5000, 0.7969, 0.6836, -0.5469, 1.3906]) # fmt: skip
- torch.testing.assert_close(hidden_states[0, 0, :20], EXPECTED_HIDDEN_SLICE, rtol=1e-2, atol=1e-2)
-
- EXPECTED_LOGITS_SLICE = torch.tensor([-2.3438, 1.7344, 3.7656, -3.8750, 0.4707, -0.7422, -2.5938, -2.7188, -2.9375, -2.9844, -3.0000, -3.0000, -3.0000, -3.0000, -3.0156, -3.0000, -3.0000, -3.0000, -3.0000, -3.0000]) # fmt: skip
- torch.testing.assert_close(logits[0, -1, :20], EXPECTED_LOGITS_SLICE, rtol=1e-2, atol=1e-2)
- self.assertEqual(logits[0, -1].argmax().item(), 107)
+ expected_argmax = torch.tensor([[105, 9731, 107, 740, 564, 1601, 611, 236881, 236881, 107, 107]])
+ torch.testing.assert_close(logits.argmax(-1), expected_argmax)
@slow
def test_model_cache_matches_full_forward(self):
From 7c48ee10eadc7465bd08a51812c320458b5dd1cc Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 20:09:31 +0800
Subject: [PATCH 16/36] remove mapping since we convert the ckpt
---
src/transformers/conversion_mapping.py | 17 -----------------
src/transformers/models/zaya/modeling_zaya.py | 15 ++++++++++++++-
src/transformers/models/zaya/modular_zaya.py | 15 ++++++++++++++-
3 files changed, 28 insertions(+), 19 deletions(-)
diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py
index 4c0a7942a698..09bea78c96d6 100755
--- a/src/transformers/conversion_mapping.py
+++ b/src/transformers/conversion_mapping.py
@@ -561,23 +561,6 @@ def _build_checkpoint_conversion_mapping():
operations=[Transpose(1, 2, check_dims=True)],
),
],
- "zaya": [
- WeightRenaming(r"self_attn\.qkv\.conv_qk\.0\.", "self_attn.qkv.conv_qk_depthwise."),
- WeightRenaming(r"self_attn\.qkv\.conv_qk\.1\.", "self_attn.qkv.conv_qk_grouped."),
- WeightRenaming(r"zaya_block\.router\.router_mlp\.0\.", "zaya_block.router.router_mlp.fc1."),
- WeightRenaming(r"zaya_block\.router\.router_mlp\.2\.", "zaya_block.router.router_mlp.fc2."),
- WeightRenaming(r"zaya_block\.router\.router_mlp\.4\.", "zaya_block.router.router_mlp.out_proj."),
- WeightConverter(
- source_patterns="zaya_block.experts.local_experts.*.linear_fc1.weight",
- target_patterns="zaya_block.experts.gate_up_proj",
- operations=[MergeModulelist(dim=0)],
- ),
- WeightConverter(
- source_patterns="zaya_block.experts.local_experts.*.linear_fc2.weight",
- target_patterns="zaya_block.experts.down_proj",
- operations=[MergeModulelist(dim=0)],
- ),
- ],
"phimoe": [
WeightRenaming(".block_sparse_moe.", ".mlp."),
WeightRenaming(".gate.weight", ".router.weight"),
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index 3f59f4fbee2c..58d2201ce747 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -736,7 +736,20 @@ def make_zaya_cache(config: ZayaConfig) -> DynamicCache:
"""
Create ZAYA's native hybrid cache.
- `config.layer_types` is reserved for full/sliding attention masks and RoPE parameters. Cache layers use the native hybrid layout because every ZAYA decoder layer has attention, convolution, and recurrent states.
+ ZAYA uses `config.layer_types` for the attention mask and RoPE variant of each layer (`"full_attention"` or
+ `"sliding_attention"`). That is separate from the cache layout: every ZAYA decoder layer needs the native
+ `"hybrid"` cache layer because it stores all three states used during decoding:
+
+ - The regular dynamic attention KV cache, updated after the CCA projection and RoPE application.
+ - `conv_states`, the pre-convolution q/k tail used by `ZayaCCAProjection` on the next decoding step. Its channel
+ dimension is `num_attention_heads * head_dim + num_key_value_heads * head_dim`, and its time dimension is
+ `cca_time0 + cca_time1 - 2`.
+ - `recurrent_states`, ZAYA's delayed value state. It stores the previous token's `val_proj2` output (the legacy
+ `prev_h2`/second value projection state), so the next token can build its value from the current `val_proj1`
+ output plus the cached delayed `val_proj2`.
+
+ The copied config only changes `layer_types` to `"hybrid"` so `DynamicCache` instantiates
+ `LinearAttentionAndFullAttentionLayer`; it does not alter the model's mask or RoPE layer types.
"""
cache_config = copy.copy(config)
cache_config.layer_types = ["hybrid"] * config.num_hidden_layers
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 7c1fd957cab6..423848c7f01d 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -176,7 +176,20 @@ def make_zaya_cache(config: ZayaConfig) -> DynamicCache:
"""
Create ZAYA's native hybrid cache.
- `config.layer_types` is reserved for full/sliding attention masks and RoPE parameters. Cache layers use the native hybrid layout because every ZAYA decoder layer has attention, convolution, and recurrent states.
+ ZAYA uses `config.layer_types` for the attention mask and RoPE variant of each layer (`"full_attention"` or
+ `"sliding_attention"`). That is separate from the cache layout: every ZAYA decoder layer needs the native
+ `"hybrid"` cache layer because it stores all three states used during decoding:
+
+ - The regular dynamic attention KV cache, updated after the CCA projection and RoPE application.
+ - `conv_states`, the pre-convolution q/k tail used by `ZayaCCAProjection` on the next decoding step. Its channel
+ dimension is `num_attention_heads * head_dim + num_key_value_heads * head_dim`, and its time dimension is
+ `cca_time0 + cca_time1 - 2`.
+ - `recurrent_states`, ZAYA's delayed value state. It stores the previous token's `val_proj2` output (the legacy
+ `prev_h2`/second value projection state), so the next token can build its value from the current `val_proj1`
+ output plus the cached delayed `val_proj2`.
+
+ The copied config only changes `layer_types` to `"hybrid"` so `DynamicCache` instantiates
+ `LinearAttentionAndFullAttentionLayer`; it does not alter the model's mask or RoPE layer types.
"""
cache_config = copy.copy(config)
cache_config.layer_types = ["hybrid"] * config.num_hidden_layers
From 498c2522c4cfae79b3f4fe0abeeb4946435eda74 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 12 May 2026 21:17:11 +0800
Subject: [PATCH 17/36] use default_swa_theta
---
src/transformers/models/zaya/configuration_zaya.py | 6 ++----
src/transformers/models/zaya/convert_zaya_weights_to_hf.py | 5 ++---
src/transformers/models/zaya/modular_zaya.py | 6 ++----
3 files changed, 6 insertions(+), 11 deletions(-)
diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py
index 479d07dea7d4..26bf600d413b 100644
--- a/src/transformers/models/zaya/configuration_zaya.py
+++ b/src/transformers/models/zaya/configuration_zaya.py
@@ -49,8 +49,6 @@ class ZayaConfig(PreTrainedConfig):
Second temporal parameter of the CCA projection.
layer_types (`list[str]`, *optional*):
Per-layer selector for standard RoPE versus SWA RoPE embeddings.
- swa_rotary_base (`float`, *optional*):
- RoPE base used by SWA layers.
```python
>>> from transformers import ZayaConfig, ZayaModel
@@ -65,6 +63,7 @@ class ZayaConfig(PreTrainedConfig):
model_type = "zaya"
keys_to_ignore_at_inference = ["past_key_values"]
default_theta = 5000000.0
+ default_swa_theta = 10000.0
vocab_size: int = 262272
hidden_size: int = 2048
@@ -91,7 +90,6 @@ class ZayaConfig(PreTrainedConfig):
cca_time1: int = 2
sliding_window: int | None = None
layer_types: list[str] | None = None
- swa_rotary_base: float | int = 10000.0
output_router_logits: bool = False
pad_token_id: int | None = 0
bos_token_id: int | None = 2
@@ -110,7 +108,7 @@ def __post_init__(self, **kwargs):
},
"sliding_attention": {
"rope_type": "default",
- "rope_theta": self.swa_rotary_base,
+ "rope_theta": self.default_swa_theta,
"partial_rotary_factor": self.partial_rotary_factor,
},
}
diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
index a1b9b357dc52..63a0bd94142f 100644
--- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
+++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
@@ -144,7 +144,7 @@ def convert_config(input_dir: Path, output_dir: Path) -> None:
layer_types = _convert_layer_types(config_dict, old_num_hidden_layers, new_num_hidden_layers)
partial_rotary_factor = config_dict.get("partial_rotary_factor", ZayaConfig.partial_rotary_factor)
rope_theta = config_dict.get("rope_theta", ZayaConfig.default_theta)
- swa_rotary_base = config_dict.get("swa_rotary_base", ZayaConfig.swa_rotary_base)
+ swa_rotary_base = config_dict.get("swa_rotary_base", ZayaConfig.default_swa_theta)
intermediate_size = config_dict.get(
"intermediate_size", config_dict.get("ffn_hidden_size", ZayaConfig.intermediate_size)
)
@@ -173,7 +173,7 @@ def convert_config(input_dir: Path, output_dir: Path) -> None:
},
}
- for key in (*_UNUSED_CONFIG_KEYS, "swa_layers", "rope_theta"):
+ for key in (*_UNUSED_CONFIG_KEYS, "swa_layers", "rope_theta", "swa_rotary_base"):
config_dict.pop(key, None)
config_dict.update(
@@ -184,7 +184,6 @@ def convert_config(input_dir: Path, output_dir: Path) -> None:
"num_experts_per_tok": num_experts_per_tok,
"layer_types": layer_types,
"sliding_window": sliding_window,
- "swa_rotary_base": swa_rotary_base,
"rope_parameters": {layer_type: rope_parameters[layer_type] for layer_type in set(layer_types)},
}
)
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 423848c7f01d..55d2219fac3f 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -72,8 +72,6 @@ class ZayaConfig(PreTrainedConfig):
Second temporal parameter of the CCA projection.
layer_types (`list[str]`, *optional*):
Per-layer selector for standard RoPE versus SWA RoPE embeddings.
- swa_rotary_base (`float`, *optional*):
- RoPE base used by SWA layers.
```python
>>> from transformers import ZayaConfig, ZayaModel
@@ -88,6 +86,7 @@ class ZayaConfig(PreTrainedConfig):
model_type = "zaya"
keys_to_ignore_at_inference = ["past_key_values"]
default_theta = 5000000.0
+ default_swa_theta = 10000.0
vocab_size: int = 262272
hidden_size: int = 2048
@@ -114,7 +113,6 @@ class ZayaConfig(PreTrainedConfig):
cca_time1: int = 2
sliding_window: int | None = None
layer_types: list[str] | None = None
- swa_rotary_base: float | int = 10000.0
output_router_logits: bool = False
pad_token_id: int | None = 0
bos_token_id: int | None = 2
@@ -133,7 +131,7 @@ def __post_init__(self, **kwargs):
},
"sliding_attention": {
"rope_type": "default",
- "rope_theta": self.swa_rotary_base,
+ "rope_theta": self.default_swa_theta,
"partial_rotary_factor": self.partial_rotary_factor,
},
}
From 3d6306129c568f8dd32b9dd3936dc9c8340db55a Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Wed, 13 May 2026 10:51:29 +0800
Subject: [PATCH 18/36] update date
---
docs/source/en/model_doc/zaya.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md
index 24468b8df65f..e6a220adbecf 100644
--- a/docs/source/en/model_doc/zaya.md
+++ b/docs/source/en/model_doc/zaya.md
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
-*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-09.*
+*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-13.*
# ZAYA
From 4d742969fb022a7485d767b31d907f121fdd396d Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Wed, 13 May 2026 10:54:01 +0800
Subject: [PATCH 19/36] temp init
---
src/transformers/models/zaya/modular_zaya.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 55d2219fac3f..fb66f6b0522f 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -641,6 +641,8 @@ def _init_weights(self, module):
if module.has_residual_scale:
init.ones_(module.residual_scale)
init.zeros_(module.residual_bias)
+ elif isinstance(module, ZayaCCAProjection):
+ init.ones_(module.temp)
elif isinstance(module, ZayaRouter):
if module.use_eda:
init.ones_(module.router_states_scale)
From d77d5d47e7a554ece9cc85a05545aa3034e164c8 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Wed, 13 May 2026 11:11:32 +0800
Subject: [PATCH 20/36] modular
---
src/transformers/models/zaya/modeling_zaya.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index 58d2201ce747..008e166b416e 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -713,6 +713,8 @@ def _init_weights(self, module):
if module.has_residual_scale:
init.ones_(module.residual_scale)
init.zeros_(module.residual_bias)
+ elif isinstance(module, ZayaCCAProjection):
+ init.ones_(module.temp)
elif isinstance(module, ZayaRouter):
if module.use_eda:
init.ones_(module.router_states_scale)
From 1c16fecb90f8a0ea54ba68db51e541ec697557e9 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Wed, 13 May 2026 18:05:14 +0800
Subject: [PATCH 21/36] better residual scaling
---
.../models/zaya/convert_zaya_weights_to_hf.py | 23 ++-
src/transformers/models/zaya/modeling_zaya.py | 168 ++++++++----------
src/transformers/models/zaya/modular_zaya.py | 95 +++++-----
3 files changed, 133 insertions(+), 153 deletions(-)
diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
index 63a0bd94142f..228532e53fd4 100644
--- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
+++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
@@ -59,8 +59,10 @@
def _rename_common(rest: str) -> str:
replacements = (
- ("self_attn.qkv.conv_qk.0.", "self_attn.qkv.conv_qk_depthwise."),
- ("self_attn.qkv.conv_qk.1.", "self_attn.qkv.conv_qk_grouped."),
+ ("self_attn.qkv.conv_qk.0.", "self_attn.qkv_proj.conv_qk_depthwise."),
+ ("self_attn.qkv.conv_qk.1.", "self_attn.qkv_proj.conv_qk_grouped."),
+ ("self_attn.qkv.temp", "self_attn.temp"),
+ ("self_attn.qkv.", "self_attn.qkv_proj."),
("zaya_block.router.router_mlp.0.", "zaya_block.router.router_mlp.fc1."),
("zaya_block.router.router_mlp.2.", "zaya_block.router.router_mlp.fc2."),
("zaya_block.router.router_mlp.4.", "zaya_block.router.router_mlp.out_proj."),
@@ -87,12 +89,15 @@ def _expert_target(name: str) -> tuple[str, int] | None:
return target, expert_idx
-def convert_weight_name(name: str) -> str | None:
+def convert_weight_name(name: str, old_num_hidden_layers: int | None = None) -> str | None:
if _expert_target(name) is not None:
return None
match = _LAYER_PATTERN.match(name)
if match is None:
+ if old_num_hidden_layers is not None and name.startswith("model.res_scale."):
+ new_layer_idx = old_num_hidden_layers // 2 - 1
+ return f"model.layers.{new_layer_idx}.post_mlp_res_scale.{name.removeprefix('model.res_scale.')}"
return name
old_layer_idx = int(match.group(1))
@@ -101,8 +106,12 @@ def convert_weight_name(name: str) -> str | None:
if old_layer_idx % 2 == 0:
rest = _rename_common(rest)
- if rest.startswith(("self_attn.", "input_norm.", "res_scale.")):
+ if rest.startswith(("self_attn.", "input_norm.")):
return f"model.layers.{new_layer_idx}.{rest}"
+ if rest.startswith("res_scale."):
+ if old_layer_idx == 0:
+ return f"model.input_{rest.removeprefix('res_scale.')}"
+ return f"model.layers.{new_layer_idx - 1}.post_mlp_res_scale.{rest.removeprefix('res_scale.')}"
else:
rest = _rename_common(rest)
if rest.startswith("zaya_block."):
@@ -209,6 +218,7 @@ def copy_non_weight_files(input_dir: Path, output_dir: Path) -> None:
def _build_weight_plan(input_dir: Path) -> tuple[dict[str, str], dict[str, list[str]], dict[str, str], dict]:
index = json.loads((input_dir / "model.safetensors.index.json").read_text())
old_weight_map = index["weight_map"]
+ old_num_hidden_layers = int(json.loads((input_dir / "config.json").read_text())["num_hidden_layers"])
converted_weight_map = {}
normal_sources_by_output_file = defaultdict(list)
expert_sources_by_target = defaultdict(list)
@@ -223,7 +233,7 @@ def _build_weight_plan(input_dir: Path) -> tuple[dict[str, str], dict[str, list[
converted_weight_map[target_key] = output_file_by_target[target_key]
continue
- target_key = convert_weight_name(source_key)
+ target_key = convert_weight_name(source_key, old_num_hidden_layers)
if target_key in converted_weight_map:
raise ValueError(f"Duplicate converted weight name: {target_key}")
converted_weight_map[target_key] = filename
@@ -253,6 +263,7 @@ def convert_safetensors(input_dir: Path, output_dir: Path) -> None:
if not safetensors_path.exists():
raise FileNotFoundError("Only safetensors ZAYA checkpoints are supported by this converter.")
+ old_num_hidden_layers = int(json.loads((input_dir / "config.json").read_text())["num_hidden_layers"])
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
state_dict = {}
@@ -263,7 +274,7 @@ def convert_safetensors(input_dir: Path, output_dir: Path) -> None:
target_key, expert_idx = expert_info
expert_groups[target_key].append((expert_idx, f.get_tensor(key)))
continue
- state_dict[convert_weight_name(key)] = f.get_tensor(key)
+ state_dict[convert_weight_name(key, old_num_hidden_layers)] = f.get_tensor(key)
for target_key, expert_tensors in expert_groups.items():
state_dict[target_key] = torch.stack([tensor for _, tensor in sorted(expert_tensors)], dim=0)
save_file(state_dict, output_dir / "model.safetensors", metadata=metadata)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index 008e166b416e..f9e1537e4972 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -31,7 +31,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, LinearAttentionAndFullAttentionLayer
from ...generation import GenerationMixin
-from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
+from ...integrations import use_experts_implementation, use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -199,8 +199,6 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
stride=1,
)
- self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads))
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -282,6 +280,43 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -321,44 +356,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
return q_embed, k_embed
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-
-
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
-@use_kernelized_func(apply_rotary_pos_emb)
class ZayaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -368,22 +365,23 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.num_key_value_heads = config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
+
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
- self.layer_n = layer_idx
+ self.qkv_proj = ZayaCCAProjection(
+ config=self.config,
+ layer_idx=layer_idx,
+ )
self.layer_type = config.layer_types[layer_idx]
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.qkv = ZayaCCAProjection(
- config=self.config,
- layer_idx=layer_idx,
- )
+ self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads))
def forward(
self,
@@ -399,7 +397,7 @@ def forward(
causal_mask = mask_mapping.get("causal")
padding_mask = mask_mapping.get("padding")
- query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask)
+ query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask)
norm_eps = torch.finfo(query_states.dtype).eps
head_dim_scale = self.scaling**-1
@@ -407,7 +405,7 @@ def forward(
head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
)
key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps))
- key_states = key_states * self.qkv.temp[None, None, :, None]
+ key_states = key_states * self.temp[None, None, :, None]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -417,7 +415,7 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n)
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
if isinstance(causal_mask, torch.Tensor):
causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]]
@@ -443,25 +441,12 @@ def forward(
return attn_output, attn_weights, past_key_values
-def _apply_residual_scaling(
- hidden_states: torch.Tensor,
- residual: torch.Tensor | None,
- residual_scaling,
- rms_norm: ZayaRMSNorm,
-) -> tuple[torch.Tensor, torch.Tensor]:
- residual, hidden_states = residual_scaling(residual, hidden_states)
- residual = hidden_states.to(torch.float32) if residual is None else hidden_states + residual
- hidden_states = rms_norm(residual.to(dtype=rms_norm.weight.dtype))
- return hidden_states, residual
-
-
class ZayaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__()
self.config = config
self.self_attn = ZayaAttention(config, layer_idx)
self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
- self.res_scale = ResidualScaling(config.hidden_size, has_residual_scale=layer_idx != 0)
self.zaya_block = ZayaSparseMoeBlock(
config,
config.num_experts,
@@ -471,18 +456,21 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
)
self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
self.post_attention_res_scale = ResidualScaling(config.hidden_size)
+ self.post_mlp_res_scale = ResidualScaling(config.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
- residual: torch.Tensor | None,
prev_router_hidden_states: torch.Tensor | None = None,
attention_mask: dict[str, Any] | None = None,
past_key_values: Cache | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]:
- hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
+ residual = hidden_states
+ # Matches the original ZAYA `residual_in_fp32` path; norm casts back to the parameter dtype below.
+ residual = residual.to(torch.float32)
+ hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype))
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
@@ -492,34 +480,31 @@ def forward(
**kwargs,
)
- hidden_states, residual = _apply_residual_scaling(
- hidden_states, residual, self.post_attention_res_scale, self.post_attention_norm
- )
+ residual = self.post_attention_res_scale(hidden_states, residual)
+ hidden_states = self.post_attention_norm(residual.to(dtype=self.post_attention_norm.weight.dtype))
hidden_states, prev_router_hidden_states, _ = self.zaya_block(
hidden_states,
prev_router_hidden_states,
)
- return hidden_states, self_attn_weights, residual, prev_router_hidden_states
+ hidden_states = self.post_mlp_res_scale(hidden_states, residual)
+
+ return hidden_states, self_attn_weights, prev_router_hidden_states
class ResidualScaling(nn.Module):
- def __init__(self, hidden_size: int, has_residual_scale: bool = True):
+ def __init__(self, hidden_size: int):
super().__init__()
- self.has_residual_scale = has_residual_scale
self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size))
self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size))
+ self.residual_scale = nn.Parameter(torch.ones(hidden_size))
+ self.residual_bias = nn.Parameter(torch.zeros(hidden_size))
- if self.has_residual_scale:
- self.residual_scale = nn.Parameter(torch.ones(hidden_size))
- self.residual_bias = nn.Parameter(torch.zeros(hidden_size))
-
- def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor):
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor):
hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale
- if self.has_residual_scale:
- residual = (residual + self.residual_bias) * self.residual_scale
- return residual, hidden_states
+ residual = (residual + self.residual_bias) * self.residual_scale
+ return hidden_states + residual
class ZayaRouterMLP(nn.Module):
@@ -710,11 +695,11 @@ def _init_weights(self, module):
if isinstance(module, ResidualScaling):
init.ones_(module.hidden_states_scale)
init.zeros_(module.hidden_states_bias)
- if module.has_residual_scale:
- init.ones_(module.residual_scale)
- init.zeros_(module.residual_bias)
- elif isinstance(module, ZayaCCAProjection):
- init.ones_(module.temp)
+ init.ones_(module.residual_scale)
+ init.zeros_(module.residual_bias)
+ elif isinstance(module, ZayaModel):
+ init.ones_(module.input_hidden_states_scale)
+ init.zeros_(module.input_hidden_states_bias)
elif isinstance(module, ZayaRouter):
if module.use_eda:
init.ones_(module.router_states_scale)
@@ -778,8 +763,9 @@ def __init__(self, config: ZayaConfig):
)
self.gradient_checkpointing = False
- self.res_scale = ResidualScaling(config.hidden_size)
+ self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size))
+ self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size))
self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
self.rotary_emb = ZayaRotaryEmbedding(config=config)
@@ -816,8 +802,6 @@ def forward(
raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.")
past_key_values = make_zaya_cache(self.config)
- residual = None
-
if position_ids is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
position_ids = torch.arange(
@@ -851,6 +835,8 @@ def forward(
for layer_type in set(self.config.layer_types)
}
+ hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale
+
prev_router_hidden_states = None
for layer_n, decoder_layer in enumerate(self.layers):
@@ -859,7 +845,6 @@ def forward(
mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask}
layer_outputs = decoder_layer(
hidden_states,
- residual,
prev_router_hidden_states,
attention_mask=mask_mapping,
past_key_values=past_key_values,
@@ -868,10 +853,9 @@ def forward(
)
hidden_states = layer_outputs[0]
- residual = layer_outputs[2]
- prev_router_hidden_states = layer_outputs[3]
+ prev_router_hidden_states = layer_outputs[2]
- hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm)
+ hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype))
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index fb66f6b0522f..f0becacb968c 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -42,7 +42,8 @@
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..afmoe.modeling_afmoe import AfmoeForCausalLM
from ..laguna.modeling_laguna import LagunaRotaryEmbedding
-from ..llama.modeling_llama import LlamaAttention, LlamaPreTrainedModel
+from ..llama.modeling_llama import LlamaPreTrainedModel
+from ..phi3.modeling_phi3 import Phi3Attention
from ..qwen3_5_moe.modeling_qwen3_5_moe import (
apply_rotary_pos_emb,
eager_attention_forward,
@@ -256,8 +257,6 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
stride=1,
)
- self.temp = nn.Parameter(torch.zeros(self.num_key_value_heads))
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -332,20 +331,20 @@ def forward(
return query, key, value
-class ZayaAttention(LlamaAttention):
+class ZayaAttention(Phi3Attention):
def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__(config, layer_idx)
- self.layer_n = layer_idx
+ del op_size # noqa: F821
self.layer_type = config.layer_types[layer_idx]
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
- self.num_key_value_heads = config.num_key_value_heads
- del self.q_proj
- del self.k_proj
- del self.v_proj
- self.qkv = ZayaCCAProjection(
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads))
+ self.qkv_proj = ZayaCCAProjection(
config=self.config,
layer_idx=layer_idx,
)
@@ -364,7 +363,7 @@ def forward(
causal_mask = mask_mapping.get("causal")
padding_mask = mask_mapping.get("padding")
- query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, padding_mask)
+ query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask)
norm_eps = torch.finfo(query_states.dtype).eps
head_dim_scale = self.scaling**-1
@@ -372,7 +371,7 @@ def forward(
head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
)
key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps))
- key_states = key_states * self.qkv.temp[None, None, :, None]
+ key_states = key_states * self.temp[None, None, :, None]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -382,7 +381,7 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n)
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
if isinstance(causal_mask, torch.Tensor):
causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]]
@@ -414,7 +413,6 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
self.config = config
self.self_attn = ZayaAttention(config, layer_idx)
self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
- self.res_scale = ResidualScaling(config.hidden_size, has_residual_scale=layer_idx != 0)
self.zaya_block = ZayaSparseMoeBlock(
config,
config.num_experts,
@@ -424,18 +422,21 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
)
self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
self.post_attention_res_scale = ResidualScaling(config.hidden_size)
+ self.post_mlp_res_scale = ResidualScaling(config.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
- residual: torch.Tensor | None,
prev_router_hidden_states: torch.Tensor | None = None,
attention_mask: dict[str, Any] | None = None,
past_key_values: Cache | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor | None]:
- hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.input_norm)
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
+ residual = hidden_states
+ # Matches the original ZAYA `residual_in_fp32` path; norm casts back to the parameter dtype below.
+ residual = residual.to(torch.float32)
+ hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype))
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
@@ -445,46 +446,31 @@ def forward(
**kwargs,
)
- hidden_states, residual = _apply_residual_scaling(
- hidden_states, residual, self.post_attention_res_scale, self.post_attention_norm
- )
+ residual = self.post_attention_res_scale(hidden_states, residual)
+ hidden_states = self.post_attention_norm(residual.to(dtype=self.post_attention_norm.weight.dtype))
hidden_states, prev_router_hidden_states, _ = self.zaya_block(
hidden_states,
prev_router_hidden_states,
)
- return hidden_states, self_attn_weights, residual, prev_router_hidden_states
+ hidden_states = self.post_mlp_res_scale(hidden_states, residual)
+
+ return hidden_states, self_attn_weights, prev_router_hidden_states
class ResidualScaling(nn.Module):
- def __init__(self, hidden_size: int, has_residual_scale: bool = True):
+ def __init__(self, hidden_size: int):
super().__init__()
- self.has_residual_scale = has_residual_scale
self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size))
self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size))
+ self.residual_scale = nn.Parameter(torch.ones(hidden_size))
+ self.residual_bias = nn.Parameter(torch.zeros(hidden_size))
- if self.has_residual_scale:
- self.residual_scale = nn.Parameter(torch.ones(hidden_size))
- self.residual_bias = nn.Parameter(torch.zeros(hidden_size))
-
- def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor):
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor):
hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale
- if self.has_residual_scale:
- residual = (residual + self.residual_bias) * self.residual_scale
- return residual, hidden_states
-
-
-def _apply_residual_scaling(
- hidden_states: torch.Tensor,
- residual: torch.Tensor | None,
- residual_scaling,
- rms_norm: ZayaRMSNorm,
-) -> tuple[torch.Tensor, torch.Tensor]:
- residual, hidden_states = residual_scaling(residual, hidden_states)
- residual = hidden_states.to(torch.float32) if residual is None else hidden_states + residual
- hidden_states = rms_norm(residual.to(dtype=rms_norm.weight.dtype))
- return hidden_states, residual
+ residual = (residual + self.residual_bias) * self.residual_scale
+ return hidden_states + residual
class ZayaRouterMLP(nn.Module):
@@ -638,11 +624,11 @@ def _init_weights(self, module):
if isinstance(module, ResidualScaling):
init.ones_(module.hidden_states_scale)
init.zeros_(module.hidden_states_bias)
- if module.has_residual_scale:
- init.ones_(module.residual_scale)
- init.zeros_(module.residual_bias)
- elif isinstance(module, ZayaCCAProjection):
- init.ones_(module.temp)
+ init.ones_(module.residual_scale)
+ init.zeros_(module.residual_bias)
+ elif isinstance(module, ZayaModel):
+ init.ones_(module.input_hidden_states_scale)
+ init.zeros_(module.input_hidden_states_bias)
elif isinstance(module, ZayaRouter):
if module.use_eda:
init.ones_(module.router_states_scale)
@@ -674,8 +660,9 @@ def __init__(self, config: ZayaConfig):
)
self.gradient_checkpointing = False
- self.res_scale = ResidualScaling(config.hidden_size)
+ self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size))
+ self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size))
self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
self.rotary_emb = ZayaRotaryEmbedding(config=config)
@@ -712,8 +699,6 @@ def forward(
raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.")
past_key_values = make_zaya_cache(self.config)
- residual = None
-
if position_ids is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
position_ids = torch.arange(
@@ -747,6 +732,8 @@ def forward(
for layer_type in set(self.config.layer_types)
}
+ hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale
+
prev_router_hidden_states = None
for layer_n, decoder_layer in enumerate(self.layers):
@@ -755,7 +742,6 @@ def forward(
mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask}
layer_outputs = decoder_layer(
hidden_states,
- residual,
prev_router_hidden_states,
attention_mask=mask_mapping,
past_key_values=past_key_values,
@@ -764,10 +750,9 @@ def forward(
)
hidden_states = layer_outputs[0]
- residual = layer_outputs[2]
- prev_router_hidden_states = layer_outputs[3]
+ prev_router_hidden_states = layer_outputs[2]
- hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm)
+ hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype))
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
From 3f53fbca6674db5d58bfb5cdb23def9f61624168 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Wed, 13 May 2026 18:13:36 +0800
Subject: [PATCH 22/36] better cache
---
docs/source/en/model_doc/zaya.md | 8 ---
src/transformers/cache_utils.py | 54 +++++++++----------
.../models/zaya/configuration_zaya.py | 10 ++++
src/transformers/models/zaya/modeling_zaya.py | 42 ++-------------
src/transformers/models/zaya/modular_zaya.py | 52 +++++-------------
tests/models/zaya/test_modeling_zaya.py | 8 +--
6 files changed, 57 insertions(+), 117 deletions(-)
diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md
index e6a220adbecf..06beb12e2e6f 100644
--- a/docs/source/en/model_doc/zaya.md
+++ b/docs/source/en/model_doc/zaya.md
@@ -27,14 +27,6 @@ and Zyphra's technical reports.
This model was contributed by [JJJYmmm](https://github.com/JJJYmmm).
-
-
-When building a manual generation loop with `past_key_values`, use [`~models.zaya.modeling_zaya.make_zaya_cache`] to
-create ZAYA's cache. ZAYA uses `config.layer_types` for full/sliding attention masks and RoPE parameters, while its
-cache uses the native hybrid layout needed by the attention, convolution, and recurrent states.
-
-
-
## Usage examples
```python
diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py
index dfef404a42f1..a9eee165b68f 100644
--- a/src/transformers/cache_utils.py
+++ b/src/transformers/cache_utils.py
@@ -23,9 +23,9 @@
logger = logging.get_logger(__name__)
-# Registry mapping ``config.layer_types[i]`` -> the dynamic cache layer class to build for
-# that layer. ``DynamicCache.__init__`` consults this mapping when a ``config`` is provided
-# so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own
+# Registry mapping ``config.cache_layer_types[i]`` (or ``config.layer_types[i]`` when the cache-specific field is not
+# set) -> the dynamic cache layer class to build for that layer. ``DynamicCache.__init__`` consults this mapping when a
+# ``config`` is provided so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own
# cache-layer subclass and stop needing a model-specific ``Cache`` subclass.
#
# A cache layer subclass with a class attribute ``layer_type = "..."`` auto-registers via
@@ -34,6 +34,24 @@
LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {}
+def _get_layer_types_for_cache(decoder_config: PreTrainedConfig) -> list[str]:
+ sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
+ decoder_config, "attention_chunk_size", None
+ )
+ layer_types = getattr(decoder_config, "cache_layer_types", None) or getattr(decoder_config, "layer_types", None)
+ if layer_types is None:
+ layer_types = []
+ for _ in range(decoder_config.num_hidden_layers):
+ if sliding_window is not None:
+ layer_types.append("sliding_attention")
+ else:
+ layer_types.append("full_attention")
+ # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
+ if hasattr(decoder_config, "num_kv_shared_layers"):
+ layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
+ return layer_types
+
+
class CacheLayerMixin(ABC):
"""Base, abstract class for a single layer's cache."""
@@ -1280,20 +1298,7 @@ def __init__(
# If a config is passed, use it to infer the layer types and initialize accordingly
if config is not None:
decoder_config = config.get_text_config(decoder=True)
- sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
- decoder_config, "attention_chunk_size", None
- )
- layer_types = getattr(decoder_config, "layer_types", None)
- if layer_types is None:
- layer_types = []
- for _ in range(decoder_config.num_hidden_layers):
- if sliding_window is not None:
- layer_types.append("sliding_attention")
- else:
- layer_types.append("full_attention")
- # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
- if hasattr(decoder_config, "num_kv_shared_layers"):
- layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
+ layer_types = _get_layer_types_for_cache(decoder_config)
for layer_type in layer_types:
cache_cls = LAYER_TYPE_CACHE_MAPPING.get(layer_type, DynamicLayer)
@@ -1382,18 +1387,7 @@ def __init__(
**kwargs,
):
config = config.get_text_config(decoder=True)
- layer_types = getattr(config, "layer_types", None)
- # If `layer_types` is not explicitly provided, infer if the model is fully sliding
- if layer_types is None:
- if getattr(config, "sliding_window", None) is not None:
- layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
- elif getattr(config, "attention_chunk_size", None) is not None:
- layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
- else:
- layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
- # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
- if hasattr(config, "num_kv_shared_layers"):
- layer_types = layer_types[: -config.num_kv_shared_layers]
+ layer_types = _get_layer_types_for_cache(config)
sliding_layer_types = {
name
@@ -1413,6 +1407,8 @@ def __init__(
# LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache
elif layer_type in ("mamba", "conv", "linear_attention", "moe"):
layer = LinearAttentionLayer()
+ elif layer_type == "hybrid":
+ layer = LinearAttentionAndFullAttentionLayer(config)
else:
layer = StaticLayer(max_cache_len=max_cache_len)
layers.append(layer)
diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py
index 26bf600d413b..9f373e9a6d21 100644
--- a/src/transformers/models/zaya/configuration_zaya.py
+++ b/src/transformers/models/zaya/configuration_zaya.py
@@ -49,6 +49,8 @@ class ZayaConfig(PreTrainedConfig):
Second temporal parameter of the CCA projection.
layer_types (`list[str]`, *optional*):
Per-layer selector for standard RoPE versus SWA RoPE embeddings.
+ cache_layer_types (`list[str]`, *optional*):
+ Per-layer selector for cache layout. ZAYA uses the native `"hybrid"` cache layer for every decoder layer.
```python
>>> from transformers import ZayaConfig, ZayaModel
@@ -90,6 +92,7 @@ class ZayaConfig(PreTrainedConfig):
cca_time1: int = 2
sliding_window: int | None = None
layer_types: list[str] | None = None
+ cache_layer_types: list[str] | None = None
output_router_logits: bool = False
pad_token_id: int | None = 0
bos_token_id: int | None = 2
@@ -99,6 +102,9 @@ def __post_init__(self, **kwargs):
self.layer_types = (
["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types)
)
+ self.cache_layer_types = (
+ ["hybrid"] * self.num_hidden_layers if self.cache_layer_types is None else list(self.cache_layer_types)
+ )
default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = {
"full_attention": {
@@ -131,6 +137,10 @@ def validate_architecture(self):
raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
if len(self.layer_types) != self.num_hidden_layers:
raise ValueError("`layer_types` must have one entry per hidden layer.")
+ if len(self.cache_layer_types) != self.num_hidden_layers:
+ raise ValueError("`cache_layer_types` must have one entry per hidden layer.")
+ if invalid_cache_layer_types := set(self.cache_layer_types) - {"hybrid"}:
+ raise ValueError(f"`cache_layer_types` contains unsupported values: {sorted(invalid_cache_layer_types)}.")
if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}:
raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.")
if "sliding_attention" in self.layer_types and self.sliding_window is None:
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index f9e1537e4972..c958f34054bf 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -19,7 +19,6 @@
# limitations under the License.
-import copy
from collections.abc import Callable
from typing import Any, Optional
@@ -29,7 +28,7 @@
from torch.nn import init
from ...activations import ACT2FN
-from ...cache_utils import Cache, DynamicCache, LinearAttentionAndFullAttentionLayer
+from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_experts_implementation, use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
@@ -719,44 +718,13 @@ def _init_weights(self, module):
getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq)
-def make_zaya_cache(config: ZayaConfig) -> DynamicCache:
- """
- Create ZAYA's native hybrid cache.
-
- ZAYA uses `config.layer_types` for the attention mask and RoPE variant of each layer (`"full_attention"` or
- `"sliding_attention"`). That is separate from the cache layout: every ZAYA decoder layer needs the native
- `"hybrid"` cache layer because it stores all three states used during decoding:
-
- - The regular dynamic attention KV cache, updated after the CCA projection and RoPE application.
- - `conv_states`, the pre-convolution q/k tail used by `ZayaCCAProjection` on the next decoding step. Its channel
- dimension is `num_attention_heads * head_dim + num_key_value_heads * head_dim`, and its time dimension is
- `cca_time0 + cca_time1 - 2`.
- - `recurrent_states`, ZAYA's delayed value state. It stores the previous token's `val_proj2` output (the legacy
- `prev_h2`/second value projection state), so the next token can build its value from the current `val_proj1`
- output plus the cached delayed `val_proj2`.
-
- The copied config only changes `layer_types` to `"hybrid"` so `DynamicCache` instantiates
- `LinearAttentionAndFullAttentionLayer`; it does not alter the model's mask or RoPE layer types.
- """
- cache_config = copy.copy(config)
- cache_config.layer_types = ["hybrid"] * config.num_hidden_layers
- return DynamicCache(config=cache_config)
-
-
-def _is_zaya_cache(past_key_values: Cache) -> bool:
- return (
- isinstance(past_key_values, DynamicCache)
- and len(past_key_values.layers) > 0
- and isinstance(past_key_values.layers[0], LinearAttentionAndFullAttentionLayer)
- )
-
-
@auto_docstring
class ZayaModel(ZayaPreTrainedModel):
def __init__(self, config: ZayaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
+ self.cache_layer_types = config.cache_layer_types
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
@@ -797,10 +765,8 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)):
- if past_key_values is not None and past_key_values.get_seq_length() > 0:
- raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.")
- past_key_values = make_zaya_cache(self.config)
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
if position_ids is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index f0becacb968c..7878e9ee8d16 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -14,7 +14,6 @@
"""PyTorch Zaya model."""
-import copy
from collections.abc import Callable
from typing import Any, Literal
@@ -26,7 +25,7 @@
from torch.nn import init
from ...activations import ACT2FN
-from ...cache_utils import Cache, DynamicCache, LinearAttentionAndFullAttentionLayer
+from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
@@ -73,6 +72,8 @@ class ZayaConfig(PreTrainedConfig):
Second temporal parameter of the CCA projection.
layer_types (`list[str]`, *optional*):
Per-layer selector for standard RoPE versus SWA RoPE embeddings.
+ cache_layer_types (`list[str]`, *optional*):
+ Per-layer selector for cache layout. ZAYA uses the native `"hybrid"` cache layer for every decoder layer.
```python
>>> from transformers import ZayaConfig, ZayaModel
@@ -114,6 +115,7 @@ class ZayaConfig(PreTrainedConfig):
cca_time1: int = 2
sliding_window: int | None = None
layer_types: list[str] | None = None
+ cache_layer_types: list[str] | None = None
output_router_logits: bool = False
pad_token_id: int | None = 0
bos_token_id: int | None = 2
@@ -123,6 +125,9 @@ def __post_init__(self, **kwargs):
self.layer_types = (
["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types)
)
+ self.cache_layer_types = (
+ ["hybrid"] * self.num_hidden_layers if self.cache_layer_types is None else list(self.cache_layer_types)
+ )
default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = {
"full_attention": {
@@ -155,6 +160,10 @@ def validate_architecture(self):
raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
if len(self.layer_types) != self.num_hidden_layers:
raise ValueError("`layer_types` must have one entry per hidden layer.")
+ if len(self.cache_layer_types) != self.num_hidden_layers:
+ raise ValueError("`cache_layer_types` must have one entry per hidden layer.")
+ if invalid_cache_layer_types := set(self.cache_layer_types) - {"hybrid"}:
+ raise ValueError(f"`cache_layer_types` contains unsupported values: {sorted(invalid_cache_layer_types)}.")
if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}:
raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.")
if "sliding_attention" in self.layer_types and self.sliding_window is None:
@@ -171,38 +180,6 @@ class ZayaRMSNorm(Qwen3MoeRMSNorm):
pass
-def make_zaya_cache(config: ZayaConfig) -> DynamicCache:
- """
- Create ZAYA's native hybrid cache.
-
- ZAYA uses `config.layer_types` for the attention mask and RoPE variant of each layer (`"full_attention"` or
- `"sliding_attention"`). That is separate from the cache layout: every ZAYA decoder layer needs the native
- `"hybrid"` cache layer because it stores all three states used during decoding:
-
- - The regular dynamic attention KV cache, updated after the CCA projection and RoPE application.
- - `conv_states`, the pre-convolution q/k tail used by `ZayaCCAProjection` on the next decoding step. Its channel
- dimension is `num_attention_heads * head_dim + num_key_value_heads * head_dim`, and its time dimension is
- `cca_time0 + cca_time1 - 2`.
- - `recurrent_states`, ZAYA's delayed value state. It stores the previous token's `val_proj2` output (the legacy
- `prev_h2`/second value projection state), so the next token can build its value from the current `val_proj1`
- output plus the cached delayed `val_proj2`.
-
- The copied config only changes `layer_types` to `"hybrid"` so `DynamicCache` instantiates
- `LinearAttentionAndFullAttentionLayer`; it does not alter the model's mask or RoPE layer types.
- """
- cache_config = copy.copy(config)
- cache_config.layer_types = ["hybrid"] * config.num_hidden_layers
- return DynamicCache(config=cache_config)
-
-
-def _is_zaya_cache(past_key_values: Cache) -> bool:
- return (
- isinstance(past_key_values, DynamicCache)
- and len(past_key_values.layers) > 0
- and isinstance(past_key_values.layers[0], LinearAttentionAndFullAttentionLayer)
- )
-
-
class ZayaCCAProjection(nn.Module):
"""
Projects hidden states into attention q/k/v states with ZAYA's CCA path.
@@ -654,6 +631,7 @@ def __init__(self, config: ZayaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
+ self.cache_layer_types = config.cache_layer_types
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
@@ -694,10 +672,8 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)):
- if past_key_values is not None and past_key_values.get_seq_length() > 0:
- raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.")
- past_key_values = make_zaya_cache(self.config)
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
if position_ids is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index 80f9112a97fc..94bd74093e15 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -27,7 +27,7 @@
from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel
from transformers.cache_utils import DynamicCache, LinearAttentionAndFullAttentionLayer
- from transformers.models.zaya.modeling_zaya import ZayaCCAProjection, make_zaya_cache
+ from transformers.models.zaya.modeling_zaya import ZayaCCAProjection
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
@@ -301,7 +301,7 @@ def test_cca_cache_matches_full_forward(self):
with torch.no_grad():
full = cca(hidden_states, None, None)
- cache = make_zaya_cache(config)
+ cache = DynamicCache(config=config)
cca(hidden_states[:, :4], cache, None)
cached = cca(hidden_states[:, 4:], cache, None)
@@ -328,7 +328,7 @@ def test_cca_cache_matches_full_forward_multi_token(self):
with torch.no_grad():
full = cca(hidden_states, None, None)
- cache = make_zaya_cache(config)
+ cache = DynamicCache(config=config)
cca(hidden_states[:, :3], cache, None)
cached = cca(hidden_states[:, 3:], cache, None)
@@ -337,7 +337,7 @@ def test_cca_cache_matches_full_forward_multi_token(self):
def test_zaya_cache_reorder_and_reset(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- cache = make_zaya_cache(config)
+ cache = DynamicCache(config=config)
conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim
cache.update_conv_state(
torch.arange(2 * conv_state_size * 2, device=torch_device, dtype=torch.float32).view(
From dc7ac50dd15c4e95d0b246b534f7c19c8fcc385c Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Wed, 13 May 2026 18:30:48 +0800
Subject: [PATCH 23/36] ops forget init again
---
src/transformers/models/zaya/modeling_zaya.py | 2 ++
src/transformers/models/zaya/modular_zaya.py | 2 ++
2 files changed, 4 insertions(+)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index c958f34054bf..b162dde52e02 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -699,6 +699,8 @@ def _init_weights(self, module):
elif isinstance(module, ZayaModel):
init.ones_(module.input_hidden_states_scale)
init.zeros_(module.input_hidden_states_bias)
+ elif isinstance(module, ZayaAttention):
+ init.zeros_(module.temp)
elif isinstance(module, ZayaRouter):
if module.use_eda:
init.ones_(module.router_states_scale)
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 7878e9ee8d16..04a6625b1313 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -606,6 +606,8 @@ def _init_weights(self, module):
elif isinstance(module, ZayaModel):
init.ones_(module.input_hidden_states_scale)
init.zeros_(module.input_hidden_states_bias)
+ elif isinstance(module, ZayaAttention):
+ init.zeros_(module.temp)
elif isinstance(module, ZayaRouter):
if module.use_eda:
init.ones_(module.router_states_scale)
From 8be4b1ee9c4871398093690664bea5b97b887a25 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Thu, 14 May 2026 14:18:18 +0800
Subject: [PATCH 24/36] better naming
---
src/transformers/cache_utils.py | 82 ++--
src/transformers/configuration_utils.py | 3 +-
src/transformers/models/zaya/__init__.py | 2 +-
.../models/zaya/configuration_zaya.py | 89 ++--
.../models/zaya/convert_zaya_weights_to_hf.py | 75 ++--
src/transformers/models/zaya/modeling_zaya.py | 298 ++++++-------
src/transformers/models/zaya/modular_zaya.py | 395 +++++++-----------
tests/models/zaya/test_modeling_zaya.py | 66 +--
8 files changed, 473 insertions(+), 537 deletions(-)
diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py
index a9eee165b68f..993643dfe390 100644
--- a/src/transformers/cache_utils.py
+++ b/src/transformers/cache_utils.py
@@ -23,9 +23,9 @@
logger = logging.get_logger(__name__)
-# Registry mapping ``config.cache_layer_types[i]`` (or ``config.layer_types[i]`` when the cache-specific field is not
-# set) -> the dynamic cache layer class to build for that layer. ``DynamicCache.__init__`` consults this mapping when a
-# ``config`` is provided so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own
+# Registry mapping ``config.layer_types[i]`` -> the dynamic cache layer class to build for
+# that layer. ``DynamicCache.__init__`` consults this mapping when a ``config`` is provided
+# so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own
# cache-layer subclass and stop needing a model-specific ``Cache`` subclass.
#
# A cache layer subclass with a class attribute ``layer_type = "..."`` auto-registers via
@@ -34,24 +34,6 @@
LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {}
-def _get_layer_types_for_cache(decoder_config: PreTrainedConfig) -> list[str]:
- sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
- decoder_config, "attention_chunk_size", None
- )
- layer_types = getattr(decoder_config, "cache_layer_types", None) or getattr(decoder_config, "layer_types", None)
- if layer_types is None:
- layer_types = []
- for _ in range(decoder_config.num_hidden_layers):
- if sliding_window is not None:
- layer_types.append("sliding_attention")
- else:
- layer_types.append("full_attention")
- # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
- if hasattr(decoder_config, "num_kv_shared_layers"):
- layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
- return layer_types
-
-
class CacheLayerMixin(ABC):
"""Base, abstract class for a single layer's cache."""
@@ -882,6 +864,33 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
DynamicLayer.reorder_cache(self, beam_idx)
+class LinearAttentionAndSlidingWindowAttentionLayer(LinearAttentionLayer, DynamicSlidingWindowLayer):
+ # The dynamic sliding attention part makes it non-compileable
+ is_compileable = False
+
+ def __init__(self, config: PreTrainedConfig | None = None):
+ DynamicSlidingWindowLayer.__init__(self, config)
+ LinearAttentionLayer.__init__(self)
+
+ def lazy_initialization(self, *args, **kwargs) -> None:
+ # When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args
+ if len(args) == 2 and len(kwargs) == 0:
+ DynamicSlidingWindowLayer.lazy_initialization(self, *args)
+ # Otherwise, for the LinearAttention cache, when it's called in `update_conv_state` or `update_recurrent_state`,
+ # it's always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states)
+ if len(args) == 0 and len(kwargs) == 1:
+ LinearAttentionLayer.lazy_initialization(self, **kwargs)
+
+ def reset(self) -> None:
+ LinearAttentionLayer.reset(self)
+ DynamicSlidingWindowLayer.reset(self)
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ LinearAttentionLayer.reorder_cache(self, beam_idx)
+ DynamicSlidingWindowLayer.reorder_cache(self, beam_idx)
+
+
# Pre-register the standard layer types (some classes are shared between multiple types,
# e.g. ``DynamicSlidingWindowLayer`` covers both ``"sliding_attention"`` and
# ``"chunked_attention"`` — those need an explicit map entry rather than the
@@ -901,6 +910,7 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
"moe": LinearAttentionLayer,
# Hybrid layers (e.g. zamba / zamba2) carry both a linear-attention state and a dynamic-attention state.
"hybrid": LinearAttentionAndFullAttentionLayer,
+ "hybrid_sliding": LinearAttentionAndSlidingWindowAttentionLayer,
}
)
@@ -1298,7 +1308,20 @@ def __init__(
# If a config is passed, use it to infer the layer types and initialize accordingly
if config is not None:
decoder_config = config.get_text_config(decoder=True)
- layer_types = _get_layer_types_for_cache(decoder_config)
+ sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
+ decoder_config, "attention_chunk_size", None
+ )
+ layer_types = getattr(decoder_config, "layer_types", None)
+ if layer_types is None:
+ layer_types = []
+ for _ in range(decoder_config.num_hidden_layers):
+ if sliding_window is not None:
+ layer_types.append("sliding_attention")
+ else:
+ layer_types.append("full_attention")
+ # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
+ if hasattr(decoder_config, "num_kv_shared_layers"):
+ layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
for layer_type in layer_types:
cache_cls = LAYER_TYPE_CACHE_MAPPING.get(layer_type, DynamicLayer)
@@ -1387,7 +1410,18 @@ def __init__(
**kwargs,
):
config = config.get_text_config(decoder=True)
- layer_types = _get_layer_types_for_cache(config)
+ layer_types = getattr(config, "layer_types", None)
+ # If `layer_types` is not explicitly provided, infer if the model is fully sliding
+ if layer_types is None:
+ if getattr(config, "sliding_window", None) is not None:
+ layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
+ elif getattr(config, "attention_chunk_size", None) is not None:
+ layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
+ else:
+ layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
+ # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
+ if hasattr(config, "num_kv_shared_layers"):
+ layer_types = layer_types[: -config.num_kv_shared_layers]
sliding_layer_types = {
name
@@ -1407,8 +1441,6 @@ def __init__(
# LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache
elif layer_type in ("mamba", "conv", "linear_attention", "moe"):
layer = LinearAttentionLayer()
- elif layer_type == "hybrid":
- layer = LinearAttentionAndFullAttentionLayer(config)
else:
layer = StaticLayer(max_cache_len=max_cache_len)
layers.append(layer)
diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py
index 7ba033c538d8..e495bbdc69c0 100755
--- a/src/transformers/configuration_utils.py
+++ b/src/transformers/configuration_utils.py
@@ -71,7 +71,8 @@
"attention",
"sparse",
"dense",
- "hybrid", # for layers that have both mamba and attention in zamba and zamba2
+ "hybrid", # for zamba/zamba2/zaya1, which use full attention + conv states
+ "hybrid_sliding", # for zaya1, which uses swa + conv states
"moe", # for nemotron_h, which uses either attention, mamba or moe
)
diff --git a/src/transformers/models/zaya/__init__.py b/src/transformers/models/zaya/__init__.py
index 54cc0c89f303..c28f97af94ea 100644
--- a/src/transformers/models/zaya/__init__.py
+++ b/src/transformers/models/zaya/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2025 Zyphra and The HuggingFace Inc. team. All rights reserved.
+# Copyright 2026 Zyphra and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py
index 9f373e9a6d21..6fb47ecb76a4 100644
--- a/src/transformers/models/zaya/configuration_zaya.py
+++ b/src/transformers/models/zaya/configuration_zaya.py
@@ -31,26 +31,14 @@
@strict
class ZayaConfig(PreTrainedConfig):
r"""
- intermediate_size (`int`, *optional*, defaults to 4096):
- Dimension of the feed-forward and expert hidden states.
- num_key_value_heads (`int`, *optional*, defaults to 2):
- Number of key/value groups.
- partial_rotary_factor (`float`, *optional*, defaults to 0.5):
- Fraction of each attention head dimension using rotary embeddings.
lm_head_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the language modeling head.
- num_experts_per_tok (`int`, *optional*, defaults to 1):
- Number of selected experts per token. ZAYA checkpoints use top-1 routing.
- zaya_mlp_expansion (`int`, *optional*, defaults to 256):
- Expansion size used by the dense ZAYA blocks.
+ router_hidden_size (`int`, *optional*, defaults to 256):
+ Hidden size used by the ZAYA router.
cca_time0 (`int`, *optional*, defaults to 2):
First temporal parameter of the CCA projection.
cca_time1 (`int`, *optional*, defaults to 2):
Second temporal parameter of the CCA projection.
- layer_types (`list[str]`, *optional*):
- Per-layer selector for standard RoPE versus SWA RoPE embeddings.
- cache_layer_types (`list[str]`, *optional*):
- Per-layer selector for cache layout. ZAYA uses the native `"hybrid"` cache layer for every decoder layer.
```python
>>> from transformers import ZayaConfig, ZayaModel
@@ -64,87 +52,76 @@ class ZayaConfig(PreTrainedConfig):
model_type = "zaya"
keys_to_ignore_at_inference = ["past_key_values"]
- default_theta = 5000000.0
- default_swa_theta = 10000.0
vocab_size: int = 262272
hidden_size: int = 2048
- intermediate_size: int = 4096
num_hidden_layers: int = 40
- num_experts: int = 16
num_attention_heads: int = 8
num_key_value_heads: int = 2
hidden_act: str = "silu"
- head_dim: int = 128
max_position_embeddings: int = 131072
initializer_range: float = 0.02
- norm_epsilon: float = 1e-5
+ rms_norm_eps: float = 1e-5
use_cache: bool = True
tie_word_embeddings: bool = True
rope_parameters: RopeParameters | dict | None = None
- partial_rotary_factor: float = 0.5
- attention_bias: bool = False
- lm_head_bias: bool = False
+ sliding_window: int | None = None
attention_dropout: float | int = 0.0
+ moe_intermediate_size: int = 2048
+
num_experts_per_tok: int = 1
- zaya_mlp_expansion: int = 256
- cca_time0: int = 2
- cca_time1: int = 2
- sliding_window: int | None = None
- layer_types: list[str] | None = None
- cache_layer_types: list[str] | None = None
+ num_experts: int = 16
output_router_logits: bool = False
+ layer_types: list[str] | None = None
pad_token_id: int | None = 0
bos_token_id: int | None = 2
eos_token_id: int | list[int] | None = 106
+ # Zaya-specific attention
+ head_dim: int = 128
+ attention_bias: bool = False
+
+ lm_head_bias: bool = False
+ router_hidden_size: int = 256
+ cca_time0: int = 2
+ cca_time1: int = 2
+
def __post_init__(self, **kwargs):
- self.layer_types = (
- ["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types)
- )
- self.cache_layer_types = (
- ["hybrid"] * self.num_hidden_layers if self.cache_layer_types is None else list(self.cache_layer_types)
- )
-
- default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = {
- "full_attention": {
+ self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types)
+
+ default_rope_params: dict[Literal["hybrid", "hybrid_sliding"], dict[str, Any]] = {
+ "hybrid": {
"rope_type": "default",
- "rope_theta": self.default_theta,
- "partial_rotary_factor": self.partial_rotary_factor,
+ "rope_theta": 5_000_000.0,
+ "partial_rotary_factor": 0.5,
},
- "sliding_attention": {
+ "hybrid_sliding": {
"rope_type": "default",
- "rope_theta": self.default_swa_theta,
- "partial_rotary_factor": self.partial_rotary_factor,
+ "rope_theta": 10_000.0,
+ "partial_rotary_factor": 0.5,
},
}
if self.rope_parameters is None:
- self.rope_parameters = {
- layer_type: default_rope_params[layer_type] for layer_type in set(self.layer_types)
- }
+ self.rope_parameters = default_rope_params
- super().__post_init__(**kwargs)
+ super().__post_init__(**kwargs, ignore_keys_at_rope_validation={"hybrid", "hybrid_sliding"})
def convert_rope_params_to_dict(self, **kwargs):
- # ZAYA uses nested RoPE parameters keyed by layer type. Keep the base RoPE BC conversion from treating them
- # like a single flat RoPE dict and injecting top-level keys such as `rope_theta`.
+ # No legacy flat RoPE format is supported here; conversion writes the nested ZAYA layer-type format directly.
return kwargs
def validate_architecture(self):
+ """Part of ``@strict``-powered validation."""
if self.num_experts_per_tok != 1:
raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.")
if self.num_attention_heads % self.num_key_value_heads != 0:
raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
if len(self.layer_types) != self.num_hidden_layers:
raise ValueError("`layer_types` must have one entry per hidden layer.")
- if len(self.cache_layer_types) != self.num_hidden_layers:
- raise ValueError("`cache_layer_types` must have one entry per hidden layer.")
- if invalid_cache_layer_types := set(self.cache_layer_types) - {"hybrid"}:
- raise ValueError(f"`cache_layer_types` contains unsupported values: {sorted(invalid_cache_layer_types)}.")
- if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}:
+ if invalid_layer_types := set(self.layer_types) - {"hybrid", "hybrid_sliding"}:
raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.")
- if "sliding_attention" in self.layer_types and self.sliding_window is None:
- raise ValueError("`sliding_window` must be set when `layer_types` contains `sliding_attention`.")
+ if "hybrid_sliding" in self.layer_types and self.sliding_window is None:
+ raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.")
if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError("`sliding_window` must be a strictly positive integer.")
diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
index 228532e53fd4..2ac6cb7df869 100644
--- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
+++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
@@ -27,6 +27,8 @@
from transformers import ZayaConfig
+_DEFAULT_ROPE_THETA = 5_000_000.0
+_DEFAULT_SWA_ROPE_THETA = 10_000.0
_LAYER_PATTERN = re.compile(r"^model\.layers\.(\d+)\.(.+)$")
_LOCAL_EXPERT_PATTERN = re.compile(
r"^model\.layers\.(\d+)\.zaya_block\.experts\.local_experts\.(\d+)\.linear_fc([12])\.weight$"
@@ -35,8 +37,11 @@
_UNUSED_CONFIG_KEYS = (
"cca",
"num_query_groups",
+ "intermediate_size",
"ffn_hidden_size",
"moe_router_topk",
+ "norm_epsilon",
+ "zaya_mlp_expansion",
"activation_func",
"normalization",
"add_bias_linear",
@@ -61,11 +66,18 @@ def _rename_common(rest: str) -> str:
replacements = (
("self_attn.qkv.conv_qk.0.", "self_attn.qkv_proj.conv_qk_depthwise."),
("self_attn.qkv.conv_qk.1.", "self_attn.qkv_proj.conv_qk_grouped."),
- ("self_attn.qkv.temp", "self_attn.temp"),
+ ("self_attn.qkv.temp", "self_attn.qk_norm.temp"),
+ ("self_attn.qkv.linear_q.", "self_attn.qkv_proj.q_proj."),
+ ("self_attn.qkv.linear_k.", "self_attn.qkv_proj.k_proj."),
+ ("self_attn.qkv.val_proj1.", "self_attn.qkv_proj.v_proj_current."),
+ ("self_attn.qkv.val_proj2.", "self_attn.qkv_proj.v_proj_delayed."),
("self_attn.qkv.", "self_attn.qkv_proj."),
- ("zaya_block.router.router_mlp.0.", "zaya_block.router.router_mlp.fc1."),
- ("zaya_block.router.router_mlp.2.", "zaya_block.router.router_mlp.fc2."),
- ("zaya_block.router.router_mlp.4.", "zaya_block.router.router_mlp.out_proj."),
+ ("zaya_block.router.rmsnorm_eda.", "mlp.gate.router_mlp.rmsnorm_eda."),
+ ("zaya_block.router.router_mlp.0.", "mlp.gate.router_mlp.fc1."),
+ ("zaya_block.router.router_mlp.2.", "mlp.gate.router_mlp.fc2."),
+ ("zaya_block.router.router_mlp.4.", "mlp.gate.router_mlp.out_proj."),
+ ("zaya_block.router.", "mlp.gate."),
+ ("zaya_block.", "mlp."),
)
for old, new in replacements:
if rest.startswith(old):
@@ -85,7 +97,7 @@ def _expert_target(name: str) -> tuple[str, int] | None:
new_layer_idx = old_layer_idx // 2
expert_idx = int(match.group(2))
projection = "gate_up_proj" if match.group(3) == "1" else "down_proj"
- target = f"model.layers.{new_layer_idx}.zaya_block.experts.{projection}"
+ target = f"model.layers.{new_layer_idx}.mlp.experts.{projection}"
return target, expert_idx
@@ -97,7 +109,7 @@ def convert_weight_name(name: str, old_num_hidden_layers: int | None = None) ->
if match is None:
if old_num_hidden_layers is not None and name.startswith("model.res_scale."):
new_layer_idx = old_num_hidden_layers // 2 - 1
- return f"model.layers.{new_layer_idx}.post_mlp_res_scale.{name.removeprefix('model.res_scale.')}"
+ return f"model.layers.{new_layer_idx}.post_mlp_residual_scale.{name.removeprefix('model.res_scale.')}"
return name
old_layer_idx = int(match.group(1))
@@ -106,41 +118,51 @@ def convert_weight_name(name: str, old_num_hidden_layers: int | None = None) ->
if old_layer_idx % 2 == 0:
rest = _rename_common(rest)
- if rest.startswith(("self_attn.", "input_norm.")):
+ if rest.startswith("self_attn."):
return f"model.layers.{new_layer_idx}.{rest}"
+ if rest.startswith("input_norm."):
+ return f"model.layers.{new_layer_idx}.input_layernorm.{rest.removeprefix('input_norm.')}"
if rest.startswith("res_scale."):
if old_layer_idx == 0:
return f"model.input_{rest.removeprefix('res_scale.')}"
- return f"model.layers.{new_layer_idx - 1}.post_mlp_res_scale.{rest.removeprefix('res_scale.')}"
+ return f"model.layers.{new_layer_idx - 1}.post_mlp_residual_scale.{rest.removeprefix('res_scale.')}"
else:
rest = _rename_common(rest)
- if rest.startswith("zaya_block."):
+ if rest.startswith("mlp."):
return f"model.layers.{new_layer_idx}.{rest}"
if rest.startswith("input_norm."):
- return f"model.layers.{new_layer_idx}.post_attention_norm.{rest.removeprefix('input_norm.')}"
+ return f"model.layers.{new_layer_idx}.post_attention_layernorm.{rest.removeprefix('input_norm.')}"
if rest.startswith("res_scale."):
- return f"model.layers.{new_layer_idx}.post_attention_res_scale.{rest.removeprefix('res_scale.')}"
+ return f"model.layers.{new_layer_idx}.post_attention_residual_scale.{rest.removeprefix('res_scale.')}"
raise ValueError(f"Unexpected ZAYA layer weight name: {name}")
+def _to_hybrid_layer_type(layer_type: str) -> str:
+ if layer_type == "full_attention":
+ return "hybrid"
+ if layer_type == "sliding_attention":
+ return "hybrid_sliding"
+ raise ValueError(f"Unsupported ZAYA layer type: {layer_type}")
+
+
def _convert_layer_types(config_dict: dict, old_num_hidden_layers: int, new_num_hidden_layers: int) -> list[str]:
layer_types = config_dict.get("layer_types")
if layer_types is not None:
if len(layer_types) == old_num_hidden_layers:
- return layer_types[::2]
+ return [_to_hybrid_layer_type(layer_type) for layer_type in layer_types[::2]]
if len(layer_types) == new_num_hidden_layers:
- return list(layer_types)
+ return [_to_hybrid_layer_type(layer_type) for layer_type in layer_types]
raise ValueError("`layer_types` must match either the original or converted number of hidden layers.")
swa_layers = config_dict.get("swa_layers")
if swa_layers is None:
- return ["full_attention"] * new_num_hidden_layers
+ return ["hybrid"] * new_num_hidden_layers
if len(swa_layers) == old_num_hidden_layers:
swa_layers = swa_layers[::2]
elif len(swa_layers) != new_num_hidden_layers:
raise ValueError("`swa_layers` must match either the original or converted number of hidden layers.")
- return ["full_attention" if int(window_size) == 0 else "sliding_attention" for window_size in swa_layers]
+ return ["hybrid" if int(window_size) == 0 else "hybrid_sliding" for window_size in swa_layers]
def convert_config(input_dir: Path, output_dir: Path) -> None:
@@ -151,12 +173,15 @@ def convert_config(input_dir: Path, output_dir: Path) -> None:
new_num_hidden_layers = old_num_hidden_layers // 2
layer_types = _convert_layer_types(config_dict, old_num_hidden_layers, new_num_hidden_layers)
- partial_rotary_factor = config_dict.get("partial_rotary_factor", ZayaConfig.partial_rotary_factor)
- rope_theta = config_dict.get("rope_theta", ZayaConfig.default_theta)
- swa_rotary_base = config_dict.get("swa_rotary_base", ZayaConfig.default_swa_theta)
- intermediate_size = config_dict.get(
- "intermediate_size", config_dict.get("ffn_hidden_size", ZayaConfig.intermediate_size)
+ partial_rotary_factor = 0.5
+ rope_theta = config_dict.get("rope_theta", _DEFAULT_ROPE_THETA)
+ swa_rotary_base = config_dict.get("swa_rotary_base", _DEFAULT_SWA_ROPE_THETA)
+ rms_norm_eps = config_dict.get("rms_norm_eps", config_dict.get("norm_epsilon", ZayaConfig.rms_norm_eps))
+ router_hidden_size = config_dict.get(
+ "router_hidden_size", config_dict.get("zaya_mlp_expansion", ZayaConfig.router_hidden_size)
)
+ expert_ffn_size = config_dict.get("intermediate_size", config_dict.get("ffn_hidden_size"))
+ moe_intermediate_size = expert_ffn_size // 2 if expert_ffn_size is not None else ZayaConfig.moe_intermediate_size
num_experts_per_tok = config_dict.get(
"num_experts_per_tok", config_dict.get("moe_router_topk", ZayaConfig.num_experts_per_tok)
)
@@ -170,12 +195,12 @@ def convert_config(input_dir: Path, output_dir: Path) -> None:
sliding_window = max(positive_windows) + 1 if positive_windows else None
rope_parameters = {
- "full_attention": {
+ "hybrid": {
"rope_type": "default",
"rope_theta": rope_theta,
"partial_rotary_factor": partial_rotary_factor,
},
- "sliding_attention": {
+ "hybrid_sliding": {
"rope_type": "default",
"rope_theta": swa_rotary_base,
"partial_rotary_factor": partial_rotary_factor,
@@ -189,11 +214,13 @@ def convert_config(input_dir: Path, output_dir: Path) -> None:
{
"architectures": ["ZayaForCausalLM"],
"num_hidden_layers": new_num_hidden_layers,
- "intermediate_size": intermediate_size,
+ "moe_intermediate_size": moe_intermediate_size,
"num_experts_per_tok": num_experts_per_tok,
+ "rms_norm_eps": rms_norm_eps,
+ "router_hidden_size": router_hidden_size,
"layer_types": layer_types,
"sliding_window": sliding_window,
- "rope_parameters": {layer_type: rope_parameters[layer_type] for layer_type in set(layer_types)},
+ "rope_parameters": rope_parameters,
}
)
ZayaConfig(**config_dict).save_pretrained(output_dir)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index b162dde52e02..a9a11daf14bb 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -144,15 +144,28 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
class ZayaCCAProjection(nn.Module):
"""
Projects hidden states into attention q/k/v states with ZAYA's CCA path.
- `linear_q` and `linear_k` produce the residual q/k states and are concatenated into `qk_states`. The causal
+ `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal
`conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail;
for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`.
- Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses
- `val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**.
+ Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token
+ `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection
+ from **the recurrent cache**.
Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys.
"""
@@ -166,21 +179,22 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
self.depthwise_kernel_size = config.cca_time0
self.grouped_kernel_size = config.cca_time1
- self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1)
+ self.conv_kernel_size = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1)
self.num_key_value_heads = config.num_key_value_heads
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.head_dim
- self.key_value_hidden_size = self.num_key_value_heads * self.head_dim
- self.query_hidden_size = self.num_attention_heads * self.head_dim
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
- self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias)
- self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias)
- self.val_proj1 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias)
- self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias)
+ query_hidden_size = self.num_attention_heads * self.head_dim
+ key_value_hidden_size = self.num_key_value_heads * self.head_dim
- conv_channels = self.key_value_hidden_size + self.query_hidden_size
+ self.q_proj = nn.Linear(self.hidden_size, query_hidden_size, bias=self.config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, key_value_hidden_size, bias=self.config.attention_bias)
+ self.v_proj_current = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias)
+ self.v_proj_delayed = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias)
+
+ conv_channels = key_value_hidden_size + query_hidden_size
self.conv_qk_depthwise = nn.Conv1d(
in_channels=conv_channels,
out_channels=conv_channels,
@@ -210,68 +224,71 @@ def forward(
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
- projected_queries = self.linear_q(hidden_states)
- projected_keys = self.linear_k(hidden_states)
+ projected_queries = self.q_proj(hidden_states)
+ projected_keys = self.k_proj(hidden_states)
qk_states = torch.cat([projected_queries, projected_keys], dim=-1)
query_residual = projected_queries.view(*hidden_shape)
- key_residual = projected_keys.view(*input_shape, self.num_key_value_heads, self.head_dim)
-
- key_residual = key_residual.repeat_interleave(self.num_key_value_groups, dim=-2)
+ key_residual = projected_keys.view(*input_shape, -1, self.head_dim).transpose(1, 2)
+ key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2)
query_residual = (query_residual + key_residual) * 0.5
- key_residual = query_residual.view(
- *input_shape, self.num_key_value_heads, self.num_key_value_groups, self.head_dim
- ).mean(dim=-2)
+ key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2)
qk_states = qk_states.transpose(1, 2)
use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx)
if use_precomputed_states:
cached_qk_states = past_key_values.layers[self.layer_idx].conv_states
- conv_input = torch.cat([cached_qk_states, qk_states], dim=-1)
+ qk_states = torch.cat([cached_qk_states, qk_states], dim=-1)
else:
- conv_input = F.pad(qk_states, (self.total_padding, 0))
+ qk_states = F.pad(qk_states, (self.conv_kernel_size, 0))
if past_key_values is not None:
- new_conv_state = qk_states[..., -self.total_padding :]
- if new_conv_state.shape[-1] < self.total_padding:
- new_conv_state = F.pad(new_conv_state, (self.total_padding - new_conv_state.shape[-1], 0))
+ new_conv_state = qk_states[..., -self.conv_kernel_size :]
+ if new_conv_state.shape[-1] < self.conv_kernel_size:
+ new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0))
past_key_values.update_conv_state(new_conv_state, self.layer_idx)
- convolved_qk_states = self.conv_qk_depthwise(conv_input)
- convolved_qk_states = self.conv_qk_grouped(convolved_qk_states).transpose(1, 2)
-
- query = (
- convolved_qk_states[..., : self.query_hidden_size].view(
- *input_shape, self.num_attention_heads, self.head_dim
- )
- + query_residual
- )
+ qk_states = self.conv_qk_depthwise(qk_states)
+ qk_states = self.conv_qk_grouped(qk_states).transpose(1, 2)
- key = (
- convolved_qk_states[..., self.query_hidden_size :].view(
- *input_shape, self.num_key_value_heads, self.head_dim
- )
- + key_residual
- )
+ query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1]
+ query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual
+ key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual
- value_current = self.val_proj1(hidden_states)
- projected_v2 = self.val_proj2(hidden_states)
+ value_current = self.v_proj_current(hidden_states)
+ delayed_v_state = self.v_proj_delayed(hidden_states)
if use_precomputed_states:
- first_v2 = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1)
+ recurrent_v_state = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1)
else:
- first_v2 = self.val_proj2(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size))
- value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1)
+ recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size))
+ value_delayed = torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1)
if past_key_values is not None:
- past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_idx)
+ past_key_values.update_recurrent_state(delayed_v_state[:, -1, :], self.layer_idx)
- value = torch.cat([value_current, value_delayed], dim=-1).view(
- *input_shape, self.num_key_value_heads, self.head_dim
- )
+ value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape)
return query, key, value
+class ZayaQKNorm(nn.Module):
+ def __init__(self, config: ZayaConfig, scaling: float):
+ super().__init__()
+ self.head_dim_scale = scaling**-1
+ self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads))
+
+ def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ norm_eps = torch.finfo(query_states.dtype).eps
+ query_states = query_states * (
+ self.head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
+ )
+ key_states = key_states * (
+ self.head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
+ )
+ key_states = key_states * self.temp[None, None, :, None]
+ return query_states, key_states
+
+
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -279,18 +296,6 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-
-
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
@@ -377,10 +382,10 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
layer_idx=layer_idx,
)
self.layer_type = config.layer_types[layer_idx]
- self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
+ self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
- self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads))
+ self.qk_norm = ZayaQKNorm(config, self.scaling)
def forward(
self,
@@ -389,7 +394,7 @@ def forward(
past_key_values: Cache | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
batch_size, seq_length, _ = hidden_states.shape
mask_mapping = attention_mask or {}
@@ -398,13 +403,7 @@ def forward(
query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask)
- norm_eps = torch.finfo(query_states.dtype).eps
- head_dim_scale = self.scaling**-1
- query_states = query_states * (
- head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
- )
- key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps))
- key_states = key_states * self.temp[None, None, :, None]
+ query_states, key_states = self.qk_norm(query_states, key_states)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -416,9 +415,6 @@ def forward(
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- if isinstance(causal_mask, torch.Tensor):
- causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]]
-
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
@@ -434,10 +430,10 @@ def forward(
**kwargs,
)
- attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim)
+ attn_output = attn_output.view(batch_size, seq_length, -1)
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights, past_key_values
+ return attn_output, attn_weights
class ZayaDecoderLayer(GradientCheckpointingLayer):
@@ -445,17 +441,11 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__()
self.config = config
self.self_attn = ZayaAttention(config, layer_idx)
- self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
- self.zaya_block = ZayaSparseMoeBlock(
- config,
- config.num_experts,
- config.zaya_mlp_expansion,
- config.intermediate_size,
- layer_idx,
- )
- self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
- self.post_attention_res_scale = ResidualScaling(config.hidden_size)
- self.post_mlp_res_scale = ResidualScaling(config.hidden_size)
+ self.input_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.mlp = ZayaSparseMoeBlock(config, layer_idx)
+ self.post_attention_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size)
+ self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size)
def forward(
self,
@@ -465,13 +455,11 @@ def forward(
past_key_values: Cache | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
residual = hidden_states
- # Matches the original ZAYA `residual_in_fp32` path; norm casts back to the parameter dtype below.
- residual = residual.to(torch.float32)
- hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype))
+ hidden_states = self.input_layernorm(hidden_states)
- hidden_states, self_attn_weights, _ = self.self_attn(
+ hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
@@ -479,20 +467,20 @@ def forward(
**kwargs,
)
- residual = self.post_attention_res_scale(hidden_states, residual)
- hidden_states = self.post_attention_norm(residual.to(dtype=self.post_attention_norm.weight.dtype))
+ residual = self.post_attention_residual_scale(hidden_states, residual)
+ hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype))
- hidden_states, prev_router_hidden_states, _ = self.zaya_block(
+ hidden_states, prev_router_hidden_states, _ = self.mlp(
hidden_states,
prev_router_hidden_states,
)
- hidden_states = self.post_mlp_res_scale(hidden_states, residual)
+ hidden_states = self.post_mlp_residual_scale(hidden_states, residual)
- return hidden_states, self_attn_weights, prev_router_hidden_states
+ return hidden_states, prev_router_hidden_states
-class ResidualScaling(nn.Module):
+class ZayaResidualScaling(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size))
@@ -501,20 +489,25 @@ def __init__(self, hidden_size: int):
self.residual_bias = nn.Parameter(torch.zeros(hidden_size))
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor):
+ output_dtype = hidden_states.dtype
hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale
+ # Matches the original ZAYA `residual_in_fp32` path.
+ residual = residual.to(torch.float32)
residual = (residual + self.residual_bias) * self.residual_scale
- return hidden_states + residual
+ return (hidden_states + residual).to(output_dtype)
class ZayaRouterMLP(nn.Module):
- def __init__(self, hidden_size: int, num_experts: int):
+ def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float):
super().__init__()
+ self.rmsnorm_eda = ZayaRMSNorm(hidden_size, eps=rms_norm_eps)
self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True)
self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, num_experts, bias=False)
self.act_fn = nn.GELU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.rmsnorm_eda(hidden_states)
hidden_states = self.act_fn(self.fc1(hidden_states))
hidden_states = self.act_fn(self.fc2(hidden_states))
return self.out_proj(hidden_states)
@@ -525,30 +518,24 @@ def __init__(
self,
config,
layer_idx: int,
- num_moe_experts: int,
- num_experts_per_tok: int,
- mlp_expansion: int,
- hidden_size: int | None = None,
) -> None:
super().__init__()
self.config = config
- self.hidden_size = int(hidden_size or getattr(config, "hidden_size"))
+ self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
- self.num_experts = num_moe_experts + 1
- self.topk = int(num_experts_per_tok)
- self.mlp_expansion = int(mlp_expansion)
+ self.num_experts = config.num_experts + 1
+ self.top_k = config.num_experts_per_tok
+ self.router_hidden_size = config.router_hidden_size
- self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True)
+ self.down_proj = nn.Linear(self.hidden_size, self.router_hidden_size, bias=True)
self.use_eda = self.layer_idx != 0
-
- self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=config.norm_epsilon)
if self.use_eda:
- self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion))
+ self.router_states_scale = nn.Parameter(torch.ones(self.router_hidden_size))
- self.router_mlp = ZayaRouterMLP(self.mlp_expansion, self.num_experts)
+ self.router_mlp = ZayaRouterMLP(self.router_hidden_size, self.num_experts, config.rms_norm_eps)
self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32))
self.balancing_biases[-1] = -1.0
@@ -558,27 +545,32 @@ def forward(
hidden_states: torch.Tensor,
router_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ final_shape = (-1, self.top_k)
seq_length = hidden_states.shape[1]
router_hidden_states = self.down_proj(hidden_states)
- if self.use_eda and (router_states is not None):
+ if self.use_eda and router_states is not None:
router_hidden_states = router_hidden_states + router_states * self.router_states_scale
router_hidden_states_next = router_hidden_states[:, -seq_length:].clone()
- router_hidden_states = self.rmsnorm_eda(router_hidden_states)
- logits = self.router_mlp(router_hidden_states)
- expert_prob = torch.softmax(logits, dim=-1)
+ router_logits = self.router_mlp(router_hidden_states)
+ router_probs = torch.softmax(router_logits, dim=-1)
+
+ biased_router_probs = router_probs.detach().to(torch.float32) + self.balancing_biases
+ _, router_indices = torch.topk(biased_router_probs, self.top_k, dim=-1)
+ router_probs = torch.gather(router_probs, dim=2, index=router_indices)
- expert_choice = expert_prob.detach().to(torch.float32) + self.balancing_biases
- _, expert_choice = torch.topk(expert_choice, self.topk, dim=-1)
- route_prob = torch.gather(expert_prob, dim=2, index=expert_choice)
+ # If the router selects the extra skip expert, mask it before `ZayaExperts` builds its one-hot expert mask.
+ skip_expert = router_indices == self.config.num_experts
+ router_probs = router_probs.masked_fill(skip_expert, 0)
+ router_indices = router_indices.masked_fill(skip_expert, 0)
return (
- route_prob.reshape(-1, self.topk),
- expert_choice.reshape(-1, self.topk),
+ router_logits.reshape(-1, self.num_experts),
+ router_probs.reshape(final_shape),
+ router_indices.reshape(final_shape),
router_hidden_states_next,
- logits.reshape(-1, self.num_experts),
)
@@ -586,11 +578,11 @@ def forward(
class ZayaExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
- def __init__(self, config, num_experts: int, intermediate_size: int):
+ def __init__(self, config):
super().__init__()
- self.num_experts = num_experts
+ self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
- self.intermediate_dim = intermediate_size // 2
+ self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
@@ -623,46 +615,25 @@ def forward(
class ZayaSparseMoeBlock(nn.Module):
- def __init__(
- self,
- config,
- num_moe_experts: int,
- mlp_expansion: int,
- intermediate_size: int,
- layer_idx: int,
- ):
+ def __init__(self, config, layer_idx: int):
super().__init__()
self.config = config
self.hidden_dim = config.hidden_size
- self.num_moe_experts = num_moe_experts
- self.router = ZayaRouter(
- config=self.config,
- layer_idx=layer_idx,
- num_moe_experts=self.num_moe_experts,
- num_experts_per_tok=self.config.num_experts_per_tok,
- mlp_expansion=mlp_expansion,
- hidden_size=self.hidden_dim,
- )
- self.experts = ZayaExperts(self.config, self.num_moe_experts, intermediate_size=intermediate_size)
+ self.gate = ZayaRouter(self.config, layer_idx)
+ self.experts = ZayaExperts(self.config)
def forward(
self,
hidden_states: torch.Tensor,
prev_router_hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
- route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router(
+ router_logits, router_probs, router_indices, prev_router_hidden_states = self.gate(
hidden_states, router_states=prev_router_hidden_states
)
- # if the router outputs num_moe_experts, just skip the tokens
- # by masking them with id=0 and prob=0 to reuse the expert code
- skip_expert = expert_choice == self.num_moe_experts
- route_prob = route_prob.masked_fill(skip_expert, 0)
- expert_choice = expert_choice.masked_fill(skip_expert, 0)
-
batch_size, seq_length, emb_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim)
- expert_output = self.experts(hidden_states_flat, expert_choice, route_prob)
+ expert_output = self.experts(hidden_states_flat, router_indices, router_probs)
expert_output = expert_output.view(batch_size, seq_length, emb_dim)
return expert_output, prev_router_hidden_states, router_logits
@@ -682,7 +653,7 @@ class ZayaPreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
- "router_logits": OutputRecorder(ZayaRouter, index=3),
+ "router_logits": OutputRecorder(ZayaRouter, index=0),
"hidden_states": ZayaDecoderLayer,
"attentions": ZayaAttention,
}
@@ -691,7 +662,7 @@ class ZayaPreTrainedModel(PreTrainedModel):
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
- if isinstance(module, ResidualScaling):
+ if isinstance(module, ZayaResidualScaling):
init.ones_(module.hidden_states_scale)
init.zeros_(module.hidden_states_bias)
init.ones_(module.residual_scale)
@@ -699,7 +670,7 @@ def _init_weights(self, module):
elif isinstance(module, ZayaModel):
init.ones_(module.input_hidden_states_scale)
init.zeros_(module.input_hidden_states_bias)
- elif isinstance(module, ZayaAttention):
+ elif isinstance(module, ZayaQKNorm):
init.zeros_(module.temp)
elif isinstance(module, ZayaRouter):
if module.use_eda:
@@ -726,7 +697,6 @@ def __init__(self, config: ZayaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- self.cache_layer_types = config.cache_layer_types
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
@@ -736,7 +706,7 @@ def __init__(self, config: ZayaConfig):
self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size))
self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size))
- self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
+ self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.rotary_emb = ZayaRotaryEmbedding(config=config)
@@ -772,11 +742,8 @@ def forward(
if position_ids is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(
- past_seen_tokens,
- past_seen_tokens + inputs_embeds.shape[1],
- device=inputs_embeds.device,
- ).unsqueeze(0)
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
+ position_ids = position_ids.unsqueeze(0)
if attention_mask is not None and attention_mask.ndim != 2:
raise ValueError(
@@ -811,7 +778,7 @@ def forward(
layer_type = self.config.layer_types[layer_n]
emb_to_use = position_embeddings[layer_type]
mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask}
- layer_outputs = decoder_layer(
+ hidden_states, prev_router_hidden_states = decoder_layer(
hidden_states,
prev_router_hidden_states,
attention_mask=mask_mapping,
@@ -820,9 +787,6 @@ def forward(
**kwargs,
)
- hidden_states = layer_outputs[0]
- prev_router_hidden_states = layer_outputs[2]
-
hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype))
return MoeModelOutputWithPast(
@@ -847,8 +811,8 @@ def _update_causal_mask(
# Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection.
sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None}
mask_creation_functions = {
- "full_attention": lambda: create_causal_mask(**mask_kwargs),
- "sliding_attention": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
+ "hybrid": lambda: create_causal_mask(**mask_kwargs),
+ "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
}
causal_mask_mapping = {}
for layer_type in set(self.config.layer_types):
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 04a6625b1313..d8655390ba61 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -24,13 +24,12 @@
from torch import nn
from torch.nn import init
-from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeModelOutputWithPast
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
@@ -40,8 +39,9 @@
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..afmoe.modeling_afmoe import AfmoeForCausalLM
+from ..laguna.configuration_laguna import LagunaConfig
from ..laguna.modeling_laguna import LagunaRotaryEmbedding
-from ..llama.modeling_llama import LlamaPreTrainedModel
+from ..llama.modeling_llama import LlamaPreTrainedModel, repeat_kv
from ..phi3.modeling_phi3 import Phi3Attention
from ..qwen3_5_moe.modeling_qwen3_5_moe import (
apply_rotary_pos_emb,
@@ -52,28 +52,16 @@
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
@strict
-class ZayaConfig(PreTrainedConfig):
+class ZayaConfig(LagunaConfig):
r"""
- intermediate_size (`int`, *optional*, defaults to 4096):
- Dimension of the feed-forward and expert hidden states.
- num_key_value_heads (`int`, *optional*, defaults to 2):
- Number of key/value groups.
- partial_rotary_factor (`float`, *optional*, defaults to 0.5):
- Fraction of each attention head dimension using rotary embeddings.
lm_head_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the language modeling head.
- num_experts_per_tok (`int`, *optional*, defaults to 1):
- Number of selected experts per token. ZAYA checkpoints use top-1 routing.
- zaya_mlp_expansion (`int`, *optional*, defaults to 256):
- Expansion size used by the dense ZAYA blocks.
+ router_hidden_size (`int`, *optional*, defaults to 256):
+ Hidden size used by the ZAYA router.
cca_time0 (`int`, *optional*, defaults to 2):
First temporal parameter of the CCA projection.
cca_time1 (`int`, *optional*, defaults to 2):
Second temporal parameter of the CCA projection.
- layer_types (`list[str]`, *optional*):
- Per-layer selector for standard RoPE versus SWA RoPE embeddings.
- cache_layer_types (`list[str]`, *optional*):
- Per-layer selector for cache layout. ZAYA uses the native `"hybrid"` cache layer for every decoder layer.
```python
>>> from transformers import ZayaConfig, ZayaModel
@@ -86,71 +74,61 @@ class ZayaConfig(PreTrainedConfig):
"""
model_type = "zaya"
- keys_to_ignore_at_inference = ["past_key_values"]
- default_theta = 5000000.0
- default_swa_theta = 10000.0
vocab_size: int = 262272
- hidden_size: int = 2048
- intermediate_size: int = 4096
- num_hidden_layers: int = 40
- num_experts: int = 16
+ moe_intermediate_size: int = 2048
num_attention_heads: int = 8
num_key_value_heads: int = 2
- hidden_act: str = "silu"
- head_dim: int = 128
- max_position_embeddings: int = 131072
- initializer_range: float = 0.02
- norm_epsilon: float = 1e-5
- use_cache: bool = True
tie_word_embeddings: bool = True
- rope_parameters: RopeParameters | dict | None = None
- partial_rotary_factor: float = 0.5
- attention_bias: bool = False
- lm_head_bias: bool = False
- attention_dropout: float | int = 0.0
- num_experts_per_tok: int = 1
- zaya_mlp_expansion: int = 256
- cca_time0: int = 2
- cca_time1: int = 2
+ rms_norm_eps: float = 1e-5
sliding_window: int | None = None
- layer_types: list[str] | None = None
- cache_layer_types: list[str] | None = None
- output_router_logits: bool = False
pad_token_id: int | None = 0
bos_token_id: int | None = 2
eos_token_id: int | list[int] | None = 106
+ num_experts_per_tok: int = 1
+ num_experts: int = 16
+
+ lm_head_bias: bool = False
+ router_hidden_size: int = 256
+ cca_time0: int = 2
+ cca_time1: int = 2
+
+ # Fields declared by LagunaConfig but not used by ZAYA.
+ # TP and PP are not tested yet, so remove for now
+ base_model_tp_plan = AttributeError()
+ base_model_pp_plan = AttributeError()
+ intermediate_size = AttributeError()
+ shared_expert_intermediate_size = AttributeError()
+ router_aux_loss_coef = AttributeError()
+ num_attention_heads_per_layer = AttributeError()
+ mlp_layer_types = AttributeError()
+ moe_routed_scaling_factor = AttributeError()
+ moe_apply_router_weight_on_input = AttributeError()
+ moe_router_logit_softcapping = AttributeError()
+
def __post_init__(self, **kwargs):
- self.layer_types = (
- ["full_attention"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types)
- )
- self.cache_layer_types = (
- ["hybrid"] * self.num_hidden_layers if self.cache_layer_types is None else list(self.cache_layer_types)
- )
+ self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types)
- default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = {
- "full_attention": {
+ default_rope_params: dict[Literal["hybrid", "hybrid_sliding"], dict[str, Any]] = {
+ "hybrid": {
"rope_type": "default",
- "rope_theta": self.default_theta,
- "partial_rotary_factor": self.partial_rotary_factor,
+ "rope_theta": 5_000_000.0,
+ "partial_rotary_factor": 0.5,
},
- "sliding_attention": {
+ "hybrid_sliding": {
"rope_type": "default",
- "rope_theta": self.default_swa_theta,
- "partial_rotary_factor": self.partial_rotary_factor,
+ "rope_theta": 10_000.0,
+ "partial_rotary_factor": 0.5,
},
}
if self.rope_parameters is None:
- self.rope_parameters = {
- layer_type: default_rope_params[layer_type] for layer_type in set(self.layer_types)
- }
+ self.rope_parameters = default_rope_params
- super().__post_init__(**kwargs)
+ PreTrainedConfig.__post_init__(self, **kwargs, ignore_keys_at_rope_validation={"hybrid", "hybrid_sliding"})
def convert_rope_params_to_dict(self, **kwargs):
- # ZAYA uses nested RoPE parameters keyed by layer type. Keep the base RoPE BC conversion from treating them
- # like a single flat RoPE dict and injecting top-level keys such as `rope_theta`.
+ # No legacy flat RoPE format is supported here; conversion writes the nested ZAYA layer-type format directly.
return kwargs
def validate_architecture(self):
@@ -160,14 +138,10 @@ def validate_architecture(self):
raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
if len(self.layer_types) != self.num_hidden_layers:
raise ValueError("`layer_types` must have one entry per hidden layer.")
- if len(self.cache_layer_types) != self.num_hidden_layers:
- raise ValueError("`cache_layer_types` must have one entry per hidden layer.")
- if invalid_cache_layer_types := set(self.cache_layer_types) - {"hybrid"}:
- raise ValueError(f"`cache_layer_types` contains unsupported values: {sorted(invalid_cache_layer_types)}.")
- if invalid_layer_types := set(self.layer_types) - {"full_attention", "sliding_attention"}:
+ if invalid_layer_types := set(self.layer_types) - {"hybrid", "hybrid_sliding"}:
raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.")
- if "sliding_attention" in self.layer_types and self.sliding_window is None:
- raise ValueError("`sliding_window` must be set when `layer_types` contains `sliding_attention`.")
+ if "hybrid_sliding" in self.layer_types and self.sliding_window is None:
+ raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.")
if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError("`sliding_window` must be a strictly positive integer.")
@@ -184,11 +158,12 @@ class ZayaCCAProjection(nn.Module):
"""
Projects hidden states into attention q/k/v states with ZAYA's CCA path.
- `linear_q` and `linear_k` produce the residual q/k states and are concatenated into `qk_states`. The causal
+ `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal
`conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail;
for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`.
- Values are built from `val_proj1(hidden_states[:, t])` and a delayed `val_proj2`: during prefill token `t` uses
- `val_proj2(hidden_states[:, t - 1])`, while decoding reads the previous `val_proj2` from **the recurrent cache**.
+ Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token
+ `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection
+ from **the recurrent cache**.
Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys.
"""
@@ -202,21 +177,22 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
self.depthwise_kernel_size = config.cca_time0
self.grouped_kernel_size = config.cca_time1
- self.total_padding = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1)
+ self.conv_kernel_size = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1)
self.num_key_value_heads = config.num_key_value_heads
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.head_dim
- self.key_value_hidden_size = self.num_key_value_heads * self.head_dim
- self.query_hidden_size = self.num_attention_heads * self.head_dim
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
- self.linear_q = nn.Linear(self.hidden_size, self.query_hidden_size, bias=self.config.attention_bias)
- self.linear_k = nn.Linear(self.hidden_size, self.key_value_hidden_size, bias=self.config.attention_bias)
- self.val_proj1 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias)
- self.val_proj2 = nn.Linear(self.hidden_size, self.key_value_hidden_size // 2, bias=self.config.attention_bias)
+ query_hidden_size = self.num_attention_heads * self.head_dim
+ key_value_hidden_size = self.num_key_value_heads * self.head_dim
- conv_channels = self.key_value_hidden_size + self.query_hidden_size
+ self.q_proj = nn.Linear(self.hidden_size, query_hidden_size, bias=self.config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, key_value_hidden_size, bias=self.config.attention_bias)
+ self.v_proj_current = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias)
+ self.v_proj_delayed = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias)
+
+ conv_channels = key_value_hidden_size + query_hidden_size
self.conv_qk_depthwise = nn.Conv1d(
in_channels=conv_channels,
out_channels=conv_channels,
@@ -246,81 +222,84 @@ def forward(
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
- projected_queries = self.linear_q(hidden_states)
- projected_keys = self.linear_k(hidden_states)
+ projected_queries = self.q_proj(hidden_states)
+ projected_keys = self.k_proj(hidden_states)
qk_states = torch.cat([projected_queries, projected_keys], dim=-1)
query_residual = projected_queries.view(*hidden_shape)
- key_residual = projected_keys.view(*input_shape, self.num_key_value_heads, self.head_dim)
-
- key_residual = key_residual.repeat_interleave(self.num_key_value_groups, dim=-2)
+ key_residual = projected_keys.view(*input_shape, -1, self.head_dim).transpose(1, 2)
+ key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2)
query_residual = (query_residual + key_residual) * 0.5
- key_residual = query_residual.view(
- *input_shape, self.num_key_value_heads, self.num_key_value_groups, self.head_dim
- ).mean(dim=-2)
+ key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2)
qk_states = qk_states.transpose(1, 2)
use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx)
if use_precomputed_states:
cached_qk_states = past_key_values.layers[self.layer_idx].conv_states
- conv_input = torch.cat([cached_qk_states, qk_states], dim=-1)
+ qk_states = torch.cat([cached_qk_states, qk_states], dim=-1)
else:
- conv_input = F.pad(qk_states, (self.total_padding, 0))
+ qk_states = F.pad(qk_states, (self.conv_kernel_size, 0))
if past_key_values is not None:
- new_conv_state = qk_states[..., -self.total_padding :]
- if new_conv_state.shape[-1] < self.total_padding:
- new_conv_state = F.pad(new_conv_state, (self.total_padding - new_conv_state.shape[-1], 0))
+ new_conv_state = qk_states[..., -self.conv_kernel_size :]
+ if new_conv_state.shape[-1] < self.conv_kernel_size:
+ new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0))
past_key_values.update_conv_state(new_conv_state, self.layer_idx)
- convolved_qk_states = self.conv_qk_depthwise(conv_input)
- convolved_qk_states = self.conv_qk_grouped(convolved_qk_states).transpose(1, 2)
-
- query = (
- convolved_qk_states[..., : self.query_hidden_size].view(
- *input_shape, self.num_attention_heads, self.head_dim
- )
- + query_residual
- )
+ qk_states = self.conv_qk_depthwise(qk_states)
+ qk_states = self.conv_qk_grouped(qk_states).transpose(1, 2)
- key = (
- convolved_qk_states[..., self.query_hidden_size :].view(
- *input_shape, self.num_key_value_heads, self.head_dim
- )
- + key_residual
- )
+ query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1]
+ query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual
+ key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual
- value_current = self.val_proj1(hidden_states)
- projected_v2 = self.val_proj2(hidden_states)
+ value_current = self.v_proj_current(hidden_states)
+ delayed_v_state = self.v_proj_delayed(hidden_states)
if use_precomputed_states:
- first_v2 = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1)
+ recurrent_v_state = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1)
else:
- first_v2 = self.val_proj2(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size))
- value_delayed = torch.cat([first_v2, projected_v2[:, :-1]], dim=1)
+ recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size))
+ value_delayed = torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1)
if past_key_values is not None:
- past_key_values.update_recurrent_state(projected_v2[:, -1, :], self.layer_idx)
+ past_key_values.update_recurrent_state(delayed_v_state[:, -1, :], self.layer_idx)
- value = torch.cat([value_current, value_delayed], dim=-1).view(
- *input_shape, self.num_key_value_heads, self.head_dim
- )
+ value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape)
return query, key, value
+class ZayaQKNorm(nn.Module):
+ def __init__(self, config: ZayaConfig, scaling: float):
+ super().__init__()
+ self.head_dim_scale = scaling**-1
+ self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads))
+
+ def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ norm_eps = torch.finfo(query_states.dtype).eps
+ query_states = query_states * (
+ self.head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
+ )
+ key_states = key_states * (
+ self.head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
+ )
+ key_states = key_states * self.temp[None, None, :, None]
+ return query_states, key_states
+
+
class ZayaAttention(Phi3Attention):
def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__(config, layer_idx)
del op_size # noqa: F821
self.layer_type = config.layer_types[layer_idx]
- self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
+ self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
- self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads))
+ self.qk_norm = ZayaQKNorm(config, self.scaling)
self.qkv_proj = ZayaCCAProjection(
config=self.config,
layer_idx=layer_idx,
@@ -333,7 +312,7 @@ def forward(
past_key_values: Cache | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
batch_size, seq_length, _ = hidden_states.shape
mask_mapping = attention_mask or {}
@@ -342,13 +321,7 @@ def forward(
query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask)
- norm_eps = torch.finfo(query_states.dtype).eps
- head_dim_scale = self.scaling**-1
- query_states = query_states * (
- head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)
- )
- key_states = key_states * (head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps))
- key_states = key_states * self.temp[None, None, :, None]
+ query_states, key_states = self.qk_norm(query_states, key_states)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -360,9 +333,6 @@ def forward(
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- if isinstance(causal_mask, torch.Tensor):
- causal_mask = causal_mask[:, :, : query_states.shape[-2], : key_states.shape[-2]]
-
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
@@ -378,10 +348,10 @@ def forward(
**kwargs,
)
- attn_output = attn_output.view(batch_size, seq_length, self.num_attention_heads * self.head_dim)
+ attn_output = attn_output.view(batch_size, seq_length, -1)
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights, past_key_values
+ return attn_output, attn_weights
class ZayaDecoderLayer(GradientCheckpointingLayer):
@@ -389,17 +359,11 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__()
self.config = config
self.self_attn = ZayaAttention(config, layer_idx)
- self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
- self.zaya_block = ZayaSparseMoeBlock(
- config,
- config.num_experts,
- config.zaya_mlp_expansion,
- config.intermediate_size,
- layer_idx,
- )
- self.post_attention_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
- self.post_attention_res_scale = ResidualScaling(config.hidden_size)
- self.post_mlp_res_scale = ResidualScaling(config.hidden_size)
+ self.input_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.mlp = ZayaSparseMoeBlock(config, layer_idx)
+ self.post_attention_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size)
+ self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size)
def forward(
self,
@@ -409,13 +373,11 @@ def forward(
past_key_values: Cache | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
residual = hidden_states
- # Matches the original ZAYA `residual_in_fp32` path; norm casts back to the parameter dtype below.
- residual = residual.to(torch.float32)
- hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype))
+ hidden_states = self.input_layernorm(hidden_states)
- hidden_states, self_attn_weights, _ = self.self_attn(
+ hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
@@ -423,20 +385,20 @@ def forward(
**kwargs,
)
- residual = self.post_attention_res_scale(hidden_states, residual)
- hidden_states = self.post_attention_norm(residual.to(dtype=self.post_attention_norm.weight.dtype))
+ residual = self.post_attention_residual_scale(hidden_states, residual)
+ hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype))
- hidden_states, prev_router_hidden_states, _ = self.zaya_block(
+ hidden_states, prev_router_hidden_states, _ = self.mlp(
hidden_states,
prev_router_hidden_states,
)
- hidden_states = self.post_mlp_res_scale(hidden_states, residual)
+ hidden_states = self.post_mlp_residual_scale(hidden_states, residual)
- return hidden_states, self_attn_weights, prev_router_hidden_states
+ return hidden_states, prev_router_hidden_states
-class ResidualScaling(nn.Module):
+class ZayaResidualScaling(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size))
@@ -445,20 +407,25 @@ def __init__(self, hidden_size: int):
self.residual_bias = nn.Parameter(torch.zeros(hidden_size))
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor):
+ output_dtype = hidden_states.dtype
hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale
+ # Matches the original ZAYA `residual_in_fp32` path.
+ residual = residual.to(torch.float32)
residual = (residual + self.residual_bias) * self.residual_scale
- return hidden_states + residual
+ return (hidden_states + residual).to(output_dtype)
class ZayaRouterMLP(nn.Module):
- def __init__(self, hidden_size: int, num_experts: int):
+ def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float):
super().__init__()
+ self.rmsnorm_eda = ZayaRMSNorm(hidden_size, eps=rms_norm_eps)
self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True)
self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, num_experts, bias=False)
self.act_fn = nn.GELU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.rmsnorm_eda(hidden_states)
hidden_states = self.act_fn(self.fc1(hidden_states))
hidden_states = self.act_fn(self.fc2(hidden_states))
return self.out_proj(hidden_states)
@@ -469,30 +436,24 @@ def __init__(
self,
config,
layer_idx: int,
- num_moe_experts: int,
- num_experts_per_tok: int,
- mlp_expansion: int,
- hidden_size: int | None = None,
) -> None:
super().__init__()
self.config = config
- self.hidden_size = int(hidden_size or getattr(config, "hidden_size"))
+ self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
- self.num_experts = num_moe_experts + 1
- self.topk = int(num_experts_per_tok)
- self.mlp_expansion = int(mlp_expansion)
+ self.num_experts = config.num_experts + 1
+ self.top_k = config.num_experts_per_tok
+ self.router_hidden_size = config.router_hidden_size
- self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True)
+ self.down_proj = nn.Linear(self.hidden_size, self.router_hidden_size, bias=True)
self.use_eda = self.layer_idx != 0
-
- self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=config.norm_epsilon)
if self.use_eda:
- self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion))
+ self.router_states_scale = nn.Parameter(torch.ones(self.router_hidden_size))
- self.router_mlp = ZayaRouterMLP(self.mlp_expansion, self.num_experts)
+ self.router_mlp = ZayaRouterMLP(self.router_hidden_size, self.num_experts, config.rms_norm_eps)
self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32))
self.balancing_biases[-1] = -1.0
@@ -502,82 +463,59 @@ def forward(
hidden_states: torch.Tensor,
router_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ final_shape = (-1, self.top_k)
seq_length = hidden_states.shape[1]
router_hidden_states = self.down_proj(hidden_states)
- if self.use_eda and (router_states is not None):
+ if self.use_eda and router_states is not None:
router_hidden_states = router_hidden_states + router_states * self.router_states_scale
router_hidden_states_next = router_hidden_states[:, -seq_length:].clone()
- router_hidden_states = self.rmsnorm_eda(router_hidden_states)
- logits = self.router_mlp(router_hidden_states)
- expert_prob = torch.softmax(logits, dim=-1)
+ router_logits = self.router_mlp(router_hidden_states)
+ router_probs = torch.softmax(router_logits, dim=-1)
+
+ biased_router_probs = router_probs.detach().to(torch.float32) + self.balancing_biases
+ _, router_indices = torch.topk(biased_router_probs, self.top_k, dim=-1)
+ router_probs = torch.gather(router_probs, dim=2, index=router_indices)
- expert_choice = expert_prob.detach().to(torch.float32) + self.balancing_biases
- _, expert_choice = torch.topk(expert_choice, self.topk, dim=-1)
- route_prob = torch.gather(expert_prob, dim=2, index=expert_choice)
+ # If the router selects the extra skip expert, mask it before `ZayaExperts` builds its one-hot expert mask.
+ skip_expert = router_indices == self.config.num_experts
+ router_probs = router_probs.masked_fill(skip_expert, 0)
+ router_indices = router_indices.masked_fill(skip_expert, 0)
return (
- route_prob.reshape(-1, self.topk),
- expert_choice.reshape(-1, self.topk),
+ router_logits.reshape(-1, self.num_experts),
+ router_probs.reshape(final_shape),
+ router_indices.reshape(final_shape),
router_hidden_states_next,
- logits.reshape(-1, self.num_experts),
)
class ZayaExperts(Qwen3MoeExperts):
- def __init__(self, config, num_experts: int, intermediate_size: int):
- nn.Module.__init__(self)
- self.num_experts = num_experts
- self.hidden_dim = config.hidden_size
- self.intermediate_dim = intermediate_size // 2
- self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
- self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
- self.act_fn = ACT2FN[config.hidden_act]
+ pass
class ZayaSparseMoeBlock(nn.Module):
- def __init__(
- self,
- config,
- num_moe_experts: int,
- mlp_expansion: int,
- intermediate_size: int,
- layer_idx: int,
- ):
+ def __init__(self, config, layer_idx: int):
super().__init__()
self.config = config
self.hidden_dim = config.hidden_size
- self.num_moe_experts = num_moe_experts
- self.router = ZayaRouter(
- config=self.config,
- layer_idx=layer_idx,
- num_moe_experts=self.num_moe_experts,
- num_experts_per_tok=self.config.num_experts_per_tok,
- mlp_expansion=mlp_expansion,
- hidden_size=self.hidden_dim,
- )
- self.experts = ZayaExperts(self.config, self.num_moe_experts, intermediate_size=intermediate_size)
+ self.gate = ZayaRouter(self.config, layer_idx)
+ self.experts = ZayaExperts(self.config)
def forward(
self,
hidden_states: torch.Tensor,
prev_router_hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
- route_prob, expert_choice, prev_router_hidden_states, router_logits = self.router(
+ router_logits, router_probs, router_indices, prev_router_hidden_states = self.gate(
hidden_states, router_states=prev_router_hidden_states
)
- # if the router outputs num_moe_experts, just skip the tokens
- # by masking them with id=0 and prob=0 to reuse the expert code
- skip_expert = expert_choice == self.num_moe_experts
- route_prob = route_prob.masked_fill(skip_expert, 0)
- expert_choice = expert_choice.masked_fill(skip_expert, 0)
-
batch_size, seq_length, emb_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim)
- expert_output = self.experts(hidden_states_flat, expert_choice, route_prob)
+ expert_output = self.experts(hidden_states_flat, router_indices, router_probs)
expert_output = expert_output.view(batch_size, seq_length, emb_dim)
return expert_output, prev_router_hidden_states, router_logits
@@ -590,7 +528,7 @@ class ZayaPreTrainedModel(LlamaPreTrainedModel):
# ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache.
_can_compile_fullgraph = False
_can_record_outputs = {
- "router_logits": OutputRecorder(ZayaRouter, index=3),
+ "router_logits": OutputRecorder(ZayaRouter, index=0),
"hidden_states": ZayaDecoderLayer,
"attentions": ZayaAttention,
}
@@ -598,7 +536,7 @@ class ZayaPreTrainedModel(LlamaPreTrainedModel):
@torch.no_grad()
def _init_weights(self, module):
PreTrainedModel._init_weights(self, module)
- if isinstance(module, ResidualScaling):
+ if isinstance(module, ZayaResidualScaling):
init.ones_(module.hidden_states_scale)
init.zeros_(module.hidden_states_bias)
init.ones_(module.residual_scale)
@@ -606,7 +544,7 @@ def _init_weights(self, module):
elif isinstance(module, ZayaModel):
init.ones_(module.input_hidden_states_scale)
init.zeros_(module.input_hidden_states_bias)
- elif isinstance(module, ZayaAttention):
+ elif isinstance(module, ZayaQKNorm):
init.zeros_(module.temp)
elif isinstance(module, ZayaRouter):
if module.use_eda:
@@ -633,7 +571,6 @@ def __init__(self, config: ZayaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- self.cache_layer_types = config.cache_layer_types
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
@@ -643,7 +580,7 @@ def __init__(self, config: ZayaConfig):
self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size))
self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size))
- self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon)
+ self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.rotary_emb = ZayaRotaryEmbedding(config=config)
@@ -679,11 +616,8 @@ def forward(
if position_ids is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(
- past_seen_tokens,
- past_seen_tokens + inputs_embeds.shape[1],
- device=inputs_embeds.device,
- ).unsqueeze(0)
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
+ position_ids = position_ids.unsqueeze(0)
if attention_mask is not None and attention_mask.ndim != 2:
raise ValueError(
@@ -718,7 +652,7 @@ def forward(
layer_type = self.config.layer_types[layer_n]
emb_to_use = position_embeddings[layer_type]
mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask}
- layer_outputs = decoder_layer(
+ hidden_states, prev_router_hidden_states = decoder_layer(
hidden_states,
prev_router_hidden_states,
attention_mask=mask_mapping,
@@ -727,9 +661,6 @@ def forward(
**kwargs,
)
- hidden_states = layer_outputs[0]
- prev_router_hidden_states = layer_outputs[2]
-
hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype))
return MoeModelOutputWithPast(
@@ -754,8 +685,8 @@ def _update_causal_mask(
# Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection.
sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None}
mask_creation_functions = {
- "full_attention": lambda: create_causal_mask(**mask_kwargs),
- "sliding_attention": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
+ "hybrid": lambda: create_causal_mask(**mask_kwargs),
+ "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
}
causal_mask_mapping = {}
for layer_type in set(self.config.layer_types):
@@ -764,18 +695,14 @@ def _update_causal_mask(
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
-class ZayaForCausalLM(ZayaPreTrainedModel, AfmoeForCausalLM):
+class ZayaForCausalLM(AfmoeForCausalLM, ZayaPreTrainedModel):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_is_stateful = True
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
- self.model = ZayaModel(config)
- self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
- self.post_init()
-
__all__ = [
"ZayaConfig",
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index 94bd74093e15..37a081273dcb 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -26,7 +26,11 @@
import torch
from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel
- from transformers.cache_utils import DynamicCache, LinearAttentionAndFullAttentionLayer
+ from transformers.cache_utils import (
+ DynamicCache,
+ LinearAttentionAndFullAttentionLayer,
+ LinearAttentionAndSlidingWindowAttentionLayer,
+ )
from transformers.models.zaya.modeling_zaya import ZayaCCAProjection
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
@@ -46,15 +50,15 @@ def __init__(self, parent):
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
- intermediate_size=64,
+ moe_intermediate_size=32,
)
self.head_dim = 8
self.num_experts = 4
self.num_experts_per_tok = 1
- self.zaya_mlp_expansion = 4
+ self.router_hidden_size = 4
self.tie_word_embeddings = False
self.rope_parameters = {
- "full_attention": {
+ "hybrid": {
"rope_theta": 10000,
"rope_type": "default",
"partial_rotary_factor": 0.5,
@@ -69,7 +73,8 @@ class ZayaModelTest(CausalLMModelTest, unittest.TestCase):
def _get_conv_state_shape(self, batch_size: int, config):
conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim
- return (batch_size, conv_state_size, config.cca_time0 + config.cca_time1 - 2)
+ conv_kernel_size = config.cca_time0 + config.cca_time1 - 2
+ return (batch_size, conv_state_size, conv_kernel_size)
def _get_recurrent_state_shape(self, batch_size: int, config):
return (batch_size, config.num_key_value_heads * config.head_dim // 2)
@@ -84,8 +89,13 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l
conv_shape = self._get_conv_state_shape(batch_size, config)
recurrent_shape = self._get_recurrent_state_shape(batch_size, config)
- for layer in past_key_values.layers:
- self.assertIs(type(layer), LinearAttentionAndFullAttentionLayer)
+ for layer_type, layer in zip(config.layer_types, past_key_values.layers):
+ expected_layer_class = (
+ LinearAttentionAndSlidingWindowAttentionLayer
+ if layer_type == "hybrid_sliding"
+ else LinearAttentionAndFullAttentionLayer
+ )
+ self.assertIs(type(layer), expected_layer_class)
self.assertEqual(layer.keys.shape, attention_shape)
self.assertEqual(layer.values.shape, attention_shape)
self.assertEqual(layer.conv_states.shape, conv_shape)
@@ -153,13 +163,13 @@ def test_model_rope_scaling_frequencies(self):
Copied from Laguna to adapt to per-layer-type rope configs.
"""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- config.layer_types = ["full_attention", "sliding_attention"]
- partial_rotary_factor = config.partial_rotary_factor
+ config.layer_types = ["hybrid", "hybrid_sliding"]
+ partial_rotary_factor = config.rope_parameters["hybrid"]["partial_rotary_factor"]
def set_rope_params(rope_params):
config.rope_parameters = {
- "full_attention": {**rope_params, "partial_rotary_factor": partial_rotary_factor},
- "sliding_attention": {**rope_params, "partial_rotary_factor": partial_rotary_factor},
+ "hybrid": {**rope_params, "partial_rotary_factor": partial_rotary_factor},
+ "hybrid_sliding": {**rope_params, "partial_rotary_factor": partial_rotary_factor},
}
set_rope_params({"rope_type": "default", "rope_theta": 10_000.0})
@@ -186,15 +196,15 @@ def set_rope_params(rope_params):
set_rope_params({"rope_type": "default", "rope_theta": 10_000.0})
original_rope = rope_class(config=config).to(torch_device)
- original_cos_short, original_sin_short = original_rope(x, position_ids_short, layer_type="sliding_attention")
- original_cos_long, original_sin_long = original_rope(x, position_ids_long, layer_type="sliding_attention")
+ original_cos_short, original_sin_short = original_rope(x, position_ids_short, layer_type="hybrid_sliding")
+ original_cos_long, original_sin_long = original_rope(x, position_ids_long, layer_type="hybrid_sliding")
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
set_rope_params({"rope_type": "linear", "factor": scaling_factor, "rope_theta": 10_000.0})
linear_scaling_rope = rope_class(config=config).to(torch_device)
- linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short, layer_type="sliding_attention")
- linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long, layer_type="sliding_attention")
+ linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding")
+ linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding")
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor):
@@ -204,22 +214,20 @@ def set_rope_params(rope_params):
set_rope_params({"rope_type": "dynamic", "factor": scaling_factor, "rope_theta": 10_000.0})
ntk_scaling_rope = rope_class(config=config).to(torch_device)
- ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short, layer_type="sliding_attention")
- ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long, layer_type="sliding_attention")
+ ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding")
+ ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding")
torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
- self.assertTrue(
- (ntk_scaling_rope.sliding_attention_inv_freq <= original_rope.sliding_attention_inv_freq).all()
- )
+ self.assertTrue((ntk_scaling_rope.hybrid_sliding_inv_freq <= original_rope.hybrid_sliding_inv_freq).all())
set_rope_params({"rope_type": "yarn", "factor": scaling_factor, "rope_theta": 10_000.0})
yarn_scaling_rope = rope_class(config=config).to(torch_device)
- yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short, layer_type="sliding_attention")
- yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long, layer_type="sliding_attention")
+ yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding")
+ yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding")
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
with self.assertRaises(AssertionError):
@@ -259,14 +267,14 @@ def test_sliding_attention_mask_is_used(self):
config = ZayaConfig(
vocab_size=128,
hidden_size=32,
- intermediate_size=64,
+ moe_intermediate_size=32,
num_hidden_layers=4,
num_experts=4,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
- zaya_mlp_expansion=4,
- layer_types=["sliding_attention", "full_attention", "full_attention", "full_attention"],
+ router_hidden_size=4,
+ layer_types=["hybrid_sliding", "hybrid", "hybrid_sliding", "hybrid"],
sliding_window=3,
tie_word_embeddings=False,
attn_implementation="eager",
@@ -285,13 +293,13 @@ def test_cca_cache_matches_full_forward(self):
config = ZayaConfig(
vocab_size=128,
hidden_size=32,
- intermediate_size=64,
+ moe_intermediate_size=32,
num_hidden_layers=1,
num_experts=4,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
- zaya_mlp_expansion=4,
+ router_hidden_size=4,
tie_word_embeddings=False,
)
torch.manual_seed(0)
@@ -312,13 +320,13 @@ def test_cca_cache_matches_full_forward_multi_token(self):
config = ZayaConfig(
vocab_size=128,
hidden_size=32,
- intermediate_size=64,
+ moe_intermediate_size=32,
num_hidden_layers=1,
num_experts=4,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
- zaya_mlp_expansion=4,
+ router_hidden_size=4,
tie_word_embeddings=False,
)
torch.manual_seed(0)
From 7bb5122a923b786e892b7d49a1c021b933cc3b0c Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Thu, 14 May 2026 15:41:33 +0800
Subject: [PATCH 25/36] llama decoderlayer
---
src/transformers/models/zaya/modeling_zaya.py | 9 +++++----
src/transformers/models/zaya/modular_zaya.py | 11 +++--------
2 files changed, 8 insertions(+), 12 deletions(-)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index a9a11daf14bb..d2e0712f603e 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -439,11 +439,12 @@ def forward(
class ZayaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ZayaConfig, layer_idx: int):
super().__init__()
- self.config = config
- self.self_attn = ZayaAttention(config, layer_idx)
- self.input_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = ZayaAttention(config=config, layer_idx=layer_idx)
self.mlp = ZayaSparseMoeBlock(config, layer_idx)
- self.post_attention_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.input_layernorm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size)
self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size)
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index d8655390ba61..5e325fb78d00 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -27,7 +27,6 @@
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
-from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@@ -41,7 +40,7 @@
from ..afmoe.modeling_afmoe import AfmoeForCausalLM
from ..laguna.configuration_laguna import LagunaConfig
from ..laguna.modeling_laguna import LagunaRotaryEmbedding
-from ..llama.modeling_llama import LlamaPreTrainedModel, repeat_kv
+from ..llama.modeling_llama import LlamaDecoderLayer, LlamaPreTrainedModel, repeat_kv
from ..phi3.modeling_phi3 import Phi3Attention
from ..qwen3_5_moe.modeling_qwen3_5_moe import (
apply_rotary_pos_emb,
@@ -354,14 +353,10 @@ def forward(
return attn_output, attn_weights
-class ZayaDecoderLayer(GradientCheckpointingLayer):
+class ZayaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: ZayaConfig, layer_idx: int):
- super().__init__()
- self.config = config
- self.self_attn = ZayaAttention(config, layer_idx)
- self.input_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ super().__init__(config, layer_idx)
self.mlp = ZayaSparseMoeBlock(config, layer_idx)
- self.post_attention_layernorm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size)
self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size)
From b315ae07e07ae1c50f5d841842ad7306555973d3 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Sat, 16 May 2026 11:51:23 +0800
Subject: [PATCH 26/36] improve
---
.../models/auto/tokenization_auto.py | 1 +
.../models/zaya/configuration_zaya.py | 6 -
src/transformers/models/zaya/modeling_zaya.py | 129 +++++++--------
src/transformers/models/zaya/modular_zaya.py | 149 +++++++-----------
tests/models/zaya/test_modeling_zaya.py | 49 ++----
tests/test_modeling_common.py | 1 +
6 files changed, 128 insertions(+), 207 deletions(-)
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index db34543e63a1..75677f8ea505 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -340,6 +340,7 @@
("xlstm", "GPTNeoXTokenizer" if is_tokenizers_available() else None),
("xmod", "XLMRobertaTokenizer" if is_tokenizers_available() else None),
("yoso", "AlbertTokenizer" if is_tokenizers_available() else None),
+ ("zaya", "GemmaTokenizer" if is_tokenizers_available() else None),
]
)
diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py
index 6fb47ecb76a4..4a18bd4716f2 100644
--- a/src/transformers/models/zaya/configuration_zaya.py
+++ b/src/transformers/models/zaya/configuration_zaya.py
@@ -116,14 +116,8 @@ def validate_architecture(self):
raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.")
if self.num_attention_heads % self.num_key_value_heads != 0:
raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
- if len(self.layer_types) != self.num_hidden_layers:
- raise ValueError("`layer_types` must have one entry per hidden layer.")
- if invalid_layer_types := set(self.layer_types) - {"hybrid", "hybrid_sliding"}:
- raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.")
if "hybrid_sliding" in self.layer_types and self.sliding_window is None:
raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.")
- if self.sliding_window is not None and self.sliding_window <= 0:
- raise ValueError("`sliding_window` must be a strictly positive integer.")
__all__ = ["ZayaConfig"]
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index d2e0712f603e..0815020f0e2e 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -166,8 +166,6 @@ class ZayaCCAProjection(nn.Module):
Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token
`t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection
from **the recurrent cache**.
-
- Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys.
"""
def __init__(self, config: ZayaConfig, layer_idx: int):
@@ -229,7 +227,7 @@ def forward(
qk_states = torch.cat([projected_queries, projected_keys], dim=-1)
query_residual = projected_queries.view(*hidden_shape)
- key_residual = projected_keys.view(*input_shape, -1, self.head_dim).transpose(1, 2)
+ key_residual = projected_keys.view(*hidden_shape).transpose(1, 2)
key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2)
query_residual = (query_residual + key_residual) * 0.5
key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2)
@@ -255,6 +253,8 @@ def forward(
query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual
key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual
+ # The value path carries half of each value head from the current token and half from the previous token.
+ # During cached decoding, `recurrent_v_state` is the previous token's delayed projection.
value_current = self.v_proj_current(hidden_states)
delayed_v_state = self.v_proj_delayed(hidden_states)
if use_precomputed_states:
@@ -272,9 +272,13 @@ def forward(
class ZayaQKNorm(nn.Module):
- def __init__(self, config: ZayaConfig, scaling: float):
+ """
+ L2-normalizes q/k states to sqrt(head_dim) and applies ZAYA's learned per-KV-head key scale.
+ """
+
+ def __init__(self, config: ZayaConfig):
super().__init__()
- self.head_dim_scale = scaling**-1
+ self.head_dim_scale = config.head_dim**0.5
self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads))
def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
@@ -385,7 +389,7 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
- self.qk_norm = ZayaQKNorm(config, self.scaling)
+ self.qk_norm = ZayaQKNorm(config)
def forward(
self,
@@ -395,7 +399,7 @@ def forward(
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
- batch_size, seq_length, _ = hidden_states.shape
+ input_shape = hidden_states.shape[:-1]
mask_mapping = attention_mask or {}
causal_mask = mask_mapping.get("causal")
@@ -430,7 +434,7 @@ def forward(
**kwargs,
)
- attn_output = attn_output.view(batch_size, seq_length, -1)
+ attn_output = attn_output.view(*input_shape, -1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@@ -469,9 +473,9 @@ def forward(
)
residual = self.post_attention_residual_scale(hidden_states, residual)
- hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype))
+ hidden_states = self.post_attention_layernorm(residual)
- hidden_states, prev_router_hidden_states, _ = self.mlp(
+ hidden_states, prev_router_hidden_states = self.mlp(
hidden_states,
prev_router_hidden_states,
)
@@ -618,17 +622,15 @@ def forward(
class ZayaSparseMoeBlock(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
- self.config = config
- self.hidden_dim = config.hidden_size
- self.gate = ZayaRouter(self.config, layer_idx)
- self.experts = ZayaExperts(self.config)
+ self.gate = ZayaRouter(config, layer_idx)
+ self.experts = ZayaExperts(config)
def forward(
self,
hidden_states: torch.Tensor,
prev_router_hidden_states: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
- router_logits, router_probs, router_indices, prev_router_hidden_states = self.gate(
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ _, router_probs, router_indices, prev_router_hidden_states = self.gate(
hidden_states, router_states=prev_router_hidden_states
)
@@ -637,7 +639,7 @@ def forward(
expert_output = self.experts(hidden_states_flat, router_indices, router_probs)
expert_output = expert_output.view(batch_size, seq_length, emb_dim)
- return expert_output, prev_router_hidden_states, router_logits
+ return expert_output, prev_router_hidden_states
@auto_docstring
@@ -658,7 +660,6 @@ class ZayaPreTrainedModel(PreTrainedModel):
"hidden_states": ZayaDecoderLayer,
"attentions": ZayaAttention,
}
- config_class = ZayaConfig
@torch.no_grad()
def _init_weights(self, module):
@@ -677,7 +678,7 @@ def _init_weights(self, module):
if module.use_eda:
init.ones_(module.router_states_scale)
init.zeros_(module.balancing_biases)
- module.balancing_biases[-1] = -1.0
+ module.balancing_biases[-1] = -1.0 # ignore: trf012
elif isinstance(module, ZayaExperts):
std = self.config.initializer_range
init.normal_(module.gate_up_proj, mean=0.0, std=std)
@@ -698,27 +699,20 @@ def __init__(self, config: ZayaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
-
+ self.rotary_emb = ZayaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
-
self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size))
self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size))
self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
- self.rotary_emb = ZayaRotaryEmbedding(config=config)
-
+ # Initialize weights and apply final processing
self.post_init()
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
@merge_with_config_defaults
@capture_outputs
@auto_docstring
@@ -751,18 +745,23 @@ def forward(
"ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution."
)
- causal_mask_mapping = self._update_causal_mask(
- attention_mask,
- inputs_embeds,
- position_ids,
- past_key_values,
- )
- padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
-
- # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
- # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
- if inputs_embeds.shape[1] == 1:
- padding_mask = None
+ mask_kwargs = {
+ "config": self.config,
+ "inputs_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection.
+ sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None}
+ mask_creation_functions = {
+ "hybrid": lambda: create_causal_mask(**mask_kwargs),
+ "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
+ }
+ causal_mask_mapping = {
+ layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types)
+ }
+ cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds)
hidden_states = inputs_embeds
@@ -775,50 +774,36 @@ def forward(
prev_router_hidden_states = None
- for layer_n, decoder_layer in enumerate(self.layers):
- layer_type = self.config.layer_types[layer_n]
- emb_to_use = position_embeddings[layer_type]
- mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask}
+ for idx, decoder_layer in enumerate(self.layers):
+ layer_type = self.config.layer_types[idx]
hidden_states, prev_router_hidden_states = decoder_layer(
hidden_states,
prev_router_hidden_states,
- attention_mask=mask_mapping,
+ attention_mask={"causal": causal_mask_mapping[layer_type], "padding": cca_mask},
past_key_values=past_key_values,
- position_embeddings=emb_to_use,
+ position_embeddings=position_embeddings[layer_type],
**kwargs,
)
- hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype))
+ hidden_states = self.final_norm(hidden_states)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- position_ids: torch.Tensor,
- past_key_values: Cache,
- ):
- mask_kwargs = {
- "config": self.config,
- "inputs_embeds": input_tensor,
- "attention_mask": attention_mask,
- "past_key_values": past_key_values,
- "position_ids": position_ids,
- }
- # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection.
- sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None}
- mask_creation_functions = {
- "hybrid": lambda: create_causal_mask(**mask_kwargs),
- "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
- }
- causal_mask_mapping = {}
- for layer_type in set(self.config.layer_types):
- causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]()
- return causal_mask_mapping
+ def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds):
+ """
+ No need to zero padding states when cached convolution states are already available or all inputs are valid.
+ """
+ cca_mask = attention_mask
+ if (past_key_values is not None and past_key_values.has_previous_state()) or (
+ attention_mask is not None and torch.all(attention_mask == 1)
+ ):
+ cca_mask = None
+ elif attention_mask is not None:
+ cca_mask = attention_mask[:, -inputs_embeds.shape[1] :]
+ return cca_mask
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 5e325fb78d00..1967bf6fc64a 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -39,7 +39,7 @@
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..afmoe.modeling_afmoe import AfmoeForCausalLM
from ..laguna.configuration_laguna import LagunaConfig
-from ..laguna.modeling_laguna import LagunaRotaryEmbedding
+from ..laguna.modeling_laguna import LagunaModel, LagunaRotaryEmbedding
from ..llama.modeling_llama import LlamaDecoderLayer, LlamaPreTrainedModel, repeat_kv
from ..phi3.modeling_phi3 import Phi3Attention
from ..qwen3_5_moe.modeling_qwen3_5_moe import (
@@ -94,7 +94,7 @@ class ZayaConfig(LagunaConfig):
cca_time1: int = 2
# Fields declared by LagunaConfig but not used by ZAYA.
- # TP and PP are not tested yet, so remove for now
+ # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state.
base_model_tp_plan = AttributeError()
base_model_pp_plan = AttributeError()
intermediate_size = AttributeError()
@@ -135,14 +135,8 @@ def validate_architecture(self):
raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.")
if self.num_attention_heads % self.num_key_value_heads != 0:
raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.")
- if len(self.layer_types) != self.num_hidden_layers:
- raise ValueError("`layer_types` must have one entry per hidden layer.")
- if invalid_layer_types := set(self.layer_types) - {"hybrid", "hybrid_sliding"}:
- raise ValueError(f"`layer_types` contains unsupported values: {sorted(invalid_layer_types)}.")
if "hybrid_sliding" in self.layer_types and self.sliding_window is None:
raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.")
- if self.sliding_window is not None and self.sliding_window <= 0:
- raise ValueError("`sliding_window` must be a strictly positive integer.")
class ZayaRotaryEmbedding(LagunaRotaryEmbedding):
@@ -163,8 +157,6 @@ class ZayaCCAProjection(nn.Module):
Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token
`t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection
from **the recurrent cache**.
-
- Final q/k states are L2-normalized to sqrt(head_dim). `temp` is the learned per-KV-head scale applied to keys.
"""
def __init__(self, config: ZayaConfig, layer_idx: int):
@@ -226,7 +218,7 @@ def forward(
qk_states = torch.cat([projected_queries, projected_keys], dim=-1)
query_residual = projected_queries.view(*hidden_shape)
- key_residual = projected_keys.view(*input_shape, -1, self.head_dim).transpose(1, 2)
+ key_residual = projected_keys.view(*hidden_shape).transpose(1, 2)
key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2)
query_residual = (query_residual + key_residual) * 0.5
key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2)
@@ -252,6 +244,8 @@ def forward(
query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual
key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual
+ # The value path carries half of each value head from the current token and half from the previous token.
+ # During cached decoding, `recurrent_v_state` is the previous token's delayed projection.
value_current = self.v_proj_current(hidden_states)
delayed_v_state = self.v_proj_delayed(hidden_states)
if use_precomputed_states:
@@ -269,9 +263,13 @@ def forward(
class ZayaQKNorm(nn.Module):
- def __init__(self, config: ZayaConfig, scaling: float):
+ """
+ L2-normalizes q/k states to sqrt(head_dim) and applies ZAYA's learned per-KV-head key scale.
+ """
+
+ def __init__(self, config: ZayaConfig):
super().__init__()
- self.head_dim_scale = scaling**-1
+ self.head_dim_scale = config.head_dim**0.5
self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads))
def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
@@ -298,7 +296,7 @@ def __init__(self, config: ZayaConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
- self.qk_norm = ZayaQKNorm(config, self.scaling)
+ self.qk_norm = ZayaQKNorm(config)
self.qkv_proj = ZayaCCAProjection(
config=self.config,
layer_idx=layer_idx,
@@ -312,7 +310,7 @@ def forward(
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
- batch_size, seq_length, _ = hidden_states.shape
+ input_shape = hidden_states.shape[:-1]
mask_mapping = attention_mask or {}
causal_mask = mask_mapping.get("causal")
@@ -347,7 +345,7 @@ def forward(
**kwargs,
)
- attn_output = attn_output.view(batch_size, seq_length, -1)
+ attn_output = attn_output.view(*input_shape, -1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@@ -381,9 +379,9 @@ def forward(
)
residual = self.post_attention_residual_scale(hidden_states, residual)
- hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype))
+ hidden_states = self.post_attention_layernorm(residual)
- hidden_states, prev_router_hidden_states, _ = self.mlp(
+ hidden_states, prev_router_hidden_states = self.mlp(
hidden_states,
prev_router_hidden_states,
)
@@ -494,17 +492,15 @@ class ZayaExperts(Qwen3MoeExperts):
class ZayaSparseMoeBlock(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
- self.config = config
- self.hidden_dim = config.hidden_size
- self.gate = ZayaRouter(self.config, layer_idx)
- self.experts = ZayaExperts(self.config)
+ self.gate = ZayaRouter(config, layer_idx)
+ self.experts = ZayaExperts(config)
def forward(
self,
hidden_states: torch.Tensor,
prev_router_hidden_states: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
- router_logits, router_probs, router_indices, prev_router_hidden_states = self.gate(
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ _, router_probs, router_indices, prev_router_hidden_states = self.gate(
hidden_states, router_states=prev_router_hidden_states
)
@@ -513,13 +509,11 @@ def forward(
expert_output = self.experts(hidden_states_flat, router_indices, router_probs)
expert_output = expert_output.view(batch_size, seq_length, emb_dim)
- return expert_output, prev_router_hidden_states, router_logits
+ return expert_output, prev_router_hidden_states
class ZayaPreTrainedModel(LlamaPreTrainedModel):
config: ZayaConfig
- config_class = ZayaConfig
- _no_split_modules = ["ZayaDecoderLayer"]
# ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache.
_can_compile_fullgraph = False
_can_record_outputs = {
@@ -545,7 +539,7 @@ def _init_weights(self, module):
if module.use_eda:
init.ones_(module.router_states_scale)
init.zeros_(module.balancing_biases)
- module.balancing_biases[-1] = -1.0
+ module.balancing_biases[-1] = -1.0 # ignore: trf012
elif isinstance(module, ZayaExperts):
std = self.config.initializer_range
init.normal_(module.gate_up_proj, mean=0.0, std=std)
@@ -561,32 +555,14 @@ def _init_weights(self, module):
@auto_docstring
-class ZayaModel(ZayaPreTrainedModel):
+class ZayaModel(LagunaModel):
def __init__(self, config: ZayaConfig):
super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
-
- self.gradient_checkpointing = False
-
+ del self.norm
self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size))
self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size))
self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
- self.rotary_emb = ZayaRotaryEmbedding(config=config)
-
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
@merge_with_config_defaults
@capture_outputs
@auto_docstring
@@ -619,18 +595,23 @@ def forward(
"ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution."
)
- causal_mask_mapping = self._update_causal_mask(
- attention_mask,
- inputs_embeds,
- position_ids,
- past_key_values,
- )
- padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] if attention_mask is not None else None
-
- # ZAYA's hybrid cache is not compileable, so generation keeps `attention_mask` as the original 2D padding mask.
- # CCA projection only needs it during multi-token prefill; single-token decoding uses the cached convolution state.
- if inputs_embeds.shape[1] == 1:
- padding_mask = None
+ mask_kwargs = {
+ "config": self.config,
+ "inputs_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection.
+ sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None}
+ mask_creation_functions = {
+ "hybrid": lambda: create_causal_mask(**mask_kwargs),
+ "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
+ }
+ causal_mask_mapping = {
+ layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types)
+ }
+ cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds)
hidden_states = inputs_embeds
@@ -643,50 +624,36 @@ def forward(
prev_router_hidden_states = None
- for layer_n, decoder_layer in enumerate(self.layers):
- layer_type = self.config.layer_types[layer_n]
- emb_to_use = position_embeddings[layer_type]
- mask_mapping = {"causal": causal_mask_mapping[layer_type], "padding": padding_mask}
+ for idx, decoder_layer in enumerate(self.layers):
+ layer_type = self.config.layer_types[idx]
hidden_states, prev_router_hidden_states = decoder_layer(
hidden_states,
prev_router_hidden_states,
- attention_mask=mask_mapping,
+ attention_mask={"causal": causal_mask_mapping[layer_type], "padding": cca_mask},
past_key_values=past_key_values,
- position_embeddings=emb_to_use,
+ position_embeddings=position_embeddings[layer_type],
**kwargs,
)
- hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype))
+ hidden_states = self.final_norm(hidden_states)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- position_ids: torch.Tensor,
- past_key_values: Cache,
- ):
- mask_kwargs = {
- "config": self.config,
- "inputs_embeds": input_tensor,
- "attention_mask": attention_mask,
- "past_key_values": past_key_values,
- "position_ids": position_ids,
- }
- # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection.
- sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None}
- mask_creation_functions = {
- "hybrid": lambda: create_causal_mask(**mask_kwargs),
- "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
- }
- causal_mask_mapping = {}
- for layer_type in set(self.config.layer_types):
- causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]()
- return causal_mask_mapping
+ def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds):
+ """
+ No need to zero padding states when cached convolution states are already available or all inputs are valid.
+ """
+ cca_mask = attention_mask
+ if (past_key_values is not None and past_key_values.has_previous_state()) or (
+ attention_mask is not None and torch.all(attention_mask == 1)
+ ):
+ cca_mask = None
+ elif attention_mask is not None:
+ cca_mask = attention_mask[:, -inputs_embeds.shape[1] :]
+ return cca_mask
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index 37a081273dcb..ceeb40fd6a06 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -40,30 +40,16 @@ class ZayaModelTester(CausalLMModelTester):
if is_torch_available():
base_model_class = ZayaModel
- def __init__(self, parent):
+ def __init__(self, parent, **kwargs):
super().__init__(
parent=parent,
- batch_size=2,
- seq_length=7,
- vocab_size=128,
- hidden_size=32,
- num_hidden_layers=2,
- num_attention_heads=4,
- num_key_value_heads=2,
+ num_hidden_layers=4,
moe_intermediate_size=32,
+ num_experts_per_tok=1,
+ layer_types=["hybrid", "hybrid_sliding", "hybrid", "hybrid_sliding"],
+ sliding_window=64,
+ **kwargs,
)
- self.head_dim = 8
- self.num_experts = 4
- self.num_experts_per_tok = 1
- self.router_hidden_size = 4
- self.tie_word_embeddings = False
- self.rope_parameters = {
- "hybrid": {
- "rope_theta": 10000,
- "rope_type": "default",
- "partial_rotary_factor": 0.5,
- },
- }
@require_torch
@@ -101,17 +87,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l
self.assertEqual(layer.conv_states.shape, conv_shape)
self.assertEqual(layer.recurrent_states.shape, recurrent_shape)
- def is_pipeline_test_to_skip(
- self,
- pipeline_test_case_name,
- config_class,
- model_architecture,
- tokenizer_name,
- image_processor_name,
- feature_extractor_name,
- processor_name,
- ):
- return True
@unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.")
def test_eager_padding_matches_padding_free_with_position_ids(self):
@@ -121,8 +96,11 @@ def test_eager_padding_matches_padding_free_with_position_ids(self):
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
- @unittest.skip("ZAYA uses MoE routing; equivalent-output comparisons are not stable for this architecture.")
- def test_model_outputs_equivalence(self, **kwargs):
+ @unittest.skip(
+ "ZAYA follows the original SWA behavior where sliding attention only applies the local causal pattern;"
+ "See https://github.com/huggingface/transformers/pull/45862#discussion_r3249556316"
+ )
+ def test_left_padding_compatibility(self):
pass
def test_attention_outputs(self):
@@ -163,7 +141,6 @@ def test_model_rope_scaling_frequencies(self):
Copied from Laguna to adapt to per-layer-type rope configs.
"""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- config.layer_types = ["hybrid", "hybrid_sliding"]
partial_rotary_factor = config.rope_parameters["hybrid"]["partial_rotary_factor"]
def set_rope_params(rope_params):
@@ -239,10 +216,6 @@ def set_rope_params(rope_params):
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
- @unittest.skip("ZAYA needs alternating attention and MoE layers in the tiny test configuration.")
- def test_num_layers_is_small(self):
- pass
-
def test_moe_router_logits(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = self.model_tester.causal_lm_class(config)
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index fcd3547a06c7..41a8f5cbbbfb 100644
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -814,6 +814,7 @@ def test_num_layers_is_small(self):
"Gemma3nVision2TextModelTest": 4, # need to test KV shared layer for both types: `full_attention` and `sliding_attention`
"BeitModelTest": 4, # BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers
"ZambaModelTest": 5, # The minimum number to test beyond the initial ["mamba", "mamba", "hybrid"] in `ZambaConfig._layers_block_type`
+ "ZayaModelTest": 4, # needs two passes over `hybrid` and `hybrid_sliding` layer types
}
target_num_hidden_layers = exceptional_num_hidden_layers.get(type(self).__name__, 2)
From 0df3204dbd5e00058e8c8b74283dfdb136f6e96f Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Sat, 16 May 2026 12:17:32 +0800
Subject: [PATCH 27/36] update date
---
docs/source/en/model_doc/zaya.md | 2 +-
tests/models/zaya/test_modeling_zaya.py | 1 -
2 files changed, 1 insertion(+), 2 deletions(-)
diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md
index 06beb12e2e6f..199cd5d2935b 100644
--- a/docs/source/en/model_doc/zaya.md
+++ b/docs/source/en/model_doc/zaya.md
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
-*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-13.*
+*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-16.*
# ZAYA
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index ceeb40fd6a06..316d206004d0 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -87,7 +87,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l
self.assertEqual(layer.conv_states.shape, conv_shape)
self.assertEqual(layer.recurrent_states.shape, recurrent_shape)
-
@unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
From d362c90c378b4b32b54513f1627b6d9d59ccc6a1 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Sun, 17 May 2026 20:04:39 +0800
Subject: [PATCH 28/36] Fix ZAYA residual stream precision regression
---
src/transformers/models/zaya/modeling_zaya.py | 17 +++++++++--------
src/transformers/models/zaya/modular_zaya.py | 17 +++++++++--------
2 files changed, 18 insertions(+), 16 deletions(-)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index 0815020f0e2e..6224c0bcc6bd 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -462,7 +462,7 @@ def forward(
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.input_layernorm(residual.to(dtype=self.input_layernorm.weight.dtype))
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
@@ -473,7 +473,7 @@ def forward(
)
residual = self.post_attention_residual_scale(hidden_states, residual)
- hidden_states = self.post_attention_layernorm(residual)
+ hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype))
hidden_states, prev_router_hidden_states = self.mlp(
hidden_states,
@@ -494,12 +494,10 @@ def __init__(self, hidden_size: int):
self.residual_bias = nn.Parameter(torch.zeros(hidden_size))
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor):
- output_dtype = hidden_states.dtype
+ # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path.
hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale
- # Matches the original ZAYA `residual_in_fp32` path.
- residual = residual.to(torch.float32)
residual = (residual + self.residual_bias) * self.residual_scale
- return (hidden_states + residual).to(output_dtype)
+ return hidden_states + residual
class ZayaRouterMLP(nn.Module):
@@ -770,7 +768,10 @@ def forward(
for layer_type in set(self.config.layer_types)
}
- hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale
+ # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path.
+ hidden_states = ((hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale).to(
+ torch.float32
+ )
prev_router_hidden_states = None
@@ -785,7 +786,7 @@ def forward(
**kwargs,
)
- hidden_states = self.final_norm(hidden_states)
+ hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype))
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 1967bf6fc64a..c7b1c237a066 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -368,7 +368,7 @@ def forward(
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.input_layernorm(residual.to(dtype=self.input_layernorm.weight.dtype))
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
@@ -379,7 +379,7 @@ def forward(
)
residual = self.post_attention_residual_scale(hidden_states, residual)
- hidden_states = self.post_attention_layernorm(residual)
+ hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype))
hidden_states, prev_router_hidden_states = self.mlp(
hidden_states,
@@ -400,12 +400,10 @@ def __init__(self, hidden_size: int):
self.residual_bias = nn.Parameter(torch.zeros(hidden_size))
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor):
- output_dtype = hidden_states.dtype
+ # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path.
hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale
- # Matches the original ZAYA `residual_in_fp32` path.
- residual = residual.to(torch.float32)
residual = (residual + self.residual_bias) * self.residual_scale
- return (hidden_states + residual).to(output_dtype)
+ return hidden_states + residual
class ZayaRouterMLP(nn.Module):
@@ -620,7 +618,10 @@ def forward(
for layer_type in set(self.config.layer_types)
}
- hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale
+ # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path.
+ hidden_states = ((hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale).to(
+ torch.float32
+ )
prev_router_hidden_states = None
@@ -635,7 +636,7 @@ def forward(
**kwargs,
)
- hidden_states = self.final_norm(hidden_states)
+ hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype))
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
From f6178966056a9269e6303ec515aeadfd78773634 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 19 May 2026 20:04:07 +0800
Subject: [PATCH 29/36] code clean
---
.../models/zaya/convert_zaya_weights_to_hf.py | 2 +-
src/transformers/models/zaya/modeling_zaya.py | 29 ++++-----
src/transformers/models/zaya/modular_zaya.py | 29 ++++-----
tests/models/zaya/test_modeling_zaya.py | 60 ++-----------------
tests/test_modeling_common.py | 1 -
5 files changed, 27 insertions(+), 94 deletions(-)
diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
index 2ac6cb7df869..ad279541ab83 100644
--- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
+++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
@@ -72,7 +72,7 @@ def _rename_common(rest: str) -> str:
("self_attn.qkv.val_proj1.", "self_attn.qkv_proj.v_proj_current."),
("self_attn.qkv.val_proj2.", "self_attn.qkv_proj.v_proj_delayed."),
("self_attn.qkv.", "self_attn.qkv_proj."),
- ("zaya_block.router.rmsnorm_eda.", "mlp.gate.router_mlp.rmsnorm_eda."),
+ ("zaya_block.router.rmsnorm_eda.", "mlp.gate.router_mlp.norm."),
("zaya_block.router.router_mlp.0.", "mlp.gate.router_mlp.fc1."),
("zaya_block.router.router_mlp.2.", "mlp.gate.router_mlp.fc2."),
("zaya_block.router.router_mlp.4.", "mlp.gate.router_mlp.out_proj."),
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index 6224c0bcc6bd..e10cf4388ef5 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -160,12 +160,8 @@ class ZayaCCAProjection(nn.Module):
"""
Projects hidden states into attention q/k/v states with ZAYA's CCA path.
- `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal
- `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail;
- for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`.
- Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token
- `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection
- from **the recurrent cache**.
+ This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D
+ convolution, q/k keep residual projection paths, and v uses a delayed recurrent state.
"""
def __init__(self, config: ZayaConfig, layer_idx: int):
@@ -242,8 +238,7 @@ def forward(
if past_key_values is not None:
new_conv_state = qk_states[..., -self.conv_kernel_size :]
- if new_conv_state.shape[-1] < self.conv_kernel_size:
- new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0))
+ new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0))
past_key_values.update_conv_state(new_conv_state, self.layer_idx)
qk_states = self.conv_qk_depthwise(qk_states)
@@ -405,8 +400,8 @@ def forward(
causal_mask = mask_mapping.get("causal")
padding_mask = mask_mapping.get("padding")
+ # ZAYA replaces the usual independent q/k/v projections with CCA projection followed by special q/k normalization.
query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask)
-
query_states, key_states = self.qk_norm(query_states, key_states)
query_states = query_states.transpose(1, 2)
@@ -503,14 +498,14 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor):
class ZayaRouterMLP(nn.Module):
def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float):
super().__init__()
- self.rmsnorm_eda = ZayaRMSNorm(hidden_size, eps=rms_norm_eps)
+ self.norm = ZayaRMSNorm(hidden_size, eps=rms_norm_eps)
self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True)
self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, num_experts, bias=False)
self.act_fn = nn.GELU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.rmsnorm_eda(hidden_states)
+ hidden_states = self.norm(hidden_states)
hidden_states = self.act_fn(self.fc1(hidden_states))
hidden_states = self.act_fn(self.fc2(hidden_states))
return self.out_proj(hidden_states)
@@ -676,7 +671,7 @@ def _init_weights(self, module):
if module.use_eda:
init.ones_(module.router_states_scale)
init.zeros_(module.balancing_biases)
- module.balancing_biases[-1] = -1.0 # ignore: trf012
+ module.balancing_biases[-1] = -1.0 # trf-ignore: TRF012
elif isinstance(module, ZayaExperts):
std = self.config.initializer_range
init.normal_(module.gate_up_proj, mean=0.0, std=std)
@@ -750,16 +745,14 @@ def forward(
"past_key_values": past_key_values,
"position_ids": position_ids,
}
- # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection.
- sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None}
mask_creation_functions = {
"hybrid": lambda: create_causal_mask(**mask_kwargs),
- "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
+ "hybrid_sliding": lambda: create_sliding_window_causal_mask(**mask_kwargs),
}
causal_mask_mapping = {
layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types)
}
- cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds)
+ cca_mask = self._update_cca_mask(attention_mask, past_key_values)
hidden_states = inputs_embeds
@@ -793,7 +786,7 @@ def forward(
past_key_values=past_key_values if use_cache else None,
)
- def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds):
+ def _update_cca_mask(self, attention_mask, past_key_values):
"""
No need to zero padding states when cached convolution states are already available or all inputs are valid.
"""
@@ -802,8 +795,6 @@ def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds):
attention_mask is not None and torch.all(attention_mask == 1)
):
cca_mask = None
- elif attention_mask is not None:
- cca_mask = attention_mask[:, -inputs_embeds.shape[1] :]
return cca_mask
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index c7b1c237a066..e85d6dd06591 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -151,12 +151,8 @@ class ZayaCCAProjection(nn.Module):
"""
Projects hidden states into attention q/k/v states with ZAYA's CCA path.
- `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal
- `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail;
- for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`.
- Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token
- `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection
- from **the recurrent cache**.
+ This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D
+ convolution, q/k keep residual projection paths, and v uses a delayed recurrent state.
"""
def __init__(self, config: ZayaConfig, layer_idx: int):
@@ -233,8 +229,7 @@ def forward(
if past_key_values is not None:
new_conv_state = qk_states[..., -self.conv_kernel_size :]
- if new_conv_state.shape[-1] < self.conv_kernel_size:
- new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0))
+ new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0))
past_key_values.update_conv_state(new_conv_state, self.layer_idx)
qk_states = self.conv_qk_depthwise(qk_states)
@@ -316,8 +311,8 @@ def forward(
causal_mask = mask_mapping.get("causal")
padding_mask = mask_mapping.get("padding")
+ # ZAYA replaces the usual independent q/k/v projections with CCA projection followed by special q/k normalization.
query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask)
-
query_states, key_states = self.qk_norm(query_states, key_states)
query_states = query_states.transpose(1, 2)
@@ -409,14 +404,14 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor):
class ZayaRouterMLP(nn.Module):
def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float):
super().__init__()
- self.rmsnorm_eda = ZayaRMSNorm(hidden_size, eps=rms_norm_eps)
+ self.norm = ZayaRMSNorm(hidden_size, eps=rms_norm_eps)
self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True)
self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, num_experts, bias=False)
self.act_fn = nn.GELU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.rmsnorm_eda(hidden_states)
+ hidden_states = self.norm(hidden_states)
hidden_states = self.act_fn(self.fc1(hidden_states))
hidden_states = self.act_fn(self.fc2(hidden_states))
return self.out_proj(hidden_states)
@@ -537,7 +532,7 @@ def _init_weights(self, module):
if module.use_eda:
init.ones_(module.router_states_scale)
init.zeros_(module.balancing_biases)
- module.balancing_biases[-1] = -1.0 # ignore: trf012
+ module.balancing_biases[-1] = -1.0 # trf-ignore: TRF012
elif isinstance(module, ZayaExperts):
std = self.config.initializer_range
init.normal_(module.gate_up_proj, mean=0.0, std=std)
@@ -600,16 +595,14 @@ def forward(
"past_key_values": past_key_values,
"position_ids": position_ids,
}
- # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection.
- sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None}
mask_creation_functions = {
"hybrid": lambda: create_causal_mask(**mask_kwargs),
- "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs),
+ "hybrid_sliding": lambda: create_sliding_window_causal_mask(**mask_kwargs),
}
causal_mask_mapping = {
layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types)
}
- cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds)
+ cca_mask = self._update_cca_mask(attention_mask, past_key_values)
hidden_states = inputs_embeds
@@ -643,7 +636,7 @@ def forward(
past_key_values=past_key_values if use_cache else None,
)
- def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds):
+ def _update_cca_mask(self, attention_mask, past_key_values):
"""
No need to zero padding states when cached convolution states are already available or all inputs are valid.
"""
@@ -652,8 +645,6 @@ def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds):
attention_mask is not None and torch.all(attention_mask == 1)
):
cca_mask = None
- elif attention_mask is not None:
- cca_mask = attention_mask[:, -inputs_embeds.shape[1] :]
return cca_mask
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index 316d206004d0..a222a957969c 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -43,10 +43,10 @@ class ZayaModelTester(CausalLMModelTester):
def __init__(self, parent, **kwargs):
super().__init__(
parent=parent,
- num_hidden_layers=4,
+ num_hidden_layers=2,
moe_intermediate_size=32,
num_experts_per_tok=1,
- layer_types=["hybrid", "hybrid_sliding", "hybrid", "hybrid_sliding"],
+ layer_types=["hybrid", "hybrid_sliding"],
sliding_window=64,
**kwargs,
)
@@ -95,13 +95,6 @@ def test_eager_padding_matches_padding_free_with_position_ids(self):
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
- @unittest.skip(
- "ZAYA follows the original SWA behavior where sliding attention only applies the local causal pattern;"
- "See https://github.com/huggingface/transformers/pull/45862#discussion_r3249556316"
- )
- def test_left_padding_compatibility(self):
- pass
-
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
@@ -215,22 +208,6 @@ def set_rope_params(rope_params):
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
- def test_moe_router_logits(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = self.model_tester.causal_lm_class(config)
- model.to(torch_device)
- model.eval()
-
- with torch.no_grad():
- outputs = model(**inputs_dict, output_router_logits=True)
-
- expected_moe_layers = config.num_hidden_layers
- self.assertEqual(len(outputs.router_logits), expected_moe_layers)
- self.assertEqual(
- outputs.router_logits[0].shape,
- (self.model_tester.batch_size * self.model_tester.seq_length, config.num_experts + 1),
- )
-
def test_num_experts_per_tok_validation(self):
with self.assertRaisesRegex(StrictDataclassClassValidationError, "num_experts_per_tok=1"):
ZayaConfig(num_experts_per_tok=2)
@@ -240,13 +217,13 @@ def test_sliding_attention_mask_is_used(self):
vocab_size=128,
hidden_size=32,
moe_intermediate_size=32,
- num_hidden_layers=4,
+ num_hidden_layers=2,
num_experts=4,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
router_hidden_size=4,
- layer_types=["hybrid_sliding", "hybrid", "hybrid_sliding", "hybrid"],
+ layer_types=["hybrid_sliding", "hybrid"],
sliding_window=3,
tie_word_embeddings=False,
attn_implementation="eager",
@@ -261,33 +238,6 @@ def test_sliding_attention_mask_is_used(self):
sliding_attention = outputs.attentions[0]
self.assertTrue(torch.all(sliding_attention[:, :, -1, :3] == 0))
- def test_cca_cache_matches_full_forward(self):
- config = ZayaConfig(
- vocab_size=128,
- hidden_size=32,
- moe_intermediate_size=32,
- num_hidden_layers=1,
- num_experts=4,
- num_attention_heads=4,
- num_key_value_heads=2,
- head_dim=8,
- router_hidden_size=4,
- tie_word_embeddings=False,
- )
- torch.manual_seed(0)
- cca = ZayaCCAProjection(config, layer_idx=0).to(torch_device)
- cca.eval()
- hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device)
-
- with torch.no_grad():
- full = cca(hidden_states, None, None)
- cache = DynamicCache(config=config)
- cca(hidden_states[:, :4], cache, None)
- cached = cca(hidden_states[:, 4:], cache, None)
-
- for full_states, cached_states in zip(full, cached):
- torch.testing.assert_close(full_states[:, -1:], cached_states, rtol=1e-5, atol=1e-5)
-
def test_cca_cache_matches_full_forward_multi_token(self):
config = ZayaConfig(
vocab_size=128,
@@ -307,6 +257,8 @@ def test_cca_cache_matches_full_forward_multi_token(self):
hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device)
with torch.no_grad():
+ # Compare full CCA projection against a cached continuation. The second chunk must recover the same
+ # q/k/v states from the cached convolution tail and delayed recurrent value state.
full = cca(hidden_states, None, None)
cache = DynamicCache(config=config)
cca(hidden_states[:, :3], cache, None)
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 41a8f5cbbbfb..fcd3547a06c7 100644
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -814,7 +814,6 @@ def test_num_layers_is_small(self):
"Gemma3nVision2TextModelTest": 4, # need to test KV shared layer for both types: `full_attention` and `sliding_attention`
"BeitModelTest": 4, # BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers
"ZambaModelTest": 5, # The minimum number to test beyond the initial ["mamba", "mamba", "hybrid"] in `ZambaConfig._layers_block_type`
- "ZayaModelTest": 4, # needs two passes over `hybrid` and `hybrid_sliding` layer types
}
target_num_hidden_layers = exceptional_num_hidden_layers.get(type(self).__name__, 2)
From bc55e03537aef2a783283f8508bb6bea194207bb Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 19 May 2026 20:19:24 +0800
Subject: [PATCH 30/36] date
---
docs/source/en/model_doc/zaya.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md
index 199cd5d2935b..de7916038254 100644
--- a/docs/source/en/model_doc/zaya.md
+++ b/docs/source/en/model_doc/zaya.md
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
-*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-16.*
+*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-19.*
# ZAYA
From 6ad8e9f077473d7ed5eaa01265ec3be0dc6690ed Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Wed, 20 May 2026 20:01:02 +0800
Subject: [PATCH 31/36] update fsdp
---
src/transformers/models/zaya/configuration_zaya.py | 6 ++++++
.../models/zaya/convert_zaya_weights_to_hf.py | 2 ++
src/transformers/models/zaya/modeling_zaya.py | 5 +++--
src/transformers/models/zaya/modular_zaya.py | 13 +++++++++----
4 files changed, 20 insertions(+), 6 deletions(-)
diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py
index 4a18bd4716f2..95de5cdf2919 100644
--- a/src/transformers/models/zaya/configuration_zaya.py
+++ b/src/transformers/models/zaya/configuration_zaya.py
@@ -86,6 +86,12 @@ class ZayaConfig(PreTrainedConfig):
cca_time0: int = 2
cca_time1: int = 2
+ base_model_fsdp_plan = {
+ "embed_tokens": "free_full_weight",
+ "layers.*": "free_full_weight",
+ "norm": "keep_full_weight",
+ }
+
def __post_init__(self, **kwargs):
self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types)
diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
index ad279541ab83..a359ed5bb8aa 100644
--- a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
+++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py
@@ -107,6 +107,8 @@ def convert_weight_name(name: str, old_num_hidden_layers: int | None = None) ->
match = _LAYER_PATTERN.match(name)
if match is None:
+ if name.startswith("model.final_norm."):
+ return f"model.norm.{name.removeprefix('model.final_norm.')}"
if old_num_hidden_layers is not None and name.startswith("model.res_scale."):
new_layer_idx = old_num_hidden_layers // 2 - 1
return f"model.layers.{new_layer_idx}.post_mlp_residual_scale.{name.removeprefix('model.res_scale.')}"
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index e10cf4388ef5..fd5fab1630aa 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -697,11 +697,11 @@ def __init__(self, config: ZayaConfig):
self.layers = nn.ModuleList(
[ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
+ self.norm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = ZayaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size))
self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size))
- self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
# Initialize weights and apply final processing
self.post_init()
@@ -779,7 +779,7 @@ def forward(
**kwargs,
)
- hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype))
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
@@ -803,6 +803,7 @@ class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_gather_output"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+ _fsdp_plan = {"lm_head": "keep_full_weight"}
_is_stateful = True
def __init__(self, config, **kwargs):
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index e85d6dd06591..83ad9199a267 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -93,8 +93,14 @@ class ZayaConfig(LagunaConfig):
cca_time0: int = 2
cca_time1: int = 2
+ base_model_fsdp_plan = {
+ "embed_tokens": "free_full_weight",
+ "layers.*": "free_full_weight",
+ "norm": "keep_full_weight",
+ }
+
# Fields declared by LagunaConfig but not used by ZAYA.
- # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state.
+ # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state. For TP, see discussion https://github.com/huggingface/transformers/pull/45862#discussion_r3266709862
base_model_tp_plan = AttributeError()
base_model_pp_plan = AttributeError()
intermediate_size = AttributeError()
@@ -551,10 +557,8 @@ def _init_weights(self, module):
class ZayaModel(LagunaModel):
def __init__(self, config: ZayaConfig):
super().__init__(config)
- del self.norm
self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size))
self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size))
- self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
@merge_with_config_defaults
@capture_outputs
@@ -629,7 +633,7 @@ def forward(
**kwargs,
)
- hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype))
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
@@ -651,6 +655,7 @@ def _update_cca_mask(self, attention_mask, past_key_values):
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
class ZayaForCausalLM(AfmoeForCausalLM, ZayaPreTrainedModel):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+ _fsdp_plan = {"lm_head": "keep_full_weight"}
_is_stateful = True
def __init__(self, config, **kwargs):
From ecb80ed181814aef33f65a39150f57a6f1c9edf8 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Wed, 20 May 2026 20:03:51 +0800
Subject: [PATCH 32/36] clean
---
src/transformers/models/zaya/modeling_zaya.py | 2 --
src/transformers/models/zaya/modular_zaya.py | 3 +++
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index fd5fab1630aa..befe762155e9 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -801,8 +801,6 @@ def _update_cca_mask(self, attention_mask, past_key_values):
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
- _tp_plan = {"lm_head": "colwise_gather_output"}
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
_fsdp_plan = {"lm_head": "keep_full_weight"}
_is_stateful = True
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 83ad9199a267..e83a97518ba4 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -658,6 +658,9 @@ class ZayaForCausalLM(AfmoeForCausalLM, ZayaPreTrainedModel):
_fsdp_plan = {"lm_head": "keep_full_weight"}
_is_stateful = True
+ _tp_plan = AttributeError()
+ _pp_plan = AttributeError()
+
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
From db1db76278affd6643da34f63c5f01c0ba509cf9 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Wed, 20 May 2026 23:04:25 +0800
Subject: [PATCH 33/36] upstream
---
src/transformers/models/zaya/configuration_zaya.py | 12 ++++++------
src/transformers/models/zaya/modeling_zaya.py | 1 +
src/transformers/models/zaya/modular_zaya.py | 7 -------
3 files changed, 7 insertions(+), 13 deletions(-)
diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py
index 95de5cdf2919..690a521ee357 100644
--- a/src/transformers/models/zaya/configuration_zaya.py
+++ b/src/transformers/models/zaya/configuration_zaya.py
@@ -53,6 +53,12 @@ class ZayaConfig(PreTrainedConfig):
model_type = "zaya"
keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_fsdp_plan = {
+ "embed_tokens": "free_full_weight",
+ "layers.*": "free_full_weight",
+ "norm": "keep_full_weight",
+ }
+
vocab_size: int = 262272
hidden_size: int = 2048
num_hidden_layers: int = 40
@@ -86,12 +92,6 @@ class ZayaConfig(PreTrainedConfig):
cca_time0: int = 2
cca_time1: int = 2
- base_model_fsdp_plan = {
- "embed_tokens": "free_full_weight",
- "layers.*": "free_full_weight",
- "norm": "keep_full_weight",
- }
-
def __post_init__(self, **kwargs):
self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index befe762155e9..b401f22d190d 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -801,6 +801,7 @@ def _update_cca_mask(self, attention_mask, past_key_values):
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+ _sp_plan = {"lm_head": "colwise_loss_parallel"}
_fsdp_plan = {"lm_head": "keep_full_weight"}
_is_stateful = True
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index e83a97518ba4..7569d5e697a9 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -93,12 +93,6 @@ class ZayaConfig(LagunaConfig):
cca_time0: int = 2
cca_time1: int = 2
- base_model_fsdp_plan = {
- "embed_tokens": "free_full_weight",
- "layers.*": "free_full_weight",
- "norm": "keep_full_weight",
- }
-
# Fields declared by LagunaConfig but not used by ZAYA.
# TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state. For TP, see discussion https://github.com/huggingface/transformers/pull/45862#discussion_r3266709862
base_model_tp_plan = AttributeError()
@@ -655,7 +649,6 @@ def _update_cca_mask(self, attention_mask, past_key_values):
@auto_docstring(checkpoint="Zyphra/ZAYA1-8B")
class ZayaForCausalLM(AfmoeForCausalLM, ZayaPreTrainedModel):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
- _fsdp_plan = {"lm_head": "keep_full_weight"}
_is_stateful = True
_tp_plan = AttributeError()
From 86215d58d5eae2ef1021365beed3dc07eb36c450 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 26 May 2026 17:42:26 +0800
Subject: [PATCH 34/36] improve comments
---
src/transformers/models/zaya/modeling_zaya.py | 8 ++-
src/transformers/models/zaya/modular_zaya.py | 17 +++---
tests/models/zaya/test_modeling_zaya.py | 57 +++++++++++--------
3 files changed, 51 insertions(+), 31 deletions(-)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index b401f22d190d..f6e2a16c82af 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -158,7 +158,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class ZayaCCAProjection(nn.Module):
"""
- Projects hidden states into attention q/k/v states with ZAYA's CCA path.
+ Projects hidden states into attention q/k/v states with ZAYA's Compressed Convolutional Attention (CCA) path.
+ See https://arxiv.org/abs/2510.04476.
This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D
convolution, q/k keep residual projection paths, and v uses a delayed recurrent state.
@@ -457,6 +458,8 @@ def forward(
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
residual = hidden_states
+ # Match upstream's residual_in_fp32 path by keeping the residual stream in fp32 and avoiding extra
+ # fp32->bf16 round trips in the residual module.
hidden_states = self.input_layernorm(residual.to(dtype=self.input_layernorm.weight.dtype))
hidden_states, _ = self.self_attn(
@@ -623,6 +626,7 @@ def forward(
hidden_states: torch.Tensor,
prev_router_hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
+ # ZAYA carries router hidden states across decoder layers; the next layer consumes this state in its router.
_, router_probs, router_indices, prev_router_hidden_states = self.gate(
hidden_states, router_states=prev_router_hidden_states
)
@@ -770,6 +774,8 @@ def forward(
for idx, decoder_layer in enumerate(self.layers):
layer_type = self.config.layer_types[idx]
+ # Attention uses the prepared causal mask, while CCA projection still needs the raw 2D padding mask to
+ # zero padding tokens before convolution.
hidden_states, prev_router_hidden_states = decoder_layer(
hidden_states,
prev_router_hidden_states,
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index 7569d5e697a9..c3b5008fe988 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -35,8 +35,7 @@
TransformersKwargs,
auto_docstring,
)
-from ...utils.generic import merge_with_config_defaults
-from ...utils.output_capturing import OutputRecorder, capture_outputs
+from ...utils.output_capturing import OutputRecorder
from ..afmoe.modeling_afmoe import AfmoeForCausalLM
from ..laguna.configuration_laguna import LagunaConfig
from ..laguna.modeling_laguna import LagunaModel, LagunaRotaryEmbedding
@@ -94,7 +93,8 @@ class ZayaConfig(LagunaConfig):
cca_time1: int = 2
# Fields declared by LagunaConfig but not used by ZAYA.
- # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state. For TP, see discussion https://github.com/huggingface/transformers/pull/45862#discussion_r3266709862
+ # NOTE: TP is intentionally disabled for now because the useful degree is limited by ZAYA's 2 KV heads; see
+ # https://github.com/huggingface/transformers/pull/45862#discussion_r3266709862. PP needs coverage for the cross-layer router state.
base_model_tp_plan = AttributeError()
base_model_pp_plan = AttributeError()
intermediate_size = AttributeError()
@@ -149,7 +149,8 @@ class ZayaRMSNorm(Qwen3MoeRMSNorm):
class ZayaCCAProjection(nn.Module):
"""
- Projects hidden states into attention q/k/v states with ZAYA's CCA path.
+ Projects hidden states into attention q/k/v states with ZAYA's Compressed Convolutional Attention (CCA) path.
+ See https://arxiv.org/abs/2510.04476.
This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D
convolution, q/k keep residual projection paths, and v uses a delayed recurrent state.
@@ -363,6 +364,8 @@ def forward(
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
residual = hidden_states
+ # Match upstream's residual_in_fp32 path by keeping the residual stream in fp32 and avoiding extra
+ # fp32->bf16 round trips in the residual module.
hidden_states = self.input_layernorm(residual.to(dtype=self.input_layernorm.weight.dtype))
hidden_states, _ = self.self_attn(
@@ -493,6 +496,7 @@ def forward(
hidden_states: torch.Tensor,
prev_router_hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
+ # ZAYA carries router hidden states across decoder layers; the next layer consumes this state in its router.
_, router_probs, router_indices, prev_router_hidden_states = self.gate(
hidden_states, router_states=prev_router_hidden_states
)
@@ -554,9 +558,6 @@ def __init__(self, config: ZayaConfig):
self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size))
self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size))
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -618,6 +619,8 @@ def forward(
for idx, decoder_layer in enumerate(self.layers):
layer_type = self.config.layer_types[idx]
+ # Attention uses the prepared causal mask, while CCA projection still needs the raw 2D padding mask to
+ # zero padding tokens before convolution.
hidden_states, prev_router_hidden_states = decoder_layer(
hidden_states,
prev_router_hidden_states,
diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py
index a222a957969c..86448222a20e 100644
--- a/tests/models/zaya/test_modeling_zaya.py
+++ b/tests/models/zaya/test_modeling_zaya.py
@@ -19,7 +19,7 @@
from parameterized import parameterized
from transformers import is_torch_available
-from transformers.testing_utils import cleanup, require_torch, slow, torch_device
+from transformers.testing_utils import Expectations, cleanup, require_torch, slow, torch_device
if is_torch_available():
@@ -213,30 +213,19 @@ def test_num_experts_per_tok_validation(self):
ZayaConfig(num_experts_per_tok=2)
def test_sliding_attention_mask_is_used(self):
- config = ZayaConfig(
- vocab_size=128,
- hidden_size=32,
- moe_intermediate_size=32,
- num_hidden_layers=2,
- num_experts=4,
- num_attention_heads=4,
- num_key_value_heads=2,
- head_dim=8,
- router_hidden_size=4,
- layer_types=["hybrid_sliding", "hybrid"],
- sliding_window=3,
- tie_word_embeddings=False,
- attn_implementation="eager",
- )
- model = ZayaModel(config).to(torch_device)
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.layer_types = ["hybrid_sliding"] + ["hybrid"] * (config.num_hidden_layers - 1)
+ config.sliding_window = 3
+ config._attn_implementation = "eager"
+
+ model = ZayaModel._from_config(config, attn_implementation="eager").to(torch_device)
model.eval()
- input_ids = torch.arange(6, device=torch_device).unsqueeze(0)
with torch.no_grad():
- outputs = model(input_ids=input_ids, output_attentions=True)
+ outputs = model(input_ids=inputs_dict["input_ids"].to(torch_device), output_attentions=True)
sliding_attention = outputs.attentions[0]
- self.assertTrue(torch.all(sliding_attention[:, :, -1, :3] == 0))
+ self.assertTrue(torch.all(sliding_attention[:, :, -1, : -config.sliding_window] == 0))
def test_cca_cache_matches_full_forward_multi_token(self):
config = ZayaConfig(
@@ -334,6 +323,18 @@ def test_model_logits(self):
self.assertEqual(logits.shape, (1, inputs.input_ids.shape[-1], model.config.vocab_size))
self.assertTrue(torch.isfinite(logits).all().item())
+ EXPECTED_LOGITS = Expectations(
+ {
+ (None, None): [
+ [0.3613, 0.3633, 0.3633],
+ [-1.3672, -1.3672, -1.3672],
+ [-2.8750, -2.8750, -2.8750],
+ ]
+ }
+ ) # fmt: skip
+ expected_slice = torch.tensor(EXPECTED_LOGITS.get_expectation(), dtype=logits.dtype)
+ torch.testing.assert_close(logits[0, -3:, -3:], expected_slice, rtol=1e-3, atol=1e-3)
+
expected_argmax = torch.tensor([[105, 9731, 107, 740, 564, 1601, 611, 236881, 236881, 107, 107]])
torch.testing.assert_close(logits.argmax(-1), expected_argmax)
@@ -366,6 +367,16 @@ def test_model_generation(self):
inputs = self.get_inputs().to(model.model.embed_tokens.weight.device)
with torch.no_grad():
- generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=3, top_k=None, top_p=None)
-
- self.assertEqual(generated_ids[0, -3:].tolist(), [107, 262146, 108])
+ generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=16, top_k=None, top_p=None)
+
+ expected_generated_ids = Expectations(
+ {
+ (None, None): [
+ 107, 262146, 108, 9259, 236888, 2088, 740, 564,
+ 6361, 611, 3124, 236881, 108, 236769, 10282, 236787,
+ ]
+ }
+ ) # fmt: skip
+ self.assertEqual(
+ generated_ids[0, inputs.input_ids.shape[-1] :].tolist(), expected_generated_ids.get_expectation()
+ )
From ebeb8c3fbf3e48c8a9d85c545cbf69f7ea013f64 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 26 May 2026 17:54:24 +0800
Subject: [PATCH 35/36] date
---
docs/source/en/model_doc/zaya.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md
index de7916038254..1ae3323416c8 100644
--- a/docs/source/en/model_doc/zaya.md
+++ b/docs/source/en/model_doc/zaya.md
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
-*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-19.*
+*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-26.*
# ZAYA
From d71306f07a44096dccde73c8ae5f783581d9b857 Mon Sep 17 00:00:00 2001
From: JJJYmmm <1650675829@qq.com>
Date: Tue, 26 May 2026 23:19:49 +0800
Subject: [PATCH 36/36] update link!
---
src/transformers/models/zaya/modeling_zaya.py | 2 +-
src/transformers/models/zaya/modular_zaya.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py
index f6e2a16c82af..714910c8831a 100755
--- a/src/transformers/models/zaya/modeling_zaya.py
+++ b/src/transformers/models/zaya/modeling_zaya.py
@@ -159,7 +159,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class ZayaCCAProjection(nn.Module):
"""
Projects hidden states into attention q/k/v states with ZAYA's Compressed Convolutional Attention (CCA) path.
- See https://arxiv.org/abs/2510.04476.
+ See https://huggingface.co/papers/2510.04476.
This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D
convolution, q/k keep residual projection paths, and v uses a delayed recurrent state.
diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py
index c3b5008fe988..c205de489e8c 100644
--- a/src/transformers/models/zaya/modular_zaya.py
+++ b/src/transformers/models/zaya/modular_zaya.py
@@ -150,7 +150,7 @@ class ZayaRMSNorm(Qwen3MoeRMSNorm):
class ZayaCCAProjection(nn.Module):
"""
Projects hidden states into attention q/k/v states with ZAYA's Compressed Convolutional Attention (CCA) path.
- See https://arxiv.org/abs/2510.04476.
+ See https://huggingface.co/papers/2510.04476.
This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D
convolution, q/k keep residual projection paths, and v uses a delayed recurrent state.