Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions mamba_ssm/modules/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,21 @@ def step(self, hidden_states, conv_state, ssm_state):
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
if self.D_has_hdim:
y = y + rearrange(self.D.to(dtype), "(h p) -> h p", p=self.headdim) * x
else:
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
if not self.rmsnorm:
y = y * self.act(z) # (B D)
else:
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
dt = repeat(dt, "b h -> b h p", p=self.headdim)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
D = repeat(self.D, "h -> h p", p=self.headdim)
if self.D_has_hdim:
D = rearrange(self.D, "(h p) -> h p", p=self.headdim)
else:
D = repeat(self.D, "h -> h p", p=self.headdim)
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
Expand Down
65 changes: 65 additions & 0 deletions tests/test_mamba2_d_has_hdim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Test that Mamba2 forward() and step() produce consistent outputs when D_has_hdim=True.

Regression test for https://github.com/state-spaces/mamba/issues/887
"""
import torch
import pytest
from einops import rearrange

from mamba_ssm.modules.mamba2 import Mamba2


@pytest.mark.parametrize("D_has_hdim", [True, False])
def test_mamba2_step_forward_consistency(D_has_hdim):
"""step() must match forward() for D_has_hdim=True and D_has_hdim=False."""
torch.manual_seed(42)
batch, seqlen = 2, 16
d_model, headdim, d_state = 256, 64, 64
device = "cuda"
dtype = torch.float32

model = Mamba2(
d_model=d_model,
headdim=headdim,
d_state=d_state,
d_conv=4,
D_has_hdim=D_has_hdim,
ngroups=1,
rmsnorm=False,
use_mem_eff_path=False,
device=device,
dtype=dtype,
)
model.eval()

# Randomize D so non-uniform values expose the bug
with torch.no_grad():
model.D.copy_(torch.randn_like(model.D))

x = torch.randn(batch, seqlen, d_model, device=device, dtype=dtype)

# Forward pass — reference output
with torch.no_grad():
out_forward = model(x)

# Step pass — one token at a time
conv_state, ssm_state = model.allocate_inference_cache(batch, seqlen, dtype=dtype)
step_outputs = []
with torch.no_grad():
for t in range(seqlen):
out_t, conv_state, ssm_state = model.step(
x[:, t : t + 1, :], conv_state, ssm_state
)
step_outputs.append(out_t)
out_step = torch.cat(step_outputs, dim=1)

# After conv warmup (d_conv - 1 = 3 tokens), outputs should match
d_conv = model.d_conv
out_fwd_tail = out_forward[:, d_conv - 1 :, :]
out_step_tail = out_step[:, d_conv - 1 :, :]

max_diff = (out_fwd_tail - out_step_tail).abs().max().item()
print(f"D_has_hdim={D_has_hdim}, max diff after conv warmup: {max_diff:.2e}")
assert torch.allclose(out_fwd_tail, out_step_tail, rtol=1e-3, atol=1e-3), (
f"forward/step mismatch with D_has_hdim={D_has_hdim}: max diff {max_diff:.2e}"
)