From 45f25e6dbce3b4865f5f24ef051f47fc0e24aaea Mon Sep 17 00:00:00 2001 From: skaae Date: Sun, 3 May 2026 18:34:33 +0000 Subject: [PATCH] Update angle normalization using remainder function #922 --- mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py b/mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py index c6b8596d1..7222cc772 100644 --- a/mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py +++ b/mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py @@ -75,6 +75,8 @@ def rotary_qk_inference_kernel( # Match angle_dt: tanh(angle_proj) * dt * pi angle_proj = tl.sigmoid(2.0 * angle_proj) * 2.0 - 1.0 # tanh angle = angle_state + angle_proj * dt * 3.141592653589793 # (rotary_dim // 2) + TWO_PI: tl.constexpr = 6.283185307179586 + angle = angle - TWO_PI * tl.floor(angle / TWO_PI) OUT_ANGLE_STATE = OUT_ANGLE_STATE + rd_half * stride_out_angle_state[2] tl.store(OUT_ANGLE_STATE, angle, mask=mask_angle) @@ -254,6 +256,7 @@ def apply_rotary_qk_inference_reference( # Match angle_dt: tanh(angle_proj) * dt * pi angle_proj = torch.tanh(angle_proj) angle = angle_state + angle_proj * dt[:, :, None] * math.pi # (B, N, S) + angle = torch.remainder(angle, 2 * math.pi) angle_state_new = angle angle = angle.unsqueeze(1).expand(-1, mimo_dim, -1, -1) # (B, R, N, S)