NVIDIA/CUDA-first LeWM training and inference runtime.
This repo is focused on one model family: LeWM image world models from
stable-worldmodel. The runtime target is Linux with NVIDIA hardware, CUDA,
cuDNN, nvJPEG, NVDECODE, and Candle CUDA tensors. The hot path is:
image/video observation -> CUDA preprocess -> LeWM encode -> candidate rollout -> cost -> action
Performance is the primary acceptance criterion. The repo is not a portability layer, and non-Linux/non-NVIDIA targets are intentionally out of scope.
Runtime work should keep media buffers, preprocessed tensors, embeddings, candidate action batches, rollouts, costs, and selected actions in the Rust/Candle CUDA path. Python is included for bootstrap, checkpoint conversion, data export, and parity checks against the official implementation. Python is not the deployment runtime.
When Candle lacks a needed NVIDIA primitive, the preferred direction is a focused Candle CUDA op, a direct NVIDIA library binding, or a CUDA-compatible crate that preserves device residency.
- LeWM model runtime: ViT-Tiny encoder, projector, action encoder, predictor, latent rollout, goal embedding, goal cost, and session caching.
- LeWM planning: CEM, MPPI, and iCEM over Candle CUDA tensors.
- NVIDIA image/video ingest: nvJPEG decode into CUDA tensors, packed RGB/BGR CUDA preprocessing, NV12 CUDA preprocessing, and NVDECODE capability/parser plumbing.
- LeWM training surface: PLDM, VCReg, temporal-straightening losses, batch-loss API, AdamW training CLIs, PushT HDF5 dataset streaming, and safetensors save/reload.
- Python bootstrap tooling: official
stable-worldmodel[train]package viauv, checkpoint conversion, PushT batch export, Python parity fixture export, and Python-vs-Rust image-planning benchmark scripts. - Hugging Face checkpoint download is available with
--features hub.
The audited upstream stable-worldmodel commit is tracked in
docs/upstream-stable-worldmodel.md.
- Linux host with NVIDIA GPU
- CUDA toolkit and driver libraries
- cuDNN available to Candle
libnvjpeg.solibnvcuvid.so- Rust toolchain from
rust-toolchain.toml uv
The build script requires libnvjpeg.so and libnvcuvid.so. Set CUDA_HOME,
CUDA_PATH, or NVIDIA_VIDEO_CODEC_SDK_PATH if they are not under standard
system library paths.
cargo check --locked --all-targets
cargo test --lockedWith Hugging Face Hub checkpoint download:
cargo check --locked --features hub --all-targetsThe repo includes .python-version, pyproject.toml, and uv.lock.
pyproject.toml defines the supported Python range and dependencies, including
stable-worldmodel[train].
uv sync --locked --no-devConvert a PyTorch state dict to safetensors:
uv run --locked --no-dev \
python tools/convert_state_dict_safetensors.py \
--input /path/to/weights.pt \
--output target/model.safetensorsExport a deterministic CUDA fixture from the official Python implementation:
uv run --locked --no-dev \
python tools/export_lewm_fixture.py \
--model quentinll/lewm-pusht \
--device cuda \
--output target/lewm-pusht-python-cuda.npzCompare Rust/Candle CUDA against that fixture:
cargo run --release --locked --features hub --bin lewm-compare-fixture -- \
--device cuda \
--fixture target/lewm-pusht-python-cuda.npz \
--hf-repo quentinll/lewm-pushtRun checkpoint-backed planning from fixture tensors:
cargo run --release --locked --features hub --bin lewm-plan-fixture -- \
--device cuda \
--fixture target/lewm-pusht-python-cuda.npz \
--hf-repo quentinll/lewm-pusht \
--planner icem \
--samples 128 \
--iterations 3 \
--seed 7Validation snapshot on 2026-06-03, RTX 4090, quentinll/lewm-pusht,
PyTorch 2.12.0+cu130, CUDA 13.0:
| Output | Max Abs |
|---|---|
emb |
5.731881e-4 |
act_emb |
4.768372e-7 |
pred |
7.328391e-4 |
rollout |
6.533712e-4 |
cost |
5.619049e-3 |
Cost argmin was stable for the fixture batch.
Snapshot on 2026-06-03, RTX 4090, quentinll/lewm-pusht, CUDA 13.0,
planner=icem, samples=1024, iterations=5, horizon=5, history_size=3.
Metric is synchronized CUDA p50 wall time after 2 warmup runs and 5 measured
runs. Python is vanilla stable-worldmodel LeWM through PyTorch; Rust is
lewm-plan-images with nvJPEG decode, Candle CUDA encode/rollout/scoring, and
Rust-native planning. In this image-input PushT benchmark, Rust/Candle is faster
across the hot path: 3-4x for media decode/preprocess, 1.37-1.51x for image
encoding, 1.13x for iCEM planning, and 1.66x for selected-score evaluation.
Plan from JPEG current/goal images through nvJPEG, CUDA preprocessing, LeWM, and Rust-native planning:
cargo run --release --locked --features hub --bin lewm-plan-images -- \
--device cuda \
--hf-repo quentinll/lewm-pusht \
--current current.jpg \
--goal goal.jpg \
--planner icem \
--samples 1024 \
--iterations 5 \
--output target/reports/lewm-pusht-plan.htmlCompare Rust/Candle LeWM losses against official Python CUDA losses:
uv run --locked --no-dev \
python tools/export_lewm_training_loss_fixture.py \
--device cuda \
--output target/lewm-training-loss-python-cuda.npz
cargo run --release --locked --bin lewm-compare-training-loss -- \
--device cuda \
--fixture target/lewm-training-loss-python-cuda.npz \
--tolerance 1e-5Validation snapshot on 2026-06-03, RTX 4090:
| Loss | Max Abs |
|---|---|
idm_loss |
0 |
temp_align_loss |
1.192093e-7 |
std_loss |
0 |
std_t_loss |
0 |
cov_loss |
2.980232e-8 |
cov_t_loss |
0 |
temporal_straightening_loss |
0 |
Export a PushT image/action batch and run a Rust/Candle CUDA training step:
uv run --locked --no-dev \
python tools/export_pusht_lewm_training_batch.py \
--output target/pusht-lewm-training-batch.npz \
--batch-size 2 \
--history-size 3 \
--action-block 5 \
--seed 7
cargo run --release --locked --bin lewm-train-batch -- \
--device cuda \
--batch-npz target/pusht-lewm-training-batch.npz \
--steps 10 \
--lr 1e-5 \
--output target/pusht-lewm-trained.safetensorsTrain LeWM from the PushT HDF5 dataset without Python in the data/training path:
cargo run --release --locked --bin lewm-train-pusht -- \
--device cuda \
--dataset-h5 ~/.stable_worldmodel/pusht_expert_train.h5 \
--epochs 100 \
--batch-size 64 \
--history-size 3 \
--action-block 5 \
--output-dir target/pusht-from-scratchlewm-train-pusht reads pusht_expert_train.h5 natively through Rust HDF5
with in-process Blosc filter support. It reproduces the Python exporter dataset
semantics: valid row selection from episode_idx, step_idx, and ep_len;
image history rows at row + idx * action_block; flattened action blocks; and
dataset-wide action mean/std normalization. Because PushT H5 pixels are already
decoded RGB arrays, the optimized path is HDF5 host reads, raw RGB
host-to-CUDA transfer, CUDA resize/normalize/history assembly, and LeWM
training on Candle CUDA tensors. It does not use nvJPEG or NVDECODE.
The trainer writes metrics.jsonl, dataset-summary.json, model-config.json,
training-state.json, latest.safetensors, periodic
checkpoint-step-*.safetensors files, optimizer.safetensors, periodic
optimizer-step-*.safetensors files, final.safetensors, and
final-optimizer.safetensors. Use --init-safetensors for a weights-only warm
start. Use --resume-dir for exact continuation from latest.safetensors,
optimizer.safetensors, and training-state.json; the trainer resumes from the
saved global_step, which maps deterministically back to the same epoch shuffle
and next batch.
cargo run --release --locked --bin lewm-train-pusht -- \
--device cuda \
--dataset-h5 ~/.stable_worldmodel/pusht_expert_train.h5 \
--resume-dir target/pusht-from-scratch \
--epochs 100 \
--batch-size 64 \
--history-size 3 \
--action-block 5 \
--output-dir target/pusht-from-scratchRun the PushT environment demo through Rust planning:
uv run --locked --no-dev \
python tools/run_pusht_lewm_rust_demo.py \
--hf-repo quentinll/lewm-pusht \
--planner icem \
--replans 2 \
--output-dir target/reports/pusht-lewm-demoRun the same demo with a locally trained Rust checkpoint:
uv run --locked --no-dev \
python tools/run_pusht_lewm_rust_demo.py \
--weights target/pusht-from-scratch/latest.safetensors \
--config target/pusht-from-scratch/model-config.json \
--planner icem \
--history-size 1 \
--replans 2 \
--output-dir target/reports/pusht-from-scratch-demoRun Python-vs-Rust image-planning benchmark tooling:
uv run --locked --no-dev \
python tools/benchmark_lewm_plan_images_python.py \
--model quentinll/lewm-pusht \
--current current.jpg \
--goal goal.jpg \
--output target/bench/lewm-plan-images-python.json