diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 662dacee..b98e5ac8 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -1,3 +1,4 @@ +import collections import torch import torch.nn as nn import torch.nn.functional as F @@ -7,7 +8,9 @@ from itertools import repeat from .chained_optimizer import ChainedOptimizer, OptimizerSpec -coeffs_list = [ + +# https://arxiv.org/pdf/2505.16932 +_unmodified_polar_express_coefficients = [ (8.28721201814563, -23.595886519098837, 17.300387312530933), (4.107059111542203, -2.9478499167379106, 0.5448431082926601), (3.9486908534822946, -2.908902115962949, 0.5518191394370137), @@ -19,28 +22,15 @@ ] # safety factor for numerical stability (but exclude last polynomial ) -coeffs_list = [(a / 1.01 , b / 1.01**3 , c / 1.01**5) for (a, b, c) in coeffs_list[: -1]] + [coeffs_list[-1]] - - -def get_bf16_support_map(): - bf16_support_map = {} - - if not torch.cuda.is_available(): - return bf16_support_map +# safety_factor = 1.01 # 'Dao-AILab/gram-newton-schulz' set 1.05 +safety_factor = 1.05 +POLAR_EXPRESS_COEFFICIENTS = [ + (a / safety_factor , b / safety_factor**3 , c / safety_factor**5) + for (a, b, c) in _unmodified_polar_express_coefficients[: -1] +] + [_unmodified_polar_express_coefficients[-1]] - device_count = torch.cuda.device_count() - if device_count == 0: - return bf16_support_map - for i in range(device_count): - device = torch.device(f'cuda:{i}') - major, minor = torch.cuda.get_device_capability(device) - bf16_support_map[device] = (major >= 8) - - return bf16_support_map - - -def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor: +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose @@ -51,22 +41,24 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor performance at all relative to UV^T, where USV^T = G is the SVD. """ assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - #a, b, c = (3.4445, -4.7750, 2.0315) - X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) + X = G.to(torch.float32) # Ensure spectral norm is at most 1 X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) + X = X.to(torch.float16) # Perform the NS iterations - hs = coeffs_list[: steps] + list(repeat(coeffs_list[-1], steps - len(coeffs_list))) + ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] + list(repeat(POLAR_EXPRESS_COEFFICIENTS[-1], steps - len(POLAR_EXPRESS_COEFFICIENTS))) if X.size(-2) < X.size(-1): - for a, b, c in hs: + for i in range(steps): + a, b, c = ns_coefficients[i] A = torch.bmm(X, X.mT) A = torch.baddbmm(A, A, A, beta=b, alpha=c) X = torch.baddbmm(X, A, X, beta=a, alpha=1) else: - for a, b, c in hs: + for i in range(steps): + a, b, c = ns_coefficients[i] A = torch.bmm(X.mT, X) A = torch.baddbmm(A, A, A, beta=b, alpha=c) X = torch.baddbmm(X, X, A, beta=a, alpha=1) @@ -74,6 +66,51 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor return X +def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]=[2]) -> Tensor: + """ + Gram Newton-Schulz iteration to compute the orthogonalization of G. + Mathematically identical to standard Newton-Schulz but computes iterating + on the smaller NxN Gram matrix to save up to 50% FLOPs. + """ + assert G.ndim == 3 + original_shape = G.shape + dtype = G.dtype + + X = G.to(torch.float32) + X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) + should_transpose = X.size(-2) > X.size(-1) + if should_transpose: + X = X.mT + X = X.to(torch.float16) + + ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] + list(repeat(POLAR_EXPRESS_COEFFICIENTS[-1], steps - len(POLAR_EXPRESS_COEFFICIENTS))) + if X.size(-2) != X.size(-1): + R = torch.bmm(X, X.mT) + Q = None + for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): + if i in reset_iterations and i != 0: + X = torch.bmm(Q, X) + R = torch.bmm(X, X.mT) + Q = None + Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i) + if i != 0 and i not in reset_iterations: + Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0) + else: + Q = Z.clone() + Q.diagonal(dim1=-2, dim2=-1).add_(a_i) + if i < steps - 1 and (i + 1) not in reset_iterations: + RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0) + R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0) + X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q) + else: + for i, (a_i, b_i, c_i) in enumerate(ns_coefficients): + A = torch.bmm(X, X.mT) + B = torch.baddbmm(A, A, A, beta=b_i, alpha=c_i) + X = torch.baddbmm(X, B, X, beta=a_i, alpha=1.0) + + return X.to(dtype).view(original_shape) + + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -100,7 +137,6 @@ class Muon(torch.optim.Optimizer): def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) super().__init__(params, defaults) - self.bf16_support_map = get_bf16_support_map() @torch.no_grad() def step(self, closure=None): @@ -129,8 +165,8 @@ def step(self, closure=None): original_shape = g.shape if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) - use_bf16 = self.bf16_support_map.get(g.device, False) - g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) + g = gram_newton_schulz(g, steps=group["ns_steps"]) + if group["weight_decay"] > 0: torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5) @@ -145,15 +181,20 @@ def get_params_for_muon(model) -> List[Parameter]: Returns: A list of parameters that should be optimized with muon. """ + excluded_module_classes = (nn.Embedding) muon_params = [] - for module in model.modules(): - for name, param in module.named_parameters(recurse=False): + # BFS through all submodules and exclude parameters from certain module types + queue = collections.deque([model]) + while queue: + module = queue.popleft() + if isinstance(module, excluded_module_classes): + continue + for param in module.parameters(recurse=False): if not param.requires_grad: continue - if name == 'weight_g': - continue - if not isinstance(module, nn.Embedding) and param.ndim >= 2: + if param.ndim >= 2: muon_params.append(param) + queue.extend(list(module.children())) return muon_params