Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a3a42e4
Update
manuelcandales Apr 14, 2026
1c965c6
Update
manuelcandales Apr 14, 2026
1be53ab
Update
manuelcandales Apr 14, 2026
47cbe76
Update
manuelcandales Apr 14, 2026
805a09d
Update
manuelcandales Apr 14, 2026
5306c5a
Update
manuelcandales Apr 14, 2026
638edaa
Update
manuelcandales Apr 14, 2026
958712e
Update
manuelcandales Apr 14, 2026
eba74c4
Update
manuelcandales Apr 14, 2026
c9ecdde
Update
manuelcandales Apr 14, 2026
c222005
Update
manuelcandales Apr 14, 2026
e7a7acc
Update
manuelcandales Apr 14, 2026
5530242
Update
manuelcandales Apr 14, 2026
59f88db
Update
manuelcandales Apr 14, 2026
1fbb94f
Update
manuelcandales Apr 14, 2026
60ca500
Update
manuelcandales Apr 14, 2026
d70d646
Update
manuelcandales Apr 14, 2026
d80da37
Update
manuelcandales Apr 14, 2026
f8ff857
Update
manuelcandales Apr 16, 2026
4632a83
Update
manuelcandales Apr 20, 2026
98d2f81
Update
manuelcandales Apr 20, 2026
95fb7f9
Update
manuelcandales Apr 20, 2026
440f7fc
Update
manuelcandales Apr 20, 2026
525e67b
Update
manuelcandales Apr 20, 2026
f4f616e
Update
manuelcandales Apr 20, 2026
b8e1201
Update
manuelcandales Apr 20, 2026
9ce837a
Update
manuelcandales Apr 20, 2026
bd12247
Update
manuelcandales Apr 20, 2026
248115a
Update
manuelcandales Apr 20, 2026
ee865c3
Update
manuelcandales Apr 20, 2026
36d45ef
Update
manuelcandales Apr 20, 2026
08a9fa2
Update
manuelcandales Apr 20, 2026
9000488
Update
manuelcandales Apr 20, 2026
a060d19
Update
manuelcandales Apr 20, 2026
01c3ce5
Update
manuelcandales Apr 20, 2026
0c1a88b
Update
manuelcandales Apr 20, 2026
2c56804
Update
manuelcandales Apr 20, 2026
7b480b3
Update
manuelcandales Apr 20, 2026
933122c
Update
manuelcandales Apr 20, 2026
9def0ed
Update
manuelcandales Apr 20, 2026
01ecf6a
Update
manuelcandales Apr 20, 2026
1766789
Update
manuelcandales Apr 20, 2026
21057d6
Update
manuelcandales Apr 20, 2026
7423226
Update
manuelcandales Apr 20, 2026
4b791ea
Update
manuelcandales Apr 20, 2026
ff92256
Update
manuelcandales Apr 20, 2026
b9b75e3
Update
manuelcandales Apr 20, 2026
f8ebcfb
Update
manuelcandales Apr 21, 2026
4cf31c8
Update
manuelcandales Apr 21, 2026
ba0e56e
Update
manuelcandales Apr 21, 2026
187e4f5
Update
manuelcandales Apr 21, 2026
23bec62
Update
manuelcandales Apr 21, 2026
c53ecc6
Update
manuelcandales Apr 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 141 additions & 7 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,34 @@ def _prepare_and_quantize_mlx(model, config, args):
pack_all_switch_linears(model)


def _prepare_and_quantize_metal(model, config, args):
"""Metal: apply source transforms, quantize experts + non-expert layers."""
import executorch.backends.apple.metal.ops.gated_delta_rule # noqa: F401
import executorch.backends.apple.metal.ops.gather_qmv # noqa: F401
from executorch.examples.models.qwen3_5_moe.metal_source_transformations import (
metal_source_transformations,
quantize_experts_metal,
)

# Quantize expert weights to Metal-compatible INT4 format
if args.qlinear:
quantize_experts_metal(model, config, args.qlinear_group_size)

if args.qlinear:
from executorch.extension.llm.export.quantize import quantize_model_

# skip_incompatible_shapes skips shared_expert_gate (N=1, N%4!=0)
quantize_model_(
model,
qlinear_config=args.qlinear,
qlinear_group_size=args.qlinear_group_size,
skip_incompatible_shapes=True,
)

_materialize_buffers(model, config)
metal_source_transformations(model, config=config)


def load_and_quantize(args): # noqa: C901
"""Load model from checkpoint, optionally quantize.

Expand Down Expand Up @@ -152,6 +180,11 @@ def load_and_quantize(args): # noqa: C901
)
_prepare_and_quantize_mlx(model, config, args)

elif backend == "metal":
if args.prequantized:
raise ValueError("Metal backend does not support --prequantized.")
_prepare_and_quantize_metal(model, config, args)

elif backend == "cuda":
if args.prequantized:
return load_prequantized_model(
Expand Down Expand Up @@ -516,6 +549,8 @@ def export_and_lower(model, config, args):

if backend == "mlx":
_export_mlx(model, config, args)
elif backend == "metal":
_export_metal(model, config, args)
else:
_export_cuda(model, config, args)

Expand Down Expand Up @@ -600,6 +635,100 @@ def _export_mlx(model, config, args):
print("Done!")


def _export_metal(model, config, args):
"""Export model to .pte via torch.export + Metal backend."""
import torch._inductor.config as inductor_config

from executorch.backends.apple.metal.metal_backend import MetalBackend
from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.exir.passes import MemoryPlanningPass
from torch.export import Dim, export

inductor_config.coordinate_descent_tuning = False
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"

# --- Decode method (T=1, static shape) ---
print("Exporting decode method...")
decode_tokens = torch.tensor([[0]], dtype=torch.long)
decode_pos = torch.tensor([0], dtype=torch.long)
with torch.no_grad():
decode_ep = export(model, (decode_tokens, decode_pos), strict=True)
print("Decode export successful!")

# --- Prefill method (T>=2, dynamic shape) ---
print("Exporting prefill method...")
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
prefill_dynamic_shapes = ({1: seq_dim}, {0: seq_dim})
with torch.no_grad():
prefill_ep = export(
model,
(prefill_tokens, prefill_pos),
dynamic_shapes=prefill_dynamic_shapes,
strict=True,
)
print("Prefill export successful!")

# Lower with Metal backend
print("Lowering to ExecuTorch with Metal...")
metadata = {
"get_max_seq_len": config.max_seq_len,
"get_vocab_size": config.vocab_size,
"get_n_layers": config.num_hidden_layers,
"use_kv_cache": True,
"use_sdpa_with_kv_cache": False,
"enable_dynamic_shape": True,
}
et_prog = to_edge_transform_and_lower(
{"decode": decode_ep, "prefill": prefill_ep},
partitioner={
"decode": [
MetalPartitioner(
[MetalBackend.generate_method_name_compile_spec("decode")]
)
],
"prefill": [
MetalPartitioner(
[MetalBackend.generate_method_name_compile_spec("prefill")]
)
],
},
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
constant_methods=metadata,
)
et_program = et_prog.to_executorch(
config=ExecutorchBackendConfig(
extract_delegate_segments=True,
do_quant_fusion_and_const_prop=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
),
)

# Save .pte
os.makedirs(args.output_dir, exist_ok=True)
pte_path = os.path.join(args.output_dir, "model.pte")
print(f"Saving to {pte_path}...")
with open(pte_path, "wb") as f:
et_program.write_to_file(f)
size_mb = os.path.getsize(pte_path) / (1024 * 1024)
print(f"Saved {size_mb:.1f} MB")

if et_program._tensor_data:
et_program.write_tensor_data_to_file(args.output_dir)
print(f"Saved tensor data to {args.output_dir}/")

print("Done!")


def _export_cuda(model, config, args):
"""Export model to .pte via torch.export + CUDA backend.

Expand Down Expand Up @@ -739,10 +868,8 @@ def _export_cuda(model, config, args):
# ---------------------------------------------------------------------------


def main():
parser = argparse.ArgumentParser(
description="Export Qwen3.5 MoE to ExecuTorch (CUDA or MLX)"
)
def main(): # noqa: C901
parser = argparse.ArgumentParser(description="Export Qwen3.5 MoE to ExecuTorch")
parser.add_argument(
"--model-dir",
default=None,
Expand All @@ -760,13 +887,13 @@ def main():
parser.add_argument(
"--backend",
default="cuda",
choices=["cuda", "mlx"],
help="Backend for export: cuda (default) or mlx.",
choices=["cuda", "mlx", "metal"],
help="Backend for export: cuda (default), mlx, or metal.",
)
parser.add_argument(
"--qlinear",
default=None,
choices=["4w", "8w", "8da4w", "8da8w"],
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
help="Quantize linear layers.",
)
parser.add_argument(
Expand Down Expand Up @@ -841,6 +968,13 @@ def main():
if args.turboquant:
parser.error("--turboquant is not supported with --backend mlx")

if args.backend == "metal":
if args.turboquant:
parser.error("--turboquant is not supported with --backend metal")

if args.qlinear == "fpa4w" and args.backend != "metal":
parser.error("--qlinear=fpa4w can only be used with --backend=metal")

model, config = load_and_quantize(args)

if args.backend == "cuda":
Expand Down
3 changes: 3 additions & 0 deletions extension/llm/export/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def _check_shape_compatible(m, fqn, config_name, group_size, skip_incompatible_s
shape = m.weight.shape
if config_name == "nvfp4":
compatible = shape[-2] % group_size == 0 and shape[-1] % group_size == 0
elif config_name == "fpa4w":
# MPS UIntx kernel requires N % 4 == 0 when M > 1 (e.g. prefill)
compatible = shape[-1] % group_size == 0 and shape[-2] % 4 == 0
elif group_size != 0:
compatible = shape[-1] % group_size == 0
else:
Expand Down
Loading