|
16 | 16 | import re |
17 | 17 | import shlex |
18 | 18 | from functools import partial |
19 | | - |
20 | 19 | from importlib import resources as _resources |
21 | 20 | from json import JSONDecodeError |
22 | 21 | from pathlib import Path |
23 | | -from typing import Callable, List, Optional, Union |
| 22 | +from typing import Callable, Dict, List, Optional, Union |
24 | 23 |
|
25 | 24 | import torch |
26 | | - |
27 | 25 | from executorch.devtools.backend_debug import print_delegation_info |
28 | 26 | from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func |
29 | 27 | from executorch.examples.models.llama.hf_download import ( |
30 | 28 | download_and_convert_hf_checkpoint, |
31 | 29 | ) |
| 30 | +from executorch.exir import to_edge_transform_and_lower |
| 31 | +from executorch.exir.backend.partitioner import Partitioner |
32 | 32 | from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass |
33 | 33 | from executorch.extension.llm.export.builder import DType, LLMEdgeManager |
34 | 34 | from executorch.extension.llm.export.config.llm_config import LlmConfig |
|
52 | 52 | ) |
53 | 53 | from executorch.util.activation_memory_profiler import generate_memory_trace |
54 | 54 | from omegaconf import DictConfig |
| 55 | +from torch.export import ExportedProgram |
55 | 56 |
|
56 | 57 | from ..model_factory import EagerModelFactory |
57 | 58 | from .source_transformation.apply_spin_quant_r1_r2 import ( |
@@ -852,6 +853,28 @@ def _validate_args(llm_config): |
852 | 853 | "Shared embedding is only supported with torchao quantization." |
853 | 854 | ) |
854 | 855 |
|
| 856 | + if llm_config.multimethod_lora.enabled: |
| 857 | + if llm_config.base.lora_config is not None: |
| 858 | + raise ValueError( |
| 859 | + "Cannot use both base.lora_config and multimethod_lora.methods. " |
| 860 | + "Use multimethod_lora.methods for all LoRA variants." |
| 861 | + ) |
| 862 | + if llm_config.quantization.pt2e_quantize is not None: |
| 863 | + raise ValueError( |
| 864 | + "PT2E quantization is not supported with multimethod_lora export." |
| 865 | + ) |
| 866 | + if ( |
| 867 | + llm_config.backend.coreml.enabled |
| 868 | + or llm_config.backend.vulkan.enabled |
| 869 | + or llm_config.backend.qnn.enabled |
| 870 | + or llm_config.backend.mps.enabled |
| 871 | + or llm_config.backend.openvino.enabled |
| 872 | + ): |
| 873 | + raise ValueError( |
| 874 | + "multimethod_lora export only supports XNNPACK backend or portable ops" |
| 875 | + "Please disable other backends (coreml, vulkan, qnn, mps, openvino)." |
| 876 | + ) |
| 877 | + |
855 | 878 |
|
856 | 879 | def _to_edge_and_lower_llama_xnnpack( |
857 | 880 | builder_exported, |
@@ -946,7 +969,6 @@ def _to_edge_and_lower_llama_tosa( |
946 | 969 | tosa_spec, |
947 | 970 | verbose: bool = False, |
948 | 971 | ) -> LLMEdgeManager: |
949 | | - |
950 | 972 | logging.info("Lowering model using TOSA partitioner") |
951 | 973 |
|
952 | 974 | partitioners = [] |
@@ -1141,9 +1163,126 @@ def _to_edge_and_lower_llama( # noqa: C901 |
1141 | 1163 | return builder |
1142 | 1164 |
|
1143 | 1165 |
|
| 1166 | +def _get_xnnpack_partitioners(llm_config: LlmConfig) -> Optional[List[Partitioner]]: |
| 1167 | + """Get XNNPACK partitioners for multimethod_lora export.""" |
| 1168 | + partitioners = [] |
| 1169 | + |
| 1170 | + if llm_config.backend.xnnpack.enabled: |
| 1171 | + partitioners.append( |
| 1172 | + get_xnnpack_partitioner(dynamic_quant_only_partitioner=True) |
| 1173 | + ) |
| 1174 | + if llm_config.backend.xnnpack.extended_ops: |
| 1175 | + partitioners.append( |
| 1176 | + get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) |
| 1177 | + ) |
| 1178 | + |
| 1179 | + return partitioners if partitioners else None |
| 1180 | + |
| 1181 | + |
| 1182 | +def _get_output_filename( |
| 1183 | + llm_config: LlmConfig, modelname: str, output_dir: str, dtype: DType |
| 1184 | +) -> str: |
| 1185 | + """Determine output filename for the .pte file.""" |
| 1186 | + if dtype == DType.fp16: |
| 1187 | + modelname = f"{modelname}_h" |
| 1188 | + |
| 1189 | + if llm_config.export.output_name: |
| 1190 | + output_name = llm_config.export.output_name |
| 1191 | + if output_name.endswith(".pte"): |
| 1192 | + return output_name |
| 1193 | + else: |
| 1194 | + return f"{output_dir}/{output_name}.pte" |
| 1195 | + else: |
| 1196 | + return f"{output_dir}/{modelname}.pte" |
| 1197 | + |
| 1198 | + |
| 1199 | +def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: |
| 1200 | + """ |
| 1201 | + Export multiple methods (base + LoRA variants) to a single .pte file. |
| 1202 | +
|
| 1203 | + For each method in llm_config.multimethod_lora.methods: |
| 1204 | + - If LoraConfig is None: use base model |
| 1205 | + - If LoraConfig is provided: create model with LoRA weights |
| 1206 | +
|
| 1207 | + Limitations: |
| 1208 | + - Only XNNPACK backend is supported for multimethod_lora export. |
| 1209 | + - PT2E quantization is not supported. |
| 1210 | + - Each method is exported separately; export time scales linearly |
| 1211 | + with the number of methods. |
| 1212 | + - The final .pte file deduplicates shared weights automatically. |
| 1213 | + """ |
| 1214 | + num_methods = len(llm_config.multimethod_lora.methods) |
| 1215 | + logging.info( |
| 1216 | + f"multimethod_lora export: exporting {num_methods} method(s). " |
| 1217 | + "Each method requires separate model instantiation and export." |
| 1218 | + ) |
| 1219 | + |
| 1220 | + additional_passes = [] |
| 1221 | + if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: |
| 1222 | + additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] |
| 1223 | + |
| 1224 | + # Build dict of exported programs |
| 1225 | + method_to_program: Dict[str, ExportedProgram] = {} |
| 1226 | + first_builder = None |
| 1227 | + |
| 1228 | + for method_name, lora_config in llm_config.multimethod_lora.methods.items(): |
| 1229 | + logging.info(f"Exporting method: {method_name}") |
| 1230 | + |
| 1231 | + # Create a copy of config with this method's LoRA setting |
| 1232 | + method_config = copy.deepcopy(llm_config) |
| 1233 | + method_config.base.lora_config = lora_config |
| 1234 | + # Disable multimethod_lora to avoid infinite recursion |
| 1235 | + method_config.multimethod_lora.methods = {} |
| 1236 | + |
| 1237 | + # Load and prepare model for this method |
| 1238 | + builder = _prepare_for_llama_export(method_config) |
| 1239 | + builder = builder.export() |
| 1240 | + builder.run_canonical_optimizations() |
| 1241 | + |
| 1242 | + # Get the exported program |
| 1243 | + exported_program = builder._export(builder.pre_autograd_graph_module) |
| 1244 | + method_to_program[method_name] = exported_program |
| 1245 | + |
| 1246 | + if first_builder is None: |
| 1247 | + first_builder = builder |
| 1248 | + |
| 1249 | + assert first_builder is not None, "No methods to export" |
| 1250 | + |
| 1251 | + # Get partitioners based on backend config |
| 1252 | + partitioners = _get_xnnpack_partitioners(llm_config) |
| 1253 | + |
| 1254 | + # Lower all methods together using multimethod_lora API |
| 1255 | + edge_config = first_builder._get_edge_config() |
| 1256 | + edge_manager = to_edge_transform_and_lower( |
| 1257 | + method_to_program, |
| 1258 | + partitioner=partitioners, |
| 1259 | + compile_config=edge_config, |
| 1260 | + constant_methods=first_builder.metadata, |
| 1261 | + generate_etrecord=llm_config.debug.generate_etrecord, |
| 1262 | + ) |
| 1263 | + |
| 1264 | + # Convert to executorch and save |
| 1265 | + first_builder.edge_manager = edge_manager |
| 1266 | + first_builder = first_builder.to_executorch(passes=additional_passes) |
| 1267 | + |
| 1268 | + output_file = _get_output_filename( |
| 1269 | + llm_config, |
| 1270 | + first_builder.modelname, |
| 1271 | + first_builder.output_dir, |
| 1272 | + first_builder.dtype, |
| 1273 | + ) |
| 1274 | + first_builder.save_to_pte(output_file) |
| 1275 | + |
| 1276 | + return first_builder |
| 1277 | + |
| 1278 | + |
1144 | 1279 | def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 |
1145 | 1280 | _validate_args(llm_config) |
1146 | 1281 |
|
| 1282 | + # Check for multimethod_lora export |
| 1283 | + if llm_config.multimethod_lora.enabled: |
| 1284 | + return _export_llama_multimethod(llm_config) |
| 1285 | + |
1147 | 1286 | pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( |
1148 | 1287 | llm_config |
1149 | 1288 | ) |
@@ -1247,23 +1386,12 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 |
1247 | 1386 | if llm_config.debug.profile_memory: |
1248 | 1387 | generate_memory_trace(builder.export_program, "memory_profile.json") |
1249 | 1388 |
|
1250 | | - if builder.dtype == DType.fp16: |
1251 | | - modelname = f"{modelname}_h" |
1252 | | - |
1253 | | - if llm_config.export.output_name: |
1254 | | - modelname = llm_config.export.output_name |
1255 | | - if modelname.endswith(".pte"): |
1256 | | - output_file = modelname |
1257 | | - modelname = modelname[:-4] |
1258 | | - print(f"modelname: {modelname}") |
1259 | | - print(f"output_file: {output_file}") |
1260 | | - else: |
1261 | | - output_file = f"{builder.output_dir}/{modelname}.pte" |
1262 | | - print(f"modelname: {modelname}") |
1263 | | - print(f"output_file: {output_file}") |
1264 | | - else: |
1265 | | - output_file = f"{builder.output_dir}/{modelname}.pte" |
1266 | | - |
| 1389 | + output_file = _get_output_filename( |
| 1390 | + llm_config, |
| 1391 | + modelname, |
| 1392 | + builder.output_dir, |
| 1393 | + builder.dtype, |
| 1394 | + ) |
1267 | 1395 | builder.save_to_pte(output_file) |
1268 | 1396 | return builder |
1269 | 1397 |
|
|
0 commit comments