[new model] Add Zyphra/ZAYA1-8B#45862
Conversation
vasqu
left a comment
There was a problem hiding this comment.
The model is a bit more complicated so please let me know if certain stuff is unclear I tried to nudge towards a better structure. The main points are
- The CCA module is way too complicated for what it essentially tries, I tried to simplify a bit
- The split into layers is unnecessary, they should be one decoder layer with mlp and attn; this also fixes the residuals paths
- Modular can be used a lot more; the current code relies on a lot of v4 specific things / remote code but we actually don't need a lot of those
- The cache is somewhat natively integrated within the hybrid layer type
- RoPE can have its own layer types (see dsv4); it seems to me that we actually don't use SWA at all but it was used as workaround which is bad
It is super detailed this time, lmk if you would like a less detailed one next time. I'm usually inclined to go full force but some don't like that :D
| if output_attentions: | ||
| all_self_attns += (layer_outputs[1],) | ||
|
|
||
| hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm) |
There was a problem hiding this comment.
seeing the order of those residual I feel like the order is just messed up as we split the layer types
You want attn -> residual -> mlp -> residual but because of the implementation you have skip first residual -> attn -> residual -> mlp -> residual which could be fixed if we just fuse properly into the one decoder layer with 2 residuals each time
There was a problem hiding this comment.
still not resolved 😉
There was a problem hiding this comment.
ops, i forgot about this. now i shift the res_scale by one layer earlier, so we can avoid the residual between layers! 🫡
|
Thank you very much for the rebase and cleanups! About the interleaved sliding window / rope theta confusion, these configs are in place for the ZAYA1-74B-preview model which also uses this branch but has 4k-SWA / 10k rope base every other layer. All of @vasqu's suggestions sound good to me, and another user has found additional fixes to support GRPO trainer on this branch (PR). @JJJYmmm would you be able to integrate all the changes into your branch? Thanks again for this PR. |
but need to construct cache from _make_zaya_cache
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@vasqu thank you for the detailed review! it was really helpful for me to catch up with the latest changes and learn the unified code style, so that’s totally ok. thanks a lot for your time 😃 in the latest code, i fixed most of the inheritance issues. since the original checkpoint has some uncommon kwargs / weight layouts, which would require a lot of custom code, i wrote a conversion script and uploaded the converted 8b checkpoint here: https://huggingface.co/JJJYmmm/ZAYA1-8B-HF. i also tested it with a fake 74b checkpoint with swa. the conversion mainly does three things:
what do you think about this conversion? 🫡 @nanduruganesh |
@nanduruganesh i also checked this pr, and most of the fixes are already covered in the current branch. the only exception seems to be besides the conversion mentioned above, i also noticed a small detail in the official code about the SWA mask calculation. in the official branch, the swa mask is: if window_size > 0:
causal_mask = (
torch.ones((seq_length, seq_length), dtype=torch.bool, device=query_states.device)
.tril_(diagonal=0)
.triu_(diagonal=-window_size)
)
attn_weights.masked_fill_(~causal_mask, -1e4)
elif attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_maskthis means that in the swa branch, EDIT: another reminder: in the conversion script, i increase swa |
ArthurZucker
left a comment
There was a problem hiding this comment.
Super small review, exciting! 🔥
| if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)): | ||
| if past_key_values is not None and past_key_values.get_seq_length() > 0: | ||
| raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.") | ||
| past_key_values = make_zaya_cache(self.config) |
There was a problem hiding this comment.
let's prevent having to add this and use a simple dynamic cache, registering the layer in
transformers/src/transformers/cache_utils.py
Lines 871 to 887 in cc832f9
There was a problem hiding this comment.
yes, i already reused the hybrid mapping. the current issue is this one: #45862 (comment)
to solve it in a simple way, i changed the layer_types logic from:
transformers/src/transformers/cache_utils.py
Line 1286 in cc832f9
to:
getattr(decoder_config, "cache_layer_types", None) or getattr(decoder_config, "layer_types", None)so models like zaya can keep layer_types for attention variants, while using cache_layer_types to describe the cache layout.
There was a problem hiding this comment.
Let's use the hybrid and hybrid sliding for all in the end 🫡 see my earlier/first comments
There was a problem hiding this comment.
add a new mapping for zaya 🫡
"hybrid_sliding": LinearAttentionAndSlidingWindowAttentionLayer| if output_attentions: | ||
| all_self_attns += (layer_outputs[1],) | ||
|
|
||
| hidden_states, residual = _apply_residual_scaling(hidden_states, residual, self.res_scale, self.final_norm) |
There was a problem hiding this comment.
still not resolved 😉
vasqu
left a comment
There was a problem hiding this comment.
Heya another round, the focus is on really avoiding passing too many args and let these values live in the config.
One of my main ideas tbh is to introduce a sliding hybrid type that would be hybrid but SWA --> then we can use the same layer types across cache and masks. (instead of this current split)
Other than that, mostly details as we try to keep naming conventions the same where we can.
| if use_cache and (past_key_values is None or not _is_zaya_cache(past_key_values)): | ||
| if past_key_values is not None and past_key_values.get_seq_length() > 0: | ||
| raise ValueError("ZAYA requires a native hybrid cache created from `make_zaya_cache`.") | ||
| past_key_values = make_zaya_cache(self.config) |
There was a problem hiding this comment.
Let's use the hybrid and hybrid sliding for all in the end 🫡 see my earlier/first comments
vasqu
left a comment
There was a problem hiding this comment.
Looking pretty good now, I think we are getting close to merge. Just a few more details here and there 🤗 thanks a lot for all the iterating
vasqu
left a comment
There was a problem hiding this comment.
Ok so it is mostly ready (few small details). I think the biggest things to manage would be
- Where would the converted checkpoints live - either a new
-hfrepo living under https://huggingface.co/Zyphra directly (cc @nanduruganesh) or we open a community to upload these - TP specifications, would try to take a look at some point (after this round)
- Clarification on the mask behavior
| ```python | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| model_id = "Zyphra/ZAYA1-8B" |
There was a problem hiding this comment.
I guess we need another repo for these then since the weights need to be restructured 🤔 cc @nanduruganesh
Best would be to have some other repo with the -hf suffix
There was a problem hiding this comment.
My plan is to update "Zyphra/ZAYA1-8B" for this upstream merge and move the current checkpoint there to "Zyphra/ZAYA1-8B-Legacy" to support people still on the old runtime (e.g. the vllm / llama cpp currently still depends on old checkpoint).
| cca_time1: int = 2 | ||
|
|
||
| # Fields declared by LagunaConfig but not used by ZAYA. | ||
| # TODO: add TP/PP plans. TP needs the router mlp, moe experts, and CCA projections to shard consistently; PP needs coverage for the cross-layer router state. |
There was a problem hiding this comment.
Will try to check TP at some point. I feel like it shouldn't be too complicated because only the MoE might be more special (and we could ignore that for now)
There was a problem hiding this comment.
Since we do CCA with GQA, TP-version of CCA will need a special gather to collect all the query heads for the QK-mean-add operation
(currently key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2), but with TP the query_residual would be sharded across devices)
There was a problem hiding this comment.
Are we not fine if we are on a subset of self.num_key_value_heads as TP degree (e.g. 2 kv heads, TP=2)? It depends on the split that is created across the q heads no?
But agree in general, you are correct, i.e. we need to gather the heads there. That's a good point
There was a problem hiding this comment.
Yes if TP <= num kv heads we do not need to gather
|
|
||
| class ZayaCCAProjection(nn.Module): | ||
| """ | ||
| Projects hidden states into attention q/k/v states with ZAYA's CCA path. |
There was a problem hiding this comment.
What does CCA exactly stand for? Could we expand the abbreviation once at least?
There was a problem hiding this comment.
Compressed Convolutional Attention, yes good idea
There was a problem hiding this comment.
Rebump - let's make this (once) extended in the docs so people know what the abbreviaton stands for
|
Will try to take a look a bit later today @JJJYmmm, I'm out for the week after today - just wanted to notify so you arent surprised why Im suddenly not answering / responding Edit: I think I responded to all things for now and a few comments are still left so waiting for now. TP seems to be the only somewhat harder case but waiting on Ferdinand for his opinion because it definitely won't be as straightforward as I expected (kv heads == 2 destroys a lot of assumptions 😢) |
vasqu
left a comment
There was a problem hiding this comment.
I think we have mostly smaller comments now and are pretty much ready. Now it's onto preparing the hub repo if possible cc @nanduruganesh
For now, the only things missing are proper TP (and PP) support but given the complexity, we skip it for now.
|
|
||
| class ZayaCCAProjection(nn.Module): | ||
| """ | ||
| Projects hidden states into attention q/k/v states with ZAYA's CCA path. |
There was a problem hiding this comment.
Rebump - let's make this (once) extended in the docs so people know what the abbreviaton stands for
| pass | ||
|
|
||
|
|
||
| class ZayaSparseMoeBlock(nn.Module): |
There was a problem hiding this comment.
nit: we could inherit from gpt oss here (for the init when we rename gate <-> router), seems a bit more fitting
There was a problem hiding this comment.
zaya needs ZayaRouter(config, layer_idx) because of the use_eda, and the GptOssMLP decorator @use_kernel_forward_from_hub("MegaBlocksMoeMLP") also needs to be overridden for zaya since it takes an extra input prev_router_hidden_states. so I think the current version is clear enough 😄
There was a problem hiding this comment.
Gotcha, yea fair enough! Maybe something that tells modular to not inherit decorators would be nice? Feel like I've encountered this a few times, but not too often
There was a problem hiding this comment.
agree! maybe adding a no-op decorator like @no_inherited_decorators is fine. I’ll take a look later and open another PR if necessary. 🫡
|
Failing CI tests should be fixed on main |
| @require_torch | ||
| class ZayaIntegrationTest(unittest.TestCase): | ||
| model = None | ||
| model_id = "Zyphra/ZAYA1-8B" |
There was a problem hiding this comment.
Is this even correct, we need a converted version or was this already done?
|
Sorry about spam pinging @nanduruganesh 🙏 Thanks a lot on all this work @JJJYmmm 🤗 |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, zaya |
Zyphra recently released ZAYA1-VL-8B, which has a small number of active parameters and looks like a nice fit here.
Since ZAYA1-VL depends on the text-only ZAYA1-8B backbone, which has not been merged yet, this PR adds support for the text-only ZAYA1 backbone first. I can follow up with the VL model in a separate PR if preferred. 😃
I also noticed that #42669 worked on a similar integration. This PR updates the implementation to better fit the current v5 codebase, including a cleaner CCA cache design and other code cleanups. I also checked the numerical outputs against Zyphra's implementation: https://github.com/Zyphra/transformers/tree/zaya1 cc @nanduruganesh
Tests: