diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359..81799d1d 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -300,20 +300,6 @@ def backward(ctx, dout): delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps) delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() - if b_rms_weight is not None: - # Recompute & RMSNorm B - B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() - B = rms_norm_forward( - B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps - ) - B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - if c_rms_weight is not None: - # Recompute & RMSNorm C - C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() - C = rms_norm_forward( - C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps - ) - C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk).