From 5d2c37579d01b557581aedaf8d15203473619dcb Mon Sep 17 00:00:00 2001 From: jonah Date: Mon, 2 Feb 2026 10:03:50 -0800 Subject: [PATCH 1/8] initial --- forge_cute_py/kernels/softmax_online.py | 0 pyproject.toml | 6 +++--- 2 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 forge_cute_py/kernels/softmax_online.py diff --git a/forge_cute_py/kernels/softmax_online.py b/forge_cute_py/kernels/softmax_online.py new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index a44eeec..e2d98fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,12 +18,12 @@ dependencies = [ [tool.uv.sources] torch = [ - { index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] [[tool.uv.index]] -name = "pytorch-cu130" -url = "https://download.pytorch.org/whl/cu130" +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" explicit = true From e011fac4b4140eff98ec9599d1cb2a50931d185c Mon Sep 17 00:00:00 2001 From: jonah Date: Tue, 3 Feb 2026 12:06:49 -0800 Subject: [PATCH 2/8] initial --- forge_cute_py/kernels/softmax_online.py | 207 ++++++++++++++++++++++++ forge_cute_py/kernels/testing.py | 124 ++++++++++++++ forge_cute_py/ops/softmax_online.py | 38 ++++- tests/test_softmax_online.py | 9 +- 4 files changed, 367 insertions(+), 11 deletions(-) create mode 100644 forge_cute_py/kernels/testing.py diff --git a/forge_cute_py/kernels/softmax_online.py b/forge_cute_py/kernels/softmax_online.py index e69de29..7998c7a 100644 --- a/forge_cute_py/kernels/softmax_online.py +++ b/forge_cute_py/kernels/softmax_online.py @@ -0,0 +1,207 @@ +import torch +import cutlass +import cutlass.cute as cute +import cuda.bindings.driver as cuda +from cutlass import BFloat16, Float16, Float32 +from cutlass.cute.runtime import from_dlpack +from cutlass import const_expr + + +class SoftmaxOnlineLoop: + def __init__(self, dtype): + self.dtype = dtype + self.num_warps = 1 + self.threads_per_block = self.num_warps * 32 + self.NEG_INF = Float32(float('-inf')) + + @cute.jit + def __call__(self, gInput: cute.Tensor, gOutput: cute.Tensor, stream: cuda.CUstream = None): + M, N = gInput.shape + thr_layout = cute.make_layout((self.threads_per_block,), stride=(1,)) + val_layout = cute.make_layout((1,), stride=(1,)) + tiler_mn_1d, tv_layout = cute.make_layout_tv(thr_layout, val_layout) + tiler_mn = (1, tiler_mn_1d[0]) + gX = cute.zipped_divide(gInput, tiler_mn) + gY = cute.zipped_divide(gOutput, tiler_mn) + + self.kernel( + gX, gY, N + ).launch( + grid=(cute.size(gX, mode=[1, 0]), 1, 1), # RestM + block=(cute.size(tv_layout, mode=[0, 0])), # threads per block + stream=stream + ) + + @cute.kernel + def kernel(self, gInput, gOutput, N): + maxValue = self.NEG_INF + sumValue = Float32(0.0) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdimx, _, _ = cute.arch.block_dim() + (TileM, TileN), (RestM, RestN) = gInput.shape + + for i in range(RestN): + idx = i * bdimx + tidx + value = Float32(gInput[(0, tidx), (bidx, i)]) if idx < N else self.NEG_INF + + curMax = cute.arch.warp_reduction_max(value) + prevMax = maxValue + maxValue = cute.arch.fmax(prevMax, curMax) + + scale = cute.math.exp(prevMax - maxValue) + scale_data = cute.math.exp(value - maxValue) + curSum = cute.arch.warp_reduction_sum(scale_data) + sumValue = sumValue * scale + curSum + + for i in range(RestN): + idx = i * bdimx + tidx + if idx < N: + value = gInput[(0, tidx), (bidx, i)].to(self.dtype) + data = cute.math.exp(value - maxValue) / sumValue + gOutput[(0, tidx), (bidx, i)] = data.to(self.dtype) + + +class SoftmaxOnline: + def __init__(self, dtype, N: int): + self.dtype = dtype + self.num_warps = 1 + self.threads_per_block = self.num_warps * 32 + self.NEG_INF = Float32(float('-inf')) + self.N = N + + self.bits_read = 128 + self.vec_load_size = self.bits_read // self.dtype.width + self.threads_per_row = 32 + self.num_warps = 4 + self.num_threads = self.num_warps * self.threads_per_row + + @cute.jit + def __call__(self, gInput: cute.Tensor, gOutput: cute.Tensor, stream: cuda.CUstream = None): + + blocks_vector_N = cute.ceil_div(self.N, self.bits_read // self.dtype.width) + blocks_over_N = cute.ceil_div(blocks_vector_N, self.threads_per_row) + tiler_mn = (self.num_warps, self.vec_load_size * blocks_over_N * self.threads_per_row) # [4, ~N] + + copy_op = cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, self.dtype, num_bits_per_copy=self.bits_read) + thr_layout = cute.make_ordered_layout( + (self.num_warps, self.threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, self.vec_load_size)) + tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + blocks = cute.ceil_div(gInput.shape[0], tiler_mn[0]) + self.kernel( + gInput, gOutput, tiler_mn, tiled_copy + ).launch( + grid=(blocks, 1, 1), + block=(self.num_threads, 1, 1), + stream=stream + ) + + # type hints are not optional!!!! + @cute.kernel + def kernel(self, gInput: cute.Tensor, gOutput: cute.Tensor, tiler_mn: cute.Shape, tiled_copy: cute.TiledCopy): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + gX = cute.local_tile(gInput, tiler_mn, (bidx, 0)) + gY = cute.local_tile(gOutput, tiler_mn, (bidx, 0)) + # this thread is response for vectorized loads, striding 4 * 32 across the row + tidxSlice = tiled_copy.get_slice(tidx) + tidxIndices = tidxSlice.partition_S(gX) + tidxRegs = cute.make_rmem_tensor_like(tidxIndices) + cute.autovec_copy(tidxIndices, tidxRegs) + + tidxValues = tidxRegs.load() + tidLocalMax = tidxValues.reduce(cute.ReductionOp.MAX, init_val=self.NEG_INF, reduction_profile=0) + rowMax = cute.arch.warp_reduction_max(tidLocalMax) + tidScaledLocalSum = cute.math.exp(tidxValues - rowMax) + tidLocalSum = tidScaledLocalSum.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0) + rowSum = cute.arch.warp_reduction_sum(tidLocalSum) + + writeValues = cute.math.exp(tidxValues - rowMax) / rowSum + tidxRegs.store(writeValues) + tidxOutIndices = tidxSlice.partition_D(gY) + cute.autovec_copy(tidxRegs, tidxOutIndices) + + +def benchmark(loopless=True): + import time + + dim = -1 + M, N = 4096, 768 + dtype = torch.float32 + dtype_map = { + torch.float16: Float16, + torch.float32: Float32, + torch.bfloat16: BFloat16, + } + cute_dtype = dtype_map[dtype] + + x = torch.randn(M, N, device='cuda', dtype=dtype) + output = torch.zeros_like(x) + + if loopless: + dx = x + dy = output + m = cute.sym_int() + input_cute = cute.runtime.make_fake_compact_tensor( + cute_dtype, (m, N), stride_order=(1, 0) + ) + output_cute = cute.runtime.make_fake_compact_tensor( + cute_dtype, (m, N), stride_order=(1, 0) + ) + softmax = SoftmaxOnline(dtype_map[dtype], N) + fn = cute.compile( + softmax, input_cute, output_cute, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + fn(x, output) + else: + dx = from_dlpack(x, enable_tvm_ffi=True) + dy = from_dlpack(output, enable_tvm_ffi=True) + softmax = SoftmaxOnlineLoop(dtype_map[dtype]) + fn = cute.compile( + softmax, dx, dy, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + fn(dx, dy) + + + print("Correctness check:") + expected = torch.nn.functional.softmax(x, dim=-1) + is_close = torch.allclose(output, expected, rtol=1e-3, atol=1e-3) + print(f" dim=-1: {'✓ PASS' if is_close else '✗ FAIL'}") + if not is_close: + max_diff = (output - expected).abs().max().item() + print(f" max diff: {max_diff}") + + print("\nBenchmarks:") + + # Warmup + for _ in range(10): + fn(dx, dy) + torch.cuda.synchronize() + + # Benchmark our softmax + start = time.perf_counter() + for _ in range(100): + fn(dx, dy) + torch.cuda.synchronize() + print(f" softmax_online dim=-1: {(time.perf_counter() - start) * 10:.3f} ms") + + # Compare to PyTorch + start = time.perf_counter() + for _ in range(100): + _ = torch.nn.functional.softmax(x, dim=-1) + torch.cuda.synchronize() + print(f" torch.softmax dim=-1: {(time.perf_counter() - start) * 10:.3f} ms") + +benchmark(loopless=False) +benchmark(loopless=True) \ No newline at end of file diff --git a/forge_cute_py/kernels/testing.py b/forge_cute_py/kernels/testing.py new file mode 100644 index 0000000..dbda4cc --- /dev/null +++ b/forge_cute_py/kernels/testing.py @@ -0,0 +1,124 @@ +import torch +import cutlass +import cutlass.cute as cute +import cuda.bindings.driver as cuda +from cutlass import BFloat16, Float16, Float32 +from cutlass.cute.runtime import from_dlpack +from cutlass import const_expr +from cute_viz import render_tiled_copy_svg +import inspect + + +def _get_tiled_copy(vecsize, dtype, N): + """ + Adapted from quack's tiles_copy_2d() + Reference: https://github.com/Dao-AILab/quack/blob/2e62faaeb6271a780a1360e6c96a003492e47eed/quack/copy_utils.py#L98 + """ + threads_per_row = 32 + num_threads = 128 + # thread groups (of size 32 each) needed to cover N // vecsize + num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row) + + # each tile covers [4, ~N] + tiler_mn = (num_threads // threads_per_row, vecsize * num_blocks_N * threads_per_row) + + num_copy_bits = vecsize * dtype.width + copy_op = cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_ordered_layout( + (num_threads // threads_per_row, threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, vecsize)) + tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + render_tiled_copy_svg(tiled_copy, tiler_mn, "my_copy_layout.svg") + print(f"tild copy: {tiled_copy}") + return tiler_mn, tiled_copy, threads_per_row + + +@cute.jit +def test_jit(mX, mY): + dtype=Float32 + vecsize = 128 // dtype.width + tiler_mn, tiled_copy, threads_per_row = _get_tiled_copy(vecsize=vecsize, dtype=dtype, N=mX.shape[1]) + num_threads = tiled_copy.size + + kernel(mX, mY, tiler_mn, tiled_copy, threads_per_row).launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1], + block=[num_threads, 1, 1] + ) + +@cute.kernel +def kernel( + mX: cute.Tensor, + mO: cute.Tensor, + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + threads_per_row: cutlass.Constexpr[int], +): + # tv_layout = (thread_layout, value_layout) = ((threads_per_row, num_rows), vec_size) + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) # (tileM, tileN) + # TODO: vectorized store + # gO = cute.local_tile(mO, cute.select(tiler_mn, mode=[0]), (bidx,)) # (tileM,) + + thr_copy_X = tiled_copy.get_slice(tidx) + # gmem -> rmem + tXgX = thr_copy_X.partition_S(gX) + + tXrX = cute.make_rmem_tensor_like(tXgX) + cute.autovec_copy(tXgX, tXrX) + + # reduce with higher precision for numerical stability + x = tXrX.load() + val = x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0) + + val = cute.arch.warp_reduction_sum(val) + if tidx == 64 and bidx == 0: + print(f"slice:\n{thr_copy_X}") + print(f"partition: {tXgX}") + print(f"x tyep: {type(x)}") + print(f"val: {val}") + print(f"gx: {gX}") + + lane_id = cute.arch.lane_idx() + warp_id = cute.arch.warp_idx() + + warps_per_row = threads_per_row // cute.arch.WARP_SIZE + + row_idx = warp_id // warps_per_row + col_idx = warp_id % warps_per_row + + # TODO: vetorized store + if lane_id == 0 and col_idx == 0: + mO[row_idx + tiler_mn[0] * bidx] = val + + +def test(): + M, N = 512, 1024 + X = torch.randn(M, N, dtype=torch.float32, device='cuda') + Y = torch.empty((N,), dtype=torch.float32, device='cuda') + + dX = from_dlpack(X) + dY = from_dlpack(Y) + fn = cute.compile( + test_jit, + dX, + dY + ) + fn(dX, dY) + + +from cute_viz import render_layout_svg + +@cute.jit +def visual(): + U = ((2,2),4,(9,(3,3))) + layout = cute.make_layout(U) + render_layout_svg(layout, "test.svg") + + + +visual() \ No newline at end of file diff --git a/forge_cute_py/ops/softmax_online.py b/forge_cute_py/ops/softmax_online.py index e1248e2..17be2f3 100644 --- a/forge_cute_py/ops/softmax_online.py +++ b/forge_cute_py/ops/softmax_online.py @@ -1,4 +1,7 @@ import torch +from cutlass import BFloat16, Float16, Float32 +import cutlass.cute as cute +from forge_cute_py.kernels.softmax_online import SoftmaxOnline @torch.library.custom_op("forge_cute_py::_softmax_fwd", mutates_args={"out"}) @@ -21,13 +24,34 @@ def _softmax_fwd(x: torch.Tensor, out: torch.Tensor, dim: int = -1) -> None: dim = dim if dim >= 0 else x.ndim + dim assert dim in [0, 1], f"dim must be 0 or 1 for 2D tensors, got {dim}" - # For now, use reference implementation - # Future: call kernel implementation when available - from forge_cute_py.ref import softmax_online as softmax_online_ref - - result = softmax_online_ref(x, dim=dim) - out.copy_(result) - + dtype_map = { + torch.float16: Float16, + torch.float32: Float32, + torch.bfloat16: BFloat16, + } + + if x.dtype not in dtype_map: + raise ValueError(f"Unsupported dtype: {x.dtype}") + + cute_dtype = dtype_map[x.dtype] + compile_key = (cute_dtype, x.shape[1]) + + if compile_key not in _softmax_fwd.compile_cache: + m = cute.sym_int() + # n = cute.sym_int() + n = x.shape[1] + input_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, n), stride_order=(1, 0)) + output_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, n), stride_order=(1, 0)) + # Compile and cache the kernel + _softmax_fwd.compile_cache[compile_key] = cute.compile( + SoftmaxOnline(cute_dtype, n), + input_cute, + output_cute, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + _softmax_fwd.compile_cache[compile_key](x, out) _softmax_fwd.compile_cache = {} diff --git a/tests/test_softmax_online.py b/tests/test_softmax_online.py index 068b324..31f968d 100644 --- a/tests/test_softmax_online.py +++ b/tests/test_softmax_online.py @@ -4,9 +4,10 @@ from forge_cute_py.ops import softmax_online from forge_cute_py.ref import softmax_online as ref_softmax_online +dims = [-1] @pytest.mark.parametrize("shape", [(4, 8), (2, 128)]) -@pytest.mark.parametrize("dim", [-1, 0, 1]) +@pytest.mark.parametrize("dim", dims) @pytest.mark.parametrize( "dtype, atol, rtol", [ @@ -25,7 +26,7 @@ def test_softmax_online_correctness(shape, dim, dtype, atol, rtol): @pytest.mark.parametrize("shape", [(4, 8), (2, 128)]) -@pytest.mark.parametrize("dim", [-1, 0, 1]) +@pytest.mark.parametrize("dim", dims) @pytest.mark.parametrize( "dtype, atol, rtol", [ @@ -91,7 +92,7 @@ def test_softmax_online_extreme_values(input_dtype): @pytest.mark.parametrize("shape", [(4, 8), (16, 128), (32, 256)]) -@pytest.mark.parametrize("dim", [-1, 0, 1]) +@pytest.mark.parametrize("dim", dims) @pytest.mark.parametrize( "dtype, atol, rtol", [ @@ -120,7 +121,7 @@ def test_softmax_online_backward(shape, dim, dtype, atol, rtol): @pytest.mark.parametrize("shape", [(4, 8), (16, 128)]) -@pytest.mark.parametrize("dim", [-1, 1]) +@pytest.mark.parametrize("dim", dims) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_softmax_online_backward_torch_compile(shape, dim, dtype): """Test backward pass works with torch.compile.""" From e6205489e9d200bcadc3eb35f769536ac767b25d Mon Sep 17 00:00:00 2001 From: jonah Date: Tue, 3 Feb 2026 20:47:37 -0800 Subject: [PATCH 3/8] remove testing file --- forge_cute_py/kernels/testing.py | 124 ------------------------------- 1 file changed, 124 deletions(-) delete mode 100644 forge_cute_py/kernels/testing.py diff --git a/forge_cute_py/kernels/testing.py b/forge_cute_py/kernels/testing.py deleted file mode 100644 index dbda4cc..0000000 --- a/forge_cute_py/kernels/testing.py +++ /dev/null @@ -1,124 +0,0 @@ -import torch -import cutlass -import cutlass.cute as cute -import cuda.bindings.driver as cuda -from cutlass import BFloat16, Float16, Float32 -from cutlass.cute.runtime import from_dlpack -from cutlass import const_expr -from cute_viz import render_tiled_copy_svg -import inspect - - -def _get_tiled_copy(vecsize, dtype, N): - """ - Adapted from quack's tiles_copy_2d() - Reference: https://github.com/Dao-AILab/quack/blob/2e62faaeb6271a780a1360e6c96a003492e47eed/quack/copy_utils.py#L98 - """ - threads_per_row = 32 - num_threads = 128 - # thread groups (of size 32 each) needed to cover N // vecsize - num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row) - - # each tile covers [4, ~N] - tiler_mn = (num_threads // threads_per_row, vecsize * num_blocks_N * threads_per_row) - - num_copy_bits = vecsize * dtype.width - copy_op = cute.nvgpu.CopyUniversalOp() - copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) - thr_layout = cute.make_ordered_layout( - (num_threads // threads_per_row, threads_per_row), - order=(1, 0), - ) - val_layout = cute.make_layout((1, vecsize)) - tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) - render_tiled_copy_svg(tiled_copy, tiler_mn, "my_copy_layout.svg") - print(f"tild copy: {tiled_copy}") - return tiler_mn, tiled_copy, threads_per_row - - -@cute.jit -def test_jit(mX, mY): - dtype=Float32 - vecsize = 128 // dtype.width - tiler_mn, tiled_copy, threads_per_row = _get_tiled_copy(vecsize=vecsize, dtype=dtype, N=mX.shape[1]) - num_threads = tiled_copy.size - - kernel(mX, mY, tiler_mn, tiled_copy, threads_per_row).launch( - grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1], - block=[num_threads, 1, 1] - ) - -@cute.kernel -def kernel( - mX: cute.Tensor, - mO: cute.Tensor, - tiler_mn: cute.Shape, - tiled_copy: cute.TiledCopy, - threads_per_row: cutlass.Constexpr[int], -): - # tv_layout = (thread_layout, value_layout) = ((threads_per_row, num_rows), vec_size) - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) # (tileM, tileN) - # TODO: vectorized store - # gO = cute.local_tile(mO, cute.select(tiler_mn, mode=[0]), (bidx,)) # (tileM,) - - thr_copy_X = tiled_copy.get_slice(tidx) - # gmem -> rmem - tXgX = thr_copy_X.partition_S(gX) - - tXrX = cute.make_rmem_tensor_like(tXgX) - cute.autovec_copy(tXgX, tXrX) - - # reduce with higher precision for numerical stability - x = tXrX.load() - val = x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0) - - val = cute.arch.warp_reduction_sum(val) - if tidx == 64 and bidx == 0: - print(f"slice:\n{thr_copy_X}") - print(f"partition: {tXgX}") - print(f"x tyep: {type(x)}") - print(f"val: {val}") - print(f"gx: {gX}") - - lane_id = cute.arch.lane_idx() - warp_id = cute.arch.warp_idx() - - warps_per_row = threads_per_row // cute.arch.WARP_SIZE - - row_idx = warp_id // warps_per_row - col_idx = warp_id % warps_per_row - - # TODO: vetorized store - if lane_id == 0 and col_idx == 0: - mO[row_idx + tiler_mn[0] * bidx] = val - - -def test(): - M, N = 512, 1024 - X = torch.randn(M, N, dtype=torch.float32, device='cuda') - Y = torch.empty((N,), dtype=torch.float32, device='cuda') - - dX = from_dlpack(X) - dY = from_dlpack(Y) - fn = cute.compile( - test_jit, - dX, - dY - ) - fn(dX, dY) - - -from cute_viz import render_layout_svg - -@cute.jit -def visual(): - U = ((2,2),4,(9,(3,3))) - layout = cute.make_layout(U) - render_layout_svg(layout, "test.svg") - - - -visual() \ No newline at end of file From 745d5183a0ebd225be01085270508bfa5d6e924a Mon Sep 17 00:00:00 2001 From: jonah Date: Wed, 4 Feb 2026 13:18:32 -0800 Subject: [PATCH 4/8] started backward --- forge_cute_py/kernels/softmax_online.py | 119 +++++++++++++++++++++++- forge_cute_py/ops/softmax_online.py | 31 +++++- tests/test_softmax_online.py | 13 ++- 3 files changed, 155 insertions(+), 8 deletions(-) diff --git a/forge_cute_py/kernels/softmax_online.py b/forge_cute_py/kernels/softmax_online.py index 7998c7a..90671f5 100644 --- a/forge_cute_py/kernels/softmax_online.py +++ b/forge_cute_py/kernels/softmax_online.py @@ -1,3 +1,4 @@ +from this import d import torch import cutlass import cutlass.cute as cute @@ -7,6 +8,85 @@ from cutlass import const_expr +class SoftmaxOnlineBackward: + def __init__(self, dtype, N: int): + self.dtype = dtype + self.num_warps = 4 + self.bits_read = 128 + self.vec_load_size = self.bits_read // dtype.width + self.warp_size = 32 + self.threads_per_block = self.num_warps * self.warp_size + self.N = N # N is static at compile time, M is dynamic + + @cute.jit + def __call__(self, dY: cute.Tensor, y: cute.Tensor, dx: cute.Tensor, stream=None): + blocks_over_N = cute.ceil_div(self.N, self.vec_load_size * self.warp_size) + tiler_mn = ( # full covering tile + self.num_warps, + self.vec_load_size * self.warp_size * blocks_over_N + ) + + copy_op = cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, self.dtype, num_bits_per_copy=self.bits_read) + + thr_layout = cute.make_ordered_layout( + (self.num_warps, self.warp_size), + order=(1, 0) # cols move faster + ) + val_layout = cute.make_layout((1, self.vec_load_size)) + tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + blocks = cute.ceil_div(dY.shape[0], self.num_warps) + self.kernel( + dY, y, dx, tiler_mn, tiled_copy + ).launch( + grid=(blocks, 1, 1), + block=(self.threads_per_block, 1, 1), + # stream=stream + ) + + # type hints are not optional!!!! + @cute.kernel + def kernel( + self, dY: cute.Tensor, y: cute.Tensor, dX: cute.Tensor, + tiler_mn: cute.Shape, tiled_copy: cute.TiledCopy + ): + # Compute gradient (numerically stable) + # dot_product = (dy * y).sum(dim=dim, keepdim=True) + # result = y * (dy - dot_product) + # dx.copy_(result) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + dy_tile = cute.local_tile(dY, tiler_mn, (bidx, 0)) + y_tile = cute.local_tile(y, tiler_mn, (bidx, 0)) + dx_tile = cute.local_tile(dX, tiler_mn, (bidx, 0)) + + tidxSlice = tiled_copy.get_slice(tidx) + + dy_idx = tidxSlice.partition_S(dy_tile) + y_idx = tidxSlice.partition_S(y_tile) + dx_idx = tidxSlice.partition_D(dx_tile) + + dy_regs = cute.make_rmem_tensor_like(dy_idx) + y_regs = cute.make_rmem_tensor_like(y_idx) + + cute.autovec_copy(dy_idx, dy_regs) + cute.autovec_copy(y_idx, y_regs) + + dy_data = dy_regs.load() + y_data = y_regs.load() + + tidx_local_dp = dy_data * y_data + tidx_local_sum = tidx_local_dp.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0) + row_dp_sum = cute.arch.warp_reduction_sum(tidx_local_sum) + + result = y_data * (dy_data - row_dp_sum) + dy_regs.store(result) + cute.autovec_copy(dy_regs, dx_idx) + + class SoftmaxOnlineLoop: def __init__(self, dtype): self.dtype = dtype @@ -203,5 +283,40 @@ def benchmark(loopless=True): torch.cuda.synchronize() print(f" torch.softmax dim=-1: {(time.perf_counter() - start) * 10:.3f} ms") -benchmark(loopless=False) -benchmark(loopless=True) \ No newline at end of file + +def bench_back(): + dim = -1 + M, N = 256, 256 + dtype = torch.float32 + dtype_map = { + torch.float16: Float16, + torch.float32: Float32, + torch.bfloat16: BFloat16, + } + cute_dtype = dtype_map[dtype] + dy = torch.randn(M, N, device='cuda', dtype=dtype) + y = torch.randn(M, N, device='cuda', dtype=dtype) + dx = torch.randn(M, N, device='cuda', dtype=dtype) + + m = cute.sym_int() + dy_cute = cute.runtime.make_fake_compact_tensor( + cute_dtype, (m, N), stride_order=(1, 0) + ) + y_cute = cute.runtime.make_fake_compact_tensor( + cute_dtype, (m, N), stride_order=(1, 0) + ) + dx_cute = cute.runtime.make_fake_compact_tensor( + cute_dtype, (m, N), stride_order=(1, 0) + ) + softmax = SoftmaxOnlineBackward(dtype_map[dtype], N) + fn = cute.compile( + softmax, dy_cute, y_cute, dx_cute, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + fn(dy, y, dx) + +# benchmark(loopless=False) +# benchmark(loopless=True) + +# bench_back() \ No newline at end of file diff --git a/forge_cute_py/ops/softmax_online.py b/forge_cute_py/ops/softmax_online.py index 17be2f3..198dfbc 100644 --- a/forge_cute_py/ops/softmax_online.py +++ b/forge_cute_py/ops/softmax_online.py @@ -1,7 +1,7 @@ import torch from cutlass import BFloat16, Float16, Float32 import cutlass.cute as cute -from forge_cute_py.kernels.softmax_online import SoftmaxOnline +from forge_cute_py.kernels.softmax_online import SoftmaxOnline, SoftmaxOnlineBackward @torch.library.custom_op("forge_cute_py::_softmax_fwd", mutates_args={"out"}) @@ -89,9 +89,32 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor, dim: assert dim in [0, 1], f"dim must be 0 or 1 for 2D, got {dim}" # Compute gradient (numerically stable) - dot_product = (dy * y).sum(dim=dim, keepdim=True) - result = y * (dy - dot_product) - dx.copy_(result) + # dot_product = (dy * y).sum(dim=dim, keepdim=True) + # result = y * (dy - dot_product) + # dx.copy_(result) + + dtype_map = { + torch.float16: Float16, + torch.float32: Float32, + torch.bfloat16: BFloat16, + } + + cute_dtype = dtype_map[dy.dtype] + compile_key = (cute_dtype, dy.shape[1]) + + if compile_key not in _softmax_backward.compile_cache: + m = cute.sym_int() + n = dy.shape[1] + dy_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, n), stride_order=(1, 0)) + y_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, n), stride_order=(1, 0)) + dx_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, n), stride_order=(1, 0)) + # Compile and cache the kernel + _softmax_backward.compile_cache[compile_key] = cute.compile( + SoftmaxOnlineBackward(cute_dtype, n), + dy_cute, y_cute, dx_cute, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) _softmax_backward.compile_cache = {} diff --git a/tests/test_softmax_online.py b/tests/test_softmax_online.py index 31f968d..edd8b91 100644 --- a/tests/test_softmax_online.py +++ b/tests/test_softmax_online.py @@ -104,20 +104,29 @@ def test_softmax_online_extreme_values(input_dtype): def test_softmax_online_backward(shape, dim, dtype, atol, rtol): """Test backward pass against PyTorch reference.""" # Create inputs with gradients enabled (scale by 0.1 to avoid overflow) + + shape = (128, 256) + dim = -1 + dtype = torch.float32 + atol = rtol = 1e-4 + x = (0.1 * torch.randn(*shape, device="cuda", dtype=dtype)).requires_grad_(True) x_ref = x.detach().clone().requires_grad_(True) # Forward pass out = softmax_online(x, dim=dim) out_ref = ref_softmax_online(x_ref, dim=dim) - torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) + # torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) + assert torch.allclose(out, out_ref, atol=atol, rtol=rtol) # Backward pass dy = torch.randn_like(out) torch.cuda.synchronize() # Critical: prevents autograd timing issues (dx,) = torch.autograd.grad(out, x, grad_outputs=dy) (dx_ref,) = torch.autograd.grad(out_ref, x_ref, grad_outputs=dy) - torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol) + + assert torch.allclose(dx, dx_ref, atol=atol, rtol=rtol) + # torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol) @pytest.mark.parametrize("shape", [(4, 8), (16, 128)]) From 33549bb93c18ab0281769b4edc1f63b418649ef9 Mon Sep 17 00:00:00 2001 From: jonah Date: Wed, 4 Feb 2026 14:39:53 -0800 Subject: [PATCH 5/8] backwards works --- forge_cute_py/kernels/softmax_online.py | 2 +- forge_cute_py/ops/softmax_online.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/forge_cute_py/kernels/softmax_online.py b/forge_cute_py/kernels/softmax_online.py index 90671f5..2e47cb5 100644 --- a/forge_cute_py/kernels/softmax_online.py +++ b/forge_cute_py/kernels/softmax_online.py @@ -42,7 +42,7 @@ def __call__(self, dY: cute.Tensor, y: cute.Tensor, dx: cute.Tensor, stream=None ).launch( grid=(blocks, 1, 1), block=(self.threads_per_block, 1, 1), - # stream=stream + stream=stream ) # type hints are not optional!!!! diff --git a/forge_cute_py/ops/softmax_online.py b/forge_cute_py/ops/softmax_online.py index 198dfbc..8024336 100644 --- a/forge_cute_py/ops/softmax_online.py +++ b/forge_cute_py/ops/softmax_online.py @@ -115,6 +115,9 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor, dim: cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) + _softmax_backward.compile_cache[compile_key]( + dy, y, dx + ) _softmax_backward.compile_cache = {} From 9eb0b846fa4707d4f8081550ab175f0d9934d025 Mon Sep 17 00:00:00 2001 From: jonah Date: Fri, 6 Feb 2026 06:39:21 -0800 Subject: [PATCH 6/8] add benchmark script --- bench/benchmark_online_softmax.py | 161 ++++++++++++++++++++++++++++ forge_cute_py/ops/softmax_online.py | 2 + 2 files changed, 163 insertions(+) create mode 100644 bench/benchmark_online_softmax.py diff --git a/bench/benchmark_online_softmax.py b/bench/benchmark_online_softmax.py new file mode 100644 index 0000000..85f5046 --- /dev/null +++ b/bench/benchmark_online_softmax.py @@ -0,0 +1,161 @@ +"""Benchmark softmax_online op against torch.softmax and torch.compile(torch.softmax).""" + +import argparse + +import torch + +from forge_cute_py.ops.softmax_online import softmax_online, softmax_fwd, softmax_bwd +from forge_cute_py.util.bench import do_bench, estimate_bandwidth, summarize_times + +SHORT_M = [128, 512, 2048, 8192] +SHORT_N = [1024, 2048, 4096, 8192] + +LONG_M = [64, 128, 256] +LONG_N = [16384, 32768, 65536, 131072] + +DEFAULT_DTYPES = ["float16", "bfloat16", "float32"] + + +def parse_int_list(s: str) -> list[int]: + return [int(x.strip()) for x in s.split(",")] + + +def parse_str_list(s: str) -> list[str]: + return [x.strip() for x in s.split(",")] + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark softmax_online op") + parser.add_argument("--long", action="store_true", help="Use long-N benchmark suite (small M, large N)") + parser.add_argument("--m-sizes", type=parse_int_list, default=None) + parser.add_argument("--n-sizes", type=parse_int_list, default=None) + parser.add_argument("--dtypes", type=parse_str_list, default=DEFAULT_DTYPES) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--iterations", type=int, default=100) + args = parser.parse_args() + + if args.m_sizes is None: + args.m_sizes = LONG_M if args.long else SHORT_M + if args.n_sizes is None: + args.n_sizes = LONG_N if args.long else SHORT_N + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required for benchmarking") + + gpu_name = torch.cuda.get_device_name(0) + suite = "long" if args.long else "short" + print(f"softmax_online benchmarks [{suite}] ({gpu_name})") + print() + + header = ( + f"{'M':>6} {'N':>6} {'Dtype':<10} {'Op':<18} {'Pass':<5} " + f"{'p50 (ms)':>10} {'BW (GB/s)':>10} {'vs torch':>10}" + ) + print(header) + print("-" * len(header)) + + for m in args.m_sizes: + for n in args.n_sizes: + for dtype_str in args.dtypes: + dtype = getattr(torch, dtype_str) + x = torch.randn(m, n, device="cuda", dtype=dtype) + assert n % 32 == 0, f"Inner dimension N must be a multiple of 32, got {n}" + elem = x.element_size() + + # --- Forward bandwidth: read input + write output --- + fwd_bytes = 2 * m * n * elem + + # --- torch.softmax fwd baseline --- + torch_fn = lambda: torch.softmax(x, dim=-1) + torch_times = do_bench(torch_fn, warmup=args.warmup, rep=args.iterations) + torch_stats = summarize_times(torch_times) + torch_fwd_p50 = torch_stats["p50_ms"] + torch_fwd_bw = estimate_bandwidth(fwd_bytes, torch_fwd_p50) + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'torch.softmax':<18} {'fwd':<5} " + f"{torch_fwd_p50:>10.4f} {torch_fwd_bw:>10.2f} {1.0:>10.2f}x" + ) + + # --- torch.compile fwd --- + try: + compiled_ref = torch.compile(lambda t: torch.softmax(t, dim=-1)) + compiled_ref(x) + fn = lambda: compiled_ref(x) + compiled_times = do_bench(fn, warmup=args.warmup, rep=args.iterations) + compiled_stats = summarize_times(compiled_times) + compiled_p50 = compiled_stats["p50_ms"] + compiled_bw = estimate_bandwidth(fwd_bytes, compiled_p50) + ratio = compiled_p50 / torch_fwd_p50 if torch_fwd_p50 > 0 else float("inf") + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'torch.compile':<18} {'fwd':<5} " + f"{compiled_p50:>10.4f} {compiled_bw:>10.2f} {ratio:>10.2f}x" + ) + except Exception as e: + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'torch.compile':<18} {'fwd':<5} " + f"{'ERROR':>10} {'':>10} {'':>10} {e}" + ) + + # --- softmax_online fwd --- + try: + softmax_fwd(x, dim=-1) + fn = lambda: softmax_fwd(x, dim=-1) + times = do_bench(fn, warmup=args.warmup, rep=args.iterations) + stats = summarize_times(times) + p50 = stats["p50_ms"] + bw = estimate_bandwidth(fwd_bytes, p50) + ratio = p50 / torch_fwd_p50 if torch_fwd_p50 > 0 else float("inf") + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'softmax_online':<18} {'fwd':<5} " + f"{p50:>10.4f} {bw:>10.2f} {ratio:>10.2f}x" + ) + except Exception as e: + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'softmax_online':<18} {'fwd':<5} " + f"{'ERROR':>10} {'':>10} {'':>10} {e}" + ) + + # --- Backward pass benchmarks --- + # Pre-compute softmax output y and fake upstream gradient dy + y = torch.softmax(x, dim=-1) + dy = torch.randn_like(y) + + # Backward bandwidth: read dy + read y + write dx = 3 * M * N * elem + bwd_bytes = 3 * m * n * elem + + # --- torch backward baseline --- + torch_bwd_fn = lambda: torch._softmax_backward_data(dy, y, -1, x.dtype) + torch_bwd_times = do_bench(torch_bwd_fn, warmup=args.warmup, rep=args.iterations) + torch_bwd_stats = summarize_times(torch_bwd_times) + torch_bwd_p50 = torch_bwd_stats["p50_ms"] + torch_bwd_bw = estimate_bandwidth(bwd_bytes, torch_bwd_p50) + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'torch.softmax':<18} {'bwd':<5} " + f"{torch_bwd_p50:>10.4f} {torch_bwd_bw:>10.2f} {1.0:>10.2f}x" + ) + + # --- softmax_online bwd --- + try: + y_ours = softmax_fwd(x, dim=-1) + softmax_bwd(dy, y_ours, dim=-1) + fn = lambda: softmax_bwd(dy, y_ours, dim=-1) + times = do_bench(fn, warmup=args.warmup, rep=args.iterations) + stats = summarize_times(times) + p50 = stats["p50_ms"] + bw = estimate_bandwidth(bwd_bytes, p50) + ratio = p50 / torch_bwd_p50 if torch_bwd_p50 > 0 else float("inf") + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'softmax_online':<18} {'bwd':<5} " + f"{p50:>10.4f} {bw:>10.2f} {ratio:>10.2f}x" + ) + except Exception as e: + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'softmax_online':<18} {'bwd':<5} " + f"{'ERROR':>10} {'':>10} {'':>10} {e}" + ) + + print() + + +if __name__ == "__main__": + main() diff --git a/forge_cute_py/ops/softmax_online.py b/forge_cute_py/ops/softmax_online.py index 8024336..46d9f6d 100644 --- a/forge_cute_py/ops/softmax_online.py +++ b/forge_cute_py/ops/softmax_online.py @@ -23,6 +23,7 @@ def _softmax_fwd(x: torch.Tensor, out: torch.Tensor, dim: int = -1) -> None: # Normalize dim to positive index dim = dim if dim >= 0 else x.ndim + dim assert dim in [0, 1], f"dim must be 0 or 1 for 2D tensors, got {dim}" + assert x.shape[1] % 32 == 0, f"Inner dimension N must be a multiple of 32, got {x.shape[1]}" dtype_map = { torch.float16: Float16, @@ -87,6 +88,7 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor, dim: # Normalize dim dim = dim if dim >= 0 else dy.ndim + dim assert dim in [0, 1], f"dim must be 0 or 1 for 2D, got {dim}" + assert dy.shape[1] % 32 == 0, f"Inner dimension N must be a multiple of 32, got {dy.shape[1]}" # Compute gradient (numerically stable) # dot_product = (dy * y).sum(dim=dim, keepdim=True) From 28b840d24f502d04f15c6ff5bf5302af54051f54 Mon Sep 17 00:00:00 2001 From: jonah Date: Fri, 6 Feb 2026 06:41:39 -0800 Subject: [PATCH 7/8] ruff --- bench/benchmark_online_softmax.py | 2 +- forge_cute_py/kernels/softmax_online.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/bench/benchmark_online_softmax.py b/bench/benchmark_online_softmax.py index 85f5046..4fe4992 100644 --- a/bench/benchmark_online_softmax.py +++ b/bench/benchmark_online_softmax.py @@ -4,7 +4,7 @@ import torch -from forge_cute_py.ops.softmax_online import softmax_online, softmax_fwd, softmax_bwd +from forge_cute_py.ops.softmax_online import softmax_fwd, softmax_bwd from forge_cute_py.util.bench import do_bench, estimate_bandwidth, summarize_times SHORT_M = [128, 512, 2048, 8192] diff --git a/forge_cute_py/kernels/softmax_online.py b/forge_cute_py/kernels/softmax_online.py index 2e47cb5..4a2d332 100644 --- a/forge_cute_py/kernels/softmax_online.py +++ b/forge_cute_py/kernels/softmax_online.py @@ -1,11 +1,8 @@ -from this import d import torch -import cutlass import cutlass.cute as cute import cuda.bindings.driver as cuda from cutlass import BFloat16, Float16, Float32 from cutlass.cute.runtime import from_dlpack -from cutlass import const_expr class SoftmaxOnlineBackward: From 34d42625f71589a14cf1043bce63bce3801d09ba Mon Sep 17 00:00:00 2001 From: jonah Date: Sat, 7 Feb 2026 07:51:07 -0800 Subject: [PATCH 8/8] ruff --- bench/benchmark_online_softmax.py | 4 +- forge_cute_py/kernels/softmax_online.py | 150 ++++++++++++------------ forge_cute_py/ops/softmax_online.py | 13 +- tests/test_softmax_online.py | 1 + 4 files changed, 89 insertions(+), 79 deletions(-) diff --git a/bench/benchmark_online_softmax.py b/bench/benchmark_online_softmax.py index 4fe4992..ef9b527 100644 --- a/bench/benchmark_online_softmax.py +++ b/bench/benchmark_online_softmax.py @@ -26,7 +26,9 @@ def parse_str_list(s: str) -> list[str]: def main(): parser = argparse.ArgumentParser(description="Benchmark softmax_online op") - parser.add_argument("--long", action="store_true", help="Use long-N benchmark suite (small M, large N)") + parser.add_argument( + "--long", action="store_true", help="Use long-N benchmark suite (small M, large N)" + ) parser.add_argument("--m-sizes", type=parse_int_list, default=None) parser.add_argument("--n-sizes", type=parse_int_list, default=None) parser.add_argument("--dtypes", type=parse_str_list, default=DEFAULT_DTYPES) diff --git a/forge_cute_py/kernels/softmax_online.py b/forge_cute_py/kernels/softmax_online.py index 4a2d332..1a57db3 100644 --- a/forge_cute_py/kernels/softmax_online.py +++ b/forge_cute_py/kernels/softmax_online.py @@ -18,9 +18,9 @@ def __init__(self, dtype, N: int): @cute.jit def __call__(self, dY: cute.Tensor, y: cute.Tensor, dx: cute.Tensor, stream=None): blocks_over_N = cute.ceil_div(self.N, self.vec_load_size * self.warp_size) - tiler_mn = ( # full covering tile + tiler_mn = ( # full covering tile self.num_warps, - self.vec_load_size * self.warp_size * blocks_over_N + self.vec_load_size * self.warp_size * blocks_over_N, ) copy_op = cute.nvgpu.CopyUniversalOp() @@ -28,26 +28,26 @@ def __call__(self, dY: cute.Tensor, y: cute.Tensor, dx: cute.Tensor, stream=None thr_layout = cute.make_ordered_layout( (self.num_warps, self.warp_size), - order=(1, 0) # cols move faster + order=(1, 0), # cols move faster ) val_layout = cute.make_layout((1, self.vec_load_size)) tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) blocks = cute.ceil_div(dY.shape[0], self.num_warps) - self.kernel( - dY, y, dx, tiler_mn, tiled_copy - ).launch( - grid=(blocks, 1, 1), - block=(self.threads_per_block, 1, 1), - stream=stream + self.kernel(dY, y, dx, tiler_mn, tiled_copy).launch( + grid=(blocks, 1, 1), block=(self.threads_per_block, 1, 1), stream=stream ) - + # type hints are not optional!!!! @cute.kernel def kernel( - self, dY: cute.Tensor, y: cute.Tensor, dX: cute.Tensor, - tiler_mn: cute.Shape, tiled_copy: cute.TiledCopy - ): + self, + dY: cute.Tensor, + y: cute.Tensor, + dX: cute.Tensor, + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + ): # Compute gradient (numerically stable) # dot_product = (dy * y).sum(dim=dim, keepdim=True) # result = y * (dy - dot_product) @@ -61,7 +61,7 @@ def kernel( dx_tile = cute.local_tile(dX, tiler_mn, (bidx, 0)) tidxSlice = tiled_copy.get_slice(tidx) - + dy_idx = tidxSlice.partition_S(dy_tile) y_idx = tidxSlice.partition_S(y_tile) dx_idx = tidxSlice.partition_D(dx_tile) @@ -76,7 +76,9 @@ def kernel( y_data = y_regs.load() tidx_local_dp = dy_data * y_data - tidx_local_sum = tidx_local_dp.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0) + tidx_local_sum = tidx_local_dp.reduce( + cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0 + ) row_dp_sum = cute.arch.warp_reduction_sum(tidx_local_sum) result = y_data * (dy_data - row_dp_sum) @@ -89,7 +91,7 @@ def __init__(self, dtype): self.dtype = dtype self.num_warps = 1 self.threads_per_block = self.num_warps * 32 - self.NEG_INF = Float32(float('-inf')) + self.NEG_INF = Float32(float("-inf")) @cute.jit def __call__(self, gInput: cute.Tensor, gOutput: cute.Tensor, stream: cuda.CUstream = None): @@ -100,15 +102,13 @@ def __call__(self, gInput: cute.Tensor, gOutput: cute.Tensor, stream: cuda.CUstr tiler_mn = (1, tiler_mn_1d[0]) gX = cute.zipped_divide(gInput, tiler_mn) gY = cute.zipped_divide(gOutput, tiler_mn) - - self.kernel( - gX, gY, N - ).launch( - grid=(cute.size(gX, mode=[1, 0]), 1, 1), # RestM + + self.kernel(gX, gY, N).launch( + grid=(cute.size(gX, mode=[1, 0]), 1, 1), # RestM block=(cute.size(tv_layout, mode=[0, 0])), # threads per block - stream=stream + stream=stream, ) - + @cute.kernel def kernel(self, gInput, gOutput, N): maxValue = self.NEG_INF @@ -126,12 +126,12 @@ def kernel(self, gInput, gOutput, N): curMax = cute.arch.warp_reduction_max(value) prevMax = maxValue maxValue = cute.arch.fmax(prevMax, curMax) - + scale = cute.math.exp(prevMax - maxValue) scale_data = cute.math.exp(value - maxValue) curSum = cute.arch.warp_reduction_sum(scale_data) sumValue = sumValue * scale + curSum - + for i in range(RestN): idx = i * bdimx + tidx if idx < N: @@ -145,7 +145,7 @@ def __init__(self, dtype, N: int): self.dtype = dtype self.num_warps = 1 self.threads_per_block = self.num_warps * 32 - self.NEG_INF = Float32(float('-inf')) + self.NEG_INF = Float32(float("-inf")) self.N = N self.bits_read = 128 @@ -159,7 +159,10 @@ def __call__(self, gInput: cute.Tensor, gOutput: cute.Tensor, stream: cuda.CUstr blocks_vector_N = cute.ceil_div(self.N, self.bits_read // self.dtype.width) blocks_over_N = cute.ceil_div(blocks_vector_N, self.threads_per_row) - tiler_mn = (self.num_warps, self.vec_load_size * blocks_over_N * self.threads_per_row) # [4, ~N] + tiler_mn = ( + self.num_warps, + self.vec_load_size * blocks_over_N * self.threads_per_row, + ) # [4, ~N] copy_op = cute.nvgpu.CopyUniversalOp() copy_atom = cute.make_copy_atom(copy_op, self.dtype, num_bits_per_copy=self.bits_read) @@ -171,36 +174,42 @@ def __call__(self, gInput: cute.Tensor, gOutput: cute.Tensor, stream: cuda.CUstr tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) blocks = cute.ceil_div(gInput.shape[0], tiler_mn[0]) - self.kernel( - gInput, gOutput, tiler_mn, tiled_copy - ).launch( - grid=(blocks, 1, 1), - block=(self.num_threads, 1, 1), - stream=stream + self.kernel(gInput, gOutput, tiler_mn, tiled_copy).launch( + grid=(blocks, 1, 1), block=(self.num_threads, 1, 1), stream=stream ) - + # type hints are not optional!!!! @cute.kernel - def kernel(self, gInput: cute.Tensor, gOutput: cute.Tensor, tiler_mn: cute.Shape, tiled_copy: cute.TiledCopy): + def kernel( + self, + gInput: cute.Tensor, + gOutput: cute.Tensor, + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + ): tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() gX = cute.local_tile(gInput, tiler_mn, (bidx, 0)) gY = cute.local_tile(gOutput, tiler_mn, (bidx, 0)) # this thread is response for vectorized loads, striding 4 * 32 across the row - tidxSlice = tiled_copy.get_slice(tidx) + tidxSlice = tiled_copy.get_slice(tidx) tidxIndices = tidxSlice.partition_S(gX) tidxRegs = cute.make_rmem_tensor_like(tidxIndices) cute.autovec_copy(tidxIndices, tidxRegs) tidxValues = tidxRegs.load() - tidLocalMax = tidxValues.reduce(cute.ReductionOp.MAX, init_val=self.NEG_INF, reduction_profile=0) + tidLocalMax = tidxValues.reduce( + cute.ReductionOp.MAX, init_val=self.NEG_INF, reduction_profile=0 + ) rowMax = cute.arch.warp_reduction_max(tidLocalMax) tidScaledLocalSum = cute.math.exp(tidxValues - rowMax) - tidLocalSum = tidScaledLocalSum.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0) + tidLocalSum = tidScaledLocalSum.reduce( + cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0 + ) rowSum = cute.arch.warp_reduction_sum(tidLocalSum) - writeValues = cute.math.exp(tidxValues - rowMax) / rowSum + writeValues = cute.math.exp(tidxValues - rowMax) / rowSum tidxRegs.store(writeValues) tidxOutIndices = tidxSlice.partition_D(gY) cute.autovec_copy(tidxRegs, tidxOutIndices) @@ -218,25 +227,23 @@ def benchmark(loopless=True): torch.bfloat16: BFloat16, } cute_dtype = dtype_map[dtype] - - x = torch.randn(M, N, device='cuda', dtype=dtype) + + x = torch.randn(M, N, device="cuda", dtype=dtype) output = torch.zeros_like(x) if loopless: dx = x dy = output m = cute.sym_int() - input_cute = cute.runtime.make_fake_compact_tensor( - cute_dtype, (m, N), stride_order=(1, 0) - ) - output_cute = cute.runtime.make_fake_compact_tensor( - cute_dtype, (m, N), stride_order=(1, 0) - ) + input_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, N), stride_order=(1, 0)) + output_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, N), stride_order=(1, 0)) softmax = SoftmaxOnline(dtype_map[dtype], N) fn = cute.compile( - softmax, input_cute, output_cute, + softmax, + input_cute, + output_cute, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), - options="--enable-tvm-ffi", + options="--enable-tvm-ffi", ) fn(x, output) else: @@ -244,13 +251,14 @@ def benchmark(loopless=True): dy = from_dlpack(output, enable_tvm_ffi=True) softmax = SoftmaxOnlineLoop(dtype_map[dtype]) fn = cute.compile( - softmax, dx, dy, + softmax, + dx, + dy, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), - options="--enable-tvm-ffi", + options="--enable-tvm-ffi", ) fn(dx, dy) - print("Correctness check:") expected = torch.nn.functional.softmax(x, dim=-1) is_close = torch.allclose(output, expected, rtol=1e-3, atol=1e-3) @@ -258,21 +266,21 @@ def benchmark(loopless=True): if not is_close: max_diff = (output - expected).abs().max().item() print(f" max diff: {max_diff}") - + print("\nBenchmarks:") - + # Warmup for _ in range(10): fn(dx, dy) torch.cuda.synchronize() - + # Benchmark our softmax start = time.perf_counter() for _ in range(100): fn(dx, dy) torch.cuda.synchronize() print(f" softmax_online dim=-1: {(time.perf_counter() - start) * 10:.3f} ms") - + # Compare to PyTorch start = time.perf_counter() for _ in range(100): @@ -289,31 +297,29 @@ def bench_back(): torch.float16: Float16, torch.float32: Float32, torch.bfloat16: BFloat16, - } + } cute_dtype = dtype_map[dtype] - dy = torch.randn(M, N, device='cuda', dtype=dtype) - y = torch.randn(M, N, device='cuda', dtype=dtype) - dx = torch.randn(M, N, device='cuda', dtype=dtype) + dy = torch.randn(M, N, device="cuda", dtype=dtype) + y = torch.randn(M, N, device="cuda", dtype=dtype) + dx = torch.randn(M, N, device="cuda", dtype=dtype) m = cute.sym_int() - dy_cute = cute.runtime.make_fake_compact_tensor( - cute_dtype, (m, N), stride_order=(1, 0) - ) - y_cute = cute.runtime.make_fake_compact_tensor( - cute_dtype, (m, N), stride_order=(1, 0) - ) - dx_cute = cute.runtime.make_fake_compact_tensor( - cute_dtype, (m, N), stride_order=(1, 0) - ) + dy_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, N), stride_order=(1, 0)) + y_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, N), stride_order=(1, 0)) + dx_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, N), stride_order=(1, 0)) softmax = SoftmaxOnlineBackward(dtype_map[dtype], N) fn = cute.compile( - softmax, dy_cute, y_cute, dx_cute, + softmax, + dy_cute, + y_cute, + dx_cute, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), - options="--enable-tvm-ffi", + options="--enable-tvm-ffi", ) fn(dy, y, dx) + # benchmark(loopless=False) # benchmark(loopless=True) -# bench_back() \ No newline at end of file +# bench_back() diff --git a/forge_cute_py/ops/softmax_online.py b/forge_cute_py/ops/softmax_online.py index 46d9f6d..ac274e7 100644 --- a/forge_cute_py/ops/softmax_online.py +++ b/forge_cute_py/ops/softmax_online.py @@ -35,7 +35,7 @@ def _softmax_fwd(x: torch.Tensor, out: torch.Tensor, dim: int = -1) -> None: raise ValueError(f"Unsupported dtype: {x.dtype}") cute_dtype = dtype_map[x.dtype] - compile_key = (cute_dtype, x.shape[1]) + compile_key = (cute_dtype, x.shape[1]) if compile_key not in _softmax_fwd.compile_cache: m = cute.sym_int() @@ -54,6 +54,7 @@ def _softmax_fwd(x: torch.Tensor, out: torch.Tensor, dim: int = -1) -> None: _softmax_fwd.compile_cache[compile_key](x, out) + _softmax_fwd.compile_cache = {} @@ -102,7 +103,7 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor, dim: } cute_dtype = dtype_map[dy.dtype] - compile_key = (cute_dtype, dy.shape[1]) + compile_key = (cute_dtype, dy.shape[1]) if compile_key not in _softmax_backward.compile_cache: m = cute.sym_int() @@ -113,13 +114,13 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor, dim: # Compile and cache the kernel _softmax_backward.compile_cache[compile_key] = cute.compile( SoftmaxOnlineBackward(cute_dtype, n), - dy_cute, y_cute, dx_cute, + dy_cute, + y_cute, + dx_cute, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) - _softmax_backward.compile_cache[compile_key]( - dy, y, dx - ) + _softmax_backward.compile_cache[compile_key](dy, y, dx) _softmax_backward.compile_cache = {} diff --git a/tests/test_softmax_online.py b/tests/test_softmax_online.py index edd8b91..fd80a6e 100644 --- a/tests/test_softmax_online.py +++ b/tests/test_softmax_online.py @@ -6,6 +6,7 @@ dims = [-1] + @pytest.mark.parametrize("shape", [(4, 8), (2, 128)]) @pytest.mark.parametrize("dim", dims) @pytest.mark.parametrize(