From 4691d57d54b900c13228b46dd6fe00757221baa4 Mon Sep 17 00:00:00 2001 From: leeyongjun Date: Tue, 19 May 2026 09:03:32 +0000 Subject: [PATCH] fix(logits): avoid nan in fused softcap Signed-off-by: leeyongjun --- .../runtime/layers/logits_processor.py | 5 ++- test/runtime/test_logits_processor.py | 35 +++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 test/runtime/test_logits_processor.py diff --git a/python/tokenspeed/runtime/layers/logits_processor.py b/python/tokenspeed/runtime/layers/logits_processor.py index 039e49e30..37125f4ac 100755 --- a/python/tokenspeed/runtime/layers/logits_processor.py +++ b/python/tokenspeed/runtime/layers/logits_processor.py @@ -537,9 +537,8 @@ def fused_softcap_kernel( # Perform operations in-place x = x / softcapping_value - # Manual tanh implementation using exp - exp2x = tl.exp(2 * x) - x = (exp2x - 1) / (exp2x + 1) + # Stable tanh form; the exp ratio overflows to inf/inf for large logits. + x = 2 * tl.sigmoid(2 * x) - 1 x = x * softcapping_value diff --git a/test/runtime/test_logits_processor.py b/test/runtime/test_logits_processor.py new file mode 100644 index 000000000..14f41cb08 --- /dev/null +++ b/test/runtime/test_logits_processor.py @@ -0,0 +1,35 @@ +"""Regression tests for logits processing helpers.""" + +from __future__ import annotations + +import os +import sys + +# CI Registration (parsed via AST, runtime no-op) +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from ci_system.ci_register import register_cuda_ci + +register_cuda_ci(est_time=90, suite="runtime-1gpu") + +import pytest +import torch + +from tokenspeed.runtime.layers.logits_processor import fused_softcap + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def test_fused_softcap_handles_large_logits_without_nan(): + cap = 30.0 + logits = torch.tensor( + [[5000.0, 2000.0, 1500.0, 100.0, 0.0, -100.0, -1500.0, -5000.0]], + device="cuda", + dtype=torch.float32, + ) + expected = cap * torch.tanh(logits / cap) + + out = fused_softcap(logits.clone(), cap) + torch.cuda.synchronize() + + assert torch.isfinite(out).all() + torch.testing.assert_close(out, expected, rtol=1e-5, atol=2e-5)