Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 153 additions & 26 deletions specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config):
def __init__(self, config, fused_input=True):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
Expand All @@ -524,14 +524,19 @@ def __init__(self, config):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings

if fused_input:
in_features_size = self.hidden_size * 2
else:
in_features_size = self.hidden_size

self.q_proj = nn.Linear(
self.hidden_size * 2, self.num_heads * self.head_dim, bias=False
in_features_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False
in_features_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False
in_features_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
Expand Down Expand Up @@ -972,8 +977,8 @@ class LlamaUSPFlashAttention(LlamaAttention):
LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging.
"""

def __init__(self, config):
super().__init__(config)
def __init__(self, config, fused_input=True):
super().__init__(config, fused_input=fused_input)
assert (
dist.is_initialized()
), f"LlamaUSPAttention requires torch.distributed; call init_distributed first."
Expand Down Expand Up @@ -1239,6 +1244,70 @@ def forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


class BasicDecoderLayer(nn.Module):
"""
The traditional decoder layer.
"""
def __init__(self, config, attention_backend: str = "sdpa"):
super().__init__()
self.hidden_size = config.hidden_size

if attention_backend == "sdpa":
self.self_attn = LlamaAttention(config=config, fused_input=False)
elif attention_backend == "flex_attention":
print_with_rank("Using flex attention on draft model training!")
self.self_attn = LlamaFlexAttention(config=config, fused_input=False)
elif attention_backend == "fa":
self.self_attn = LlamaFlashAttention(config=config, fused_input=False)
elif attention_backend == "usp":
self.self_attn = LlamaAttention(config=config, fused_input=False)
else:
raise ValueError(f"Unknown attention backend {attention_backend}")

self.attention_backend = attention_backend
self.mlp = LlamaMLP(config)

self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Basic decoder layer forward pass with self-attention and mlp.
"""
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)

# First residual connection
hidden_states = residual + hidden_states

# Feed Forward Network with res connection
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states


class LlamaDecoderLayer(nn.Module):
def __init__(self, config, attention_backend: str = "sdpa"):
super().__init__()
Expand Down Expand Up @@ -1322,6 +1391,71 @@ def forward(
return hidden_states


class LlamaMultiLayerDecoder(nn.Module):
def __init__(self, config, attention_backend: str = "sdpa"):
super().__init__()
self.config = config
self.num_additional_layers = config.num_hidden_layers - 1
# initialize the fuse layer
self.fuselayer = LlamaDecoderLayer(config, attention_backend=attention_backend)

# initialize additional decoder layers
self.additional_layers = None
self.final_layernorm = None
if self.num_additional_layers > 0:
self.additional_layers = nn.ModuleList(
[
BasicDecoderLayer(config, attention_backend=attention_backend)
for _ in range(self.num_additional_layers)
]
)

self.final_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
input_emb: torch.Tensor,
hidden_states: torch.Tensor,
caches_hidden: Optional[List[List[List[torch.Tensor]]]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Forward of multi-layer decoder.
"""
hidden_states = self.fuselayer(
input_emb=input_emb,
hidden_states=hidden_states,
cache_hidden=caches_hidden[0] if caches_hidden is not None else None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=(
past_key_values[0] if past_key_values is not None else None
),
output_attentions=False,
use_cache=False,
)

if self.num_additional_layers > 0:
for i, layer in enumerate(self.additional_layers):
hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values[i+1] if past_key_values is not None else None,
use_cache=use_cache
)

hidden_states = self.final_layernorm(hidden_states)

return hidden_states


class LlamaForCausalLMEagle3(Eagle3DraftModel):

config_class = LlamaConfig
Expand All @@ -1337,12 +1471,8 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
config.vocab_size, config.hidden_size, config.pad_token_id
)
self.num_hidden_layers = config.num_hidden_layers
self.midlayers = nn.ModuleList(
[
LlamaDecoderLayer(config, attention_backend=attention_backend)
for _ in range(self.num_hidden_layers)
]
)

self.midlayers = LlamaMultiLayerDecoder(config)

if hasattr(config, "target_hidden_size"):
self.fc = torch.nn.Linear(
Expand Down Expand Up @@ -1446,17 +1576,14 @@ def backbone(
past_key_values: Optional[List[Cache]] = None,
use_cache: bool = True,
) -> torch.Tensor:
for i, layer in enumerate(self.midlayers):
hidden_states = layer(
input_emb=input_embeds,
hidden_states=hidden_states,
cache_hidden=caches_hidden[i] if caches_hidden is not None else None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=(
past_key_values[i] if past_key_values is not None else None
),
output_attentions=False,
use_cache=False,
)
return hidden_states

return self.midlayers(
input_emb=input_embeds,
hidden_states=hidden_states,
caches_hidden=caches_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=False,
use_cache=False,
)
Loading