fix: expose alpha param through LigerFusedLinearDPO public API#1194
fix: expose alpha param through LigerFusedLinearDPO public API#1194micdoh wants to merge 6 commits into
Conversation
LigerFusedLinearDPOFunction.forward() accepted every base-class parameter except alpha, so the NLL scaling weight silently defaulted to 1.0 regardless of what callers passed. This adds alpha to both the Function and the LigerFusedLinearDPOLoss module, fixes the positional-arg order in the existing functional tests, and adds a regression test that verifies alpha actually affects the loss value when compute_nll_loss=True.
|
Suggestion: append Thanks for plumbing The PR inserts liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
I'd suggest:
This gives zero positional shift and zero breakage for existing
|
…l contract Move `alpha` from after `beta` to the end of `LigerFusedLinearDPOFunction.forward`'s positional list so existing positional callers of the public `liger_fused_linear_dpo` / `.apply()` alias don't silently shift every later argument by one slot. The keyword-driven `LigerFusedLinearDPOLoss.__init__` API keeps `alpha` after `beta` as before. Revert the two functional tests to the pre-PR positional list and add a regression test pinning the positional contract.
…expose_dpo_alpha # Conflicts: # src/liger_kernel/chunked_loss/dpo_loss.py # test/chunked_loss/test_dpo_loss.py
|
Ready for re-review.
|
Summary
Changes
`src/liger_kernel/chunked_loss/dpo_loss.py`
`test/chunked_loss/test_dpo_loss.py`
Test plan