System Info
Summary
Hi, first of all: thank you for your amazing work on incorporating Mamba2 in the transformers library.
We are researching Mamba for a context in which we have very long sequences, and for computational efficiency we would want to process a sequence in batches (e.g. 10% of the sequence each time). However, we find that this is not supported because use_cache (start batch x from the last state of batch x-1) assumes seq_len=1 (L=1), and silently produces incorrect results when seq_len >1.
Details:
The Mamba2Mixer implementation only implements state carry over when one processes tokens one by one. So it assumes:
if use_cache (cached previous state), then it process timesteps one by one
if not use_cache, then it uses fast CUDA kernel and process in batch, always starting from initial state.
We are missing the option of:
use_cache and process multiple tokens of the sequence at once
In Mamba2Mixer:
is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx)
CPU path
If is_decoding=True it processes only the first input (L=1): dt = dt[:, 0, :][:, None, ...] and does a single SSM step instead of a full scan. If we would pass a full sequence into here, it silently only processes the first token, which produces incorrect results.
GPU path
If is_decoding=True, it does:
_, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
hidden_states_B_C = causal_conv1d_update(hidden_states_B_C, ...) # expects 2D input: (B, conv_dim)
The .squeeze(1) only removes dim 1 when its size is 1. If seq_len > 1, squeeze(1) does nothing, tensor stays (B, L, proj_dim) instead of (B, proj_dim) and causal_conv1d_update (which expects 2D (B, conv_dim) input) will crash.
Parallel path
If is_decoding=False it does the parallel processing (e.g. hidden_states_B_C = causal_conv1d_fn(x=hidden_states_B_C.transpose(1,2), ...)), which handles full sequences correctly but cannot incorporate a non-zero initial state since initial_states is never passed through to mamba_chunk_scan_combined.
Related issues (and even a PR! but for Mamba1 ...) in state-spaces/mamba, all for Mamba1 except #641)
- state-spaces/mamba#641: "Chunk-Wise Inference does not match Full-Length Inference" (Mamba2, mamba-ssm package, same bug demonstrated with reproduction code)
- state-spaces/mamba#536 : "Chunked inference" (asks if fast chunked inference with state carry-over is possible)
- state-spaces/mamba#101: "Using ssm_state and conv_state during training" (Mamba1, same conceptual limitation)
- state-spaces/mamba#830: "Parallel forward with state" (PR implementing the fix for Mamba1 only)
- state-spaces/mamba#488: "Initial state support for Mamba SSM (1)" (PR implementing chunked prefill for Mamba1 only)
Who can help?
@zucchini-nlp @Cyrilvallez 🤗
Information
Tasks
Reproduction
## Reproduction
```python
import torch
from transformers import Mamba2Config
from transformers.models.mamba2.modeling_mamba2 import Mamba2Model
from transformers.cache_utils import DynamicCache
# GIVEN Mamba2 model
config = Mamba2Config(
hidden_size=64,
num_hidden_layers=2,
state_size=16,
num_heads=8,
head_dim=16,
n_groups=1,
expand=2,
conv_kernel=4,
chunk_size=8,
vocab_size=2,
use_bias=True,
use_conv_bias=True,
)
model = Mamba2Model(config).eval()
B, L = 1, 8 # B = batch_size, L=seq_len
inputs_embeds = torch.randn(B, L, config.hidden_size)
# WHEN first, we process token-by-otken (seq_len=1) with result_tbt as output
cache_tbt = DynamicCache(config=config)
outputs_tbt = []
for t in range(L):
out = model(inputs_embeds=inputs_embeds[:, t : t + 1, :], cache_params=cache_tbt, use_cache=True)
outputs_tbt.append(out.last_hidden_state)
result_tbt = torch.cat(outputs_tbt, dim=1) # (B, L, hidden_size)
# WHEN second, we process the same sequence in batch
cache_chunk = DynamicCache(config=config)
out_first = model(inputs_embeds=inputs_embeds[:, :1, :], cache_params=cache_chunk, use_cache=True)
out_rest = model(inputs_embeds=inputs_embeds[:, 1:, :], cache_params=cache_chunk, use_cache=True)
result_chunk = torch.cat([out_first.last_hidden_state, out_rest.last_hidden_state], dim=1)
error = (result_tbt[:, 1:, :] - result_chunk[:, 1:, :]).abs().max().item()
# THEN the results do not match, even though they should.
print(f"difference between two processing paths (should be equal) is {error}")
Expected behavior
# THEN the results do not match, even though they should.
print(f"difference between two processing paths (should be equal) is {error}"}
Expected no difference (no absolute error)
System Info
Summary
Hi, first of all: thank you for your amazing work on incorporating Mamba2 in the
transformerslibrary.We are researching Mamba for a context in which we have very long sequences, and for computational efficiency we would want to process a sequence in batches (e.g. 10% of the sequence each time). However, we find that this is not supported because
use_cache(start batch x from the last state of batch x-1) assumesseq_len=1(L=1), and silently produces incorrect results whenseq_len >1.Details:
The Mamba2Mixer implementation only implements state carry over when one processes tokens one by one. So it assumes:
if use_cache(cached previous state), then it process timesteps one by oneif not use_cache, then it uses fast CUDA kernel and process in batch, always starting from initial state.We are missing the option of:
use_cacheand process multiple tokens of the sequence at onceIn
Mamba2Mixer:CPU path
If
is_decoding=Trueit processes only the first input (L=1):dt = dt[:, 0, :][:, None, ...]and does a single SSM step instead of a full scan. If we would pass a full sequence into here, it silently only processes the first token, which produces incorrect results.GPU path
If
is_decoding=True, it does:The
.squeeze(1)only removes dim 1 when its size is 1. Ifseq_len > 1,squeeze(1)does nothing, tensor stays(B, L, proj_dim)instead of(B, proj_dim)andcausal_conv1d_update(which expects 2D(B, conv_dim)input) will crash.Parallel path
If
is_decoding=Falseit does the parallel processing (e.g.hidden_states_B_C = causal_conv1d_fn(x=hidden_states_B_C.transpose(1,2), ...)), which handles full sequences correctly but cannot incorporate a non-zero initial state sinceinitial_statesis never passed through tomamba_chunk_scan_combined.Related issues (and even a PR! but for Mamba1 ...) in
state-spaces/mamba, all for Mamba1 except #641)Who can help?
@zucchini-nlp @Cyrilvallez 🤗
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
Expected no difference (no absolute error)