Refactor GEMM for flexible work, operand routing, and epilogue policies#111
Open
santoshmo wants to merge 2 commits into
Open
Refactor GEMM for flexible work, operand routing, and epilogue policies#111santoshmo wants to merge 2 commits into
santoshmo wants to merge 2 commits into
Conversation
…ant dispatch sites - Add __extract_mlir_values__ to WorkDesc so it can participate in CuTe DSL while-loop state management (required by the new WorkDesc type replacing cutlass.utils.WorkTileInfo). - Append problem_args=None to compiled_fn() calls in gemm_act.py, gemm_dact.py, gemm_norm_act.py, gemm_sq_reduce.py, and gemm_symmetric.py for both SM100 and SM90 paths. Tested: 7,668+ tests passing on B200 (SM100) across test_linear, test_gemm_symmetric, test_linear_varlen_k, test_gemm_256x512, and test_linear_varlen_m.
91ef7e8 to
8de1110
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
QuACK's current mixins and hooks are good for local specialization — customizing epilogue math inside a fixed control-flow skeleton. They are not enough for structural specialization. Grouped GEMM, Split-K, and richer epilogue dispatch don't just change a few math ops — they change what a tile means, where operands come from, how K is partitioned, and how results are committed.
Today,
kernel()decides:That logic is embedded directly in the arch kernels. The current abstraction boundary is "customize the math". The next class of features needs "customize the meaning of work", "customize the mapping from work to operands", and "customize the finalization/store policy".
Without proper abstraction boundaries, every new structural feature does one of two bad things: adds more hooks into already-large kernels, or forks
kernel()/epilogue()wholesale. Both scale poorly.GemmSymmetricMixinalready demonstrates this — it had to override the entire epilogue loop (~170 lines) just to change store-commit behavior for symmetric tiles. That's the signal that the hook surface is too low-level in the wrong places.What this PR does
Introduces three new abstraction layers that lift the extension point to where structural behavior actually lives, without replacing the existing hook model:
1.
WorkDesc— richer work descriptor (gemm_work.py)Replaces
cutlass.utils.WorkTileInfo(justtile_idx+is_valid_tile) with a descriptor that carries enough metadata for grouped and split-K execution:All tile schedulers updated to return
WorkDesc. For current plain GEMM, extra fields are defaults — no behavioral change.2.
ProblemAdapter— per-problem operand routing (gemm_problem_adapter.py)Moves problem identity and operand provenance out of the architecture kernels via virtual methods on
GemmSm90:problem_get_problem_idx(params, work)— which logical problem a tile belongs toproblem_get_batch_A/B(params, tensor, varlen_manager, work)— locate A/B tensorsproblem_get_batch_epi(params, tensor, varlen_manager, work)— locate output tensorsproblem_get_len_k(params, varlen_manager, work)— K dimension for the current problemDefault implementations delegate to
VarlenManager, preserving current behavior with zero change to generated PTX.GroupedProblemAdapterMixinprovides grouped variants, applied at compile time:3.
EpiloguePlan— composable epilogue traversal (gemm_epilogue_plan.py)Separates the reusable epilogue loop (subtile iteration, C prefetch pipelining, accumulator transport) from pluggable policy:
epi_plan_make_tile_layout()— subtile iteration orderepi_plan_commit()— D/postact store behavior (default vs symmetric vs future split-K partial/final)GemmSymmetricMixinis simplified from ~170 lines (full epilogue override) to ~30 lines (justepi_plan_commitoverride). Features that change store policy no longer need to duplicate traversal logic.The key insight
Every call site in every warp role that previously did:
Now does:
This is a mechanical, one-to-one replacement across ~20 call sites in 3 arch files. The diff is large in line count but the pattern is uniform and reviewable.
What the new boundary buys
GemmSm90/GemmSm100stay responsible for architecture executionProblemAdapterowns which problem is being executed and where operands come fromEpiloguePlanowns how outputs are traversed and committedFuture features can change the right axis without touching unrelated machinery. Grouped GEMM becomes a ~60-line mixin. Split-K will specialize
WorkDescfields +epi_plan_commitfor partial/final stores.Files changed
quack/gemm_work.py(new)WorkDescdefinition andmake_work_desc()factoryquack/gemm_problem_adapter.py(new)ProblemAdaptertypes, default + grouped implementationsquack/gemm_epilogue_plan.py(new)run_epilogue_plan,default_epi_commit,symmetric_epi_commitquack/tile_scheduler.pyWorkDescinstead ofWorkTileInfoquack/gemm_sm90.pyproblem_get_*methods, threadproblem_paramsthrough__call__andkernelquack/gemm_sm100.pyproblem_paramsthrough__call__andkernelquack/gemm_sm120.pyproblem_paramsthroughkernelquack/gemm.pygroupedpath with dynamic class creation and dispatchquack/gemm_tvm_ffi_utils.pyproblem_argstocompile_gemm_kernelquack/gemm_interface.pygemm_grouped()host wrapperquack/__init__.pygemm,gemm_groupedquack/gemm_symmetric.pyepi_plan_commitquack/gemm_act.py,gemm_dact.py,gemm_norm_act.py,gemm_sq_reduce.pyproblem_args=Noneto dispatch calls