From 314f4f81feef41cc9af69ae8e89a6b9dacf5b570 Mon Sep 17 00:00:00 2001 From: micdoh Date: Thu, 16 Apr 2026 12:01:27 +0100 Subject: [PATCH 1/3] fix: expose alpha param through LigerFusedLinearDPO public API LigerFusedLinearDPOFunction.forward() accepted every base-class parameter except alpha, so the NLL scaling weight silently defaulted to 1.0 regardless of what callers passed. This adds alpha to both the Function and the LigerFusedLinearDPOLoss module, fixes the positional-arg order in the existing functional tests, and adds a regression test that verifies alpha actually affects the loss value when compute_nll_loss=True. --- src/liger_kernel/chunked_loss/dpo_loss.py | 13 +++-- test/chunked_loss/test_dpo_loss.py | 59 +++++++++++++++++++---- 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index f7a14e539..633aebf43 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -108,6 +108,7 @@ def forward( ref_bias=None, ignore_index=-100, beta=0.1, + alpha=1.0, compute_nll_loss=False, compiled=True, use_ref_model=True, @@ -126,7 +127,8 @@ def forward( ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size) ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,) ignore_index (int): Index to ignore in loss computation - beta (float): Weight for the odds ratio loss + beta (float): Weight for the direct preference loss + alpha (float): Weight for the NLL loss component compute_nll_loss (bool): Whether to compute the NLL loss compiled (bool): Whether to use torch compile use_ref_model (bool): Whether to use a reference model @@ -144,6 +146,7 @@ def forward( bias=bias, ignore_index=ignore_index, beta=beta, + alpha=alpha, compute_nll_loss=compute_nll_loss, compiled=compiled, use_ref_model=use_ref_model, @@ -158,7 +161,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): @@ -170,6 +173,7 @@ def __init__( self, ignore_index: int = -100, beta: float = 0.1, + alpha: float = 1.0, compute_nll_loss: bool = False, compiled: bool = True, use_ref_model: bool = True, @@ -180,7 +184,8 @@ def __init__( """ Args: ignore_index (int): Index to ignore in the loss. - beta (float): Weight for the odds ratio loss. + beta (float): Weight for the direct preference loss. + alpha (float): Weight for the NLL loss component. compute_nll_loss (bool): Whether to compute the NLL loss. compiled (bool): Whether to use the torch compiled kernel. use_ref_model (bool): Whether to use a reference model for the DPO loss. @@ -190,6 +195,7 @@ def __init__( super().__init__() self.ignore_index = ignore_index self.beta = beta + self.alpha = alpha self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.use_ref_model = use_ref_model @@ -220,6 +226,7 @@ def forward( ref_bias, self.ignore_index, self.beta, + self.alpha, self.compute_nll_loss, self.compiled, self.use_ref_model, diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index de5762f26..2c39af796 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -643,8 +643,9 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_input, ref_weight1, ref_bias1, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta + 1.0, # alpha compute_nll_loss, ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( @@ -655,8 +656,9 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_input, ref_weight2, ref_bias2, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta + 1.0, # alpha compute_nll_loss, ) @@ -886,14 +888,15 @@ def test_correctness_functional_apo_loss_types( ref_input, ref_weight1, ref_bias1, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta + 1.0, # alpha compute_nll_loss, - True, # compiled - True, # use_ref_model + True, # compiled + True, # use_ref_model False, # average_log_prob - 1, # chunk_size - loss_type, # loss_type + 1, # chunk_size + loss_type, ) # For comparison, create a LigerFusedLinearDPOLoss with the loss_type @@ -936,3 +939,39 @@ def test_invalid_loss_type(): # Should not raise an exception loss_fn = LigerFusedLinearDPOLoss(loss_type=loss_type) assert loss_fn.loss_type == loss_type + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_alpha_scales_nll_loss(dtype): + """ + Verify that alpha is actually forwarded and scales the NLL component. + With compute_nll_loss=True, loss(alpha=2) should differ from loss(alpha=1). + """ + B, T, H, V = 4, 16, 32, 64 + atol = 1e-4 if dtype == torch.float32 else 5e-2 + + _weight = torch.randn(V, H, device=device, dtype=dtype) + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + _input = torch.randn(B, T, H, device=device, dtype=dtype) + target = torch.randint(0, V, (B, T), device=device, dtype=torch.long) + + def run(alpha): + inp = _input.detach().clone().requires_grad_(True) + w = _weight.detach().clone().requires_grad_(True) + rw = _ref_weight.detach().clone().requires_grad_(True) + loss_fn = LigerFusedLinearDPOLoss( + beta=0.1, + alpha=alpha, + compute_nll_loss=True, + use_ref_model=True, + average_log_prob=False, + ) + loss, _ = loss_fn(w, inp, target, None, _input.detach(), rw, None) + return loss + + loss_alpha1 = run(alpha=1.0) + loss_alpha2 = run(alpha=2.0) + + assert not torch.allclose(loss_alpha1, loss_alpha2, atol=atol), ( + f"Expected losses to differ when alpha changes, but got {loss_alpha1} vs {loss_alpha2}" + ) From f2e580d18788645f4f45251a7525922da2e6c4e4 Mon Sep 17 00:00:00 2001 From: micdoh Date: Fri, 22 May 2026 20:45:11 +0100 Subject: [PATCH 2/3] fix: append alpha at end of DPOFunction.forward to preserve positional contract Move `alpha` from after `beta` to the end of `LigerFusedLinearDPOFunction.forward`'s positional list so existing positional callers of the public `liger_fused_linear_dpo` / `.apply()` alias don't silently shift every later argument by one slot. The keyword-driven `LigerFusedLinearDPOLoss.__init__` API keeps `alpha` after `beta` as before. Revert the two functional tests to the pre-PR positional list and add a regression test pinning the positional contract. --- src/liger_kernel/chunked_loss/dpo_loss.py | 6 +-- test/chunked_loss/test_dpo_loss.py | 66 +++++++++++++++++++++-- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 633aebf43..f4065c2f5 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -108,13 +108,13 @@ def forward( ref_bias=None, ignore_index=-100, beta=0.1, - alpha=1.0, compute_nll_loss=False, compiled=True, use_ref_model=True, average_log_prob=False, chunk_size=1, loss_type="sigmoid", + alpha=1.0, ): """ Fused linear layer with DPO loss. @@ -128,12 +128,12 @@ def forward( ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,) ignore_index (int): Index to ignore in loss computation beta (float): Weight for the direct preference loss - alpha (float): Weight for the NLL loss component compute_nll_loss (bool): Whether to compute the NLL loss compiled (bool): Whether to use torch compile use_ref_model (bool): Whether to use a reference model average_log_prob (bool): Whether to average the log probability per non-masked token chunk_size (int): Size of chunks for processing. + alpha (float): Weight for the NLL loss component Returns: torch.Tensor: Computed loss """ @@ -226,11 +226,11 @@ def forward( ref_bias, self.ignore_index, self.beta, - self.alpha, self.compute_nll_loss, self.compiled, self.use_ref_model, self.average_log_prob, self.chunk_size, self.loss_type, + self.alpha, ) diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 2c39af796..342c7f7c0 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -645,7 +645,6 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_bias1, -100, # ignore_index 0.1, # beta - 1.0, # alpha compute_nll_loss, ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( @@ -658,7 +657,6 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_bias2, -100, # ignore_index 0.1, # beta - 1.0, # alpha compute_nll_loss, ) @@ -890,7 +888,6 @@ def test_correctness_functional_apo_loss_types( ref_bias1, -100, # ignore_index 0.1, # beta - 1.0, # alpha compute_nll_loss, True, # compiled True, # use_ref_model @@ -975,3 +972,66 @@ def run(alpha): assert not torch.allclose(loss_alpha1, loss_alpha2, atol=atol), ( f"Expected losses to differ when alpha changes, but got {loss_alpha1} vs {loss_alpha2}" ) + + +def test_functional_positional_arg_contract(): + """ + Pin the positional-argument contract of the public functional alias. + + `alpha` is appended at the *end* of `LigerFusedLinearDPOFunction.forward` (not + inserted mid-list) precisely so that existing positional `.apply()` / + `liger_fused_linear_dpo` callers don't silently shift every later argument by + one slot. This test exercises the pre-PR positional list (no `alpha`) and + asserts it produces the same result as the keyword-driven `nn.Module` wrapper. + If a future param insertion shifts the positional slots, this diverges. + """ + B, T, H, V = 4, 8, 16, 32 + dtype = torch.float32 + + _input = torch.randn(B, T, H, device=device, dtype=dtype) + target = torch.randint(0, V, (B, T), device=device, dtype=torch.long) + _weight = torch.randn(V, H, device=device, dtype=dtype) + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) + + # Pre-PR positional list: the args after `beta` are + # compute_nll_loss, compiled, use_ref_model, average_log_prob, chunk_size, loss_type. + loss_positional, _ = liger_fused_linear_dpo( + _input.detach().clone().requires_grad_(True), + _weight.detach().clone().requires_grad_(True), + target, + None, # bias + ref_input, + _ref_weight.detach().clone().requires_grad_(True), + None, # ref_bias + -100, # ignore_index + 0.1, # beta + True, # compute_nll_loss + True, # compiled + True, # use_ref_model + False, # average_log_prob + 1, # chunk_size + "sigmoid", # loss_type + ) + + loss_module, _ = LigerFusedLinearDPOLoss( + ignore_index=-100, + beta=0.1, + alpha=1.0, + compute_nll_loss=True, + compiled=True, + use_ref_model=True, + average_log_prob=False, + chunk_size=1, + loss_type="sigmoid", + )( + _weight.detach().clone().requires_grad_(True), + _input.detach().clone().requires_grad_(True), + target, + None, # bias + ref_input, + _ref_weight.detach().clone().requires_grad_(True), + None, # ref_bias + ) + + assert_verbose_allclose(loss_positional, loss_module, atol=1e-5, rtol=1e-4) From 993e9c408e62356233efdd5dc8f6587b66d2e5a9 Mon Sep 17 00:00:00 2001 From: micdoh Date: Fri, 22 May 2026 20:48:13 +0100 Subject: [PATCH 3/3] style: apply ruff format to dpo loss tests --- test/chunked_loss/test_dpo_loss.py | 36 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 342c7f7c0..2c0f2ddaf 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -644,7 +644,7 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_weight1, ref_bias1, -100, # ignore_index - 0.1, # beta + 0.1, # beta compute_nll_loss, ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( @@ -656,7 +656,7 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_weight2, ref_bias2, -100, # ignore_index - 0.1, # beta + 0.1, # beta compute_nll_loss, ) @@ -886,13 +886,13 @@ def test_correctness_functional_apo_loss_types( ref_input, ref_weight1, ref_bias1, - -100, # ignore_index - 0.1, # beta + -100, # ignore_index + 0.1, # beta compute_nll_loss, - True, # compiled - True, # use_ref_model + True, # compiled + True, # use_ref_model False, # average_log_prob - 1, # chunk_size + 1, # chunk_size loss_type, ) @@ -1000,17 +1000,17 @@ def test_functional_positional_arg_contract(): _input.detach().clone().requires_grad_(True), _weight.detach().clone().requires_grad_(True), target, - None, # bias + None, # bias ref_input, _ref_weight.detach().clone().requires_grad_(True), - None, # ref_bias - -100, # ignore_index - 0.1, # beta - True, # compute_nll_loss - True, # compiled - True, # use_ref_model - False, # average_log_prob - 1, # chunk_size + None, # ref_bias + -100, # ignore_index + 0.1, # beta + True, # compute_nll_loss + True, # compiled + True, # use_ref_model + False, # average_log_prob + 1, # chunk_size "sigmoid", # loss_type ) @@ -1028,10 +1028,10 @@ def test_functional_positional_arg_contract(): _weight.detach().clone().requires_grad_(True), _input.detach().clone().requires_grad_(True), target, - None, # bias + None, # bias ref_input, _ref_weight.detach().clone().requires_grad_(True), - None, # ref_bias + None, # ref_bias ) assert_verbose_allclose(loss_positional, loss_module, atol=1e-5, rtol=1e-4)