Fix Mamba2.step() D handling when D_has_hdim=True#903
Closed
Chessing234 wants to merge 1 commit into
Closed
Conversation
When D_has_hdim=True, self.D has shape (d_ssm,) = (nheads * headdim,) but step() was treating it as shape (nheads,). Now reshape D consistently with forward(). Fixes #887 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
3 tasks
Contributor
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
D_has_hdim=True,self.Dhas shape(d_ssm,)=(nheads * headdim,), butstep()was treating it as shape(nheads,)in both the fallback path and theselective_state_updatepath.rearrange(self.D, "(h p) -> h p", p=self.headdim)) to both code paths instep(), consistent with howforward()already handlesD_has_hdim=True.Details
In
step(), there are two code paths that useself.D:Fallback path (when
selective_state_update is None): was doingrearrange(self.D, "h -> h 1")which assumes D has shape(nheads,). Now conditionally usesrearrange(self.D, "(h p) -> h p", p=self.headdim)whenD_has_hdim=True.selective_state_updatepath: was doingrepeat(self.D, "h -> h p", p=self.headdim)which also assumes D has shape(nheads,). Now conditionally usesrearrangeinstead ofrepeatwhenD_has_hdim=True.Both
forward()code paths (lines 191 and 251) already handled this correctly via:Fixes #887
Test plan
Mamba2withD_has_hdim=Trueand verifystep()produces correct output matchingforward()for single-token inputsself.Dshape(d_ssm,)is reshaped to(nheads, headdim)before use in both step pathsD_has_hdim=False(default behavior unchanged)🤖 Generated with Claude Code