Skip to content

SynapticSage/torch2mlx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torch2mlx

Translate PyTorch neural network models to Apple's MLX framework.

Scope: torch2mlx converts models for inference on Apple Silicon. Training support (including a Lightning-compatible MLX Trainer) is on the roadmap.

Why

PyTorch models don't run natively on Apple Silicon's GPU/Neural Engine. MLX does — but porting a model means manually transposing weight layouts, renaming state dict keys, rewriting forward() calls, and debugging silent numerical mismatches.

torch2mlx automates the mechanical parts:

  • Weight conversion — dispatches the correct transposition per layer type (Conv2d needs [O,I,H,W][O,H,W,I], Linear is identity, etc.)
  • State dict surgery — converts PyTorch's flat dot-separated keys to MLX's nested dicts, through safetensors as the interchange format
  • Portability analysis — tells you before you start porting what percentage of the model converts automatically and what needs manual work
  • MLX templates — hand-written reference implementations for common patterns (transformer blocks, conv stacks, MLPs)

Quickstart

pip install torch2mlx          # core (numpy + safetensors only)
pip install torch2mlx[all]     # with torch + mlx + dev tools

Python API

import torch2mlx

# Analyze portability before converting
report = torch2mlx.analyze(model)
print(f"Coverage: {report.coverage:.0%}")

# Convert a PyTorch model → safetensors
torch2mlx.convert(model, "weights.safetensors")

# Load into MLX
params = torch2mlx.load_converted("weights.safetensors")
mlx_model.load_weights(list(params.items()))

CLI

# Convert with portability report
python -m torch2mlx model.pt output/

# Analyze only (no conversion)
python -m torch2mlx model.pt --analyze-only

You can also pass a pre-extracted state dict (numpy arrays with dot-separated keys) instead of a live torch.nn.Module — no torch installation required for the conversion step itself.

How it works

torch2mlx walks the PyTorch module tree, looks up each layer in a registry to find its MLX equivalent and weight transposition rule, applies the transpositions using numpy only (no framework imports during conversion), and saves the result as safetensors. A separate analyzer inspects the model's forward() source for non-convertible patterns (in-place mutation, custom autograd, etc.) and reports blockers before you invest time porting.

src/torch2mlx/
├── registry.py          # torch.nn.X → mlx.nn.X dispatch table
├── op_mapping.py        # torch.cat → mx.concatenate etc. + dtype mappings
├── weight_converter.py  # Per-layer transposition rules (numpy only)
├── state_dict.py        # Flat keys ↔ nested dict + safetensors I/O
├── analyzer.py          # Portability report: % convertible, blockers
├── converter.py         # End-to-end orchestration
└── templates/           # Hand-written MLX module implementations

What's supported

72 layer types, 30 ops, 12 dtype mappings, 7 weight transposition rules — covering Linear, Conv1d/2d, ConvTranspose1d/2d, BatchNorm, LayerNorm, RMSNorm, Embedding, MultiheadAttention, GroupNorm, InstanceNorm, pooling (MaxPool/AvgPool 1d/2d/3d, AdaptiveAvgPool2d), Transformer encoder/decoder, common activations (GELU, ReLU, SiLU, Tanh, Sigmoid, LeakyReLU, Softmax), and tensor ops (einsum, matmul, reshape, squeeze/unsqueeze, reductions, etc.).

Works with torch.compile() — compiled models convert identically to uncompiled ones.

See docs/support-matrix.md for the full table.

Tested HuggingFace models

The analyzer reports 100% coverage on all 36 tested architectures:

Category Models
Encoder BERT, RoBERTa, DistilBERT, ALBERT, DeBERTa, DeBERTa-v3, Electra, MPNet, Longformer, Funnel, CamemBERT, Data2Vec-Text
Decoder / Causal LM GPT-2, GPT-Neo, OPT, BLOOM, Qwen2, Pythia, CodeGen, Falcon
Encoder-Decoder T5, BART, Pegasus
Vision ViT, CLIP, Swin Transformer, ConvNeXt, DINOv2, BEiT, SegFormer, MobileNetV2, ResNet
Speech Whisper, Wav2Vec2, HuBERT
Other XLNet

Not supported (architectural blockers): RNNs/LSTMs (stateful, out of scope), Conv3d (MLX lacks it), in-place mutation patterns (+=, .copy_() — MLX arrays are immutable).

Numerical equivalence

Three end-to-end validation examples in examples/ prove that converted models produce identical outputs:

Example Architecture Max logit diff MLX speedup
validate_mnist.py CNN (Conv2d, MaxPool2d, Linear) < 1e-5 ~3x
validate_transformer.py Transformer (Attention, FFN, LayerNorm) < 1e-5 ~2x
validate_resnet.py ResNet (Conv2d, BatchNorm, skip connections) < 1e-5 ~6x

Each script trains a small model in PyTorch, converts via torch2mlx, loads into an equivalent MLX model, and compares predictions — 100% agreement across all three.

Templates

Hand-written MLX implementations for common architecture patterns:

Template Description
MLP Linear stacks with configurable activation, dropout, residual connections
TransformerBlock Self-attention + FFN + LayerNorm (pre-norm and post-norm)
ConvBlock Conv + normalization + activation
ConvStack Stacked ConvBlocks with channel progression
AdaptiveAvgPool2d Dynamic kernel/stride computation for adaptive average pooling

These are reference implementations, not auto-generated. Use them directly or as a starting point for hand-porting custom architectures.

Progress

Phase Status Highlights
P0 — Layer & op coverage Done 62 layer mappings, 30 op mappings, 7 transposition rules, 12 dtype mappings
P1 — CLI & API Done python -m torch2mlx, public API (convert, analyze, export), e2e tests, 3 numerical equivalence examples
P2 — Polish Done PyPI metadata, support-matrix cleanup, dtype registry, torch.compile interop
P3 — HuggingFace validation Done 22/22 models at 100% analyzer coverage, weight round-trip (MLX→PyTorch)
Training support Planned Lightning-compatible MLX Trainer — see roadmap

Current numbers

Metric
Layer types 72
Op mappings 30
Dtype mappings 12
Transposition rules 7 (+ reverse for round-trip)
Templates 5 (MLP, Transformer, ConvBlock, ConvStack, AdaptiveAvgPool2d)
Tests 301
HuggingFace models tested 36/36 at 100%

Roadmap

torch2mlx currently targets inference-only conversion of feed-forward architectures.

Planned next:

  • Auto template generation — generate MLX module stubs from torch module trees
  • Decorator API@torch2mlx.export for compile-style annotation
  • Weight streaming — convert large models without loading full state dict into memory
  • Training support — Lightning-compatible MLX Trainer where users provide an MLX-native forward() while weights, optimizers, schedulers, and the training loop are automated

See next-steps.md for detailed plans including the three-level Lightning integration strategy.

Development

pip install -e ".[all]"          # Install with torch + mlx + dev deps
python -m pytest                 # Run tests (301 tests)
ruff check src/                  # Lint
ruff format src/ tests/          # Format

License

Apache 2.0

About

(alpha development) Translate PyTorch neural network models to Apple's MLX framework. (An attempt towards torch.compile() for mlx).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages