diff --git a/INVESTIGATION_22_WARP.md b/INVESTIGATION_22_WARP.md new file mode 100644 index 0000000..efe591b --- /dev/null +++ b/INVESTIGATION_22_WARP.md @@ -0,0 +1,104 @@ +# Investigation: SM100 gated `(2, 2)` epilogue warp-shape bug — RESOLVED + +## TL;DR + +**Bug fixed at the source level.** The `_valid_2cta_m` overrides on +`GemmGatedMixin` and `GemmDGatedMixin` (commit `b10ffed`) are no longer +needed; this branch removes them and replaces the workaround with a real +fix in `quack/gemm_act.py` — a 1-method override on `GemmGatedMixin` that +constructs the aux-out r2s tiled copy with explicit thread + value layouts +so the tiler MN matches aux smem. + +## Verification + +With `b10ffed`'s overrides reverted on this branch (so the bug *would* fire +without the new fix): + +| test | result | +|-------------------------------------------------------|--------------------------------------------------------| +| `solo_ab_min.py` | phase A rel=0.0000 PASS, phase B rel=0.0000 PASS | +| `instr_run.py` original buggy shape (M=32768, E=8) | preact rel=0, postact rel=8.21e-4 PASS | +| 6 (M, H, I, E) × 2 cluster_m configs | All 12 PASS, identical errors between cm=1 and cm=2 | +| `test_untuned_buggy_tiles.py --shapes small` | 208/208 PASS, 0 timeouts (4 shapes × 52 forced configs)| +| `sweep_gated_dgated.py` (216 autotuned shape grid) | 216/216 PASS | + +Phase A's monkey-patch in `solo_ab_min.py` is now a no-op (the override +method doesn't exist on the class); even so, with 2-CTA forced on, output +is correct. + +## Root cause (recap) + +The gated postact tile has **half** the N elements of D's tile (via +`_gated_epi_tile_fn`'s `recast_layout(2, 1, ...)`). The original +construction at `gemm_act.py:104`: + + cute.make_tiled_copy_S(aux_atom, tiled_copy_r2s) + +inherits **D's full-N tiler MN** (e.g. 64×64) and applies it to aux smem +which is half-N (e.g. 64×32). Per epi-iter, 128 threads × 32 vals/thread = +4096 elements get emitted into a 2048-element smem region — a 2× overlap. + +For the (4, 1) epi-warp shape this is harmless: the over-emission has +stride 0 in the smem layout's phantom N-warp dim (since there's only 1 +N-warp), so it's a no-op self-overwrite. For the (2, 2) shape, the smem +N-warp dim has stride 1024 — the over-emitted elements land at warp 1's +smem region, clobbering warp 1's data with a duplicate of warp 0's. TMA +then dutifully scatters the duplicated smem to two distinct gmem +positions, producing the observed corruption pattern at gmem[0..15] == +gmem[64..79]. + +The non-gated D path is unaffected because aux smem and D's smem have +the same dimensions there (no half-N recast). + +The dgated bwd path is unaffected because `GemmDGatedMixin._epi_ops` uses +`TileStore("mAuxOut")` with no `epi_tile_fn` (no half-N recast). The +preventive override that `b10ffed` added on `GemmDGatedMixin` was +empirically unneeded; the sweeps with that override removed all pass. + +## The fix + +`quack/gemm_act.py` adds an override on `GemmGatedMixin` only: + +```python +def epi_make_aux_out_tiled_copy_r2s(self, params, tiled_copy_r2s, tiled_copy_t2r): + if self.arch != 100: + return super().epi_make_aux_out_tiled_copy_r2s( + params, tiled_copy_r2s, tiled_copy_t2r + ) + copy_atom_aux_out_r2s = self.epi_make_aux_out_copy_atom_r2s(params, tiled_copy_t2r) + cta_tile_aux_m, _ = self.cta_tile_shape_aux_out_mn + _, num_n_warps, _ = self.epi_smem_warp_shape_mnk() + epi_tile_aux_n = cute.size(params.epi_tile_mAuxOut[1]) + vals_per_thread_n = epi_tile_aux_n // num_n_warps + thr_layout = cute.make_layout( + (cta_tile_aux_m, num_n_warps), stride=(1, cta_tile_aux_m) + ) + val_layout = cute.make_layout((1, vals_per_thread_n)) + return cute.make_tiled_copy_tv(copy_atom_aux_out_r2s, thr_layout, val_layout) +``` + +Threading is `(cta_tile_aux_m, num_n_warps)` with stride `(1, cta_tile_aux_m)` +— 128 threads laid out as 1 thread per (M-row, N-warp) cell. Each thread +holds `vals_per_thread_n = size(epi_tile_aux_n) / num_n_warps` values +along N. Total = 128 × `vals_per_thread_n` = aux smem per stage exactly, +no overlap. SM90/SM120 fall back to the original construction (the +Layout-typed `epi_tile_n` is SM100-specific via `compute_epilogue_tile_shape`). + +## What was removed in this branch + +- `GemmGatedMixin._valid_2cta_m -> (256,)` (workaround, no longer needed). +- `GemmDGatedMixin._valid_2cta_m -> (256,)` (preventive workaround, + empirically unneeded — dgated has no half-N recast and no (2, 2) bug). +- `GemmSm100._valid_2cta_m()` method indirection (introduced by `b10ffed` + to support the workaround). + +## Reproduction + +```bash +git checkout explore-22-warp +CACHE=$(mktemp -d /tmp/quack_explore_XXXX) +CUDA_VISIBLE_DEVICES=0 QUACK_CACHE_DIR=$CACHE QUACK_CACHE_ENABLED=0 \ + python instr_run.py +# CLUSTER_M=2 (default) reproduces the previously-buggy cocktail; both +# phase A and phase B of solo_ab_min.py now PASS with rel=0. +``` diff --git a/instr_run.py b/instr_run.py new file mode 100644 index 0000000..1ffd848 --- /dev/null +++ b/instr_run.py @@ -0,0 +1,76 @@ +"""Tiny repro to capture instrumentation prints from the gated forward path +at the buggy cocktail (tile_m=128, cm=2, clc=True, gather=True). + +Uses small M to keep output manageable. The instrumentation print()s in +quack/gemm_base.py (D path) and quack/epi_ops.py (TileStore.to_params, aux +path) will fire during kernel construction and dump epi_tile + smem_layout ++ tma_atom for both D and aux side-by-side. +""" +import os +import sys +import torch + +from quack.gemm_config import GemmConfig +from quack.gemm_interface import gemm_gated_tuned, gemm_gated_ref + + +def main(): + import os + M = int(os.environ.get("M", "4096")) + H = int(os.environ.get("H", "256")) + I = int(os.environ.get("I", "128")) + E = int(os.environ.get("E", "4")) + device = torch.device("cuda:0") + dtype = torch.float16 + g = torch.Generator(device=device).manual_seed(0) + counts = torch.full((E,), M // E, dtype=torch.int32, device=device) + cu = torch.zeros(E + 1, dtype=torch.int32, device=device) + cu[1:] = torch.cumsum(counts, dim=0).to(torch.int32) + T = M // 4 + x = (0.02 * torch.randn(T, H, generator=g, device=device, dtype=torch.float32)).to(dtype) + A_idx = torch.randint(0, T, (M,), dtype=torch.int32, device=device, generator=g) + w = torch.empty(E, 2 * I, H, dtype=torch.float32, device=device) + torch.nn.init.normal_(w, mean=0.0, std=0.02, generator=g) + w1 = w.to(dtype).permute(1, 2, 0).permute(2, 1, 0) + + cluster_m = int(os.environ.get("CLUSTER_M", "2")) + cfg = GemmConfig( + tile_m=128, tile_n=256, cluster_m=cluster_m, cluster_n=1, + swap_ab=False, max_swizzle_size=8, + is_dynamic_persistent=True, use_tma_gather=True, + pingpong=False, device_capacity=10, + ) + print(f"\n>>> cluster_m={cluster_m} (warp-shape: {(2,2) if cluster_m==2 else (4,1)})\n", flush=True) + pre = torch.empty(M, 2 * I, dtype=dtype, device=device) + post = torch.empty(M, I, dtype=dtype, device=device) + + print("\n========== INVOKING gemm_gated_tuned.fn (buggy cocktail) ==========\n", flush=True) + gemm_gated_tuned.fn( + x, w1, pre, post, None, None, "swiglu", cu, A_idx, False, config=cfg, + ) + torch.cuda.synchronize() + print("\n========== KERNEL EXECUTION DONE ==========\n", flush=True) + + pre_ref, post_ref = gemm_gated_ref( + x, w1, bias=None, activation="swiglu", + cu_seqlens_m=cu, A_idx=A_idx, + store_preact=True, concat_layout=None, + ) + pre_diff = (pre.float() - pre_ref.float()).abs() + post_diff = (post.float() - post_ref.float()).abs() + print(f"\npreact rel = {pre_diff.max().item() / max(pre_ref.float().abs().max().item(), 1e-12):.4e}") + print(f"postact rel = {post_diff.max().item() / max(post_ref.float().abs().max().item(), 1e-12):.4e}") + # Inspect output values at row 0, columns 0..15: which pattern of corruption? + print("\n postact row=0 cols=0..15:", post[0, :16].float().tolist()) + print(" postact_ref row=0 cols=0..15:", post_ref[0, :16].float().tolist()) + print("\n postact row=0 cols=64..79:", post[0, 64:80].float().tolist()) + print(" postact_ref row=0 cols=64..79:", post_ref[0, 64:80].float().tolist()) + # Check if postact has zeros (skipped writes) or shifted/scrambled values. + n_zeros = (post == 0).sum().item() + print(f"\n postact: total elems = {post.numel()}, n_zeros = {n_zeros} ({100*n_zeros/post.numel():.1f}%)") + n_zeros_ref = (post_ref == 0).sum().item() + print(f" postact_ref: n_zeros = {n_zeros_ref}") + + +if __name__ == "__main__": + main() diff --git a/quack/gemm_act.py b/quack/gemm_act.py index d389a42..8f5d2ad 100644 --- a/quack/gemm_act.py +++ b/quack/gemm_act.py @@ -216,6 +216,44 @@ class GemmGatedMixin(GemmActMixin): TileStore("mAuxOut", epi_tile_fn=_gated_epi_tile_fn), ) + def epi_make_aux_out_tiled_copy_r2s(self, params, tiled_copy_r2s, tiled_copy_t2r): + """Build the register-to-shared tiled copy used by gated aux outputs. + + Unlike the non-gated path, the gated postact tile has half the N elements + of D (via `_gated_epi_tile_fn`'s `recast_layout(2, 1, ...)`). The + straightforward `make_tiled_copy_S(aux_atom, tiled_copy_r2s)` inherits D's + full-N tiler MN, which over-emits by 2x when applied to the half-N aux + smem. For the (4, 1) epi-warp shape (cta_tile_m != 64 or 1-CTA) this is + harmless because the over-emission has stride 0 in smem (a phantom + N-warp dim), but for the (2, 2) shape (cta_tile_m=64 + 2-CTA on SM100) + the over-emission has the warp-N stride and corrupts warp 1's smem + region with a duplicate of warp 0's data. + + Build the aux r2s tiled copy explicitly to match aux's + (cta_tile_aux_m, size(epi_tile_aux_n)) tile: 1 thread per (M, N-warp) + position, each thread holding `size(epi_tile_aux_n) / num_n_warps` + values along N. That places every thread's writes into a single warp's + smem region with no aliasing across warps. Only applied for SM100 (the + (2, 2) layout is SM100-specific). + """ + if self.arch != 100: + return super().epi_make_aux_out_tiled_copy_r2s( + params, tiled_copy_r2s, tiled_copy_t2r + ) + copy_atom_aux_out_r2s = self.epi_make_aux_out_copy_atom_r2s(params, tiled_copy_t2r) + cta_tile_aux_m, _ = self.cta_tile_shape_aux_out_mn + _, num_n_warps, _ = self.epi_smem_warp_shape_mnk() + # epi_tile_aux_n size: the N mode of epi_tile_mAuxOut may be a Layout + # (e.g. (16,2):(1,64)) when the (2, 2) warp shape is in effect, or an int + # otherwise. cute.size() handles both. + epi_tile_aux_n = cute.size(params.epi_tile_mAuxOut[1]) + vals_per_thread_n = epi_tile_aux_n // num_n_warps + thr_layout = cute.make_layout( + (cta_tile_aux_m, num_n_warps), stride=(1, cta_tile_aux_m) + ) + val_layout = cute.make_layout((1, vals_per_thread_n)) + return cute.make_tiled_copy_tv(copy_atom_aux_out_r2s, thr_layout, val_layout) + def epi_to_underlying_arguments( self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None ) -> GemmActMixin.EpilogueParams: diff --git a/solo_ab_min.py b/solo_ab_min.py new file mode 100644 index 0000000..f35f74d --- /dev/null +++ b/solo_ab_min.py @@ -0,0 +1,82 @@ +"""Minimal A/B test for the gated SM100 fix. + +Phase A — _valid_2cta_m returns (128, 256): the original buggy default. +Phase B — _valid_2cta_m returns (256,): the patched default. + +Each phase runs in its own subprocess with its own QUACK_CACHE_DIR so the +disk-backed jit_cache (whose key doesn't include use_2cta_instrs) can't serve +a stale cross-phase kernel. + +Usage: CUDA_VISIBLE_DEVICES= python solo_ab_min.py +""" +import json, os, shutil, subprocess, sys, tempfile, torch + + +def child(phase): + valid = (128, 256) if phase == "A" else (256,) + from quack.gemm_act import GemmGatedMixin + GemmGatedMixin._valid_2cta_m = lambda self, _v=valid: _v + + from quack.gemm_config import GemmConfig + from quack.gemm_interface import gemm_gated_tuned, gemm_gated_ref + + M, H, I, E = 32768, 1024, 512, 8 + dtype, dev = torch.float16, torch.device("cuda:0") + g = torch.Generator(device=dev).manual_seed(0) + counts = torch.full((E,), M // E, dtype=torch.int32, device=dev) + cu = torch.zeros(E + 1, dtype=torch.int32, device=dev) + cu[1:] = torch.cumsum(counts, dim=0).to(torch.int32) + T = M // 4 + x = (0.02 * torch.randn(T, H, generator=g, device=dev, dtype=torch.float32)).to(dtype) + A_idx = torch.randint(0, T, (M,), dtype=torch.int32, device=dev, generator=g) + w = torch.empty(E, 2 * I, H, dtype=torch.float32, device=dev) + torch.nn.init.normal_(w, mean=0.0, std=0.02, generator=g) + w1 = w.to(dtype).permute(1, 2, 0).permute(2, 1, 0) + + cfg = GemmConfig(tile_m=128, tile_n=256, cluster_m=2, cluster_n=1, + swap_ab=False, max_swizzle_size=8, + is_dynamic_persistent=True, use_tma_gather=True, + pingpong=False, device_capacity=10) + pre, post = torch.empty(M, 2 * I, dtype=dtype, device=dev), torch.empty(M, I, dtype=dtype, device=dev) + gemm_gated_tuned.fn(x, w1, pre, post, None, None, "swiglu", cu, A_idx, False, config=cfg) + torch.cuda.synchronize() + pre_ref, post_ref = gemm_gated_ref(x, w1, bias=None, activation="swiglu", + cu_seqlens_m=cu, A_idx=A_idx, + store_preact=True, concat_layout=None) + + print("PHASE_RESULT " + json.dumps({ + "phase": phase, "valid_2cta_m": list(valid), + "preact_max_abs": (pre.float() - pre_ref.float()).abs().max().item(), + "preact_max_ref": pre_ref.float().abs().max().item(), + "postact_max_abs": (post.float() - post_ref.float()).abs().max().item(), + "postact_max_ref": post_ref.float().abs().max().item(), + }), flush=True) + + +def run_phase(phase): + cache = tempfile.mkdtemp(prefix=f"quack_cache_{phase}_") + env = {**os.environ, "QUACK_CACHE_DIR": cache} + try: + out = subprocess.run([sys.executable, "-u", __file__, phase], + capture_output=True, text=True, env=env, timeout=600) + finally: + shutil.rmtree(cache, ignore_errors=True) + for line in out.stdout.splitlines(): + if line.startswith("PHASE_RESULT "): + return json.loads(line[len("PHASE_RESULT "):]) + sys.exit(f"phase {phase} produced no result\n{out.stdout}\n{out.stderr}") + + +def main(): + if len(sys.argv) > 1 and sys.argv[1] in ("A", "B"): + child(sys.argv[1]); return + + a, b = run_phase("A"), run_phase("B") + for d in (a, b): + rel = d["preact_max_abs"] / max(d["preact_max_ref"], 1e-12) + print(f"phase {d['phase']} _valid_2cta_m={str(tuple(d['valid_2cta_m'])):<10} " + f"preact rel={rel:.4e} ({'FAIL' if rel > 0.05 else 'PASS'})") + + +if __name__ == "__main__": + main()