Skip to content

Refactor GEMM for flexible work, operand routing, and epilogue policies#111

Open
santoshmo wants to merge 2 commits into
Dao-AILab:mainfrom
santoshmo:refactor/gemm-flexible-work-epilogue
Open

Refactor GEMM for flexible work, operand routing, and epilogue policies#111
santoshmo wants to merge 2 commits into
Dao-AILab:mainfrom
santoshmo:refactor/gemm-flexible-work-epilogue

Conversation

@santoshmo
Copy link
Copy Markdown
Contributor

Motivation

QuACK's current mixins and hooks are good for local specialization — customizing epilogue math inside a fixed control-flow skeleton. They are not enough for structural specialization. Grouped GEMM, Split-K, and richer epilogue dispatch don't just change a few math ops — they change what a tile means, where operands come from, how K is partitioned, and how results are committed.

Today, kernel() decides:

  • Which logical problem a tile belongs to
  • How to find A/B/C/D for that problem
  • What K range to run
  • When and how to store outputs

That logic is embedded directly in the arch kernels. The current abstraction boundary is "customize the math". The next class of features needs "customize the meaning of work", "customize the mapping from work to operands", and "customize the finalization/store policy".

Without proper abstraction boundaries, every new structural feature does one of two bad things: adds more hooks into already-large kernels, or forks kernel() / epilogue() wholesale. Both scale poorly. GemmSymmetricMixin already demonstrates this — it had to override the entire epilogue loop (~170 lines) just to change store-commit behavior for symmetric tiles. That's the signal that the hook surface is too low-level in the wrong places.

What this PR does

Introduces three new abstraction layers that lift the extension point to where structural behavior actually lives, without replacing the existing hook model:

1. WorkDesc — richer work descriptor (gemm_work.py)

Replaces cutlass.utils.WorkTileInfo (just tile_idx + is_valid_tile) with a descriptor that carries enough metadata for grouped and split-K execution:

class WorkDesc(NamedTuple):
    tile_coord_mnkl: cute.Coord
    problem_idx: Int32
    k_tile_begin: Int32 = Int32(0)
    k_tile_count: Optional[Int32] = None
    split_k_idx: Int32 = Int32(0)
    split_k_parts: Int32 = Int32(1)
    is_final_split: Boolean = Boolean(True)
    is_valid_tile: Boolean = Boolean(False)

All tile schedulers updated to return WorkDesc. For current plain GEMM, extra fields are defaults — no behavioral change.

2. ProblemAdapter — per-problem operand routing (gemm_problem_adapter.py)

Moves problem identity and operand provenance out of the architecture kernels via virtual methods on GemmSm90:

  • problem_get_problem_idx(params, work) — which logical problem a tile belongs to
  • problem_get_batch_A/B(params, tensor, varlen_manager, work) — locate A/B tensors
  • problem_get_batch_epi(params, tensor, varlen_manager, work) — locate output tensors
  • problem_get_len_k(params, varlen_manager, work) — K dimension for the current problem

Default implementations delegate to VarlenManager, preserving current behavior with zero change to generated PTX. GroupedProblemAdapterMixin provides grouped variants, applied at compile time:

GemmCls = type("GroupedGemmSm100", (GroupedProblemAdapterMixin, GemmSm100), {})

3. EpiloguePlan — composable epilogue traversal (gemm_epilogue_plan.py)

Separates the reusable epilogue loop (subtile iteration, C prefetch pipelining, accumulator transport) from pluggable policy:

  • epi_plan_make_tile_layout() — subtile iteration order
  • epi_plan_commit() — D/postact store behavior (default vs symmetric vs future split-K partial/final)

GemmSymmetricMixin is simplified from ~170 lines (full epilogue override) to ~30 lines (just epi_plan_commit override). Features that change store policy no longer need to duplicate traversal logic.

The key insight

Every call site in every warp role that previously did:

batch_idx = tile_coord_mnkl[3]
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)

Now does:

batch_idx = self.problem_get_problem_idx(problem_params, work_tile)
mA_mk = self.problem_get_batch_A(problem_params, mA_mkl, varlen_manager, work_tile)

This is a mechanical, one-to-one replacement across ~20 call sites in 3 arch files. The diff is large in line count but the pattern is uniform and reviewable.

What the new boundary buys

  • GemmSm90 / GemmSm100 stay responsible for architecture execution
  • ProblemAdapter owns which problem is being executed and where operands come from
  • EpiloguePlan owns how outputs are traversed and committed

Future features can change the right axis without touching unrelated machinery. Grouped GEMM becomes a ~60-line mixin. Split-K will specialize WorkDesc fields + epi_plan_commit for partial/final stores.

Files changed

File Change
quack/gemm_work.py (new) WorkDesc definition and make_work_desc() factory
quack/gemm_problem_adapter.py (new) ProblemAdapter types, default + grouped implementations
quack/gemm_epilogue_plan.py (new) run_epilogue_plan, default_epi_commit, symmetric_epi_commit
quack/tile_scheduler.py Return WorkDesc instead of WorkTileInfo
quack/gemm_sm90.py Add problem_get_* methods, thread problem_params through __call__ and kernel
quack/gemm_sm100.py Thread problem_params through __call__ and kernel
quack/gemm_sm120.py Thread problem_params through kernel
quack/gemm.py Add grouped path with dynamic class creation and dispatch
quack/gemm_tvm_ffi_utils.py Add problem_args to compile_gemm_kernel
quack/gemm_interface.py Add gemm_grouped() host wrapper
quack/__init__.py Export gemm, gemm_grouped
quack/gemm_symmetric.py Simplified from ~170 to ~30 lines using epi_plan_commit
quack/gemm_act.py, gemm_dact.py, gemm_norm_act.py, gemm_sq_reduce.py Append problem_args=None to dispatch calls

Santosh Mohan and others added 2 commits April 18, 2026 19:53
…ant dispatch sites

- Add __extract_mlir_values__ to WorkDesc so it can participate in CuTe
  DSL while-loop state management (required by the new WorkDesc type
  replacing cutlass.utils.WorkTileInfo).
- Append problem_args=None to compiled_fn() calls in gemm_act.py,
  gemm_dact.py, gemm_norm_act.py, gemm_sq_reduce.py, and
  gemm_symmetric.py for both SM100 and SM90 paths.

Tested: 7,668+ tests passing on B200 (SM100) across test_linear,
test_gemm_symmetric, test_linear_varlen_k, test_gemm_256x512, and
test_linear_varlen_m.
@santoshmo santoshmo force-pushed the refactor/gemm-flexible-work-epilogue branch from 91ef7e8 to 8de1110 Compare April 19, 2026 02:57
@thakkarV thakkarV requested review from thakkarV and tridao and removed request for thakkarV April 21, 2026 15:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant