From e353f3c2721a423cba6dacd68ac631ddda29d4e0 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 7 Apr 2026 09:17:04 -0700 Subject: [PATCH 1/4] clean up hyperball Signed-off-by: mikail --- .../muon_hyperball.py | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index c9a8c09d..66c58151 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -35,8 +35,8 @@ class MuonHyperball(muon.Muon): W_{t+1} = R \\cdot \\text{normalize}(W_t - \\text{lr} \\cdot R \\cdot \\text{normalize}(\\text{update})) - where :math:`R` is the Frobenius norm of :math:`W_t` (or a user-specified radius). This keeps - the weight matrix at constant scale while updating. + where :math:`R` is the user-specified hyperball radius. This keeps the weight matrix at + constant scale while updating. Warning: This optimizer is experimental and may change in future versions. @@ -49,52 +49,62 @@ class MuonHyperball(muon.Muon): *args: Arguments passed to Muon. hyperball_eps: Epsilon for numerical stability in normalization. Default: ``1e-8``. - hyperball_radius: Fixed radius for the hyperball. If ``None`` (default), - uses each parameter's initial Frobenius norm as its radius. If specified, all - parameters will be rescaled to have this radius at initialization. + hyperball_radius: Fixed radius for the hyperball. All parameters must + already have this Frobenius norm at construction time. **kwargs: Keyword arguments passed to Muon. + Raises: + ValueError: If any parameter has zero norm, or if a parameter's + Frobenius norm does not match ``hyperball_radius``. + """ def __init__( self, *args: Any, hyperball_eps: float = 1e-8, - hyperball_radius: float | None = None, + hyperball_radius: float, **kwargs: Any, ) -> None: self.hyperball_eps = hyperball_eps self.hyperball_radius = hyperball_radius super().__init__(*args, **kwargs) - # Validate and optionally rescale parameters based on hyperball_radius. with torch.no_grad(): for group in self.param_groups: for p in group["params"]: p_norm = p.norm() - # Validate that parameter has non-zero norm. - if p_norm.item() == 0: + if p_norm == 0: + raise ValueError( + "MuonHyperball requires all parameters to have non-zero norm. " + "Found parameter with zero norm." + ) + if not torch.isclose( + p_norm, + torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device), + rtol=1e-5, + atol=1e-8, + ): raise ValueError( - "MuonHyperball requires all parameters to have non-zero norm. Found parameter with zero norm." + f"hyperball_radius={self.hyperball_radius} was specified but a parameter " + f"has Frobenius norm {p_norm.item()}. Rescale your model parameters to the " + f"desired radius before constructing the optimizer." ) - # Rescale parameter to have the specified radius if provided. - if self.hyperball_radius is not None: - p.mul_(self.hyperball_radius / p_norm.clamp_min(self.hyperball_eps)) @override def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None: - """Store the original weight norm and normalize the update using Frobenius norm. + """Normalize the update using Frobenius norm, scaled by R. Args: p: The parameter tensor. update: The orthogonalized gradient tensor. """ - # Use user-specified radius or compute R = ||W_t||_F (Frobenius norm) - R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item() - self.state[p]["hyperball_R"] = R + if "hyperball_R" not in self.state[p]: + self.state[p]["hyperball_R"] = torch.tensor( + self.hyperball_radius, dtype=p.dtype, device=p.device + ) + R = self.state[p]["hyperball_R"] - # Normalize the update in-place and scale by R - # This modifies update to be: R * normalize(update) using Frobenius norm. update_norm = update.norm().clamp_min(self.hyperball_eps) update.mul_(R / update_norm) From 2fcf8486f9aa4ed6d39eb0b62020566c79e7b145 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 7 Apr 2026 13:29:35 -0700 Subject: [PATCH 2/4] clean up hyperball Signed-off-by: mikail --- .../muon_hyperball.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index 66c58151..ef8a0651 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -121,3 +121,21 @@ def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: # Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm. p_norm = p.norm().clamp_min(self.hyperball_eps) p.mul_(R / p_norm) + + @staticmethod + def _compute_tangent_projection( + param: torch.Tensor, grad_like: torch.Tensor + ) -> torch.Tensor: + """Compute the Riemannian gradient via tangent-space projection. + Frobenius sphere (entire matrix on a single sphere). + + Args: + param: Parameter tensor (2D). + grad_like: Gradient-like tensor (momentum buffer or gradient). + + Returns: + The tangent-space projected gradient. + """ + + projection = (param * grad_like).sum() / param.pow(2).sum().clamp(min=1e-12) + return grad_like - projection * param From 8b137d0c2dd36b8b8a8d6ed4d102da875026031e Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 7 Apr 2026 13:31:39 -0700 Subject: [PATCH 3/4] clean up hyperball, revert some comments Signed-off-by: mikail --- .../muon_hyperball.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index ef8a0651..1bde8dc8 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -35,7 +35,7 @@ class MuonHyperball(muon.Muon): W_{t+1} = R \\cdot \\text{normalize}(W_t - \\text{lr} \\cdot R \\cdot \\text{normalize}(\\text{update})) - where :math:`R` is the user-specified hyperball radius. This keeps the weight matrix at + where :math:`R` is the user-specified Frobenius norm. This keeps the weight matrix at constant scale while updating. Warning: @@ -122,20 +122,3 @@ def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: p_norm = p.norm().clamp_min(self.hyperball_eps) p.mul_(R / p_norm) - @staticmethod - def _compute_tangent_projection( - param: torch.Tensor, grad_like: torch.Tensor - ) -> torch.Tensor: - """Compute the Riemannian gradient via tangent-space projection. - Frobenius sphere (entire matrix on a single sphere). - - Args: - param: Parameter tensor (2D). - grad_like: Gradient-like tensor (momentum buffer or gradient). - - Returns: - The tangent-space projected gradient. - """ - - projection = (param * grad_like).sum() / param.pow(2).sum().clamp(min=1e-12) - return grad_like - projection * param From 3655f80e52952b02e09cf73bb6475d514be6c88b Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 7 Apr 2026 13:33:21 -0700 Subject: [PATCH 4/4] linting Signed-off-by: mikail --- .../orthogonalized_optimizers/muon_hyperball.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index 1bde8dc8..8d75b747 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -100,9 +100,7 @@ def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> update: The orthogonalized gradient tensor. """ if "hyperball_R" not in self.state[p]: - self.state[p]["hyperball_R"] = torch.tensor( - self.hyperball_radius, dtype=p.dtype, device=p.device - ) + self.state[p]["hyperball_R"] = torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device) R = self.state[p]["hyperball_R"] update_norm = update.norm().clamp_min(self.hyperball_eps) @@ -121,4 +119,3 @@ def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: # Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm. p_norm = p.norm().clamp_min(self.hyperball_eps) p.mul_(R / p_norm) -