Skip to content

Gram Newton Schulz#297

Open
KakaruHayate wants to merge 19 commits intoopenvpi:polar_expressfrom
KakaruHayate:gram_ns_pr
Open

Gram Newton Schulz#297
KakaruHayate wants to merge 19 commits intoopenvpi:polar_expressfrom
KakaruHayate:gram_ns_pr

Conversation

@KakaruHayate
Copy link
Copy Markdown

No description provided.

Update muon.py

Update muon.py
This reverts commit 5b6ce35.
This reverts commit f097a1e.
Add OptimizerTimerCallback to basics/base_task.py to measure GPU optimizer step time using torch.cuda.Event and torch.cuda.synchronize. The callback records start/end events around optimizer steps (after epoch 0) and logs the elapsed milliseconds as "stats/optimizer_step_duration_ms" via pl_module.log (on_step, shown in prog_bar). The callback is registered in the Trainer callbacks so durations appear in TensorBoard/console. Note: a local timer_callback variable is instantiated but the callbacks list also constructs a new OptimizerTimerCallback (minor redundancy).

Update base_task.py
Introduce a mud() implementation (MomentUm Decorrelation) that performs lightweight orthogonalization via row-normalization, row-gram construction, lower-triangular extraction and forward triangular solve. Update Muon optimizer to replace the boolean use_gram_ns with a string method selector (defaults to 'gram_ns') and dispatch dynamically between 'gram_ns', 'mud', and 'ns5' implementations, raising on unknown methods. Also preserve bfloat16 handling and tensor transpose logic; mud() returns a contiguous tensor.
Cast input to float32 for the triangular solve (triangular_solve_cuda not implemented for BFloat16), while preserving the original dtype and casting the result back before returning. Also corrected row normalization to use dim=1 (instead of -1) and tightened eps from 1e-7 to 1e-8. Added explanatory comment and small cleanup.

Update muon.py

Update muon.py
Clamp the diagonal entries of the lower-triangular Gram matrix in mud_whiten with a minimum of 1e-5 before solving the triangular system. This prevents T from having all-zero diagonal values (which would cause singular/ill-conditioned solves) and improves numerical stability of the forward solve.
Update muon.py

Update muon.py
Drop bfloat16 detection and runtime BF16 paths: remove get_bf16_support_map and the bf16_support_map field, eliminate use_bf16 parameters from zeropower_via_newtonschulz5, gram_newton_schulz and mud_whiten, and stop passing use_bf16 from Muon.step. Simplify tensor casts to explicit float32/float16 usage and clean up related conditional logic. This streamlines the orthogonalization codepaths and avoids BF16-specific code (e.g. triangular_solve_cuda incompatibilities).
Apply safety_factor scaling to all POLAR_EXPRESS_COEFFICIENTS except the final tuple. The list comprehension now iterates over _unmodified_polar_express_coefficients[:-1] and the original last element is appended unchanged, preserving that coefficient (likely for correctness or numerical stability).
Remove extraneous blank lines and trailing spaces in modules/optimizer/muon.py and tidy formatting around the normalization, transpose and Newton–Schulz loops. No functional logic was changed.
Remove unused coefficient tables and the mud_whiten path, and streamline orthogonalization to always use gram_newton_schulz. Add collections import and switch get_params_for_muon to a BFS that excludes Embedding modules and only collects trainable params with ndim >= 2. Cast intermediate X to float16 in zeropower_via_newtonschulz5 for faster half-precision ops, and drop unused imports (itertools.repeat) and redundant method dispatch in the Muon step. These changes reduce complexity and unify the orthogonalization flow.

Update muon.py

Update muon.py

Update muon.py

Update muon.py

Update muon.py
Ensure ns_coefficients contains exactly `steps` entries by padding with the last POLAR_EXPRESS_COEFFICIENTS value when `steps` exceeds the predefined list. Adds itertools.repeat import and applies the fix in zeropower_via_newtonschulz5 and gram_newton_schulz to avoid out-of-range access during iteration.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant