Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 75 additions & 34 deletions modules/optimizer/muon.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -7,7 +8,9 @@
from itertools import repeat
from .chained_optimizer import ChainedOptimizer, OptimizerSpec

coeffs_list = [

# https://arxiv.org/pdf/2505.16932
_unmodified_polar_express_coefficients = [
(8.28721201814563, -23.595886519098837, 17.300387312530933),
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
Expand All @@ -19,28 +22,15 @@
]

# safety factor for numerical stability (but exclude last polynomial )
coeffs_list = [(a / 1.01 , b / 1.01**3 , c / 1.01**5) for (a, b, c) in coeffs_list[: -1]] + [coeffs_list[-1]]


def get_bf16_support_map():
bf16_support_map = {}

if not torch.cuda.is_available():
return bf16_support_map
# safety_factor = 1.01 # 'Dao-AILab/gram-newton-schulz' set 1.05
safety_factor = 1.05
POLAR_EXPRESS_COEFFICIENTS = [
(a / safety_factor , b / safety_factor**3 , c / safety_factor**5)
for (a, b, c) in _unmodified_polar_express_coefficients[: -1]
] + [_unmodified_polar_express_coefficients[-1]]

device_count = torch.cuda.device_count()
if device_count == 0:
return bf16_support_map

for i in range(device_count):
device = torch.device(f'cuda:{i}')
major, minor = torch.cuda.get_device_capability(device)
bf16_support_map[device] = (major >= 8)

return bf16_support_map


def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor:
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
Expand All @@ -51,29 +41,76 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
#a, b, c = (3.4445, -4.7750, 2.0315)

X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32)
X = G.to(torch.float32)

# Ensure spectral norm is at most 1
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
X = X.to(torch.float16)

# Perform the NS iterations
hs = coeffs_list[: steps] + list(repeat(coeffs_list[-1], steps - len(coeffs_list)))
ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] + list(repeat(POLAR_EXPRESS_COEFFICIENTS[-1], steps - len(POLAR_EXPRESS_COEFFICIENTS)))
if X.size(-2) < X.size(-1):
for a, b, c in hs:
for i in range(steps):
a, b, c = ns_coefficients[i]
A = torch.bmm(X, X.mT)
A = torch.baddbmm(A, A, A, beta=b, alpha=c)
X = torch.baddbmm(X, A, X, beta=a, alpha=1)
else:
for a, b, c in hs:
for i in range(steps):
a, b, c = ns_coefficients[i]
A = torch.bmm(X.mT, X)
A = torch.baddbmm(A, A, A, beta=b, alpha=c)
X = torch.baddbmm(X, X, A, beta=a, alpha=1)

return X


def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]=[2]) -> Tensor:
"""
Gram Newton-Schulz iteration to compute the orthogonalization of G.
Mathematically identical to standard Newton-Schulz but computes iterating
on the smaller NxN Gram matrix to save up to 50% FLOPs.
"""
assert G.ndim == 3
original_shape = G.shape
dtype = G.dtype

X = G.to(torch.float32)
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
should_transpose = X.size(-2) > X.size(-1)
if should_transpose:
X = X.mT
X = X.to(torch.float16)

ns_coefficients = POLAR_EXPRESS_COEFFICIENTS[:steps] + list(repeat(POLAR_EXPRESS_COEFFICIENTS[-1], steps - len(POLAR_EXPRESS_COEFFICIENTS)))
if X.size(-2) != X.size(-1):
R = torch.bmm(X, X.mT)
Q = None
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
if i in reset_iterations and i != 0:
X = torch.bmm(Q, X)
R = torch.bmm(X, X.mT)
Q = None
Z = torch.baddbmm(R, R, R, beta=b_i, alpha=c_i)
if i != 0 and i not in reset_iterations:
Q = torch.baddbmm(Q, Q, Z, beta=a_i, alpha=1.0)
else:
Q = Z.clone()
Q.diagonal(dim1=-2, dim2=-1).add_(a_i)
if i < steps - 1 and (i + 1) not in reset_iterations:
RZ = torch.baddbmm(R, R, Z, beta=a_i, alpha=1.0)
R = torch.baddbmm(RZ, Z, RZ, beta=a_i, alpha=1.0)
X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q)
else:
for i, (a_i, b_i, c_i) in enumerate(ns_coefficients):
A = torch.bmm(X, X.mT)
B = torch.baddbmm(A, A, A, beta=b_i, alpha=c_i)
X = torch.baddbmm(X, B, X, beta=a_i, alpha=1.0)

return X.to(dtype).view(original_shape)


class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
Expand All @@ -100,7 +137,6 @@ class Muon(torch.optim.Optimizer):
def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
super().__init__(params, defaults)
self.bf16_support_map = get_bf16_support_map()

@torch.no_grad()
def step(self, closure=None):
Expand Down Expand Up @@ -129,8 +165,8 @@ def step(self, closure=None):
original_shape = g.shape
if g.ndim >= 4: # for the case of conv filters
g = g.view(g.size(0), g.size(1), -1)
use_bf16 = self.bf16_support_map.get(g.device, False)
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
g = gram_newton_schulz(g, steps=group["ns_steps"])

if group["weight_decay"] > 0:
torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"])
torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5)
Expand All @@ -145,15 +181,20 @@ def get_params_for_muon(model) -> List[Parameter]:
Returns:
A list of parameters that should be optimized with muon.
"""
excluded_module_classes = (nn.Embedding)
muon_params = []
for module in model.modules():
for name, param in module.named_parameters(recurse=False):
# BFS through all submodules and exclude parameters from certain module types
queue = collections.deque([model])
while queue:
module = queue.popleft()
if isinstance(module, excluded_module_classes):
continue
for param in module.parameters(recurse=False):
if not param.requires_grad:
continue
if name == 'weight_g':
continue
if not isinstance(module, nn.Embedding) and param.ndim >= 2:
if param.ndim >= 2:
muon_params.append(param)
queue.extend(list(module.children()))
return muon_params


Expand Down