Skip to content

Mamba2Mixer: use_cache with seq_len > 1 silently produces incorrect results (both CPU and GPU paths) #46032

@clara2911

Description

@clara2911

System Info

Summary

Hi, first of all: thank you for your amazing work on incorporating Mamba2 in the transformers library.
We are researching Mamba for a context in which we have very long sequences, and for computational efficiency we would want to process a sequence in batches (e.g. 10% of the sequence each time). However, we find that this is not supported because use_cache (start batch x from the last state of batch x-1) assumes seq_len=1 (L=1), and silently produces incorrect results when seq_len >1.

Details:

The Mamba2Mixer implementation only implements state carry over when one processes tokens one by one. So it assumes:

  • if use_cache (cached previous state), then it process timesteps one by one
  • if not use_cache, then it uses fast CUDA kernel and process in batch, always starting from initial state.

We are missing the option of:

  • use_cache and process multiple tokens of the sequence at once

In Mamba2Mixer:

is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx)

CPU path

If is_decoding=True it processes only the first input (L=1): dt = dt[:, 0, :][:, None, ...] and does a single SSM step instead of a full scan. If we would pass a full sequence into here, it silently only processes the first token, which produces incorrect results.

GPU path

If is_decoding=True, it does:

_, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
    [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
hidden_states_B_C = causal_conv1d_update(hidden_states_B_C, ...)  # expects 2D input: (B, conv_dim)

The .squeeze(1) only removes dim 1 when its size is 1. If seq_len > 1, squeeze(1) does nothing, tensor stays (B, L, proj_dim) instead of (B, proj_dim) and causal_conv1d_update (which expects 2D (B, conv_dim) input) will crash.

Parallel path

If is_decoding=False it does the parallel processing (e.g. hidden_states_B_C = causal_conv1d_fn(x=hidden_states_B_C.transpose(1,2), ...)), which handles full sequences correctly but cannot incorporate a non-zero initial state since initial_states is never passed through to mamba_chunk_scan_combined.

Related issues (and even a PR! but for Mamba1 ...) in state-spaces/mamba, all for Mamba1 except #641)

  • state-spaces/mamba#641: "Chunk-Wise Inference does not match Full-Length Inference" (Mamba2, mamba-ssm package, same bug demonstrated with reproduction code)
  • state-spaces/mamba#536 : "Chunked inference" (asks if fast chunked inference with state carry-over is possible)
  • state-spaces/mamba#101: "Using ssm_state and conv_state during training" (Mamba1, same conceptual limitation)
  • state-spaces/mamba#830: "Parallel forward with state" (PR implementing the fix for Mamba1 only)
  • state-spaces/mamba#488: "Initial state support for Mamba SSM (1)" (PR implementing chunked prefill for Mamba1 only)

Who can help?

@zucchini-nlp @Cyrilvallez 🤗

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

## Reproduction

```python
import torch
from transformers import Mamba2Config
from transformers.models.mamba2.modeling_mamba2 import Mamba2Model
from transformers.cache_utils import DynamicCache

# GIVEN Mamba2 model
config = Mamba2Config(
    hidden_size=64,
    num_hidden_layers=2,
    state_size=16,
    num_heads=8,
    head_dim=16,
    n_groups=1,
    expand=2,
    conv_kernel=4,
    chunk_size=8,
    vocab_size=2,
    use_bias=True,
    use_conv_bias=True,
)
model = Mamba2Model(config).eval()

B, L = 1, 8 # B = batch_size, L=seq_len
inputs_embeds = torch.randn(B, L, config.hidden_size)

# WHEN first, we process token-by-otken (seq_len=1) with result_tbt as output

cache_tbt = DynamicCache(config=config)
outputs_tbt = []
for t in range(L):
    out = model(inputs_embeds=inputs_embeds[:, t : t + 1, :], cache_params=cache_tbt, use_cache=True)
    outputs_tbt.append(out.last_hidden_state)
result_tbt = torch.cat(outputs_tbt, dim=1)  # (B, L, hidden_size)

# WHEN second, we process the same sequence in batch
cache_chunk = DynamicCache(config=config)
out_first = model(inputs_embeds=inputs_embeds[:, :1, :], cache_params=cache_chunk, use_cache=True)

out_rest = model(inputs_embeds=inputs_embeds[:, 1:, :], cache_params=cache_chunk, use_cache=True)
result_chunk = torch.cat([out_first.last_hidden_state, out_rest.last_hidden_state], dim=1)
error = (result_tbt[:, 1:, :] - result_chunk[:, 1:, :]).abs().max().item()

# THEN the results do not match, even though they should. 
print(f"difference between two processing paths (should be equal) is {error}")

Expected behavior

# THEN the results do not match, even though they should. 
print(f"difference between two processing paths (should be equal) is {error}"}

Expected no difference (no absolute error)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Good Second IssueIssues that are more difficult to do than "Good First" issues - give it a try if you want!bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions