diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp index a97588e6f..affbf562b 100644 --- a/csrc/selective_scan/selective_scan.cpp +++ b/csrc/selective_scan/selective_scan.cpp @@ -79,7 +79,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, void* delta_bias_ptr, void* x_ptr, bool has_z, - bool delta_softplus) { + bool delta_softplus, + const at::Tensor initial_state) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -138,6 +139,18 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, } params.out_batch_stride = out.stride(0); params.out_d_stride = out.stride(1); + + // Set initial state if provided + params.initial_state_ptr = initial_state.defined() ? initial_state.data_ptr() : nullptr; + if (initial_state.defined()) { + params.initial_state_batch_stride = initial_state.stride(0); + params.initial_state_d_stride = initial_state.stride(1); + params.initial_state_dstate_stride = initial_state.stride(2); + } else { + params.initial_state_batch_stride = 0; + params.initial_state_d_stride = 0; + params.initial_state_dstate_stride = 0; + } } void set_ssm_params_bwd(SSMParamsBwd ¶ms, @@ -181,7 +194,7 @@ void set_ssm_params_bwd(SSMParamsBwd ¶ms, // If not recompute_out_z, pass dout instead of out_z. // This won't be used by the bwd kernel recompute_out_z ? out_z : dout, - D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); + D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus, at::Tensor()); if (!recompute_out_z) { params.out_z_ptr = nullptr; } // Set the pointers and strides. @@ -229,7 +242,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, - bool delta_softplus) { + bool delta_softplus, + const c10::optional &initial_state_ = c10::nullopt) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -293,6 +307,15 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, CHECK_SHAPE(delta_bias, dim); } + if (initial_state_.has_value()) { + auto initial_state = initial_state_.value(); + TORCH_CHECK(initial_state.scalar_type() == weight_type); + TORCH_CHECK(initial_state.is_cuda()); + TORCH_CHECK(initial_state.dim() == 3); + CHECK_SHAPE(initial_state, batch_size, dim, dstate); + TORCH_CHECK(initial_state.stride(-1) == 1 || initial_state.size(-1) == 1); + } + at::Tensor z, out_z; const bool has_z = z_.has_value(); if (has_z) { @@ -319,7 +342,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, x.data_ptr(), has_z, - delta_softplus); + delta_softplus, + initial_state_.has_value() ? initial_state_.value() : at::Tensor()); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing diff --git a/csrc/selective_scan/selective_scan.h b/csrc/selective_scan/selective_scan.h index e2c7bcdbd..e5544cd6c 100644 --- a/csrc/selective_scan/selective_scan.h +++ b/csrc/selective_scan/selective_scan.h @@ -66,6 +66,10 @@ struct SSMParamsBase { void *__restrict__ x_ptr; void *__restrict__ z_ptr; void *__restrict__ out_z_ptr; + void *__restrict__ initial_state_ptr; // Optional initial state (batch, dim, dstate) + index_t initial_state_batch_stride; + index_t initial_state_d_stride; + index_t initial_state_dstate_stride; }; struct SSMParamsBwd: public SSMParamsBase { diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 80e9e37e3..f4900a010 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -23,7 +23,7 @@ template + bool kHasZ_, bool kHasInitialState_, typename input_t_, typename weight_t_> struct Selective_Scan_fwd_kernel_traits { static_assert(kNItems_ % 4 == 0); using input_t = input_t_; @@ -43,6 +43,7 @@ struct Selective_Scan_fwd_kernel_traits { static constexpr bool kIsVariableB = kIsVariableB_; static constexpr bool kIsVariableC = kIsVariableC_; static constexpr bool kHasZ = kHasZ_; + static constexpr bool kHasInitialState = kHasInitialState_; static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; @@ -76,6 +77,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { constexpr bool kIsVariableB = Ktraits::kIsVariableB; constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kHasInitialState = Ktraits::kHasInitialState; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNItems = Ktraits::kNItems; constexpr int kNRows = Ktraits::kNRows; @@ -218,8 +220,21 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { #pragma unroll for (int i = 0; i < kNItems; ++i) { if constexpr (!kIsComplex) { - thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), - !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + weight_t delta_a = exp2f(delta_vals[r][i] * A_val[r]); + weight_t delta_b_u = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; + + if constexpr (kHasInitialState) { + if (chunk == 0 && i == 0 && threadIdx.x == 0) { + const weight_t *initial_state = reinterpret_cast(params.initial_state_ptr) + + batch_id * params.initial_state_batch_stride + + dim_id * params.initial_state_d_stride; + weight_t h0_val = initial_state[state_idx * params.initial_state_dstate_stride]; + // Modify: deltaB[0]*u[0] -> deltaA[0]*h0 + deltaB[0]*u[0] + delta_b_u = delta_a * h0_val + delta_b_u; + } + } + + thread_data[i] = make_float2(delta_a, delta_b_u); if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float2(1.f, 0.f); @@ -229,6 +244,21 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // Pytorch's implementation of complex exp (which calls thrust) is very slow complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; + + // Incorporate initial state for chunk 0, first timestep (complex case) + if constexpr (kHasInitialState) { + if (chunk == 0 && i == 0 && threadIdx.x == 0) { + // For complex, initial_state is stored as complex_t (interleaved real/imag) + const complex_t *initial_state_complex = reinterpret_cast(params.initial_state_ptr) + + batch_id * (params.initial_state_batch_stride / 2) + + dim_id * (params.initial_state_d_stride / 2); + complex_t h0_val = initial_state_complex[state_idx * (params.initial_state_dstate_stride / 2)]; + complex_t h0_contrib = delta_a_exp * h0_val; + // B_delta_u_val is already complex_t, add h0_contrib + B_delta_u_val = h0_contrib + B_delta_u_val; + } + } + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { @@ -316,7 +346,8 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; + BOOL_SWITCH(params.initial_state_ptr != nullptr, kHasInitialState, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); dim3 grid(params.batch, params.dim / kNRows); @@ -341,6 +372,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 4c8a38821..3622bdb64 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -160,21 +160,32 @@ def forward(self, hidden_states, inference_params=None): ) else: x, z = xz.chunk(2, dim=1) - # Compute short convolution + # Compute short convolution, state continuity logic is inference-only if conv_state is not None: - # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) + k = self.d_conv - 1 + conv_inputs = conv_state[:, :, -k:] + x_input = torch.cat([conv_inputs, x], dim=2) + conv_state.copy_(F.pad(x_input, (self.d_conv - x_input.shape[-1], 0))[:, :, -self.d_conv:]) # Update state (B D W) + else: + x_input = x if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) + x_conv = self.conv1d(x_input) + if conv_state is not None: + x = self.act(x_conv[:, :, k:k+seqlen]) + else: + x = self.act(x_conv[..., :seqlen]) else: assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x=x, + x_conv = causal_conv1d_fn( + x=x_input, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, ) + if conv_state is not None: + x = x_conv[:, :, k:k+seqlen] + else: + x = x_conv # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension @@ -186,6 +197,7 @@ def forward(self, hidden_states, inference_params=None): B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() assert self.activation in ["silu", "swish"] + # Ability to pass initial state to kernel in inference - it will incorporate exp(deltaA[0]) * h0 into the first state y = selective_scan_fn( x, dt, @@ -197,6 +209,7 @@ def forward(self, hidden_states, inference_params=None): delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=ssm_state is not None, + initial_state=ssm_state, # Kernel will incorporate this as exp(deltaA[0]) * h0 ) if ssm_state is not None: y, last_state = y diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359c..7a4623556 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -24,7 +24,7 @@ class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, initial_state=None): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -37,13 +37,15 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp C = C.contiguous() if z is not None and z.stride(-1) != 1: z = z.contiguous() + if initial_state is not None and initial_state.stride(-1) != 1: + initial_state = initial_state.contiguous() if B.dim() == 3: B = rearrange(B, "b dstate l -> b 1 dstate l") ctx.squeeze_B = True if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, initial_state) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) @@ -104,12 +106,14 @@ def rms_norm_forward( def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, initial_state=None): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. + + initial_state: Optional (batch, dim, dstate) initial SSM state """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, initial_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, diff --git a/tests/test_mamba_chunk_processing.py b/tests/test_mamba_chunk_processing.py new file mode 100644 index 000000000..71c86b83e --- /dev/null +++ b/tests/test_mamba_chunk_processing.py @@ -0,0 +1,95 @@ +import torch +import pytest + +from mamba_ssm.models.mixer_seq_simple import MixerModel +from mamba_ssm.utils.generation import InferenceParams + + +def _make_mamba(d_model=32, n_layers=2, d_state=16, vocab_size=100, device="cuda", dtype=torch.float32): + """Create a simple Mamba model for testing.""" + model = MixerModel( + d_model=d_model, + n_layer=n_layers, + d_intermediate=0, # No MLP for simplicity + vocab_size=vocab_size, + ssm_cfg=dict(layer="Mamba1"), + device=device, + dtype=dtype, + ) + model.eval() + return model + + +def _empty_caches_for_model(model, batch_size, device, dtype): + """Create empty inference caches for a model.""" + max_seqlen = 1024 # Large enough for tests + caches = model.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype) + # Initialize inference params + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=0, + ) + inference_params.key_value_memory_dict = caches + return inference_params + + +@pytest.mark.parametrize("device", ["cuda"]) +def test_one_forward_matches_two_state_continuity_forward(device): + """Test that processing two chunks with state continuity matches full forward pass.""" + torch.manual_seed(0) + B, L, D = 2, 30, 32 + model = _make_mamba(d_model=D, n_layers=2, d_state=16, device=device, dtype=torch.float32) + + vocab_size = 100 + x = torch.randint(0, vocab_size, (B, L), device=device, dtype=torch.long) + L1 = L // 2 + + with torch.no_grad(): + # Full forward pass + gold = model(x) # (B, L, D) + + # Process in two chunks with state continuity + # Use seqlen_offset=0 to use parallel scan with initial_state support + inference_params = _empty_caches_for_model(model, B, device, torch.float32) + y1 = model(x[:, :L1], inference_params=inference_params) # (B, L1, D) + # State is updated in inference_params.key_value_memory_dict + # Process second chunk with same inference_params (state continuity) + y2 = model(x[:, L1:], inference_params=inference_params) # (B, L-L1, D) + got = torch.cat([y1, y2], dim=1) + + assert torch.allclose(got, gold, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("device", ["cuda"]) +def test_forward_matches_steps(device): + """Test that processing two chunks with state continuity matches sequential step-by-step processing.""" + torch.manual_seed(0) + B, L, D = 2, 30, 32 + model = _make_mamba(d_model=D, n_layers=2, d_state=16, device=device, dtype=torch.float32) + + vocab_size = 100 + x = torch.randint(0, vocab_size, (B, L), device=device, dtype=torch.long) + L1 = L // 2 + + with torch.no_grad(): + # Sequential step-by-step processing using forward with seqlen_offset > 0 + # This triggers step() method internally + inference_params_seq = _empty_caches_for_model(model, B, device, torch.float32) + y_list = [] + for t in range(L): + x_t = x[:, t:t+1] # (B, 1) + # seqlen_offset > 0 triggers step() method internally + inference_params_seq.seqlen_offset = t + y_t = model(x_t, inference_params=inference_params_seq) # (B, 1, D) + y_list.append(y_t) + logits_step = torch.cat(y_list, dim=1) # (B, L, D) + + # Process in two chunks using forward with seqlen_offset=0 + # This uses parallel scan with initial_state support + inference_params_chunk = _empty_caches_for_model(model, B, device, torch.float32) + y1 = model(x[:, :L1], inference_params=inference_params_chunk) # (B, L1, D) + y2 = model(x[:, L1:], inference_params=inference_params_chunk) # (B, L-L1, D) + logits_chunk = torch.cat([y1, y2], dim=1) + + assert torch.allclose(logits_chunk, logits_step, atol=1e-5, rtol=1e-5)