Skip to content

Fix Mamba2.step() D handling when D_has_hdim=True#903

Closed
Chessing234 wants to merge 1 commit into
state-spaces:mainfrom
Chessing234:fix/issue-887-step-D-has-hdim
Closed

Fix Mamba2.step() D handling when D_has_hdim=True#903
Chessing234 wants to merge 1 commit into
state-spaces:mainfrom
Chessing234:fix/issue-887-step-D-has-hdim

Conversation

@Chessing234
Copy link
Copy Markdown
Contributor

Summary

  • When D_has_hdim=True, self.D has shape (d_ssm,) = (nheads * headdim,), but step() was treating it as shape (nheads,) in both the fallback path and the selective_state_update path.
  • Added conditional reshaping (rearrange(self.D, "(h p) -> h p", p=self.headdim)) to both code paths in step(), consistent with how forward() already handles D_has_hdim=True.

Details

In step(), there are two code paths that use self.D:

  1. Fallback path (when selective_state_update is None): was doing rearrange(self.D, "h -> h 1") which assumes D has shape (nheads,). Now conditionally uses rearrange(self.D, "(h p) -> h p", p=self.headdim) when D_has_hdim=True.

  2. selective_state_update path: was doing repeat(self.D, "h -> h p", p=self.headdim) which also assumes D has shape (nheads,). Now conditionally uses rearrange instead of repeat when D_has_hdim=True.

Both forward() code paths (lines 191 and 251) already handled this correctly via:

D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D

Fixes #887

Test plan

  • Instantiate Mamba2 with D_has_hdim=True and verify step() produces correct output matching forward() for single-token inputs
  • Verify shapes are consistent: self.D shape (d_ssm,) is reshaped to (nheads, headdim) before use in both step paths
  • Verify no regression when D_has_hdim=False (default behavior unchanged)

🤖 Generated with Claude Code

When D_has_hdim=True, self.D has shape (d_ssm,) = (nheads * headdim,)
but step() was treating it as shape (nheads,). Now reshape D
consistently with forward().

Fixes #887

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ccs1112
Copy link
Copy Markdown

ccs1112 commented Apr 10, 2026

I opened #893 on April 2 to address the same bug and also fix #888. It includes a parametrized test test_mamba2_step_forward_consistency covering both D_has_hdim cases with non-uniform D values to ensure the bug isn't masked.

@Chessing234
Copy link
Copy Markdown
Contributor Author

Closing as #893 (opened earlier by @ccs668899) already covers both this fix and #888, plus a parametrized test_mamba2_step_forward_consistency that catches the bug. Deferring to that PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mamba2.step() handles D incorrectly when D_has_dim=True

2 participants