diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index dcf796651..7db8c2e91 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -173,6 +173,7 @@ def forward( loss_type="sigmoid", label_smoothing=0.0, discopop_tau=0.05, + alpha=1.0, ): """ Fused linear layer with DPO loss. @@ -185,7 +186,7 @@ def forward( ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size) ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,) ignore_index (int): Index to ignore in loss computation - beta (float): Weight for the odds ratio loss + beta (float): Weight for the direct preference loss compute_nll_loss (bool): Whether to compute the NLL loss compiled (bool): Whether to use torch compile use_ref_model (bool): Whether to use a reference model @@ -194,6 +195,7 @@ def forward( loss_type (str): Variant of DPO loss to compute. label_smoothing (float): Label smoothing for "robust" / "exo_pair" / cDPO. discopop_tau (float): Temperature for the DiscoPOP modulation term. + alpha (float): Weight for the NLL loss component Returns: torch.Tensor: Computed loss """ @@ -206,6 +208,7 @@ def forward( bias=bias, ignore_index=ignore_index, beta=beta, + alpha=alpha, compute_nll_loss=compute_nll_loss, compiled=compiled, use_ref_model=use_ref_model, @@ -222,7 +225,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): @@ -247,6 +250,7 @@ def __init__( self, ignore_index: int = -100, beta: float = 0.1, + alpha: float = 1.0, compute_nll_loss: bool = False, compiled: bool = True, use_ref_model: bool = True, @@ -259,7 +263,8 @@ def __init__( """ Args: ignore_index (int): Index to ignore in the loss. - beta (float): Weight for the odds ratio loss. + beta (float): Weight for the direct preference loss. + alpha (float): Weight for the NLL loss component. compute_nll_loss (bool): Whether to compute the NLL loss. compiled (bool): Whether to use the torch compiled kernel. use_ref_model (bool): Whether to use a reference model for the DPO loss. @@ -274,6 +279,7 @@ def __init__( super().__init__() self.ignore_index = ignore_index self.beta = beta + self.alpha = alpha self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.use_ref_model = use_ref_model @@ -321,4 +327,5 @@ def forward( self.loss_type, self.label_smoothing, self.discopop_tau, + self.alpha, ) diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 600c79f46..144cb341e 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -869,8 +869,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_input, ref_weight1, ref_bias1, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta compute_nll_loss, ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( @@ -881,8 +881,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_input, ref_weight2, ref_bias2, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta compute_nll_loss, ) @@ -1112,14 +1112,14 @@ def test_correctness_functional_apo_loss_types( ref_input, ref_weight1, ref_bias1, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta compute_nll_loss, True, # compiled True, # use_ref_model False, # average_log_prob 1, # chunk_size - loss_type, # loss_type + loss_type, ) # For comparison, create a LigerFusedLinearDPOLoss with the loss_type @@ -1315,6 +1315,105 @@ def test_invalid_loss_type(): assert loss_fn.loss_type == loss_type +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_alpha_scales_nll_loss(dtype): + """ + Verify that alpha is actually forwarded and scales the NLL component. + With compute_nll_loss=True, loss(alpha=2) should differ from loss(alpha=1). + """ + B, T, H, V = 4, 16, 32, 64 + atol = 1e-4 if dtype == torch.float32 else 5e-2 + + _weight = torch.randn(V, H, device=device, dtype=dtype) + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + _input = torch.randn(B, T, H, device=device, dtype=dtype) + target = torch.randint(0, V, (B, T), device=device, dtype=torch.long) + + def run(alpha): + inp = _input.detach().clone().requires_grad_(True) + w = _weight.detach().clone().requires_grad_(True) + rw = _ref_weight.detach().clone().requires_grad_(True) + loss_fn = LigerFusedLinearDPOLoss( + beta=0.1, + alpha=alpha, + compute_nll_loss=True, + use_ref_model=True, + average_log_prob=False, + ) + loss, _ = loss_fn(w, inp, target, None, _input.detach(), rw, None) + return loss + + loss_alpha1 = run(alpha=1.0) + loss_alpha2 = run(alpha=2.0) + + assert not torch.allclose(loss_alpha1, loss_alpha2, atol=atol), ( + f"Expected losses to differ when alpha changes, but got {loss_alpha1} vs {loss_alpha2}" + ) + + +def test_functional_positional_arg_contract(): + """ + Pin the positional-argument contract of the public functional alias. + + `alpha` is appended at the *end* of `LigerFusedLinearDPOFunction.forward` (not + inserted mid-list) precisely so that existing positional `.apply()` / + `liger_fused_linear_dpo` callers don't silently shift every later argument by + one slot. This test exercises the pre-PR positional list (no `alpha`) and + asserts it produces the same result as the keyword-driven `nn.Module` wrapper. + If a future param insertion shifts the positional slots, this diverges. + """ + B, T, H, V = 4, 8, 16, 32 + dtype = torch.float32 + + _input = torch.randn(B, T, H, device=device, dtype=dtype) + target = torch.randint(0, V, (B, T), device=device, dtype=torch.long) + _weight = torch.randn(V, H, device=device, dtype=dtype) + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) + + # Pre-PR positional list: the args after `beta` are + # compute_nll_loss, compiled, use_ref_model, average_log_prob, chunk_size, loss_type. + loss_positional, _ = liger_fused_linear_dpo( + _input.detach().clone().requires_grad_(True), + _weight.detach().clone().requires_grad_(True), + target, + None, # bias + ref_input, + _ref_weight.detach().clone().requires_grad_(True), + None, # ref_bias + -100, # ignore_index + 0.1, # beta + True, # compute_nll_loss + True, # compiled + True, # use_ref_model + False, # average_log_prob + 1, # chunk_size + "sigmoid", # loss_type + ) + + loss_module, _ = LigerFusedLinearDPOLoss( + ignore_index=-100, + beta=0.1, + alpha=1.0, + compute_nll_loss=True, + compiled=True, + use_ref_model=True, + average_log_prob=False, + chunk_size=1, + loss_type="sigmoid", + )( + _weight.detach().clone().requires_grad_(True), + _input.detach().clone().requires_grad_(True), + target, + None, # bias + ref_input, + _ref_weight.detach().clone().requires_grad_(True), + None, # ref_bias + ) + + assert_verbose_allclose(loss_positional, loss_module, atol=1e-5, rtol=1e-4) + + def test_label_smoothing_validation(): """Test that invalid label_smoothing values raise ValueError for the relevant loss types.""" with pytest.raises(ValueError, match="label_smoothing must be > 0 for loss_type='exo_pair'"):