Skip to content

torch.compile compatibility: progress and remaining work #1

@vmoens

Description

@vmoens

Summary

Branch torch-compile-compat makes mujoco-torch compatible with torch.compile. The compiled path produces numerically identical results to eager, and steady-state compiled performance is at parity with eager (~1.0-1.1x on CPU).

Fixes applied

Phase 1: Core compatibility (commit de115ed)

  • Solver: Converted _Context/_LSPoint/_LSContext to NamedTuples for while_loop compatibility; fixed tensor aliasing and dtype mismatches.
  • Scan: Converted numpy index arrays to torch.LongTensor; cached grouping indices.
  • Collision: Pre-resolved collision functions and GeomType enums to avoid numpy key lookups.
  • Removed named_scope decorator — caused ~12 unnecessary recompilations per step.
  • Conditional _check_output — skips validation during compilation.
  • Bulk replace() — uses clone._tensordict._tensordict.update(kwargs) to avoid per-key _set_str recompilations.

Phase 2: Graph break elimination (commit 935d7dd)

  • Replaced Model.__getattribute__ override with a targeted names property descriptor. The __getattribute__ override caused graph breaks on EVERY attribute access on Model instances, making torch.compile skip entire frames (_advance, _euler).
  • Pre-built scan caches in device_putscan.flat / scan.body_tree contain Python for loops over groups with varying tensor shapes that cause O(groups) recompilations. Grouping caches are now pre-built in device_put so these functions don't need @torch.compiler.disable at runtime. The scan callbacks are compiled via torch.vmap internally.
  • @torch.compiler.disable on _collide_hfield_geoms — heightfield collision contains non-traceable operations.

Current state

Metric Before After
Recompilations (per 3 warm-up steps) 130+ ~20 (bounded)
Graph breaks (unique skipped frames) 42+ 10
Compiled vs eager speedup (CPU, 1000 steps) 0.5x ~1.0-1.1x
Warm-up compile time ~20s ~35s
Numerical accuracy identical identical

Benchmark (ant model, 1000 steps, CPU)

MuJoCo (C)                         8.6 ms  (116,038 steps/s)
mujoco-torch eager              24353.6 ms  (41 steps/s)
mujoco-torch compiled           23655.4 ms  (42 steps/s)
MJX (JAX jit)                    2201.0 ms  (454 steps/s)

Remaining work

Compilability

  • Reduce scan recompilations (high impact) — scan.flat and scan.body_tree use Python for loops over groups with varying tensor shapes, causing O(groups) recompilations that fragment the compiled graph. Options: (1) pad tensors to uniform shapes for a single batched vmap call, (2) use torch._higher_order_ops.scan for the tree traversal, (3) compile individual scan callbacks separately.
  • Make TensorClass constructors traceable (tensordict fix) — TensorClass.__init__ raises RuntimeError during tracing, causing graph breaks in collision and constraint code. Requires upstream fix in tensordict. (pytorch/tensordict@c07686f)

CUDA graph support (mode="reduce-overhead")

  • Fix tensor aliasing in _take() — CUDA graph capture fails because _take returns views that alias the input tensor's storage. CUDA graph replay requires non-aliased outputs. Options: (1) clone the sliced tensor, (2) restructure scan.flat/scan.body_tree to avoid aliased returns, (3) wait for PyTorch upstream support.

Performance: smooth.py cleanup

Compile warmup

Testing & benchmarking

  • GPU benchmarks — pytest-benchmark suite in benchmarks/ with 5 models x 5 batch sizes, comparing C, loop, vmap, compile, compile+H4, compile+H4+H8, and MJX.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions