Skip to content

fix: expose alpha param through LigerFusedLinearDPO public API#1194

Open
micdoh wants to merge 6 commits into
linkedin:mainfrom
micdoh:expose_dpo_alpha
Open

fix: expose alpha param through LigerFusedLinearDPO public API#1194
micdoh wants to merge 6 commits into
linkedin:mainfrom
micdoh:expose_dpo_alpha

Conversation

@micdoh
Copy link
Copy Markdown

@micdoh micdoh commented Apr 16, 2026

Summary

  • `LigerFusedLinearDPOFunction.forward()` was missing the `alpha` parameter entirely, so the NLL loss scaling weight silently defaulted to `1.0` no matter what callers passed — the parameter exists in `LigerFusedLinearPreferenceBase` but was never plumbed through `LigerFusedLinearDPOFunction`.
  • `LigerFusedLinearDPOLoss` (the `nn.Module` wrapper) similarly had no `alpha` argument.
  • Added one extra `None` to `backward()` to match the new positional count.
  • Fixed positional-arg ordering in the two existing functional tests that called `LigerFusedLinearDPOFunction.apply(...)` directly.
  • Added `test_alpha_scales_nll_loss` to verify `alpha` actually reaches the loss computation.

Changes

`src/liger_kernel/chunked_loss/dpo_loss.py`

  • Add `alpha: float = 1.0` to `LigerFusedLinearDPOFunction.forward()` and forward it to `super().forward()`
  • Add one `None` to `backward()` return (was 11, now 12 — one per non-tensor arg)
  • Add `alpha: float = 1.0` to `LigerFusedLinearDPOLoss.init()` and `self.alpha` assignment
  • Pass `self.alpha` to `LigerFusedLinearDPOFunction.apply()` in `LigerFusedLinearDPOLoss.forward()`

`test/chunked_loss/test_dpo_loss.py`

  • Insert `1.0` (alpha) in the correct positional slot in `test_correctness_functional` and `test_correctness_functional_apo_loss_types`
  • Add `test_alpha_scales_nll_loss` regression test

Test plan

  • Existing `test_correctness`, `test_correctness_functional`, `test_correctness_apo_loss_types`, `test_correctness_functional_apo_loss_types` all pass
  • New `test_alpha_scales_nll_loss` passes and confirms `alpha != 1.0` changes the loss value

micdoh and others added 2 commits April 16, 2026 12:03
LigerFusedLinearDPOFunction.forward() accepted every base-class parameter
except alpha, so the NLL scaling weight silently defaulted to 1.0 regardless
of what callers passed. This adds alpha to both the Function and the
LigerFusedLinearDPOLoss module, fixes the positional-arg order in the
existing functional tests, and adds a regression test that verifies alpha
actually affects the loss value when compute_nll_loss=True.
@yueyiming2009
Copy link
Copy Markdown
Collaborator

Suggestion: append alpha on the autograd Function rather than inserting it after beta

Thanks for plumbing alpha through — the __init__/wrapper changes look right. One concern with how it's inserted into LigerFusedLinearDPOFunction.forward.

The PR inserts alpha positionally between beta and compute_nll_loss in forward. LigerFusedLinearDPOFunction isn't only reached via LigerFusedLinearDPOLoss — it's also exposed as a public functional alias in chunked_loss/functional.py:

liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply

.apply() is positional-only, so for any external caller that passes args positionally past beta, every later argument silently shifts by one slot (compute_nll_loss ← old compiled, etc.). No exception is raised — just wrong behavior. That's the worst failure mode, and it's the kind of break that's invisible until someone's training run is subtly off. (The in-repo callers and tests are updated here, so this only bites third-party positional callers — but the functional API is part of the public surface.)

I'd suggest:

  1. Keep alpha right after beta in LigerFusedLinearDPOLoss.__init__ as-is — that's the documented, keyword-driven public API and reads naturally.
  2. In LigerFusedLinearDPOFunction.forward, append alpha at the end of the positional list (after discopop_tau) and add the matching trailing None in backward. forward already forwards everything to super().forward(...) by keyword, so the internal order doesn't matter.

This gives zero positional shift and zero breakage for existing .apply() / liger_fused_linear_dpo callers, while the user-facing __init__ ordering stays exactly as you have it. Note the autograd Function's positional order is already an implementation detail that diverges from the base (LigerFusedLinearPreferenceBase.forward actually has alpha before beta, and is called by keyword), so append-only here is consistent with how autograd.Function signatures are conventionally evolved.

  1. Optionally, add a regression test that calls liger_fused_linear_dpo(...) with the pre-PR positional argument list and asserts unchanged semantics — that pins the positional contract so a future param insertion can't silently break it again.

micdoh added 3 commits May 22, 2026 20:45
…l contract

Move `alpha` from after `beta` to the end of
`LigerFusedLinearDPOFunction.forward`'s positional list so existing
positional callers of the public `liger_fused_linear_dpo` /
`.apply()` alias don't silently shift every later argument by one slot.
The keyword-driven `LigerFusedLinearDPOLoss.__init__` API keeps `alpha`
after `beta` as before. Revert the two functional tests to the pre-PR
positional list and add a regression test pinning the positional contract.
…expose_dpo_alpha

# Conflicts:
#	src/liger_kernel/chunked_loss/dpo_loss.py
#	test/chunked_loss/test_dpo_loss.py
@micdoh
Copy link
Copy Markdown
Author

micdoh commented May 22, 2026

Ready for re-review.

  • Addressed @yueyiming2009's feedback: alpha is now appended at the end of LigerFusedLinearDPOFunction.forward (after discopop_tau) instead of inserted after beta, so existing positional liger_fused_linear_dpo / .apply() callers don't silently shift. The keyword-driven LigerFusedLinearDPOLoss.__init__ API keeps alpha after beta as before.
  • Added test_functional_positional_arg_contract to pin the positional contract so a future param insertion can't silently break it again.
  • Merged latest main and resolved conflicts with the new label_smoothing / discopop_tau work (alpha ordered last; backward None-count updated accordingly).
  • ruff check and ruff format --check pass locally (checkstyle).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants