Skip to content

[MISC] Speed up rigid constraint solver by adjusting cholesky tile size.#2827

Merged
duburcqa merged 32 commits into
Genesis-Embodied-AI:mainfrom
hughperkins:hp/cholesky-tile-factor
May 24, 2026
Merged

[MISC] Speed up rigid constraint solver by adjusting cholesky tile size.#2827
duburcqa merged 32 commits into
Genesis-Embodied-AI:mainfrom
hughperkins:hp/cholesky-tile-factor

Conversation

@hughperkins
Copy link
Copy Markdown
Collaborator

@hughperkins hughperkins commented May 23, 2026

Description

Interestingly, even though this only targets dex_hand with 32x32 (dofs 64 or so), this PR also modifies shared memory usage:

For g1_fall (n_dofs=35):
• Upstream rounds up to multiple of 32: ceil(35/32)*32 = 64
• This branch rounds up to multiple of cholesky_tile_size=16: ceil(35/16)*16 = 48

This reduction in shared memory usage then plausibly lead to increasd occupancy.

In any case, we see +7% FPS on g1_fall, even though g1_fall continues to use 16x16 tiles.

Replaces 16x16 tile Cholesky in func_cholesky_factor_direct_tiled and
func_cholesky_and_solve_fused_tiled with a 32x32 register-tile version
(genesis/utils/_tile32.py, ported from _tile16.py by mechanical 16 -> 32
expansion). Kernels now run at block_dim=32 (full warp) instead of 16
(sub-warp), eliminating the half-warp idle penalty that the
cholesky_mjw_vs_gs_2026may21 doc identified as a major driver of the 2.52x
compute-only gap vs MJWarp.

For dex_hand (n_dofs=62, tiled_n_dofs=64) the outer tile-block grid drops
from 4x4=16 blocks at block_dim=16 to 2x2=4 blocks at block_dim=32 — same
total FFMA work, twice the warp-execution throughput, fewer inter-tile
sync boundaries. Other 32-multiple tile sizes (g1_fall at 32, franka at
32, etc.) collapse to 1x1=1 block.

Diverges from T18-T20 pack-2 (which got 1.38x on factor kernel but lost
-8.45 % overall on dex_hand): one env per warp here (no pack-2 inter-tile
sync penalty, no skip-unchanged conflict).

The standalone cholesky_solve_tiled kernel (block_dim=64, 2 warps per
block) is unchanged. The non-tiled batch path (func_cholesky_factor_direct_batch)
is unchanged.

Includes tests/test_tile32_cholesky.py for direct numerical-equivalence
validation against numpy on a 32x32 SPD matrix.
…oops

Two changes on top of F3-S1:

1. _tile32.cholesky_: split the column-update dot reduction from 2 chains
   (dot0/dot1, ~16 FMAs deep at k=31) to 4 chains (dot0/dot1/dot2/dot3,
   ~8 FMAs deep at k=31). Cuts the back-to-back FMA dependency chain length
   another 2x for the 32-deep tile; _tile16's 2-way split was sufficient at
   16-deep but at 32-deep we want more ILP.

2. solver.py func_cholesky_factor_direct_tiled and _fused_tiled: outer
   tile-block loops wrapped in qd.static(range(N_BLOCKS)) where N_BLOCKS is
   qd.static(tiled_n_dofs // 32). For dex_hand (tiled_n_dofs=64) this
   constant-folds the entire kb/ib/jb tile-block structure to N_BLOCKS=2.
   k0/i0/j0 also qd.static so the compiler can constant-propagate them into
   load3d/store3d/_resolve_vec3d call sites.

S1 alone landed +2.62 % FPS on dex_hand (5-run, ±0.18 %), just shy of the
+3 % decision gate; S2 aims to push past it via tighter compile-time
folding and longer ILP.
…t-split

S2 (8d6ab56) blew up compile time from 107s -> 928s (8.7x) and ran into
pytest timeout on 4 of 5 dex_hand bench runs. The remaining run measured
-3.20% FPS vs main; this is dominated by the compile-time-induced register
spill / suboptimal codegen, not a real regression of the algorithm.

Diagnosis: qd.static-wrapping the kb/ib/jb tile-block loops on top of the
already-32-way-unrolled inner _ger_sub / _resolve_vec3d / cholesky_/
solve_triangular_ register-cascade ops generated >60k AST nodes for the
factor + fused funcs combined. Quadrants AST -> PTX path apparently doesn't
scale that far without spilling.

S2b keeps the small surgical change that's safe (4-way dot-split inside
_tile32.cholesky_, ~8-deep FMA chains vs 16-deep at k=31) and drops the
problematic static unroll. solver.py reverts to S1 (5659357) state.
S2b bench landed at +2.13 % +/- 0.37 % vs WandB main 23012 = -0.47 % vs S1
(within combined CI 0.55 % so within noise, but trending negative). 4-way
dot-split provides no gain over 2-way at the 32-deep POTRF chain, likely
because the GPU scheduler already hides the FMA latency at 2-way and the
extra accumulator regs / cross-chain adds eat the gain.

Reverting to clean S1 state: tile32 + block_dim=32 + 2-way dot-split (same
as _tile16). Branch head is now functionally equivalent to the original
F3-S1 commit (5659357) on solver.py; _tile32.py is the clean port.

This is the proposed final state for a PR. Headline: +2.62 % FPS on
dex_hand vs WandB main 79a0e9b (5-run, +/-0.18 %), from killing the
sub-warp execution penalty of the 16x16 register-tile Cholesky.
… of r0..r31 named fields

Same algorithm as S1 (F3-final), but the per-thread tile row is now a single
vector<32, dtype> field 'r' instead of 32 named scalar fields r0..r31. This
eliminates every 32-way 'if k == N: self.rN = val' if-cascade that hand-rolled
the runtime register-index dispatch in the named-field variant.

With qd.static-folded indices (the common path -- inside cholesky_'s nested
qd.static loops, eye_, _load3d, _store3d, _ger_sub, and from _resolve_vec3d
call sites in solver.py) the 'self.r[k]' access lowers to the same direct
register reference as 'self.rN' did via getattr in the named-field _r helper.
With runtime indices (in _get_col / _set_col / _trsm) it lowers to a 32-way
switch -- the same lowering the hand-rolled cascade produced. Generated PTX
should be byte-identical or near-byte-identical to F3-S1.

What changes is the *source* / *AST* node count: the named-field variant
emitted 64 lines (32 if-stmts + 32 assigns) per cascade site, walked by the
AST transformer even though all but one branch was dead-folded. The
vector-storage variant emits a single 'self.r[<k>] = val' line per site.
Across 7 cascade sites + cholesky_'s 2 cascades-per-outer-k-iteration, that
collapses the tile-method AST from O(2k nodes) to O(few hundred). The
_tile32.py file itself shrunk from 1077 to 505 lines.

Goal: drop the +20s S1 vs main compile-time cost back toward zero without
touching the runtime path.
…unc bodies

The first F4-A bench failed all 5 dex_hand runs with:

    UserWarning: [PURE.VIOLATION] WARNING: Accessing global variable _TILE
    <class 'int'> _TILE is in global vars, therefore violates pure

Triggered inside _trsm (and likely every other qd.func that referenced _TILE).
The original _tile32.py used a literal 32 inside qd.func bodies for exactly
this reason; module-level _TILE was OK to reference outside qd.func (at the
class factory level for vector(_TILE, dtype) and result.SIZE = _TILE).

Replace the in-func _TILE references with literal 32; keep the out-of-func
ones (which are evaluated at Python class-build time, not inside the AST
transformer).
F4-A vec32 storage compiled 10s faster (-50% of S1's overhead) but cost
-19% runtime FPS on dex_hand — the vec32 didn't register-promote on cuda
7.x, fell back to local memory. F4-B uses 4 separate qd.types.vector(8, dtype)
sub-banks 'b0..b3', each small enough to reliably register-promote (matching
the 12x12 / 144-element per-thread matrices that quadrants's per-thread
linalg ops register-promote).

All hot indexing in this module is static (qd.static unrolls in cholesky_,
_ger_sub, eye_, _load3d, _store3d), so the 4-way sub-bank dispatch (kb =
k // 8, ko = k % 8) folds at trace time to direct sub-vector + intra-vector
scalar access. The trace-time _static_read helper resolves the static-bank
dispatch in pure Python (not @qd.func), producing a single field-access AST
node per call site rather than a 4-way cascade. Write sites use an explicit
'if kb == N: self.bN[ko] = val' 4-way cascade — folded the same way.

Only _get_col / _set_col / _trsm carry the runtime 4-way cascade, where it
collapses to a switch over 4 banks (vs the original 32-way switch). _trsm
itself is unchanged in structure since the cascade now lives in _get_col /
_set_col.

Source size: 601 lines (vs S1 1077, vec32 505). Cascade lines reduced ~10x
vs S1.
…thon helper)

The previous F4-B failed with 'Quadrants Expression object is not subscriptable'
because the _static_read helper was a pure-Python function that tried to do
'self.b0[off]' outside the @qd.func AST context — quadrants Expression
objects don't implement Python __getitem__, only the qd AST transformer can
emit Subscript nodes against them.

Fix: drop the helper, inline the 4-way 'if kb == 0: self.b0[ko] ... elif
kb == 3: self.b3[ko]' dispatch directly inside cholesky_'s qd.static-unrolled
outer/inner loops. kb = k // 8 and ko = k % 8 are python ints (k is a python
int from qd.static), so the if-cascade folds at trace time — Python evaluates
the const predicate during qd.static unroll, only the matching branch enters
the AST. Same lowering as the original named-field approach, just with vec8
bank storage instead of 32 scalar fields.

Also dropped the now-unused self_k_post re-read: after the diag write, only
the tid==k lane mutates its register; the tid > k lanes still hold the
original loaded col-k value, so reusing the self_k SSA above is correct.
…s not elif)

The previous attempt failed at trace time with 'Name self_k is not defined'
— quadrants treats variables introduced inside an if/elif chain as
locally-scoped to that branch, not propagated to the outer scope, even when
every branch assigns to the same name. Matches the pattern used in
_tile16.cholesky_ where 'diag_val = qd.cast(0.0, dtype)' is pre-declared
before the if-cascade that assigns to it.

Fix: pre-declare self_k = qd.cast(0.0, dtype) and my_col = qd.cast(0.0, dtype)
before each respective 4-way cascade, and switch the cascade to separate
'if kb == N:' statements (matching the original tile16 pattern) rather than
elif/else. At trace time with kb being a python int from qd.static, each
'if kb == N' folds to True/False; only the matching branch emits AST. Same
single-branch lowering as the original named-field cascade, just with
self.bN[ko] writes instead of self.rN writes.
Both vec-storage attempts regressed runtime by ~18% on dex_hand:

  S1     (named r0..r31): 23614 FPS (+2.62%), compile 107.1s
  F4-A   (vec32):         18662 FPS (-18.90%), compile  97.4s
  F4-B   (vec8 x 4):      18915 FPS (-17.80%), compile 141.5s

The vec32 single-field layout fell back to local memory on cuda 7.x (gpu
register file doesn't promote 32-element per-thread vectors). The vec8 x 4
sub-bank layout also regressed — runtime hit suggests vec8 fields inside a
qd.dataclass also don't reliably register-promote (vs vec8 locals in
quadrants's per-thread linalg ops which do).  F4-B was additionally SLOWER
to compile than S1 because the 4-way bank-dispatch cascades emit more AST
than the 32-way named-field cascades after qd.static folding (each cascade
emits 4 if-statements regardless of whether only one branch is live, plus
the pre-declared self_k = qd.cast(0.0, dtype) lines, plus the bank-vec
subscript expressions are heavier than direct field references).

Conclusion: vec storage is not viable for compile-time-only optimization
of _tile32 without runtime regression on cuda 7.x. S1 remains the best
known state. Compile-time-only F4 effort is a null result; documented in
perso_hugh/doc/cholesky_tile32_2026may22.md.

Keeping the F4-A/B history in git so the next investigation has the failure
data on hand (e.g. when quadrants gets reliable register promotion for vec
fields, F4-B's 4-way cascade approach should immediately compile faster than
S1).
Add a build-time dispatch in rigid_solver.py that selects the register-tile
width for the Hessian Cholesky kernels:

  cholesky_tile_size = 32 if self.n_dofs >= 49 else 16

Rationale (see perso_hugh/doc/cholesky_tile32_2026may22.md):
- T=32 wins for large problems where the sub-warp penalty matters
  (dex_hand n_dofs=62: +2.6 %, box_pyramid_6 (~70 cholesky dofs): +4.7 %).
- T=16 wins when n_dofs lands in a padding-unfavorable band, where T=32
  rounds up to a much larger padded tile while T=16 fits tightly
  (g1_fall n_dofs=35: T=32 regressed -2.9 %; box_pyramid_3 (~24 dofs): -4.3 %;
   small problems too: box_pyramid_1 -9.7 %).

Implementation:
- New static field static_rigid_sim_config.cholesky_tile_size (default 32).
- Rename existing T=32 funcs to _t32 suffix; loop names get _t32 suffix.
- Restore T=16 versions from origin/main 79a0e9b as _t16 functions;
  their loop names get _t16 suffix so both can coexist.
- New unsuffixed dispatcher funcs (func_cholesky_factor_direct_tiled,
  func_cholesky_and_solve_fused_tiled) qd.static-switch between the
  two paths based on the static field. All existing call sites
  continue to call the unsuffixed dispatcher unchanged.
- Fix tiled_n_dofs rounding to use the chosen tile size as the modulus
  (was always rounded up to multiples of 32, now multiples of T).

Only the Hessian Cholesky is affected. Mass-matrix Cholesky (M = L L^T)
and other tiled kernels still use T=32 unconditionally.

Local AST + import sanity OK; cluster bench is the real test.
Original rule from the padded-volume + sub-warp model:
  T = 16 if n_dofs in [1..16] or [33..48] else 32

The previous commit used a simpler 'n_dofs >= 49 -> T=32' threshold which would
have regressed the [17..32] band where a single 32-lane tile beats two
sequential 16-lane tiles. The simpler rule was tempting because box_pyramid_3
(-4.29 % T=32 vs T=16) looked like a [17..32] regression, but pyramid_3's
actual Cholesky-N was never measured (it could easily be smaller, e.g. 12,
in the [1..16] band where the 4-band rule already picks T=16).

The two solidly measured endpoints both agree with the 4-band rule:
  - dex_hand (n_dofs=62) -> T=32, +2.6 %
  - g1_fall (n_dofs=35)  -> T=16, +2.9 %
… this branch

Tightening to 120c per the project line-width target. Several runs hit
hard word-boundary limits ("left-looking" anchors the break on the
T=32 algorithm comments at 113c, two below the script's 5-char-slack
threshold); accepted as structural and not reworded further.
This test was added by S1 (5659357). It validates the Tile32x32Cholesky
primitive directly via a one-warp kernel, mirroring the upstream
quadrants test_simt.py style. Until S1 lands in main, it lives in
perso_hugh/prot/test_tile32_cholesky.py so the dispatch branch stays
focused on the production code change.
Three sites referenced perso_hugh/doc/* paths or a GitHub URL into the
private fork:
- array_class.py / rigid_solver.py: docstring/comment pointers I added on
  this branch for the Cholesky dispatch context.
- rigid_solver.py: pre-existing 'perso_hugh/doc/linesearch_shuffle.md'
  pointer in _should_transpose_constraint_layout's docstring.
- test_rigid_physics_analytical_vs_gjk.py: pre-existing internal-only
  GitHub URL.

Keeping perso_hugh references out of genesis sources so the tree stays
self-contained for upstream review.
…lytical_vs_gjk.py

These lines were pre-existing in main and are not part of this PR's
scope; reverting their accidental deletion.
…_solver.py

Pre-existing on main; not part of this PR's scope.
The T=16 dispatch variants in solver.py had drifted from upstream/main
(docstring shortened, several inline comments stripped, three multi-line
comments rewrapped). Restore them so the diff against upstream is limited
to intentional changes only: the _t16 suffix on the function names and
qd.loop_config names, plus the docstring note pointing to the dispatcher.
The four explicit @qd.func definitions (factor_direct + fused, x2 tile
sizes) were byte-identical modulo 16/32 literals, Tile{16,32}x{16,32}Cholesky
references, block_dim, and qd.loop_config name. Replace them with two
factory functions that take T and TileCls as Python locals captured by
closure; Quadrants treats them as compile-time constants when parsing
the inner @qd.func (range(T) unrolls, block_dim=T constant-folds), so
each factory call produces the same compiled kernel as the explicit
version.

Module-level bindings func_cholesky_*_t16 / _t32 are preserved, so the
dispatcher and external callers are unchanged.

Net: 392 lines removed, 198 added (-194 LoC).

Validated by a minimal Quadrants-only smoke that the closure factory
returns valid @qd.func objects with per-specialization unrolled
range(T). GPU compile + perf parity vs the explicit version still needs
a cluster A/B run before this is merge-ready.
Replace the four-function split (factor_direct_t{16,32} + fused_t{16,32}
+ two dispatchers) with two unified @qd.func bodies that derive T from
static_rigid_sim_config.cholesky_tile_size via qd.static, and pick the
tile class with qd.static if/else at the two call sites that need it
(.eye() and .zeros() of the diagonal/off-diagonal tiles).

No closure captures: T and LOG2_T are local vars bound from qd.static(...)
of templated static_rigid_sim_config fields, which is the exact pattern
already used for MAX_DOFS / ENABLE_WARP_REDUCTION elsewhere in this
file. The tile classes are referenced directly (module-level imports),
which the PURE check exempts for non-(int/float/Field) globals
(quadrants/lang/ast/ast_transformer.py:96).

qd.static dead-eliminates the unused tile-class branch at parse time,
so only one class reaches Quadrants IR per build -- matches the prior
two-function specialization without needing two separate @qd.funcs.

Replaces the closure-factory attempt (00e5f76) which violated
fastcache purity on T (uppercase int -> warning, treated as error by
genesis/tests/conftest.py:439).
Field access on a qd.template() dataclass already yields the value at
compile time, so the qd.static wrapper around the T binding is
redundant. qd.static is still applied to the if-clauses
(`if qd.static(T == 32):`) where it's needed to dead-eliminate the
unused tile-class branch at parse time.
This reverts 416c895. Empirically, field access on qd.template() does
NOT return a Python int at parse time -- it returns a Quadrants
Expression that the compiler resolves later. Without qd.static(...),
block_dim=T fails:

  TypeError: block_dim(): incompatible function arguments.
  Invoked with: ..., <qd.Quadrants Expression>

(observed on cluster, dex_hand benchmark run on hp/cholesky-tile-factor).
qd.static(...) is what extracts the Python int at parse time, same as
the established MAX_DOFS / ENABLE_WARP_REDUCTION pattern in this file.
The qd.static if-else inside a @qd.func body creates a scope: variables
assigned inside don't escape, so L_kk/L_ik weren't visible after the
branch (verified on cluster: 'Name "L_kk" is not defined' at line 1917).

Replace with a parse-time Python ternary computed before the loop:

    T = qd.static(static_rigid_sim_config.cholesky_tile_size)
    TileCls = qd.static(Tile32x32Cholesky if T == 32 else Tile16x16Cholesky)

Since T is a Python int (from qd.static of a templated dataclass field),
the conditional resolves at parse time to a single Python class. TileCls
is then used the same way the explicit Tile16x16Cholesky / Tile32x32Cholesky
references were used in the four-function version.
The Python ternary form

    TileCls = Tile32x32Cholesky if qd.static(T == 32) else Tile16x16Cholesky

fails because Quadrants treats the assignment as a runtime expression
and can't make a constant from a class object (verified on cluster:
'Invalid constant scalar data type: ..._Tile32x32CholeskyProxy').

Solution: lift TileCls to a qd.template() kernel parameter on a new
internal _cholesky_factor_direct_tiled_impl (and matching fused
variant), and add a thin dispatcher wrapper that passes the right
class based on static_rigid_sim_config.cholesky_tile_size:

    @qd.func
    def func_cholesky_factor_direct_tiled(...):
        if qd.static(static_rigid_sim_config.cholesky_tile_size == 32):
            _cholesky_factor_direct_tiled_impl(..., Tile32x32Cholesky)
        else:
            _cholesky_factor_direct_tiled_impl(..., Tile16x16Cholesky)

qd.template() parameters are kernel parameters (not closure captures
or globals), so they don't trigger the PURE check that bit the closure
factory. One body per logical function, two thin dispatchers.

External callers still see func_cholesky_factor_direct_tiled /
func_cholesky_and_solve_fused_tiled unchanged.
git diff upstream/main flagged three deletions:
- the FIXME about migrating back to slice indexing once _tile16.py
  lands in Quadrants (restored in the direct variant body).
- the full func_cholesky_factor_direct_tiled docstring (algorithm
  description, padding semantics, DOF threshold). Moved back onto
  the dispatcher wrapper, with '16x16' generalised to 'TxT' and the
  n_dofs-based dispatch rule appended.
- the func_cholesky_and_solve_fused_tiled docstring. Same treatment.

Also drop the inline 'TileCls is a parse-time Python class binding'
comment I added; the TileCls=qd.template() docstring on the impl
already covers it.
git diff aligns line-by-line by similarity, so keeping the long
algorithm docstring on the body (which moved from
func_cholesky_factor_direct_tiled to _cholesky_factor_direct_tiled_impl)
collapses the upstream change to just:
- function rename + TileCls qd.template() parameter
- 16 -> T literal substitutions
- one added paragraph documenting the dispatch rule

The dispatcher wrappers now carry only a one-line pointer.
@github-actions
Copy link
Copy Markdown

🔴 Benchmark Regression Detected ➡️ Report

@hughperkins hughperkins changed the title [MISC] 32x32 Cholesky tile v2. [MISC] 32x32 Cholesky tile v2 to increase solver speed. May 23, 2026
@github-actions
Copy link
Copy Markdown

🔴 Benchmark Regression Detected ➡️ Report

@hughperkins hughperkins marked this pull request as ready for review May 23, 2026 22:51
@github-actions
Copy link
Copy Markdown

🔴 Benchmark Regression Detected ➡️ Report

@duburcqa duburcqa changed the title [MISC] 32x32 Cholesky tile v2 to increase solver speed. [MISC] Speed up rigid constraint solver by tuning cholesky tile size. May 24, 2026
@duburcqa duburcqa changed the title [MISC] Speed up rigid constraint solver by tuning cholesky tile size. [MISC] Speed up rigid constraint solver by adjusting cholesky tile size. May 24, 2026
@duburcqa duburcqa merged commit b28fdb8 into Genesis-Embodied-AI:main May 24, 2026
22 of 23 checks passed
@hughperkins
Copy link
Copy Markdown
Collaborator Author

Thanks! 🙌

hughperkins added a commit to hughperkins/Genesis that referenced this pull request May 25, 2026
…-AI#2827 baseline

Genesis-Embodied-AI#2827 widened the Cholesky kernel to Tile32x32 for n_dofs >= 17, tightening
the per-warp register budget. Cap=4 (which won +1.07 % on the pre-Genesis-Embodied-AI#2827
Tile16x16 baseline) now regresses dex_hand by -1.42 % on cluster (RTX PRO
6000) even though the qfrc kernel itself sped up -77 %.

Re-tuned on post-Genesis-Embodied-AI#2827 main (cluster A/B, 6 rounds, sd 41-46 FPS/run):
  cap=4: -1.42 % (~18 SEMs significant negative)
  cap=2: +0.69 % (~6 SEMs significant positive)  <- WINNER
  cap=1: +0.63 % (~5 SEMs significant positive)

Cap=2 still covers dex_hand active n_con ~ 55 fully (capacity 64); larger
scenes hit the tail's global re-read path (unchanged from baseline).

See perso_hugh/doc/p5_qfrc_register_cache_2026may23.md for full A/B and
diff-profile data.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants