Skip to content

Adding Layerwise GN (https://arxiv.org/pdf/2510.09378) #59

@switiz

Description

@switiz

Is your feature request related to a problem? Please describe.

Large-batch LLM pretraining currently relies on first-order or approximate second-order optimizers (e.g., AdamW, SOAP, Muon, Shampoo).
However, even the strongest of these still under-utilize curvature information — requiring significantly more iterations to reach the same loss compared to a full Gauss-Newton (GN) preconditioner.

For example, the paper The Potential of Second-Order Optimization for LLMs: A Study with Full Gauss-Newton (arXiv:2510.09378) reports that:

  • GN reaches a loss of 3.25 in 54 steps
  • SOAP requires 292 steps to reach the same loss
    → roughly 5.4× fewer iterations, and GN also extends the critical batch size.

Describe the solution you’d like

Add a Layerwise Gauss-Newton (GN) preconditioning mode, with an optional “Full GN (oracle)” flag for research comparison.

1. Optimizer Core

  • Integrate GN updates via JVP-based preconditioning, avoiding explicit Hessian materialization.
  • GN acts as a plug-in preconditioner that wraps existing optimizers (SOAP / Muon / Shampoo).
  • Use inner-loop Muon or AdamW to minimize the quadratic objective under GN preconditioning, with optional line search for stability.

2. Layerwise GN Variant (Default)

  • Compute per-layer GN updates (ignore cross-layer curvature).
  • Nearly matches Full GN on medium-scale LLMs and large batches — requiring only ~1.4× more steps than Full GN but ~3.4× fewer steps than SOAP.

3. Why GN?

  • GN captures curvature from the loss surface only (positive semi-definite),
    avoiding negative-curvature instability of full Newton updates while significantly improving iteration efficiency at scale.

Describe alternatives you’ve considered

  • Existing SOAP / Muon / Shampoo implementations already provide approximate second-order preconditioning,
    but they lack full curvature fidelity and plateau earlier in large-batch regimes.
  • Extending these optimizers with GN-based preconditioning could preserve backward compatibility
    while improving step efficiency.
  • The GN-prox-linear variant was analyzed but offered little gain, suggesting the loss curvature alone captures most of the benefit.

Additional context

  • Treat Full GN as a research-only configuration (≈ 4–5× slower wall-clock).
  • Layerwise GN is the practical, scalable variant to evaluate for improved step efficiency and batch scaling.
  • Recommended evaluation setup: 45M and 150M-parameter models under large-batch regimes.
  • GN update formula:
    [
    \theta_{t+1} = \theta_t - G^{-1} g,\quad G = J^\top \nabla_z^2 L, J
    ]
    implemented efficiently via JVPs without explicit Hessian storage.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions