Skip to content

Mamba2 step() works silently but when D_has_hdim=True and selective_state_update=None but is #888

@PuR3Luck

Description

@PuR3Luck

There seems to be an issue on https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2.py#L319

When D_has_hdim=True, self.D has shape (d_ssm) = (n_heads * headdim,)

So the the rearrange is unsqueezing the last dimension (n_heads * headdim,) -> (n_heads * headdim, 1).
However the shape of x and y is (b h p) which cannot be broadcasted to the shape of the rearranged D (n_heads * headdim, 1)

This behaviour is inconsistent with the forward() method

If needed, I can contribute a PR for this

Related issue: #887

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions