diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 11e3c9008d56..de07fec0f88f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -883,6 +883,8 @@ title: Zamba - local: model_doc/zamba2 title: Zamba2 + - local: model_doc/zaya + title: ZAYA title: Text models - sections: - local: model_doc/aimv2 @@ -1379,6 +1381,8 @@ title: Qwen3VL - local: model_doc/qwen3_vl_moe title: Qwen3VLMoe + - local: model_doc/zaya1_vl + title: ZAYA1-VL - local: model_doc/sam3 title: SAM3 - local: model_doc/sam3_video diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md new file mode 100644 index 000000000000..199cd5d2935b --- /dev/null +++ b/docs/source/en/model_doc/zaya.md @@ -0,0 +1,63 @@ + +*This model was released on 2026-05-06 and added to Hugging Face Transformers on 2026-05-16.* + +# ZAYA + +## Overview + +ZAYA1 is a 760M active / 8.4B total parameter MoE language model trained by Zyphra. It combines Compressed +Convolutional Attention (CCA), a nonlinear ZAYA1 router, and residual scaling. + +ZAYA1 uses the Gemma 3 tokenizer. For more details, see the [ZAYA1 model card](https://huggingface.co/Zyphra/ZAYA1-8B) +and Zyphra's technical reports. + +This model was contributed by [JJJYmmm](https://github.com/JJJYmmm). + +## Usage examples + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_id = "Zyphra/ZAYA1-8B" +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + +inputs = tokenizer.apply_chat_template( + [{"role": "user", "content": "Write a haiku about recursion in programming."}], + tokenize=True, + add_generation_prompt=True, + enable_thinking=False, + return_tensors="pt", +) +inputs = inputs.to(model.device) +outputs = model.generate(**inputs, max_new_tokens=256) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +## ZayaConfig + +[[autodoc]] ZayaConfig + +## ZayaModel + +[[autodoc]] ZayaModel + - forward + +## ZayaForCausalLM + +[[autodoc]] ZayaForCausalLM + - forward diff --git a/docs/source/en/model_doc/zaya1_vl.md b/docs/source/en/model_doc/zaya1_vl.md new file mode 100644 index 000000000000..4a3320189f10 --- /dev/null +++ b/docs/source/en/model_doc/zaya1_vl.md @@ -0,0 +1,95 @@ + +*This model was released on 2026-05-08 and added to Hugging Face Transformers on 2026-05-15.* + +# ZAYA1-VL + +## Overview + +ZAYA1-VL is a vision-language model from Zyphra built on top of the ZAYA1 text decoder and the Qwen2.5-VL vision +encoder. It adds vision-token-specific LoRA parameters in the text decoder and uses bidirectional attention between +image placeholder tokens. + +For more details, see the [ZAYA1-VL model card](https://huggingface.co/Zyphra/ZAYA1-VL-8B). + +This model was contributed by [JJJYmmm](https://github.com/JJJYmmm). + +## Usage examples + +```python +from transformers import AutoModelForImageTextToText, AutoProcessor + +model_id = "Zyphra/ZAYA1-VL-8B" +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForImageTextToText.from_pretrained(model_id, device_map="auto") + +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}, + {"type": "text", "text": "What do you see in the image?"}, + ], + } +] + +inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", +).to(model.device) +outputs = model.generate(**inputs, max_new_tokens=100) +generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(inputs.input_ids, outputs)] +print(processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) +``` + +## Zaya1VLConfig + +[[autodoc]] Zaya1VLConfig + +## Zaya1VLTextConfig + +[[autodoc]] Zaya1VLTextConfig + +## Zaya1VLVisionConfig + +[[autodoc]] Zaya1VLVisionConfig + +## Zaya1VLProcessor + +[[autodoc]] Zaya1VLProcessor + +## Zaya1VLModel + +[[autodoc]] Zaya1VLModel + - forward + +## Zaya1VLVisionModel + +[[autodoc]] Zaya1VLVisionModel + - forward + +## Zaya1VLTextModel + +[[autodoc]] Zaya1VLTextModel + - forward + +## Zaya1VLForConditionalGeneration + +[[autodoc]] Zaya1VLForConditionalGeneration + - forward diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index dfef404a42f1..993643dfe390 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -864,6 +864,33 @@ def reorder_cache(self, beam_idx: torch.LongTensor): DynamicLayer.reorder_cache(self, beam_idx) +class LinearAttentionAndSlidingWindowAttentionLayer(LinearAttentionLayer, DynamicSlidingWindowLayer): + # The dynamic sliding attention part makes it non-compileable + is_compileable = False + + def __init__(self, config: PreTrainedConfig | None = None): + DynamicSlidingWindowLayer.__init__(self, config) + LinearAttentionLayer.__init__(self) + + def lazy_initialization(self, *args, **kwargs) -> None: + # When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args + if len(args) == 2 and len(kwargs) == 0: + DynamicSlidingWindowLayer.lazy_initialization(self, *args) + # Otherwise, for the LinearAttention cache, when it's called in `update_conv_state` or `update_recurrent_state`, + # it's always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states) + if len(args) == 0 and len(kwargs) == 1: + LinearAttentionLayer.lazy_initialization(self, **kwargs) + + def reset(self) -> None: + LinearAttentionLayer.reset(self) + DynamicSlidingWindowLayer.reset(self) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + LinearAttentionLayer.reorder_cache(self, beam_idx) + DynamicSlidingWindowLayer.reorder_cache(self, beam_idx) + + # Pre-register the standard layer types (some classes are shared between multiple types, # e.g. ``DynamicSlidingWindowLayer`` covers both ``"sliding_attention"`` and # ``"chunked_attention"`` — those need an explicit map entry rather than the @@ -883,6 +910,7 @@ def reorder_cache(self, beam_idx: torch.LongTensor): "moe": LinearAttentionLayer, # Hybrid layers (e.g. zamba / zamba2) carry both a linear-attention state and a dynamic-attention state. "hybrid": LinearAttentionAndFullAttentionLayer, + "hybrid_sliding": LinearAttentionAndSlidingWindowAttentionLayer, } ) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 7ba033c538d8..e495bbdc69c0 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -71,7 +71,8 @@ "attention", "sparse", "dense", - "hybrid", # for layers that have both mamba and attention in zamba and zamba2 + "hybrid", # for zamba/zamba2/zaya1, which use full attention + conv states + "hybrid_sliding", # for zaya1, which uses swa + conv states "moe", # for nemotron_h, which uses either attention, mamba or moe ) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 406c5f7be0fc..caf647163c7c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -479,6 +479,8 @@ from .youtu import * from .zamba import * from .zamba2 import * + from .zaya import * + from .zaya1_vl import * from .zoedepth import * else: import sys diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 048dd5275537..c87ad12d9534 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -643,6 +643,10 @@ ("youtu", "YoutuConfig"), ("zamba", "ZambaConfig"), ("zamba2", "Zamba2Config"), + ("zaya", "ZayaConfig"), + ("zaya1_vl", "Zaya1VLConfig"), + ("zaya1_vl_text", "Zaya1VLTextConfig"), + ("zaya1_vl_vision", "Zaya1VLVisionConfig"), ("zoedepth", "ZoeDepthConfig"), ] ) @@ -858,6 +862,8 @@ ("xclip_vision_model", "x_clip"), ("xlm-roberta", "xlm_roberta"), ("xlm-roberta-xl", "xlm_roberta_xl"), + ("zaya1_vl_text", "zaya1_vl"), + ("zaya1_vl_vision", "zaya1_vl"), ] ) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index bc848896e1ee..9f80ec080ab0 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -154,6 +154,7 @@ ("vit_msn", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), ("vivit", {"torchvision": "VivitImageProcessor"}), ("xclip", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("zaya1_vl", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2202cc773db0..81a0b3bed549 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -512,6 +512,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("youtu", "YoutuModel"), ("zamba", "ZambaModel"), ("zamba2", "Zamba2Model"), + ("zaya", "ZayaModel"), + ("zaya1_vl", "Zaya1VLModel"), + ("zaya1_vl_text", "Zaya1VLTextModel"), ] ) @@ -774,6 +777,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("youtu", "YoutuForCausalLM"), ("zamba", "ZambaForCausalLM"), ("zamba2", "Zamba2ForCausalLM"), + ("zaya", "ZayaForCausalLM"), ] ) @@ -1050,6 +1054,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("video_llava", "VideoLlavaForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"), + ("zaya1_vl", "Zaya1VLForConditionalGeneration"), ] ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 30c6fc520c49..79c7ff8a61d7 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -189,6 +189,7 @@ ("wavlm", "Wav2Vec2Processor"), ("whisper", "WhisperProcessor"), ("xclip", "XCLIPProcessor"), + ("zaya1_vl", "Zaya1VLProcessor"), ] ) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index db34543e63a1..75677f8ea505 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -340,6 +340,7 @@ ("xlstm", "GPTNeoXTokenizer" if is_tokenizers_available() else None), ("xmod", "XLMRobertaTokenizer" if is_tokenizers_available() else None), ("yoso", "AlbertTokenizer" if is_tokenizers_available() else None), + ("zaya", "GemmaTokenizer" if is_tokenizers_available() else None), ] ) diff --git a/src/transformers/models/zaya/__init__.py b/src/transformers/models/zaya/__init__.py new file mode 100644 index 000000000000..c28f97af94ea --- /dev/null +++ b/src/transformers/models/zaya/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2026 Zyphra and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_zaya import * + from .modeling_zaya import * + +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py new file mode 100644 index 000000000000..4a18bd4716f2 --- /dev/null +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -0,0 +1,123 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zaya/modular_zaya.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zaya.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Literal + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +@strict +class ZayaConfig(PreTrainedConfig): + r""" + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the language modeling head. + router_hidden_size (`int`, *optional*, defaults to 256): + Hidden size used by the ZAYA router. + cca_time0 (`int`, *optional*, defaults to 2): + First temporal parameter of the CCA projection. + cca_time1 (`int`, *optional*, defaults to 2): + Second temporal parameter of the CCA projection. + + ```python + >>> from transformers import ZayaConfig, ZayaModel + + >>> configuration = ZayaConfig() + >>> model = ZayaModel(configuration) + + >>> configuration = model.config + ``` + """ + + model_type = "zaya" + keys_to_ignore_at_inference = ["past_key_values"] + + vocab_size: int = 262272 + hidden_size: int = 2048 + num_hidden_layers: int = 40 + num_attention_heads: int = 8 + num_key_value_heads: int = 2 + hidden_act: str = "silu" + max_position_embeddings: int = 131072 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-5 + use_cache: bool = True + tie_word_embeddings: bool = True + rope_parameters: RopeParameters | dict | None = None + sliding_window: int | None = None + attention_dropout: float | int = 0.0 + moe_intermediate_size: int = 2048 + + num_experts_per_tok: int = 1 + num_experts: int = 16 + output_router_logits: bool = False + layer_types: list[str] | None = None + pad_token_id: int | None = 0 + bos_token_id: int | None = 2 + eos_token_id: int | list[int] | None = 106 + + # Zaya-specific attention + head_dim: int = 128 + attention_bias: bool = False + + lm_head_bias: bool = False + router_hidden_size: int = 256 + cca_time0: int = 2 + cca_time1: int = 2 + + def __post_init__(self, **kwargs): + self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) + + default_rope_params: dict[Literal["hybrid", "hybrid_sliding"], dict[str, Any]] = { + "hybrid": { + "rope_type": "default", + "rope_theta": 5_000_000.0, + "partial_rotary_factor": 0.5, + }, + "hybrid_sliding": { + "rope_type": "default", + "rope_theta": 10_000.0, + "partial_rotary_factor": 0.5, + }, + } + if self.rope_parameters is None: + self.rope_parameters = default_rope_params + + super().__post_init__(**kwargs, ignore_keys_at_rope_validation={"hybrid", "hybrid_sliding"}) + + def convert_rope_params_to_dict(self, **kwargs): + # No legacy flat RoPE format is supported here; conversion writes the nested ZAYA layer-type format directly. + return kwargs + + def validate_architecture(self): + """Part of ``@strict``-powered validation.""" + if self.num_experts_per_tok != 1: + raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.") + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") + if "hybrid_sliding" in self.layer_types and self.sliding_window is None: + raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.") + + +__all__ = ["ZayaConfig"] diff --git a/src/transformers/models/zaya/convert_zaya_weights_to_hf.py b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py new file mode 100644 index 000000000000..2ac6cb7df869 --- /dev/null +++ b/src/transformers/models/zaya/convert_zaya_weights_to_hf.py @@ -0,0 +1,374 @@ +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert original alternating-layer ZAYA checkpoints to the Transformers-native decoder-layer layout.""" + +import argparse +import json +import re +import shutil +from collections import defaultdict +from pathlib import Path + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from transformers import ZayaConfig + + +_DEFAULT_ROPE_THETA = 5_000_000.0 +_DEFAULT_SWA_ROPE_THETA = 10_000.0 +_LAYER_PATTERN = re.compile(r"^model\.layers\.(\d+)\.(.+)$") +_LOCAL_EXPERT_PATTERN = re.compile( + r"^model\.layers\.(\d+)\.zaya_block\.experts\.local_experts\.(\d+)\.linear_fc([12])\.weight$" +) + +_UNUSED_CONFIG_KEYS = ( + "cca", + "num_query_groups", + "intermediate_size", + "ffn_hidden_size", + "moe_router_topk", + "norm_epsilon", + "zaya_mlp_expansion", + "activation_func", + "normalization", + "add_bias_linear", + "gated_linear_unit", + "fused_add_norm", + "apply_rope_fusion", + "bias_activation_fusion", + "activation_func_fp8_input_store", + "clamp_temp", + "kv_channels", + "mamba_cache_dtype", + "residual_in_fp32", + "rope_scaling", + "scale_residual_merge", + "zaya_high_prec", + "zaya_use_mod", + "zaya_use_eda", +) + + +def _rename_common(rest: str) -> str: + replacements = ( + ("self_attn.qkv.conv_qk.0.", "self_attn.qkv_proj.conv_qk_depthwise."), + ("self_attn.qkv.conv_qk.1.", "self_attn.qkv_proj.conv_qk_grouped."), + ("self_attn.qkv.temp", "self_attn.qk_norm.temp"), + ("self_attn.qkv.linear_q.", "self_attn.qkv_proj.q_proj."), + ("self_attn.qkv.linear_k.", "self_attn.qkv_proj.k_proj."), + ("self_attn.qkv.val_proj1.", "self_attn.qkv_proj.v_proj_current."), + ("self_attn.qkv.val_proj2.", "self_attn.qkv_proj.v_proj_delayed."), + ("self_attn.qkv.", "self_attn.qkv_proj."), + ("zaya_block.router.rmsnorm_eda.", "mlp.gate.router_mlp.rmsnorm_eda."), + ("zaya_block.router.router_mlp.0.", "mlp.gate.router_mlp.fc1."), + ("zaya_block.router.router_mlp.2.", "mlp.gate.router_mlp.fc2."), + ("zaya_block.router.router_mlp.4.", "mlp.gate.router_mlp.out_proj."), + ("zaya_block.router.", "mlp.gate."), + ("zaya_block.", "mlp."), + ) + for old, new in replacements: + if rest.startswith(old): + return new + rest.removeprefix(old) + return rest + + +def _expert_target(name: str) -> tuple[str, int] | None: + match = _LOCAL_EXPERT_PATTERN.match(name) + if match is None: + return None + + old_layer_idx = int(match.group(1)) + if old_layer_idx % 2 != 1: + raise ValueError(f"Expert weights are expected on odd ZAYA layers, got: {name}") + + new_layer_idx = old_layer_idx // 2 + expert_idx = int(match.group(2)) + projection = "gate_up_proj" if match.group(3) == "1" else "down_proj" + target = f"model.layers.{new_layer_idx}.mlp.experts.{projection}" + return target, expert_idx + + +def convert_weight_name(name: str, old_num_hidden_layers: int | None = None) -> str | None: + if _expert_target(name) is not None: + return None + + match = _LAYER_PATTERN.match(name) + if match is None: + if old_num_hidden_layers is not None and name.startswith("model.res_scale."): + new_layer_idx = old_num_hidden_layers // 2 - 1 + return f"model.layers.{new_layer_idx}.post_mlp_residual_scale.{name.removeprefix('model.res_scale.')}" + return name + + old_layer_idx = int(match.group(1)) + rest = match.group(2) + new_layer_idx = old_layer_idx // 2 + + if old_layer_idx % 2 == 0: + rest = _rename_common(rest) + if rest.startswith("self_attn."): + return f"model.layers.{new_layer_idx}.{rest}" + if rest.startswith("input_norm."): + return f"model.layers.{new_layer_idx}.input_layernorm.{rest.removeprefix('input_norm.')}" + if rest.startswith("res_scale."): + if old_layer_idx == 0: + return f"model.input_{rest.removeprefix('res_scale.')}" + return f"model.layers.{new_layer_idx - 1}.post_mlp_residual_scale.{rest.removeprefix('res_scale.')}" + else: + rest = _rename_common(rest) + if rest.startswith("mlp."): + return f"model.layers.{new_layer_idx}.{rest}" + if rest.startswith("input_norm."): + return f"model.layers.{new_layer_idx}.post_attention_layernorm.{rest.removeprefix('input_norm.')}" + if rest.startswith("res_scale."): + return f"model.layers.{new_layer_idx}.post_attention_residual_scale.{rest.removeprefix('res_scale.')}" + + raise ValueError(f"Unexpected ZAYA layer weight name: {name}") + + +def _to_hybrid_layer_type(layer_type: str) -> str: + if layer_type == "full_attention": + return "hybrid" + if layer_type == "sliding_attention": + return "hybrid_sliding" + raise ValueError(f"Unsupported ZAYA layer type: {layer_type}") + + +def _convert_layer_types(config_dict: dict, old_num_hidden_layers: int, new_num_hidden_layers: int) -> list[str]: + layer_types = config_dict.get("layer_types") + if layer_types is not None: + if len(layer_types) == old_num_hidden_layers: + return [_to_hybrid_layer_type(layer_type) for layer_type in layer_types[::2]] + if len(layer_types) == new_num_hidden_layers: + return [_to_hybrid_layer_type(layer_type) for layer_type in layer_types] + raise ValueError("`layer_types` must match either the original or converted number of hidden layers.") + + swa_layers = config_dict.get("swa_layers") + if swa_layers is None: + return ["hybrid"] * new_num_hidden_layers + if len(swa_layers) == old_num_hidden_layers: + swa_layers = swa_layers[::2] + elif len(swa_layers) != new_num_hidden_layers: + raise ValueError("`swa_layers` must match either the original or converted number of hidden layers.") + return ["hybrid" if int(window_size) == 0 else "hybrid_sliding" for window_size in swa_layers] + + +def convert_config(input_dir: Path, output_dir: Path) -> None: + config_dict = json.loads((input_dir / "config.json").read_text()) + old_num_hidden_layers = int(config_dict["num_hidden_layers"]) + if old_num_hidden_layers % 2 != 0: + raise ValueError("Original ZAYA checkpoints must have an even number of alternating attention/MoE layers.") + + new_num_hidden_layers = old_num_hidden_layers // 2 + layer_types = _convert_layer_types(config_dict, old_num_hidden_layers, new_num_hidden_layers) + partial_rotary_factor = 0.5 + rope_theta = config_dict.get("rope_theta", _DEFAULT_ROPE_THETA) + swa_rotary_base = config_dict.get("swa_rotary_base", _DEFAULT_SWA_ROPE_THETA) + rms_norm_eps = config_dict.get("rms_norm_eps", config_dict.get("norm_epsilon", ZayaConfig.rms_norm_eps)) + router_hidden_size = config_dict.get( + "router_hidden_size", config_dict.get("zaya_mlp_expansion", ZayaConfig.router_hidden_size) + ) + expert_ffn_size = config_dict.get("intermediate_size", config_dict.get("ffn_hidden_size")) + moe_intermediate_size = expert_ffn_size // 2 if expert_ffn_size is not None else ZayaConfig.moe_intermediate_size + num_experts_per_tok = config_dict.get( + "num_experts_per_tok", config_dict.get("moe_router_topk", ZayaConfig.num_experts_per_tok) + ) + + swa_layers = config_dict.get("swa_layers") or [] + sliding_window = config_dict.get("sliding_window") + if sliding_window is None: + positive_windows = [int(window_size) for window_size in swa_layers if int(window_size) > 0] + # Original ZAYA stores the number of previous tokens attended by SWA layers. Transformers' sliding window + # is the total local attention span, including the current token. + sliding_window = max(positive_windows) + 1 if positive_windows else None + + rope_parameters = { + "hybrid": { + "rope_type": "default", + "rope_theta": rope_theta, + "partial_rotary_factor": partial_rotary_factor, + }, + "hybrid_sliding": { + "rope_type": "default", + "rope_theta": swa_rotary_base, + "partial_rotary_factor": partial_rotary_factor, + }, + } + + for key in (*_UNUSED_CONFIG_KEYS, "swa_layers", "rope_theta", "swa_rotary_base"): + config_dict.pop(key, None) + + config_dict.update( + { + "architectures": ["ZayaForCausalLM"], + "num_hidden_layers": new_num_hidden_layers, + "moe_intermediate_size": moe_intermediate_size, + "num_experts_per_tok": num_experts_per_tok, + "rms_norm_eps": rms_norm_eps, + "router_hidden_size": router_hidden_size, + "layer_types": layer_types, + "sliding_window": sliding_window, + "rope_parameters": rope_parameters, + } + ) + ZayaConfig(**config_dict).save_pretrained(output_dir) + + +def copy_non_weight_files(input_dir: Path, output_dir: Path) -> None: + for path in input_dir.iterdir(): + if path.name == "config.json": + continue + if path.name.endswith(".safetensors") or path.name.endswith(".bin"): + continue + if path.name in {"model.safetensors.index.json", "pytorch_model.bin.index.json"}: + continue + + output_path = output_dir / path.name + if path.is_dir(): + shutil.copytree(path, output_path, dirs_exist_ok=True) + else: + shutil.copy2(path, output_path) + + +def _build_weight_plan(input_dir: Path) -> tuple[dict[str, str], dict[str, list[str]], dict[str, str], dict]: + index = json.loads((input_dir / "model.safetensors.index.json").read_text()) + old_weight_map = index["weight_map"] + old_num_hidden_layers = int(json.loads((input_dir / "config.json").read_text())["num_hidden_layers"]) + converted_weight_map = {} + normal_sources_by_output_file = defaultdict(list) + expert_sources_by_target = defaultdict(list) + output_file_by_target = {} + + for source_key, filename in old_weight_map.items(): + expert_info = _expert_target(source_key) + if expert_info is not None: + target_key, expert_idx = expert_info + expert_sources_by_target[target_key].append((expert_idx, source_key)) + output_file_by_target.setdefault(target_key, filename) + converted_weight_map[target_key] = output_file_by_target[target_key] + continue + + target_key = convert_weight_name(source_key, old_num_hidden_layers) + if target_key in converted_weight_map: + raise ValueError(f"Duplicate converted weight name: {target_key}") + converted_weight_map[target_key] = filename + normal_sources_by_output_file[filename].append((source_key, target_key)) + + index["weight_map"] = converted_weight_map + return normal_sources_by_output_file, expert_sources_by_target, output_file_by_target, index + + +def _load_sources(input_dir: Path, source_keys: list[str], old_weight_map: dict[str, str]) -> dict[str, torch.Tensor]: + sources_by_file = defaultdict(list) + for source_key in source_keys: + sources_by_file[old_weight_map[source_key]].append(source_key) + + tensors = {} + for filename, keys in sources_by_file.items(): + with safe_open(input_dir / filename, framework="pt", device="cpu") as f: + for key in keys: + tensors[key] = f.get_tensor(key) + return tensors + + +def convert_safetensors(input_dir: Path, output_dir: Path) -> None: + index_path = input_dir / "model.safetensors.index.json" + if not index_path.exists(): + safetensors_path = input_dir / "model.safetensors" + if not safetensors_path.exists(): + raise FileNotFoundError("Only safetensors ZAYA checkpoints are supported by this converter.") + + old_num_hidden_layers = int(json.loads((input_dir / "config.json").read_text())["num_hidden_layers"]) + with safe_open(safetensors_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + state_dict = {} + expert_groups = defaultdict(list) + for key in f.keys(): + expert_info = _expert_target(key) + if expert_info is not None: + target_key, expert_idx = expert_info + expert_groups[target_key].append((expert_idx, f.get_tensor(key))) + continue + state_dict[convert_weight_name(key, old_num_hidden_layers)] = f.get_tensor(key) + for target_key, expert_tensors in expert_groups.items(): + state_dict[target_key] = torch.stack([tensor for _, tensor in sorted(expert_tensors)], dim=0) + save_file(state_dict, output_dir / "model.safetensors", metadata=metadata) + return + + old_index = json.loads(index_path.read_text()) + old_weight_map = old_index["weight_map"] + normal_sources_by_output_file, expert_sources_by_target, output_file_by_target, converted_index = ( + _build_weight_plan(input_dir) + ) + output_filenames = sorted(set(converted_index["weight_map"].values())) + + metadata_by_file = {} + for filename in sorted(set(old_weight_map.values())): + with safe_open(input_dir / filename, framework="pt", device="cpu") as f: + metadata_by_file[filename] = f.metadata() + + for output_filename in output_filenames: + shard = {} + normal_sources = normal_sources_by_output_file.get(output_filename, []) + source_keys = [source_key for source_key, _ in normal_sources] + + expert_groups_for_shard = { + target_key: sorted(sources) + for target_key, sources in expert_sources_by_target.items() + if output_file_by_target[target_key] == output_filename + } + for sources in expert_groups_for_shard.values(): + source_keys.extend(source_key for _, source_key in sources) + + loaded_tensors = _load_sources(input_dir, source_keys, old_weight_map) + for source_key, target_key in normal_sources: + shard[target_key] = loaded_tensors[source_key] + for target_key, sources in expert_groups_for_shard.items(): + shard[target_key] = torch.stack([loaded_tensors[source_key] for _, source_key in sources], dim=0) + + save_file(shard, output_dir / output_filename, metadata=metadata_by_file.get(output_filename)) + + (output_dir / "model.safetensors.index.json").write_text( + json.dumps(converted_index, indent=2, sort_keys=True) + "\n" + ) + + +def convert_checkpoint(input_dir: str, output_dir: str, overwrite: bool = False) -> None: + input_path = Path(input_dir).expanduser().resolve() + output_path = Path(output_dir).expanduser().resolve() + if input_path == output_path: + raise ValueError("Please write the converted checkpoint to a different output directory.") + if output_path.exists() and any(output_path.iterdir()): + if not overwrite: + raise FileExistsError(f"{output_path} already exists and is not empty. Pass --overwrite to replace it.") + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + copy_non_weight_files(input_path, output_path) + convert_config(input_path, output_path) + convert_safetensors(input_path, output_path) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--input_dir", required=True, help="Path to the original alternating-layer ZAYA checkpoint.") + parser.add_argument("--output_dir", required=True, help="Path where the converted checkpoint should be written.") + parser.add_argument("--overwrite", action="store_true", help="Overwrite a non-empty output directory.") + args = parser.parse_args() + convert_checkpoint(args.input_dir, args.output_dir, overwrite=args.overwrite) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py new file mode 100755 index 000000000000..0815020f0e2e --- /dev/null +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -0,0 +1,888 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zaya/modular_zaya.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zaya.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections.abc import Callable +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import init + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from .configuration_zaya import ZayaConfig + + +class ZayaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: ZayaConfig): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + self.layer_types = list(set(config.layer_types)) + self.rope_type = {} + for layer_type in self.layer_types: + rope_params = self.config.rope_parameters[layer_type] + if rope_params is None: + continue + + self.rope_type[layer_type] = rope_params["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] + curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, layer_type=layer_type) + self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False) + self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False) + setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling) + + @staticmethod + def compute_default_rope_parameters( + config: ZayaConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + layer_type (`str`, *optional*): + The current layer type if the model has different RoPE parameters per type. + Should not be used unless `config.layer_types is not None` + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters[layer_type]["rope_theta"] + # key difference to gemma3: partial rope + partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids, layer_type=None): + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernel_forward_from_hub("RMSNorm") +class ZayaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + ZayaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ZayaCCAProjection(nn.Module): + """ + Projects hidden states into attention q/k/v states with ZAYA's CCA path. + + `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal + `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail; + for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. + Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token + `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection + from **the recurrent cache**. + """ + + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + + self.depthwise_kernel_size = config.cca_time0 + self.grouped_kernel_size = config.cca_time1 + self.conv_kernel_size = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) + + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + query_hidden_size = self.num_attention_heads * self.head_dim + key_value_hidden_size = self.num_key_value_heads * self.head_dim + + self.q_proj = nn.Linear(self.hidden_size, query_hidden_size, bias=self.config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, key_value_hidden_size, bias=self.config.attention_bias) + self.v_proj_current = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + self.v_proj_delayed = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + + conv_channels = key_value_hidden_size + query_hidden_size + self.conv_qk_depthwise = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.depthwise_kernel_size, + groups=conv_channels, + padding=0, + stride=1, + ) + self.conv_qk_grouped = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.grouped_kernel_size, + groups=(self.num_key_value_heads + self.num_attention_heads), + padding=0, + stride=1, + ) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Cache | None, + padding_mask: torch.Tensor | None = None, + ): + if padding_mask is not None: + hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + projected_queries = self.q_proj(hidden_states) + projected_keys = self.k_proj(hidden_states) + qk_states = torch.cat([projected_queries, projected_keys], dim=-1) + + query_residual = projected_queries.view(*hidden_shape) + key_residual = projected_keys.view(*hidden_shape).transpose(1, 2) + key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) + query_residual = (query_residual + key_residual) * 0.5 + key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) + + qk_states = qk_states.transpose(1, 2) + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx) + if use_precomputed_states: + cached_qk_states = past_key_values.layers[self.layer_idx].conv_states + qk_states = torch.cat([cached_qk_states, qk_states], dim=-1) + else: + qk_states = F.pad(qk_states, (self.conv_kernel_size, 0)) + + if past_key_values is not None: + new_conv_state = qk_states[..., -self.conv_kernel_size :] + if new_conv_state.shape[-1] < self.conv_kernel_size: + new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) + past_key_values.update_conv_state(new_conv_state, self.layer_idx) + + qk_states = self.conv_qk_depthwise(qk_states) + qk_states = self.conv_qk_grouped(qk_states).transpose(1, 2) + + query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1] + query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual + key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual + + # The value path carries half of each value head from the current token and half from the previous token. + # During cached decoding, `recurrent_v_state` is the previous token's delayed projection. + value_current = self.v_proj_current(hidden_states) + delayed_v_state = self.v_proj_delayed(hidden_states) + if use_precomputed_states: + recurrent_v_state = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) + else: + recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) + value_delayed = torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1) + + if past_key_values is not None: + past_key_values.update_recurrent_state(delayed_v_state[:, -1, :], self.layer_idx) + + value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape) + + return query, key, value + + +class ZayaQKNorm(nn.Module): + """ + L2-normalizes q/k states to sqrt(head_dim) and applies ZAYA's learned per-KV-head key scale. + """ + + def __init__(self, config: ZayaConfig): + super().__init__() + self.head_dim_scale = config.head_dim**0.5 + self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + + def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + norm_eps = torch.finfo(query_states.dtype).eps + query_states = query_states * ( + self.head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * ( + self.head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * self.temp[None, None, :, None] + return query_states, key_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class ZayaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.qkv_proj = ZayaCCAProjection( + config=self.config, + layer_idx=layer_idx, + ) + self.layer_type = config.layer_types[layer_idx] + self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.qk_norm = ZayaQKNorm(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + + mask_mapping = attention_mask or {} + causal_mask = mask_mapping.get("causal") + padding_mask = mask_mapping.get("padding") + + query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask) + + query_states, key_states = self.qk_norm(query_states, key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + causal_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.view(*input_shape, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class ZayaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ZayaAttention(config=config, layer_idx=layer_idx) + self.mlp = ZayaSparseMoeBlock(config, layer_idx) + self.input_layernorm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = ZayaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size) + self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_embeddings=position_embeddings, + **kwargs, + ) + + residual = self.post_attention_residual_scale(hidden_states, residual) + hidden_states = self.post_attention_layernorm(residual) + + hidden_states, prev_router_hidden_states = self.mlp( + hidden_states, + prev_router_hidden_states, + ) + + hidden_states = self.post_mlp_residual_scale(hidden_states, residual) + + return hidden_states, prev_router_hidden_states + + +class ZayaResidualScaling(nn.Module): + def __init__(self, hidden_size: int): + super().__init__() + self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) + self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size)) + self.residual_scale = nn.Parameter(torch.ones(hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): + output_dtype = hidden_states.dtype + hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale + # Matches the original ZAYA `residual_in_fp32` path. + residual = residual.to(torch.float32) + residual = (residual + self.residual_bias) * self.residual_scale + return (hidden_states + residual).to(output_dtype) + + +class ZayaRouterMLP(nn.Module): + def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float): + super().__init__() + self.rmsnorm_eda = ZayaRMSNorm(hidden_size, eps=rms_norm_eps) + self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) + self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.rmsnorm_eda(hidden_states) + hidden_states = self.act_fn(self.fc1(hidden_states)) + hidden_states = self.act_fn(self.fc2(hidden_states)) + return self.out_proj(hidden_states) + + +class ZayaRouter(nn.Module): + def __init__( + self, + config, + layer_idx: int, + ) -> None: + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.num_experts = config.num_experts + 1 + self.top_k = config.num_experts_per_tok + self.router_hidden_size = config.router_hidden_size + + self.down_proj = nn.Linear(self.hidden_size, self.router_hidden_size, bias=True) + + self.use_eda = self.layer_idx != 0 + if self.use_eda: + self.router_states_scale = nn.Parameter(torch.ones(self.router_hidden_size)) + + self.router_mlp = ZayaRouterMLP(self.router_hidden_size, self.num_experts, config.rms_norm_eps) + + self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + self.balancing_biases[-1] = -1.0 + + def forward( + self, + hidden_states: torch.Tensor, + router_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + final_shape = (-1, self.top_k) + seq_length = hidden_states.shape[1] + + router_hidden_states = self.down_proj(hidden_states) + + if self.use_eda and router_states is not None: + router_hidden_states = router_hidden_states + router_states * self.router_states_scale + + router_hidden_states_next = router_hidden_states[:, -seq_length:].clone() + router_logits = self.router_mlp(router_hidden_states) + router_probs = torch.softmax(router_logits, dim=-1) + + biased_router_probs = router_probs.detach().to(torch.float32) + self.balancing_biases + _, router_indices = torch.topk(biased_router_probs, self.top_k, dim=-1) + router_probs = torch.gather(router_probs, dim=2, index=router_indices) + + # If the router selects the extra skip expert, mask it before `ZayaExperts` builds its one-hot expert mask. + skip_expert = router_indices == self.config.num_experts + router_probs = router_probs.masked_fill(skip_expert, 0) + router_indices = router_indices.masked_fill(skip_expert, 0) + + return ( + router_logits.reshape(-1, self.num_experts), + router_probs.reshape(final_shape), + router_indices.reshape(final_shape), + router_hidden_states_next, + ) + + +@use_experts_implementation +class ZayaExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class ZayaSparseMoeBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.gate = ZayaRouter(config, layer_idx) + self.experts = ZayaExperts(config) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + _, router_probs, router_indices, prev_router_hidden_states = self.gate( + hidden_states, router_states=prev_router_hidden_states + ) + + batch_size, seq_length, emb_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) + expert_output = self.experts(hidden_states_flat, router_indices, router_probs) + expert_output = expert_output.view(batch_size, seq_length, emb_dim) + + return expert_output, prev_router_hidden_states + + +@auto_docstring +class ZayaPreTrainedModel(PreTrainedModel): + config: ZayaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ZayaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache. + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(ZayaRouter, index=0), + "hidden_states": ZayaDecoderLayer, + "attentions": ZayaAttention, + } + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, ZayaResidualScaling): + init.ones_(module.hidden_states_scale) + init.zeros_(module.hidden_states_bias) + init.ones_(module.residual_scale) + init.zeros_(module.residual_bias) + elif isinstance(module, ZayaModel): + init.ones_(module.input_hidden_states_scale) + init.zeros_(module.input_hidden_states_bias) + elif isinstance(module, ZayaQKNorm): + init.zeros_(module.temp) + elif isinstance(module, ZayaRouter): + if module.use_eda: + init.ones_(module.router_states_scale) + init.zeros_(module.balancing_biases) + module.balancing_biases[-1] = -1.0 # ignore: trf012 + elif isinstance(module, ZayaExperts): + std = self.config.initializer_range + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, ZayaRotaryEmbedding): + for layer_type in module.layer_types: + rope_init_fn = module.compute_default_rope_parameters + if module.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]] + curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type) + getattr(module, f"{layer_type}_inv_freq").copy_(curr_inv_freq) + getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq) + + +@auto_docstring +class ZayaModel(ZayaPreTrainedModel): + def __init__(self, config: ZayaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = ZayaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) + self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError( + "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." + ) + + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. + sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} + mask_creation_functions = { + "hybrid": lambda: create_causal_mask(**mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + causal_mask_mapping = { + layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types) + } + cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds) + + hidden_states = inputs_embeds + + position_embeddings = { + layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) + for layer_type in set(self.config.layer_types) + } + + hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale + + prev_router_hidden_states = None + + for idx, decoder_layer in enumerate(self.layers): + layer_type = self.config.layer_types[idx] + hidden_states, prev_router_hidden_states = decoder_layer( + hidden_states, + prev_router_hidden_states, + attention_mask={"causal": causal_mask_mapping[layer_type], "padding": cca_mask}, + past_key_values=past_key_values, + position_embeddings=position_embeddings[layer_type], + **kwargs, + ) + + hidden_states = self.final_norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds): + """ + No need to zero padding states when cached convolution states are already available or all inputs are valid. + """ + cca_mask = attention_mask + if (past_key_values is not None and past_key_values.has_previous_state()) or ( + attention_mask is not None and torch.all(attention_mask == 1) + ): + cca_mask = None + elif attention_mask is not None: + cca_mask = attention_mask[:, -inputs_embeds.shape[1] :] + return cca_mask + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + _is_stateful = True + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = ZayaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_router_logits: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, ZayaForCausalLM + + >>> model = ZayaForCausalLM.from_pretrained("meta-zaya/Zaya-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-zaya/Zaya-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return MoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["ZayaPreTrainedModel", "ZayaModel", "ZayaForCausalLM"] diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py new file mode 100644 index 000000000000..1967bf6fc64a --- /dev/null +++ b/src/transformers/models/zaya/modular_zaya.py @@ -0,0 +1,674 @@ +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Zaya model.""" + +from collections.abc import Callable +from typing import Any, Literal + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from huggingface_hub.dataclasses import strict +from torch import nn +from torch.nn import init + +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PreTrainedConfig +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_outputs import MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, +) +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..afmoe.modeling_afmoe import AfmoeForCausalLM +from ..laguna.configuration_laguna import LagunaConfig +from ..laguna.modeling_laguna import LagunaModel, LagunaRotaryEmbedding +from ..llama.modeling_llama import LlamaDecoderLayer, LlamaPreTrainedModel, repeat_kv +from ..phi3.modeling_phi3 import Phi3Attention +from ..qwen3_5_moe.modeling_qwen3_5_moe import ( + apply_rotary_pos_emb, + eager_attention_forward, +) +from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeExperts, Qwen3MoeRMSNorm + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +@strict +class ZayaConfig(LagunaConfig): + r""" + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the language modeling head. + router_hidden_size (`int`, *optional*, defaults to 256): + Hidden size used by the ZAYA router. + cca_time0 (`int`, *optional*, defaults to 2): + First temporal parameter of the CCA projection. + cca_time1 (`int`, *optional*, defaults to 2): + Second temporal parameter of the CCA projection. + + ```python + >>> from transformers import ZayaConfig, ZayaModel + + >>> configuration = ZayaConfig() + >>> model = ZayaModel(configuration) + + >>> configuration = model.config + ``` + """ + + model_type = "zaya" + + vocab_size: int = 262272 + moe_intermediate_size: int = 2048 + num_attention_heads: int = 8 + num_key_value_heads: int = 2 + tie_word_embeddings: bool = True + rms_norm_eps: float = 1e-5 + sliding_window: int | None = None + pad_token_id: int | None = 0 + bos_token_id: int | None = 2 + eos_token_id: int | list[int] | None = 106 + + num_experts_per_tok: int = 1 + num_experts: int = 16 + + lm_head_bias: bool = False + router_hidden_size: int = 256 + cca_time0: int = 2 + cca_time1: int = 2 + + # Fields declared by LagunaConfig but not used by ZAYA. + # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state. + base_model_tp_plan = AttributeError() + base_model_pp_plan = AttributeError() + intermediate_size = AttributeError() + shared_expert_intermediate_size = AttributeError() + router_aux_loss_coef = AttributeError() + num_attention_heads_per_layer = AttributeError() + mlp_layer_types = AttributeError() + moe_routed_scaling_factor = AttributeError() + moe_apply_router_weight_on_input = AttributeError() + moe_router_logit_softcapping = AttributeError() + + def __post_init__(self, **kwargs): + self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) + + default_rope_params: dict[Literal["hybrid", "hybrid_sliding"], dict[str, Any]] = { + "hybrid": { + "rope_type": "default", + "rope_theta": 5_000_000.0, + "partial_rotary_factor": 0.5, + }, + "hybrid_sliding": { + "rope_type": "default", + "rope_theta": 10_000.0, + "partial_rotary_factor": 0.5, + }, + } + if self.rope_parameters is None: + self.rope_parameters = default_rope_params + + PreTrainedConfig.__post_init__(self, **kwargs, ignore_keys_at_rope_validation={"hybrid", "hybrid_sliding"}) + + def convert_rope_params_to_dict(self, **kwargs): + # No legacy flat RoPE format is supported here; conversion writes the nested ZAYA layer-type format directly. + return kwargs + + def validate_architecture(self): + if self.num_experts_per_tok != 1: + raise ValueError("ZAYA currently supports `num_experts_per_tok=1` only.") + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") + if "hybrid_sliding" in self.layer_types and self.sliding_window is None: + raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.") + + +class ZayaRotaryEmbedding(LagunaRotaryEmbedding): + pass + + +class ZayaRMSNorm(Qwen3MoeRMSNorm): + pass + + +class ZayaCCAProjection(nn.Module): + """ + Projects hidden states into attention q/k/v states with ZAYA's CCA path. + + `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal + `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail; + for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. + Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token + `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection + from **the recurrent cache**. + """ + + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + + self.depthwise_kernel_size = config.cca_time0 + self.grouped_kernel_size = config.cca_time1 + self.conv_kernel_size = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) + + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + query_hidden_size = self.num_attention_heads * self.head_dim + key_value_hidden_size = self.num_key_value_heads * self.head_dim + + self.q_proj = nn.Linear(self.hidden_size, query_hidden_size, bias=self.config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, key_value_hidden_size, bias=self.config.attention_bias) + self.v_proj_current = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + self.v_proj_delayed = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + + conv_channels = key_value_hidden_size + query_hidden_size + self.conv_qk_depthwise = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.depthwise_kernel_size, + groups=conv_channels, + padding=0, + stride=1, + ) + self.conv_qk_grouped = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.grouped_kernel_size, + groups=(self.num_key_value_heads + self.num_attention_heads), + padding=0, + stride=1, + ) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Cache | None, + padding_mask: torch.Tensor | None = None, + ): + if padding_mask is not None: + hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + projected_queries = self.q_proj(hidden_states) + projected_keys = self.k_proj(hidden_states) + qk_states = torch.cat([projected_queries, projected_keys], dim=-1) + + query_residual = projected_queries.view(*hidden_shape) + key_residual = projected_keys.view(*hidden_shape).transpose(1, 2) + key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) + query_residual = (query_residual + key_residual) * 0.5 + key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) + + qk_states = qk_states.transpose(1, 2) + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx) + if use_precomputed_states: + cached_qk_states = past_key_values.layers[self.layer_idx].conv_states + qk_states = torch.cat([cached_qk_states, qk_states], dim=-1) + else: + qk_states = F.pad(qk_states, (self.conv_kernel_size, 0)) + + if past_key_values is not None: + new_conv_state = qk_states[..., -self.conv_kernel_size :] + if new_conv_state.shape[-1] < self.conv_kernel_size: + new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) + past_key_values.update_conv_state(new_conv_state, self.layer_idx) + + qk_states = self.conv_qk_depthwise(qk_states) + qk_states = self.conv_qk_grouped(qk_states).transpose(1, 2) + + query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1] + query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual + key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual + + # The value path carries half of each value head from the current token and half from the previous token. + # During cached decoding, `recurrent_v_state` is the previous token's delayed projection. + value_current = self.v_proj_current(hidden_states) + delayed_v_state = self.v_proj_delayed(hidden_states) + if use_precomputed_states: + recurrent_v_state = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) + else: + recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) + value_delayed = torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1) + + if past_key_values is not None: + past_key_values.update_recurrent_state(delayed_v_state[:, -1, :], self.layer_idx) + + value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape) + + return query, key, value + + +class ZayaQKNorm(nn.Module): + """ + L2-normalizes q/k states to sqrt(head_dim) and applies ZAYA's learned per-KV-head key scale. + """ + + def __init__(self, config: ZayaConfig): + super().__init__() + self.head_dim_scale = config.head_dim**0.5 + self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + + def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + norm_eps = torch.finfo(query_states.dtype).eps + query_states = query_states * ( + self.head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * ( + self.head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * self.temp[None, None, :, None] + return query_states, key_states + + +class ZayaAttention(Phi3Attention): + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__(config, layer_idx) + del op_size # noqa: F821 + self.layer_type = config.layer_types[layer_idx] + self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.qk_norm = ZayaQKNorm(config) + self.qkv_proj = ZayaCCAProjection( + config=self.config, + layer_idx=layer_idx, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + + mask_mapping = attention_mask or {} + causal_mask = mask_mapping.get("causal") + padding_mask = mask_mapping.get("padding") + + query_states, key_states, value_states = self.qkv_proj(hidden_states, past_key_values, padding_mask) + + query_states, key_states = self.qk_norm(query_states, key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + causal_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.view(*input_shape, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class ZayaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.mlp = ZayaSparseMoeBlock(config, layer_idx) + self.post_attention_residual_scale = ZayaResidualScaling(config.hidden_size) + self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_embeddings=position_embeddings, + **kwargs, + ) + + residual = self.post_attention_residual_scale(hidden_states, residual) + hidden_states = self.post_attention_layernorm(residual) + + hidden_states, prev_router_hidden_states = self.mlp( + hidden_states, + prev_router_hidden_states, + ) + + hidden_states = self.post_mlp_residual_scale(hidden_states, residual) + + return hidden_states, prev_router_hidden_states + + +class ZayaResidualScaling(nn.Module): + def __init__(self, hidden_size: int): + super().__init__() + self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) + self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size)) + self.residual_scale = nn.Parameter(torch.ones(hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): + output_dtype = hidden_states.dtype + hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale + # Matches the original ZAYA `residual_in_fp32` path. + residual = residual.to(torch.float32) + residual = (residual + self.residual_bias) * self.residual_scale + return (hidden_states + residual).to(output_dtype) + + +class ZayaRouterMLP(nn.Module): + def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float): + super().__init__() + self.rmsnorm_eda = ZayaRMSNorm(hidden_size, eps=rms_norm_eps) + self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) + self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.rmsnorm_eda(hidden_states) + hidden_states = self.act_fn(self.fc1(hidden_states)) + hidden_states = self.act_fn(self.fc2(hidden_states)) + return self.out_proj(hidden_states) + + +class ZayaRouter(nn.Module): + def __init__( + self, + config, + layer_idx: int, + ) -> None: + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.num_experts = config.num_experts + 1 + self.top_k = config.num_experts_per_tok + self.router_hidden_size = config.router_hidden_size + + self.down_proj = nn.Linear(self.hidden_size, self.router_hidden_size, bias=True) + + self.use_eda = self.layer_idx != 0 + if self.use_eda: + self.router_states_scale = nn.Parameter(torch.ones(self.router_hidden_size)) + + self.router_mlp = ZayaRouterMLP(self.router_hidden_size, self.num_experts, config.rms_norm_eps) + + self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + self.balancing_biases[-1] = -1.0 + + def forward( + self, + hidden_states: torch.Tensor, + router_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + final_shape = (-1, self.top_k) + seq_length = hidden_states.shape[1] + + router_hidden_states = self.down_proj(hidden_states) + + if self.use_eda and router_states is not None: + router_hidden_states = router_hidden_states + router_states * self.router_states_scale + + router_hidden_states_next = router_hidden_states[:, -seq_length:].clone() + router_logits = self.router_mlp(router_hidden_states) + router_probs = torch.softmax(router_logits, dim=-1) + + biased_router_probs = router_probs.detach().to(torch.float32) + self.balancing_biases + _, router_indices = torch.topk(biased_router_probs, self.top_k, dim=-1) + router_probs = torch.gather(router_probs, dim=2, index=router_indices) + + # If the router selects the extra skip expert, mask it before `ZayaExperts` builds its one-hot expert mask. + skip_expert = router_indices == self.config.num_experts + router_probs = router_probs.masked_fill(skip_expert, 0) + router_indices = router_indices.masked_fill(skip_expert, 0) + + return ( + router_logits.reshape(-1, self.num_experts), + router_probs.reshape(final_shape), + router_indices.reshape(final_shape), + router_hidden_states_next, + ) + + +class ZayaExperts(Qwen3MoeExperts): + pass + + +class ZayaSparseMoeBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.gate = ZayaRouter(config, layer_idx) + self.experts = ZayaExperts(config) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + _, router_probs, router_indices, prev_router_hidden_states = self.gate( + hidden_states, router_states=prev_router_hidden_states + ) + + batch_size, seq_length, emb_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) + expert_output = self.experts(hidden_states_flat, router_indices, router_probs) + expert_output = expert_output.view(batch_size, seq_length, emb_dim) + + return expert_output, prev_router_hidden_states + + +class ZayaPreTrainedModel(LlamaPreTrainedModel): + config: ZayaConfig + # ZAYA generation uses the native hybrid dynamic cache, which is not a compileable cache. + _can_compile_fullgraph = False + _can_record_outputs = { + "router_logits": OutputRecorder(ZayaRouter, index=0), + "hidden_states": ZayaDecoderLayer, + "attentions": ZayaAttention, + } + + @torch.no_grad() + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, ZayaResidualScaling): + init.ones_(module.hidden_states_scale) + init.zeros_(module.hidden_states_bias) + init.ones_(module.residual_scale) + init.zeros_(module.residual_bias) + elif isinstance(module, ZayaModel): + init.ones_(module.input_hidden_states_scale) + init.zeros_(module.input_hidden_states_bias) + elif isinstance(module, ZayaQKNorm): + init.zeros_(module.temp) + elif isinstance(module, ZayaRouter): + if module.use_eda: + init.ones_(module.router_states_scale) + init.zeros_(module.balancing_biases) + module.balancing_biases[-1] = -1.0 # ignore: trf012 + elif isinstance(module, ZayaExperts): + std = self.config.initializer_range + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, ZayaRotaryEmbedding): + for layer_type in module.layer_types: + rope_init_fn = module.compute_default_rope_parameters + if module.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]] + curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type) + getattr(module, f"{layer_type}_inv_freq").copy_(curr_inv_freq) + getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq) + + +@auto_docstring +class ZayaModel(LagunaModel): + def __init__(self, config: ZayaConfig): + super().__init__(config) + del self.norm + self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) + self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError( + "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." + ) + + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. + sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} + mask_creation_functions = { + "hybrid": lambda: create_causal_mask(**mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + causal_mask_mapping = { + layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types) + } + cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds) + + hidden_states = inputs_embeds + + position_embeddings = { + layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) + for layer_type in set(self.config.layer_types) + } + + hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale + + prev_router_hidden_states = None + + for idx, decoder_layer in enumerate(self.layers): + layer_type = self.config.layer_types[idx] + hidden_states, prev_router_hidden_states = decoder_layer( + hidden_states, + prev_router_hidden_states, + attention_mask={"causal": causal_mask_mapping[layer_type], "padding": cca_mask}, + past_key_values=past_key_values, + position_embeddings=position_embeddings[layer_type], + **kwargs, + ) + + hidden_states = self.final_norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds): + """ + No need to zero padding states when cached convolution states are already available or all inputs are valid. + """ + cca_mask = attention_mask + if (past_key_values is not None and past_key_values.has_previous_state()) or ( + attention_mask is not None and torch.all(attention_mask == 1) + ): + cca_mask = None + elif attention_mask is not None: + cca_mask = attention_mask[:, -inputs_embeds.shape[1] :] + return cca_mask + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-8B") +class ZayaForCausalLM(AfmoeForCausalLM, ZayaPreTrainedModel): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _is_stateful = True + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) + + +__all__ = [ + "ZayaConfig", + "ZayaPreTrainedModel", + "ZayaModel", + "ZayaForCausalLM", +] diff --git a/src/transformers/models/zaya1_vl/__init__.py b/src/transformers/models/zaya1_vl/__init__.py new file mode 100644 index 000000000000..ccda0a4dcc84 --- /dev/null +++ b/src/transformers/models/zaya1_vl/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2026 Zyphra and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_zaya1_vl import * + from .modeling_zaya1_vl import * + from .processing_zaya1_vl import * + +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/zaya1_vl/configuration_zaya1_vl.py b/src/transformers/models/zaya1_vl/configuration_zaya1_vl.py new file mode 100644 index 000000000000..8f8de4a45519 --- /dev/null +++ b/src/transformers/models/zaya1_vl/configuration_zaya1_vl.py @@ -0,0 +1,209 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zaya1_vl/modular_zaya1_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zaya1_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Literal + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-VL-8B") +@strict +class Zaya1VLVisionConfig(PreTrainedConfig): + r""" + window_size (`int`, *optional*, defaults to 112): + Window size used by the Qwen2.5-VL vision encoder. + out_hidden_size (`int`, *optional*, defaults to 2048): + Output hidden size after the vision merger. + fullatt_block_indexes (`list[int]`, *optional*, defaults to `[7, 15, 23, 31]`): + Vision encoder layers that use full attention. + """ + + model_type = "zaya1_vl_vision" + base_config_key = "vision_config" + + depth: int = 32 + + hidden_size: int = 1280 + hidden_act: str = "silu" + intermediate_size: int = 3420 + num_heads: int = 16 + in_channels: int = 3 + patch_size: int | list[int] | tuple[int, int] = 14 + spatial_merge_size: int = 2 + temporal_patch_size: int | list[int] | tuple[int, int] = 1 + window_size: int = 112 + out_hidden_size: int = 2048 + fullatt_block_indexes: list[int] | tuple[int, ...] = (7, 15, 23, 31) + initializer_range: float = 0.02 + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-VL-8B") +@strict +class Zaya1VLTextConfig(PreTrainedConfig): + r""" + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the language modeling head. + router_hidden_size (`int`, *optional*, defaults to 256): + Hidden size used by the ZAYA router. + cca_time0 (`int`, *optional*, defaults to 2): + First temporal parameter of the CCA projection. + cca_time1 (`int`, *optional*, defaults to 2): + Second temporal parameter of the CCA projection. + + vision_lora (`bool`, *optional*, defaults to `True`): + Whether to enable LoRA modules that are applied only on vision-token positions. + vision_lora_rank_attn (`int`, *optional*, defaults to 8): + LoRA rank for the CCA and attention output projections applied to vision-token positions. + vision_lora_rank_mlp (`int`, *optional*, defaults to 32): + LoRA rank for the MoE expert projections applied to vision-token positions. + """ + + model_type = "zaya1_vl_text" + keys_to_ignore_at_inference = ["past_key_values"] + + vocab_size: int = 262272 + hidden_size: int = 2048 + num_hidden_layers: int = 40 + num_attention_heads: int = 8 + num_key_value_heads: int = 2 + hidden_act: str = "silu" + max_position_embeddings: int = 131072 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-5 + use_cache: bool = True + tie_word_embeddings: bool = True + rope_parameters: RopeParameters | dict | None = None + sliding_window: int | None = None + attention_dropout: float | int = 0.0 + moe_intermediate_size: int = 2048 + + num_experts_per_tok: int = 1 + num_experts: int = 16 + output_router_logits: bool = False + layer_types: list[str] | None = None + pad_token_id: int | None = 0 + bos_token_id: int | None = 2 + eos_token_id: int | list[int] | None = 106 + + # Zaya1VLText-specific attention + head_dim: int = 128 + attention_bias: bool = False + + lm_head_bias: bool = False + router_hidden_size: int = 256 + cca_time0: int = 2 + cca_time1: int = 2 + base_config_key = "text_config" + + vision_lora: bool = True + vision_lora_rank_attn: int = 8 + vision_lora_rank_mlp: int = 32 + + def __post_init__(self, **kwargs): + self.layer_types = ["hybrid"] * self.num_hidden_layers if self.layer_types is None else list(self.layer_types) + + default_rope_params: dict[Literal["hybrid", "hybrid_sliding"], dict[str, Any]] = { + "hybrid": { + "rope_type": "default", + "rope_theta": 5_000_000.0, + "partial_rotary_factor": 0.5, + }, + "hybrid_sliding": { + "rope_type": "default", + "rope_theta": 10_000.0, + "partial_rotary_factor": 0.5, + }, + } + if self.rope_parameters is None: + self.rope_parameters = default_rope_params + + super().__post_init__(**kwargs, ignore_keys_at_rope_validation={"hybrid", "hybrid_sliding"}) + + def convert_rope_params_to_dict(self, **kwargs): + # No legacy flat RoPE format is supported here; conversion writes the nested ZAYA1_VL_TEXT layer-type format directly. + return kwargs + + def validate_architecture(self): + """Part of ``@strict``-powered validation.""" + if self.num_experts_per_tok != 1: + raise ValueError("ZAYA1_VL_TEXT currently supports `num_experts_per_tok=1` only.") + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError("`num_attention_heads` must be a multiple of `num_key_value_heads`.") + if "hybrid_sliding" in self.layer_types and self.sliding_window is None: + raise ValueError("`sliding_window` must be set when `layer_types` contains `hybrid_sliding`.") + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-VL-8B") +@strict +class Zaya1VLConfig(PreTrainedConfig): + r""" + text_config (`dict` or `Zaya1VLTextConfig`, *optional*): + Configuration for the ZAYA text decoder. + vision_config (`dict` or `Zaya1VLVisionConfig`, *optional*): + Configuration for the Qwen2.5-VL vision encoder. + image_token_id (`int`, *optional*, defaults to 262147): + Token id used as an image placeholder. + vision_start_token_id (`int`, *optional*, defaults to 255999): + Token id that starts an image span. + vision_end_token_id (`int`, *optional*, defaults to 256000): + Token id that ends an image span. + """ + + model_type = "zaya1_vl" + sub_configs = {"vision_config": Zaya1VLVisionConfig, "text_config": Zaya1VLTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + text_config: dict | PreTrainedConfig | None = None + vision_config: dict | PreTrainedConfig | None = None + + image_token_id: int = 262147 + vision_start_token_id: int = 255999 + vision_end_token_id: int = 256000 + + tie_word_embeddings: bool = True + output_router_logits: bool = False + + def __post_init__(self, **kwargs): + if isinstance(self.vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**self.vision_config) + elif self.vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + # Hub configs are saved as flat dicts so we pop some of kwargs to init `TextConfig` + text_params = inspect.signature(self.sub_configs["text_config"].__init__).parameters.keys() + text_params = list(text_params) + ["rope_parameters", "rope_scaling", "rope_theta"] + text_kwargs = {key: kwargs.pop(key) for key in text_params if key in kwargs} + + if isinstance(self.text_config, dict): + self.text_config = self.sub_configs["text_config"](**self.text_config) + elif self.text_config is None: + # Hub configs are saved as flat dicts so we pop some of kwargs to init `TextConfig` + text_kwargs["dtype"] = kwargs.get("torch_dtype", kwargs.get("dtype")) # don't pop the dtype + self.text_config = self.sub_configs["text_config"](**text_kwargs) + + super().__post_init__(**kwargs) + + +__all__ = ["Zaya1VLTextConfig", "Zaya1VLVisionConfig", "Zaya1VLConfig"] diff --git a/src/transformers/models/zaya1_vl/convert_zaya1_vl_weights_to_hf.py b/src/transformers/models/zaya1_vl/convert_zaya1_vl_weights_to_hf.py new file mode 100644 index 000000000000..170d3ed8667f --- /dev/null +++ b/src/transformers/models/zaya1_vl/convert_zaya1_vl_weights_to_hf.py @@ -0,0 +1,362 @@ +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert original ZAYA1-VL checkpoints to the Transformers-native layout.""" + +import argparse +import json +import re +import shutil +from collections import defaultdict +from pathlib import Path + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from transformers import Zaya1VLConfig, Zaya1VLTextConfig + + +_DEFAULT_SWA_ROPE_THETA = 10_000.0 +_LAYER_PATTERN = re.compile(r"^model\.layers\.(\d+)\.(.+)$") +_EXPERT_PATTERN = re.compile( + r"^model\.layers\.(\d+)\.mlp\.zaya_block\.experts\.local_experts\.(\d+)\." + r"(linear_fc[12]|lora_fc[12]\.[01])\.weight$" +) + +_UNUSED_CONFIG_KEYS = ( + "activation_func", + "activation_func_fp8_input_store", + "add_bias_linear", + "apply_rope_fusion", + "ar_threshold", + "bias_activation_fusion", + "cca", + "clamp_temp", + "ffn_hidden_size", + "fused_add_norm", + "gated_linear_unit", + "lora_rank", + "moe_router_topk", + "norm_epsilon", + "normalization", + "num_query_groups", + "projector_hidden_act", + "residual_in_fp32", + "rotary_base", + "scale_residual_merge", + "temporal_patch_size", + "use_lora_att", + "use_rope_scaling", + "zaya_mlp_expansion", + "zaya_use_eda", + "zaya_use_mod", +) + +_VISION_CONFIG_UNUSED_KEYS = ("_attn_implementation_autoset", "in_chans", "model_type", "torch_dtype") + + +def _rename_common(rest: str) -> str: + replacements = ( + ("self_attn.qkv.conv_qk.0.", "self_attn.qkv_proj.conv_qk_depthwise."), + ("self_attn.qkv.conv_qk.1.", "self_attn.qkv_proj.conv_qk_grouped."), + ("self_attn.qkv.temp", "self_attn.qk_norm.temp"), + ("self_attn.qkv.linear_q.", "self_attn.qkv_proj.q_proj."), + ("self_attn.qkv.linear_k.", "self_attn.qkv_proj.k_proj."), + ("self_attn.qkv.val_proj1.", "self_attn.qkv_proj.v_proj_current."), + ("self_attn.qkv.val_proj2.", "self_attn.qkv_proj.v_proj_delayed."), + ("self_attn.qkv.lora_linear_q.0.", "self_attn.qkv_proj.q_lora_a."), + ("self_attn.qkv.lora_linear_q.1.", "self_attn.qkv_proj.q_lora_b."), + ("self_attn.qkv.lora_linear_k.0.", "self_attn.qkv_proj.k_lora_a."), + ("self_attn.qkv.lora_linear_k.1.", "self_attn.qkv_proj.k_lora_b."), + ("self_attn.qkv.lora_val_proj1.0.", "self_attn.qkv_proj.v_current_lora_a."), + ("self_attn.qkv.lora_val_proj1.1.", "self_attn.qkv_proj.v_current_lora_b."), + ("self_attn.qkv.lora_val_proj2.0.", "self_attn.qkv_proj.v_delayed_lora_a."), + ("self_attn.qkv.lora_val_proj2.1.", "self_attn.qkv_proj.v_delayed_lora_b."), + ("self_attn.lora_linear_o.0.", "self_attn.o_lora_a."), + ("self_attn.lora_linear_o.1.", "self_attn.o_lora_b."), + ("self_attn.qkv.", "self_attn.qkv_proj."), + ("zaya_block.router.rmsnorm_eda.", "mlp.gate.router_mlp.rmsnorm_eda."), + ("zaya_block.router.router_mlp.0.", "mlp.gate.router_mlp.fc1."), + ("zaya_block.router.router_mlp.2.", "mlp.gate.router_mlp.fc2."), + ("zaya_block.router.router_mlp.4.", "mlp.gate.router_mlp.out_proj."), + ("zaya_block.router.", "mlp.gate."), + ("zaya_block.", "mlp."), + ) + for old, new in replacements: + if rest.startswith(old): + return new + rest.removeprefix(old) + return rest + + +def _expert_target(name: str) -> tuple[str, int] | None: + match = _EXPERT_PATTERN.match(name) + if match is None: + return None + + layer_idx = int(match.group(1)) + expert_idx = int(match.group(2)) + projection = match.group(3) + target_projection = { + "linear_fc1": "gate_up_proj", + "linear_fc2": "down_proj", + "lora_fc1.0": "lora_gate_up_proj_a", + "lora_fc1.1": "lora_gate_up_proj_b", + "lora_fc2.0": "lora_down_proj_a", + "lora_fc2.1": "lora_down_proj_b", + }[projection] + target = f"model.language_model.layers.{layer_idx}.mlp.experts.{target_projection}" + return target, expert_idx + + +def convert_weight_name(name: str, num_hidden_layers: int) -> str | None: + if _expert_target(name) is not None: + return None + if name.startswith("vision_tower."): + return f"model.visual.{name.removeprefix('vision_tower.')}" + if name.startswith("model.embed_tokens."): + return f"model.language_model.embed_tokens.{name.removeprefix('model.embed_tokens.')}" + if name.startswith("model.final_norm."): + return f"model.language_model.final_norm.{name.removeprefix('model.final_norm.')}" + if name.startswith("model.res_scale."): + return ( + f"model.language_model.layers.{num_hidden_layers - 1}.post_mlp_residual_scale." + f"{name.removeprefix('model.res_scale.')}" + ) + + match = _LAYER_PATTERN.match(name) + if match is None: + return name + + layer_idx = int(match.group(1)) + rest = match.group(2) + if rest.startswith("attn."): + rest = _rename_common(rest.removeprefix("attn.")) + if rest.startswith("self_attn."): + return f"model.language_model.layers.{layer_idx}.{rest}" + if rest.startswith("input_norm."): + return f"model.language_model.layers.{layer_idx}.input_layernorm.{rest.removeprefix('input_norm.')}" + if rest.startswith("res_scale."): + if layer_idx == 0: + return f"model.language_model.input_{rest.removeprefix('res_scale.')}" + return ( + f"model.language_model.layers.{layer_idx - 1}.post_mlp_residual_scale." + f"{rest.removeprefix('res_scale.')}" + ) + if rest.startswith("mlp."): + rest = _rename_common(rest.removeprefix("mlp.")) + if rest.startswith("mlp."): + return f"model.language_model.layers.{layer_idx}.{rest}" + if rest.startswith("input_norm."): + return ( + f"model.language_model.layers.{layer_idx}.post_attention_layernorm.{rest.removeprefix('input_norm.')}" + ) + if rest.startswith("res_scale."): + return ( + f"model.language_model.layers.{layer_idx}.post_attention_residual_scale." + f"{rest.removeprefix('res_scale.')}" + ) + + raise ValueError(f"Unexpected ZAYA1-VL layer weight name: {name}") + + +def convert_config(input_dir: Path, output_dir: Path) -> None: + config_dict = json.loads((input_dir / "config.json").read_text()) + num_hidden_layers = int(config_dict["num_hidden_layers"]) + rms_norm_eps = config_dict.get("rms_norm_eps", config_dict.get("norm_epsilon", Zaya1VLTextConfig.rms_norm_eps)) + router_hidden_size = config_dict.get("router_hidden_size", config_dict.get("zaya_mlp_expansion", 256)) + expert_ffn_size = config_dict.get("intermediate_size", config_dict.get("ffn_hidden_size")) + moe_intermediate_size = ( + expert_ffn_size // 2 if expert_ffn_size is not None else Zaya1VLTextConfig.moe_intermediate_size + ) + num_experts_per_tok = config_dict.get("num_experts_per_tok", config_dict.get("moe_router_topk", 1)) + rope_theta = config_dict.get("rope_theta", config_dict.get("rotary_base", 1_000_000.0)) + swa_rotary_base = config_dict.get("swa_rotary_base", _DEFAULT_SWA_ROPE_THETA) + + vision_config = dict(config_dict.pop("vision_config", {})) + vision_config["in_channels"] = vision_config.get("in_channels", vision_config.get("in_chans", 3)) + vision_config.setdefault("depth", 32) + vision_config.setdefault("hidden_size", 1280) + vision_config.setdefault("intermediate_size", 3420) + vision_config.setdefault("num_heads", 16) + vision_config.setdefault("hidden_act", "silu") + vision_config.setdefault("patch_size", vision_config.pop("spatial_patch_size", 14)) + vision_config.setdefault("spatial_merge_size", 2) + vision_config.setdefault("temporal_patch_size", config_dict.get("temporal_patch_size", 1)) + vision_config.setdefault("tokens_per_second", 2) + vision_config.setdefault("window_size", 112) + vision_config.setdefault("out_hidden_size", config_dict["hidden_size"]) + vision_config.setdefault("fullatt_block_indexes", [7, 15, 23, 31]) + for key in _VISION_CONFIG_UNUSED_KEYS: + vision_config.pop(key, None) + + rope_parameters = { + "hybrid": { + "rope_type": "default", + "rope_theta": rope_theta, + "partial_rotary_factor": config_dict.get("rope_pct", 0.5), + }, + "hybrid_sliding": { + "rope_type": "default", + "rope_theta": swa_rotary_base, + "partial_rotary_factor": config_dict.get("rope_pct", 0.5), + }, + } + + for key in (*_UNUSED_CONFIG_KEYS, "rope_pct", "swa_layers", "swa_rotary_base"): + config_dict.pop(key, None) + + image_token_id = config_dict.pop("image_token_id", Zaya1VLConfig.image_token_id) + vision_start_token_id = config_dict.pop("vision_start_token_id", Zaya1VLConfig.vision_start_token_id) + vision_end_token_id = config_dict.pop("vision_end_token_id", Zaya1VLConfig.vision_end_token_id) + tie_word_embeddings = config_dict.get("tie_word_embeddings", Zaya1VLConfig.tie_word_embeddings) + output_router_logits = config_dict.get("output_router_logits", Zaya1VLConfig.output_router_logits) + + text_config = { + **config_dict, + "model_type": "zaya1_vl_text", + "rms_norm_eps": rms_norm_eps, + "moe_intermediate_size": moe_intermediate_size, + "router_hidden_size": router_hidden_size, + "num_experts_per_tok": num_experts_per_tok, + "layer_types": ["hybrid"] * num_hidden_layers, + "rope_parameters": rope_parameters, + } + text_config.pop("architectures", None) + text_config.pop("model_type", None) + text_config = Zaya1VLTextConfig(**text_config).to_dict() + + config = Zaya1VLConfig( + architectures=["Zaya1VLForConditionalGeneration"], + text_config=text_config, + vision_config=vision_config, + image_token_id=image_token_id, + vision_start_token_id=vision_start_token_id, + vision_end_token_id=vision_end_token_id, + tie_word_embeddings=tie_word_embeddings, + output_router_logits=output_router_logits, + ) + config.save_pretrained(output_dir) + + +def copy_non_weight_files(input_dir: Path, output_dir: Path) -> None: + for path in input_dir.iterdir(): + if path.name == "config.json": + continue + if path.name.endswith(".safetensors") or path.name.endswith(".bin"): + continue + if path.name in {"model.safetensors.index.json", "pytorch_model.bin.index.json"}: + continue + + output_path = output_dir / path.name + if path.is_dir(): + shutil.copytree(path, output_path, dirs_exist_ok=True) + else: + if path.name == "preprocessor_config.json": + preprocessor_config = json.loads(path.read_text()) + preprocessor_config["processor_class"] = "Zaya1VLProcessor" + output_path.write_text(json.dumps(preprocessor_config, indent=2, sort_keys=True) + "\n") + else: + shutil.copy2(path, output_path) + + +def _build_weight_plan(input_dir: Path): + index = json.loads((input_dir / "model.safetensors.index.json").read_text()) + old_weight_map = index["weight_map"] + num_hidden_layers = int(json.loads((input_dir / "config.json").read_text())["num_hidden_layers"]) + converted_weight_map = {} + normal_sources_by_output_file = defaultdict(list) + expert_sources_by_target = defaultdict(list) + output_file_by_target = {} + + for source_key, filename in old_weight_map.items(): + expert_info = _expert_target(source_key) + if expert_info is not None: + target_key, expert_idx = expert_info + expert_sources_by_target[target_key].append((expert_idx, source_key)) + output_file_by_target.setdefault(target_key, filename) + converted_weight_map[target_key] = output_file_by_target[target_key] + continue + + target_key = convert_weight_name(source_key, num_hidden_layers) + if target_key in converted_weight_map: + raise ValueError(f"Duplicate converted weight name: {target_key}") + converted_weight_map[target_key] = filename + normal_sources_by_output_file[filename].append((source_key, target_key)) + + index["weight_map"] = converted_weight_map + return normal_sources_by_output_file, expert_sources_by_target, output_file_by_target, old_weight_map, index + + +def _load_sources(input_dir: Path, source_keys: list[str], old_weight_map: dict[str, str]) -> dict[str, torch.Tensor]: + sources_by_file = defaultdict(list) + for source_key in source_keys: + sources_by_file[old_weight_map[source_key]].append(source_key) + + tensors = {} + for filename, keys in sources_by_file.items(): + with safe_open(input_dir / filename, framework="pt", device="cpu") as f: + for key in keys: + tensors[key] = f.get_tensor(key) + return tensors + + +def convert_weights(input_dir: Path, output_dir: Path) -> None: + ( + normal_sources_by_output_file, + expert_sources_by_target, + output_file_by_target, + old_weight_map, + index, + ) = _build_weight_plan(input_dir) + expert_tensors_by_output_file = defaultdict(dict) + + for target_key, indexed_sources in expert_sources_by_target.items(): + indexed_sources = sorted(indexed_sources) + source_keys = [source_key for _, source_key in indexed_sources] + sources = _load_sources(input_dir, source_keys, old_weight_map) + expert_tensors_by_output_file[output_file_by_target[target_key]][target_key] = torch.stack( + [sources[source_key] for _, source_key in indexed_sources], dim=0 + ).contiguous() + + for filename, source_and_target_keys in normal_sources_by_output_file.items(): + tensors = {} + with safe_open(input_dir / filename, framework="pt", device="cpu") as f: + for source_key, target_key in source_and_target_keys: + tensors[target_key] = f.get_tensor(source_key) + tensors.update(expert_tensors_by_output_file.pop(filename, {})) + save_file(tensors, output_dir / filename, metadata={"format": "pt"}) + + for filename, tensors in expert_tensors_by_output_file.items(): + save_file(tensors, output_dir / filename, metadata={"format": "pt"}) + + (output_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2, sort_keys=True) + "\n") + + +def convert_checkpoint(input_dir: Path, output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + convert_config(input_dir, output_dir) + copy_non_weight_files(input_dir, output_dir) + convert_weights(input_dir, output_dir) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=Path, required=True) + parser.add_argument("--output_dir", type=Path, required=True) + args = parser.parse_args() + convert_checkpoint(args.input_dir, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/zaya1_vl/modeling_zaya1_vl.py b/src/transformers/models/zaya1_vl/modeling_zaya1_vl.py new file mode 100644 index 000000000000..b62e7eec7a20 --- /dev/null +++ b/src/transformers/models/zaya1_vl/modeling_zaya1_vl.py @@ -0,0 +1,1545 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zaya1_vl/modular_zaya1_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zaya1_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPooling, MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check +from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from .configuration_zaya1_vl import Zaya1VLConfig, Zaya1VLTextConfig, Zaya1VLVisionConfig + + +class Zaya1VLRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Zaya1VLConfig): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + self.layer_types = list(set(config.layer_types)) + self.rope_type = {} + for layer_type in self.layer_types: + rope_params = self.config.rope_parameters[layer_type] + if rope_params is None: + continue + + self.rope_type[layer_type] = rope_params["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] + curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, layer_type=layer_type) + self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False) + self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False) + setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling) + + @staticmethod + def compute_default_rope_parameters( + config: Zaya1VLConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + layer_type (`str`, *optional*): + The current layer type if the model has different RoPE parameters per type. + Should not be used unless `config.layer_types is not None` + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters[layer_type]["rope_theta"] + # key difference to gemma3: partial rope + partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids, layer_type=None): + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _make_lora_pair(in_features: int, rank: int, out_features: int) -> tuple[nn.Linear, nn.Linear]: + return nn.Linear(in_features, rank, bias=False), nn.Linear(rank, out_features, bias=False) + + +def _apply_masked_lora( + output: torch.Tensor, + hidden_states: torch.Tensor, + lora_a: nn.Linear | torch.Tensor, + lora_b: nn.Linear | torch.Tensor, + mask: torch.Tensor | None, +) -> torch.Tensor: + if mask is None: + return output + indices = mask.nonzero(as_tuple=True) + if indices[0].numel() == 0: + return output + hidden_states = hidden_states[indices] + hidden_states = F.linear(hidden_states, lora_a) if isinstance(lora_a, torch.Tensor) else lora_a(hidden_states) + hidden_states = F.linear(hidden_states, lora_b) if isinstance(lora_b, torch.Tensor) else lora_b(hidden_states) + return output.index_put(indices, hidden_states.to(output.dtype), accumulate=True) + + +class Zaya1VLCCAProjection(nn.Module): + """ + Projects hidden states into attention q/k/v states with ZAYA1_VL's CCA path. + + `q_proj` and `k_proj` produce the residual q/k states and are concatenated into `qk_states`. The causal + `conv_qk_depthwise` + `conv_qk_grouped` stack mixes the current q/k stream with the cached pre-convolution tail; + for example, decoding token `t` uses the cached q/k states from previous tokens plus the current `qk_states[:, t]`. + Values are built from `v_proj_current(hidden_states[:, t])` and a delayed `v_proj_delayed`: during prefill token + `t` uses `v_proj_delayed(hidden_states[:, t - 1])`, while decoding reads the previous delayed value projection + from **the recurrent cache**. + """ + + def __init__(self, config: Zaya1VLTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + + self.depthwise_kernel_size = config.cca_time0 + self.grouped_kernel_size = config.cca_time1 + self.conv_kernel_size = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) + + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + query_hidden_size = self.num_attention_heads * self.head_dim + key_value_hidden_size = self.num_key_value_heads * self.head_dim + + self.q_proj = nn.Linear(self.hidden_size, query_hidden_size, bias=self.config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, key_value_hidden_size, bias=self.config.attention_bias) + self.v_proj_current = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + self.v_proj_delayed = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + + conv_channels = key_value_hidden_size + query_hidden_size + self.conv_qk_depthwise = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.depthwise_kernel_size, + groups=conv_channels, + padding=0, + stride=1, + ) + self.conv_qk_grouped = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.grouped_kernel_size, + groups=(self.num_key_value_heads + self.num_attention_heads), + padding=0, + stride=1, + ) + if config.vision_lora: + self.q_lora_a, self.q_lora_b = _make_lora_pair( + self.hidden_size, config.vision_lora_rank_attn, self.num_attention_heads * self.head_dim + ) + self.k_lora_a, self.k_lora_b = _make_lora_pair( + self.hidden_size, config.vision_lora_rank_attn, self.num_key_value_heads * self.head_dim + ) + self.v_current_lora_a, self.v_current_lora_b = _make_lora_pair( + self.hidden_size, config.vision_lora_rank_attn, self.num_key_value_heads * self.head_dim // 2 + ) + self.v_delayed_lora_a, self.v_delayed_lora_b = _make_lora_pair( + self.hidden_size, config.vision_lora_rank_attn, self.num_key_value_heads * self.head_dim // 2 + ) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Cache | None, + padding_mask: torch.Tensor | None = None, + image_mask: torch.Tensor | None = None, + ): + if padding_mask is not None: + hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + projected_queries = self.q_proj(hidden_states) + projected_keys = self.k_proj(hidden_states) + + if self.config.vision_lora and image_mask is not None: + # visual specific: apply LoRA only on vision-token positions + projected_queries = _apply_masked_lora( + projected_queries, hidden_states, self.q_lora_a, self.q_lora_b, image_mask + ) + projected_keys = _apply_masked_lora( + projected_keys, hidden_states, self.k_lora_a, self.k_lora_b, image_mask + ) + + qk_states = torch.cat([projected_queries, projected_keys], dim=-1) + + query_residual = projected_queries.view(*hidden_shape) + key_residual = projected_keys.view(*hidden_shape).transpose(1, 2) + key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) + query_residual = (query_residual + key_residual) * 0.5 + key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) + + qk_states = qk_states.transpose(1, 2) + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx) + if use_precomputed_states: + cached_qk_states = past_key_values.layers[self.layer_idx].conv_states + qk_states = torch.cat([cached_qk_states, qk_states], dim=-1) + else: + qk_states = F.pad(qk_states, (self.conv_kernel_size, 0)) + + if past_key_values is not None: + new_conv_state = qk_states[..., -self.conv_kernel_size :] + if new_conv_state.shape[-1] < self.conv_kernel_size: + new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) + past_key_values.update_conv_state(new_conv_state, self.layer_idx) + + qk_states = self.conv_qk_depthwise(qk_states) + qk_states = self.conv_qk_grouped(qk_states).transpose(1, 2) + + query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1] + query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual + key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual + + value_current = self.v_proj_current(hidden_states) + delayed_v_state = self.v_proj_delayed(hidden_states) + + if self.config.vision_lora and image_mask is not None: + # visual specific: apply LoRA only on vision-token positions + value_current = _apply_masked_lora( + value_current, hidden_states, self.v_current_lora_a, self.v_current_lora_b, image_mask + ) + delayed_v_state = _apply_masked_lora( + delayed_v_state, hidden_states, self.v_delayed_lora_a, self.v_delayed_lora_b, image_mask + ) + + if use_precomputed_states: + recurrent_v_state = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) + else: + recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) + value_delayed = torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1) + + if past_key_values is not None: + past_key_values.update_recurrent_state(delayed_v_state[:, -1, :], self.layer_idx) + + value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape) + + return query, key, value + + +class Zaya1VLQKNorm(nn.Module): + """ + L2-normalizes q/k states to sqrt(head_dim) and applies ZAYA1_VL's learned per-KV-head key scale. + """ + + def __init__(self, config: Zaya1VLConfig): + super().__init__() + self.head_dim_scale = config.head_dim**0.5 + self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + + def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + norm_eps = torch.finfo(query_states.dtype).eps + query_states = query_states * ( + self.head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * ( + self.head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * self.temp[None, None, :, None] + return query_states, key_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class Zaya1VLAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Zaya1VLTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.qkv_proj = Zaya1VLCCAProjection( + config=self.config, + layer_idx=layer_idx, + ) + self.layer_type = config.layer_types[layer_idx] + self.sliding_window = config.sliding_window if self.layer_type == "hybrid_sliding" else None + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.qk_norm = Zaya1VLQKNorm(config) + if config.vision_lora: + self.o_lora_a, self.o_lora_b = _make_lora_pair( + config.num_attention_heads * self.head_dim, config.vision_lora_rank_attn, config.hidden_size + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + image_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + + mask_mapping = attention_mask or {} + causal_mask = mask_mapping.get("causal") + padding_mask = mask_mapping.get("padding") + + query_states, key_states, value_states = self.qkv_proj( + hidden_states, past_key_values, padding_mask, image_mask=image_mask + ) + query_states, key_states = self.qk_norm(query_states, key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + causal_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.view(*input_shape, -1) + output = self.o_proj(attn_output) + + if self.config.vision_lora and image_mask is not None: + # visual specific: apply LoRA only on vision-token positions + output = _apply_masked_lora(output, attn_output, self.o_lora_a, self.o_lora_b, image_mask) + + return output, attn_weights + + +def identity_decorator(cls): + """ + modular transformers need new decorators to overwrite the old ones e.g. use_experts_implementation; + this decorator is just used to skip them. + """ + return cls + + +@identity_decorator +class Zaya1VLExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config: Zaya1VLTextConfig): + super().__init__() + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + self.vision_lora = config.vision_lora + if self.vision_lora: + self.lora_gate_up_proj_a = nn.Parameter( + torch.empty(self.num_experts, config.vision_lora_rank_mlp, self.hidden_dim) + ) + self.lora_gate_up_proj_b = nn.Parameter( + torch.empty(self.num_experts, 2 * self.intermediate_dim, config.vision_lora_rank_mlp) + ) + self.lora_down_proj_a = nn.Parameter( + torch.empty(self.num_experts, config.vision_lora_rank_mlp, self.intermediate_dim) + ) + self.lora_down_proj_b = nn.Parameter( + torch.empty(self.num_experts, self.hidden_dim, config.vision_lora_rank_mlp) + ) + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + image_mask_flat: torch.Tensor | None = None, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate_up = F.linear(current_state, self.gate_up_proj[expert_idx]) + + image_mask_curr_expert = None + if self.vision_lora and image_mask_curr_expert is not None: + image_mask_curr_expert = image_mask_flat[token_idx] + # visual specific: apply expert LoRA only on vision-token positions + gate_up = _apply_masked_lora( + gate_up, + current_state, + self.lora_gate_up_proj_a[expert_idx], + self.lora_gate_up_proj_b[expert_idx], + image_mask_curr_expert, + ) + + gate, up = gate_up.chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + down = F.linear(current_hidden_states, self.down_proj[expert_idx]) + + if image_mask_curr_expert is not None: + # visual specific: apply expert LoRA only on vision-token positions + down = _apply_masked_lora( + down, + current_hidden_states, + self.lora_down_proj_a[expert_idx], + self.lora_down_proj_b[expert_idx], + image_mask_curr_expert, + ) + + down = down * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, down.to(final_hidden_states.dtype)) + + return final_hidden_states + + +@use_kernel_forward_from_hub("RMSNorm") +class Zaya1VLRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Zaya1VLRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Zaya1VLRouterMLP(nn.Module): + def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float): + super().__init__() + self.rmsnorm_eda = Zaya1VLRMSNorm(hidden_size, eps=rms_norm_eps) + self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) + self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.rmsnorm_eda(hidden_states) + hidden_states = self.act_fn(self.fc1(hidden_states)) + hidden_states = self.act_fn(self.fc2(hidden_states)) + return self.out_proj(hidden_states) + + +class Zaya1VLRouter(nn.Module): + def __init__( + self, + config, + layer_idx: int, + ) -> None: + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.num_experts = config.num_experts + 1 + self.top_k = config.num_experts_per_tok + self.router_hidden_size = config.router_hidden_size + + self.down_proj = nn.Linear(self.hidden_size, self.router_hidden_size, bias=True) + + self.use_eda = self.layer_idx != 0 + if self.use_eda: + self.router_states_scale = nn.Parameter(torch.ones(self.router_hidden_size)) + + self.router_mlp = Zaya1VLRouterMLP(self.router_hidden_size, self.num_experts, config.rms_norm_eps) + + self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + self.balancing_biases[-1] = -1.0 + + def forward( + self, + hidden_states: torch.Tensor, + router_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + final_shape = (-1, self.top_k) + seq_length = hidden_states.shape[1] + + router_hidden_states = self.down_proj(hidden_states) + + if self.use_eda and router_states is not None: + router_hidden_states = router_hidden_states + router_states * self.router_states_scale + + router_hidden_states_next = router_hidden_states[:, -seq_length:].clone() + router_logits = self.router_mlp(router_hidden_states) + router_probs = torch.softmax(router_logits, dim=-1) + + biased_router_probs = router_probs.detach().to(torch.float32) + self.balancing_biases + _, router_indices = torch.topk(biased_router_probs, self.top_k, dim=-1) + router_probs = torch.gather(router_probs, dim=2, index=router_indices) + + # If the router selects the extra skip expert, mask it before `Zaya1VLExperts` builds its one-hot expert mask. + skip_expert = router_indices == self.config.num_experts + router_probs = router_probs.masked_fill(skip_expert, 0) + router_indices = router_indices.masked_fill(skip_expert, 0) + + return ( + router_logits.reshape(-1, self.num_experts), + router_probs.reshape(final_shape), + router_indices.reshape(final_shape), + router_hidden_states_next, + ) + + +class Zaya1VLSparseMoeBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.gate = Zaya1VLRouter(config, layer_idx) + self.experts = Zaya1VLExperts(config) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + image_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + _, router_probs, router_indices, prev_router_hidden_states = self.gate( + hidden_states, router_states=prev_router_hidden_states + ) + + batch_size, seq_length, emb_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) + image_mask_flat = image_mask.reshape(batch_size * seq_length) if image_mask is not None else None + expert_output = self.experts(hidden_states_flat, router_indices, router_probs, image_mask_flat=image_mask_flat) + expert_output = expert_output.view(batch_size, seq_length, emb_dim) + + return expert_output, prev_router_hidden_states + + +class Zaya1VLResidualScaling(nn.Module): + def __init__(self, hidden_size: int): + super().__init__() + self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) + self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size)) + self.residual_scale = nn.Parameter(torch.ones(hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor): + output_dtype = hidden_states.dtype + hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale + # Matches the original ZAYA1_VL `residual_in_fp32` path. + residual = residual.to(torch.float32) + residual = (residual + self.residual_bias) * self.residual_scale + return (hidden_states + residual).to(output_dtype) + + +class Zaya1VLDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Zaya1VLTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Zaya1VLAttention(config=config, layer_idx=layer_idx) + self.mlp = Zaya1VLSparseMoeBlock(config, layer_idx) + self.input_layernorm = Zaya1VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Zaya1VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_residual_scale = Zaya1VLResidualScaling(config.hidden_size) + self.post_mlp_residual_scale = Zaya1VLResidualScaling(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + image_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_embeddings=position_embeddings, + image_mask=image_mask, + **kwargs, + ) + + residual = self.post_attention_residual_scale(hidden_states, residual) + hidden_states = self.post_attention_layernorm(residual) + + hidden_states, prev_router_hidden_states = self.mlp( + hidden_states, + prev_router_hidden_states, + image_mask=image_mask, + ) + + hidden_states = self.post_mlp_residual_scale(hidden_states, residual) + + return hidden_states, prev_router_hidden_states + + +@auto_docstring +class Zaya1VLPreTrainedModel(PreTrainedModel): + config: Zaya1VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Zaya1VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + # ZAYA1_VL generation uses the native hybrid dynamic cache, which is not a compileable cache. + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(Zaya1VLRouter, index=0), + "hidden_states": Zaya1VLDecoderLayer, + "attentions": Zaya1VLAttention, + } + input_modalities = ("image", "text") + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Zaya1VLResidualScaling): + init.ones_(module.hidden_states_scale) + init.zeros_(module.hidden_states_bias) + init.ones_(module.residual_scale) + init.zeros_(module.residual_bias) + elif isinstance(module, Zaya1VLModel): + init.ones_(module.input_hidden_states_scale) + init.zeros_(module.input_hidden_states_bias) + elif isinstance(module, Zaya1VLQKNorm): + init.zeros_(module.temp) + elif isinstance(module, Zaya1VLRouter): + if module.use_eda: + init.ones_(module.router_states_scale) + init.zeros_(module.balancing_biases) + module.balancing_biases[-1] = -1.0 # ignore: trf012 + elif isinstance(module, Zaya1VLExperts): + std = self.config.initializer_range + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, Zaya1VLRotaryEmbedding): + for layer_type in module.layer_types: + rope_init_fn = module.compute_default_rope_parameters + if module.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]] + curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type) + getattr(module, f"{layer_type}_inv_freq").copy_(curr_inv_freq) + getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq) + + # specific for visual expert lora + if isinstance(module, Zaya1VLExperts): + if module.vision_lora: + lora_param_names = "lora_gate_up_proj_a", "lora_gate_up_proj_b", "lora_down_proj_a", "lora_down_proj_b" + for param_name in lora_param_names: + init.normal_(getattr(module, param_name), mean=0.0, std=0.02) + + +@auto_docstring +class Zaya1VLTextModel(Zaya1VLPreTrainedModel): + config: Zaya1VLTextConfig + + def __init__(self, config: Zaya1VLTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Zaya1VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = Zaya1VLRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) + self.final_norm = Zaya1VLRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + image_mask: torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + r""" + image_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Boolean mask selecting image placeholder token positions. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError( + "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." + ) + + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. + sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} + mask_creation_functions = { + "hybrid": lambda: create_causal_mask(**mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + causal_mask_mapping = { + layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types) + } + cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds) + if inputs_embeds.shape[1] == 1: + image_mask = None + + hidden_states = inputs_embeds + position_embeddings = { + layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) + for layer_type in set(self.config.layer_types) + } + + hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale + prev_router_hidden_states = None + + for layer_n, decoder_layer in enumerate(self.layers): + layer_type = self.config.layer_types[layer_n] + causal_mask = causal_mask_mapping[layer_type] + if image_mask is not None and causal_mask is not None and causal_mask.shape[-1] == image_mask.shape[-1]: + image_pair_mask = image_mask[:, None, :, None] & image_mask[:, None, None, :] + causal_mask = causal_mask.clone().masked_fill(image_pair_mask, 0) + mask_mapping = {"causal": causal_mask, "padding": cca_mask} + hidden_states, prev_router_hidden_states = decoder_layer( + hidden_states, + prev_router_hidden_states, + attention_mask=mask_mapping, + past_key_values=past_key_values, + position_embeddings=position_embeddings[layer_type], + image_mask=image_mask, + **kwargs, + ) + + hidden_states = self.final_norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + def _update_cca_mask(self, attention_mask, past_key_values, inputs_embeds): + """ + No need to zero padding states when cached convolution states are already available or all inputs are valid. + """ + cca_mask = attention_mask + if (past_key_values is not None and past_key_values.has_previous_state()) or ( + attention_mask is not None and torch.all(attention_mask == 1) + ): + cca_mask = None + elif attention_mask is not None: + cca_mask = attention_mask[:, -inputs_embeds.shape[1] :] + return cca_mask + + +class Zaya1VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int | list[int] | tuple[int, int] = 14, + temporal_patch_size: int | list[int] | tuple[int, int] = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Zaya1VLPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Zaya1VLRMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Zaya1VLVisionAttention(nn.Module): + def __init__(self, config: Zaya1VLVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Zaya1VLVisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Zaya1VLRMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Zaya1VLRMSNorm(config.hidden_size, eps=1e-6) + self.attn = Zaya1VLVisionAttention(config=config) + self.mlp = Zaya1VLMLP(config, bias=True) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + r""" + cu_seqlens (`torch.Tensor`): + Cumulative sequence lengths used for packed variable-length attention in Flash Attention kernels. + rotary_pos_emb (`torch.Tensor`, *optional*): + Precomputed rotary positional embeddings applied to the vision attention query/key states. + """ + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Zaya1VLVisionModel(Zaya1VLPreTrainedModel): + config: Zaya1VLVisionConfig + _no_split_modules = ["Zaya1VLVisionBlock"] + _input_embed_layer = "patch_embed" + _can_record_outputs = { + "hidden_states": Zaya1VLVisionBlock, + "attentions": Zaya1VLVisionAttention, + } + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Zaya1VLVisionBlock(config) for _ in range(config.depth)]) + self.merger = Zaya1VLPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + self.post_init() + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw.tolist(): + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + grid_thw_list = grid_thw.tolist() + + for grid_t, grid_h, grid_w in grid_thw_list: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += grid_t * llm_grid_h * llm_grid_w + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + @merge_with_config_defaults + @capture_outputs + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + ) -> tuple | BaseModelOutputWithPooling: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings, + **kwargs, + ) + + merged_hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + merged_hidden_states = merged_hidden_states[reverse_indices, :] + + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=merged_hidden_states, + ) + + +@auto_docstring +class Zaya1VLModel(Zaya1VLPreTrainedModel): + def __init__(self, config: Zaya1VLConfig): + super().__init__(config) + self.visual = Zaya1VLVisionModel._from_config(config.vision_config) + self.language_model = Zaya1VLTextModel(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + @merge_with_config_defaults + @can_return_tuple + @auto_docstring( + custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection." + ) + def get_image_features( + self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor, **kwargs: Unpack[TransformersKwargs] + ) -> torch.FloatTensor: + r""" + pixel_values (`torch.FloatTensor`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): + The temporal, height and width grid of each image after image preprocessing. + """ + pixel_values = pixel_values.type(self.visual.dtype) + return self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs).pooler_output + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[special_image_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width grid of each image after image preprocessing. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + image_mask = None + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features(pixel_values, image_grid_thw, **kwargs) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features.unsqueeze(0) + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + image_mask = image_mask[..., 0] + + return self.language_model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + image_mask=image_mask, + use_cache=use_cache, + **kwargs, + ) + + +@auto_docstring +class Zaya1VLForConditionalGeneration(Zaya1VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + _is_stateful = True + + def __init__(self, config: Zaya1VLConfig): + super().__init__(config) + self.model = Zaya1VLModel(config) + self.vocab_size = config.text_config.vocab_size + self.lm_head = nn.Linear( + config.text_config.hidden_size, config.text_config.vocab_size, bias=config.text_config.lm_head_bias + ) + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_router_logits: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Zaya1VLForConditionalGeneration + + >>> model = Zaya1VLForConditionalGeneration.from_pretrained("meta-zaya1_vl/Zaya1VL-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-zaya1_vl/Zaya1VL-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return MoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def get_image_features( + self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor, **kwargs: Unpack[TransformersKwargs] + ) -> torch.FloatTensor: + return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + position_ids=None, + use_cache=True, + pixel_values=None, + image_grid_thw=None, + is_first_iteration=False, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + use_cache=use_cache, + is_first_iteration=is_first_iteration, + **kwargs, + ) + if not is_first_iteration and use_cache: + model_inputs["pixel_values"] = None + return model_inputs + + +__all__ = [ + "Zaya1VLVisionModel", + "Zaya1VLModel", + "Zaya1VLPreTrainedModel", + "Zaya1VLTextModel", + "Zaya1VLForConditionalGeneration", +] diff --git a/src/transformers/models/zaya1_vl/modular_zaya1_vl.py b/src/transformers/models/zaya1_vl/modular_zaya1_vl.py new file mode 100644 index 000000000000..b2c53f52556d --- /dev/null +++ b/src/transformers/models/zaya1_vl/modular_zaya1_vl.py @@ -0,0 +1,778 @@ +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Zaya1-VL model.""" + +from typing import Any + +import torch +import torch.nn.functional as F +from huggingface_hub.dataclasses import strict +from torch import nn + +from ... import initialization as init +from ...cache_utils import Cache, DynamicCache +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_outputs import MoeModelOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import MultiModalData, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import TransformersKwargs, auto_docstring +from ..llama.modeling_llama import repeat_kv +from ..llava.modeling_llava import LlavaModel +from ..qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig +from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel +from ..qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig +from ..qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor, Qwen2VLProcessorKwargs +from ..qwen3_5_moe.modeling_qwen3_5_moe import apply_rotary_pos_emb +from ..zaya.configuration_zaya import ZayaConfig +from ..zaya.modeling_zaya import ( + ZayaAttention, + ZayaCCAProjection, + ZayaDecoderLayer, + ZayaExperts, + ZayaForCausalLM, + ZayaModel, + ZayaPreTrainedModel, + ZayaRotaryEmbedding, + ZayaSparseMoeBlock, + eager_attention_forward, +) + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-VL-8B") +@strict +class Zaya1VLVisionConfig(Qwen2_5_VLVisionConfig): + r""" + window_size (`int`, *optional*, defaults to 112): + Window size used by the Qwen2.5-VL vision encoder. + out_hidden_size (`int`, *optional*, defaults to 2048): + Output hidden size after the vision merger. + fullatt_block_indexes (`list[int]`, *optional*, defaults to `[7, 15, 23, 31]`): + Vision encoder layers that use full attention. + """ + + model_type = "zaya1_vl_vision" + base_config_key = "vision_config" + + hidden_size: int = 1280 + temporal_patch_size: int | list[int] | tuple[int, int] = 1 + out_hidden_size: int = 2048 + + tokens_per_second = AttributeError() + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-VL-8B") +@strict +class Zaya1VLTextConfig(ZayaConfig): + r""" + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the language modeling head. + router_hidden_size (`int`, *optional*, defaults to 256): + Hidden size used by the ZAYA router. + cca_time0 (`int`, *optional*, defaults to 2): + First temporal parameter of the CCA projection. + cca_time1 (`int`, *optional*, defaults to 2): + Second temporal parameter of the CCA projection. + + vision_lora (`bool`, *optional*, defaults to `True`): + Whether to enable LoRA modules that are applied only on vision-token positions. + vision_lora_rank_attn (`int`, *optional*, defaults to 8): + LoRA rank for the CCA and attention output projections applied to vision-token positions. + vision_lora_rank_mlp (`int`, *optional*, defaults to 32): + LoRA rank for the MoE expert projections applied to vision-token positions. + """ + + model_type = "zaya1_vl_text" + base_config_key = "text_config" + + vision_lora: bool = True + vision_lora_rank_attn: int = 8 + vision_lora_rank_mlp: int = 32 + + +@auto_docstring(checkpoint="Zyphra/ZAYA1-VL-8B") +@strict +class Zaya1VLConfig(Qwen2VLConfig): + r""" + text_config (`dict` or `Zaya1VLTextConfig`, *optional*): + Configuration for the ZAYA text decoder. + vision_config (`dict` or `Zaya1VLVisionConfig`, *optional*): + Configuration for the Qwen2.5-VL vision encoder. + image_token_id (`int`, *optional*, defaults to 262147): + Token id used as an image placeholder. + vision_start_token_id (`int`, *optional*, defaults to 255999): + Token id that starts an image span. + vision_end_token_id (`int`, *optional*, defaults to 256000): + Token id that ends an image span. + """ + + model_type = "zaya1_vl" + sub_configs = {"vision_config": Zaya1VLVisionConfig, "text_config": Zaya1VLTextConfig} + + image_token_id: int = 262147 + vision_start_token_id: int = 255999 + vision_end_token_id: int = 256000 + video_token_id = AttributeError() + + tie_word_embeddings: bool = True + output_router_logits: bool = False + + +class Zaya1VLRotaryEmbedding(ZayaRotaryEmbedding): + pass + + +def _make_lora_pair(in_features: int, rank: int, out_features: int) -> tuple[nn.Linear, nn.Linear]: + return nn.Linear(in_features, rank, bias=False), nn.Linear(rank, out_features, bias=False) + + +def _apply_masked_lora( + output: torch.Tensor, + hidden_states: torch.Tensor, + lora_a: nn.Linear | torch.Tensor, + lora_b: nn.Linear | torch.Tensor, + mask: torch.Tensor | None, +) -> torch.Tensor: + if mask is None: + return output + indices = mask.nonzero(as_tuple=True) + if indices[0].numel() == 0: + return output + hidden_states = hidden_states[indices] + hidden_states = F.linear(hidden_states, lora_a) if isinstance(lora_a, torch.Tensor) else lora_a(hidden_states) + hidden_states = F.linear(hidden_states, lora_b) if isinstance(lora_b, torch.Tensor) else lora_b(hidden_states) + return output.index_put(indices, hidden_states.to(output.dtype), accumulate=True) + + +class Zaya1VLCCAProjection(ZayaCCAProjection): + def __init__(self, config: Zaya1VLTextConfig, layer_idx: int): + super().__init__(config, layer_idx) + if config.vision_lora: + self.q_lora_a, self.q_lora_b = _make_lora_pair( + self.hidden_size, config.vision_lora_rank_attn, self.num_attention_heads * self.head_dim + ) + self.k_lora_a, self.k_lora_b = _make_lora_pair( + self.hidden_size, config.vision_lora_rank_attn, self.num_key_value_heads * self.head_dim + ) + self.v_current_lora_a, self.v_current_lora_b = _make_lora_pair( + self.hidden_size, config.vision_lora_rank_attn, self.num_key_value_heads * self.head_dim // 2 + ) + self.v_delayed_lora_a, self.v_delayed_lora_b = _make_lora_pair( + self.hidden_size, config.vision_lora_rank_attn, self.num_key_value_heads * self.head_dim // 2 + ) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Cache | None, + padding_mask: torch.Tensor | None = None, + image_mask: torch.Tensor | None = None, + ): + if padding_mask is not None: + hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + projected_queries = self.q_proj(hidden_states) + projected_keys = self.k_proj(hidden_states) + + if self.config.vision_lora and image_mask is not None: + # visual specific: apply LoRA only on vision-token positions + projected_queries = _apply_masked_lora( + projected_queries, hidden_states, self.q_lora_a, self.q_lora_b, image_mask + ) + projected_keys = _apply_masked_lora( + projected_keys, hidden_states, self.k_lora_a, self.k_lora_b, image_mask + ) + + qk_states = torch.cat([projected_queries, projected_keys], dim=-1) + + query_residual = projected_queries.view(*hidden_shape) + key_residual = projected_keys.view(*hidden_shape).transpose(1, 2) + key_residual = repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) + query_residual = (query_residual + key_residual) * 0.5 + key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) + + qk_states = qk_states.transpose(1, 2) + use_precomputed_states = past_key_values is not None and past_key_values.has_previous_state(self.layer_idx) + if use_precomputed_states: + cached_qk_states = past_key_values.layers[self.layer_idx].conv_states + qk_states = torch.cat([cached_qk_states, qk_states], dim=-1) + else: + qk_states = F.pad(qk_states, (self.conv_kernel_size, 0)) + + if past_key_values is not None: + new_conv_state = qk_states[..., -self.conv_kernel_size :] + if new_conv_state.shape[-1] < self.conv_kernel_size: + new_conv_state = F.pad(new_conv_state, (self.conv_kernel_size - new_conv_state.shape[-1], 0)) + past_key_values.update_conv_state(new_conv_state, self.layer_idx) + + qk_states = self.conv_qk_depthwise(qk_states) + qk_states = self.conv_qk_grouped(qk_states).transpose(1, 2) + + query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1] + query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual + key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual + + value_current = self.v_proj_current(hidden_states) + delayed_v_state = self.v_proj_delayed(hidden_states) + + if self.config.vision_lora and image_mask is not None: + # visual specific: apply LoRA only on vision-token positions + value_current = _apply_masked_lora( + value_current, hidden_states, self.v_current_lora_a, self.v_current_lora_b, image_mask + ) + delayed_v_state = _apply_masked_lora( + delayed_v_state, hidden_states, self.v_delayed_lora_a, self.v_delayed_lora_b, image_mask + ) + + if use_precomputed_states: + recurrent_v_state = past_key_values.layers[self.layer_idx].recurrent_states.unsqueeze(1) + else: + recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) + value_delayed = torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1) + + if past_key_values is not None: + past_key_values.update_recurrent_state(delayed_v_state[:, -1, :], self.layer_idx) + + value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape) + + return query, key, value + + +class Zaya1VLAttention(ZayaAttention): + def __init__(self, config: Zaya1VLTextConfig, layer_idx: int): + super().__init__(config, layer_idx) + if config.vision_lora: + self.o_lora_a, self.o_lora_b = _make_lora_pair( + config.num_attention_heads * self.head_dim, config.vision_lora_rank_attn, config.hidden_size + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + image_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + + mask_mapping = attention_mask or {} + causal_mask = mask_mapping.get("causal") + padding_mask = mask_mapping.get("padding") + + query_states, key_states, value_states = self.qkv_proj( + hidden_states, past_key_values, padding_mask, image_mask=image_mask + ) + query_states, key_states = self.qk_norm(query_states, key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + causal_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.view(*input_shape, -1) + output = self.o_proj(attn_output) + + if self.config.vision_lora and image_mask is not None: + # visual specific: apply LoRA only on vision-token positions + output = _apply_masked_lora(output, attn_output, self.o_lora_a, self.o_lora_b, image_mask) + + return output, attn_weights + + +def identity_decorator(cls): + """ + modular transformers need new decorators to overwrite the old ones e.g. use_experts_implementation; + this decorator is just used to skip them. + """ + return cls + + +@identity_decorator +class Zaya1VLExperts(ZayaExperts): + def __init__(self, config: Zaya1VLTextConfig): + super().__init__(config) + self.vision_lora = config.vision_lora + if self.vision_lora: + self.lora_gate_up_proj_a = nn.Parameter( + torch.empty(self.num_experts, config.vision_lora_rank_mlp, self.hidden_dim) + ) + self.lora_gate_up_proj_b = nn.Parameter( + torch.empty(self.num_experts, 2 * self.intermediate_dim, config.vision_lora_rank_mlp) + ) + self.lora_down_proj_a = nn.Parameter( + torch.empty(self.num_experts, config.vision_lora_rank_mlp, self.intermediate_dim) + ) + self.lora_down_proj_b = nn.Parameter( + torch.empty(self.num_experts, self.hidden_dim, config.vision_lora_rank_mlp) + ) + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + image_mask_flat: torch.Tensor | None = None, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate_up = F.linear(current_state, self.gate_up_proj[expert_idx]) + + image_mask_curr_expert = None + if self.vision_lora and image_mask_curr_expert is not None: + image_mask_curr_expert = image_mask_flat[token_idx] + # visual specific: apply expert LoRA only on vision-token positions + gate_up = _apply_masked_lora( + gate_up, + current_state, + self.lora_gate_up_proj_a[expert_idx], + self.lora_gate_up_proj_b[expert_idx], + image_mask_curr_expert, + ) + + gate, up = gate_up.chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + down = F.linear(current_hidden_states, self.down_proj[expert_idx]) + + if image_mask_curr_expert is not None: + # visual specific: apply expert LoRA only on vision-token positions + down = _apply_masked_lora( + down, + current_hidden_states, + self.lora_down_proj_a[expert_idx], + self.lora_down_proj_b[expert_idx], + image_mask_curr_expert, + ) + + down = down * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, down.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class Zaya1VLSparseMoeBlock(ZayaSparseMoeBlock): + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + image_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + _, router_probs, router_indices, prev_router_hidden_states = self.gate( + hidden_states, router_states=prev_router_hidden_states + ) + + batch_size, seq_length, emb_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) + image_mask_flat = image_mask.reshape(batch_size * seq_length) if image_mask is not None else None + expert_output = self.experts(hidden_states_flat, router_indices, router_probs, image_mask_flat=image_mask_flat) + expert_output = expert_output.view(batch_size, seq_length, emb_dim) + + return expert_output, prev_router_hidden_states + + +class Zaya1VLDecoderLayer(ZayaDecoderLayer): + def __init__(self, config: Zaya1VLTextConfig, layer_idx: int): + super().__init__(config, layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + attention_mask: dict[str, Any] | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + image_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_embeddings=position_embeddings, + image_mask=image_mask, + **kwargs, + ) + + residual = self.post_attention_residual_scale(hidden_states, residual) + hidden_states = self.post_attention_layernorm(residual) + + hidden_states, prev_router_hidden_states = self.mlp( + hidden_states, + prev_router_hidden_states, + image_mask=image_mask, + ) + + hidden_states = self.post_mlp_residual_scale(hidden_states, residual) + + return hidden_states, prev_router_hidden_states + + +class Zaya1VLPreTrainedModel(ZayaPreTrainedModel): + _no_split_modules = ["Zaya1VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + input_modalities = ("image", "text") + + def _init_weights(self, module): + super()._init_weights(self, module) + + # specific for visual expert lora + if isinstance(module, Zaya1VLExperts): + if module.vision_lora: + lora_param_names = "lora_gate_up_proj_a", "lora_gate_up_proj_b", "lora_down_proj_a", "lora_down_proj_b" + for param_name in lora_param_names: + init.normal_(getattr(module, param_name), mean=0.0, std=0.02) + + +class Zaya1VLTextModel(ZayaModel): + config: Zaya1VLTextConfig + + def __init__(self, config: Zaya1VLTextConfig): + super().__init__(config) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + image_mask: torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + r""" + image_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Boolean mask selecting image placeholder token positions. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError( + "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." + ) + + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Original ZAYA SWA only applies the local causal pattern; padding tokens are zeroed before the CCA projection. + sliding_mask_kwargs = {**mask_kwargs, "attention_mask": None} + mask_creation_functions = { + "hybrid": lambda: create_causal_mask(**mask_kwargs), + "hybrid_sliding": lambda: create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + causal_mask_mapping = { + layer_type: mask_creation_functions[layer_type]() for layer_type in set(self.config.layer_types) + } + cca_mask = self._update_cca_mask(attention_mask, past_key_values, inputs_embeds) + if inputs_embeds.shape[1] == 1: + image_mask = None + + hidden_states = inputs_embeds + position_embeddings = { + layer_type: self.rotary_emb(hidden_states, position_ids, layer_type) + for layer_type in set(self.config.layer_types) + } + + hidden_states = (hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale + prev_router_hidden_states = None + + for layer_n, decoder_layer in enumerate(self.layers): + layer_type = self.config.layer_types[layer_n] + causal_mask = causal_mask_mapping[layer_type] + if image_mask is not None and causal_mask is not None and causal_mask.shape[-1] == image_mask.shape[-1]: + image_pair_mask = image_mask[:, None, :, None] & image_mask[:, None, None, :] + causal_mask = causal_mask.clone().masked_fill(image_pair_mask, 0) + mask_mapping = {"causal": causal_mask, "padding": cca_mask} + hidden_states, prev_router_hidden_states = decoder_layer( + hidden_states, + prev_router_hidden_states, + attention_mask=mask_mapping, + past_key_values=past_key_values, + position_embeddings=position_embeddings[layer_type], + image_mask=image_mask, + **kwargs, + ) + + hidden_states = self.final_norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class Zaya1VLVisionModel(Qwen2_5_VisionTransformerPretrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + +@auto_docstring +class Zaya1VLModel(LlavaModel, Zaya1VLPreTrainedModel): + def __init__(self, config: Zaya1VLConfig): + Zaya1VLPreTrainedModel.__init__(self, config) + self.visual = Zaya1VLVisionModel._from_config(config.vision_config) + self.language_model = Zaya1VLTextModel(config.text_config) + self.post_init() + + def get_image_features( + self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor, **kwargs: Unpack[TransformersKwargs] + ) -> torch.FloatTensor: + r""" + pixel_values (`torch.FloatTensor`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): + The temporal, height and width grid of each image after image preprocessing. + """ + pixel_values = pixel_values.type(self.visual.dtype) + return self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs).pooler_output + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width grid of each image after image preprocessing. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + image_mask = None + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features(pixel_values, image_grid_thw, **kwargs) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features.unsqueeze(0) + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + image_mask = image_mask[..., 0] + + return self.language_model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + image_mask=image_mask, + use_cache=use_cache, + **kwargs, + ) + + +@auto_docstring +class Zaya1VLForConditionalGeneration(ZayaForCausalLM, Zaya1VLPreTrainedModel): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: Zaya1VLConfig): + super().__init__(self, config) + self.vocab_size = config.text_config.vocab_size + self.lm_head = nn.Linear( + config.text_config.hidden_size, config.text_config.vocab_size, bias=config.text_config.lm_head_bias + ) + self.post_init() + + def get_image_features( + self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor, **kwargs: Unpack[TransformersKwargs] + ) -> torch.FloatTensor: + return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + position_ids=None, + use_cache=True, + pixel_values=None, + image_grid_thw=None, + is_first_iteration=False, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + use_cache=use_cache, + is_first_iteration=is_first_iteration, + **kwargs, + ) + if not is_first_iteration and use_cache: + model_inputs["pixel_values"] = None + return model_inputs + + +class Zaya1VLProcessorKwargs(Qwen2VLProcessorKwargs): + pass + + +@auto_docstring +class Zaya1VLProcessor(Qwen2VLProcessor): + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = getattr(tokenizer, "image_token", "") + self.image_token_id = getattr(tokenizer, "image_token_id", None) or tokenizer.convert_tokens_to_ids( + self.image_token + ) + ProcessorMixin.__init__(self, image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput | None = None, + text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, + **kwargs: Unpack[Zaya1VLProcessorKwargs], + ) -> BatchFeature: + r""" + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Zaya1VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + image_inputs = {} + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + + text = text.copy() if isinstance(text, list) else [text] + if images is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids") + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"]) + + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for image inputs with the given sizes. + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per image. + """ + if image_sizes is None: + return MultiModalData() + + images_kwargs = {**Zaya1VLProcessorKwargs._defaults.get("images_kwargs", {}), **kwargs} + merge_size = images_kwargs.get("merge_size") or self.image_processor.merge_size + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) for image_size in image_sizes + ] + num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] + return MultiModalData(num_image_tokens=num_image_tokens, num_image_patches=num_image_patches) + + @property + def model_input_names(self): + return self.image_processor.model_input_names + self.tokenizer.model_input_names + + +__all__ = [ + "Zaya1VLTextConfig", + "Zaya1VLVisionConfig", + "Zaya1VLConfig", + "Zaya1VLVisionModel", + "Zaya1VLModel", + "Zaya1VLPreTrainedModel", + "Zaya1VLTextModel", + "Zaya1VLForConditionalGeneration", + "Zaya1VLProcessor", +] diff --git a/src/transformers/models/zaya1_vl/processing_zaya1_vl.py b/src/transformers/models/zaya1_vl/processing_zaya1_vl.py new file mode 100644 index 000000000000..cf73ff172b6a --- /dev/null +++ b/src/transformers/models/zaya1_vl/processing_zaya1_vl.py @@ -0,0 +1,145 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zaya1_vl/modular_zaya1_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zaya1_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import auto_docstring + + +class Zaya1VLProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + "return_mm_token_type_ids": True, + }, + } + + +@auto_docstring +class Zaya1VLProcessor(ProcessorMixin): + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = getattr(tokenizer, "image_token", "") + self.image_token_id = getattr(tokenizer, "image_token_id", None) or tokenizer.convert_tokens_to_ids( + self.image_token + ) + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + @auto_docstring + def __call__( + self, + images: ImageInput | None = None, + text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, + **kwargs: Unpack[Zaya1VLProcessorKwargs], + ) -> BatchFeature: + r""" + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Zaya1VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + image_inputs = {} + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + + text = text.copy() if isinstance(text, list) else [text] + if images is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids") + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"]) + + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for image inputs with the given sizes. + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per image. + """ + if image_sizes is None: + return MultiModalData() + + images_kwargs = {**Zaya1VLProcessorKwargs._defaults.get("images_kwargs", {}), **kwargs} + merge_size = images_kwargs.get("merge_size") or self.image_processor.merge_size + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) for image_size in image_sizes + ] + num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] + return MultiModalData(num_image_tokens=num_image_tokens, num_image_patches=num_image_patches) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self): + return self.image_processor.model_input_names + self.tokenizer.model_input_names + + +__all__ = ["Zaya1VLProcessor"] diff --git a/tests/models/zaya/__init__.py b/tests/models/zaya/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/models/zaya/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py new file mode 100644 index 000000000000..316d206004d0 --- /dev/null +++ b/tests/models/zaya/test_modeling_zaya.py @@ -0,0 +1,419 @@ +# Copyright 2026 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch ZAYA model.""" + +import unittest + +from huggingface_hub.errors import StrictDataclassClassValidationError +from parameterized import parameterized + +from transformers import is_torch_available +from transformers.testing_utils import cleanup, require_torch, slow, torch_device + + +if is_torch_available(): + import torch + + from transformers import AutoTokenizer, ZayaConfig, ZayaForCausalLM, ZayaModel + from transformers.cache_utils import ( + DynamicCache, + LinearAttentionAndFullAttentionLayer, + LinearAttentionAndSlidingWindowAttentionLayer, + ) + from transformers.models.zaya.modeling_zaya import ZayaCCAProjection + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class ZayaModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = ZayaModel + + def __init__(self, parent, **kwargs): + super().__init__( + parent=parent, + num_hidden_layers=4, + moe_intermediate_size=32, + num_experts_per_tok=1, + layer_types=["hybrid", "hybrid_sliding", "hybrid", "hybrid_sliding"], + sliding_window=64, + **kwargs, + ) + + +@require_torch +class ZayaModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = ZayaModelTester + test_all_params_have_gradient = False + + def _get_conv_state_shape(self, batch_size: int, config): + conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim + conv_kernel_size = config.cca_time0 + config.cca_time1 - 2 + return (batch_size, conv_state_size, conv_kernel_size) + + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.num_key_value_heads * config.head_dim // 2) + + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + if not isinstance(past_key_values, DynamicCache): + raise ValueError("The cache does not use the correct Cache") + + config = config.get_text_config(decoder=True) + self.assertEqual(config.num_hidden_layers, len(past_key_values)) + attention_shape = (batch_size, config.num_key_value_heads, seq_length, config.head_dim) + conv_shape = self._get_conv_state_shape(batch_size, config) + recurrent_shape = self._get_recurrent_state_shape(batch_size, config) + + for layer_type, layer in zip(config.layer_types, past_key_values.layers): + expected_layer_class = ( + LinearAttentionAndSlidingWindowAttentionLayer + if layer_type == "hybrid_sliding" + else LinearAttentionAndFullAttentionLayer + ) + self.assertIs(type(layer), expected_layer_class) + self.assertEqual(layer.keys.shape, attention_shape) + self.assertEqual(layer.values.shape, attention_shape) + self.assertEqual(layer.conv_states.shape, conv_shape) + self.assertEqual(layer.recurrent_states.shape, recurrent_shape) + + @unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("ZAYA uses key/query normalization which is not equivalent under padding-free packing.") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip( + "ZAYA follows the original SWA behavior where sliding attention only applies the local causal pattern;" + "See https://github.com/huggingface/transformers/pull/45862#discussion_r3249556316" + ) + def test_left_padding_compatibility(self): + pass + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + config._attn_implementation = "eager" + + for model_class in self.all_model_classes: + model = model_class._from_config(config, attn_implementation="eager") + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class({**inputs_dict, "output_attentions": True}, model_class)) + + expected_attn_layers = config.num_hidden_layers + self.assertEqual(len(outputs.attentions), expected_attn_layers) + self.assertEqual( + outputs.attentions[0].shape, + ( + self.model_tester.batch_size, + config.num_attention_heads, + self.model_tester.seq_length, + self.model_tester.seq_length, + ), + ) + + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) + @unittest.skip( + "RoPE-scaling-from-config test doesn't match ZAYA's nested per-layer-type rope_parameters (same as e.g. Laguna, Gemma3)." + ) + def test_model_rope_scaling_from_config(self, scaling_type): + pass + + def test_model_rope_scaling_frequencies(self): + """ + Tests the frequency properties of the different RoPE scaling types on the model RoPE layer. + Copied from Laguna to adapt to per-layer-type rope configs. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + partial_rotary_factor = config.rope_parameters["hybrid"]["partial_rotary_factor"] + + def set_rope_params(rope_params): + config.rope_parameters = { + "hybrid": {**rope_params, "partial_rotary_factor": partial_rotary_factor}, + "hybrid_sliding": {**rope_params, "partial_rotary_factor": partial_rotary_factor}, + } + + set_rope_params({"rope_type": "default", "rope_theta": 10_000.0}) + + base_model = self.model_tester.base_model_class(config) + possible_rope_attributes = [ + "pos_emb", + "rotary_emb", + "global_rotary_emb", + "local_rotary_emb", + ] + for name, module in base_model.named_modules(): + if any(potential_name in name for potential_name in possible_rope_attributes): + rope_class = type(module) + break + + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + x = torch.randn(1, dtype=torch.float32, device=torch_device) + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device).unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device).unsqueeze(0) + + set_rope_params({"rope_type": "default", "rope_theta": 10_000.0}) + original_rope = rope_class(config=config).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, position_ids_short, layer_type="hybrid_sliding") + original_cos_long, original_sin_long = original_rope(x, position_ids_long, layer_type="hybrid_sliding") + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + + set_rope_params({"rope_type": "linear", "factor": scaling_factor, "rope_theta": 10_000.0}) + linear_scaling_rope = rope_class(config=config).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding") + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding") + torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) + torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) + + set_rope_params({"rope_type": "dynamic", "factor": scaling_factor, "rope_theta": 10_000.0}) + ntk_scaling_rope = rope_class(config=config).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding") + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding") + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.hybrid_sliding_inv_freq <= original_rope.hybrid_sliding_inv_freq).all()) + + set_rope_params({"rope_type": "yarn", "factor": scaling_factor, "rope_theta": 10_000.0}) + yarn_scaling_rope = rope_class(config=config).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short, layer_type="hybrid_sliding") + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long, layer_type="hybrid_sliding") + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_short, original_cos_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + def test_moe_router_logits(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = self.model_tester.causal_lm_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**inputs_dict, output_router_logits=True) + + expected_moe_layers = config.num_hidden_layers + self.assertEqual(len(outputs.router_logits), expected_moe_layers) + self.assertEqual( + outputs.router_logits[0].shape, + (self.model_tester.batch_size * self.model_tester.seq_length, config.num_experts + 1), + ) + + def test_num_experts_per_tok_validation(self): + with self.assertRaisesRegex(StrictDataclassClassValidationError, "num_experts_per_tok=1"): + ZayaConfig(num_experts_per_tok=2) + + def test_sliding_attention_mask_is_used(self): + config = ZayaConfig( + vocab_size=128, + hidden_size=32, + moe_intermediate_size=32, + num_hidden_layers=4, + num_experts=4, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + router_hidden_size=4, + layer_types=["hybrid_sliding", "hybrid", "hybrid_sliding", "hybrid"], + sliding_window=3, + tie_word_embeddings=False, + attn_implementation="eager", + ) + model = ZayaModel(config).to(torch_device) + model.eval() + input_ids = torch.arange(6, device=torch_device).unsqueeze(0) + + with torch.no_grad(): + outputs = model(input_ids=input_ids, output_attentions=True) + + sliding_attention = outputs.attentions[0] + self.assertTrue(torch.all(sliding_attention[:, :, -1, :3] == 0)) + + def test_cca_cache_matches_full_forward(self): + config = ZayaConfig( + vocab_size=128, + hidden_size=32, + moe_intermediate_size=32, + num_hidden_layers=1, + num_experts=4, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + router_hidden_size=4, + tie_word_embeddings=False, + ) + torch.manual_seed(0) + cca = ZayaCCAProjection(config, layer_idx=0).to(torch_device) + cca.eval() + hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device) + + with torch.no_grad(): + full = cca(hidden_states, None, None) + cache = DynamicCache(config=config) + cca(hidden_states[:, :4], cache, None) + cached = cca(hidden_states[:, 4:], cache, None) + + for full_states, cached_states in zip(full, cached): + torch.testing.assert_close(full_states[:, -1:], cached_states, rtol=1e-5, atol=1e-5) + + def test_cca_cache_matches_full_forward_multi_token(self): + config = ZayaConfig( + vocab_size=128, + hidden_size=32, + moe_intermediate_size=32, + num_hidden_layers=1, + num_experts=4, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + router_hidden_size=4, + tie_word_embeddings=False, + ) + torch.manual_seed(0) + cca = ZayaCCAProjection(config, layer_idx=0).to(torch_device) + cca.eval() + hidden_states = torch.randn(1, 5, config.hidden_size, device=torch_device) + + with torch.no_grad(): + full = cca(hidden_states, None, None) + cache = DynamicCache(config=config) + cca(hidden_states[:, :3], cache, None) + cached = cca(hidden_states[:, 3:], cache, None) + + for full_states, cached_states in zip(full, cached): + torch.testing.assert_close(full_states[:, 3:], cached_states, rtol=1e-5, atol=1e-5) + + def test_zaya_cache_reorder_and_reset(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + cache = DynamicCache(config=config) + conv_state_size = config.num_key_value_heads * config.head_dim + config.num_attention_heads * config.head_dim + cache.update_conv_state( + torch.arange(2 * conv_state_size * 2, device=torch_device, dtype=torch.float32).view( + 2, conv_state_size, 2 + ), + 0, + ) + cache.update_recurrent_state( + torch.arange( + 2 * config.num_key_value_heads * config.head_dim // 2, device=torch_device, dtype=torch.float32 + ).view(2, config.num_key_value_heads * config.head_dim // 2), + 0, + ) + self.assertEqual(cache.layers[0].recurrent_states.shape[-1], config.num_key_value_heads * config.head_dim // 2) + + cache.reorder_cache(torch.tensor([1, 0], device=torch_device)) + self.assertEqual(cache.layers[0].conv_states.shape[0], 2) + + cache.reset() + self.assertFalse(cache.has_previous_state(0)) + self.assertEqual(cache.layers[0].conv_states.sum().item(), 0) + self.assertEqual(cache.layers[0].recurrent_states.sum().item(), 0) + + +@require_torch +class ZayaIntegrationTest(unittest.TestCase): + model = None + model_id = "Zyphra/ZAYA1-8B" + + @classmethod + def get_model(cls): + if cls.model is None: + cls.model = ZayaForCausalLM.from_pretrained(cls.model_id, device_map="auto", dtype=torch.bfloat16) + return cls.model + + @classmethod + def tearDownClass(cls): + if cls.model is not None: + del cls.model + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def get_inputs(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + inputs = tokenizer("Hello! How can I assist you today?", return_tensors="pt") + self.assertEqual( + inputs.input_ids.tolist(), + [[2, 9259, 236888, 2088, 740, 564, 6361, 611, 3124, 236881, 106]], + ) + return inputs + + @slow + def test_model_logits(self): + model = self.get_model() + inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) + + with torch.no_grad(): + logits = model(**inputs, use_cache=False, return_dict=True).logits.float().cpu() + + self.assertEqual(logits.shape, (1, inputs.input_ids.shape[-1], model.config.vocab_size)) + self.assertTrue(torch.isfinite(logits).all().item()) + + expected_argmax = torch.tensor([[105, 9731, 107, 740, 564, 1601, 611, 236881, 236881, 107, 107]]) + torch.testing.assert_close(logits.argmax(-1), expected_argmax) + + @slow + def test_model_cache_matches_full_forward(self): + model = self.get_model() + inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) + + with torch.no_grad(): + full_logits = model(**inputs, use_cache=False).logits[:, -1] + prefill_outputs = model( + input_ids=inputs.input_ids[:, :-1], + attention_mask=inputs.attention_mask[:, :-1], + use_cache=True, + return_dict=True, + ) + cached_logits = model( + input_ids=inputs.input_ids[:, -1:], + attention_mask=inputs.attention_mask, + past_key_values=prefill_outputs.past_key_values, + use_cache=True, + return_dict=True, + ).logits[:, -1] + + torch.testing.assert_close(cached_logits.float().cpu(), full_logits.float().cpu(), rtol=1e-4, atol=1e-4) + + @slow + def test_model_generation(self): + model = self.get_model() + inputs = self.get_inputs().to(model.model.embed_tokens.weight.device) + + with torch.no_grad(): + generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=3, top_k=None, top_p=None) + + self.assertEqual(generated_ids[0, -3:].tolist(), [107, 262146, 108]) diff --git a/tests/models/zaya1_vl/__init__.py b/tests/models/zaya1_vl/__init__.py new file mode 100644 index 000000000000..19b2f1409e84 --- /dev/null +++ b/tests/models/zaya1_vl/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/models/zaya1_vl/test_modeling_zaya1_vl.py b/tests/models/zaya1_vl/test_modeling_zaya1_vl.py new file mode 100644 index 000000000000..4d9111f6228f --- /dev/null +++ b/tests/models/zaya1_vl/test_modeling_zaya1_vl.py @@ -0,0 +1,116 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Zaya1-VL model.""" + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_torch + + +if is_torch_available(): + import torch + + from transformers import ( + Zaya1VLConfig, + Zaya1VLForConditionalGeneration, + Zaya1VLTextConfig, + ) + + +def _tiny_config(): + return Zaya1VLConfig( + text_config=Zaya1VLTextConfig( + vocab_size=128, + hidden_size=32, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + moe_intermediate_size=16, + num_experts=2, + router_hidden_size=4, + tie_word_embeddings=False, + ), + tie_word_embeddings=False, + image_token_id=127, + vision_start_token_id=126, + vision_end_token_id=125, + vision_config={ + "depth": 1, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 4, + "patch_size": 2, + "temporal_patch_size": 1, + "spatial_merge_size": 2, + "out_hidden_size": 32, + "fullatt_block_indexes": [0], + "window_size": 4, + }, + ) + + +@require_torch +class Zaya1VLModelTest(unittest.TestCase): + def test_image_forward(self): + config = _tiny_config() + model = Zaya1VLForConditionalGeneration(config).eval() + + input_ids = torch.tensor([[2, config.image_token_id, 5]]) + attention_mask = torch.ones_like(input_ids) + pixel_values = torch.randn( + 4, 3 * config.vision_config.temporal_patch_size * config.vision_config.patch_size**2 + ) + image_grid_thw = torch.tensor([[1, 2, 2]]) + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_router_logits=True, + ) + + self.assertEqual(outputs.logits.shape, (1, 3, config.text_config.vocab_size)) + self.assertEqual(len(outputs.router_logits), config.text_config.num_hidden_layers) + self.assertEqual(outputs.router_logits[0].shape, (3, config.text_config.num_experts + 1)) + + def test_image_generation(self): + config = _tiny_config() + model = Zaya1VLForConditionalGeneration(config).eval() + + input_ids = torch.tensor([[2, config.image_token_id, 5]]) + attention_mask = torch.ones_like(input_ids) + pixel_values = torch.randn( + 4, 3 * config.vision_config.temporal_patch_size * config.vision_config.patch_size**2 + ) + image_grid_thw = torch.tensor([[1, 2, 2]]) + + with torch.no_grad(): + generated_ids = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + max_new_tokens=2, + do_sample=False, + ) + + self.assertEqual(generated_ids.shape, (1, input_ids.shape[-1] + 2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/models/zaya1_vl/test_processing_zaya1_vl.py b/tests/models/zaya1_vl/test_processing_zaya1_vl.py new file mode 100644 index 000000000000..005d3d3b8cc5 --- /dev/null +++ b/tests/models/zaya1_vl/test_processing_zaya1_vl.py @@ -0,0 +1,96 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import is_tokenizers_available, is_torch_available +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_processing_utils import BaseImageProcessor +from transformers.testing_utils import require_tokenizers, require_torch + + +if is_torch_available(): + import torch + +if is_tokenizers_available(): + from tokenizers import Tokenizer + from tokenizers.models import WordLevel + from tokenizers.pre_tokenizers import Sequence, Split, WhitespaceSplit + + from transformers import PreTrainedTokenizerFast, Zaya1VLProcessor + + +class DummyZaya1VLImageProcessor(BaseImageProcessor): + model_input_names = ["pixel_values", "image_grid_thw"] + merge_size = 2 + + def __call__(self, images=None, **kwargs): + return BatchFeature( + { + "pixel_values": torch.zeros(4, 3), + "image_grid_thw": torch.tensor([[1, 4, 4]]), + } + ) + + def get_number_of_image_patches(self, height, width, images_kwargs): + return height * width + + +def get_tokenizer(): + tokenizer = Tokenizer(WordLevel({"": 0, "": 1, "": 2, "hello": 3}, unk_token="")) + tokenizer.pre_tokenizer = Sequence([Split("", behavior="isolated"), WhitespaceSplit()]) + tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer, unk_token="", pad_token="") + tokenizer.image_token = "" + tokenizer.image_token_id = tokenizer.convert_tokens_to_ids("") + return tokenizer + + +@require_torch +@require_tokenizers +class Zaya1VLProcessorTest(unittest.TestCase): + def get_processor(self): + return Zaya1VLProcessor(DummyZaya1VLImageProcessor(), get_tokenizer()) + + def test_image_token_expansion_without_default_token_type_ids(self): + inputs = self.get_processor()(text=" hello", images=[object()], return_tensors="pt") + + self.assertEqual(inputs.input_ids.tolist(), [[2, 2, 2, 2, 3]]) + self.assertEqual(inputs.pixel_values.shape, (4, 3)) + self.assertNotIn("mm_token_type_ids", inputs) + + def test_return_mm_token_type_ids(self): + inputs = self.get_processor()( + text=" hello", + images=[object()], + return_mm_token_type_ids=True, + return_tensors="pt", + ) + + self.assertEqual(inputs.mm_token_type_ids.tolist(), [[1, 1, 1, 1, 0]]) + + def test_get_num_multimodal_tokens(self): + output = self.get_processor()._get_num_multimodal_tokens(image_sizes=[(4, 4), (8, 4)]) + + self.assertEqual(output["num_image_patches"], [16, 32]) + self.assertEqual(output["num_image_tokens"], [4, 8]) + + def test_model_input_names_excludes_default_mm_token_type_ids(self): + self.assertEqual( + self.get_processor().model_input_names, + ["pixel_values", "image_grid_thw", "input_ids", "attention_mask"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index fcd3547a06c7..41a8f5cbbbfb 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -814,6 +814,7 @@ def test_num_layers_is_small(self): "Gemma3nVision2TextModelTest": 4, # need to test KV shared layer for both types: `full_attention` and `sliding_attention` "BeitModelTest": 4, # BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers "ZambaModelTest": 5, # The minimum number to test beyond the initial ["mamba", "mamba", "hybrid"] in `ZambaConfig._layers_block_type` + "ZayaModelTest": 4, # needs two passes over `hybrid` and `hybrid_sliding` layer types } target_num_hidden_layers = exceptional_num_hidden_layers.get(type(self).__name__, 2) diff --git a/utils/check_repo.py b/utils/check_repo.py index 5a7484409e31..658075a1155a 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -89,6 +89,7 @@ "Qwen2_5_VisionTransformerPretrainedModel", "Qwen3VLVisionModel", "Qwen3VLMoeVisionModel", + "Zaya1VLVisionModel", "Qwen3_5VisionModel", "Qwen3_5MoeVisionModel", "SwitchTransformersStack", @@ -253,6 +254,7 @@ "GlmOcrTextModel", # Building part of bigger (tested) model "Qwen2VLTextModel", # Building part of bigger (tested) model "Qwen2_5_VLTextModel", # Building part of bigger (tested) model + "Zaya1VLTextModel", # Building part of bigger (tested) model "MiniCPMV4_6Model", # Building part of bigger (tested) model. Tested implicitly through MiniCPMV4_6ForConditionalGeneration. "MiniCPMV4_6ForConditionalGeneration", # Tested in MiniCPMV4_6ModelTest via VLMModelTest; check_repo doesn't detect VLMModelTest.conditional_generation_class. "InternVLVisionModel", # Building part of bigger (tested) model