Skip to content

[MISC] 32x32 Cholesky tile.#2826

Closed
hughperkins wants to merge 20 commits into
Genesis-Embodied-AI:mainfrom
hughperkins:hp/cholesky-tile-dispatch
Closed

[MISC] 32x32 Cholesky tile.#2826
hughperkins wants to merge 20 commits into
Genesis-Embodied-AI:mainfrom
hughperkins:hp/cholesky-tile-dispatch

Conversation

@hughperkins
Copy link
Copy Markdown
Collaborator

Description

Related Issue

Resolves Genesis-Embodied-AI/Genesis#

Motivation and Context

How Has This Been / Can This Be Tested?

Screenshots (if appropriate):

Checklist:

  • I read the CONTRIBUTING document.
  • I followed the Submitting Code Changes section of CONTRIBUTING document.
  • I tagged the title correctly (including BUG FIX/FEATURE/MISC/BREAKING)
  • I updated the documentation accordingly or no change is needed.
  • I tested my changes and added instructions on how to test it for reviewers.
  • I have added tests to cover my changes.
  • All new and existing tests passed.

Replaces 16x16 tile Cholesky in func_cholesky_factor_direct_tiled and
func_cholesky_and_solve_fused_tiled with a 32x32 register-tile version
(genesis/utils/_tile32.py, ported from _tile16.py by mechanical 16 -> 32
expansion). Kernels now run at block_dim=32 (full warp) instead of 16
(sub-warp), eliminating the half-warp idle penalty that the
cholesky_mjw_vs_gs_2026may21 doc identified as a major driver of the 2.52x
compute-only gap vs MJWarp.

For dex_hand (n_dofs=62, tiled_n_dofs=64) the outer tile-block grid drops
from 4x4=16 blocks at block_dim=16 to 2x2=4 blocks at block_dim=32 — same
total FFMA work, twice the warp-execution throughput, fewer inter-tile
sync boundaries. Other 32-multiple tile sizes (g1_fall at 32, franka at
32, etc.) collapse to 1x1=1 block.

Diverges from T18-T20 pack-2 (which got 1.38x on factor kernel but lost
-8.45 % overall on dex_hand): one env per warp here (no pack-2 inter-tile
sync penalty, no skip-unchanged conflict).

The standalone cholesky_solve_tiled kernel (block_dim=64, 2 warps per
block) is unchanged. The non-tiled batch path (func_cholesky_factor_direct_batch)
is unchanged.

Includes tests/test_tile32_cholesky.py for direct numerical-equivalence
validation against numpy on a 32x32 SPD matrix.
…oops

Two changes on top of F3-S1:

1. _tile32.cholesky_: split the column-update dot reduction from 2 chains
   (dot0/dot1, ~16 FMAs deep at k=31) to 4 chains (dot0/dot1/dot2/dot3,
   ~8 FMAs deep at k=31). Cuts the back-to-back FMA dependency chain length
   another 2x for the 32-deep tile; _tile16's 2-way split was sufficient at
   16-deep but at 32-deep we want more ILP.

2. solver.py func_cholesky_factor_direct_tiled and _fused_tiled: outer
   tile-block loops wrapped in qd.static(range(N_BLOCKS)) where N_BLOCKS is
   qd.static(tiled_n_dofs // 32). For dex_hand (tiled_n_dofs=64) this
   constant-folds the entire kb/ib/jb tile-block structure to N_BLOCKS=2.
   k0/i0/j0 also qd.static so the compiler can constant-propagate them into
   load3d/store3d/_resolve_vec3d call sites.

S1 alone landed +2.62 % FPS on dex_hand (5-run, ±0.18 %), just shy of the
+3 % decision gate; S2 aims to push past it via tighter compile-time
folding and longer ILP.
…t-split

S2 (8d6ab56) blew up compile time from 107s -> 928s (8.7x) and ran into
pytest timeout on 4 of 5 dex_hand bench runs. The remaining run measured
-3.20% FPS vs main; this is dominated by the compile-time-induced register
spill / suboptimal codegen, not a real regression of the algorithm.

Diagnosis: qd.static-wrapping the kb/ib/jb tile-block loops on top of the
already-32-way-unrolled inner _ger_sub / _resolve_vec3d / cholesky_/
solve_triangular_ register-cascade ops generated >60k AST nodes for the
factor + fused funcs combined. Quadrants AST -> PTX path apparently doesn't
scale that far without spilling.

S2b keeps the small surgical change that's safe (4-way dot-split inside
_tile32.cholesky_, ~8-deep FMA chains vs 16-deep at k=31) and drops the
problematic static unroll. solver.py reverts to S1 (5659357) state.
S2b bench landed at +2.13 % +/- 0.37 % vs WandB main 23012 = -0.47 % vs S1
(within combined CI 0.55 % so within noise, but trending negative). 4-way
dot-split provides no gain over 2-way at the 32-deep POTRF chain, likely
because the GPU scheduler already hides the FMA latency at 2-way and the
extra accumulator regs / cross-chain adds eat the gain.

Reverting to clean S1 state: tile32 + block_dim=32 + 2-way dot-split (same
as _tile16). Branch head is now functionally equivalent to the original
F3-S1 commit (5659357) on solver.py; _tile32.py is the clean port.

This is the proposed final state for a PR. Headline: +2.62 % FPS on
dex_hand vs WandB main 79a0e9b (5-run, +/-0.18 %), from killing the
sub-warp execution penalty of the 16x16 register-tile Cholesky.
… of r0..r31 named fields

Same algorithm as S1 (F3-final), but the per-thread tile row is now a single
vector<32, dtype> field 'r' instead of 32 named scalar fields r0..r31. This
eliminates every 32-way 'if k == N: self.rN = val' if-cascade that hand-rolled
the runtime register-index dispatch in the named-field variant.

With qd.static-folded indices (the common path -- inside cholesky_'s nested
qd.static loops, eye_, _load3d, _store3d, _ger_sub, and from _resolve_vec3d
call sites in solver.py) the 'self.r[k]' access lowers to the same direct
register reference as 'self.rN' did via getattr in the named-field _r helper.
With runtime indices (in _get_col / _set_col / _trsm) it lowers to a 32-way
switch -- the same lowering the hand-rolled cascade produced. Generated PTX
should be byte-identical or near-byte-identical to F3-S1.

What changes is the *source* / *AST* node count: the named-field variant
emitted 64 lines (32 if-stmts + 32 assigns) per cascade site, walked by the
AST transformer even though all but one branch was dead-folded. The
vector-storage variant emits a single 'self.r[<k>] = val' line per site.
Across 7 cascade sites + cholesky_'s 2 cascades-per-outer-k-iteration, that
collapses the tile-method AST from O(2k nodes) to O(few hundred). The
_tile32.py file itself shrunk from 1077 to 505 lines.

Goal: drop the +20s S1 vs main compile-time cost back toward zero without
touching the runtime path.
…unc bodies

The first F4-A bench failed all 5 dex_hand runs with:

    UserWarning: [PURE.VIOLATION] WARNING: Accessing global variable _TILE
    <class 'int'> _TILE is in global vars, therefore violates pure

Triggered inside _trsm (and likely every other qd.func that referenced _TILE).
The original _tile32.py used a literal 32 inside qd.func bodies for exactly
this reason; module-level _TILE was OK to reference outside qd.func (at the
class factory level for vector(_TILE, dtype) and result.SIZE = _TILE).

Replace the in-func _TILE references with literal 32; keep the out-of-func
ones (which are evaluated at Python class-build time, not inside the AST
transformer).
F4-A vec32 storage compiled 10s faster (-50% of S1's overhead) but cost
-19% runtime FPS on dex_hand — the vec32 didn't register-promote on cuda
7.x, fell back to local memory. F4-B uses 4 separate qd.types.vector(8, dtype)
sub-banks 'b0..b3', each small enough to reliably register-promote (matching
the 12x12 / 144-element per-thread matrices that quadrants's per-thread
linalg ops register-promote).

All hot indexing in this module is static (qd.static unrolls in cholesky_,
_ger_sub, eye_, _load3d, _store3d), so the 4-way sub-bank dispatch (kb =
k // 8, ko = k % 8) folds at trace time to direct sub-vector + intra-vector
scalar access. The trace-time _static_read helper resolves the static-bank
dispatch in pure Python (not @qd.func), producing a single field-access AST
node per call site rather than a 4-way cascade. Write sites use an explicit
'if kb == N: self.bN[ko] = val' 4-way cascade — folded the same way.

Only _get_col / _set_col / _trsm carry the runtime 4-way cascade, where it
collapses to a switch over 4 banks (vs the original 32-way switch). _trsm
itself is unchanged in structure since the cascade now lives in _get_col /
_set_col.

Source size: 601 lines (vs S1 1077, vec32 505). Cascade lines reduced ~10x
vs S1.
…thon helper)

The previous F4-B failed with 'Quadrants Expression object is not subscriptable'
because the _static_read helper was a pure-Python function that tried to do
'self.b0[off]' outside the @qd.func AST context — quadrants Expression
objects don't implement Python __getitem__, only the qd AST transformer can
emit Subscript nodes against them.

Fix: drop the helper, inline the 4-way 'if kb == 0: self.b0[ko] ... elif
kb == 3: self.b3[ko]' dispatch directly inside cholesky_'s qd.static-unrolled
outer/inner loops. kb = k // 8 and ko = k % 8 are python ints (k is a python
int from qd.static), so the if-cascade folds at trace time — Python evaluates
the const predicate during qd.static unroll, only the matching branch enters
the AST. Same lowering as the original named-field approach, just with vec8
bank storage instead of 32 scalar fields.

Also dropped the now-unused self_k_post re-read: after the diag write, only
the tid==k lane mutates its register; the tid > k lanes still hold the
original loaded col-k value, so reusing the self_k SSA above is correct.
…s not elif)

The previous attempt failed at trace time with 'Name self_k is not defined'
— quadrants treats variables introduced inside an if/elif chain as
locally-scoped to that branch, not propagated to the outer scope, even when
every branch assigns to the same name. Matches the pattern used in
_tile16.cholesky_ where 'diag_val = qd.cast(0.0, dtype)' is pre-declared
before the if-cascade that assigns to it.

Fix: pre-declare self_k = qd.cast(0.0, dtype) and my_col = qd.cast(0.0, dtype)
before each respective 4-way cascade, and switch the cascade to separate
'if kb == N:' statements (matching the original tile16 pattern) rather than
elif/else. At trace time with kb being a python int from qd.static, each
'if kb == N' folds to True/False; only the matching branch emits AST. Same
single-branch lowering as the original named-field cascade, just with
self.bN[ko] writes instead of self.rN writes.
Both vec-storage attempts regressed runtime by ~18% on dex_hand:

  S1     (named r0..r31): 23614 FPS (+2.62%), compile 107.1s
  F4-A   (vec32):         18662 FPS (-18.90%), compile  97.4s
  F4-B   (vec8 x 4):      18915 FPS (-17.80%), compile 141.5s

The vec32 single-field layout fell back to local memory on cuda 7.x (gpu
register file doesn't promote 32-element per-thread vectors). The vec8 x 4
sub-bank layout also regressed — runtime hit suggests vec8 fields inside a
qd.dataclass also don't reliably register-promote (vs vec8 locals in
quadrants's per-thread linalg ops which do).  F4-B was additionally SLOWER
to compile than S1 because the 4-way bank-dispatch cascades emit more AST
than the 32-way named-field cascades after qd.static folding (each cascade
emits 4 if-statements regardless of whether only one branch is live, plus
the pre-declared self_k = qd.cast(0.0, dtype) lines, plus the bank-vec
subscript expressions are heavier than direct field references).

Conclusion: vec storage is not viable for compile-time-only optimization
of _tile32 without runtime regression on cuda 7.x. S1 remains the best
known state. Compile-time-only F4 effort is a null result; documented in
perso_hugh/doc/cholesky_tile32_2026may22.md.

Keeping the F4-A/B history in git so the next investigation has the failure
data on hand (e.g. when quadrants gets reliable register promotion for vec
fields, F4-B's 4-way cascade approach should immediately compile faster than
S1).
Add a build-time dispatch in rigid_solver.py that selects the register-tile
width for the Hessian Cholesky kernels:

  cholesky_tile_size = 32 if self.n_dofs >= 49 else 16

Rationale (see perso_hugh/doc/cholesky_tile32_2026may22.md):
- T=32 wins for large problems where the sub-warp penalty matters
  (dex_hand n_dofs=62: +2.6 %, box_pyramid_6 (~70 cholesky dofs): +4.7 %).
- T=16 wins when n_dofs lands in a padding-unfavorable band, where T=32
  rounds up to a much larger padded tile while T=16 fits tightly
  (g1_fall n_dofs=35: T=32 regressed -2.9 %; box_pyramid_3 (~24 dofs): -4.3 %;
   small problems too: box_pyramid_1 -9.7 %).

Implementation:
- New static field static_rigid_sim_config.cholesky_tile_size (default 32).
- Rename existing T=32 funcs to _t32 suffix; loop names get _t32 suffix.
- Restore T=16 versions from origin/main 79a0e9b as _t16 functions;
  their loop names get _t16 suffix so both can coexist.
- New unsuffixed dispatcher funcs (func_cholesky_factor_direct_tiled,
  func_cholesky_and_solve_fused_tiled) qd.static-switch between the
  two paths based on the static field. All existing call sites
  continue to call the unsuffixed dispatcher unchanged.
- Fix tiled_n_dofs rounding to use the chosen tile size as the modulus
  (was always rounded up to multiples of 32, now multiples of T).

Only the Hessian Cholesky is affected. Mass-matrix Cholesky (M = L L^T)
and other tiled kernels still use T=32 unconditionally.

Local AST + import sanity OK; cluster bench is the real test.
Original rule from the padded-volume + sub-warp model:
  T = 16 if n_dofs in [1..16] or [33..48] else 32

The previous commit used a simpler 'n_dofs >= 49 -> T=32' threshold which would
have regressed the [17..32] band where a single 32-lane tile beats two
sequential 16-lane tiles. The simpler rule was tempting because box_pyramid_3
(-4.29 % T=32 vs T=16) looked like a [17..32] regression, but pyramid_3's
actual Cholesky-N was never measured (it could easily be smaller, e.g. 12,
in the [1..16] band where the 4-band rule already picks T=16).

The two solidly measured endpoints both agree with the 4-band rule:
  - dex_hand (n_dofs=62) -> T=32, +2.6 %
  - g1_fall (n_dofs=35)  -> T=16, +2.9 %
@hughperkins
Copy link
Copy Markdown
Collaborator Author

Ran benchmarks twice:

20260523-s1-dispatch-run2 20260523-s1-dispatch

=> small, but consistent, benefit on dex_hand (~2%) and g1_fall (~5%)

… this branch

Tightening to 120c per the project line-width target. Several runs hit
hard word-boundary limits ("left-looking" anchors the break on the
T=32 algorithm comments at 113c, two below the script's 5-char-slack
threshold); accepted as structural and not reworded further.
This test was added by S1 (5659357). It validates the Tile32x32Cholesky
primitive directly via a one-warp kernel, mirroring the upstream
quadrants test_simt.py style. Until S1 lands in main, it lives in
perso_hugh/prot/test_tile32_cholesky.py so the dispatch branch stays
focused on the production code change.
Three sites referenced perso_hugh/doc/* paths or a GitHub URL into the
private fork:
- array_class.py / rigid_solver.py: docstring/comment pointers I added on
  this branch for the Cholesky dispatch context.
- rigid_solver.py: pre-existing 'perso_hugh/doc/linesearch_shuffle.md'
  pointer in _should_transpose_constraint_layout's docstring.
- test_rigid_physics_analytical_vs_gjk.py: pre-existing internal-only
  GitHub URL.

Keeping perso_hugh references out of genesis sources so the tree stays
self-contained for upstream review.
…lytical_vs_gjk.py

These lines were pre-existing in main and are not part of this PR's
scope; reverting their accidental deletion.
…_solver.py

Pre-existing on main; not part of this PR's scope.
The T=16 dispatch variants in solver.py had drifted from upstream/main
(docstring shortened, several inline comments stripped, three multi-line
comments rewrapped). Restore them so the diff against upstream is limited
to intentional changes only: the _t16 suffix on the function names and
qd.loop_config names, plus the docstring note pointing to the dispatcher.
@github-actions
Copy link
Copy Markdown

🔴 Benchmark Regression Detected ➡️ Report

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant