From ff3b72cd4ff6e64593c8b7a59f86ed9bcd94b82e Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Tue, 31 Mar 2026 21:05:24 +0800 Subject: [PATCH 01/19] update gram-newton-schulz Update muon.py Update muon.py --- modules/optimizer/muon.py | 68 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 662dacee..28e205bd 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -74,6 +74,61 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor return X +def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Tensor: + """ + Gram Newton-Schulz iteration to compute the orthogonalization of G. + Mathematically identical to standard Newton-Schulz but computes iterating + on the smaller NxN Gram matrix to save up to 50% FLOPs. + """ + assert G.ndim == 3 + a, b, c = (3.4445, -4.7750, 2.0315) + ns_coefficients = [(a, b, c)] * steps + + original_shape = G.shape + dtype = G.dtype + + X = G.to(torch.float32) + + should_transpose = X.size(-2) > X.size(-1) + if should_transpose: + X = X.mT + + X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) + X = X.to(torch.float16) + + if X.size(-2) != X.size(-1): + R = torch.bmm(X, X.mT) + Q = None + + for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): + if i in reset_iterations and i != 0: + X = torch.bmm(Q, X) + R = torch.bmm(X, X.mT) + Q = None + + Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i) + if i != 0 and i not in reset_iterations: + Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0) + else: + Q = Z.clone() + Q.diagonal(dim1=-2, dim2=-1).add_(a_i) + if i < steps - 1 and (i + 1) not in reset_iterations: + RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0) + R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0) + + X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q) + else: + for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): + A = torch.bmm(X, X.mT) + B = torch.baddbmm(A, A, A, beta=b_i, alpha=c_i) + X = torch.baddbmm(X, B, X, beta=a_i, alpha=1.0) + + if should_transpose: + X = X.mT + + return X.to(dtype).view(original_shape) + + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -97,10 +152,12 @@ class Muon(torch.optim.Optimizer): ns_steps: The number of Newton-Schulz iteration steps to use. """ - def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) + def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5, reset_iterations=None): + if reset_iterations is None: + reset_iterations = [3] + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, reset_iterations=reset_iterations) super().__init__(params, defaults) - self.bf16_support_map = get_bf16_support_map() + # self.bf16_support_map = get_bf16_support_map() @torch.no_grad() def step(self, closure=None): @@ -129,8 +186,9 @@ def step(self, closure=None): original_shape = g.shape if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) - use_bf16 = self.bf16_support_map.get(g.device, False) - g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) + # use_bf16 = self.bf16_support_map.get(g.device, False) + # g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) + g = gram_newton_schulz(g, steps=group["ns_steps"], reset_iterations=group["reset_iterations"]) if group["weight_decay"] > 0: torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5) From 298956089b09dee69bfa68fff6d293f19ec44870 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Tue, 31 Mar 2026 21:54:25 +0800 Subject: [PATCH 02/19] use bf16 --- modules/optimizer/muon.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 28e205bd..28c53544 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -74,7 +74,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor return X -def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Tensor: +def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int]) -> Tensor: """ Gram Newton-Schulz iteration to compute the orthogonalization of G. Mathematically identical to standard Newton-Schulz but computes iterating @@ -87,7 +87,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te original_shape = G.shape dtype = G.dtype - X = G.to(torch.float32) + X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) should_transpose = X.size(-2) > X.size(-1) if should_transpose: @@ -107,16 +107,19 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te Q = None Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i) + if i != 0 and i not in reset_iterations: Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0) else: Q = Z.clone() Q.diagonal(dim1=-2, dim2=-1).add_(a_i) + if i < steps - 1 and (i + 1) not in reset_iterations: RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0) R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0) X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q) + else: for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): A = torch.bmm(X, X.mT) @@ -157,7 +160,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr reset_iterations = [3] defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, reset_iterations=reset_iterations) super().__init__(params, defaults) - # self.bf16_support_map = get_bf16_support_map() + self.bf16_support_map = get_bf16_support_map() @torch.no_grad() def step(self, closure=None): @@ -186,9 +189,9 @@ def step(self, closure=None): original_shape = g.shape if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) - # use_bf16 = self.bf16_support_map.get(g.device, False) + use_bf16 = self.bf16_support_map.get(g.device, False) # g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) - g = gram_newton_schulz(g, steps=group["ns_steps"], reset_iterations=group["reset_iterations"]) + g = gram_newton_schulz(g, steps=group["ns_steps"], use_bf16=use_bf16, reset_iterations=group["reset_iterations"]) if group["weight_decay"] > 0: torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5) From f0fa6d9bff88e9c2c79e2cb2fdc877061db87673 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Tue, 31 Mar 2026 22:34:44 +0800 Subject: [PATCH 03/19] Revert "use bf16" This reverts commit 5b6ce35f7d0eca578d213ceb838ae62cf187c436. --- modules/optimizer/muon.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 28c53544..28e205bd 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -74,7 +74,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor return X -def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int]) -> Tensor: +def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Tensor: """ Gram Newton-Schulz iteration to compute the orthogonalization of G. Mathematically identical to standard Newton-Schulz but computes iterating @@ -87,7 +87,7 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: original_shape = G.shape dtype = G.dtype - X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) + X = G.to(torch.float32) should_transpose = X.size(-2) > X.size(-1) if should_transpose: @@ -107,19 +107,16 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: Q = None Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i) - if i != 0 and i not in reset_iterations: Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0) else: Q = Z.clone() Q.diagonal(dim1=-2, dim2=-1).add_(a_i) - if i < steps - 1 and (i + 1) not in reset_iterations: RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0) R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0) X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q) - else: for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): A = torch.bmm(X, X.mT) @@ -160,7 +157,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr reset_iterations = [3] defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, reset_iterations=reset_iterations) super().__init__(params, defaults) - self.bf16_support_map = get_bf16_support_map() + # self.bf16_support_map = get_bf16_support_map() @torch.no_grad() def step(self, closure=None): @@ -189,9 +186,9 @@ def step(self, closure=None): original_shape = g.shape if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) - use_bf16 = self.bf16_support_map.get(g.device, False) + # use_bf16 = self.bf16_support_map.get(g.device, False) # g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) - g = gram_newton_schulz(g, steps=group["ns_steps"], use_bf16=use_bf16, reset_iterations=group["reset_iterations"]) + g = gram_newton_schulz(g, steps=group["ns_steps"], reset_iterations=group["reset_iterations"]) if group["weight_decay"] > 0: torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5) From bbdc0e686bb0933480bbe861ab290d95edd354ad Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Tue, 31 Mar 2026 22:41:18 +0800 Subject: [PATCH 04/19] del useless code --- modules/optimizer/muon.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 28e205bd..e7cee4dc 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -122,9 +122,6 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te A = torch.bmm(X, X.mT) B = torch.baddbmm(A, A, A, beta=b_i, alpha=c_i) X = torch.baddbmm(X, B, X, beta=a_i, alpha=1.0) - - if should_transpose: - X = X.mT return X.to(dtype).view(original_shape) From 86a6342c66acef11a9cd019319c10050e68412ee Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Tue, 31 Mar 2026 22:46:01 +0800 Subject: [PATCH 05/19] post-transpose --- modules/optimizer/muon.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index e7cee4dc..39d2d347 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -89,13 +89,13 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te X = G.to(torch.float32) + X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) + X = X.to(torch.float16) + should_transpose = X.size(-2) > X.size(-1) if should_transpose: X = X.mT - X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) - X = X.to(torch.float16) - if X.size(-2) != X.size(-1): R = torch.bmm(X, X.mT) Q = None From 08c929adbc3e528574d995fa7fdf74ae008e1556 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Tue, 31 Mar 2026 23:21:32 +0800 Subject: [PATCH 06/19] Reapply "use bf16" This reverts commit f097a1e05fd0c46ae22f61b879ea033941b86d22. --- modules/optimizer/muon.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 39d2d347..12ccbe5a 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -74,7 +74,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor return X -def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Tensor: +def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int]) -> Tensor: """ Gram Newton-Schulz iteration to compute the orthogonalization of G. Mathematically identical to standard Newton-Schulz but computes iterating @@ -87,7 +87,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te original_shape = G.shape dtype = G.dtype - X = G.to(torch.float32) + X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) X = X.to(torch.float16) @@ -107,16 +107,19 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]) -> Te Q = None Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i) + if i != 0 and i not in reset_iterations: Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0) else: Q = Z.clone() Q.diagonal(dim1=-2, dim2=-1).add_(a_i) + if i < steps - 1 and (i + 1) not in reset_iterations: RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0) R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0) X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q) + else: for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): A = torch.bmm(X, X.mT) @@ -154,7 +157,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr reset_iterations = [3] defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, reset_iterations=reset_iterations) super().__init__(params, defaults) - # self.bf16_support_map = get_bf16_support_map() + self.bf16_support_map = get_bf16_support_map() @torch.no_grad() def step(self, closure=None): @@ -183,9 +186,9 @@ def step(self, closure=None): original_shape = g.shape if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) - # use_bf16 = self.bf16_support_map.get(g.device, False) + use_bf16 = self.bf16_support_map.get(g.device, False) # g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) - g = gram_newton_schulz(g, steps=group["ns_steps"], reset_iterations=group["reset_iterations"]) + g = gram_newton_schulz(g, steps=group["ns_steps"], use_bf16=use_bf16, reset_iterations=group["reset_iterations"]) if group["weight_decay"] > 0: torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5) From d26584231133c59f783aa25fbbed148364fa1bae Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Tue, 31 Mar 2026 23:24:04 +0800 Subject: [PATCH 07/19] set bf16 when X.size(-2) == X.size(-1) --- modules/optimizer/muon.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 12ccbe5a..a6ac32b5 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -87,16 +87,16 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: original_shape = G.shape dtype = G.dtype - X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) + X = G.to(torch.float32) X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) - X = X.to(torch.float16) should_transpose = X.size(-2) > X.size(-1) if should_transpose: X = X.mT if X.size(-2) != X.size(-1): + X = X.to(torch.float16) R = torch.bmm(X, X.mT) Q = None @@ -121,6 +121,7 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q) else: + X = X.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): A = torch.bmm(X, X.mT) B = torch.baddbmm(A, A, A, beta=b_i, alpha=c_i) From f0a1c191425ef86d9a7dc639938933cace4bc01d Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Thu, 2 Apr 2026 13:30:27 +0800 Subject: [PATCH 08/19] Update muon.py --- modules/optimizer/muon.py | 92 ++++++++++++++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 12 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index a6ac32b5..e93d6a97 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -22,6 +22,35 @@ coeffs_list = [(a / 1.01 , b / 1.01**3 , c / 1.01**5) for (a, b, c) in coeffs_list[: -1]] + [coeffs_list[-1]] +# https://x.com/YouJiacheng/status/1905861218138804534 +YOU_COEFFICIENTS = [ + [4.0848, -6.8946, 2.9270], + [3.9505, -6.3029, 2.6377], + [3.7418, -5.5913, 2.3037], + [2.8769, -3.1427, 1.2046], + [2.8366, -3.0525, 1.2012] +] + +# https://arxiv.org/pdf/2505.16932 +_unmodified_polar_express_coefficients = [ + (8.28721201814563, -23.595886519098837, 17.300387312530933), + (4.107059111542203, -2.9478499167379106, 0.5448431082926601), + (3.9486908534822946, -2.908902115962949, 0.5518191394370137), + (3.3184196573706015, -2.488488024314874, 0.51004894012372), + (2.300652019954817, -1.6689039845747493, 0.4188073119525673), + (1.891301407787398, -1.2679958271945868, 0.37680408948524835), + (1.8750014808534479, -1.2500016453999487, 0.3750001645474248), + (1.875, -1.25, 0.375), # subsequent coeffs equal this numerically +] + +# safety factor for numerical stability (but exclude last polynomial ) +safety_factor = 1.01 # 'Dao-AILab/gram-newton-schulz' set 1.05 +POLAR_EXPRESS_COEFFICIENTS = [ + (a / safety_factor , b / safety_factor**3 , c / safety_factor**5) + for (a, b, c) in _unmodified_polar_express_coefficients +] + + def get_bf16_support_map(): bf16_support_map = {} @@ -39,8 +68,8 @@ def get_bf16_support_map(): return bf16_support_map - -def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor: + +def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool, ns_coefficients: List[tuple]) -> Tensor: """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose @@ -51,7 +80,6 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor performance at all relative to UV^T, where USV^T = G is the SVD. """ assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - #a, b, c = (3.4445, -4.7750, 2.0315) X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) @@ -61,12 +89,14 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor # Perform the NS iterations hs = coeffs_list[: steps] + list(repeat(coeffs_list[-1], steps - len(coeffs_list))) if X.size(-2) < X.size(-1): - for a, b, c in hs: + for i in range(steps): + a, b, c = ns_coefficients[i] A = torch.bmm(X, X.mT) A = torch.baddbmm(A, A, A, beta=b, alpha=c) X = torch.baddbmm(X, A, X, beta=a, alpha=1) else: - for a, b, c in hs: + for i in range(steps): + a, b, c = ns_coefficients[i] A = torch.bmm(X.mT, X) A = torch.baddbmm(A, A, A, beta=b, alpha=c) X = torch.baddbmm(X, X, A, beta=a, alpha=1) @@ -74,15 +104,13 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor return X -def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int]) -> Tensor: +def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int], ns_coefficients: List[tuple]) -> Tensor: """ Gram Newton-Schulz iteration to compute the orthogonalization of G. Mathematically identical to standard Newton-Schulz but computes iterating on the smaller NxN Gram matrix to save up to 50% FLOPs. """ assert G.ndim == 3 - a, b, c = (3.4445, -4.7750, 2.0315) - ns_coefficients = [(a, b, c)] * steps original_shape = G.shape dtype = G.dtype @@ -151,12 +179,35 @@ class Muon(torch.optim.Optimizer): momentum: The momentum used by the internal SGD. nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iteration steps to use. + ns_coefficients: List of tuples (a, b, c) for each NS step. Defaults to standard 5-step NS coeffs. + use_gram_ns: Whether to use the FLOP-saving Gram-NS implementation instead of standard NS. """ - def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5, reset_iterations=None): + def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, + ns_steps=5, reset_iterations=[2], ns_coefficients=POLAR_EXPRESS_COEFFICIENTS, + use_gram_ns=True): if reset_iterations is None: reset_iterations = [3] - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, reset_iterations=reset_iterations) + # set [3] on default and set [2] on POLAR_EXPRESS_COEFFICIENTS or YOU_COEFFICIENTS + + if ns_coefficients is None: + parsed_coefficients = [(3.4445, -4.7750, 2.0315)] * ns_steps + else: + parsed_coefficients = list(ns_coefficients) + if len(parsed_coefficients) < ns_steps: + parsed_coefficients += [parsed_coefficients[-1]] * (ns_steps - len(parsed_coefficients)) + parsed_coefficients = parsed_coefficients[:ns_steps] + + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + reset_iterations=reset_iterations, + ns_coefficients=parsed_coefficients, + use_gram_ns=use_gram_ns + ) super().__init__(params, defaults) self.bf16_support_map = get_bf16_support_map() @@ -187,9 +238,26 @@ def step(self, closure=None): original_shape = g.shape if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) + use_bf16 = self.bf16_support_map.get(g.device, False) - # g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) - g = gram_newton_schulz(g, steps=group["ns_steps"], use_bf16=use_bf16, reset_iterations=group["reset_iterations"]) + + # Dynamic NS function invocation + if group["use_gram_ns"]: + g = gram_newton_schulz( + g, + steps=group["ns_steps"], + use_bf16=use_bf16, + reset_iterations=group["reset_iterations"], + ns_coefficients=group["ns_coefficients"] + ) + else: + g = zeropower_via_newtonschulz5( + g, + steps=group["ns_steps"], + use_bf16=use_bf16, + ns_coefficients=group["ns_coefficients"] + ) + if group["weight_decay"] > 0: torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5) From 4e81fbbf5b185e82c5fdb48dc41d0b4056a9723a Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Thu, 2 Apr 2026 14:22:10 +0800 Subject: [PATCH 09/19] Log optimizer step duration via callback Add OptimizerTimerCallback to basics/base_task.py to measure GPU optimizer step time using torch.cuda.Event and torch.cuda.synchronize. The callback records start/end events around optimizer steps (after epoch 0) and logs the elapsed milliseconds as "stats/optimizer_step_duration_ms" via pl_module.log (on_step, shown in prog_bar). The callback is registered in the Trainer callbacks so durations appear in TensorBoard/console. Note: a local timer_callback variable is instantiated but the callbacks list also constructs a new OptimizerTimerCallback (minor redundancy). Update base_task.py --- basics/base_task.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/basics/base_task.py b/basics/base_task.py index 656893d9..1baf4928 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -15,6 +15,7 @@ from torchmetrics import Metric, MeanMetric import lightning.pytorch as pl from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_only +from lightning.pytorch.callbacks import Callback from basics.base_module import CategorizedModule from utils.hparams import hparams @@ -32,6 +33,37 @@ format=log_format, datefmt='%m/%d %I:%M:%S %p') +class OptimizerTimerCallback(Callback): + def __init__(self): + super().__init__() + # 使用 CUDA Event 确保获取的是 GPU 真实执行时间,而非 CPU 发射时间 + self.start_event = torch.cuda.Event(enable_timing=True) + self.end_event = torch.cuda.Event(enable_timing=True) + + def on_before_optimizer_step(self, trainer, pl_module, optimizer): + # 只在第一个 Epoch 之后开始计时 + if trainer.current_epoch > 0: + self.start_event.record() + + def on_after_optimizer_step(self, trainer, pl_module, optimizer): + if trainer.current_epoch > 0: + self.end_event.record() + torch.cuda.synchronize() # 等待 GPU 完成该 Step 的所有计算 + + # 计算耗时(毫秒) + epoch_time_ms = self.start_event.elapsed_time(self.end_event) + + # 记录到 TensorBoard + # pl_module.log 会自动寻找当前配置的 Logger (如 TensorBoardLogger) + pl_module.log( + "stats/optimizer_step_duration_ms", + epoch_time_ms, + on_step=True, + on_epoch=False, + prog_bar=True + ) + + class BaseTask(pl.LightningModule): """ Base class for training tasks. @@ -423,6 +455,7 @@ def start(cls): ), # LearningRateMonitor(logging_interval='step'), DsTQDMProgressBar(), + OptimizerTimerCallback(), ], logger=DsTensorBoardLogger( save_dir=str(work_dir), From e8e33328924019acd6949da73601b44f0173bac6 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Thu, 2 Apr 2026 14:51:24 +0800 Subject: [PATCH 10/19] Add MUD orthogonalization and method switch Introduce a mud() implementation (MomentUm Decorrelation) that performs lightweight orthogonalization via row-normalization, row-gram construction, lower-triangular extraction and forward triangular solve. Update Muon optimizer to replace the boolean use_gram_ns with a string method selector (defaults to 'gram_ns') and dispatch dynamically between 'gram_ns', 'mud', and 'ns5' implementations, raising on unknown methods. Also preserve bfloat16 handling and tensor transpose logic; mud() returns a contiguous tensor. --- modules/optimizer/muon.py | 48 ++++++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index e93d6a97..a38e51bf 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -158,6 +158,33 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: return X.to(dtype).view(original_shape) +def mud(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor: + """ + MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G. + A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training". + Constructs a lower-triangular approximation to the Gram matrix and applies forward triangular solve. + """ + assert G.ndim == 3 + + X = G.to(dtype=torch.bfloat16 if use_bf16 else torch.float32) + + should_transpose = X.size(-2) > X.size(-1) + if should_transpose: + X = X.mT + + for _ in range(passes): + X = F.normalize(X, p=2.0, dim=-1, eps=1e-7) # Row normalization + G_mat = torch.bmm(X, X.mT) # Row Gram (k,k) + T = torch.tril(G_mat) # Lower-triangular of Gram + X = torch.linalg.solve_triangular(T, X, upper=False) # Forward solve: T X = Q + X = F.normalize(X, p=2.0, dim=-1, eps=1e-7) # Renormalize rows + + if should_transpose: + X = X.mT + + return X.contiguous() + + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -180,12 +207,12 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iteration steps to use. ns_coefficients: List of tuples (a, b, c) for each NS step. Defaults to standard 5-step NS coeffs. - use_gram_ns: Whether to use the FLOP-saving Gram-NS implementation instead of standard NS. + method: String selector for the orthogonalization strategy ('ns5', 'gram_ns', or 'mud'). """ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5, reset_iterations=[2], ns_coefficients=POLAR_EXPRESS_COEFFICIENTS, - use_gram_ns=True): + method='gram_ns'): if reset_iterations is None: reset_iterations = [3] # set [3] on default and set [2] on POLAR_EXPRESS_COEFFICIENTS or YOU_COEFFICIENTS @@ -206,7 +233,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr ns_steps=ns_steps, reset_iterations=reset_iterations, ns_coefficients=parsed_coefficients, - use_gram_ns=use_gram_ns + method=method ) super().__init__(params, defaults) self.bf16_support_map = get_bf16_support_map() @@ -241,8 +268,9 @@ def step(self, closure=None): use_bf16 = self.bf16_support_map.get(g.device, False) - # Dynamic NS function invocation - if group["use_gram_ns"]: + # Dynamic orthogonalization method invocation + method = group.get("method", "gram_ns") + if method == 'gram_ns': g = gram_newton_schulz( g, steps=group["ns_steps"], @@ -250,13 +278,21 @@ def step(self, closure=None): reset_iterations=group["reset_iterations"], ns_coefficients=group["ns_coefficients"] ) - else: + elif method == 'mud': + g = mud( + g, + passes=1, + use_bf16=use_bf16 + ) + elif method == 'ns5': g = zeropower_via_newtonschulz5( g, steps=group["ns_steps"], use_bf16=use_bf16, ns_coefficients=group["ns_coefficients"] ) + else: + raise ValueError(f"Unknown orthogonalization method: {method}") if group["weight_decay"] > 0: torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) From 6d3786c71005696ce9edbe978a0e9ad260364032 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Thu, 2 Apr 2026 15:45:29 +0800 Subject: [PATCH 11/19] Revert "Log optimizer step duration via callback" This reverts commit e380a729dc51bf82cec86307dae4675d9d7f48d7. --- basics/base_task.py | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 1baf4928..656893d9 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -15,7 +15,6 @@ from torchmetrics import Metric, MeanMetric import lightning.pytorch as pl from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_only -from lightning.pytorch.callbacks import Callback from basics.base_module import CategorizedModule from utils.hparams import hparams @@ -33,37 +32,6 @@ format=log_format, datefmt='%m/%d %I:%M:%S %p') -class OptimizerTimerCallback(Callback): - def __init__(self): - super().__init__() - # 使用 CUDA Event 确保获取的是 GPU 真实执行时间,而非 CPU 发射时间 - self.start_event = torch.cuda.Event(enable_timing=True) - self.end_event = torch.cuda.Event(enable_timing=True) - - def on_before_optimizer_step(self, trainer, pl_module, optimizer): - # 只在第一个 Epoch 之后开始计时 - if trainer.current_epoch > 0: - self.start_event.record() - - def on_after_optimizer_step(self, trainer, pl_module, optimizer): - if trainer.current_epoch > 0: - self.end_event.record() - torch.cuda.synchronize() # 等待 GPU 完成该 Step 的所有计算 - - # 计算耗时(毫秒) - epoch_time_ms = self.start_event.elapsed_time(self.end_event) - - # 记录到 TensorBoard - # pl_module.log 会自动寻找当前配置的 Logger (如 TensorBoardLogger) - pl_module.log( - "stats/optimizer_step_duration_ms", - epoch_time_ms, - on_step=True, - on_epoch=False, - prog_bar=True - ) - - class BaseTask(pl.LightningModule): """ Base class for training tasks. @@ -455,7 +423,6 @@ def start(cls): ), # LearningRateMonitor(logging_interval='step'), DsTQDMProgressBar(), - OptimizerTimerCallback(), ], logger=DsTensorBoardLogger( save_dir=str(work_dir), From 4cd074844053356d40ddd3ff2ba31ad5009c47d4 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Thu, 2 Apr 2026 17:25:03 +0800 Subject: [PATCH 12/19] Fix muon mud dtype handling and normalization Cast input to float32 for the triangular solve (triangular_solve_cuda not implemented for BFloat16), while preserving the original dtype and casting the result back before returning. Also corrected row normalization to use dim=1 (instead of -1) and tightened eps from 1e-7 to 1e-8. Added explanatory comment and small cleanup. Update muon.py Update muon.py --- modules/optimizer/muon.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index a38e51bf..810b2293 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -158,31 +158,34 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: return X.to(dtype).view(original_shape) -def mud(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor: +def mud_whiten(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor: """ MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G. A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training". Constructs a lower-triangular approximation to the Gram matrix and applies forward triangular solve. """ assert G.ndim == 3 - - X = G.to(dtype=torch.bfloat16 if use_bf16 else torch.float32) + dtype = G.dtype + + # X = X.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) + # "triangular_solve_cuda" not implemented for 'BFloat16' + X = G.to(torch.float32) should_transpose = X.size(-2) > X.size(-1) if should_transpose: - X = X.mT + X = X.mT.contiguous() for _ in range(passes): - X = F.normalize(X, p=2.0, dim=-1, eps=1e-7) # Row normalization + X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Row normalization G_mat = torch.bmm(X, X.mT) # Row Gram (k,k) T = torch.tril(G_mat) # Lower-triangular of Gram X = torch.linalg.solve_triangular(T, X, upper=False) # Forward solve: T X = Q - X = F.normalize(X, p=2.0, dim=-1, eps=1e-7) # Renormalize rows + X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Renormalize rows if should_transpose: X = X.mT - return X.contiguous() + return X.to(dtype).contiguous() class Muon(torch.optim.Optimizer): @@ -279,7 +282,7 @@ def step(self, closure=None): ns_coefficients=group["ns_coefficients"] ) elif method == 'mud': - g = mud( + g = mud_whiten( g, passes=1, use_bf16=use_bf16 From 16e1989f4170d9aad052e01777319750f582844f Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Thu, 2 Apr 2026 19:55:19 +0800 Subject: [PATCH 13/19] Stabilize mud_whiten by clamping diagonal Clamp the diagonal entries of the lower-triangular Gram matrix in mud_whiten with a minimum of 1e-5 before solving the triangular system. This prevents T from having all-zero diagonal values (which would cause singular/ill-conditioned solves) and improves numerical stability of the forward solve. --- modules/optimizer/muon.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 810b2293..4f6072aa 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -179,6 +179,7 @@ def mud_whiten(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor: X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Row normalization G_mat = torch.bmm(X, X.mT) # Row Gram (k,k) T = torch.tril(G_mat) # Lower-triangular of Gram + T.diagonal(dim1=-2, dim2=-1).clamp_min_(1e-5) # avoid T all zero X = torch.linalg.solve_triangular(T, X, upper=False) # Forward solve: T X = Q X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Renormalize rows From 931d824ec023be6f6d914bb6defb2451a2b7adc6 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Thu, 2 Apr 2026 23:37:06 +0800 Subject: [PATCH 14/19] safety_factor set 1.05 Update muon.py Update muon.py --- modules/optimizer/muon.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 4f6072aa..08b54a2b 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -44,7 +44,8 @@ ] # safety factor for numerical stability (but exclude last polynomial ) -safety_factor = 1.01 # 'Dao-AILab/gram-newton-schulz' set 1.05 +# safety_factor = 1.01 # 'Dao-AILab/gram-newton-schulz' set 1.05 +safety_factor = 1.05 POLAR_EXPRESS_COEFFICIENTS = [ (a / safety_factor , b / safety_factor**3 , c / safety_factor**5) for (a, b, c) in _unmodified_polar_express_coefficients From 237a0e8654fa2f84ac55618f6c5974356222bae4 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Sat, 4 Apr 2026 11:54:52 +0800 Subject: [PATCH 15/19] Remove bf16 support and simplify tensor casting Drop bfloat16 detection and runtime BF16 paths: remove get_bf16_support_map and the bf16_support_map field, eliminate use_bf16 parameters from zeropower_via_newtonschulz5, gram_newton_schulz and mud_whiten, and stop passing use_bf16 from Muon.step. Simplify tensor casts to explicit float32/float16 usage and clean up related conditional logic. This streamlines the orthogonalization codepaths and avoids BF16-specific code (e.g. triangular_solve_cuda incompatibilities). --- modules/optimizer/muon.py | 42 ++++++++------------------------------- 1 file changed, 8 insertions(+), 34 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 08b54a2b..3dce5138 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -52,25 +52,7 @@ ] -def get_bf16_support_map(): - bf16_support_map = {} - - if not torch.cuda.is_available(): - return bf16_support_map - - device_count = torch.cuda.device_count() - if device_count == 0: - return bf16_support_map - - for i in range(device_count): - device = torch.device(f'cuda:{i}') - major, minor = torch.cuda.get_device_capability(device) - bf16_support_map[device] = (major >= 8) - - return bf16_support_map - - -def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool, ns_coefficients: List[tuple]) -> Tensor: +def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tuple]) -> Tensor: """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose @@ -82,7 +64,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool, ns_coeffi """ assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) + X = G.to(torch.float32) # Ensure spectral norm is at most 1 X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) @@ -105,7 +87,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool, ns_coeffi return X -def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: List[int], ns_coefficients: List[tuple]) -> Tensor: +def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_coefficients: List[tuple]) -> Tensor: """ Gram Newton-Schulz iteration to compute the orthogonalization of G. Mathematically identical to standard Newton-Schulz but computes iterating @@ -124,8 +106,8 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: if should_transpose: X = X.mT + X = X.to(torch.float16) if X.size(-2) != X.size(-1): - X = X.to(torch.float16) R = torch.bmm(X, X.mT) Q = None @@ -150,7 +132,6 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q) else: - X = X.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): A = torch.bmm(X, X.mT) B = torch.baddbmm(A, A, A, beta=b_i, alpha=c_i) @@ -159,7 +140,7 @@ def gram_newton_schulz(G: Tensor, steps: int, use_bf16: bool, reset_iterations: return X.to(dtype).view(original_shape) -def mud_whiten(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor: +def mud_whiten(G: Tensor, passes: int = 1) -> Tensor: """ MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G. A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training". @@ -168,7 +149,6 @@ def mud_whiten(G: Tensor, passes: int = 1, use_bf16: bool = False) -> Tensor: assert G.ndim == 3 dtype = G.dtype - # X = X.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) # "triangular_solve_cuda" not implemented for 'BFloat16' X = G.to(torch.float32) @@ -241,7 +221,6 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr method=method ) super().__init__(params, defaults) - self.bf16_support_map = get_bf16_support_map() @torch.no_grad() def step(self, closure=None): @@ -271,29 +250,24 @@ def step(self, closure=None): if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) - use_bf16 = self.bf16_support_map.get(g.device, False) - # Dynamic orthogonalization method invocation method = group.get("method", "gram_ns") if method == 'gram_ns': g = gram_newton_schulz( g, - steps=group["ns_steps"], - use_bf16=use_bf16, + steps=group["ns_steps"], reset_iterations=group["reset_iterations"], ns_coefficients=group["ns_coefficients"] ) elif method == 'mud': g = mud_whiten( g, - passes=1, - use_bf16=use_bf16 + passes=1 ) elif method == 'ns5': g = zeropower_via_newtonschulz5( g, - steps=group["ns_steps"], - use_bf16=use_bf16, + steps=group["ns_steps"], ns_coefficients=group["ns_coefficients"] ) else: From 59b348eafedc5e4d6d54724bdfaf6f407e51a900 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Sun, 5 Apr 2026 11:26:15 +0800 Subject: [PATCH 16/19] Keep last POLAR_EXPRESS_COEFFICIENT unscaled Apply safety_factor scaling to all POLAR_EXPRESS_COEFFICIENTS except the final tuple. The list comprehension now iterates over _unmodified_polar_express_coefficients[:-1] and the original last element is appended unchanged, preserving that coefficient (likely for correctness or numerical stability). --- modules/optimizer/muon.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 3dce5138..18dc4ad3 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -48,8 +48,8 @@ safety_factor = 1.05 POLAR_EXPRESS_COEFFICIENTS = [ (a / safety_factor , b / safety_factor**3 , c / safety_factor**5) - for (a, b, c) in _unmodified_polar_express_coefficients -] + for (a, b, c) in _unmodified_polar_express_coefficients[: -1] +] + [_unmodified_polar_express_coefficients[-1]] def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tuple]) -> Tensor: From 32811e2feb5302a6d6bea70c5b057aa02cddb708 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Sun, 5 Apr 2026 11:42:54 +0800 Subject: [PATCH 17/19] Cleanup whitespace in muon optimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove extraneous blank lines and trailing spaces in modules/optimizer/muon.py and tidy formatting around the normalization, transpose and Newton–Schulz loops. No functional logic was changed. --- modules/optimizer/muon.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 18dc4ad3..498cf1cf 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -94,43 +94,34 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_co on the smaller NxN Gram matrix to save up to 50% FLOPs. """ assert G.ndim == 3 - original_shape = G.shape dtype = G.dtype X = G.to(torch.float32) - X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) - should_transpose = X.size(-2) > X.size(-1) if should_transpose: X = X.mT - X = X.to(torch.float16) + if X.size(-2) != X.size(-1): R = torch.bmm(X, X.mT) Q = None - for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): if i in reset_iterations and i != 0: X = torch.bmm(Q, X) R = torch.bmm(X, X.mT) Q = None - Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i) - if i != 0 and i not in reset_iterations: Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0) else: Q = Z.clone() Q.diagonal(dim1=-2, dim2=-1).add_(a_i) - if i < steps - 1 and (i + 1) not in reset_iterations: RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0) R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0) - X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q) - else: for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): A = torch.bmm(X, X.mT) From 3a6fab7f4d06a203d338feab9af799e905770232 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Sun, 5 Apr 2026 13:02:50 +0800 Subject: [PATCH 18/19] Simplify Muon optimizer and param selection Remove unused coefficient tables and the mud_whiten path, and streamline orthogonalization to always use gram_newton_schulz. Add collections import and switch get_params_for_muon to a BFS that excludes Embedding modules and only collects trainable params with ndim >= 2. Cast intermediate X to float16 in zeropower_via_newtonschulz5 for faster half-precision ops, and drop unused imports (itertools.repeat) and redundant method dispatch in the Muon step. These changes reduce complexity and unify the orthogonalization flow. Update muon.py Update muon.py Update muon.py Update muon.py Update muon.py --- modules/optimizer/muon.py | 131 ++++++-------------------------------- 1 file changed, 19 insertions(+), 112 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 498cf1cf..4e6f7c67 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -1,35 +1,12 @@ +import collections import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.nn import Module, Parameter, Embedding from typing import List -from itertools import repeat from .chained_optimizer import ChainedOptimizer, OptimizerSpec -coeffs_list = [ - (8.28721201814563, -23.595886519098837, 17.300387312530933), - (4.107059111542203, -2.9478499167379106, 0.5448431082926601), - (3.9486908534822946, -2.908902115962949, 0.5518191394370137), - (3.3184196573706015, -2.488488024314874, 0.51004894012372), - (2.300652019954817, -1.6689039845747493, 0.4188073119525673), - (1.891301407787398, -1.2679958271945868, 0.37680408948524835), - (1.8750014808534479, -1.2500016453999487, 0.3750001645474248), - (1.875, -1.25, 0.375), # subsequent coeffs equal this numerically -] - -# safety factor for numerical stability (but exclude last polynomial ) -coeffs_list = [(a / 1.01 , b / 1.01**3 , c / 1.01**5) for (a, b, c) in coeffs_list[: -1]] + [coeffs_list[-1]] - - -# https://x.com/YouJiacheng/status/1905861218138804534 -YOU_COEFFICIENTS = [ - [4.0848, -6.8946, 2.9270], - [3.9505, -6.3029, 2.6377], - [3.7418, -5.5913, 2.3037], - [2.8769, -3.1427, 1.2046], - [2.8366, -3.0525, 1.2012] -] # https://arxiv.org/pdf/2505.16932 _unmodified_polar_express_coefficients = [ @@ -52,7 +29,7 @@ ] + [_unmodified_polar_express_coefficients[-1]] -def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tuple]) -> Tensor: +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose @@ -68,9 +45,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tup # Ensure spectral norm is at most 1 X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) + X = X.to(torch.float16) # Perform the NS iterations - hs = coeffs_list[: steps] + list(repeat(coeffs_list[-1], steps - len(coeffs_list))) + ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] if X.size(-2) < X.size(-1): for i in range(steps): a, b, c = ns_coefficients[i] @@ -87,7 +65,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, ns_coefficients: List[tup return X -def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_coefficients: List[tuple]) -> Tensor: +def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]=[2]) -> Tensor: """ Gram Newton-Schulz iteration to compute the orthogonalization of G. Mathematically identical to standard Newton-Schulz but computes iterating @@ -104,6 +82,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_co X = X.mT X = X.to(torch.float16) + ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] if X.size(-2) != X.size(-1): R = torch.bmm(X, X.mT) Q = None @@ -131,36 +110,6 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int], ns_co return X.to(dtype).view(original_shape) -def mud_whiten(G: Tensor, passes: int = 1) -> Tensor: - """ - MomentUm Decorrelation (MUD) iteration to compute the orthogonalization of G. - A lightweight PyTorch implementation based on "Beyond Muon: MUD for Faster Transformer Training". - Constructs a lower-triangular approximation to the Gram matrix and applies forward triangular solve. - """ - assert G.ndim == 3 - dtype = G.dtype - - # "triangular_solve_cuda" not implemented for 'BFloat16' - X = G.to(torch.float32) - - should_transpose = X.size(-2) > X.size(-1) - if should_transpose: - X = X.mT.contiguous() - - for _ in range(passes): - X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Row normalization - G_mat = torch.bmm(X, X.mT) # Row Gram (k,k) - T = torch.tril(G_mat) # Lower-triangular of Gram - T.diagonal(dim1=-2, dim2=-1).clamp_min_(1e-5) # avoid T all zero - X = torch.linalg.solve_triangular(T, X, upper=False) # Forward solve: T X = Q - X = F.normalize(X, p=2.0, dim=-1, eps=1e-8) # Renormalize rows - - if should_transpose: - X = X.mT - - return X.to(dtype).contiguous() - - class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -182,35 +131,10 @@ class Muon(torch.optim.Optimizer): momentum: The momentum used by the internal SGD. nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iteration steps to use. - ns_coefficients: List of tuples (a, b, c) for each NS step. Defaults to standard 5-step NS coeffs. - method: String selector for the orthogonalization strategy ('ns5', 'gram_ns', or 'mud'). """ - def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, - ns_steps=5, reset_iterations=[2], ns_coefficients=POLAR_EXPRESS_COEFFICIENTS, - method='gram_ns'): - if reset_iterations is None: - reset_iterations = [3] - # set [3] on default and set [2] on POLAR_EXPRESS_COEFFICIENTS or YOU_COEFFICIENTS - - if ns_coefficients is None: - parsed_coefficients = [(3.4445, -4.7750, 2.0315)] * ns_steps - else: - parsed_coefficients = list(ns_coefficients) - if len(parsed_coefficients) < ns_steps: - parsed_coefficients += [parsed_coefficients[-1]] * (ns_steps - len(parsed_coefficients)) - parsed_coefficients = parsed_coefficients[:ns_steps] - - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - reset_iterations=reset_iterations, - ns_coefficients=parsed_coefficients, - method=method - ) + def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) super().__init__(params, defaults) @torch.no_grad() @@ -240,29 +164,7 @@ def step(self, closure=None): original_shape = g.shape if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) - - # Dynamic orthogonalization method invocation - method = group.get("method", "gram_ns") - if method == 'gram_ns': - g = gram_newton_schulz( - g, - steps=group["ns_steps"], - reset_iterations=group["reset_iterations"], - ns_coefficients=group["ns_coefficients"] - ) - elif method == 'mud': - g = mud_whiten( - g, - passes=1 - ) - elif method == 'ns5': - g = zeropower_via_newtonschulz5( - g, - steps=group["ns_steps"], - ns_coefficients=group["ns_coefficients"] - ) - else: - raise ValueError(f"Unknown orthogonalization method: {method}") + g = gram_newton_schulz(g, steps=group["ns_steps"]) if group["weight_decay"] > 0: torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) @@ -278,15 +180,20 @@ def get_params_for_muon(model) -> List[Parameter]: Returns: A list of parameters that should be optimized with muon. """ + excluded_module_classes = (nn.Embedding) muon_params = [] - for module in model.modules(): - for name, param in module.named_parameters(recurse=False): + # BFS through all submodules and exclude parameters from certain module types + queue = collections.deque([model]) + while queue: + module = queue.popleft() + if isinstance(module, excluded_module_classes): + continue + for param in module.parameters(recurse=False): if not param.requires_grad: continue - if name == 'weight_g': - continue - if not isinstance(module, nn.Embedding) and param.ndim >= 2: + if param.ndim >= 2: muon_params.append(param) + queue.extend(list(module.children())) return muon_params From b101a8362a06d285ee92c02dffaf0758b531a120 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Sun, 5 Apr 2026 15:33:31 +0800 Subject: [PATCH 19/19] =?UTF-8?q?Pad=20Newton=E2=80=93Schulz=20coefficient?= =?UTF-8?q?s=20to=20match=20steps?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ensure ns_coefficients contains exactly `steps` entries by padding with the last POLAR_EXPRESS_COEFFICIENTS value when `steps` exceeds the predefined list. Adds itertools.repeat import and applies the fix in zeropower_via_newtonschulz5 and gram_newton_schulz to avoid out-of-range access during iteration. --- modules/optimizer/muon.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 4e6f7c67..b98e5ac8 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -5,6 +5,7 @@ from torch import Tensor from torch.nn import Module, Parameter, Embedding from typing import List +from itertools import repeat from .chained_optimizer import ChainedOptimizer, OptimizerSpec @@ -48,7 +49,7 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: X = X.to(torch.float16) # Perform the NS iterations - ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] + ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] + list(repeat(POLAR_EXPRESS_COEFFICIENTS[-1], steps - len(POLAR_EXPRESS_COEFFICIENTS))) if X.size(-2) < X.size(-1): for i in range(steps): a, b, c = ns_coefficients[i] @@ -82,7 +83,7 @@ def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]=[2]) - X = X.mT X = X.to(torch.float16) - ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] + ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] + list(repeat(POLAR_EXPRESS_COEFFICIENTS[-1], steps - len(POLAR_EXPRESS_COEFFICIENTS))) if X.size(-2) != X.size(-1): R = torch.bmm(X, X.mT) Q = None