Skip to content

Mamba2.step() handles D incorrectly when D_has_dim=True #887

@GiftedNovaHD

Description

@GiftedNovaHD

Hi, I think there's a bug in Mamba2.step() when we set D_has_dim = True.

In __init__, self.D is initialized as

self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))

Thus, when D_has_hdim=True, self.D has shape (d_ssm,) = (nheads * headdim,).

I believe the forward path appear to handle this correctly by reshaping D as (h, p)

So when D_has_hdim=True, self.D has shape (d_ssm,) = (nheads * headdim,).

But I don't think step() makes the same distinction because for this line, it implicitly assumes that self.D is per head, with shape (nheads,), and not per head-dim: (nheads * headdim,).

So for D_has_hdim = True, step() seems inconsistent with forward(), as forward() reshapes self.D from (h * p,) -> (h,p), whereas step() instead treats it as it were (h,), as mentioned earlier

I can contribute a PR to fix this

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions