Skip to content

[new model] Add Zyphra/ZAYA1-8B#45862

Open
JJJYmmm wants to merge 39 commits into
huggingface:mainfrom
JJJYmmm:add_zaya1
Open

[new model] Add Zyphra/ZAYA1-8B#45862
JJJYmmm wants to merge 39 commits into
huggingface:mainfrom
JJJYmmm:add_zaya1

Conversation

@JJJYmmm
Copy link
Copy Markdown
Contributor

@JJJYmmm JJJYmmm commented May 9, 2026

Zyphra recently released ZAYA1-VL-8B, which has a small number of active parameters and looks like a nice fit here.

Since ZAYA1-VL depends on the text-only ZAYA1-8B backbone, which has not been merged yet, this PR adds support for the text-only ZAYA1 backbone first. I can follow up with the VL model in a separate PR if preferred. 😃

I also noticed that #42669 worked on a similar integration. This PR updates the implementation to better fit the current v5 codebase, including a cleaner CCA cache design and other code cleanups. I also checked the numerical outputs against Zyphra's implementation: https://github.com/Zyphra/transformers/tree/zaya1 cc @nanduruganesh

Tests:

RUN_SLOW=1 python -m pytest tests/models/zaya/test_modeling_zaya.py -q

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model is a bit more complicated so please let me know if certain stuff is unclear I tried to nudge towards a better structure. The main points are

  1. The CCA module is way too complicated for what it essentially tries, I tried to simplify a bit
  2. The split into layers is unnecessary, they should be one decoder layer with mlp and attn; this also fixes the residuals paths
  3. Modular can be used a lot more; the current code relies on a lot of v4 specific things / remote code but we actually don't need a lot of those
  4. The cache is somewhat natively integrated within the hybrid layer type
  5. RoPE can have its own layer types (see dsv4); it seems to me that we actually don't use SWA at all but it was used as workaround which is bad

It is super detailed this time, lmk if you would like a less detailed one next time. I'm usually inclined to go full force but some don't like that :D

Comment thread docs/source/en/model_doc/zaya.md
Comment thread docs/source/en/model_doc/zaya.md Outdated
Comment thread docs/source/en/model_doc/zaya.md Outdated
Comment thread src/transformers/models/zaya/__init__.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seeing the order of those residual I feel like the order is just messed up as we split the layer types

You want attn -> residual -> mlp -> residual but because of the implementation you have skip first residual -> attn -> residual -> mlp -> residual which could be fixed if we just fuse properly into the one decoder layer with 2 residuals each time

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still not resolved 😉

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops, i forgot about this. now i shift the res_scale by one layer earlier, so we can avoid the residual between layers! 🫡

@nanduruganesh
Copy link
Copy Markdown

Thank you very much for the rebase and cleanups! About the interleaved sliding window / rope theta confusion, these configs are in place for the ZAYA1-74B-preview model which also uses this branch but has 4k-SWA / 10k rope base every other layer. All of @vasqu's suggestions sound good to me, and another user has found additional fixes to support GRPO trainer on this branch (PR). @JJJYmmm would you be able to integrate all the changes into your branch? Thanks again for this PR.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@JJJYmmm
Copy link
Copy Markdown
Contributor Author

JJJYmmm commented May 12, 2026

@vasqu thank you for the detailed review! it was really helpful for me to catch up with the latest changes and learn the unified code style, so that’s totally ok. thanks a lot for your time 😃

in the latest code, i fixed most of the inheritance issues. since the original checkpoint has some uncommon kwargs / weight layouts, which would require a lot of custom code, i wrote a conversion script and uploaded the converted 8b checkpoint here: https://huggingface.co/JJJYmmm/ZAYA1-8B-HF. i also tested it with a fake 74b checkpoint with swa.

the conversion mainly does three things:

  1. use more common names in the config, e.g. intermediate_size, and update some fields like rope_parameters
  2. remove nn.Sequential and use explicit module names
  3. combine the separate attention and mlp layers into a single ZayaDecoderLayer, with the corresponding config fixes, e.g. num_hidden_layers: 80 -> 40
  4. 3d experts

what do you think about this conversion? 🫡 @nanduruganesh

@JJJYmmm
Copy link
Copy Markdown
Contributor Author

JJJYmmm commented May 12, 2026

another user has found additional fixes to support GRPO trainer on this branch (Zyphra#2)

@nanduruganesh i also checked this pr, and most of the fixes are already covered in the current branch. the only exception seems to be 8. router_aux_loss_coef, but i think zaya does not use an auxiliary loss, right?

besides the conversion mentioned above, i also noticed a small detail in the official code about the SWA mask calculation. in the official branch, the swa mask is:

if window_size > 0:
    causal_mask = (
        torch.ones((seq_length, seq_length), dtype=torch.bool, device=query_states.device)
        .tril_(diagonal=0)
        .triu_(diagonal=-window_size)
    )
    attn_weights.masked_fill_(~causal_mask, -1e4)
elif attention_mask is not None:  # no matter the length, we just slice it
    causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
    attn_weights = attn_weights + causal_mask

this means that in the swa branch, attention_mask is discarded. for now, i kept the same behavior to preserve numerical consistency with the original implementation. is this expected?

EDIT: another reminder: in the conversion script, i increase swa window_size by one (4096 -> 4097).
this is because in the original logic, .tril_(diagonal=0).triu_(diagonal=-window_size) means the current query can attend to window_size keys. but in the current branch, window_size means the total window size directly.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super small review, exciting! 🔥

Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)):
if past_key_values is not None and past_key_values.get_seq_length() > 0:
raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.")
past_key_values = make_zaya_cache(self.config)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's prevent having to add this and use a simple dynamic cache, registering the layer in

LAYER_TYPE_CACHE_MAPPING.update(
{
"full_attention": DynamicLayer,
# From a cache point of view, sliding and chunked are the same in how they should behave;
# only the mask differs.
"sliding_attention": DynamicSlidingWindowLayer,
"chunked_attention": DynamicSlidingWindowLayer,
# Linear-attention-shaped layers (mamba / conv / pure linear-attention / moe placeholders)
# don't grow per-token KV; they're tracked just so position bookkeeping stays consistent.
"mamba": LinearAttentionLayer,
"conv": LinearAttentionLayer,
"linear_attention": LinearAttentionLayer,
"moe": LinearAttentionLayer,
# Hybrid layers (e.g. zamba / zamba2) carry both a linear-attention state and a dynamic-attention state.
"hybrid": LinearAttentionAndFullAttentionLayer,
}
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i already reused the hybrid mapping. the current issue is this one: #45862 (comment)

to solve it in a simple way, i changed the layer_types logic from:

layer_types = getattr(decoder_config, "layer_types", None)

to:

getattr(decoder_config, "cache_layer_types", None) or getattr(decoder_config, "layer_types", None)

so models like zaya can keep layer_types for attention variants, while using cache_layer_types to describe the cache layout.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the hybrid and hybrid sliding for all in the end 🫡 see my earlier/first comments

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a new mapping for zaya 🫡

"hybrid_sliding": LinearAttentionAndSlidingWindowAttentionLayer

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still not resolved 😉

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heya another round, the focus is on really avoiding passing too many args and let these values live in the config.

One of my main ideas tbh is to introduce a sliding hybrid type that would be hybrid but SWA --> then we can use the same layer types across cache and masks. (instead of this current split)

Other than that, mostly details as we try to keep naming conventions the same where we can.

Comment thread src/transformers/models/zaya/__init__.py Outdated
Comment thread src/transformers/models/zaya/convert_zaya_weights_to_hf.py
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py
if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)):
if past_key_values is not None and past_key_values.get_seq_length() > 0:
raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.")
past_key_values = make_zaya_cache(self.config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the hybrid and hybrid sliding for all in the end 🫡 see my earlier/first comments

Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking pretty good now, I think we are getting close to merge. Just a few more details here and there 🤗 thanks a lot for all the iterating

Comment thread src/transformers/models/auto/modeling_auto.py
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread tests/models/zaya/test_modeling_zaya.py Outdated
Comment thread tests/models/zaya/test_modeling_zaya.py Outdated
Comment thread tests/models/zaya/test_modeling_zaya.py Outdated
Comment thread tests/models/zaya/test_modeling_zaya.py Outdated
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so it is mostly ready (few small details). I think the biggest things to manage would be

  1. Where would the converted checkpoints live - either a new -hf repo living under https://huggingface.co/Zyphra directly (cc @nanduruganesh) or we open a community to upload these
  2. TP specifications, would try to take a look at some point (after this round)
  3. Clarification on the mask behavior

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "Zyphra/ZAYA1-8B"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need another repo for these then since the weights need to be restructured 🤔 cc @nanduruganesh

Best would be to have some other repo with the -hf suffix

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My plan is to update "Zyphra/ZAYA1-8B" for this upstream merge and move the current checkpoint there to "Zyphra/ZAYA1-8B-Legacy" to support people still on the old runtime (e.g. the vllm / llama cpp currently still depends on old checkpoint).

cca_time1: int = 2

# Fields declared by LagunaConfig but not used by ZAYA.
# TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try to check TP at some point. I feel like it shouldn't be too complicated because only the MoE might be more special (and we could ignore that for now)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we do CCA with GQA, TP-version of CCA will need a special gather to collect all the query heads for the QK-mean-add operation
(currently key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2), but with TP the query_residual would be sharded across devices)

Copy link
Copy Markdown
Contributor

@vasqu vasqu May 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not fine if we are on a subset of self.num_key_value_heads as TP degree (e.g. 2 kv heads, TP=2)? It depends on the split that is created across the q heads no?

But agree in general, you are correct, i.e. we need to gather the heads there. That's a good point

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if TP <= num kv heads we do not need to gather


class ZayaCCAProjection(nn.Module):
"""
Projects hidden states into attention q/k/v states with ZAYA's CCA path.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does CCA exactly stand for? Could we expand the abbreviation once at least?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebump - let's make this (once) extended in the docs so people know what the abbreviaton stands for

Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/cache_utils.py
Comment thread tests/models/zaya/test_modeling_zaya.py Outdated
Comment thread tests/models/zaya/test_modeling_zaya.py Outdated
Comment thread tests/models/zaya/test_modeling_zaya.py
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented May 19, 2026

Will try to take a look a bit later today @JJJYmmm, I'm out for the week after today - just wanted to notify so you arent surprised why Im suddenly not answering / responding

Edit: I think I responded to all things for now and a few comments are still left so waiting for now. TP seems to be the only somewhat harder case but waiting on Ferdinand for his opinion because it definitely won't be as straightforward as I expected (kv heads == 2 destroys a lot of assumptions 😢)

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have mostly smaller comments now and are pretty much ready. Now it's onto preparing the hub repo if possible cc @nanduruganesh

For now, the only things missing are proper TP (and PP) support but given the complexity, we skip it for now.

Comment thread src/transformers/models/zaya/modular_zaya.py Outdated

class ZayaCCAProjection(nn.Module):
"""
Projects hidden states into attention q/k/v states with ZAYA's CCA path.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebump - let's make this (once) extended in the docs so people know what the abbreviaton stands for

Comment thread src/transformers/models/zaya/modular_zaya.py
pass


class ZayaSparseMoeBlock(nn.Module):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we could inherit from gpt oss here (for the init when we rename gate <-> router), seems a bit more fitting

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zaya needs ZayaRouter(config, layer_idx) because of the use_eda, and the GptOssMLP decorator @use_kernel_forward_from_hub("MegaBlocksMoeMLP") also needs to be overridden for zaya since it takes an extra input prev_router_hidden_states. so I think the current version is clear enough 😄

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, yea fair enough! Maybe something that tells modular to not inherit decorators would be nice? Feel like I've encountered this a few times, but not too often

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree! maybe adding a no-op decorator like @no_inherited_decorators is fine. I’ll take a look later and open another PR if necessary. 🫡

Comment thread src/transformers/models/zaya/modular_zaya.py
Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
Comment thread src/transformers/models/zaya/modular_zaya.py
Comment thread tests/models/zaya/test_modeling_zaya.py Outdated
Comment thread tests/models/zaya/test_modeling_zaya.py
Comment thread tests/models/zaya/test_modeling_zaya.py Outdated
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented May 25, 2026

Failing CI tests should be fixed on main

@require_torch
class ZayaIntegrationTest(unittest.TestCase):
model = None
model_id = "Zyphra/ZAYA1-8B"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this even correct, we need a converted version or was this already done?

Comment thread src/transformers/models/zaya/modular_zaya.py Outdated
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented May 26, 2026

Sorry about spam pinging @nanduruganesh 🙏
From the transformers side we are pretty much ready now, we only need the converted versions on the hub (and adjust the integration tests possibly, e.g. different model id).

Thanks a lot on all this work @JJJYmmm 🤗

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, zaya

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants