Skip to content

Fix the SwiGLU/dSwiGLU bug for the (2, 2) epilogue layout#133

Open
GarlGuo wants to merge 4 commits into
mainfrom
explore-22-warp
Open

Fix the SwiGLU/dSwiGLU bug for the (2, 2) epilogue layout#133
GarlGuo wants to merge 4 commits into
mainfrom
explore-22-warp

Conversation

@GarlGuo
Copy link
Copy Markdown
Member

@GarlGuo GarlGuo commented May 7, 2026

No description provided.

tridao and others added 4 commits May 6, 2026 12:36
Reverts the _valid_2cta_m overrides on GemmGatedMixin and GemmDGatedMixin
so the bug fires, plus adds Python print() instrumentation in:
- quack/gemm_base.py: D-path TMA atom inputs
- quack/epi_ops.py: TileStore aux-path inputs and outputs
- quack/gemm_act.py: epi_visit_subtile register layouts (compile-time
  tracing) and epi_setup_aux_out tiled-copy + partition_D dump

INVESTIGATION_22_WARP.md captures the smoking gun: tiled_copy_aux_out_r2s
(built via make_tiled_copy_S(aux_atom, tiled_copy_r2s)) inherits D's full-N
tiler MN of 64x64, but is then applied to aux smem of 64x32. Each smem
position is written by two threads -- warp 1's data gets clobbered by
warp 0's, then TMA copies the duplicated smem to two distinct gmem
positions, producing the observed gmem[0..15] == gmem[64..79] corruption.

This is exploration only. Not for merge.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the _valid_2cta_m=(256,) workaround on GemmGatedMixin (and
preventive workaround on GemmDGatedMixin) with a source-level fix in
GemmGatedMixin.epi_make_aux_out_tiled_copy_r2s.

Root cause: make_tiled_copy_S(aux_atom, tiled_copy_r2s) inherited D's
full-N tiler MN (64x64) and applied it to aux's half-N smem (64x32),
emitting 128 threads x 32 vals = 4096 elements into a 2048-element smem
region. For (4, 1) the over-emission had stride 0 (harmless self-
overwrite); for (2, 2) it had warp-N stride 1024 (corrupted warp 1's
region with a duplicate of warp 0's data).

Fix: build the aux r2s tiled copy via make_tiled_copy_tv with explicit
layouts -- 128 threads as (cta_tile_aux_m, num_n_warps), each holding
size(epi_tile_aux_n)/num_n_warps values along N. Tiler MN now matches
aux smem exactly. SM90/SM120 keep the original construction since the
Layout-typed epi_tile_n is SM100-specific.

Removes the gemm_dact.py override (preventive, dgated has no half-N
recast) and the gemm_sm100.py _valid_2cta_m method indirection.

Verification with the b10ffed workaround reverted:
- solo_ab_min.py: both phases PASS rel=0.0
- instr_run.py original buggy shape: postact rel=8.21e-4 PASS
- 12 (M,H,I,E) x cluster_m combos: identical errors cm=1 vs cm=2
- test_untuned_buggy_tiles.py --shapes small: 208/208 PASS (4 shapes
  x 52 forced configs, fwd+bwd)
- sweep_gated_dgated.py: 216/216 PASS

INVESTIGATION_22_WARP.md captures the full investigation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@GarlGuo
Copy link
Copy Markdown
Member Author

GarlGuo commented May 7, 2026

@tridao I will check if all tests are passed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants