Skip to content

Commit ee0c191

Browse files
author
Github Executorch
committed
Support multimethod in export_llama_lib
Pull Request resolved: #17231 TODO: add CI test. Note: multimethod export is currently limited to: - xnnpack or portable lib - only lora (does not support arbitrary nn.Modules in each method) - if quant is enabled, lora models must share quant schemes at source transformation time - no pt2e quant, as each model could have slightly different results after calibration ghstack-source-id: 338339885 @exported-using-ghexport Differential Revision: [D92315602](https://our.internmc.facebook.com/intern/diff/D92315602/)
1 parent e54c208 commit ee0c191

5 files changed

Lines changed: 291 additions & 1 deletion

File tree

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -exu
9+
# shellcheck source=/dev/null
10+
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
11+
12+
cmake_install_executorch_libraries() {
13+
echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a"
14+
rm -rf cmake-out
15+
cmake --workflow llm-release
16+
}
17+
18+
cmake_build_llama_runner() {
19+
echo "Building llama runner"
20+
pushd extension/llm/tokenizers
21+
echo "Updating tokenizers submodule"
22+
git submodule update --init
23+
popd
24+
make llama-cpu
25+
}
26+
27+
cleanup_files() {
28+
echo "Deleting downloaded and generated files"
29+
rm -rf "${HF_QWEN_PATH}/"
30+
rm -rf "${HF_ADAPTER_PATH}/"
31+
rm -rf *.pte
32+
rm -f result*.txt
33+
}
34+
35+
# Download LoRA adapter.
36+
python -m pip install -q huggingface_hub
37+
HF_ADAPTER_REPO="lucylq/qwen3_06B_lora_math"
38+
HF_ADAPTER_PATH=$(
39+
bash "$(dirname "${BASH_SOURCE[0]}")/download_hf_hub.sh" \
40+
--model_id "${HF_ADAPTER_REPO}" \
41+
--files "adapter_config.json" "adapter_model.safetensors"
42+
)
43+
44+
# Download base model (for tokenizer path).
45+
HF_QWEN_PATH=$(python -c "from huggingface_hub import snapshot_download; print(snapshot_download('unsloth/Qwen3-0.6B'))")
46+
echo "Model downloaded to: $HF_QWEN_PATH"
47+
48+
### EXPORT MULTIMETHOD PTE ###
49+
# Set environment variables for OmegaConf interpolation in yaml.
50+
export LORA_ADAPTER_CHECKPOINT="${HF_ADAPTER_PATH}/adapter_model.safetensors"
51+
export LORA_ADAPTER_CONFIG="${HF_ADAPTER_PATH}/adapter_config.json"
52+
53+
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
54+
--config examples/models/qwen3/config/qwen3_multimethod.yaml
55+
56+
### BUILD LLAMA RUNNER ###
57+
cmake_install_executorch_libraries
58+
cmake_build_llama_runner
59+
60+
# Runner constants.
61+
RUNTIME_ARGS="--tokenizer_path=${HF_QWEN_PATH}/ --temperature=0 --seq_len=100 --warmup=1"
62+
PROMPT="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant"
63+
64+
# Expected outputs.
65+
EXPECTED_LORA_PREFIX="
66+
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
67+
To calculate 15% of 80"
68+
69+
EXPECTED_BASE_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant:
70+
<think>
71+
Okay, so I need to calculate 15% of 80."
72+
73+
### TEST 1: Run lora_forward method ###
74+
NOW=$(date +"%H:%M:%S")
75+
echo "Test 1: Multimethod lora_forward. Starting at ${NOW}"
76+
# shellcheck source=/dev/null
77+
cmake-out/examples/models/llama/llama_main \
78+
--model_path=multimethod_qwen.pte \
79+
--method_name=lora_forward \
80+
--prompt="${PROMPT}" \
81+
${RUNTIME_ARGS} > result_lora.txt
82+
NOW=$(date +"%H:%M:%S")
83+
echo "Finished at ${NOW}"
84+
85+
RESULT=$(cat result_lora.txt)
86+
if [[ "${RESULT}" == "${EXPECTED_LORA_PREFIX}"* ]]; then
87+
echo "Expected result prefix: ${EXPECTED_LORA_PREFIX}"
88+
echo "Actual result: ${RESULT}"
89+
echo "Test 1 (lora_forward): Success"
90+
else
91+
echo "Expected result prefix: ${EXPECTED_LORA_PREFIX}"
92+
echo "Actual result: ${RESULT}"
93+
echo "Test 1 (lora_forward): Failure"
94+
cleanup_files
95+
exit 1
96+
fi
97+
98+
### TEST 2: Run base_forward method ###
99+
NOW=$(date +"%H:%M:%S")
100+
echo "Test 2: Multimethod base_forward. Starting at ${NOW}"
101+
# shellcheck source=/dev/null
102+
cmake-out/examples/models/llama/llama_main \
103+
--model_path=multimethod_qwen.pte \
104+
--method_name=base_forward \
105+
--prompt="${PROMPT}" \
106+
${RUNTIME_ARGS} > result_base.txt
107+
NOW=$(date +"%H:%M:%S")
108+
echo "Finished at ${NOW}"
109+
110+
RESULT=$(cat result_base.txt)
111+
if [[ "${RESULT}" == "${EXPECTED_BASE_PREFIX}"* ]]; then
112+
echo "Expected result prefix: ${EXPECTED_BASE_PREFIX}"
113+
echo "Actual result: ${RESULT}"
114+
echo "Test 2 (base_forward): Success"
115+
else
116+
echo "Expected result prefix: ${EXPECTED_BASE_PREFIX}"
117+
echo "Actual result: ${RESULT}"
118+
echo "Test 2 (base_forward): Failure"
119+
cleanup_files
120+
exit 1
121+
fi
122+
123+
echo "Multimethod tests passed!"
124+
cleanup_files

examples/models/llama/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ fbcode_target(_kind = runtime.python_library,
148148
fbcode_target(_kind = runtime.python_library,
149149
name = "export_library",
150150
srcs = [
151+
"convert_weights.py",
151152
"export_llama.py",
152153
"export_llama_lib.py",
153154
"model.py",

examples/models/llama/export_llama_lib.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@
2020
from importlib import resources as _resources
2121
from json import JSONDecodeError
2222
from pathlib import Path
23-
from typing import Callable, List, Optional, Union
23+
from typing import Callable, Dict, List, Optional, Union
2424

2525
import torch
26+
from torch.export import ExportedProgram
2627

2728
from executorch.devtools.backend_debug import print_delegation_info
2829
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
2930
from executorch.examples.models.llama.hf_download import (
3031
download_and_convert_hf_checkpoint,
3132
)
33+
from executorch.exir import to_edge_transform_and_lower
3234
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
3335
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
3436
from executorch.extension.llm.export.config.llm_config import LlmConfig
@@ -844,6 +846,28 @@ def _validate_args(llm_config):
844846
"Shared embedding is only supported with torchao quantization."
845847
)
846848

849+
if llm_config.multimethod.enabled:
850+
if llm_config.base.lora is not None:
851+
raise ValueError(
852+
"Cannot use both base.lora and multimethod.methods. "
853+
"Use multimethod.methods for all LoRA variants."
854+
)
855+
if llm_config.quantization.pt2e_quantize is not None:
856+
raise ValueError(
857+
"PT2E quantization is not supported with multimethod export."
858+
)
859+
if (
860+
llm_config.backend.coreml.enabled
861+
or llm_config.backend.vulkan.enabled
862+
or llm_config.backend.qnn.enabled
863+
or llm_config.backend.mps.enabled
864+
or llm_config.backend.openvino.enabled
865+
):
866+
raise ValueError(
867+
"Multimethod export only supports XNNPACK backend or portable ops"
868+
"Please disable other backends (coreml, vulkan, qnn, mps, openvino)."
869+
)
870+
847871

848872
def _to_edge_and_lower_llama_xnnpack(
849873
builder_exported,
@@ -1107,9 +1131,121 @@ def _to_edge_and_lower_llama( # noqa: C901
11071131
return builder
11081132

11091133

1134+
def _get_xnnpack_partitioners(llm_config: LlmConfig) -> Optional[List]:
1135+
"""Get XNNPACK partitioners for multimethod export."""
1136+
partitioners = []
1137+
1138+
if llm_config.backend.xnnpack.enabled:
1139+
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))
1140+
if llm_config.backend.xnnpack.extended_ops:
1141+
partitioners.append(
1142+
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
1143+
)
1144+
1145+
return partitioners if partitioners else None
1146+
1147+
1148+
def _get_output_filename(llm_config: LlmConfig, modelname: str, output_dir: str, dtype: DType) -> str:
1149+
"""Determine output filename for the .pte file."""
1150+
if dtype == DType.fp16:
1151+
modelname = f"{modelname}_h"
1152+
1153+
if llm_config.export.output_name:
1154+
output_name = llm_config.export.output_name
1155+
if output_name.endswith(".pte"):
1156+
return output_name
1157+
else:
1158+
return f"{output_dir}/{output_name}.pte"
1159+
else:
1160+
return f"{output_dir}/{modelname}.pte"
1161+
1162+
1163+
def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
1164+
"""
1165+
Export multiple methods (base + LoRA variants) to a single .pte file.
1166+
1167+
For each method in llm_config.multimethod.methods:
1168+
- If LoraConfig is None: use base model
1169+
- If LoraConfig is provided: create model with LoRA weights
1170+
1171+
Limitations:
1172+
- Only XNNPACK backend is supported for multimethod export.
1173+
- PT2E quantization is not supported.
1174+
- Each method is exported separately; export time scales linearly
1175+
with the number of methods.
1176+
- The final .pte file deduplicates shared weights automatically.
1177+
"""
1178+
num_methods = len(llm_config.multimethod.methods)
1179+
logging.info(
1180+
f"Multimethod export: exporting {num_methods} method(s). "
1181+
"Each method requires separate model instantiation and export."
1182+
)
1183+
1184+
additional_passes = []
1185+
if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS:
1186+
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
1187+
1188+
# Build dict of exported programs
1189+
method_to_program: Dict[str, ExportedProgram] = {}
1190+
first_builder = None
1191+
1192+
for method_name, lora_config in llm_config.multimethod.methods.items():
1193+
logging.info(f"Exporting method: {method_name}")
1194+
1195+
# Create a copy of config with this method's LoRA setting
1196+
method_config = copy.deepcopy(llm_config)
1197+
method_config.base.lora = lora_config
1198+
# Disable multimethod to avoid infinite recursion
1199+
method_config.multimethod.methods = {}
1200+
1201+
# Load and prepare model for this method
1202+
builder = _prepare_for_llama_export(method_config)
1203+
builder = builder.export()
1204+
builder.run_canonical_optimizations()
1205+
1206+
# Get the exported program
1207+
exported_program = builder._export(builder.pre_autograd_graph_module)
1208+
method_to_program[method_name] = exported_program
1209+
1210+
if first_builder is None:
1211+
first_builder = builder
1212+
1213+
assert first_builder is not None, "No methods to export"
1214+
1215+
# Get partitioners based on backend config
1216+
partitioners = _get_xnnpack_partitioners(llm_config)
1217+
1218+
# Lower all methods together using multimethod API
1219+
edge_config = first_builder._get_edge_config()
1220+
edge_manager = to_edge_transform_and_lower(
1221+
method_to_program,
1222+
partitioner=partitioners,
1223+
compile_config=edge_config,
1224+
constant_methods=first_builder.metadata,
1225+
)
1226+
1227+
# Convert to executorch and save
1228+
first_builder.edge_manager = edge_manager
1229+
first_builder = first_builder.to_executorch(passes=additional_passes)
1230+
1231+
output_file = _get_output_filename(
1232+
llm_config,
1233+
first_builder.modelname,
1234+
first_builder.output_dir,
1235+
first_builder.dtype,
1236+
)
1237+
first_builder.save_to_pte(output_file)
1238+
1239+
return first_builder
1240+
1241+
11101242
def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11111243
_validate_args(llm_config)
11121244

1245+
# Check for multimethod export
1246+
if llm_config.multimethod.enabled:
1247+
return _export_llama_multimethod(llm_config)
1248+
11131249
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
11141250
llm_config
11151251
)

examples/models/llama/runner/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def define_common_targets():
4747
"//executorch/examples/models/llama/tokenizer:tiktoken",
4848
"//pytorch/tokenizers:llama2c_tokenizer",
4949
"//pytorch/tokenizers:hf_tokenizer",
50+
"//pytorch/tokenizers:regex_lookahead",
5051
] + (_get_operator_lib(aten)) + ([
5152
# Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE)
5253
# Therefore enable it explicitly for now to avoid failing tests
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
base:
2+
model_class: "qwen3_0_6b"
3+
params: "examples/models/qwen3/config/0_6b_config.json"
4+
metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}'
5+
6+
model:
7+
use_kv_cache: true
8+
use_sdpa_with_kv_cache: true
9+
10+
export:
11+
output_name: multimethod_qwen
12+
13+
backend:
14+
xnnpack:
15+
enabled: true
16+
17+
quantization:
18+
qmode: "8da4w"
19+
group_size: 32
20+
21+
multimethod:
22+
methods:
23+
# LoRA method - adapter paths from environment variables
24+
lora_forward:
25+
adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT}
26+
adapter_config: ${oc.env:LORA_ADAPTER_CONFIG}
27+
# Base method - no LoRA
28+
base_forward: null

0 commit comments

Comments
 (0)