Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/zaya/configuration_zaya.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
apply_rope_fusion=True,
bias_activation_fusion=True,
activation_func_fp8_input_store=False,
router_aux_loss_coef=0.001,
sliding_window=None,
rope_scaling=None,
rope_parameters=None,
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(
self.residual_in_fp32 = residual_in_fp32
self.bias_activation_fusion = bias_activation_fusion
self.activation_func_fp8_input_store = activation_func_fp8_input_store
self.router_aux_loss_coef = router_aux_loss_coef
self.sliding_window = sliding_window
self.partial_rotary_factor = partial_rotary_factor
self.rope_theta = rope_theta
Expand Down
17 changes: 13 additions & 4 deletions src/transformers/models/zaya/modular_zaya.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
from ...utils.deprecation import deprecate_kwarg
from ...utils.import_utils import is_torch_fx_available
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from .configuration_zaya import ZayaConfig

if is_flash_attn_2_available():
Expand Down Expand Up @@ -183,6 +184,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
Applies Rotary Position Embedding to the query and key tensors. This version
Expand Down Expand Up @@ -215,6 +217,7 @@ class ZayaRotaryEmbedding(Glm4RotaryEmbedding):
pass


@use_kernel_forward_from_hub("RMSNorm")
class ZayaRMSNorm(LlamaRMSNorm):
pass

Expand Down Expand Up @@ -916,7 +919,7 @@ def forward(
}


class ZayaDecoderATTLayer(nn.Module):
class ZayaDecoderATTLayer(GradientCheckpointingLayer):
def __init__(self, config: ZayaConfig, layer_n: int, training: bool):

super().__init__()
Expand Down Expand Up @@ -1568,15 +1571,18 @@ class ZayaPreTrainedModel(PreTrainedModel):
config_class = ZayaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ZayaDecoderLayer"]
_no_split_modules = ["ZayaDecoderATTLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = False
# MoE models don't work with torch.compile (`torch.where(condition)` not
# supported)
_supports_static_cache = False
_can_record_outputs = {
"hidden_states": ZayaDecoderATTLayer,
"attentions": ZayaAttention,
}


Zaya_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -2073,6 +2079,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(


class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin):
_tp_plan = {"lm_head": "colwise_gather_output"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.config = config
Expand Down