Skip to content

Possible redundant norm computation in MuonHyperball #155

@Harry-Chen

Description

@Harry-Chen

As described by MuonH paper, the parameter's norm is normalized to its F-norm at first step, so this value should be computed at most once.

The current implementation:

@override
def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None:
"""Store the original weight norm and normalize the update using Frobenius norm.
Args:
p: The parameter tensor.
update: The orthogonalized gradient tensor.
"""
# Use user-specified radius or compute R = ||W_t||_F (Frobenius norm)
R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item()
self.state[p]["hyperball_R"] = R
# Normalize the update in-place and scale by R
# This modifies update to be: R * normalize(update) using Frobenius norm.
update_norm = update.norm().clamp_min(self.hyperball_eps)
update.mul_(R / update_norm)

Computes p.norm in each optimizer call, which leads to redundancy. My suggestion is something like:

    @override
    def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None:
        if "hyperball_R" not in self.state:
            R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item()
            self.state[p]["hyperball_R"] = R
        else:
            R = self.state[p]["hyperball_R"]
        update_norm = update.norm().clamp_min(self.hyperball_eps)
        update.mul_(R / update_norm)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions