You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Branch torch-compile-compat makes mujoco-torch compatible with torch.compile. The compiled path produces numerically identical results to eager, and steady-state compiled performance is at parity with eager (~1.0-1.1x on CPU).
Replaced Model.__getattribute__ override with a targeted names property descriptor. The __getattribute__ override caused graph breaks on EVERY attribute access on Model instances, making torch.compile skip entire frames (_advance, _euler).
Pre-built scan caches in device_put — scan.flat / scan.body_tree contain Python for loops over groups with varying tensor shapes that cause O(groups) recompilations. Grouping caches are now pre-built in device_put so these functions don't need @torch.compiler.disable at runtime. The scan callbacks are compiled via torch.vmap internally.
@torch.compiler.disable on _collide_hfield_geoms — heightfield collision contains non-traceable operations.
Current state
Metric
Before
After
Recompilations (per 3 warm-up steps)
130+
~20 (bounded)
Graph breaks (unique skipped frames)
42+
10
Compiled vs eager speedup (CPU, 1000 steps)
0.5x
~1.0-1.1x
Warm-up compile time
~20s
~35s
Numerical accuracy
identical
identical
Benchmark (ant model, 1000 steps, CPU)
MuJoCo (C) 8.6 ms (116,038 steps/s)
mujoco-torch eager 24353.6 ms (41 steps/s)
mujoco-torch compiled 23655.4 ms (42 steps/s)
MJX (JAX jit) 2201.0 ms (454 steps/s)
Remaining work
Compilability
Reduce scan recompilations (high impact) — scan.flat and scan.body_tree use Python for loops over groups with varying tensor shapes, causing O(groups) recompilations that fragment the compiled graph. Options: (1) pad tensors to uniform shapes for a single batched vmap call, (2) use torch._higher_order_ops.scan for the tree traversal, (3) compile individual scan callbacks separately.
Make TensorClass constructors traceable (tensordict fix) — TensorClass.__init__ raises RuntimeError during tracing, causing graph breaks in collision and constraint code. Requires upstream fix in tensordict. (pytorch/tensordict@c07686f)
CUDA graph support (mode="reduce-overhead")
Fix tensor aliasing in _take() — CUDA graph capture fails because _take returns views that alias the input tensor's storage. CUDA graph replay requires non-aliased outputs. Options: (1) clone the sliced tensor, (2) restructure scan.flat/scan.body_tree to avoid aliased returns, (3) wait for PyTorch upstream support.
Performance: smooth.py cleanup
Remove redundant clone in factor_m — qld = qld / qld[...] already creates a new tensor; the subsequent .clone() was unnecessary. Also replaced clone-per-iteration with out-of-place scatter(). (PR [smooth] Performance cleanup + reduce-overhead benchmarks #29)
Precompute tendon tensors — moved torch.as_tensor()/torch.arange() from per-step tendon() to one-time device_put(), stored as UnbatchedTensor fields. tendon() now does zero numpy computation and zero tensor creation at runtime. (PR [smooth] Performance cleanup + reduce-overhead benchmarks #29)
GPU benchmarks — pytest-benchmark suite in benchmarks/ with 5 models x 5 batch sizes, comparing C, loop, vmap, compile, compile+H4, compile+H4+H8, and MJX.
Summary
Branch
torch-compile-compatmakes mujoco-torch compatible withtorch.compile. The compiled path produces numerically identical results to eager, and steady-state compiled performance is at parity with eager (~1.0-1.1x on CPU).Fixes applied
Phase 1: Core compatibility (commit de115ed)
_Context/_LSPoint/_LSContexttoNamedTuples forwhile_loopcompatibility; fixed tensor aliasing and dtype mismatches.torch.LongTensor; cached grouping indices.GeomTypeenums to avoid numpy key lookups.named_scopedecorator — caused ~12 unnecessary recompilations per step._check_output— skips validation during compilation.replace()— usesclone._tensordict._tensordict.update(kwargs)to avoid per-key_set_strrecompilations.Phase 2: Graph break elimination (commit 935d7dd)
Model.__getattribute__override with a targetednamesproperty descriptor. The__getattribute__override caused graph breaks on EVERY attribute access onModelinstances, making torch.compile skip entire frames (_advance,_euler).device_put—scan.flat/scan.body_treecontain Pythonforloops over groups with varying tensor shapes that cause O(groups) recompilations. Grouping caches are now pre-built indevice_putso these functions don't need@torch.compiler.disableat runtime. The scan callbacks are compiled viatorch.vmapinternally.@torch.compiler.disableon_collide_hfield_geoms— heightfield collision contains non-traceable operations.Current state
Benchmark (ant model, 1000 steps, CPU)
Remaining work
Compilability
scan.flatandscan.body_treeuse Pythonforloops over groups with varying tensor shapes, causing O(groups) recompilations that fragment the compiled graph. Options: (1) pad tensors to uniform shapes for a single batched vmap call, (2) usetorch._higher_order_ops.scanfor the tree traversal, (3) compile individual scan callbacks separately.TensorClass.__init__raises RuntimeError during tracing, causing graph breaks in collision and constraint code. Requires upstream fix intensordict. (pytorch/tensordict@c07686f)CUDA graph support (
mode="reduce-overhead")_take()— CUDA graph capture fails because_takereturns views that alias the input tensor's storage. CUDA graph replay requires non-aliased outputs. Options: (1) clone the sliced tensor, (2) restructurescan.flat/scan.body_treeto avoid aliased returns, (3) wait for PyTorch upstream support.Performance:
smooth.pycleanupfactor_m—qld = qld / qld[...]already creates a new tensor; the subsequent.clone()was unnecessary. Also replaced clone-per-iteration with out-of-placescatter(). (PR [smooth] Performance cleanup + reduce-overhead benchmarks #29)solve_m— replacedx = x.clone(); x[idx] = valwithx = x.scatter(0, idx, val), matching the pattern already used insolver.py. (PR [smooth] Performance cleanup + reduce-overhead benchmarks #29)np.nonzero/np.isinfrom per-steptendon()to one-timedevice_put(), stored asUnbatchedTensorfields on the Model. (PR [smooth] Performance cleanup + reduce-overhead benchmarks #29)torch.as_tensor()/torch.arange()from per-steptendon()to one-timedevice_put(), stored asUnbatchedTensorfields.tendon()now does zero numpy computation and zero tensor creation at runtime. (PR [smooth] Performance cleanup + reduce-overhead benchmarks #29)Compile warmup
Testing & benchmarking
benchmarks/with 5 models x 5 batch sizes, comparing C, loop, vmap, compile, compile+H4, compile+H4+H8, and MJX.