Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions INVESTIGATION_22_WARP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Investigation: SM100 gated `(2, 2)` epilogue warp-shape bug — RESOLVED

## TL;DR

**Bug fixed at the source level.** The `_valid_2cta_m` overrides on
`GemmGatedMixin` and `GemmDGatedMixin` (commit `b10ffed`) are no longer
needed; this branch removes them and replaces the workaround with a real
fix in `quack/gemm_act.py` — a 1-method override on `GemmGatedMixin` that
constructs the aux-out r2s tiled copy with explicit thread + value layouts
so the tiler MN matches aux smem.

## Verification

With `b10ffed`'s overrides reverted on this branch (so the bug *would* fire
without the new fix):

| test | result |
|-------------------------------------------------------|--------------------------------------------------------|
| `solo_ab_min.py` | phase A rel=0.0000 PASS, phase B rel=0.0000 PASS |
| `instr_run.py` original buggy shape (M=32768, E=8) | preact rel=0, postact rel=8.21e-4 PASS |
| 6 (M, H, I, E) × 2 cluster_m configs | All 12 PASS, identical errors between cm=1 and cm=2 |
| `test_untuned_buggy_tiles.py --shapes small` | 208/208 PASS, 0 timeouts (4 shapes × 52 forced configs)|
| `sweep_gated_dgated.py` (216 autotuned shape grid) | 216/216 PASS |

Phase A's monkey-patch in `solo_ab_min.py` is now a no-op (the override
method doesn't exist on the class); even so, with 2-CTA forced on, output
is correct.

## Root cause (recap)

The gated postact tile has **half** the N elements of D's tile (via
`_gated_epi_tile_fn`'s `recast_layout(2, 1, ...)`). The original
construction at `gemm_act.py:104`:

cute.make_tiled_copy_S(aux_atom, tiled_copy_r2s)

inherits **D's full-N tiler MN** (e.g. 64×64) and applies it to aux smem
which is half-N (e.g. 64×32). Per epi-iter, 128 threads × 32 vals/thread =
4096 elements get emitted into a 2048-element smem region — a 2× overlap.

For the (4, 1) epi-warp shape this is harmless: the over-emission has
stride 0 in the smem layout's phantom N-warp dim (since there's only 1
N-warp), so it's a no-op self-overwrite. For the (2, 2) shape, the smem
N-warp dim has stride 1024 — the over-emitted elements land at warp 1's
smem region, clobbering warp 1's data with a duplicate of warp 0's. TMA
then dutifully scatters the duplicated smem to two distinct gmem
positions, producing the observed corruption pattern at gmem[0..15] ==
gmem[64..79].

The non-gated D path is unaffected because aux smem and D's smem have
the same dimensions there (no half-N recast).

The dgated bwd path is unaffected because `GemmDGatedMixin._epi_ops` uses
`TileStore("mAuxOut")` with no `epi_tile_fn` (no half-N recast). The
preventive override that `b10ffed` added on `GemmDGatedMixin` was
empirically unneeded; the sweeps with that override removed all pass.

## The fix

`quack/gemm_act.py` adds an override on `GemmGatedMixin` only:

```python
def epi_make_aux_out_tiled_copy_r2s(self, params, tiled_copy_r2s, tiled_copy_t2r):
if self.arch != 100:
return super().epi_make_aux_out_tiled_copy_r2s(
params, tiled_copy_r2s, tiled_copy_t2r
)
copy_atom_aux_out_r2s = self.epi_make_aux_out_copy_atom_r2s(params, tiled_copy_t2r)
cta_tile_aux_m, _ = self.cta_tile_shape_aux_out_mn
_, num_n_warps, _ = self.epi_smem_warp_shape_mnk()
epi_tile_aux_n = cute.size(params.epi_tile_mAuxOut[1])
vals_per_thread_n = epi_tile_aux_n // num_n_warps
thr_layout = cute.make_layout(
(cta_tile_aux_m, num_n_warps), stride=(1, cta_tile_aux_m)
)
val_layout = cute.make_layout((1, vals_per_thread_n))
return cute.make_tiled_copy_tv(copy_atom_aux_out_r2s, thr_layout, val_layout)
```

Threading is `(cta_tile_aux_m, num_n_warps)` with stride `(1, cta_tile_aux_m)`
— 128 threads laid out as 1 thread per (M-row, N-warp) cell. Each thread
holds `vals_per_thread_n = size(epi_tile_aux_n) / num_n_warps` values
along N. Total = 128 × `vals_per_thread_n` = aux smem per stage exactly,
no overlap. SM90/SM120 fall back to the original construction (the
Layout-typed `epi_tile_n` is SM100-specific via `compute_epilogue_tile_shape`).

## What was removed in this branch

- `GemmGatedMixin._valid_2cta_m -> (256,)` (workaround, no longer needed).
- `GemmDGatedMixin._valid_2cta_m -> (256,)` (preventive workaround,
empirically unneeded — dgated has no half-N recast and no (2, 2) bug).
- `GemmSm100._valid_2cta_m()` method indirection (introduced by `b10ffed`
to support the workaround).

## Reproduction

```bash
git checkout explore-22-warp
CACHE=$(mktemp -d /tmp/quack_explore_XXXX)
CUDA_VISIBLE_DEVICES=0 QUACK_CACHE_DIR=$CACHE QUACK_CACHE_ENABLED=0 \
python instr_run.py
# CLUSTER_M=2 (default) reproduces the previously-buggy cocktail; both
# phase A and phase B of solo_ab_min.py now PASS with rel=0.
```
76 changes: 76 additions & 0 deletions instr_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Tiny repro to capture instrumentation prints from the gated forward path
at the buggy cocktail (tile_m=128, cm=2, clc=True, gather=True).

Uses small M to keep output manageable. The instrumentation print()s in
quack/gemm_base.py (D path) and quack/epi_ops.py (TileStore.to_params, aux
path) will fire during kernel construction and dump epi_tile + smem_layout
+ tma_atom for both D and aux side-by-side.
"""
import os
import sys
import torch

from quack.gemm_config import GemmConfig
from quack.gemm_interface import gemm_gated_tuned, gemm_gated_ref


def main():
import os
M = int(os.environ.get("M", "4096"))
H = int(os.environ.get("H", "256"))
I = int(os.environ.get("I", "128"))
E = int(os.environ.get("E", "4"))
device = torch.device("cuda:0")
dtype = torch.float16
g = torch.Generator(device=device).manual_seed(0)
counts = torch.full((E,), M // E, dtype=torch.int32, device=device)
cu = torch.zeros(E + 1, dtype=torch.int32, device=device)
cu[1:] = torch.cumsum(counts, dim=0).to(torch.int32)
T = M // 4
x = (0.02 * torch.randn(T, H, generator=g, device=device, dtype=torch.float32)).to(dtype)
A_idx = torch.randint(0, T, (M,), dtype=torch.int32, device=device, generator=g)
w = torch.empty(E, 2 * I, H, dtype=torch.float32, device=device)
torch.nn.init.normal_(w, mean=0.0, std=0.02, generator=g)
w1 = w.to(dtype).permute(1, 2, 0).permute(2, 1, 0)

cluster_m = int(os.environ.get("CLUSTER_M", "2"))
cfg = GemmConfig(
tile_m=128, tile_n=256, cluster_m=cluster_m, cluster_n=1,
swap_ab=False, max_swizzle_size=8,
is_dynamic_persistent=True, use_tma_gather=True,
pingpong=False, device_capacity=10,
)
print(f"\n>>> cluster_m={cluster_m} (warp-shape: {(2,2) if cluster_m==2 else (4,1)})\n", flush=True)
pre = torch.empty(M, 2 * I, dtype=dtype, device=device)
post = torch.empty(M, I, dtype=dtype, device=device)

print("\n========== INVOKING gemm_gated_tuned.fn (buggy cocktail) ==========\n", flush=True)
gemm_gated_tuned.fn(
x, w1, pre, post, None, None, "swiglu", cu, A_idx, False, config=cfg,
)
torch.cuda.synchronize()
print("\n========== KERNEL EXECUTION DONE ==========\n", flush=True)

pre_ref, post_ref = gemm_gated_ref(
x, w1, bias=None, activation="swiglu",
cu_seqlens_m=cu, A_idx=A_idx,
store_preact=True, concat_layout=None,
)
pre_diff = (pre.float() - pre_ref.float()).abs()
post_diff = (post.float() - post_ref.float()).abs()
print(f"\npreact rel = {pre_diff.max().item() / max(pre_ref.float().abs().max().item(), 1e-12):.4e}")
print(f"postact rel = {post_diff.max().item() / max(post_ref.float().abs().max().item(), 1e-12):.4e}")
# Inspect output values at row 0, columns 0..15: which pattern of corruption?
print("\n postact row=0 cols=0..15:", post[0, :16].float().tolist())
print(" postact_ref row=0 cols=0..15:", post_ref[0, :16].float().tolist())
print("\n postact row=0 cols=64..79:", post[0, 64:80].float().tolist())
print(" postact_ref row=0 cols=64..79:", post_ref[0, 64:80].float().tolist())
# Check if postact has zeros (skipped writes) or shifted/scrambled values.
n_zeros = (post == 0).sum().item()
print(f"\n postact: total elems = {post.numel()}, n_zeros = {n_zeros} ({100*n_zeros/post.numel():.1f}%)")
n_zeros_ref = (post_ref == 0).sum().item()
print(f" postact_ref: n_zeros = {n_zeros_ref}")


if __name__ == "__main__":
main()
38 changes: 38 additions & 0 deletions quack/gemm_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,44 @@ class GemmGatedMixin(GemmActMixin):
TileStore("mAuxOut", epi_tile_fn=_gated_epi_tile_fn),
)

def epi_make_aux_out_tiled_copy_r2s(self, params, tiled_copy_r2s, tiled_copy_t2r):
"""Build the register-to-shared tiled copy used by gated aux outputs.

Unlike the non-gated path, the gated postact tile has half the N elements
of D (via `_gated_epi_tile_fn`'s `recast_layout(2, 1, ...)`). The
straightforward `make_tiled_copy_S(aux_atom, tiled_copy_r2s)` inherits D's
full-N tiler MN, which over-emits by 2x when applied to the half-N aux
smem. For the (4, 1) epi-warp shape (cta_tile_m != 64 or 1-CTA) this is
harmless because the over-emission has stride 0 in smem (a phantom
N-warp dim), but for the (2, 2) shape (cta_tile_m=64 + 2-CTA on SM100)
the over-emission has the warp-N stride and corrupts warp 1's smem
region with a duplicate of warp 0's data.

Build the aux r2s tiled copy explicitly to match aux's
(cta_tile_aux_m, size(epi_tile_aux_n)) tile: 1 thread per (M, N-warp)
position, each thread holding `size(epi_tile_aux_n) / num_n_warps`
values along N. That places every thread's writes into a single warp's
smem region with no aliasing across warps. Only applied for SM100 (the
(2, 2) layout is SM100-specific).
"""
if self.arch != 100:
return super().epi_make_aux_out_tiled_copy_r2s(
params, tiled_copy_r2s, tiled_copy_t2r
)
copy_atom_aux_out_r2s = self.epi_make_aux_out_copy_atom_r2s(params, tiled_copy_t2r)
cta_tile_aux_m, _ = self.cta_tile_shape_aux_out_mn
_, num_n_warps, _ = self.epi_smem_warp_shape_mnk()
# epi_tile_aux_n size: the N mode of epi_tile_mAuxOut may be a Layout
# (e.g. (16,2):(1,64)) when the (2, 2) warp shape is in effect, or an int
# otherwise. cute.size() handles both.
epi_tile_aux_n = cute.size(params.epi_tile_mAuxOut[1])
vals_per_thread_n = epi_tile_aux_n // num_n_warps
thr_layout = cute.make_layout(
(cta_tile_aux_m, num_n_warps), stride=(1, cta_tile_aux_m)
)
val_layout = cute.make_layout((1, vals_per_thread_n))
return cute.make_tiled_copy_tv(copy_atom_aux_out_r2s, thr_layout, val_layout)

def epi_to_underlying_arguments(
self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None
) -> GemmActMixin.EpilogueParams:
Expand Down
82 changes: 82 additions & 0 deletions solo_ab_min.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Minimal A/B test for the gated SM100 fix.

Phase A — _valid_2cta_m returns (128, 256): the original buggy default.
Phase B — _valid_2cta_m returns (256,): the patched default.

Each phase runs in its own subprocess with its own QUACK_CACHE_DIR so the
disk-backed jit_cache (whose key doesn't include use_2cta_instrs) can't serve
a stale cross-phase kernel.

Usage: CUDA_VISIBLE_DEVICES=<gpu> python solo_ab_min.py
"""
import json, os, shutil, subprocess, sys, tempfile, torch


def child(phase):
valid = (128, 256) if phase == "A" else (256,)
from quack.gemm_act import GemmGatedMixin
GemmGatedMixin._valid_2cta_m = lambda self, _v=valid: _v

from quack.gemm_config import GemmConfig
from quack.gemm_interface import gemm_gated_tuned, gemm_gated_ref

M, H, I, E = 32768, 1024, 512, 8
dtype, dev = torch.float16, torch.device("cuda:0")
g = torch.Generator(device=dev).manual_seed(0)
counts = torch.full((E,), M // E, dtype=torch.int32, device=dev)
cu = torch.zeros(E + 1, dtype=torch.int32, device=dev)
cu[1:] = torch.cumsum(counts, dim=0).to(torch.int32)
T = M // 4
x = (0.02 * torch.randn(T, H, generator=g, device=dev, dtype=torch.float32)).to(dtype)
A_idx = torch.randint(0, T, (M,), dtype=torch.int32, device=dev, generator=g)
w = torch.empty(E, 2 * I, H, dtype=torch.float32, device=dev)
torch.nn.init.normal_(w, mean=0.0, std=0.02, generator=g)
w1 = w.to(dtype).permute(1, 2, 0).permute(2, 1, 0)

cfg = GemmConfig(tile_m=128, tile_n=256, cluster_m=2, cluster_n=1,
swap_ab=False, max_swizzle_size=8,
is_dynamic_persistent=True, use_tma_gather=True,
pingpong=False, device_capacity=10)
pre, post = torch.empty(M, 2 * I, dtype=dtype, device=dev), torch.empty(M, I, dtype=dtype, device=dev)
gemm_gated_tuned.fn(x, w1, pre, post, None, None, "swiglu", cu, A_idx, False, config=cfg)
torch.cuda.synchronize()
pre_ref, post_ref = gemm_gated_ref(x, w1, bias=None, activation="swiglu",
cu_seqlens_m=cu, A_idx=A_idx,
store_preact=True, concat_layout=None)

print("PHASE_RESULT " + json.dumps({
"phase": phase, "valid_2cta_m": list(valid),
"preact_max_abs": (pre.float() - pre_ref.float()).abs().max().item(),
"preact_max_ref": pre_ref.float().abs().max().item(),
"postact_max_abs": (post.float() - post_ref.float()).abs().max().item(),
"postact_max_ref": post_ref.float().abs().max().item(),
}), flush=True)


def run_phase(phase):
cache = tempfile.mkdtemp(prefix=f"quack_cache_{phase}_")
env = {**os.environ, "QUACK_CACHE_DIR": cache}
try:
out = subprocess.run([sys.executable, "-u", __file__, phase],
capture_output=True, text=True, env=env, timeout=600)
finally:
shutil.rmtree(cache, ignore_errors=True)
for line in out.stdout.splitlines():
if line.startswith("PHASE_RESULT "):
return json.loads(line[len("PHASE_RESULT "):])
sys.exit(f"phase {phase} produced no result\n{out.stdout}\n{out.stderr}")


def main():
if len(sys.argv) > 1 and sys.argv[1] in ("A", "B"):
child(sys.argv[1]); return

a, b = run_phase("A"), run_phase("B")
for d in (a, b):
rel = d["preact_max_abs"] / max(d["preact_max_ref"], 1e-12)
print(f"phase {d['phase']} _valid_2cta_m={str(tuple(d['valid_2cta_m'])):<10} "
f"preact rel={rel:.4e} ({'FAIL' if rel > 0.05 else 'PASS'})")


if __name__ == "__main__":
main()
Loading