From 5d1954b87cc2743ca58b2c0b3eb0a4e2ed1551de Mon Sep 17 00:00:00 2001 From: Vaibhav Jindal Date: Tue, 26 May 2026 20:45:00 +0000 Subject: [PATCH] Remove dead lce_forward_deprecated path from qwen2 lce_forward_deprecated in transformers/model/qwen2.py was the pre- logits_to_keep codepath. It has had zero references in src/, test/, or benchmark/ for some time and the active lce_forward right below it has been the only one wired into the monkey patch. Removing it also drops four imports that only it used (CrossEntropyLoss, CausalLMOutputWithPast, LigerFusedLinearCrossEntropyLoss, and the redundant blank line). Net: -120 LOC, no behavior change. --- src/liger_kernel/transformers/model/qwen2.py | 119 ------------------- 1 file changed, 119 deletions(-) diff --git a/src/liger_kernel/transformers/model/qwen2.py b/src/liger_kernel/transformers/model/qwen2.py index 6d43caada..7a40ab761 100644 --- a/src/liger_kernel/transformers/model/qwen2.py +++ b/src/liger_kernel/transformers/model/qwen2.py @@ -5,130 +5,11 @@ import torch -from torch.nn import CrossEntropyLoss -from transformers.modeling_outputs import CausalLMOutputWithPast - -from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast -def lce_forward_deprecated( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - skip_logits: Optional[bool] = None, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy - - - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained("Qwen/Qwen2-1.5B") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - - loss = None - logits = None - - if skip_logits and labels is None: - raise ValueError("skip_logits is True, but labels is None") - - if skip_logits is None: - # By default, if in training mode, don't materialize logits - skip_logits = self.training and labels is not None - - if self.training and (labels is not None): - shift_hidden_states = hidden_states[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - - # flatten tokens - shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) - shift_labels = shift_labels.view(-1) - - lce = LigerFusedLinearCrossEntropyLoss() - loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) - - else: - logits = self.lm_head(hidden_states) - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def lce_forward( self, input_ids: torch.LongTensor = None,