From b10ffed69de5391e5f73ef70cf42b6fd985d7fb7 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 3 May 2026 18:21:19 +0700 Subject: [PATCH 1/3] fix the SM100 2CTA issue --- AI/complex64_design_notes.md | 255 +++++++++++++ quack/complex.py | 292 +++++++++++++++ quack/gemm_act.py | 11 + quack/gemm_dact.py | 6 + quack/gemm_sm100.py | 11 +- tests/test_complex.py | 668 +++++++++++++++++++++++++++++++++++ 6 files changed, 1242 insertions(+), 1 deletion(-) create mode 100644 AI/complex64_design_notes.md create mode 100644 quack/complex.py create mode 100644 tests/test_complex.py diff --git a/AI/complex64_design_notes.md b/AI/complex64_design_notes.md new file mode 100644 index 00000000..83033a69 --- /dev/null +++ b/AI/complex64_design_notes.md @@ -0,0 +1,255 @@ +# `quack.complex.Complex64` design notes + +A native-feeling `complex64` element type for CuTe-DSL kernels, without +patching the C++ MLIR bindings. Lives in `quack/complex.py`. Validated by +`tests/test_complex.py` (44 tests covering constructors, arithmetic matrix, +methods, class invariants, helpers, the c_pointers boundary, special / edge +values, auto-promotion, and a smem round-trip). FFT-correctness tests that +exercise Complex64 through real kernels live in `tests/test_fft.py`. + +This file records *why* the design looks the way it does and what was tried +and rejected. The production native C2C FFT path in `quack/fft.py` already +uses Complex64 for register values and IO -- see +`AI/complex64_fft_migration_plan.md` for the cutover history and remaining +work. + +--- + +## The fundamental constraint + +`cute.MemRefType.get(value_type, layout_type)` -- the C++ MLIR-binding call +that backs every `make_tensor`, `make_rmem_tensor`, `recast_tensor`, and smem +`allocate_tensor` -- enforces an element-type **allowlist**: the value_type +must be one of `int`, `float`, `ptr`, or `sparse_elem`. The natural choice +for "complex single-precision" -- `complex` -- gets: + +``` +TypeError: expects value type to be int, float, ptr or sparse_elem, + but got 'complex' +``` + +This is enforced in `_cute_ir` (the compiled binding), not in Python. We +have no Python-level escape hatch. + +The other plausible "unique" MLIR types we tried have the same fate: + +| MLIR element type | allowlisted? | unique to Complex64? | +| --------------------- | :----------: | --------------------------------- | +| `complex` | ❌ | unique but rejected | +| `tuple` | ❌ | unique but rejected | +| `f64` | ✅ | shared with `Float64` | +| `i64` | ✅ | shared with `Int64` / `Uint64` | +| `vector<2xf32>`, etc. | ❌ (by family)| unique but rejected | + +So the only viable storage MLIR types are ones already claimed by some other +Python `Numeric`. We picked **`f64`**. + +--- + +## The Python type identity is recoverable + +`f64` is the storage type, but the kernel writer wants `tensor[i]` to return +a `Complex64`-typed value, not a `Float64`. The JIT-side `_Tensor` derives +`element_type` from the MLIR memref via `Numeric.from_mlir_type`, which is a +many-to-one lookup -- f64 always becomes `Float64`. So whatever we tell the +constructor about Python type identity must survive the round-trip. + +We do this two ways: + +1. **At `make_fake_tensor` / `make_rmem_tensor` time**, the dtype is passed + explicitly and stored on the runtime tensor's `_dtype` attribute. As long + as `Complex64.mlir_type == T.f64()`, the MLIR memref is f64 and the + Python `_dtype` stays `Complex64`. ✅ +2. **At `SmemAllocator.allocate_tensor` / `cute.recast_tensor` / pointer-to- + tensor `make_tensor` time**, the resulting JIT tensor is built from an + MLIR value with no carried dtype. Its `_dtype` falls back to + `Numeric.from_mlir_type(f64)` = `Float64`. ❌ -- needs a re-tag. + +The fix is a one-line wrapper, exposed as `allocate_smem_complex(...)` and +`recast_to_complex64(...)`. Both internally do +`t._dtype = Complex64; return t`. + +Without the re-tag, scalar writes like `smem[i] = c` (where `c: Complex64`) +flow through `_cvt_to_dest` (`cute/tensor.py:308`), which calls +`data.to(self.element_type)` = `Complex64.to(Float64)` -- the latter does a +cvtf-style real-number conversion on the f64 packed bits, then truncates to +f32, then stores. Output is silently corrupted (specifically: high 32 bits +zeroed, low 32 bits hold the f32 truncation of the bit-packed-as-f64 value). + +This is the *only* sharp edge the wrapper class can't smooth over, and it +shows up exactly at MLIR boundaries. + +--- + +## Why inherit from `Float32` + +The arithmetic operators (`__add__`/`__mul__`/...) on Complex64 unpack to +two Float32 lanes, compute, repack. Easy when Complex64 is on the LEFT -- +Python calls `Complex64.__mul__(other)` first. + +For `Float32 * Complex64`, Python normally calls `Float32.__mul__(complex)` +first. Without intervention, that goes through `Numeric._binary_op_type_promote` +which sees the operands as `(f32, i64)` or `(f32, f64)` (depending on the +Complex64.mlir_type) and either: + +- **Inherit from `Int64`** (mlir_type=i64): kind mismatch fires at store time + with `"type mismatch, store f64 to CxI64"`. Loud error, no corruption. +- **Inherit from `Float64`** (mlir_type=f64): same kind, same width -- promotion + picks Complex64 as result type, then `op(lhs.value, rhs.value)` calls + `arith.mulf` on two f64 SSAs. Stored bits are valid f64s but interpret the + packed-bit pattern as a real number. **Silent corruption.** +- **Inherit from `Float32`** (mlir_type=f64, width=64 override): `Complex64` is + a strict subclass of `Float32`, so Python's subclass-precedence rule fires + `Complex64.__rmul__(Float32)` *before* `Float32.__mul__(Complex64)`. We + control the dispatch and do the right thing. ✅ + +`Float32` inheritance is the only one of the three that gives both directions +of mixed arithmetic without monkey-patching. The structural cost: `isinstance( +complex_val, Float32) is True`. Only one site in cutlass touches this +(`nvvm_wrappers.py:441`, an fmax helper); not a problem in practice. + +--- + +## Why same-type init must short-circuit first + +`_cvt_to_dest` calls `data.to(element_type)` on every tensor write. If both +sides are `Complex64`, this becomes `Complex64(complex_instance)`. If our +`__init__` falls through to the generic-`Numeric` branch (`Float32(x)` then +re-pack), it treats the existing packed-bits Complex64 as if it were a +real-valued Float, extracts the f32 truncation of the bit pattern, and packs +that as `(real_view, 0)`. This was the section-8b bug. + +The fix is a one-line check at the top of `__init__`: + +```python +if isinstance(x, Complex64): + Numeric.__init__(self, x.value) + return +``` + +--- + +## What's been validated + +Complex64 type itself (`tests/test_complex.py`, 44 tests): + +- All constructor branches: Python `complex`, two-arg `(re, im)`, same-type + copy, `ir.Value` of f64/f32, `Float32` instance, other `Numeric` (Int32, + Float64), Python int/float, plus the bad-input `TypeError` path. +- Arithmetic matrix: `Complex64 OP X` for `X` in `{Complex64, Float32, int, + float}`, both directions; `__neg__`; randomized correctness sweep against + Python complex reference. +- Methods: `conj()` (involution + `c * c.conj()` is real), `real()`, `imag()`, + `from_re_im` alias. +- Class invariants: `width=64`, `is_float=True`, `mlir_type == T.f64()`, + `isinstance(c, Float32) is True`, `is_same_kind(Float32 / Float64) is True`. +- Helpers: `complex_storage(t)` (passthrough for f64, view for complex64, + `TypeError` for other dtypes), `recast_to_complex64`, `allocate_smem_complex`. +- `__c_pointers__`: static-value path works, dynamic-SSA path raises. +- Edge values round-trip exactly through the bitcast plumbing: zero, -0.0, + +/-inf, NaN, denormal, very large magnitudes, mixed signs. +- Auto-promotion via `_cvt_to_dest`: writing a `Float32` or `Float64` to a + `Complex64` tensor lands as `(value, 0)` (pinned because future changes to + `Numeric` promotion could silently break this). +- Smem round-trip (gmem -> rmem -> smem -> sync -> smem -> rmem -> gmem) + using `allocate_smem_complex`; serves as a regression test for the + dtype-loss bug at the smem allocation boundary. + +Through real kernels (`tests/test_fft.py`): + +- tvm-ffi compile path (`--enable-tvm-ffi` + `complex_storage(t)` boundary). +- Standalone radix-8 FFT butterfly matches `torch.fft.fft` numerically. +- N=64 = 8x8 Cooley-Tukey FFT with smem transpose matches numerically. +- The production native FFT class (N=2..8192) operates on Complex64 register + fragments end-to-end and matches `torch.fft.fft` (existing + `test_fft_native_power_of_two_matches_torch` parametrizations). + +## What hasn't been validated + +- TMA (`SM90`-style bulk loads / cp.async.bulk) carrying Complex64 elements. +- `cute.copy` with explicit copy atoms typed as Complex64; the production + FFT path keeps smem in `Float32` layouts and recasts at the rmem boundary. +- Tensor-core (MMA) operations -- not expected to be relevant for FFT but + noted for completeness. +- Typed-Complex64 smem layouts for the FFT path. The fast paths still allocate + `Float32` smem because the tuned interleaved / split-real-imag / swizzled + layouts are easier to express as scalar f32. Whether a typed-Complex64 smem + variant wins on any specific N is an open profiling question (see the + remaining work in `AI/complex64_fft_migration_plan.md`). +- Mixed Complex64 / Float32 arithmetic when `Float32` is a TENSOR (not just + a scalar) -- e.g. multiplying a Complex64 register by a Float32 lane from + another rmem tensor. Should work via the `__rmul__` path but not tested + directly. + +--- + +## Sharp edges to keep in mind + +1. **MLIR boundary loses dtype.** Any code path that constructs a JIT tensor + from MLIR (smem alloc, recast, ptr-to-tensor) loses the Complex64 tag. + Use the wrappers in `quack/complex.py` or call `_retag_as_complex64(t)` + manually. +2. **`isinstance(c, Float32) is True`.** Only one cutlass site checks this + (`nvvm_wrappers.py:441`). Easy to grep for if behavior is suspicious. +3. **`is_same_kind(Complex64, Float32) is True`** because both have + `is_float=True`. So `tensor[i] = some_float32` will auto-promote via + `_cvt_to_dest` -> `Complex64(float32_instance)` -> `(re=value, im=0)`. + That's friendly ergonomics for scalar real-valued writes, but if you + accidentally store an f32 expecting "complex bits", it'll be silently + wrapped. Add an explicit type check in `__setitem__` if you want to + harden a particular kernel. +4. **`Float32 OP Complex64` only works** because `Complex64` is a strict + subclass of `Float32`, triggering Python's reflected-operator precedence. + Don't break the inheritance. +5. **`Numeric.from_mlir_type` is a single global registry.** Anything + monkey-patching it for Complex64 → Float64 mapping must avoid disturbing + the existing Float64 mapping (we don't currently patch it; we use the + per-tensor re-tag instead). +6. **Dynamic Complex64 SSA values can't be passed as kernel args.** + `__c_pointers__` raises if `self.value` is an SSA. Static (Python + complex literal) values work fine. This is consistent with how `Float32` + handles the same case. +7. **Twiddle generation builds an O(N) `arith.select` cascade** when the + index is dynamic. For small N (8, 16, 32, 64) that's fine; for larger N + use the LUT pattern via `_twiddle_from_lut_cx` (and friends: + `_twiddle_cx`, `_twiddle_binary_cx`, `_twiddle_in_range_cx`) in + `quack/fft.py`. These wrap the existing tuple-returning helpers and + return a single `Complex64` value. + +--- + +## File layout + +``` +quack/complex.py + class Complex64(Float32, width=64, mlir_type=T.f64) + __init__(x, im=None) # complex / Float32 / Numeric / ir.Value(f32|f64) + _pack_ssa(re, im) # static: two f32 SSAs -> packed f64 SSA + _unpack(self) # f64 -> (Float32, Float32) + from_re_im(re, im) # public alias of _pack_ssa-based ctor + real() / imag() / conj() + __add__/__radd__/__sub__/__rsub__/__mul__/__rmul__/__neg__ + __c_pointers__ # 8-byte scalar-arg path + + allocate_smem_complex(allocator, layout) + recast_to_complex64(src) + complex_storage(t) # torch.complex64 -> torch.float64 view + + _register_with_tvm_ffi() # called at import time + +quack/fft.py # Complex64-native FFT path; *_cx primitives + # (_fft{2,4,8,16,32}_inplace_cx*, _mul_j_cx, + # _apply_stage_twiddle_cx, + # _mul_by_base_twiddle_powers_cx, + # _twiddle{,_binary,_in_range,_from_lut}_cx, + # plus pure-Complex64 smem helpers) +``` + +``` +tests/test_complex.py # 44 tests for Complex64 the type +tests/test_fft.py # FFT-correctness tests using Complex64 + +AI/complex64_design_notes.md # this file (design rationale + sharp edges) +AI/complex64_fft_migration_plan.md # cutover history and remaining work +AI/fft_optimization_notes.md # broader FFT perf notes +``` diff --git a/quack/complex.py b/quack/complex.py new file mode 100644 index 00000000..cc76056c --- /dev/null +++ b/quack/complex.py @@ -0,0 +1,292 @@ +"""Complex64 element type for CuTe-DSL kernels. + +Single-precision complex (re + imj) carried as f64-packed bits (re in the low +32 bits, im in the high 32 bits). f64 is on `cute.MemRefType`'s element-type +allowlist; the natural `complex` MLIR type is not. Arithmetic methods +unpack each f64 into two Float32 lanes, compute, and repack -- the bitcasts +are folded out by ptxas. + +Inherits from `Float32` (with `width=64, mlir_type=T.f64` overrides) so that +Python's subclass-precedence rule routes `Float32 OP Complex64` to our +reflected `__r*__` operators before Numeric's promotion logic sees the +operands. Without this, Float32-on-the-LEFT would silently promote through +Float64 conversion and corrupt the packed bits. + +Boundary convention (tvm-ffi): the compiled kernel's ABI sees f64 storage. +At the call site, pass `torch.complex64` tensors as `t.view(torch.float64)` +-- use `complex_storage(t)` for the conversion. + +See `AI/complex64_design_notes.md` for the why and what's been validated. +""" + +from __future__ import annotations + +import ctypes + +import numpy as np +import torch + +import cutlass.cute as cute +from cutlass import Float32, Numeric +from cutlass._mlir import ir +from cutlass._mlir.dialects import arith +from cutlass._mlir.extras import types as T +from cutlass.base_dsl._mlir_helpers.arith import bitcast as _bitcast +from cutlass.base_dsl.typing import FloatMeta + + +class Complex64(Float32, metaclass=FloatMeta, width=64, mlir_type=T.f64): + """Complex64 carried as f64-packed bits (re in low 32, im in high 32). + + `tensor.element_type is Complex64` inside the kernel; indexing returns + Complex64 instances; `+`, `-`, `*`, `__neg__`, and `conj()` work natively. + """ + + def __init__(self, x, im=None, *, loc=None, ip=None): + # Two-arg lane form: Complex64(re, im) packs (re, im) into f64 bits. + # Coerce both through Float32 so int / float / Float32 / ir.Value(f32) + # all work as inputs. + if im is not None: + # Static fast path: both args are Python int/float -- no MLIR + # context needed (lets host-side code call Complex64(2.5, -1.5)). + if isinstance(x, (int, float)) and isinstance(im, (int, float)): + Complex64.__init__(self, complex(x, im), loc=loc, ip=ip) + return + re_ssa = Float32(x).ir_value() + im_ssa = Float32(im).ir_value() + Numeric.__init__(self, Complex64._pack_ssa(re_ssa, im_ssa)) + return + + # Same-type copy MUST be checked first. `_cvt_to_dest` (cute/tensor.py) + # calls `data.to(element_type)` on every tensor write, which becomes + # `Complex64(complex_instance)`; falling through to the generic-Numeric + # branch below would re-pack as (real_view_of_packed_bits, 0) and + # silently corrupt the data. + if isinstance(x, Complex64): + Numeric.__init__(self, x.value) + return + + if isinstance(x, complex): + f64_val = _pack_python_complex(x) + Numeric.__init__(self, f64_val) + return + + if isinstance(x, ir.Value): + if x.type == T.f64(): + # Already in our storage form (loaded from a Complex64 tensor, + # or output of _pack_ssa). + Numeric.__init__(self, x) + return + if x.type == T.f32(): + packed = Complex64._pack_ssa(x, arith.constant(T.f32(), 0.0)) + Numeric.__init__(self, packed) + return + raise TypeError(f"Complex64: ir.Value of unsupported type {x.type}") + + if isinstance(x, Numeric): + # Float32, Int32, Float64, etc. -> coerce real lane through Float32, + # imag lane = 0. Float32(Float32) is a no-op, so this also handles + # the Float32 case cleanly. + re_ssa = Float32(x).ir_value() + packed = Complex64._pack_ssa(re_ssa, arith.constant(T.f32(), 0.0)) + Numeric.__init__(self, packed) + return + + if isinstance(x, (int, float)): + Complex64.__init__(self, complex(x, 0.0), loc=loc, ip=ip) + return + + raise TypeError(f"Complex64: unsupported source type {type(x)}") + + # ---- packing / unpacking primitives -------------------------------- + + @staticmethod + def _pack_ssa(re_f32, im_f32): + """Pack two f32 SSA lanes into one f64 SSA value (re lo, im hi).""" + re_i32 = _bitcast(re_f32, T.i32()) + im_i32 = _bitcast(im_f32, T.i32()) + re_i64 = arith.extui(T.i64(), re_i32) + im_i64 = arith.extui(T.i64(), im_i32) + hi = arith.shli(im_i64, arith.constant(T.i64(), 32)) + return _bitcast(arith.ori(re_i64, hi), T.f64()) + + def _unpack(self): + """Split self -> (re_f32, im_f32) as Float32 SSA values.""" + i64_ssa = _bitcast(self.ir_value(), T.i64()) + lo32 = arith.trunci(T.i32(), i64_ssa) + hi32 = arith.trunci(T.i32(), arith.shrui(i64_ssa, arith.constant(T.i64(), 32))) + return Float32(_bitcast(lo32, T.f32())), Float32(_bitcast(hi32, T.f32())) + + @staticmethod + def from_re_im(re: Float32, im: Float32) -> "Complex64": + """Build a Complex64 from two Float32 SSA lanes. + + Equivalent to `Complex64(re, im)`; kept as an explicit name for the + hot-path call sites that want to skip the Float32 coercion in __init__. + """ + return Complex64(Complex64._pack_ssa(re.ir_value(), im.ir_value())) + + # Internal alias used by arithmetic methods. + _from_re_im = from_re_im + + # ---- accessors ------------------------------------------------------ + + def real(self) -> Float32: + re, _ = self._unpack() + return re + + def imag(self) -> Float32: + _, im = self._unpack() + return im + + def conj(self) -> "Complex64": + re, im = self._unpack() + return Complex64._from_re_im(re, -im) + + # ---- arithmetic ----------------------------------------------------- + + def __add__(self, other, *, loc=None, ip=None): + a_re, a_im = self._unpack() + b_re, b_im = _other_lanes(other) + return Complex64._from_re_im(a_re + b_re, a_im + b_im) + + def __radd__(self, other, *, loc=None, ip=None): + return self.__add__(other, loc=loc, ip=ip) + + def __sub__(self, other, *, loc=None, ip=None): + a_re, a_im = self._unpack() + b_re, b_im = _other_lanes(other) + return Complex64._from_re_im(a_re - b_re, a_im - b_im) + + def __rsub__(self, other, *, loc=None, ip=None): + a_re, a_im = self._unpack() + b_re, b_im = _other_lanes(other) + return Complex64._from_re_im(b_re - a_re, b_im - a_im) + + def __mul__(self, other, *, loc=None, ip=None): + a_re, a_im = self._unpack() + if isinstance(other, Complex64): + b_re, b_im = other._unpack() + return Complex64._from_re_im(a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re) + # Real scalar: (re, im) * s = (re*s, im*s) + s = Float32(other) + return Complex64._from_re_im(a_re * s, a_im * s) + + def __rmul__(self, other, *, loc=None, ip=None): + return self.__mul__(other, loc=loc, ip=ip) + + def __neg__(self, *, loc=None, ip=None): + re, im = self._unpack() + return Complex64._from_re_im(-re, -im) + + # ---- runtime arg passing ------------------------------------------- + + def __c_pointers__(self): + # Scalar Complex64 args travel as 8 bytes (the packed-as-f64 value). + if not isinstance(self.value, float): + raise ValueError( + "Complex64 with a dynamic SSA value cannot be passed as a " + "kernel argument; only static values are supported" + ) + return [ctypes.cast(ctypes.pointer(ctypes.c_double(self.value)), ctypes.c_void_p)] + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _pack_python_complex(c: complex) -> float: + """Compute the f64 representation of a complex bit-packed (re, im).""" + re_b = int(np.float32(c.real).view(np.uint32)) + im_b = int(np.float32(c.imag).view(np.uint32)) + return float(np.uint64((im_b << 32) | re_b).view(np.float64)) + + +def _other_lanes(other): + """Unpack the RHS of a binary op into (re_f32, im_f32) Float32 lanes.""" + if isinstance(other, Complex64): + return other._unpack() + return Float32(other), Float32(0.0) + + +def _retag_as_complex64(t): + """Restore `t.element_type is Complex64` after a code path that derived it + from MLIR (where complex64 collapses to Float64 / Int64 because + `Numeric.from_mlir_type` is a many-to-one lookup).""" + t._dtype = Complex64 + return t + + +# --------------------------------------------------------------------------- +# Public helpers +# --------------------------------------------------------------------------- + + +def allocate_smem_complex( + allocator, + layout_or_shape, + byte_alignment: int = 16, + swizzle=None, +): + """Allocate a `Complex64` smem tensor. + + Wraps `cutlass.utils.SmemAllocator.allocate_tensor(Complex64, ...)` and + re-tags the result so `tensor.element_type is Complex64`. Without the + re-tag, the JIT-side tensor's element_type is `Float64` (derived from the + f64 memref) and writes go through `Complex64.to(Float64)` and corrupt the + packed bits. + """ + t = allocator.allocate_tensor( + Complex64, layout_or_shape, byte_alignment=byte_alignment, swizzle=swizzle + ) + return _retag_as_complex64(t) + + +def recast_to_complex64(src: cute.Tensor) -> cute.Tensor: + """Recast any tensor (e.g. Float32, Int64) to a `Complex64` tensor. + + Wraps `cute.recast_tensor(src, Complex64)` and re-tags the result. Same + dtype-loss bug as `allocate_smem_complex` -- the underlying recast goes + through `make_tensor`, which derives element_type from the MLIR memref + (here f64) and gets back Float64. + """ + return _retag_as_complex64(cute.recast_tensor(src, Complex64)) + + +def complex_storage(t: torch.Tensor) -> torch.Tensor: + """View a `torch.complex64` tensor as `torch.float64` with the same memory. + + Compiled kernels declared with `Complex64` element type have an f64 ABI; + use this at the boundary to satisfy tvm-ffi's dtype check without copying. + """ + if t.dtype == torch.float64: + return t + if t.dtype != torch.complex64: + raise TypeError( + f"complex_storage expects torch.complex64 (or torch.float64 for " + f"already-converted storage), got {t.dtype}" + ) + return t.view(torch.float64) + + +# --------------------------------------------------------------------------- +# tvm-ffi registration +# --------------------------------------------------------------------------- + + +def _register_with_tvm_ffi() -> None: + """Teach tvm-ffi that Complex64 has an f64 ABI. + + Both `NumericToTVMFFIDtype` (the type->dtype-string lookup) and + `AcceptableNumericTypesForScalar` (the allowlist for scalar kernel args) + are plain Python collections, so we extend them at import time. + """ + from cutlass.cute import _tvm_ffi_args_spec_converter as _cv + + _cv.NumericToTVMFFIDtype.setdefault(Complex64, "float64") + if Complex64 not in _cv.AcceptableNumericTypesForScalar: + _cv.AcceptableNumericTypesForScalar.append(Complex64) + + +_register_with_tvm_ffi() diff --git a/quack/gemm_act.py b/quack/gemm_act.py index d389a426..5d68cd6c 100644 --- a/quack/gemm_act.py +++ b/quack/gemm_act.py @@ -216,6 +216,17 @@ class GemmGatedMixin(GemmActMixin): TileStore("mAuxOut", epi_tile_fn=_gated_epi_tile_fn), ) + def _valid_2cta_m(self): + # mma_tiler_m=128 with 2-CTA gives cta_tile_m=64, which forces a (2, 2) + # epilogue warp shape in compute_epilogue_tile_shape. The non-contiguous + # epi_tile_n that this layout produces (e.g. shape (32, 2) stride (1, 128)) + # is recast by `_gated_epi_tile_fn` for the half-N postact tile, but the + # resulting smem/TMA partitioning miscomputes preact and postact (the + # corruption builds across persistent-kernel iterations). Until the + # gated epi_visit_subtile and aux-out r2s copy are taught about the + # (2, 2) layout, restrict 2-CTA to mma_tiler_m=256 for gated. + return (256,) + def epi_to_underlying_arguments( self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None ) -> GemmActMixin.EpilogueParams: diff --git a/quack/gemm_dact.py b/quack/gemm_dact.py index d630f1bc..648064b2 100644 --- a/quack/gemm_dact.py +++ b/quack/gemm_dact.py @@ -104,6 +104,12 @@ class GemmDGatedMixin(GemmActMixin): ) _extra_param_fields = (("act_bwd_fn", cutlass.Constexpr, None),) + def _valid_2cta_m(self): + # See GemmGatedMixin._valid_2cta_m: the (2, 2) epilogue warp shape that + # mma_tiler_m=128 + 2-CTA produces breaks the gated/dgated postact path + # on SM100. Restrict 2-CTA to mma_tiler_m=256 here too. + return (256,) + @mlir_namedtuple class EpilogueArguments(NamedTuple): mAuxOut: cute.Tensor diff --git a/quack/gemm_sm100.py b/quack/gemm_sm100.py index df5edcab..a5b85c20 100644 --- a/quack/gemm_sm100.py +++ b/quack/gemm_sm100.py @@ -169,7 +169,7 @@ def __init__( self.sf_vec_size = sf_vec_size self.blockscaled = sf_vec_size is not None assert len(mma_tiler_mnk) in [2, 3], "MMA tiler must be (M, N) or (M, N, K)" - valid_2cta_m = (128, 256) if not self.blockscaled else (256,) + valid_2cta_m = self._valid_2cta_m() self.use_2cta_instrs = cluster_shape_mnk[0] % 2 == 0 and mma_tiler_mnk[0] in valid_2cta_m self.cluster_shape_mnk = cluster_shape_mnk assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1" @@ -239,6 +239,15 @@ def epi_smem_warp_shape_mnk(self): ) return (warp_m, warp_n, 1) + def _valid_2cta_m(self): + """Return the set of mma_tiler[0] values for which 2-CTA MMA is enabled. + + Subclasses override to exclude shapes whose epilogue layout doesn't yet + support certain features (e.g. gated postact with the (2, 2) epilogue + warp shape produced by mma_tiler_m=128 + 2-CTA). + """ + return (128, 256) if not self.blockscaled else (256,) + def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments): """Set up configurations that are dependent on GEMM inputs diff --git a/tests/test_complex.py b/tests/test_complex.py new file mode 100644 index 00000000..6128fd80 --- /dev/null +++ b/tests/test_complex.py @@ -0,0 +1,668 @@ +"""Smoke tests for quack.complex.Complex64 through the tvm-ffi compile path.""" + +import math + +import numpy as np +import pytest +import torch + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Float64, Int32, const_expr +from cutlass._mlir import ir +from cutlass._mlir.extras import types as T + +from quack.complex import ( + Complex64, + allocate_smem_complex, + complex_storage, + recast_to_complex64, +) + + +class _ScaleByComplex: + """out[i, j] = scale_complex * in[i, j], one thread per element.""" + + def __init__(self, threads_per_block: int = 128): + self.threads_per_block = threads_per_block + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mO: cute.Tensor, + scale: Complex64, + stream: cuda.CUstream, + ): + assert mX.element_type is Complex64 + assert mO.element_type is Complex64 + batch, n = mX.shape + threads = self.threads_per_block + grid_x = cute.ceil_div(n, threads) + self.kernel(mX, mO, scale).launch( + grid=[grid_x, batch, 1], + block=[threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel(self, mX: cute.Tensor, mO: cute.Tensor, scale: Complex64): + tidx, _, _ = cute.arch.thread_idx() + bidx_x, bidx_y, _ = cute.arch.block_idx() + i = bidx_x * self.threads_per_block + tidx + if i < mX.shape[1]: + mO[bidx_y, i] = mX[bidx_y, i] * scale + + +def _compile_scale_by_complex(n: int): + batch_sym = cute.sym_int() + x_fake = cute.runtime.make_fake_tensor( + Complex64, (batch_sym, n), stride=(n, 1), assumed_align=8 + ) + o_fake = cute.runtime.make_fake_tensor( + Complex64, (batch_sym, n), stride=(n, 1), assumed_align=8 + ) + return cute.compile( + _ScaleByComplex(), + x_fake, + o_fake, + Complex64(complex(0.0, 0.0)), # placeholder; real value passed at call + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +@pytest.mark.parametrize("n", [128, 256]) +@pytest.mark.parametrize("batch", [1, 4]) +def test_complex_scale_tvm_ffi(batch: int, n: int): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + + torch.manual_seed(0) + x = torch.randn(batch, n, dtype=torch.complex64, device="cuda") + out = torch.empty_like(x) + scale = complex(2.0, 1.0) + + fn = _compile_scale_by_complex(n) + fn(complex_storage(x), complex_storage(out), Complex64(scale)) + torch.cuda.synchronize() + + expected = x * scale + torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5) + + +# --------------------------------------------------------------------------- +# Smem round-trip: gmem -> rmem -> smem -> sync -> smem -> rmem -> gmem. +# Catches the "smem allocation loses Complex64 type tag" bug -- without +# allocate_smem_complex, scalar smem writes silently truncate via Float64.to. +# --------------------------------------------------------------------------- + + +_SMEM_THREADS = 32 +_SMEM_ELEMS_PER_THREAD = 4 +_SMEM_BLOCK_ELEMS = _SMEM_THREADS * _SMEM_ELEMS_PER_THREAD + + +class _SmemRoundTrip: + """Each thread loads 4 contiguous complex from gmem, stores to smem at the + same indices, syncs, reads them back at strided indices, writes to gmem at + the strided indices. Output should equal input (read-stride and write-stride + cancel).""" + + @cute.jit + def __call__(self, mX: cute.Tensor, mO: cute.Tensor, stream): + batch = mX.shape[0] + self.kernel(mX, mO).launch( + grid=[batch, 1, 1], + block=[_SMEM_THREADS, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel(self, mX, mO): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + rmem = cute.make_rmem_tensor(_SMEM_ELEMS_PER_THREAD, Complex64) + for i in cutlass.range_constexpr(_SMEM_ELEMS_PER_THREAD): + rmem[i] = mX[bidx, tidx * const_expr(_SMEM_ELEMS_PER_THREAD) + const_expr(i)] + + smem = cutlass.utils.SmemAllocator() + exchange = allocate_smem_complex(smem, cute.make_layout(_SMEM_BLOCK_ELEMS)) + + for i in cutlass.range_constexpr(_SMEM_ELEMS_PER_THREAD): + exchange[tidx * const_expr(_SMEM_ELEMS_PER_THREAD) + const_expr(i)] = rmem[i] + cute.arch.barrier() + for i in cutlass.range_constexpr(_SMEM_ELEMS_PER_THREAD): + rmem[i] = exchange[tidx + const_expr(i * _SMEM_THREADS)] + for i in cutlass.range_constexpr(_SMEM_ELEMS_PER_THREAD): + mO[bidx, tidx + const_expr(i * _SMEM_THREADS)] = rmem[i] + + +def _compile_smem_roundtrip(): + batch_sym = cute.sym_int() + args = [ + cute.runtime.make_fake_tensor( + Complex64, + (batch_sym, _SMEM_BLOCK_ELEMS), + stride=(_SMEM_BLOCK_ELEMS, 1), + assumed_align=8, + ), + cute.runtime.make_fake_tensor( + Complex64, + (batch_sym, _SMEM_BLOCK_ELEMS), + stride=(_SMEM_BLOCK_ELEMS, 1), + assumed_align=8, + ), + ] + return cute.compile( + _SmemRoundTrip(), + *args, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +@pytest.mark.parametrize("batch", [1, 4]) +def test_smem_roundtrip_complex64(batch: int): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + + torch.manual_seed(0) + x = torch.randn(batch, _SMEM_BLOCK_ELEMS, dtype=torch.complex64, device="cuda") + out = torch.empty_like(x) + fn = _compile_smem_roundtrip() + fn(complex_storage(x), complex_storage(out)) + torch.cuda.synchronize() + torch.testing.assert_close(out, x, atol=1e-6, rtol=1e-6) + + +def test_complex_static_unpack_pack_roundtrip(): + """Sanity: a static Python complex packs/unpacks losslessly through f64 bits.""" + import numpy as np + + for c in [ + complex(1.0, -1.0), + complex(3.0, 4.0), + complex(-7.5, 0.25), + complex(0.0, 0.0), + ]: + cx = Complex64(c) + # Round-trip through the value's stored f64 -> bytes -> two f32s + f64_bits = np.float64(cx.value).view(np.uint64).item() + re_bits = np.uint32(f64_bits & 0xFFFFFFFF) + im_bits = np.uint32((f64_bits >> 32) & 0xFFFFFFFF) + re = re_bits.view(np.float32).item() + im = im_bits.view(np.float32).item() + assert re == c.real and im == c.imag, f"{c} -> ({re}, {im})" + + +# =========================================================================== +# Tier 1-8 Complex64 type tests. See AI/complex64_design_notes.md for the +# behaviors being pinned. Most kernel-side tests share the same harness: +# launch a one-thread kernel that writes test results into a small Complex64 +# output buffer; verify on the host. +# =========================================================================== + + +def _f64bits_to_complex(f64_value: float) -> complex: + """Decode the f64 packed-as-complex value into a Python complex.""" + bits = np.float64(f64_value).view(np.uint64).item() + re = np.uint32(bits & 0xFFFFFFFF).view(np.float32).item() + im = np.uint32((bits >> 32) & 0xFFFFFFFF).view(np.float32).item() + return complex(re, im) + + +# --------------------------------------------------------------------------- +# Harness: run a one-thread kernel that writes Complex64 results to gmem. +# --------------------------------------------------------------------------- + + +class _OneShot: + """Generic one-thread kernel that delegates the body to `body_fn`.""" + + def __init__(self, body_fn): + self.body_fn = body_fn + + @cute.jit + def __call__(self, mO: cute.Tensor, stream): + self.kernel(mO).launch(grid=[1, 1, 1], block=[1, 1, 1], stream=stream) + + @cute.kernel + def kernel(self, mO: cute.Tensor): + self.body_fn(mO) + + +def _run_oneshot(body_fn, n_outputs: int) -> list[complex]: + """Compile a kernel that writes n_outputs Complex64 values, run it, return them.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + out = torch.zeros(n_outputs, dtype=torch.complex64, device="cuda") + out_fake = cute.runtime.make_fake_tensor(Complex64, (n_outputs,), stride=(1,), assumed_align=8) + fn = cute.compile( + _OneShot(body_fn), + out_fake, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + fn(complex_storage(out)) + torch.cuda.synchronize() + return out.tolist() + + +# --------------------------------------------------------------------------- +# Tier 1: constructor paths +# --------------------------------------------------------------------------- + + +def test_constructor_python_complex(): + """Tier 1.1 -- Complex64(c) for Python complex (host-side).""" + cx = Complex64(complex(3.0, 4.0)) + assert _f64bits_to_complex(cx.value) == complex(3.0, 4.0) + + +def test_constructor_two_arg(): + """Tier 1.2 -- Complex64(re, im) two-arg form (host-side, static floats).""" + cx = Complex64(2.5, -1.5) + assert _f64bits_to_complex(cx.value) == complex(2.5, -1.5) + + +def test_constructor_same_type_copy_kernel(): + """Tier 1.3 -- Complex64(complex64_instance) kernel-side same-type copy.""" + + def body(mO): + a = Complex64(complex(3.0, 4.0)) + mO[0] = Complex64(a) # short-circuit copy + + out = _run_oneshot(body, 1) + assert out[0] == complex(3.0, 4.0) + + +def test_constructor_from_f32_ir_value_kernel(): + """Tier 1.5 -- Complex64(ir.Value of f32) becomes (val, 0).""" + + def body(mO): + f32_val = Float32(2.5) + mO[0] = Complex64(f32_val.ir_value()) + + out = _run_oneshot(body, 1) + assert out[0] == complex(2.5, 0.0) + + +def test_constructor_from_float32_kernel(): + """Tier 1.6 -- Complex64(Float32 instance) becomes (val, 0).""" + + def body(mO): + mO[0] = Complex64(Float32(7.0)) + mO[1] = Complex64(Float32(-3.5)) + + out = _run_oneshot(body, 2) + assert out[0] == complex(7.0, 0.0) + assert out[1] == complex(-3.5, 0.0) + + +def test_constructor_from_other_numeric_kernel(): + """Tier 1.7 -- Complex64(Int32 / Float64) becomes (real, 0).""" + + def body(mO): + mO[0] = Complex64(Int32(5)) + mO[1] = Complex64(Float64(1.5)) + + out = _run_oneshot(body, 2) + assert out[0] == complex(5.0, 0.0) + assert out[1] == complex(1.5, 0.0) + + +def test_constructor_from_python_int_float_kernel(): + """Tier 1.8 -- Complex64(int) and Complex64(float) become (val, 0).""" + + def body(mO): + mO[0] = Complex64(3) + mO[1] = Complex64(2.5) + mO[2] = Complex64(-7) + + out = _run_oneshot(body, 3) + assert out[0] == complex(3.0, 0.0) + assert out[1] == complex(2.5, 0.0) + assert out[2] == complex(-7.0, 0.0) + + +def test_constructor_bad_input_raises(): + """Tier 1.9 -- Complex64('hello') and Complex64(object()) raise TypeError.""" + with pytest.raises(TypeError): + Complex64("hello") + with pytest.raises(TypeError): + Complex64(object()) + + +# --------------------------------------------------------------------------- +# Tier 2: arithmetic matrix +# --------------------------------------------------------------------------- + + +def test_arithmetic_complex_complex(): + """Tier 2.10 -- Complex64 +/-/* Complex64 against Python complex reference.""" + + def body(mO): + a = Complex64(complex(1.0, 2.0)) + b = Complex64(complex(3.0, -1.0)) + mO[0] = a + b + mO[1] = a - b + mO[2] = a * b + mO[3] = -a + + out = _run_oneshot(body, 4) + a = complex(1.0, 2.0) + b = complex(3.0, -1.0) + assert out[0] == a + b + assert out[1] == a - b + assert out[2] == a * b + assert out[3] == -a + + +def test_arithmetic_with_float32_both_directions(): + """Tier 2.11 -- Cx OP Float32 and Float32 OP Cx via subclass-precedence.""" + + def body(mO): + a = Complex64(complex(1.0, 2.0)) + s = Float32(5.0) + mO[0] = a + s # Complex64 + Float32 (rhs) + mO[1] = s + a # Float32 + Complex64 (lhs, via __radd__) + mO[2] = a - s # Complex64 - Float32 + mO[3] = s - a # Float32 - Complex64 (lhs, via __rsub__) + mO[4] = a * Float32(2.0) + mO[5] = Float32(2.0) * a # __rmul__ + + out = _run_oneshot(body, 6) + a = complex(1.0, 2.0) + assert out[0] == a + 5.0 + assert out[1] == 5.0 + a + assert out[2] == a - 5.0 + assert out[3] == 5.0 - a + assert out[4] == a * 2.0 + assert out[5] == 2.0 * a + + +def test_arithmetic_with_python_scalar(): + """Tier 2.10 cont. -- Cx OP {int, float} (via __init__'s int/float branch).""" + + def body(mO): + a = Complex64(complex(1.0, 2.0)) + mO[0] = a + 7.0 # cx + python float + mO[1] = a * 3 # cx * python int + + out = _run_oneshot(body, 2) + a = complex(1.0, 2.0) + assert out[0] == a + 7.0 + assert out[1] == a * 3.0 + + +def test_arithmetic_random_sweep(): + """Tier 2.13 -- (a + b * c) against Python complex reference for random triples.""" + rng = np.random.default_rng(0) + inputs = [ + ( + complex(rng.uniform(-2, 2), rng.uniform(-2, 2)), + complex(rng.uniform(-2, 2), rng.uniform(-2, 2)), + complex(rng.uniform(-2, 2), rng.uniform(-2, 2)), + ) + for _ in range(8) + ] + + def body(mO): + for i, (a_v, b_v, c_v) in enumerate(inputs): + a = Complex64(a_v) + b = Complex64(b_v) + c = Complex64(c_v) + mO[i] = a + b * c + + out = _run_oneshot(body, len(inputs)) + for i, (a_v, b_v, c_v) in enumerate(inputs): + expected = a_v + b_v * c_v + # f32 precision; use loose tolerance + assert abs(complex(out[i]) - expected) < 1e-4, f"i={i}: {out[i]} vs {expected}" + + +# --------------------------------------------------------------------------- +# Tier 3: methods +# --------------------------------------------------------------------------- + + +def test_conj_basic(): + """Tier 3.14 -- conj() flips the imaginary lane.""" + + def body(mO): + a = Complex64(complex(3.0, -4.0)) + mO[0] = a.conj() + mO[1] = a.conj().conj() # idempotent + + out = _run_oneshot(body, 2) + assert out[0] == complex(3.0, 4.0) + assert out[1] == complex(3.0, -4.0) + + +def test_conj_self_product_is_real(): + """Tier 3.14 cont. -- (c * c.conj()).imag == 0; .real == |c|^2.""" + + def body(mO): + b = Complex64(complex(1.5, 2.5)) + mO[0] = b * b.conj() + + out = _run_oneshot(body, 1) + expected_norm_sq = 1.5 * 1.5 + 2.5 * 2.5 # = 8.5 + assert abs(out[0].real - expected_norm_sq) < 1e-5 + assert out[0].imag == 0.0 + + +def test_real_imag_accessors_kernel(): + """Tier 3.15 -- real() and imag() return the correct Float32 lanes.""" + + def body(mO): + a = Complex64(complex(3.0, -4.0)) + # Pack the lanes back into a Complex64 to round-trip them through gmem + mO[0] = Complex64(a.real(), a.imag()) + + out = _run_oneshot(body, 1) + assert out[0] == complex(3.0, -4.0) + + +def test_from_re_im_alias_equivalence(): + """Tier 3.16 -- from_re_im(r, i) and Complex64(r, i) produce identical bits.""" + + def body(mO): + re = Float32(2.0) + im = Float32(-1.5) + mO[0] = Complex64.from_re_im(re, im) + mO[1] = Complex64(re, im) + + out = _run_oneshot(body, 2) + assert out[0] == out[1] + assert out[0] == complex(2.0, -1.5) + + +# --------------------------------------------------------------------------- +# Tier 4: class invariants +# --------------------------------------------------------------------------- + + +def test_complex64_class_attributes(): + """Tier 4.17, 4.18, 4.20, 4.21 -- pure Python class metadata.""" + assert Complex64.width == 64 + assert Complex64.is_float is True + assert isinstance(Complex64(complex(0, 0)), Float32) + # Auto-promotion behavior pins (used by _cvt_to_dest): + assert Complex64.is_same_kind(Float32) is True + assert Complex64.is_same_kind(Float64) is True + + +def test_complex64_mlir_type_is_f64(): + """Tier 4.19 -- mlir_type query needs an MLIR context.""" + with ir.Context() as ctx, ir.Location.unknown(): + ctx.allow_unregistered_dialects = True + assert Complex64.mlir_type == T.f64() + + +# --------------------------------------------------------------------------- +# Tier 5: helpers +# --------------------------------------------------------------------------- + + +def test_recast_to_complex64_kernel(): + """Tier 5.22 -- recast_to_complex64 on an rmem Float32 tensor preserves data.""" + + # body_fn is plain Python called from inside @cute.kernel, so plain + # `range(...)` unrolls at trace time (cutlass.range_constexpr only works + # inside a directly-decorated @cute.jit/@cute.kernel function). + def body(mO): + rmem_f32 = cute.make_rmem_tensor(16, Float32) + for i in range(8): + rmem_f32[2 * i] = Float32(float(i + 1)) + rmem_f32[2 * i + 1] = Float32(-float(i + 1)) + rmem_cx = recast_to_complex64(rmem_f32) + for i in range(8): + mO[i] = rmem_cx[i] + + out = _run_oneshot(body, 8) + expected = [complex(i + 1, -(i + 1)) for i in range(8)] + assert out == expected + + +def test_complex_storage_passthrough_for_float64(): + """Tier 5.24a -- complex_storage(float64 tensor) returns the input unchanged.""" + t = torch.zeros(4, dtype=torch.float64) + assert complex_storage(t) is t + + +def test_complex_storage_view_for_complex64(): + """Tier 5.24b -- complex_storage(complex64) returns a float64 view sharing memory.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA required for view alignment") + c = torch.zeros(4, dtype=torch.complex64, device="cuda") + v = complex_storage(c) + assert v.dtype == torch.float64 + assert v.data_ptr() == c.data_ptr() + assert v.numel() == c.numel() + + +def test_complex_storage_bad_dtype_raises(): + """Tier 5.24c -- complex_storage(other dtype) raises TypeError.""" + with pytest.raises(TypeError): + complex_storage(torch.zeros(4, dtype=torch.float32)) + with pytest.raises(TypeError): + complex_storage(torch.zeros(4, dtype=torch.int32)) + + +# --------------------------------------------------------------------------- +# Tier 6: __c_pointers__ boundary +# --------------------------------------------------------------------------- + + +def test_c_pointers_static_value_works(): + """Tier 6.25 -- static-value Complex64 produces an 8-byte pointer.""" + c = Complex64(complex(2.0, 1.0)) + ptrs = c.__c_pointers__() + assert len(ptrs) == 1 + + +def test_c_pointers_dynamic_value_rejected(): + """Tier 6.26 -- a Complex64 carrying a non-float self.value raises. + + Simulates the "dynamic SSA" case (which can only really arise inside a + kernel context) by mutating self.value to a non-float.""" + c = Complex64(complex(1.0, 0.0)) + c.value = "not-a-float" # simulate dynamic state + with pytest.raises(ValueError, match="dynamic SSA value"): + c.__c_pointers__() + + +# --------------------------------------------------------------------------- +# Tier 7: special / edge values +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "val", + [ + complex(0.0, 0.0), + complex(-0.0, -0.0), + complex(float("inf"), 0.0), + complex(0.0, float("inf")), + complex(float("-inf"), 0.0), + complex(float("nan"), 0.0), + complex(0.0, float("nan")), + complex(1e-38, -1e-38), + complex(1e30, -1e30), + complex(-1.0, -2.0), + ], +) +def test_static_special_value_roundtrip(val): + """Tier 7.27 -- bitcast plumbing preserves edge values exactly. + + Compare against the f32-rounded version of `val`, since Complex64 stores + each lane as f32 by construction. NaN equality is special-cased. + """ + cx = Complex64(val) + decoded = _f64bits_to_complex(cx.value) + expected_re = float(np.float32(val.real)) + expected_im = float(np.float32(val.imag)) + if math.isnan(expected_re): + assert math.isnan(decoded.real) + else: + assert decoded.real == expected_re, f"re: {decoded.real} != {expected_re}" + if math.isnan(expected_im): + assert math.isnan(decoded.imag) + else: + assert decoded.imag == expected_im, f"im: {decoded.imag} != {expected_im}" + + +def test_sign_sensitivity_kernel(): + """Tier 7.28 -- all four sign combinations land at the right (re, im).""" + cases = [ + complex(-1.0, -2.0), + complex(-1.0, 2.0), + complex(1.0, -2.0), + complex(1.0, 2.0), + ] + + def body(mO): + for i, c in enumerate(cases): + mO[i] = Complex64(c) + + out = _run_oneshot(body, len(cases)) + for i, c in enumerate(cases): + assert out[i] == c, f"sign case {i}: {out[i]} != {c}" + + +# --------------------------------------------------------------------------- +# Tier 8: documented-behavior pins (auto-promotion via _cvt_to_dest) +# --------------------------------------------------------------------------- + + +def test_float32_writes_promote_to_complex64(): + """Tier 8.29 -- writing a Float32 to a Complex64 tensor lands as (val, 0).""" + + def body(mO): + rmem = cute.make_rmem_tensor(2, Complex64) + rmem[0] = Float32(5.0) # promote: (5, 0) + rmem[1] = Float32(-2.5) # promote: (-2.5, 0) + for i in range(2): + mO[i] = rmem[i] + + out = _run_oneshot(body, 2) + assert out[0] == complex(5.0, 0.0) + assert out[1] == complex(-2.5, 0.0) + + +def test_float64_writes_promote_to_complex64(): + """Tier 8.30 -- pin behavior of Float64 written to a Complex64 tensor.""" + + def body(mO): + rmem = cute.make_rmem_tensor(1, Complex64) + rmem[0] = Float64(3.14) # whatever Numeric promotion does, pin it + mO[0] = rmem[0] + + out = _run_oneshot(body, 1) + # Currently Numeric.is_same_kind(Float, Float) plus width >= width triggers + # data.to(Complex64), which goes through Float32(Float64) (cvtf truncate to f32), + # then the Numeric branch packs (re=truncated, im=0). + assert out[0] == complex(np.float32(3.14), 0.0) From f56cac34706c1af0b833347510cb801921bfae56 Mon Sep 17 00:00:00 2001 From: GarlGuo Date: Wed, 6 May 2026 17:00:35 -0400 Subject: [PATCH 2/3] [Explore] Localize SM100 gated (2,2) warp-shape bug Reverts the _valid_2cta_m overrides on GemmGatedMixin and GemmDGatedMixin so the bug fires, plus adds Python print() instrumentation in: - quack/gemm_base.py: D-path TMA atom inputs - quack/epi_ops.py: TileStore aux-path inputs and outputs - quack/gemm_act.py: epi_visit_subtile register layouts (compile-time tracing) and epi_setup_aux_out tiled-copy + partition_D dump INVESTIGATION_22_WARP.md captures the smoking gun: tiled_copy_aux_out_r2s (built via make_tiled_copy_S(aux_atom, tiled_copy_r2s)) inherits D's full-N tiler MN of 64x64, but is then applied to aux smem of 64x32. Each smem position is written by two threads -- warp 1's data gets clobbered by warp 0's, then TMA copies the duplicated smem to two distinct gmem positions, producing the observed gmem[0..15] == gmem[64..79] corruption. This is exploration only. Not for merge. Co-Authored-By: Claude Opus 4.7 (1M context) --- INVESTIGATION_22_WARP.md | 111 +++++++++++++++++++++++++++++++++++++++ instr_run.py | 72 +++++++++++++++++++++++++ quack/epi_ops.py | 16 ++++++ quack/gemm_act.py | 32 +++++++---- quack/gemm_base.py | 13 +++++ quack/gemm_dact.py | 6 --- quack/gemm_sm100.py | 11 +--- solo_ab_min.py | 82 +++++++++++++++++++++++++++++ 8 files changed, 316 insertions(+), 27 deletions(-) create mode 100644 INVESTIGATION_22_WARP.md create mode 100644 instr_run.py create mode 100644 solo_ab_min.py diff --git a/INVESTIGATION_22_WARP.md b/INVESTIGATION_22_WARP.md new file mode 100644 index 00000000..deac4de7 --- /dev/null +++ b/INVESTIGATION_22_WARP.md @@ -0,0 +1,111 @@ +# Investigation: SM100 gated `(2, 2)` epilogue warp-shape bug + +## Setup + +Branch: `explore-22-warp` (forked from `fix-gated-dgated` HEAD). + +The `_valid_2cta_m` overrides on `GemmGatedMixin` and `GemmDGatedMixin` +have been **reverted** so the bug fires. Plus `print()` instrumentation in +`quack/gemm_base.py`, `quack/epi_ops.py`, and `quack/gemm_act.py`. + +Repro: `instr_run.py`. Run with fresh `QUACK_CACHE_DIR` and +`QUACK_CACHE_ENABLED=0` to force re-compile each run. + +## Trigger + +`tile_m=128, cluster_m=2, is_dynamic_persistent=True, use_tma_gather=True` +on the gated forward path. With 2-CTA, `cta_tile_m=64`, which forces +`compute_epilogue_tile_shape` to a `(2, 2)` M-warps × N-warps layout. The +non-gated D path with the same warp shape works correctly — the bug is +specific to the gated half-N postact aux-out chain. + +## Localization (the smoking gun) + +`tiled_copy_aux_out_r2s` is built via: + + cute.make_tiled_copy_S(aux_atom, tiled_copy_r2s) + +`make_tiled_copy_S` keeps the source-side threading from `tiled_copy_r2s` +(D's r2s copy) and only swaps the per-atom store op. The Tiler MN is +inherited verbatim — full-N D dimensions, NOT half-N aux dimensions. + +Side-by-side for `tile_m=128, cm=2, swiglu fp16`, `cta_tile_shape=(64,256)`: + +| object | shape / layout | +|---------------------------------------|-------------------------------------------------------------------------| +| D's r2s `tiled_copy_r2s` Tiler MN | `((2,32):(32,1), (2,32):(32,1))` = 64M × 64N | +| Aux's r2s `tiled_copy_aux_out_r2s` | `((2,32):(32,1), (2,32):(32,1))` = 64M × 64N (**same as D**) | +| Aux's r2s TV layout | `((32,2,2),(1,32)):((2,1,64),(0,128))` -- 32 values per thread | +| Aux's smem `sAuxOut.layout` | `((8,8),(16,2),(1,2)):((16,128),(1,1024),(0,2048))` = 64M × 32N | +| D's smem `sD.layout` | `((8,8),(32,2),(1,2)):((32,256),(1,2048),(0,4096))` = 64M × 64N | + +**Mismatch:** the aux r2s copy has a 64×64 tiler producing 32 values per +thread × 128 threads = 4096 elements, but aux smem per stage holds only +64×32 = 2048 elements. Each aux smem position is written by **two +threads** -- warp 0's threads and warp 1's threads collide on the same +smem range. Whichever thread arrives last "wins"; warp 1's data is lost. + +The TMA descriptor for aux *is* correct (it scatters smem regions to gmem +at warp-stride 64). It's just that smem holds duplicated data when the TMA +reads it -- both the (smem) "warp 0 region" and the (smem) "warp 1 region" +hold warp 0's values after the r2s race. TMA then dutifully writes warp 0's +values to gmem `[0..15]` and warp 0's values again to gmem `[64..79]`, +producing the observed: + + postact[0, 0..15] = warp 0's values (correct) + postact[0, 64..79] = warp 0's values (DUPLICATE -- should be warp 1) + +## Why the (4, 1) warp shape works + +For `cluster_m=1` (= `(4, 1)` warp shape), `epi_tile_n` is just `int 32` +(no Layout). After `_gated_epi_tile_fn` halves to `int 16`, aux smem is +flat with only 1 N-warp. The Tiler MN match between D and aux remains +"D's full-N tile = aux's full-N tile" because there's no warp-N split in +either; the per-thread value count of 16 lands cleanly in aux smem with no +collision. + +Per-thread `tRS_rD.layout` is `((1,32),1,1):((0,1),0,0)` for **both** warp +shapes. The bug is purely in the destination-side (smem) partitioning of +`tiled_copy_aux_out_r2s`, not in registers or in `act_fn` indexing. + +## Why D's full-N path is unaffected + +D's smem layout has 64 N elements (twice aux's), with warp 1 at smem +stride 2048. D's r2s tiler `((2,32),(2,32))` produces 32 values per +thread × 128 threads = 4096 elements -- matches D smem per stage exactly. +No collision. + +## Fix direction + +The aux r2s tiled copy must be re-tiled to match aux's tile dimensions +(half N) before being used to partition `sAuxOut`. Two plausible builders: + +1. Build from scratch via `make_tiled_copy_D(aux_atom, sAuxOut.layout)` so + the destination shape comes from aux smem rather than D's r2s. +2. Re-tile `tiled_copy_r2s` to halve its N extent before passing through + `make_tiled_copy_S`. + +Either approach requires careful handling of the per-thread register slice +(`tRS_rAuxOut` has 16 fp32 elements per thread, derived via +`recast_layout(2, 1, tRS_rD.layout)`). The atom returned by +`sm100_utils.get_smem_store_op(aux_layout, aux_dtype, acc_dtype, tiled_copy_t2r)` +is selected based on `tiled_copy_t2r` (D's full-N pattern) -- it likely +needs to be rebuilt from a t2r-equivalent for aux's half-N slice as well. + +This is real cuTeDSL design work. The current `_valid_2cta_m` override on +`GemmGatedMixin` / `GemmDGatedMixin` is the practical workaround; this +investigation explains exactly why the override is needed and what would +need to change to remove it. + +## Reproduction commands + +```bash +git checkout explore-22-warp +CACHE=$(mktemp -d /tmp/quack_explore_XXXX) +CUDA_VISIBLE_DEVICES=0 QUACK_CACHE_DIR=$CACHE QUACK_CACHE_ENABLED=0 \ + python instr_run.py +# CLUSTER_M=1 to compare against the working (4, 1) warp shape. +``` + +The instrumentation prints D-path and aux-path layouts side by side; the +mismatch in Tiler MN vs sAuxOut shape is the smoking gun. diff --git a/instr_run.py b/instr_run.py new file mode 100644 index 00000000..ac243f9b --- /dev/null +++ b/instr_run.py @@ -0,0 +1,72 @@ +"""Tiny repro to capture instrumentation prints from the gated forward path +at the buggy cocktail (tile_m=128, cm=2, clc=True, gather=True). + +Uses small M to keep output manageable. The instrumentation print()s in +quack/gemm_base.py (D path) and quack/epi_ops.py (TileStore.to_params, aux +path) will fire during kernel construction and dump epi_tile + smem_layout ++ tma_atom for both D and aux side-by-side. +""" +import os +import sys +import torch + +from quack.gemm_config import GemmConfig +from quack.gemm_interface import gemm_gated_tuned, gemm_gated_ref + + +def main(): + M, H, I, E = 4096, 256, 128, 4 # small enough for tractable output + device = torch.device("cuda:0") + dtype = torch.float16 + g = torch.Generator(device=device).manual_seed(0) + counts = torch.full((E,), M // E, dtype=torch.int32, device=device) + cu = torch.zeros(E + 1, dtype=torch.int32, device=device) + cu[1:] = torch.cumsum(counts, dim=0).to(torch.int32) + T = M // 4 + x = (0.02 * torch.randn(T, H, generator=g, device=device, dtype=torch.float32)).to(dtype) + A_idx = torch.randint(0, T, (M,), dtype=torch.int32, device=device, generator=g) + w = torch.empty(E, 2 * I, H, dtype=torch.float32, device=device) + torch.nn.init.normal_(w, mean=0.0, std=0.02, generator=g) + w1 = w.to(dtype).permute(1, 2, 0).permute(2, 1, 0) + + cluster_m = int(os.environ.get("CLUSTER_M", "2")) + cfg = GemmConfig( + tile_m=128, tile_n=256, cluster_m=cluster_m, cluster_n=1, + swap_ab=False, max_swizzle_size=8, + is_dynamic_persistent=True, use_tma_gather=True, + pingpong=False, device_capacity=10, + ) + print(f"\n>>> cluster_m={cluster_m} (warp-shape: {(2,2) if cluster_m==2 else (4,1)})\n", flush=True) + pre = torch.empty(M, 2 * I, dtype=dtype, device=device) + post = torch.empty(M, I, dtype=dtype, device=device) + + print("\n========== INVOKING gemm_gated_tuned.fn (buggy cocktail) ==========\n", flush=True) + gemm_gated_tuned.fn( + x, w1, pre, post, None, None, "swiglu", cu, A_idx, False, config=cfg, + ) + torch.cuda.synchronize() + print("\n========== KERNEL EXECUTION DONE ==========\n", flush=True) + + pre_ref, post_ref = gemm_gated_ref( + x, w1, bias=None, activation="swiglu", + cu_seqlens_m=cu, A_idx=A_idx, + store_preact=True, concat_layout=None, + ) + pre_diff = (pre.float() - pre_ref.float()).abs() + post_diff = (post.float() - post_ref.float()).abs() + print(f"\npreact rel = {pre_diff.max().item() / max(pre_ref.float().abs().max().item(), 1e-12):.4e}") + print(f"postact rel = {post_diff.max().item() / max(post_ref.float().abs().max().item(), 1e-12):.4e}") + # Inspect output values at row 0, columns 0..15: which pattern of corruption? + print("\n postact row=0 cols=0..15:", post[0, :16].float().tolist()) + print(" postact_ref row=0 cols=0..15:", post_ref[0, :16].float().tolist()) + print("\n postact row=0 cols=64..79:", post[0, 64:80].float().tolist()) + print(" postact_ref row=0 cols=64..79:", post_ref[0, 64:80].float().tolist()) + # Check if postact has zeros (skipped writes) or shifted/scrambled values. + n_zeros = (post == 0).sum().item() + print(f"\n postact: total elems = {post.numel()}, n_zeros = {n_zeros} ({100*n_zeros/post.numel():.1f}%)") + n_zeros_ref = (post_ref == 0).sum().item() + print(f" postact_ref: n_zeros = {n_zeros_ref}") + + +if __name__ == "__main__": + main() diff --git a/quack/epi_ops.py b/quack/epi_ops.py index c14682cd..246ee7cb 100644 --- a/quack/epi_ops.py +++ b/quack/epi_ops.py @@ -511,9 +511,25 @@ def to_params(self, gemm, args): self._epi_tile_key(): None, } epi_tile = self.epi_tile_fn(gemm, gemm.epi_tile) if self.epi_tile_fn else None + # [INSTRUMENTATION] print the recast tile vs the full tile. + print( + f"[INSTR TileStore.to_params name={self.name}]\n" + f" cta_tile_shape_mnk={gemm.cta_tile_shape_mnk}\n" + f" use_2cta_instrs={gemm.use_2cta_instrs}\n" + f" gemm.epi_tile (full) = {gemm.epi_tile}\n" + f" epi_tile passed to setup = {epi_tile}", + flush=True, + ) tma_atom, tma_tensor, smem_layout, epi_tile_out = setup_epi_tensor( gemm, tensor, epi_tile=epi_tile ) + print( + f"[INSTR TileStore.to_params name={self.name}] post-setup\n" + f" smem_layout = {smem_layout}\n" + f" epi_tile_out = {epi_tile_out}\n" + f" tma_atom = {tma_atom}", + flush=True, + ) return { self._tma_atom_key(): tma_atom, self.name: tma_tensor, diff --git a/quack/gemm_act.py b/quack/gemm_act.py index 5d68cd6c..db790f8c 100644 --- a/quack/gemm_act.py +++ b/quack/gemm_act.py @@ -123,7 +123,16 @@ def epi_setup_aux_out( tiled_copy_aux_out_r2s = self.epi_make_aux_out_tiled_copy_r2s( params, tiled_copy_r2s, tiled_copy_t2r ) + # [INSTRUMENTATION] dump the destination partition shape. + print( + f"[INSTR epi_setup_aux_out]\n" + f" sAuxOut.layout = {sAuxOut.layout}\n" + f" tiled_copy_r2s = {tiled_copy_r2s}\n" + f" tiled_copy_aux_out_r2s = {tiled_copy_aux_out_r2s}", + flush=True, + ) tRS_sAuxOut = tiled_copy_aux_out_r2s.get_slice(tidx).partition_D(sAuxOut) + print(f" tRS_sAuxOut.layout = {tRS_sAuxOut.layout}", flush=True) batch_idx = tile_coord_mnkl[3] copy_aux_out, _, _ = self.epilog_gmem_copy_and_partition( params.tma_atom_mAuxOut, @@ -216,17 +225,6 @@ class GemmGatedMixin(GemmActMixin): TileStore("mAuxOut", epi_tile_fn=_gated_epi_tile_fn), ) - def _valid_2cta_m(self): - # mma_tiler_m=128 with 2-CTA gives cta_tile_m=64, which forces a (2, 2) - # epilogue warp shape in compute_epilogue_tile_shape. The non-contiguous - # epi_tile_n that this layout produces (e.g. shape (32, 2) stride (1, 128)) - # is recast by `_gated_epi_tile_fn` for the half-N postact tile, but the - # resulting smem/TMA partitioning miscomputes preact and postact (the - # corruption builds across persistent-kernel iterations). Until the - # gated epi_visit_subtile and aux-out r2s copy are taught about the - # (2, 2) layout, restrict 2-CTA to mma_tiler_m=256 for gated. - return (256,) - def epi_to_underlying_arguments( self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None ) -> GemmActMixin.EpilogueParams: @@ -265,6 +263,18 @@ def epi_visit_subtile( tRS_rAuxOut_layout = cute.recast_layout(2, 1, tRS_rD.layout) # If we don't have .shape here, the compiler generates local stores and loads tRS_rAuxOut = cute.make_rmem_tensor(tRS_rAuxOut_layout.shape, self.acc_dtype) + # [INSTRUMENTATION] compile-time print of register layouts (fires at JIT trace). + print( + f"[INSTR gated.epi_visit_subtile JIT] arch={self.arch}\n" + f" tRS_rD.layout = {tRS_rD.layout}\n" + f" tRS_rD.shape = {tRS_rD.shape}\n" + f" cute.size(tRS_rD) = {cute.size(tRS_rD)}\n" + f" tRS_rAuxOut_layout = {tRS_rAuxOut_layout}\n" + f" tRS_rAuxOut.layout = {tRS_rAuxOut.layout}\n" + f" tRS_rAuxOut.shape = {tRS_rAuxOut.shape}\n" + f" cute.size(tRS_rAuxOut) = {cute.size(tRS_rAuxOut)}", + flush=True, + ) if const_expr(self.arch != 100): for i in cutlass.range(cute.size(tRS_rAuxOut), unroll_full=True): tRS_rAuxOut[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1]) diff --git a/quack/gemm_base.py b/quack/gemm_base.py index 55cbad13..6f2c29d4 100644 --- a/quack/gemm_base.py +++ b/quack/gemm_base.py @@ -569,6 +569,15 @@ def make_tma_epilogue_atoms_and_tensors( ): tma_atom_d, tma_tensor_d = None, None if const_expr(mD is not None): + # [INSTRUMENTATION] print D-path inputs. + print( + f"[INSTR D-path make_tma_epilogue_atoms_and_tensors]\n" + f" cta_tile_shape_mnk = {self.cta_tile_shape_mnk}\n" + f" use_2cta_instrs = {self.use_2cta_instrs}\n" + f" self.epi_tile = {self.epi_tile}\n" + f" self.epi_smem_layout_staged = {self.epi_smem_layout_staged}", + flush=True, + ) tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors( copy_utils.create_ragged_tensor_for_tma(mD, ragged_dim=0, ptr_shift=True) if varlen_m @@ -579,6 +588,10 @@ def make_tma_epilogue_atoms_and_tensors( if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output) else "add", ) + print( + f"[INSTR D-path] tma_atom_d = {tma_atom_d}", + flush=True, + ) tma_atom_c, tma_tensor_c = None, None if const_expr(mC is not None): tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors( diff --git a/quack/gemm_dact.py b/quack/gemm_dact.py index 648064b2..d630f1bc 100644 --- a/quack/gemm_dact.py +++ b/quack/gemm_dact.py @@ -104,12 +104,6 @@ class GemmDGatedMixin(GemmActMixin): ) _extra_param_fields = (("act_bwd_fn", cutlass.Constexpr, None),) - def _valid_2cta_m(self): - # See GemmGatedMixin._valid_2cta_m: the (2, 2) epilogue warp shape that - # mma_tiler_m=128 + 2-CTA produces breaks the gated/dgated postact path - # on SM100. Restrict 2-CTA to mma_tiler_m=256 here too. - return (256,) - @mlir_namedtuple class EpilogueArguments(NamedTuple): mAuxOut: cute.Tensor diff --git a/quack/gemm_sm100.py b/quack/gemm_sm100.py index a5b85c20..df5edcab 100644 --- a/quack/gemm_sm100.py +++ b/quack/gemm_sm100.py @@ -169,7 +169,7 @@ def __init__( self.sf_vec_size = sf_vec_size self.blockscaled = sf_vec_size is not None assert len(mma_tiler_mnk) in [2, 3], "MMA tiler must be (M, N) or (M, N, K)" - valid_2cta_m = self._valid_2cta_m() + valid_2cta_m = (128, 256) if not self.blockscaled else (256,) self.use_2cta_instrs = cluster_shape_mnk[0] % 2 == 0 and mma_tiler_mnk[0] in valid_2cta_m self.cluster_shape_mnk = cluster_shape_mnk assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1" @@ -239,15 +239,6 @@ def epi_smem_warp_shape_mnk(self): ) return (warp_m, warp_n, 1) - def _valid_2cta_m(self): - """Return the set of mma_tiler[0] values for which 2-CTA MMA is enabled. - - Subclasses override to exclude shapes whose epilogue layout doesn't yet - support certain features (e.g. gated postact with the (2, 2) epilogue - warp shape produced by mma_tiler_m=128 + 2-CTA). - """ - return (128, 256) if not self.blockscaled else (256,) - def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments): """Set up configurations that are dependent on GEMM inputs diff --git a/solo_ab_min.py b/solo_ab_min.py new file mode 100644 index 00000000..f35f74d6 --- /dev/null +++ b/solo_ab_min.py @@ -0,0 +1,82 @@ +"""Minimal A/B test for the gated SM100 fix. + +Phase A — _valid_2cta_m returns (128, 256): the original buggy default. +Phase B — _valid_2cta_m returns (256,): the patched default. + +Each phase runs in its own subprocess with its own QUACK_CACHE_DIR so the +disk-backed jit_cache (whose key doesn't include use_2cta_instrs) can't serve +a stale cross-phase kernel. + +Usage: CUDA_VISIBLE_DEVICES= python solo_ab_min.py +""" +import json, os, shutil, subprocess, sys, tempfile, torch + + +def child(phase): + valid = (128, 256) if phase == "A" else (256,) + from quack.gemm_act import GemmGatedMixin + GemmGatedMixin._valid_2cta_m = lambda self, _v=valid: _v + + from quack.gemm_config import GemmConfig + from quack.gemm_interface import gemm_gated_tuned, gemm_gated_ref + + M, H, I, E = 32768, 1024, 512, 8 + dtype, dev = torch.float16, torch.device("cuda:0") + g = torch.Generator(device=dev).manual_seed(0) + counts = torch.full((E,), M // E, dtype=torch.int32, device=dev) + cu = torch.zeros(E + 1, dtype=torch.int32, device=dev) + cu[1:] = torch.cumsum(counts, dim=0).to(torch.int32) + T = M // 4 + x = (0.02 * torch.randn(T, H, generator=g, device=dev, dtype=torch.float32)).to(dtype) + A_idx = torch.randint(0, T, (M,), dtype=torch.int32, device=dev, generator=g) + w = torch.empty(E, 2 * I, H, dtype=torch.float32, device=dev) + torch.nn.init.normal_(w, mean=0.0, std=0.02, generator=g) + w1 = w.to(dtype).permute(1, 2, 0).permute(2, 1, 0) + + cfg = GemmConfig(tile_m=128, tile_n=256, cluster_m=2, cluster_n=1, + swap_ab=False, max_swizzle_size=8, + is_dynamic_persistent=True, use_tma_gather=True, + pingpong=False, device_capacity=10) + pre, post = torch.empty(M, 2 * I, dtype=dtype, device=dev), torch.empty(M, I, dtype=dtype, device=dev) + gemm_gated_tuned.fn(x, w1, pre, post, None, None, "swiglu", cu, A_idx, False, config=cfg) + torch.cuda.synchronize() + pre_ref, post_ref = gemm_gated_ref(x, w1, bias=None, activation="swiglu", + cu_seqlens_m=cu, A_idx=A_idx, + store_preact=True, concat_layout=None) + + print("PHASE_RESULT " + json.dumps({ + "phase": phase, "valid_2cta_m": list(valid), + "preact_max_abs": (pre.float() - pre_ref.float()).abs().max().item(), + "preact_max_ref": pre_ref.float().abs().max().item(), + "postact_max_abs": (post.float() - post_ref.float()).abs().max().item(), + "postact_max_ref": post_ref.float().abs().max().item(), + }), flush=True) + + +def run_phase(phase): + cache = tempfile.mkdtemp(prefix=f"quack_cache_{phase}_") + env = {**os.environ, "QUACK_CACHE_DIR": cache} + try: + out = subprocess.run([sys.executable, "-u", __file__, phase], + capture_output=True, text=True, env=env, timeout=600) + finally: + shutil.rmtree(cache, ignore_errors=True) + for line in out.stdout.splitlines(): + if line.startswith("PHASE_RESULT "): + return json.loads(line[len("PHASE_RESULT "):]) + sys.exit(f"phase {phase} produced no result\n{out.stdout}\n{out.stderr}") + + +def main(): + if len(sys.argv) > 1 and sys.argv[1] in ("A", "B"): + child(sys.argv[1]); return + + a, b = run_phase("A"), run_phase("B") + for d in (a, b): + rel = d["preact_max_abs"] / max(d["preact_max_ref"], 1e-12) + print(f"phase {d['phase']} _valid_2cta_m={str(tuple(d['valid_2cta_m'])):<10} " + f"preact rel={rel:.4e} ({'FAIL' if rel > 0.05 else 'PASS'})") + + +if __name__ == "__main__": + main() From b0c70d4c88897f3beb33f301b9c2e12fd4592e01 Mon Sep 17 00:00:00 2001 From: GarlGuo Date: Wed, 6 May 2026 18:25:51 -0400 Subject: [PATCH 3/3] [Explore] Real fix for SM100 gated (2,2) epi warp shape Replaces the _valid_2cta_m=(256,) workaround on GemmGatedMixin (and preventive workaround on GemmDGatedMixin) with a source-level fix in GemmGatedMixin.epi_make_aux_out_tiled_copy_r2s. Root cause: make_tiled_copy_S(aux_atom, tiled_copy_r2s) inherited D's full-N tiler MN (64x64) and applied it to aux's half-N smem (64x32), emitting 128 threads x 32 vals = 4096 elements into a 2048-element smem region. For (4, 1) the over-emission had stride 0 (harmless self- overwrite); for (2, 2) it had warp-N stride 1024 (corrupted warp 1's region with a duplicate of warp 0's data). Fix: build the aux r2s tiled copy via make_tiled_copy_tv with explicit layouts -- 128 threads as (cta_tile_aux_m, num_n_warps), each holding size(epi_tile_aux_n)/num_n_warps values along N. Tiler MN now matches aux smem exactly. SM90/SM120 keep the original construction since the Layout-typed epi_tile_n is SM100-specific. Removes the gemm_dact.py override (preventive, dgated has no half-N recast) and the gemm_sm100.py _valid_2cta_m method indirection. Verification with the b10ffed workaround reverted: - solo_ab_min.py: both phases PASS rel=0.0 - instr_run.py original buggy shape: postact rel=8.21e-4 PASS - 12 (M,H,I,E) x cluster_m combos: identical errors cm=1 vs cm=2 - test_untuned_buggy_tiles.py --shapes small: 208/208 PASS (4 shapes x 52 forced configs, fwd+bwd) - sweep_gated_dgated.py: 216/216 PASS INVESTIGATION_22_WARP.md captures the full investigation. Co-Authored-By: Claude Opus 4.7 (1M context) --- INVESTIGATION_22_WARP.md | 171 +++++++++++++++++++-------------------- instr_run.py | 6 +- quack/epi_ops.py | 16 ---- quack/gemm_act.py | 59 +++++++++----- quack/gemm_base.py | 13 --- 5 files changed, 125 insertions(+), 140 deletions(-) diff --git a/INVESTIGATION_22_WARP.md b/INVESTIGATION_22_WARP.md index deac4de7..efe591ba 100644 --- a/INVESTIGATION_22_WARP.md +++ b/INVESTIGATION_22_WARP.md @@ -1,111 +1,104 @@ -# Investigation: SM100 gated `(2, 2)` epilogue warp-shape bug +# Investigation: SM100 gated `(2, 2)` epilogue warp-shape bug — RESOLVED -## Setup +## TL;DR -Branch: `explore-22-warp` (forked from `fix-gated-dgated` HEAD). +**Bug fixed at the source level.** The `_valid_2cta_m` overrides on +`GemmGatedMixin` and `GemmDGatedMixin` (commit `b10ffed`) are no longer +needed; this branch removes them and replaces the workaround with a real +fix in `quack/gemm_act.py` — a 1-method override on `GemmGatedMixin` that +constructs the aux-out r2s tiled copy with explicit thread + value layouts +so the tiler MN matches aux smem. -The `_valid_2cta_m` overrides on `GemmGatedMixin` and `GemmDGatedMixin` -have been **reverted** so the bug fires. Plus `print()` instrumentation in -`quack/gemm_base.py`, `quack/epi_ops.py`, and `quack/gemm_act.py`. +## Verification -Repro: `instr_run.py`. Run with fresh `QUACK_CACHE_DIR` and -`QUACK_CACHE_ENABLED=0` to force re-compile each run. +With `b10ffed`'s overrides reverted on this branch (so the bug *would* fire +without the new fix): -## Trigger +| test | result | +|-------------------------------------------------------|--------------------------------------------------------| +| `solo_ab_min.py` | phase A rel=0.0000 PASS, phase B rel=0.0000 PASS | +| `instr_run.py` original buggy shape (M=32768, E=8) | preact rel=0, postact rel=8.21e-4 PASS | +| 6 (M, H, I, E) × 2 cluster_m configs | All 12 PASS, identical errors between cm=1 and cm=2 | +| `test_untuned_buggy_tiles.py --shapes small` | 208/208 PASS, 0 timeouts (4 shapes × 52 forced configs)| +| `sweep_gated_dgated.py` (216 autotuned shape grid) | 216/216 PASS | -`tile_m=128, cluster_m=2, is_dynamic_persistent=True, use_tma_gather=True` -on the gated forward path. With 2-CTA, `cta_tile_m=64`, which forces -`compute_epilogue_tile_shape` to a `(2, 2)` M-warps × N-warps layout. The -non-gated D path with the same warp shape works correctly — the bug is -specific to the gated half-N postact aux-out chain. +Phase A's monkey-patch in `solo_ab_min.py` is now a no-op (the override +method doesn't exist on the class); even so, with 2-CTA forced on, output +is correct. -## Localization (the smoking gun) +## Root cause (recap) -`tiled_copy_aux_out_r2s` is built via: +The gated postact tile has **half** the N elements of D's tile (via +`_gated_epi_tile_fn`'s `recast_layout(2, 1, ...)`). The original +construction at `gemm_act.py:104`: cute.make_tiled_copy_S(aux_atom, tiled_copy_r2s) -`make_tiled_copy_S` keeps the source-side threading from `tiled_copy_r2s` -(D's r2s copy) and only swaps the per-atom store op. The Tiler MN is -inherited verbatim — full-N D dimensions, NOT half-N aux dimensions. - -Side-by-side for `tile_m=128, cm=2, swiglu fp16`, `cta_tile_shape=(64,256)`: - -| object | shape / layout | -|---------------------------------------|-------------------------------------------------------------------------| -| D's r2s `tiled_copy_r2s` Tiler MN | `((2,32):(32,1), (2,32):(32,1))` = 64M × 64N | -| Aux's r2s `tiled_copy_aux_out_r2s` | `((2,32):(32,1), (2,32):(32,1))` = 64M × 64N (**same as D**) | -| Aux's r2s TV layout | `((32,2,2),(1,32)):((2,1,64),(0,128))` -- 32 values per thread | -| Aux's smem `sAuxOut.layout` | `((8,8),(16,2),(1,2)):((16,128),(1,1024),(0,2048))` = 64M × 32N | -| D's smem `sD.layout` | `((8,8),(32,2),(1,2)):((32,256),(1,2048),(0,4096))` = 64M × 64N | - -**Mismatch:** the aux r2s copy has a 64×64 tiler producing 32 values per -thread × 128 threads = 4096 elements, but aux smem per stage holds only -64×32 = 2048 elements. Each aux smem position is written by **two -threads** -- warp 0's threads and warp 1's threads collide on the same -smem range. Whichever thread arrives last "wins"; warp 1's data is lost. - -The TMA descriptor for aux *is* correct (it scatters smem regions to gmem -at warp-stride 64). It's just that smem holds duplicated data when the TMA -reads it -- both the (smem) "warp 0 region" and the (smem) "warp 1 region" -hold warp 0's values after the r2s race. TMA then dutifully writes warp 0's -values to gmem `[0..15]` and warp 0's values again to gmem `[64..79]`, -producing the observed: - - postact[0, 0..15] = warp 0's values (correct) - postact[0, 64..79] = warp 0's values (DUPLICATE -- should be warp 1) - -## Why the (4, 1) warp shape works - -For `cluster_m=1` (= `(4, 1)` warp shape), `epi_tile_n` is just `int 32` -(no Layout). After `_gated_epi_tile_fn` halves to `int 16`, aux smem is -flat with only 1 N-warp. The Tiler MN match between D and aux remains -"D's full-N tile = aux's full-N tile" because there's no warp-N split in -either; the per-thread value count of 16 lands cleanly in aux smem with no -collision. - -Per-thread `tRS_rD.layout` is `((1,32),1,1):((0,1),0,0)` for **both** warp -shapes. The bug is purely in the destination-side (smem) partitioning of -`tiled_copy_aux_out_r2s`, not in registers or in `act_fn` indexing. - -## Why D's full-N path is unaffected - -D's smem layout has 64 N elements (twice aux's), with warp 1 at smem -stride 2048. D's r2s tiler `((2,32),(2,32))` produces 32 values per -thread × 128 threads = 4096 elements -- matches D smem per stage exactly. -No collision. - -## Fix direction - -The aux r2s tiled copy must be re-tiled to match aux's tile dimensions -(half N) before being used to partition `sAuxOut`. Two plausible builders: +inherits **D's full-N tiler MN** (e.g. 64×64) and applies it to aux smem +which is half-N (e.g. 64×32). Per epi-iter, 128 threads × 32 vals/thread = +4096 elements get emitted into a 2048-element smem region — a 2× overlap. + +For the (4, 1) epi-warp shape this is harmless: the over-emission has +stride 0 in the smem layout's phantom N-warp dim (since there's only 1 +N-warp), so it's a no-op self-overwrite. For the (2, 2) shape, the smem +N-warp dim has stride 1024 — the over-emitted elements land at warp 1's +smem region, clobbering warp 1's data with a duplicate of warp 0's. TMA +then dutifully scatters the duplicated smem to two distinct gmem +positions, producing the observed corruption pattern at gmem[0..15] == +gmem[64..79]. + +The non-gated D path is unaffected because aux smem and D's smem have +the same dimensions there (no half-N recast). + +The dgated bwd path is unaffected because `GemmDGatedMixin._epi_ops` uses +`TileStore("mAuxOut")` with no `epi_tile_fn` (no half-N recast). The +preventive override that `b10ffed` added on `GemmDGatedMixin` was +empirically unneeded; the sweeps with that override removed all pass. + +## The fix + +`quack/gemm_act.py` adds an override on `GemmGatedMixin` only: + +```python +def epi_make_aux_out_tiled_copy_r2s(self, params, tiled_copy_r2s, tiled_copy_t2r): + if self.arch != 100: + return super().epi_make_aux_out_tiled_copy_r2s( + params, tiled_copy_r2s, tiled_copy_t2r + ) + copy_atom_aux_out_r2s = self.epi_make_aux_out_copy_atom_r2s(params, tiled_copy_t2r) + cta_tile_aux_m, _ = self.cta_tile_shape_aux_out_mn + _, num_n_warps, _ = self.epi_smem_warp_shape_mnk() + epi_tile_aux_n = cute.size(params.epi_tile_mAuxOut[1]) + vals_per_thread_n = epi_tile_aux_n // num_n_warps + thr_layout = cute.make_layout( + (cta_tile_aux_m, num_n_warps), stride=(1, cta_tile_aux_m) + ) + val_layout = cute.make_layout((1, vals_per_thread_n)) + return cute.make_tiled_copy_tv(copy_atom_aux_out_r2s, thr_layout, val_layout) +``` -1. Build from scratch via `make_tiled_copy_D(aux_atom, sAuxOut.layout)` so - the destination shape comes from aux smem rather than D's r2s. -2. Re-tile `tiled_copy_r2s` to halve its N extent before passing through - `make_tiled_copy_S`. +Threading is `(cta_tile_aux_m, num_n_warps)` with stride `(1, cta_tile_aux_m)` +— 128 threads laid out as 1 thread per (M-row, N-warp) cell. Each thread +holds `vals_per_thread_n = size(epi_tile_aux_n) / num_n_warps` values +along N. Total = 128 × `vals_per_thread_n` = aux smem per stage exactly, +no overlap. SM90/SM120 fall back to the original construction (the +Layout-typed `epi_tile_n` is SM100-specific via `compute_epilogue_tile_shape`). -Either approach requires careful handling of the per-thread register slice -(`tRS_rAuxOut` has 16 fp32 elements per thread, derived via -`recast_layout(2, 1, tRS_rD.layout)`). The atom returned by -`sm100_utils.get_smem_store_op(aux_layout, aux_dtype, acc_dtype, tiled_copy_t2r)` -is selected based on `tiled_copy_t2r` (D's full-N pattern) -- it likely -needs to be rebuilt from a t2r-equivalent for aux's half-N slice as well. +## What was removed in this branch -This is real cuTeDSL design work. The current `_valid_2cta_m` override on -`GemmGatedMixin` / `GemmDGatedMixin` is the practical workaround; this -investigation explains exactly why the override is needed and what would -need to change to remove it. +- `GemmGatedMixin._valid_2cta_m -> (256,)` (workaround, no longer needed). +- `GemmDGatedMixin._valid_2cta_m -> (256,)` (preventive workaround, + empirically unneeded — dgated has no half-N recast and no (2, 2) bug). +- `GemmSm100._valid_2cta_m()` method indirection (introduced by `b10ffed` + to support the workaround). -## Reproduction commands +## Reproduction ```bash git checkout explore-22-warp CACHE=$(mktemp -d /tmp/quack_explore_XXXX) CUDA_VISIBLE_DEVICES=0 QUACK_CACHE_DIR=$CACHE QUACK_CACHE_ENABLED=0 \ python instr_run.py -# CLUSTER_M=1 to compare against the working (4, 1) warp shape. +# CLUSTER_M=2 (default) reproduces the previously-buggy cocktail; both +# phase A and phase B of solo_ab_min.py now PASS with rel=0. ``` - -The instrumentation prints D-path and aux-path layouts side by side; the -mismatch in Tiler MN vs sAuxOut shape is the smoking gun. diff --git a/instr_run.py b/instr_run.py index ac243f9b..1ffd8486 100644 --- a/instr_run.py +++ b/instr_run.py @@ -15,7 +15,11 @@ def main(): - M, H, I, E = 4096, 256, 128, 4 # small enough for tractable output + import os + M = int(os.environ.get("M", "4096")) + H = int(os.environ.get("H", "256")) + I = int(os.environ.get("I", "128")) + E = int(os.environ.get("E", "4")) device = torch.device("cuda:0") dtype = torch.float16 g = torch.Generator(device=device).manual_seed(0) diff --git a/quack/epi_ops.py b/quack/epi_ops.py index 246ee7cb..c14682cd 100644 --- a/quack/epi_ops.py +++ b/quack/epi_ops.py @@ -511,25 +511,9 @@ def to_params(self, gemm, args): self._epi_tile_key(): None, } epi_tile = self.epi_tile_fn(gemm, gemm.epi_tile) if self.epi_tile_fn else None - # [INSTRUMENTATION] print the recast tile vs the full tile. - print( - f"[INSTR TileStore.to_params name={self.name}]\n" - f" cta_tile_shape_mnk={gemm.cta_tile_shape_mnk}\n" - f" use_2cta_instrs={gemm.use_2cta_instrs}\n" - f" gemm.epi_tile (full) = {gemm.epi_tile}\n" - f" epi_tile passed to setup = {epi_tile}", - flush=True, - ) tma_atom, tma_tensor, smem_layout, epi_tile_out = setup_epi_tensor( gemm, tensor, epi_tile=epi_tile ) - print( - f"[INSTR TileStore.to_params name={self.name}] post-setup\n" - f" smem_layout = {smem_layout}\n" - f" epi_tile_out = {epi_tile_out}\n" - f" tma_atom = {tma_atom}", - flush=True, - ) return { self._tma_atom_key(): tma_atom, self.name: tma_tensor, diff --git a/quack/gemm_act.py b/quack/gemm_act.py index db790f8c..8f5d2ad1 100644 --- a/quack/gemm_act.py +++ b/quack/gemm_act.py @@ -123,16 +123,7 @@ def epi_setup_aux_out( tiled_copy_aux_out_r2s = self.epi_make_aux_out_tiled_copy_r2s( params, tiled_copy_r2s, tiled_copy_t2r ) - # [INSTRUMENTATION] dump the destination partition shape. - print( - f"[INSTR epi_setup_aux_out]\n" - f" sAuxOut.layout = {sAuxOut.layout}\n" - f" tiled_copy_r2s = {tiled_copy_r2s}\n" - f" tiled_copy_aux_out_r2s = {tiled_copy_aux_out_r2s}", - flush=True, - ) tRS_sAuxOut = tiled_copy_aux_out_r2s.get_slice(tidx).partition_D(sAuxOut) - print(f" tRS_sAuxOut.layout = {tRS_sAuxOut.layout}", flush=True) batch_idx = tile_coord_mnkl[3] copy_aux_out, _, _ = self.epilog_gmem_copy_and_partition( params.tma_atom_mAuxOut, @@ -225,6 +216,44 @@ class GemmGatedMixin(GemmActMixin): TileStore("mAuxOut", epi_tile_fn=_gated_epi_tile_fn), ) + def epi_make_aux_out_tiled_copy_r2s(self, params, tiled_copy_r2s, tiled_copy_t2r): + """Build the register-to-shared tiled copy used by gated aux outputs. + + Unlike the non-gated path, the gated postact tile has half the N elements + of D (via `_gated_epi_tile_fn`'s `recast_layout(2, 1, ...)`). The + straightforward `make_tiled_copy_S(aux_atom, tiled_copy_r2s)` inherits D's + full-N tiler MN, which over-emits by 2x when applied to the half-N aux + smem. For the (4, 1) epi-warp shape (cta_tile_m != 64 or 1-CTA) this is + harmless because the over-emission has stride 0 in smem (a phantom + N-warp dim), but for the (2, 2) shape (cta_tile_m=64 + 2-CTA on SM100) + the over-emission has the warp-N stride and corrupts warp 1's smem + region with a duplicate of warp 0's data. + + Build the aux r2s tiled copy explicitly to match aux's + (cta_tile_aux_m, size(epi_tile_aux_n)) tile: 1 thread per (M, N-warp) + position, each thread holding `size(epi_tile_aux_n) / num_n_warps` + values along N. That places every thread's writes into a single warp's + smem region with no aliasing across warps. Only applied for SM100 (the + (2, 2) layout is SM100-specific). + """ + if self.arch != 100: + return super().epi_make_aux_out_tiled_copy_r2s( + params, tiled_copy_r2s, tiled_copy_t2r + ) + copy_atom_aux_out_r2s = self.epi_make_aux_out_copy_atom_r2s(params, tiled_copy_t2r) + cta_tile_aux_m, _ = self.cta_tile_shape_aux_out_mn + _, num_n_warps, _ = self.epi_smem_warp_shape_mnk() + # epi_tile_aux_n size: the N mode of epi_tile_mAuxOut may be a Layout + # (e.g. (16,2):(1,64)) when the (2, 2) warp shape is in effect, or an int + # otherwise. cute.size() handles both. + epi_tile_aux_n = cute.size(params.epi_tile_mAuxOut[1]) + vals_per_thread_n = epi_tile_aux_n // num_n_warps + thr_layout = cute.make_layout( + (cta_tile_aux_m, num_n_warps), stride=(1, cta_tile_aux_m) + ) + val_layout = cute.make_layout((1, vals_per_thread_n)) + return cute.make_tiled_copy_tv(copy_atom_aux_out_r2s, thr_layout, val_layout) + def epi_to_underlying_arguments( self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None ) -> GemmActMixin.EpilogueParams: @@ -263,18 +292,6 @@ def epi_visit_subtile( tRS_rAuxOut_layout = cute.recast_layout(2, 1, tRS_rD.layout) # If we don't have .shape here, the compiler generates local stores and loads tRS_rAuxOut = cute.make_rmem_tensor(tRS_rAuxOut_layout.shape, self.acc_dtype) - # [INSTRUMENTATION] compile-time print of register layouts (fires at JIT trace). - print( - f"[INSTR gated.epi_visit_subtile JIT] arch={self.arch}\n" - f" tRS_rD.layout = {tRS_rD.layout}\n" - f" tRS_rD.shape = {tRS_rD.shape}\n" - f" cute.size(tRS_rD) = {cute.size(tRS_rD)}\n" - f" tRS_rAuxOut_layout = {tRS_rAuxOut_layout}\n" - f" tRS_rAuxOut.layout = {tRS_rAuxOut.layout}\n" - f" tRS_rAuxOut.shape = {tRS_rAuxOut.shape}\n" - f" cute.size(tRS_rAuxOut) = {cute.size(tRS_rAuxOut)}", - flush=True, - ) if const_expr(self.arch != 100): for i in cutlass.range(cute.size(tRS_rAuxOut), unroll_full=True): tRS_rAuxOut[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1]) diff --git a/quack/gemm_base.py b/quack/gemm_base.py index 6f2c29d4..55cbad13 100644 --- a/quack/gemm_base.py +++ b/quack/gemm_base.py @@ -569,15 +569,6 @@ def make_tma_epilogue_atoms_and_tensors( ): tma_atom_d, tma_tensor_d = None, None if const_expr(mD is not None): - # [INSTRUMENTATION] print D-path inputs. - print( - f"[INSTR D-path make_tma_epilogue_atoms_and_tensors]\n" - f" cta_tile_shape_mnk = {self.cta_tile_shape_mnk}\n" - f" use_2cta_instrs = {self.use_2cta_instrs}\n" - f" self.epi_tile = {self.epi_tile}\n" - f" self.epi_smem_layout_staged = {self.epi_smem_layout_staged}", - flush=True, - ) tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors( copy_utils.create_ragged_tensor_for_tma(mD, ragged_dim=0, ptr_shift=True) if varlen_m @@ -588,10 +579,6 @@ def make_tma_epilogue_atoms_and_tensors( if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output) else "add", ) - print( - f"[INSTR D-path] tma_atom_d = {tma_atom_d}", - flush=True, - ) tma_atom_c, tma_tensor_c = None, None if const_expr(mC is not None): tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(