diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b0283d464d04..f443d2be5d57 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -894,6 +894,8 @@ title: Zamba - local: model_doc/zamba2 title: Zamba2 + - local: model_doc/zaya + title: ZAYA title: Text models - sections: - local: model_doc/aimv2 diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md new file mode 100644 index 000000000000..1ae3323416c8 --- /dev/null +++ b/docs/source/en/model_doc/zaya.md @@ -0,0 +1,63 @@ + +*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-26.* + +# ZAYA + +## Overview + +ZAYA1 is a 760M active / 8.4B total parameter MoE language model trained by Zyphra. It combines Compressed +Convolutional Attention (CCA), a nonlinear ZAYA1 router, and residual scaling. + +ZAYA1 uses the Gemma 3 tokenizer. For more details, see the [ZAYA1 model card](https://huggingface.co/Zyphra/ZAYA1-8B) +and Zyphra's technical reports. + +This model was contributed by [JJJYmmm](https://github.com/JJJYmmm). + +## Usage examples + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_id = "Zyphra/ZAYA1-8B" +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + +inputs = tokenizer.apply_chat_template( + [{"role": "user", "content": "Write a haiku about recursion in programming."}], + tokenize=True, + add_generation_prompt=True, + enable_thinking=False, + return_tensors="pt", +) +inputs = inputs.to(model.device) +outputs = model.generate(**inputs, max_new_tokens=256) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +## ZayaConfig + +[[autodoc]] ZayaConfig + +## ZayaModel + +[[autodoc]] ZayaModel + - forward + +## ZayaForCausalLM + +[[autodoc]] ZayaForCausalLM + - forward diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index dfef404a42f1..993643dfe390 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -864,6 +864,33 @@ def reorder_cache(self, beam_idx: torch.LongTensor): DynamicLayer.reorder_cache(self, beam_idx) +class LinearAttentionAndSlidingWindowAttentionLayer(LinearAttentionLayer, DynamicSlidingWindowLayer): + # The dynamic sliding attention part makes it non-compileable + is_compileable = False + + def __init__(self, config: PreTrainedConfig | None = None): + DynamicSlidingWindowLayer.__init__(self, config) + LinearAttentionLayer.__init__(self) + + def lazy_initialization(self, *args, **kwargs) -> None: + # When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args + if len(args) == 2 and len(kwargs) == 0: + DynamicSlidingWindowLayer.lazy_initialization(self, *args) + # Otherwise, for the LinearAttention cache, when it's called in `update_conv_state` or `update_recurrent_state`, + # it's always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states) + if len(args) == 0 and len(kwargs) == 1: + LinearAttentionLayer.lazy_initialization(self, **kwargs) + + def reset(self) -> None: + LinearAttentionLayer.reset(self) + DynamicSlidingWindowLayer.reset(self) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + LinearAttentionLayer.reorder_cache(self, beam_idx) + DynamicSlidingWindowLayer.reorder_cache(self, beam_idx) + + # Pre-register the standard layer types (some classes are shared between multiple types, # e.g. ``DynamicSlidingWindowLayer`` covers both ``"sliding_attention"`` and # ``"chunked_attention"`` — those need an explicit map entry rather than the @@ -883,6 +910,7 @@ def reorder_cache(self, beam_idx: torch.LongTensor): "moe": LinearAttentionLayer, # Hybrid layers (e.g. zamba / zamba2) carry both a linear-attention state and a dynamic-attention state. "hybrid": LinearAttentionAndFullAttentionLayer, + "hybrid_sliding": LinearAttentionAndSlidingWindowAttentionLayer, } ) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index eb22824db8b0..427feb941bce 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -71,7 +71,8 @@ "attention", "sparse", "dense", - "hybrid", # for layers that have both mamba and attention in zamba and zamba2 + "hybrid", # for zamba/zamba2/zaya1, which use full attention + conv states + "hybrid_sliding", # for zaya1, which uses swa + conv states "moe", # for nemotron_h, which uses either attention, mamba or moe ) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index f1967932a02d..0524ea34b3c5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -485,6 +485,7 @@ from .youtu import * from .zamba import * from .zamba2 import * + from .zaya import * from .zoedepth import * else: import sys diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index d6e3069ed025..95c8a7d4130f 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -651,6 +651,7 @@ ("youtu", "YoutuConfig"), ("zamba", "ZambaConfig"), ("zamba2", "Zamba2Config"), + ("zaya", "ZayaConfig"), ("zoedepth", "ZoeDepthConfig"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e9de6747076f..3532d20d6b25 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -519,6 +519,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("youtu", "YoutuModel"), ("zamba", "ZambaModel"), ("zamba2", "Zamba2Model"), + ("zaya", "ZayaModel"), ] ) @@ -783,6 +784,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("youtu", "YoutuForCausalLM"), ("zamba", "ZambaForCausalLM"), ("zamba2", "Zamba2ForCausalLM"), + ("zaya", "ZayaForCausalLM"), ] ) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index cef2e4d5863e..7944ab3220fc 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -342,6 +342,7 @@ ("xlstm", "GPTNeoXTokenizer" if is_tokenizers_available() else None), ("xmod", "XLMRobertaTokenizer" if is_tokenizers_available() else None), ("yoso", "AlbertTokenizer" if is_tokenizers_available() else None), + ("zaya", "GemmaTokenizer" if is_tokenizers_available() else None), ] ) diff --git a/src/transformers/models/zaya/__init__.py b/src/transformers/models/zaya/__init__.py new file mode 100644 index 000000000000..c28f97af94ea --- /dev/null +++ b/src/transformers/models/zaya/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2026 Zyphra and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_zaya import * + from .modeling_zaya import * + +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py new file mode 100644 index 000000000000..690a521ee357 --- /dev/null +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -0,0 +1,129 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zaya/modular_zaya.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zaya.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Literal + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +@strict +class ZayaConfig(PreTrainedConfig): + r""" + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the language modeling head. + router_hidden_size (`int`, *optional*, defaults to 256): + Hidden size used by the ZAYA router. + cca_time0 (`int`, *optional*, defaults to 2): + First temporal parameter of the CCA projection. + cca_time1 (`int`, *optional*, defaults to 2): + Second temporal parameter of the CCA projection. + + ```python + >>> from transformers import ZayaConfig, ZayaModel + + >>> configuration = ZayaConfig() + >>> model = ZayaModel(configuration) + + >>> configuration = model.config + ``` + """ + + model_type = "zaya" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_fsdp_plan = { + "embed_tokens": "free_full_weight", + "layers.*": "free_full_weight", + "norm": "keep_full_weight", + } + + vocab_size: int = 262272 + hidden_size: int = 2048 + num_hidden_layers: int = 40 + num_attention_heads: int = 8 + num_key_value_heads: int = 2 + hidden_act: str = "silu" + max_position_embeddings: int = 131072 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-5 + use_cache: bool = True + tie_word_embeddings: bool = True + rope_parameters: RopeParameters | dict | None = None + sliding_window: int | None = None + attention_dropout: float | int = 0.0 + moe_intermediate_size: int = 2048 + + num_experts_per_tok: int = 1 + num_experts: int = 16 + output_router_logits: bool = False + layer_types: list[str] | None = None + pad_token_id: int | None = 0 + bos_token_id: int | None = 2 + eos_token_id: int | list[int] | None = 106 + + # Zaya-specific attention + head_dim: int = 128 + attention_bias: bool = False + + lm_head_bias: bool = False + router_hidden_size: int = 256 + cca_time0: int = 2 + cca_time1: int = 2 + + def __post_init__(self, **kwargs): + self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) + + default_rope_params: dict[Literal["hybrid", "hybrid_sliding"], dict[str, Any]] = { + "hybrid": { + "rope_type": "default", + "rope_theta": 5_000_000.0, + "partial_rotary_factor": 0.5, + }, + "hybrid_sliding": { + "rope_type": "default", + "rope_theta": 10_000.0, + "partial_rotary_factor": 0.5, + }, + } + if self.rope_parameters is None: + self.rope_parameters = default_rope_params + + super().__post_init__(**kwargs, ignore_keys_at_rope_validation={"hybrid", "hybrid_sliding"}) + + def convert_rope_params_to_dict(self, **kwargs): + # No legacy flat RoPE format is supported here; conversion writes the nested ZAYA layer-type format directly. + return kwargs + + def validate_architecture(self): + """Part of ``@strict``-powered validation.""" + if self.num_experts_per_tok != 1: + raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.") + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") + if "hybrid_sliding" in self.layer_types and self.sliding_window is None: + raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.") + + +__all__ = ["ZayaConfig"] diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py new file mode 100644 index 000000000000..a359ed5bb8aa --- /dev/null +++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py @@ -0,0 +1,376 @@ +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert original alternating-layer ZAYA checkpoints to the Transformers-native decoder-layer layout.""" + +import argparse +import json +import re +import shutil +from collections import defaultdict +from pathlib import Path + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from transformers import ZayaConfig + + +_DEFAULT_ROPE_THETA = 5_000_000.0 +_DEFAULT_SWA_ROPE_THETA = 10_000.0 +_LAYER_PATTERN = re.compile(r"^model\.layers\.(\d+)\.(.+)$") +_LOCAL_EXPERT_PATTERN = re.compile( + r"^model\.layers\.(\d+)\.zaya_block\.experts\.local_experts\.(\d+)\.linear_fc([12])\.weight$" +) + +_UNUSED_CONFIG_KEYS = ( + "cca", + "num_query_groups", + "intermediate_size", + "ffn_hidden_size", + "moe_router_topk", + "norm_epsilon", + "zaya_mlp_expansion", + "activation_func", + "normalization", + "add_bias_linear", + "gated_linear_unit", + "fused_add_norm", + "apply_rope_fusion", + "bias_activation_fusion", + "activation_func_fp8_input_store", + "clamp_temp", + "kv_channels", + "mamba_cache_dtype", + "residual_in_fp32", + "rope_scaling", + "scale_residual_merge", + "zaya_high_prec", + "zaya_use_mod", + "zaya_use_eda", +) + + +def _rename_common(rest: str) -> str: + replacements = ( + ("self_attn.qkv.conv_qk.0.", "self_attn.qkv_proj.conv_qk_depthwise."), + ("self_attn.qkv.conv_qk.1.", "self_attn.qkv_proj.conv_qk_grouped."), + ("self_attn.qkv.temp", "self_attn.qk_norm.temp"), + ("self_attn.qkv.linear_q.", "self_attn.qkv_proj.q_proj."), + ("self_attn.qkv.linear_k.", "self_attn.qkv_proj.k_proj."), + ("self_attn.qkv.val_proj1.", "self_attn.qkv_proj.v_proj_current."), + ("self_attn.qkv.val_proj2.", "self_attn.qkv_proj.v_proj_delayed."), + ("self_attn.qkv.", "self_attn.qkv_proj."), + ("zaya_block.router.rmsnorm_eda.", "mlp.gate.router_mlp.norm."), + ("zaya_block.router.router_mlp.0.", "mlp.gate.router_mlp.fc1."), + ("zaya_block.router.router_mlp.2.", "mlp.gate.router_mlp.fc2."), + ("zaya_block.router.router_mlp.4.", "mlp.gate.router_mlp.out_proj."), + ("zaya_block.router.", "mlp.gate."), + ("zaya_block.", "mlp."), + ) + for old, new in replacements: + if rest.startswith(old): + return new + rest.removeprefix(old) + return rest + + +def _expert_target(name: str) -> tuple[str, int] | None: + match = _LOCAL_EXPERT_PATTERN.match(name) + if match is None: + return None + + old_layer_idx = int(match.group(1)) + if old_layer_idx % 2 != 1: + raise ValueError(f"Expert weights are expected on odd ZAYA layers, got: {name}") + + new_layer_idx = old_layer_idx // 2 + expert_idx = int(match.group(2)) + projection = "gate_up_proj" if match.group(3) == "1" else "down_proj" + target = f"model.layers.{new_layer_idx}.mlp.experts.{projection}" + return target, expert_idx + + +def convert_weight_name(name: str, old_num_hidden_layers: int | None = None) -> str | None: + if _expert_target(name) is not None: + return None + + match = _LAYER_PATTERN.match(name) + if match is None: + if name.startswith("model.final_norm."): + return f"model.norm.{name.removeprefix('model.final_norm.')}" + if old_num_hidden_layers is not None and name.startswith("model.res_scale."): + new_layer_idx = old_num_hidden_layers // 2 - 1 + return f"model.layers.{new_layer_idx}.post_mlp_residual_scale.{name.removeprefix('model.res_scale.')}" + return name + + old_layer_idx = int(match.group(1)) + rest = match.group(2) + new_layer_idx = old_layer_idx // 2 + + if old_layer_idx % 2 == 0: + rest = _rename_common(rest) + if rest.startswith("self_attn."): + return f"model.layers.{new_layer_idx}.{rest}" + if rest.startswith("input_norm."): + return f"model.layers.{new_layer_idx}.input_layernorm.{rest.removeprefix('input_norm.')}" + if rest.startswith("res_scale."): + if old_layer_idx == 0: + return f"model.input_{rest.removeprefix('res_scale.')}" + return f"model.layers.{new_layer_idx - 1}.post_mlp_residual_scale.{rest.removeprefix('res_scale.')}" + else: + rest = _rename_common(rest) + if rest.startswith("mlp."): + return f"model.layers.{new_layer_idx}.{rest}" + if rest.startswith("input_norm."): + return f"model.layers.{new_layer_idx}.post_attention_layernorm.{rest.removeprefix('input_norm.')}" + if rest.startswith("res_scale."): + return f"model.layers.{new_layer_idx}.post_attention_residual_scale.{rest.removeprefix('res_scale.')}" + + raise ValueError(f"Unexpected ZAYA layer weight name: {name}") + + +def _to_hybrid_layer_type(layer_type: str) -> str: + if layer_type == "full_attention": + return "hybrid" + if layer_type == "sliding_attention": + return "hybrid_sliding" + raise ValueError(f"Unsupported ZAYA layer type: {layer_type}") + + +def _convert_layer_types(config_dict: dict, old_num_hidden_layers: int, new_num_hidden_layers: int) -> list[str]: + layer_types = config_dict.get("layer_types") + if layer_types is not None: + if len(layer_types) == old_num_hidden_layers: + return [_to_hybrid_layer_type(layer_type) for layer_type in layer_types[::2]] + if len(layer_types) == new_num_hidden_layers: + return [_to_hybrid_layer_type(layer_type) for layer_type in layer_types] + raise ValueError("`layer_types` must match either the original or converted number of hidden layers.") + + swa_layers = config_dict.get("swa_layers") + if swa_layers is None: + return ["hybrid"] * new_num_hidden_layers + if len(swa_layers) == old_num_hidden_layers: + swa_layers = swa_layers[::2] + elif len(swa_layers) != new_num_hidden_layers: + raise ValueError("`swa_layers` must match either the original or converted number of hidden layers.") + return ["hybrid" if int(window_size) == 0 else "hybrid_sliding" for window_size in swa_layers] + + +def convert_config(input_dir: Path, output_dir: Path) -> None: + config_dict = json.loads((input_dir / "config.json").read_text()) + old_num_hidden_layers = int(config_dict["num_hidden_layers"]) + if old_num_hidden_layers % 2 != 0: + raise ValueError("Original ZAYA checkpoints must have an even number of alternating attention/MoE layers.") + + new_num_hidden_layers = old_num_hidden_layers // 2 + layer_types = _convert_layer_types(config_dict, old_num_hidden_layers, new_num_hidden_layers) + partial_rotary_factor = 0.5 + rope_theta = config_dict.get("rope_theta", _DEFAULT_ROPE_THETA) + swa_rotary_base = config_dict.get("swa_rotary_base", _DEFAULT_SWA_ROPE_THETA) + rms_norm_eps = config_dict.get("rms_norm_eps", config_dict.get("norm_epsilon", ZayaConfig.rms_norm_eps)) + router_hidden_size = config_dict.get( + "router_hidden_size", config_dict.get("zaya_mlp_expansion", ZayaConfig.router_hidden_size) + ) + expert_ffn_size = config_dict.get("intermediate_size", config_dict.get("ffn_hidden_size")) + moe_intermediate_size = expert_ffn_size // 2 if expert_ffn_size is not None else ZayaConfig.moe_intermediate_size + num_experts_per_tok = config_dict.get( + "num_experts_per_tok", config_dict.get("moe_router_topk", ZayaConfig.num_experts_per_tok) + ) + + swa_layers = config_dict.get("swa_layers") or [] + sliding_window = config_dict.get("sliding_window") + if sliding_window is None: + positive_windows = [int(window_size) for window_size in swa_layers if int(window_size) > 0] + # Original ZAYA stores the number of previous tokens attended by SWA layers. Transformers' sliding window + # is the total local attention span, including the current token. + sliding_window = max(positive_windows) + 1 if positive_windows else None + + rope_parameters = { + "hybrid": { + "rope_type": "default", + "rope_theta": rope_theta, + "partial_rotary_factor": partial_rotary_factor, + }, + "hybrid_sliding": { + "rope_type": "default", + "rope_theta": swa_rotary_base, + "partial_rotary_factor": partial_rotary_factor, + }, + } + + for key in (*_UNUSED_CONFIG_KEYS, "swa_layers", "rope_theta", "swa_rotary_base"): + config_dict.pop(key, None) + + config_dict.update( + { + "architectures": ["ZayaForCausalLM"], + "num_hidden_layers": new_num_hidden_layers, + "moe_intermediate_size": moe_intermediate_size, + "num_experts_per_tok": num_experts_per_tok, + "rms_norm_eps": rms_norm_eps, + "router_hidden_size": router_hidden_size, + "layer_types": layer_types, + "sliding_window": sliding_window, + "rope_parameters": rope_parameters, + } + ) + ZayaConfig(**config_dict).save_pretrained(output_dir) + + +def copy_non_weight_files(input_dir: Path, output_dir: Path) -> None: + for path in input_dir.iterdir(): + if path.name == "config.json": + continue + if path.name.endswith(".safetensors") or path.name.endswith(".bin"): + continue + if path.name in {"model.safetensors.index.json", "pytorch_model.bin.index.json"}: + continue + + output_path = output_dir / path.name + if path.is_dir(): + shutil.copytree(path, output_path, dirs_exist_ok=True) + else: + shutil.copy2(path, output_path) + + +def _build_weight_plan(input_dir: Path) -> tuple[dict[str, str], dict[str, list[str]], dict[str, str], dict]: + index = json.loads((input_dir / "model.safetensors.index.json").read_text()) + old_weight_map = index["weight_map"] + old_num_hidden_layers = int(json.loads((input_dir / "config.json").read_text())["num_hidden_layers"]) + converted_weight_map = {} + normal_sources_by_output_file = defaultdict(list) + expert_sources_by_target = defaultdict(list) + output_file_by_target = {} + + for source_key, filename in old_weight_map.items(): + expert_info = _expert_target(source_key) + if expert_info is not None: + target_key, expert_idx = expert_info + expert_sources_by_target[target_key].append((expert_idx, source_key)) + output_file_by_target.setdefault(target_key, filename) + converted_weight_map[target_key] = output_file_by_target[target_key] + continue + + target_key = convert_weight_name(source_key, old_num_hidden_layers) + if target_key in converted_weight_map: + raise ValueError(f"Duplicate converted weight name: {target_key}") + converted_weight_map[target_key] = filename + normal_sources_by_output_file[filename].append((source_key, target_key)) + + index["weight_map"] = converted_weight_map + return normal_sources_by_output_file, expert_sources_by_target, output_file_by_target, index + + +def _load_sources(input_dir: Path, source_keys: list[str], old_weight_map: dict[str, str]) -> dict[str, torch.Tensor]: + sources_by_file = defaultdict(list) + for source_key in source_keys: + sources_by_file[old_weight_map[source_key]].append(source_key) + + tensors = {} + for filename, keys in sources_by_file.items(): + with safe_open(input_dir / filename, framework="pt", device="cpu") as f: + for key in keys: + tensors[key] = f.get_tensor(key) + return tensors + + +def convert_safetensors(input_dir: Path, output_dir: Path) -> None: + index_path = input_dir / "model.safetensors.index.json" + if not index_path.exists(): + safetensors_path = input_dir / "model.safetensors" + if not safetensors_path.exists(): + raise FileNotFoundError("Only safetensors ZAYA checkpoints are supported by this converter.") + + old_num_hidden_layers = int(json.loads((input_dir / "config.json").read_text())["num_hidden_layers"]) + with safe_open(safetensors_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + state_dict = {} + expert_groups = defaultdict(list) + for key in f.keys(): + expert_info = _expert_target(key) + if expert_info is not None: + target_key, expert_idx = expert_info + expert_groups[target_key].append((expert_idx, f.get_tensor(key))) + continue + state_dict[convert_weight_name(key, old_num_hidden_layers)] = f.get_tensor(key) + for target_key, expert_tensors in expert_groups.items(): + state_dict[target_key] = torch.stack([tensor for _, tensor in sorted(expert_tensors)], dim=0) + save_file(state_dict, output_dir / "model.safetensors", metadata=metadata) + return + + old_index = json.loads(index_path.read_text()) + old_weight_map = old_index["weight_map"] + normal_sources_by_output_file, expert_sources_by_target, output_file_by_target, converted_index = ( + _build_weight_plan(input_dir) + ) + output_filenames = sorted(set(converted_index["weight_map"].values())) + + metadata_by_file = {} + for filename in sorted(set(old_weight_map.values())): + with safe_open(input_dir / filename, framework="pt", device="cpu") as f: + metadata_by_file[filename] = f.metadata() + + for output_filename in output_filenames: + shard = {} + normal_sources = normal_sources_by_output_file.get(output_filename, []) + source_keys = [source_key for source_key, _ in normal_sources] + + expert_groups_for_shard = { + target_key: sorted(sources) + for target_key, sources in expert_sources_by_target.items() + if output_file_by_target[target_key] == output_filename + } + for sources in expert_groups_for_shard.values(): + source_keys.extend(source_key for _, source_key in sources) + + loaded_tensors = _load_sources(input_dir, source_keys, old_weight_map) + for source_key, target_key in normal_sources: + shard[target_key] = loaded_tensors[source_key] + for target_key, sources in expert_groups_for_shard.items(): + shard[target_key] = torch.stack([loaded_tensors[source_key] for _, source_key in sources], dim=0) + + save_file(shard, output_dir / output_filename, metadata=metadata_by_file.get(output_filename)) + + (output_dir / "model.safetensors.index.json").write_text( + json.dumps(converted_index, indent=2, sort_keys=True) + "\n" + ) + + +def convert_checkpoint(input_dir: str, output_dir: str, overwrite: bool = False) -> None: + input_path = Path(input_dir).expanduser().resolve() + output_path = Path(output_dir).expanduser().resolve() + if input_path == output_path: + raise ValueError("Please write the converted checkpoint to a different output directory.") + if output_path.exists() and any(output_path.iterdir()): + if not overwrite: + raise FileExistsError(f"{output_path} already exists and is not empty. Pass --overwrite to replace it.") + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + copy_non_weight_files(input_path, output_path) + convert_config(input_path, output_path) + convert_safetensors(input_path, output_path) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--input_dir", required=True, help="Path to the original alternating-layer ZAYA checkpoint.") + parser.add_argument("--output_dir", required=True, help="Path where the converted checkpoint should be written.") + parser.add_argument("--overwrite", action="store_true", help="Overwrite a non-empty output directory.") + args = parser.parse_args() + convert_checkpoint(args.input_dir, args.output_dir, overwrite=args.overwrite) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py new file mode 100755 index 000000000000..714910c8831a --- /dev/null +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -0,0 +1,886 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zaya/modular_zaya.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zaya.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections.abc import Callable +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import init + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from .configuration_zaya import ZayaConfig + + +class ZayaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: ZayaConfig): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + self.layer_types = list(set(config.layer_types)) + self.rope_type = {} + for layer_type in self.layer_types: + rope_params = self.config.rope_parameters[layer_type] + if rope_params is None: + continue + + self.rope_type[layer_type] = rope_params["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] + curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, layer_type=layer_type) + self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False) + self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False) + setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling) + + @staticmethod + def compute_default_rope_parameters( + config: ZayaConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + layer_type (`str`, *optional*): + The current layer type if the model has different RoPE parameters per type. + Should not be used unless `config.layer_types is not None` + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters[layer_type]["rope_theta"] + # key difference to gemma3: partial rope + partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids, layer_type=None): + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernel_forward_from_hub("RMSNorm") +class ZayaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + ZayaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ZayaCCAProjection(nn.Module): + """ + Projects hidden states into attention q/k/v states with ZAYA's Compressed Convolutional Attention (CCA) path. + See https://huggingface.co/papers/2510.04476. + + This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D + convolution, q/k keep residual projection paths, and v uses a delayed recurrent state. + """ + + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + + self.depthwise_kernel_size = config.cca_time0 + self.grouped_kernel_size = config.cca_time1 + self.conv_kernel_size = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) + + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + query_hidden_size = self.num_attention_heads * self.head_dim + key_value_hidden_size = self.num_key_value_heads * self.head_dim + + self.q_proj = nn.Linear(self.hidden_size, query_hidden_size, bias=self.config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, key_value_hidden_size, bias=self.config.attention_bias) + self.v_proj_current = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + self.v_proj_delayed = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + + conv_channels = key_value_hidden_size + query_hidden_size + self.conv_qk_depthwise = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.depthwise_kernel_size, + groups=conv_channels, + padding=0, + stride=1, + ) + self.conv_qk_grouped = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.grouped_kernel_size, + groups=(self.num_key_value_heads + self.num_attention_heads), + padding=0, + stride=1, + ) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Cache | None, + padding_mask: torch.Tensor | None = None, + ): + if padding_mask is not None: + hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + projected_queries = self.q_proj(hidden_states) + projected_keys = self.k_proj(hidden_states) + qk_states = torch.cat([projected_queries, projected_keys], dim=-1) + + query_residual = projected_queries.view(*hidden_shape) + key_residual = projected_keys.view(*hidden_shape).transpose(1, 2) + key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) + query_residual = (query_residual + key_residual) * 0.5 + key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) + + qk_states = qk_states.transpose(1, 2) + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx) + if use_precomputed_states: + cached_qk_states = past_key_values.layers[self.layer_idx].conv_states + qk_states = torch.cat([cached_qk_states, qk_states], dim=-1) + else: + qk_states = F.pad(qk_states, (self.conv_kernel_size, 0)) + + if past_key_values is not None: + new_conv_state = qk_states[..., -self.conv_kernel_size :] + new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) + past_key_values.update_conv_state(new_conv_state, self.layer_idx) + + qk_states = self.conv_qk_depthwise(qk_states) + qk_states = self.conv_qk_grouped(qk_states).transpose(1, 2) + + query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1] + query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual + key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual + + # The value path carries half of each value head from the current token and half from the previous token. + # During cached decoding, `recurrent_v_state` is the previous token's delayed projection. + value_current = self.v_proj_current(hidden_states) + delayed_v_state = self.v_proj_delayed(hidden_states) + if use_precomputed_states: + recurrent_v_state = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) + else: + recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) + value_delayed = torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1) + + if past_key_values is not None: + past_key_values.update_recurrent_state(delayed_v_state[:, -1, :], self.layer_idx) + + value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape) + + return query, key, value + + +class ZayaQKNorm(nn.Module): + """ + L2-normalizes q/k states to sqrt(head_dim) and applies ZAYA's learned per-KV-head key scale. + """ + + def __init__(self, config: ZayaConfig): + super().__init__() + self.head_dim_scale = config.head_dim**0.5 + self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + + def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + norm_eps = torch.finfo(query_states.dtype).eps + query_states = query_states * ( + self.head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * ( + self.head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * self.temp[None, None, :, None] + return query_states, key_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class ZayaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.qkv_proj = ZayaCCAProjection( + config=self.config, + layer_idx=layer_idx, + ) + self.layer_type = config.layer_types[layer_idx] + self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.qk_norm = ZayaQKNorm(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + + mask_mapping = attention_mask or {} + causal_mask = mask_mapping.get("causal") + padding_mask = mask_mapping.get("padding") + + # ZAYA replaces the usual independent q/k/v projections with CCA projection followed by special q/k normalization. + query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask) + query_states, key_states = self.qk_norm(query_states, key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + causal_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.view(*input_shape, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class ZayaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ZayaAttention(config=config, layer_idx=layer_idx) + self.mlp = ZayaSparseMoeBlock(config, layer_idx) + self.input_layernorm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size) + self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + residual = hidden_states + # Match upstream's residual_in_fp32 path by keeping the residual stream in fp32 and avoiding extra + # fp32->bf16 round trips in the residual module. + hidden_states = self.input_layernorm(residual.to(dtype=self.input_layernorm.weight.dtype)) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_embeddings=position_embeddings, + **kwargs, + ) + + residual = self.post_attention_residual_scale(hidden_states, residual) + hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype)) + + hidden_states, prev_router_hidden_states = self.mlp( + hidden_states, + prev_router_hidden_states, + ) + + hidden_states = self.post_mlp_residual_scale(hidden_states, residual) + + return hidden_states, prev_router_hidden_states + + +class ZayaResidualScaling(nn.Module): + def __init__(self, hidden_size: int): + super().__init__() + self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) + self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size)) + self.residual_scale = nn.Parameter(torch.ones(hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): + # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path. + hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale + residual = (residual + self.residual_bias) * self.residual_scale + return hidden_states + residual + + +class ZayaRouterMLP(nn.Module): + def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float): + super().__init__() + self.norm = ZayaRMSNorm(hidden_size, eps=rms_norm_eps) + self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) + self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + hidden_states = self.act_fn(self.fc1(hidden_states)) + hidden_states = self.act_fn(self.fc2(hidden_states)) + return self.out_proj(hidden_states) + + +class ZayaRouter(nn.Module): + def __init__( + self, + config, + layer_idx: int, + ) -> None: + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.num_experts = config.num_experts + 1 + self.top_k = config.num_experts_per_tok + self.router_hidden_size = config.router_hidden_size + + self.down_proj = nn.Linear(self.hidden_size, self.router_hidden_size, bias=True) + + self.use_eda = self.layer_idx != 0 + if self.use_eda: + self.router_states_scale = nn.Parameter(torch.ones(self.router_hidden_size)) + + self.router_mlp = ZayaRouterMLP(self.router_hidden_size, self.num_experts, config.rms_norm_eps) + + self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + self.balancing_biases[-1] = -1.0 + + def forward( + self, + hidden_states: torch.Tensor, + router_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + final_shape = (-1, self.top_k) + seq_length = hidden_states.shape[1] + + router_hidden_states = self.down_proj(hidden_states) + + if self.use_eda and router_states is not None: + router_hidden_states = router_hidden_states + router_states * self.router_states_scale + + router_hidden_states_next = router_hidden_states[:, -seq_length:].clone() + router_logits = self.router_mlp(router_hidden_states) + router_probs = torch.softmax(router_logits, dim=-1) + + biased_router_probs = router_probs.detach().to(torch.float32) + self.balancing_biases + _, router_indices = torch.topk(biased_router_probs, self.top_k, dim=-1) + router_probs = torch.gather(router_probs, dim=2, index=router_indices) + + # If the router selects the extra skip expert, mask it before `ZayaExperts` builds its one-hot expert mask. + skip_expert = router_indices == self.config.num_experts + router_probs = router_probs.masked_fill(skip_expert, 0) + router_indices = router_indices.masked_fill(skip_expert, 0) + + return ( + router_logits.reshape(-1, self.num_experts), + router_probs.reshape(final_shape), + router_indices.reshape(final_shape), + router_hidden_states_next, + ) + + +@use_experts_implementation +class ZayaExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class ZayaSparseMoeBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.gate = ZayaRouter(config, layer_idx) + self.experts = ZayaExperts(config) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # ZAYA carries router hidden states across decoder layers; the next layer consumes this state in its router. + _, router_probs, router_indices, prev_router_hidden_states = self.gate( + hidden_states, router_states=prev_router_hidden_states + ) + + batch_size, seq_length, emb_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) + expert_output = self.experts(hidden_states_flat, router_indices, router_probs) + expert_output = expert_output.view(batch_size, seq_length, emb_dim) + + return expert_output, prev_router_hidden_states + + +@auto_docstring +class ZayaPreTrainedModel(PreTrainedModel): + config: ZayaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ZayaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache. + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(ZayaRouter, index=0), + "hidden_states": ZayaDecoderLayer, + "attentions": ZayaAttention, + } + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, ZayaResidualScaling): + init.ones_(module.hidden_states_scale) + init.zeros_(module.hidden_states_bias) + init.ones_(module.residual_scale) + init.zeros_(module.residual_bias) + elif isinstance(module, ZayaModel): + init.ones_(module.input_hidden_states_scale) + init.zeros_(module.input_hidden_states_bias) + elif isinstance(module, ZayaQKNorm): + init.zeros_(module.temp) + elif isinstance(module, ZayaRouter): + if module.use_eda: + init.ones_(module.router_states_scale) + init.zeros_(module.balancing_biases) + module.balancing_biases[-1] = -1.0 # trf-ignore: TRF012 + elif isinstance(module, ZayaExperts): + std = self.config.initializer_range + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, ZayaRotaryEmbedding): + for layer_type in module.layer_types: + rope_init_fn = module.compute_default_rope_parameters + if module.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]] + curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type) + getattr(module, f"{layer_type}_inv_freq").copy_(curr_inv_freq) + getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq) + + +@auto_docstring +class ZayaModel(ZayaPreTrainedModel): + def __init__(self, config: ZayaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = ZayaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError( + "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." + ) + + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + mask_creation_functions = { + "hybrid": lambda: create_causal_mask(**mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**mask_kwargs), + } + causal_mask_mapping = { + layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types) + } + cca_mask = self._update_cca_mask(attention_mask, past_key_values) + + hidden_states = inputs_embeds + + position_embeddings = { + layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) + for layer_type in set(self.config.layer_types) + } + + # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path. + hidden_states = ((hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale).to( + torch.float32 + ) + + prev_router_hidden_states = None + + for idx, decoder_layer in enumerate(self.layers): + layer_type = self.config.layer_types[idx] + # Attention uses the prepared causal mask, while CCA projection still needs the raw 2D padding mask to + # zero padding tokens before convolution. + hidden_states, prev_router_hidden_states = decoder_layer( + hidden_states, + prev_router_hidden_states, + attention_mask={"causal": causal_mask_mapping[layer_type], "padding": cca_mask}, + past_key_values=past_key_values, + position_embeddings=position_embeddings[layer_type], + **kwargs, + ) + + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + def _update_cca_mask(self, attention_mask, past_key_values): + """ + No need to zero padding states when cached convolution states are already available or all inputs are valid. + """ + cca_mask = attention_mask + if (past_key_values is not None and past_key_values.has_previous_state()) or ( + attention_mask is not None and torch.all(attention_mask == 1) + ): + cca_mask = None + return cca_mask + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _sp_plan = {"lm_head": "colwise_loss_parallel"} + _fsdp_plan = {"lm_head": "keep_full_weight"} + _is_stateful = True + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = ZayaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_router_logits: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, ZayaForCausalLM + + >>> model = ZayaForCausalLM.from_pretrained("meta-zaya/Zaya-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-zaya/Zaya-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return MoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["ZayaPreTrainedModel", "ZayaModel", "ZayaForCausalLM"] diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py new file mode 100644 index 000000000000..c205de489e8c --- /dev/null +++ b/src/transformers/models/zaya/modular_zaya.py @@ -0,0 +1,670 @@ +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Zaya model.""" + +from collections.abc import Callable +from typing import Any, Literal + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from huggingface_hub.dataclasses import strict +from torch import nn +from torch.nn import init + +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PreTrainedConfig +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_outputs import MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, +) +from ...utils.output_capturing import OutputRecorder +from ..afmoe.modeling_afmoe import AfmoeForCausalLM +from ..laguna.configuration_laguna import LagunaConfig +from ..laguna.modeling_laguna import LagunaModel, LagunaRotaryEmbedding +from ..llama.modeling_llama import LlamaDecoderLayer, LlamaPreTrainedModel, repeat_kv +from ..phi3.modeling_phi3 import Phi3Attention +from ..qwen3_5_moe.modeling_qwen3_5_moe import ( + apply_rotary_pos_emb, + eager_attention_forward, +) +from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeExperts, Qwen3MoeRMSNorm + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +@strict +class ZayaConfig(LagunaConfig): + r""" + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the language modeling head. + router_hidden_size (`int`, *optional*, defaults to 256): + Hidden size used by the ZAYA router. + cca_time0 (`int`, *optional*, defaults to 2): + First temporal parameter of the CCA projection. + cca_time1 (`int`, *optional*, defaults to 2): + Second temporal parameter of the CCA projection. + + ```python + >>> from transformers import ZayaConfig, ZayaModel + + >>> configuration = ZayaConfig() + >>> model = ZayaModel(configuration) + + >>> configuration = model.config + ``` + """ + + model_type = "zaya" + + vocab_size: int = 262272 + moe_intermediate_size: int = 2048 + num_attention_heads: int = 8 + num_key_value_heads: int = 2 + tie_word_embeddings: bool = True + rms_norm_eps: float = 1e-5 + sliding_window: int | None = None + pad_token_id: int | None = 0 + bos_token_id: int | None = 2 + eos_token_id: int | list[int] | None = 106 + + num_experts_per_tok: int = 1 + num_experts: int = 16 + + lm_head_bias: bool = False + router_hidden_size: int = 256 + cca_time0: int = 2 + cca_time1: int = 2 + + # Fields declared by LagunaConfig but not used by ZAYA. + # NOTE: TP is intentionally disabled for now because the useful degree is limited by ZAYA's 2 KV heads; see + # https://github.com/huggingface/transformers/pull/45862#discussion_r3266709862. PP needs coverage for the cross-layer router state. + base_model_tp_plan = AttributeError() + base_model_pp_plan = AttributeError() + intermediate_size = AttributeError() + shared_expert_intermediate_size = AttributeError() + router_aux_loss_coef = AttributeError() + num_attention_heads_per_layer = AttributeError() + mlp_layer_types = AttributeError() + moe_routed_scaling_factor = AttributeError() + moe_apply_router_weight_on_input = AttributeError() + moe_router_logit_softcapping = AttributeError() + + def __post_init__(self, **kwargs): + self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) + + default_rope_params: dict[Literal["hybrid", "hybrid_sliding"], dict[str, Any]] = { + "hybrid": { + "rope_type": "default", + "rope_theta": 5_000_000.0, + "partial_rotary_factor": 0.5, + }, + "hybrid_sliding": { + "rope_type": "default", + "rope_theta": 10_000.0, + "partial_rotary_factor": 0.5, + }, + } + if self.rope_parameters is None: + self.rope_parameters = default_rope_params + + PreTrainedConfig.__post_init__(self, **kwargs, ignore_keys_at_rope_validation={"hybrid", "hybrid_sliding"}) + + def convert_rope_params_to_dict(self, **kwargs): + # No legacy flat RoPE format is supported here; conversion writes the nested ZAYA layer-type format directly. + return kwargs + + def validate_architecture(self): + if self.num_experts_per_tok != 1: + raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.") + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") + if "hybrid_sliding" in self.layer_types and self.sliding_window is None: + raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.") + + +class ZayaRotaryEmbedding(LagunaRotaryEmbedding): + pass + + +class ZayaRMSNorm(Qwen3MoeRMSNorm): + pass + + +class ZayaCCAProjection(nn.Module): + """ + Projects hidden states into attention q/k/v states with ZAYA's Compressed Convolutional Attention (CCA) path. + See https://huggingface.co/papers/2510.04476. + + This follows the usual q/k/v projection flow, with three ZAYA-specific changes: q/k are mixed by a causal 1D + convolution, q/k keep residual projection paths, and v uses a delayed recurrent state. + """ + + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + + self.depthwise_kernel_size = config.cca_time0 + self.grouped_kernel_size = config.cca_time1 + self.conv_kernel_size = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) + + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + query_hidden_size = self.num_attention_heads * self.head_dim + key_value_hidden_size = self.num_key_value_heads * self.head_dim + + self.q_proj = nn.Linear(self.hidden_size, query_hidden_size, bias=self.config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, key_value_hidden_size, bias=self.config.attention_bias) + self.v_proj_current = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + self.v_proj_delayed = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + + conv_channels = key_value_hidden_size + query_hidden_size + self.conv_qk_depthwise = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.depthwise_kernel_size, + groups=conv_channels, + padding=0, + stride=1, + ) + self.conv_qk_grouped = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.grouped_kernel_size, + groups=(self.num_key_value_heads + self.num_attention_heads), + padding=0, + stride=1, + ) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Cache | None, + padding_mask: torch.Tensor | None = None, + ): + if padding_mask is not None: + hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + projected_queries = self.q_proj(hidden_states) + projected_keys = self.k_proj(hidden_states) + qk_states = torch.cat([projected_queries, projected_keys], dim=-1) + + query_residual = projected_queries.view(*hidden_shape) + key_residual = projected_keys.view(*hidden_shape).transpose(1, 2) + key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) + query_residual = (query_residual + key_residual) * 0.5 + key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) + + qk_states = qk_states.transpose(1, 2) + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx) + if use_precomputed_states: + cached_qk_states = past_key_values.layers[self.layer_idx].conv_states + qk_states = torch.cat([cached_qk_states, qk_states], dim=-1) + else: + qk_states = F.pad(qk_states, (self.conv_kernel_size, 0)) + + if past_key_values is not None: + new_conv_state = qk_states[..., -self.conv_kernel_size :] + new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) + past_key_values.update_conv_state(new_conv_state, self.layer_idx) + + qk_states = self.conv_qk_depthwise(qk_states) + qk_states = self.conv_qk_grouped(qk_states).transpose(1, 2) + + query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1] + query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual + key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual + + # The value path carries half of each value head from the current token and half from the previous token. + # During cached decoding, `recurrent_v_state` is the previous token's delayed projection. + value_current = self.v_proj_current(hidden_states) + delayed_v_state = self.v_proj_delayed(hidden_states) + if use_precomputed_states: + recurrent_v_state = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) + else: + recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) + value_delayed = torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1) + + if past_key_values is not None: + past_key_values.update_recurrent_state(delayed_v_state[:, -1, :], self.layer_idx) + + value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape) + + return query, key, value + + +class ZayaQKNorm(nn.Module): + """ + L2-normalizes q/k states to sqrt(head_dim) and applies ZAYA's learned per-KV-head key scale. + """ + + def __init__(self, config: ZayaConfig): + super().__init__() + self.head_dim_scale = config.head_dim**0.5 + self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + + def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + norm_eps = torch.finfo(query_states.dtype).eps + query_states = query_states * ( + self.head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * ( + self.head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * self.temp[None, None, :, None] + return query_states, key_states + + +class ZayaAttention(Phi3Attention): + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__(config, layer_idx) + del op_size # noqa: F821 + self.layer_type = config.layer_types[layer_idx] + self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.qk_norm = ZayaQKNorm(config) + self.qkv_proj = ZayaCCAProjection( + config=self.config, + layer_idx=layer_idx, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + + mask_mapping = attention_mask or {} + causal_mask = mask_mapping.get("causal") + padding_mask = mask_mapping.get("padding") + + # ZAYA replaces the usual independent q/k/v projections with CCA projection followed by special q/k normalization. + query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask) + query_states, key_states = self.qk_norm(query_states, key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + causal_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.view(*input_shape, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class ZayaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.mlp = ZayaSparseMoeBlock(config, layer_idx) + self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size) + self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + residual = hidden_states + # Match upstream's residual_in_fp32 path by keeping the residual stream in fp32 and avoiding extra + # fp32->bf16 round trips in the residual module. + hidden_states = self.input_layernorm(residual.to(dtype=self.input_layernorm.weight.dtype)) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_embeddings=position_embeddings, + **kwargs, + ) + + residual = self.post_attention_residual_scale(hidden_states, residual) + hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype)) + + hidden_states, prev_router_hidden_states = self.mlp( + hidden_states, + prev_router_hidden_states, + ) + + hidden_states = self.post_mlp_residual_scale(hidden_states, residual) + + return hidden_states, prev_router_hidden_states + + +class ZayaResidualScaling(nn.Module): + def __init__(self, hidden_size: int): + super().__init__() + self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) + self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size)) + self.residual_scale = nn.Parameter(torch.ones(hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): + # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path. + hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale + residual = (residual + self.residual_bias) * self.residual_scale + return hidden_states + residual + + +class ZayaRouterMLP(nn.Module): + def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float): + super().__init__() + self.norm = ZayaRMSNorm(hidden_size, eps=rms_norm_eps) + self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) + self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + hidden_states = self.act_fn(self.fc1(hidden_states)) + hidden_states = self.act_fn(self.fc2(hidden_states)) + return self.out_proj(hidden_states) + + +class ZayaRouter(nn.Module): + def __init__( + self, + config, + layer_idx: int, + ) -> None: + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.num_experts = config.num_experts + 1 + self.top_k = config.num_experts_per_tok + self.router_hidden_size = config.router_hidden_size + + self.down_proj = nn.Linear(self.hidden_size, self.router_hidden_size, bias=True) + + self.use_eda = self.layer_idx != 0 + if self.use_eda: + self.router_states_scale = nn.Parameter(torch.ones(self.router_hidden_size)) + + self.router_mlp = ZayaRouterMLP(self.router_hidden_size, self.num_experts, config.rms_norm_eps) + + self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + self.balancing_biases[-1] = -1.0 + + def forward( + self, + hidden_states: torch.Tensor, + router_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + final_shape = (-1, self.top_k) + seq_length = hidden_states.shape[1] + + router_hidden_states = self.down_proj(hidden_states) + + if self.use_eda and router_states is not None: + router_hidden_states = router_hidden_states + router_states * self.router_states_scale + + router_hidden_states_next = router_hidden_states[:, -seq_length:].clone() + router_logits = self.router_mlp(router_hidden_states) + router_probs = torch.softmax(router_logits, dim=-1) + + biased_router_probs = router_probs.detach().to(torch.float32) + self.balancing_biases + _, router_indices = torch.topk(biased_router_probs, self.top_k, dim=-1) + router_probs = torch.gather(router_probs, dim=2, index=router_indices) + + # If the router selects the extra skip expert, mask it before `ZayaExperts` builds its one-hot expert mask. + skip_expert = router_indices == self.config.num_experts + router_probs = router_probs.masked_fill(skip_expert, 0) + router_indices = router_indices.masked_fill(skip_expert, 0) + + return ( + router_logits.reshape(-1, self.num_experts), + router_probs.reshape(final_shape), + router_indices.reshape(final_shape), + router_hidden_states_next, + ) + + +class ZayaExperts(Qwen3MoeExperts): + pass + + +class ZayaSparseMoeBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.gate = ZayaRouter(config, layer_idx) + self.experts = ZayaExperts(config) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # ZAYA carries router hidden states across decoder layers; the next layer consumes this state in its router. + _, router_probs, router_indices, prev_router_hidden_states = self.gate( + hidden_states, router_states=prev_router_hidden_states + ) + + batch_size, seq_length, emb_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) + expert_output = self.experts(hidden_states_flat, router_indices, router_probs) + expert_output = expert_output.view(batch_size, seq_length, emb_dim) + + return expert_output, prev_router_hidden_states + + +class ZayaPreTrainedModel(LlamaPreTrainedModel): + config: ZayaConfig + # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache. + _can_compile_fullgraph = False + _can_record_outputs = { + "router_logits": OutputRecorder(ZayaRouter, index=0), + "hidden_states": ZayaDecoderLayer, + "attentions": ZayaAttention, + } + + @torch.no_grad() + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, ZayaResidualScaling): + init.ones_(module.hidden_states_scale) + init.zeros_(module.hidden_states_bias) + init.ones_(module.residual_scale) + init.zeros_(module.residual_bias) + elif isinstance(module, ZayaModel): + init.ones_(module.input_hidden_states_scale) + init.zeros_(module.input_hidden_states_bias) + elif isinstance(module, ZayaQKNorm): + init.zeros_(module.temp) + elif isinstance(module, ZayaRouter): + if module.use_eda: + init.ones_(module.router_states_scale) + init.zeros_(module.balancing_biases) + module.balancing_biases[-1] = -1.0 # trf-ignore: TRF012 + elif isinstance(module, ZayaExperts): + std = self.config.initializer_range + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, ZayaRotaryEmbedding): + for layer_type in module.layer_types: + rope_init_fn = module.compute_default_rope_parameters + if module.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]] + curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type) + getattr(module, f"{layer_type}_inv_freq").copy_(curr_inv_freq) + getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq) + + +@auto_docstring +class ZayaModel(LagunaModel): + def __init__(self, config: ZayaConfig): + super().__init__(config) + self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError( + "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." + ) + + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + mask_creation_functions = { + "hybrid": lambda: create_causal_mask(**mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**mask_kwargs), + } + causal_mask_mapping = { + layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types) + } + cca_mask = self._update_cca_mask(attention_mask, past_key_values) + + hidden_states = inputs_embeds + + position_embeddings = { + layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) + for layer_type in set(self.config.layer_types) + } + + # Keep the residual stream in fp32 to match the original ZAYA `residual_in_fp32` path. + hidden_states = ((hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale).to( + torch.float32 + ) + + prev_router_hidden_states = None + + for idx, decoder_layer in enumerate(self.layers): + layer_type = self.config.layer_types[idx] + # Attention uses the prepared causal mask, while CCA projection still needs the raw 2D padding mask to + # zero padding tokens before convolution. + hidden_states, prev_router_hidden_states = decoder_layer( + hidden_states, + prev_router_hidden_states, + attention_mask={"causal": causal_mask_mapping[layer_type], "padding": cca_mask}, + past_key_values=past_key_values, + position_embeddings=position_embeddings[layer_type], + **kwargs, + ) + + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + def _update_cca_mask(self, attention_mask, past_key_values): + """ + No need to zero padding states when cached convolution states are already available or all inputs are valid. + """ + cca_mask = attention_mask + if (past_key_values is not None and past_key_values.has_previous_state()) or ( + attention_mask is not None and torch.all(attention_mask == 1) + ): + cca_mask = None + return cca_mask + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +class ZayaForCausalLM(AfmoeForCausalLM, ZayaPreTrainedModel): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _is_stateful = True + + _tp_plan = AttributeError() + _pp_plan = AttributeError() + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) + + +__all__ = [ + "ZayaConfig", + "ZayaPreTrainedModel", + "ZayaModel", + "ZayaForCausalLM", +] diff --git a/tests/models/zaya/__init__.py b/tests/models/zaya/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/models/zaya/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py new file mode 100644 index 000000000000..86448222a20e --- /dev/null +++ b/tests/models/zaya/test_modeling_zaya.py @@ -0,0 +1,382 @@ +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch ZAYA model.""" + +import unittest + +from huggingface_hub.errors import StrictDataclassClassValidationError +from parameterized import parameterized + +from transformers import is_torch_available +from transformers.testing_utils import Expectations, cleanup, require_torch, slow, torch_device + + +if is_torch_available(): + import torch + + from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel + from transformers.cache_utils import ( + DynamicCache, + LinearAttentionAndFullAttentionLayer, + LinearAttentionAndSlidingWindowAttentionLayer, + ) + from transformers.models.zaya.modeling_zaya import ZayaCCAProjection + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class ZayaModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = ZayaModel + + def __init__(self, parent, **kwargs): + super().__init__( + parent=parent, + num_hidden_layers=2, + moe_intermediate_size=32, + num_experts_per_tok=1, + layer_types=["hybrid", "hybrid_sliding"], + sliding_window=64, + **kwargs, + ) + + +@require_torch +class ZayaModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = ZayaModelTester + test_all_params_have_gradient = False + + def _get_conv_state_shape(self, batch_size: int, config): + conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim + conv_kernel_size = config.cca_time0 + config.cca_time1 - 2 + return (batch_size, conv_state_size, conv_kernel_size) + + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.num_key_value_heads * config.head_dim // 2) + + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + if not isinstance(past_key_values, DynamicCache): + raise ValueError("The cache does not use the correct Cache") + + config = config.get_text_config(decoder=True) + self.assertEqual(config.num_hidden_layers, len(past_key_values)) + attention_shape = (batch_size, config.num_key_value_heads, seq_length, config.head_dim) + conv_shape = self._get_conv_state_shape(batch_size, config) + recurrent_shape = self._get_recurrent_state_shape(batch_size, config) + + for layer_type, layer in zip(config.layer_types, past_key_values.layers): + expected_layer_class = ( + LinearAttentionAndSlidingWindowAttentionLayer + if layer_type == "hybrid_sliding" + else LinearAttentionAndFullAttentionLayer + ) + self.assertIs(type(layer), expected_layer_class) + self.assertEqual(layer.keys.shape, attention_shape) + self.assertEqual(layer.values.shape, attention_shape) + self.assertEqual(layer.conv_states.shape, conv_shape) + self.assertEqual(layer.recurrent_states.shape, recurrent_shape) + + @unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + config._attn_implementation = "eager" + + for model_class in self.all_model_classes: + model = model_class._from_config(config, attn_implementation="eager") + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class({**inputs_dict, "output_attentions": True}, model_class)) + + expected_attn_layers = config.num_hidden_layers + self.assertEqual(len(outputs.attentions), expected_attn_layers) + self.assertEqual( + outputs.attentions[0].shape, + ( + self.model_tester.batch_size, + config.num_attention_heads, + self.model_tester.seq_length, + self.model_tester.seq_length, + ), + ) + + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) + @unittest.skip( + "RoPE-scaling-from-config test doesn't match ZAYA's nested per-layer-type rope_parameters (same as e.g. Laguna, Gemma3)." + ) + def test_model_rope_scaling_from_config(self, scaling_type): + pass + + def test_model_rope_scaling_frequencies(self): + """ + Tests the frequency properties of the different RoPE scaling types on the model RoPE layer. + Copied from Laguna to adapt to per-layer-type rope configs. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + partial_rotary_factor = config.rope_parameters["hybrid"]["partial_rotary_factor"] + + def set_rope_params(rope_params): + config.rope_parameters = { + "hybrid": {**rope_params, "partial_rotary_factor": partial_rotary_factor}, + "hybrid_sliding": {**rope_params, "partial_rotary_factor": partial_rotary_factor}, + } + + set_rope_params({"rope_type": "default", "rope_theta": 10_000.0}) + + base_model = self.model_tester.base_model_class(config) + possible_rope_attributes = [ + "pos_emb", + "rotary_emb", + "global_rotary_emb", + "local_rotary_emb", + ] + for name, module in base_model.named_modules(): + if any(potential_name in name for potential_name in possible_rope_attributes): + rope_class = type(module) + break + + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + x = torch.randn(1, dtype=torch.float32, device=torch_device) + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device).unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device).unsqueeze(0) + + set_rope_params({"rope_type": "default", "rope_theta": 10_000.0}) + original_rope = rope_class(config=config).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, position_ids_short, layer_type="hybrid_sliding") + original_cos_long, original_sin_long = original_rope(x, position_ids_long, layer_type="hybrid_sliding") + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + + set_rope_params({"rope_type": "linear", "factor": scaling_factor, "rope_theta": 10_000.0}) + linear_scaling_rope = rope_class(config=config).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding") + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding") + torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) + torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) + + set_rope_params({"rope_type": "dynamic", "factor": scaling_factor, "rope_theta": 10_000.0}) + ntk_scaling_rope = rope_class(config=config).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding") + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding") + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.hybrid_sliding_inv_freq <= original_rope.hybrid_sliding_inv_freq).all()) + + set_rope_params({"rope_type": "yarn", "factor": scaling_factor, "rope_theta": 10_000.0}) + yarn_scaling_rope = rope_class(config=config).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding") + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding") + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_short, original_cos_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + def test_num_experts_per_tok_validation(self): + with self.assertRaisesRegex(StrictDataclassClassValidationError, "num_experts_per_tok=1"): + ZayaConfig(num_experts_per_tok=2) + + def test_sliding_attention_mask_is_used(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.layer_types = ["hybrid_sliding"] + ["hybrid"] * (config.num_hidden_layers - 1) + config.sliding_window = 3 + config._attn_implementation = "eager" + + model = ZayaModel._from_config(config, attn_implementation="eager").to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(input_ids=inputs_dict["input_ids"].to(torch_device), output_attentions=True) + + sliding_attention = outputs.attentions[0] + self.assertTrue(torch.all(sliding_attention[:, :, -1, : -config.sliding_window] == 0)) + + def test_cca_cache_matches_full_forward_multi_token(self): + config = ZayaConfig( + vocab_size=128, + hidden_size=32, + moe_intermediate_size=32, + num_hidden_layers=1, + num_experts=4, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + router_hidden_size=4, + tie_word_embeddings=False, + ) + torch.manual_seed(0) + cca = ZayaCCAProjection(config, layer_idx=0).to(torch_device) + cca.eval() + hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device) + + with torch.no_grad(): + # Compare full CCA projection against a cached continuation. The second chunk must recover the same + # q/k/v states from the cached convolution tail and delayed recurrent value state. + full = cca(hidden_states, None, None) + cache = DynamicCache(config=config) + cca(hidden_states[:, :3], cache, None) + cached = cca(hidden_states[:, 3:], cache, None) + + for full_states, cached_states in zip(full, cached): + torch.testing.assert_close(full_states[:, 3:], cached_states, rtol=1e-5, atol=1e-5) + + def test_zaya_cache_reorder_and_reset(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + cache = DynamicCache(config=config) + conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim + cache.update_conv_state( + torch.arange(2 * conv_state_size * 2, device=torch_device, dtype=torch.float32).view( + 2, conv_state_size, 2 + ), + 0, + ) + cache.update_recurrent_state( + torch.arange( + 2 * config.num_key_value_heads * config.head_dim // 2, device=torch_device, dtype=torch.float32 + ).view(2, config.num_key_value_heads * config.head_dim // 2), + 0, + ) + self.assertEqual(cache.layers[0].recurrent_states.shape[-1], config.num_key_value_heads * config.head_dim // 2) + + cache.reorder_cache(torch.tensor([1, 0], device=torch_device)) + self.assertEqual(cache.layers[0].conv_states.shape[0], 2) + + cache.reset() + self.assertFalse(cache.has_previous_state(0)) + self.assertEqual(cache.layers[0].conv_states.sum().item(), 0) + self.assertEqual(cache.layers[0].recurrent_states.sum().item(), 0) + + +@require_torch +class ZayaIntegrationTest(unittest.TestCase): + model = None + model_id = "Zyphra/ZAYA1-8B" + + @classmethod + def get_model(cls): + if cls.model is None: + cls.model = ZayaForCausalLM.from_pretrained(cls.model_id, device_map="auto", dtype=torch.bfloat16) + return cls.model + + @classmethod + def tearDownClass(cls): + if cls.model is not None: + del cls.model + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def get_inputs(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + inputs = tokenizer("Hello! How can I assist you today?", return_tensors="pt") + self.assertEqual( + inputs.input_ids.tolist(), + [[2, 9259, 236888, 2088, 740, 564, 6361, 611, 3124, 236881, 106]], + ) + return inputs + + @slow + def test_model_logits(self): + model = self.get_model() + inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) + + with torch.no_grad(): + logits = model(**inputs, use_cache=False, return_dict=True).logits.float().cpu() + + self.assertEqual(logits.shape, (1, inputs.input_ids.shape[-1], model.config.vocab_size)) + self.assertTrue(torch.isfinite(logits).all().item()) + + EXPECTED_LOGITS = Expectations( + { + (None, None): [ + [0.3613, 0.3633, 0.3633], + [-1.3672, -1.3672, -1.3672], + [-2.8750, -2.8750, -2.8750], + ] + } + ) # fmt: skip + expected_slice = torch.tensor(EXPECTED_LOGITS.get_expectation(), dtype=logits.dtype) + torch.testing.assert_close(logits[0, -3:, -3:], expected_slice, rtol=1e-3, atol=1e-3) + + expected_argmax = torch.tensor([[105, 9731, 107, 740, 564, 1601, 611, 236881, 236881, 107, 107]]) + torch.testing.assert_close(logits.argmax(-1), expected_argmax) + + @slow + def test_model_cache_matches_full_forward(self): + model = self.get_model() + inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) + + with torch.no_grad(): + full_logits = model(**inputs, use_cache=False).logits[:, -1] + prefill_outputs = model( + input_ids=inputs.input_ids[:, :-1], + attention_mask=inputs.attention_mask[:, :-1], + use_cache=True, + return_dict=True, + ) + cached_logits = model( + input_ids=inputs.input_ids[:, -1:], + attention_mask=inputs.attention_mask, + past_key_values=prefill_outputs.past_key_values, + use_cache=True, + return_dict=True, + ).logits[:, -1] + + torch.testing.assert_close(cached_logits.float().cpu(), full_logits.float().cpu(), rtol=1e-4, atol=1e-4) + + @slow + def test_model_generation(self): + model = self.get_model() + inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) + + with torch.no_grad(): + generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=16, top_k=None, top_p=None) + + expected_generated_ids = Expectations( + { + (None, None): [ + 107, 262146, 108, 9259, 236888, 2088, 740, 564, + 6361, 611, 3124, 236881, 108, 236769, 10282, 236787, + ] + } + ) # fmt: skip + self.assertEqual( + generated_ids[0, inputs.input_ids.shape[-1] :].tolist(), expected_generated_ids.get_expectation() + )