Export PyTorch models to WebNN format using torch dynamo IR.
This is an early-stage experimental implementation for research and exploration. Many features are incomplete, untested, or may change significantly.
# Clone the repository
git clone https://github.com/yourusername/webnn_torch_export.git
cd webnn_torch_export
# Install in editable mode with dev dependencies
pip install -e ".[dev]"
# Optional: Run pytest
pytestpip install webnn_torch_exportExport graph only:
from webnn_torch_export import export_model
import torch
import torch.nn as nn
# Create your model
model = nn.Conv2d(3, 16, kernel_size=3)
input_tensor = torch.randn(1, 3, 28, 28)
# Export with debug output
compiled_model, exporter = export_model(model, input_tensor, debug=True)
# Save exported graph
exporter.save_to_file('my_export.json')
# Access exported graphs programmatically
for graph in exporter.exported_graphs:
print(graph['nodes'])Export graph + weights:
from webnn_torch_export import export_model_with_weights, load_weights_from_safetensors
import torch
import torch.nn as nn
# Create and export your model
model = nn.Sequential(
nn.Conv2d(3, 16, 3),
nn.ReLU(),
nn.Linear(16, 10)
)
input_tensor = torch.randn(1, 3, 28, 28)
# Export both graph and weights
compiled_model, exporter = export_model_with_weights(
model=model,
example_input=input_tensor,
graph_path="model_graph.json",
weights_path="model_weights.safetensors",
debug=False
)
# Later: load weights into a fresh model
new_model = nn.Sequential(
nn.Conv2d(3, 16, 3),
nn.ReLU(),
nn.Linear(16, 10)
)
load_weights_from_safetensors(new_model, "model_weights.safetensors")# Using the installed command
webnn-export
# Or run directly
python -m webnn_torch_export.exporterThe CustomExporter class is a Dynamo backend that:
- Receives FX graphs from PyTorch's compilation process
- Converts them to a custom format (JSON)
- Provides debug output to understand graph structure
- Maintains execution compatibility
Key methods:
export_graph(): Main callback that receives FX graphs_convert_fx_to_custom_format(): Converts FX graph to JSONsave_to_file(): Exports graphs to JSON files
Single Operator Tests (tests/test_single_ops.py):
test_conv2d_export(): Tests Conv2d exporttest_matmul_export(): Tests matmul exporttest_linear_export(): Tests Linear layer exporttest_conv_with_different_configs(): Parametrized tests for various Conv2d configurationstest_exported_graph_structure(): Validates exported graph structure
Integration Tests (tests/test_mnist_integration.py):
SimplerMNISTClassifier: Conv + ReLU + LinearMNISTClassifier: Full classifier with 2 conv blockstest_simple_mnist_export(): Exports simple modeltest_full_mnist_export(): Exports full modeltest_mnist_inference_consistency(): Tests consistency across multiple runstest_mnist_batch_size_invariance(): Tests with different batch sizes
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
# Your export logic here
return gm
compiled_model = torch.compile(model, backend=custom_backend)When Dynamo compiles a model, it produces an FX graph with nodes representing:
- placeholder: Input tensors
- call_function: Function calls (e.g.,
torch.relu,torch.matmul) - call_module: Module invocations (e.g.,
conv1,fc1) - call_method: Tensor method calls (e.g.,
x.flatten()) - output: Return values
PyTorch Model → torch.compile() → Dynamo → FX Graph → Custom Backend → Export Format
↓
Your Export Logic
With debug=True, the exporter prints:
- Complete FX graph representation
- Generated Python code
- Individual node details:
- Node name and operation type
- Target function/module
- Arguments and keyword arguments
- Tensor metadata (shapes, dtypes)
================================================================================
DYNAMO EXPORT CALLBACK TRIGGERED
================================================================================
Graph Module:
graph():
%x : [num_users=1] = placeholder[target=x]
%conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%conv1,), kwargs = {})
return (relu,)
Node: x
Op: placeholder
Target: x
...
{
"nodes": [
{
"name": "x",
"op": "placeholder",
"target": "x",
"args": [],
"kwargs": {}
},
{
"name": "conv1",
"op": "call_module",
"target": "conv1",
"module": "conv1",
"args": ["x"],
"kwargs": {}
}
],
"graph_str": "graph(): ...",
"code": "def forward(self, x): ..."
}When you export a model with unsupported operations, you'll get a clear error message showing exactly what's missing:
================================================================================
UNSUPPORTED OPERATION
================================================================================
Operation: layer_norm
Node: layer_norm_output
Target: <function layer_norm at 0x...>
Schema: aten::layer_norm(Tensor input, int[] normalized_shape, ...)
Args: ['input_tensor', '[3072]', 'weight', 'bias', '1e-5']
Kwargs: {}
================================================================================
This operation is not yet supported in WebNN export.
To add support, update webnn_op_mappings.py with a mapping for this operation.
This makes it easy to incrementally add support for operations as you need them.
Quick Steps:
- Run your export - get the error showing the unsupported operation
- Add mapping in
webnn_torch_export/webnn_op_mappings.py:TARGET_CONTAINS_TO_CONVERTER: Dict[str, ConverterFn] = { # ... existing mappings ... "layer_norm": lambda gen, node, output, inputs: gen._convert_layer_norm(node, output, inputs), }
- Implement converter in
webnn_torch_export/webnn_generator.py:def _convert_layer_norm(self, node: fx.Node, output: str, inputs: List[str]) -> str: """Convert LayerNorm to WebNN""" input_tensor = inputs[0] if inputs else 'unknown' # ... conversion logic ... return f'[{output}] = layerNormalization({input_tensor}, ...);'
- Test - run export again, repeat for next unsupported operation
For detailed guidance, see ADDING_OPS.md - a comprehensive guide covering:
- How to map PyTorch operations to WebNN
- Common patterns (activations, normalization, matrix ops)
- Step-by-step walkthrough with examples
- WebNN operation reference
- Debugging tips
Modify _convert_fx_to_custom_format() to output your desired format:
def _convert_fx_to_custom_format(self, gm):
# Convert to your format (protobuf, flatbuffer, etc.)
my_format = convert_to_my_format(gm.graph)
return my_format# Run all tests
pytest
# Run with coverage
pytest --cov=webnn_torch_export --cov-report=html
# Run specific markers
pytest -m "not slow"
pytest -m integration# Build distribution
python -m build
# Install locally
pip install -e ".[dev]"- PyTorch 2.0+ (for
torch.compilesupport) - Python 3.8+
- Start with
debug=Trueto see full graph output - Use single operator tests to understand individual operations
- Check node metadata for tensor shapes and types
- Verify correctness by comparing original vs compiled outputs
- Examine exported JSON to understand graph structure
- Add support for more operators (pooling, normalization, etc.)
- Implement graph optimization passes
- Add serialization to binary formats (protobuf, flatbuffer)
- Handle dynamic shapes
- Support quantized models
- Add execution validation tests
Apache License (2.0) (see LICENSE file)