|
20 | 20 | from importlib import resources as _resources |
21 | 21 | from json import JSONDecodeError |
22 | 22 | from pathlib import Path |
23 | | -from typing import Callable, List, Optional, Union |
| 23 | +from typing import Callable, Dict, List, Optional, Union |
24 | 24 |
|
25 | 25 | import torch |
| 26 | +from torch.export import ExportedProgram |
26 | 27 |
|
27 | 28 | from executorch.devtools.backend_debug import print_delegation_info |
28 | 29 | from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func |
29 | 30 | from executorch.examples.models.llama.hf_download import ( |
30 | 31 | download_and_convert_hf_checkpoint, |
31 | 32 | ) |
| 33 | +from executorch.exir import to_edge_transform_and_lower |
32 | 34 | from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass |
33 | 35 | from executorch.extension.llm.export.builder import DType, LLMEdgeManager |
34 | 36 | from executorch.extension.llm.export.config.llm_config import LlmConfig |
@@ -844,6 +846,28 @@ def _validate_args(llm_config): |
844 | 846 | "Shared embedding is only supported with torchao quantization." |
845 | 847 | ) |
846 | 848 |
|
| 849 | + if llm_config.multimethod.enabled: |
| 850 | + if llm_config.base.lora is not None: |
| 851 | + raise ValueError( |
| 852 | + "Cannot use both base.lora and multimethod.methods. " |
| 853 | + "Use multimethod.methods for all LoRA variants." |
| 854 | + ) |
| 855 | + if llm_config.quantization.pt2e_quantize is not None: |
| 856 | + raise ValueError( |
| 857 | + "PT2E quantization is not supported with multimethod export." |
| 858 | + ) |
| 859 | + if ( |
| 860 | + llm_config.backend.coreml.enabled |
| 861 | + or llm_config.backend.vulkan.enabled |
| 862 | + or llm_config.backend.qnn.enabled |
| 863 | + or llm_config.backend.mps.enabled |
| 864 | + or llm_config.backend.openvino.enabled |
| 865 | + ): |
| 866 | + raise ValueError( |
| 867 | + "Multimethod export only supports XNNPACK backend or portable ops" |
| 868 | + "Please disable other backends (coreml, vulkan, qnn, mps, openvino)." |
| 869 | + ) |
| 870 | + |
847 | 871 |
|
848 | 872 | def _to_edge_and_lower_llama_xnnpack( |
849 | 873 | builder_exported, |
@@ -1107,9 +1131,121 @@ def _to_edge_and_lower_llama( # noqa: C901 |
1107 | 1131 | return builder |
1108 | 1132 |
|
1109 | 1133 |
|
| 1134 | +def _get_xnnpack_partitioners(llm_config: LlmConfig) -> Optional[List]: |
| 1135 | + """Get XNNPACK partitioners for multimethod export.""" |
| 1136 | + partitioners = [] |
| 1137 | + |
| 1138 | + if llm_config.backend.xnnpack.enabled: |
| 1139 | + partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)) |
| 1140 | + if llm_config.backend.xnnpack.extended_ops: |
| 1141 | + partitioners.append( |
| 1142 | + get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) |
| 1143 | + ) |
| 1144 | + |
| 1145 | + return partitioners if partitioners else None |
| 1146 | + |
| 1147 | + |
| 1148 | +def _get_output_filename(llm_config: LlmConfig, modelname: str, output_dir: str, dtype: DType) -> str: |
| 1149 | + """Determine output filename for the .pte file.""" |
| 1150 | + if dtype == DType.fp16: |
| 1151 | + modelname = f"{modelname}_h" |
| 1152 | + |
| 1153 | + if llm_config.export.output_name: |
| 1154 | + output_name = llm_config.export.output_name |
| 1155 | + if output_name.endswith(".pte"): |
| 1156 | + return output_name |
| 1157 | + else: |
| 1158 | + return f"{output_dir}/{output_name}.pte" |
| 1159 | + else: |
| 1160 | + return f"{output_dir}/{modelname}.pte" |
| 1161 | + |
| 1162 | + |
| 1163 | +def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: |
| 1164 | + """ |
| 1165 | + Export multiple methods (base + LoRA variants) to a single .pte file. |
| 1166 | +
|
| 1167 | + For each method in llm_config.multimethod.methods: |
| 1168 | + - If LoraConfig is None: use base model |
| 1169 | + - If LoraConfig is provided: create model with LoRA weights |
| 1170 | +
|
| 1171 | + Limitations: |
| 1172 | + - Only XNNPACK backend is supported for multimethod export. |
| 1173 | + - PT2E quantization is not supported. |
| 1174 | + - Each method is exported separately; export time scales linearly |
| 1175 | + with the number of methods. |
| 1176 | + - The final .pte file deduplicates shared weights automatically. |
| 1177 | + """ |
| 1178 | + num_methods = len(llm_config.multimethod.methods) |
| 1179 | + logging.info( |
| 1180 | + f"Multimethod export: exporting {num_methods} method(s). " |
| 1181 | + "Each method requires separate model instantiation and export." |
| 1182 | + ) |
| 1183 | + |
| 1184 | + additional_passes = [] |
| 1185 | + if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: |
| 1186 | + additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] |
| 1187 | + |
| 1188 | + # Build dict of exported programs |
| 1189 | + method_to_program: Dict[str, ExportedProgram] = {} |
| 1190 | + first_builder = None |
| 1191 | + |
| 1192 | + for method_name, lora_config in llm_config.multimethod.methods.items(): |
| 1193 | + logging.info(f"Exporting method: {method_name}") |
| 1194 | + |
| 1195 | + # Create a copy of config with this method's LoRA setting |
| 1196 | + method_config = copy.deepcopy(llm_config) |
| 1197 | + method_config.base.lora = lora_config |
| 1198 | + # Disable multimethod to avoid infinite recursion |
| 1199 | + method_config.multimethod.methods = {} |
| 1200 | + |
| 1201 | + # Load and prepare model for this method |
| 1202 | + builder = _prepare_for_llama_export(method_config) |
| 1203 | + builder = builder.export() |
| 1204 | + builder.run_canonical_optimizations() |
| 1205 | + |
| 1206 | + # Get the exported program |
| 1207 | + exported_program = builder._export(builder.pre_autograd_graph_module) |
| 1208 | + method_to_program[method_name] = exported_program |
| 1209 | + |
| 1210 | + if first_builder is None: |
| 1211 | + first_builder = builder |
| 1212 | + |
| 1213 | + assert first_builder is not None, "No methods to export" |
| 1214 | + |
| 1215 | + # Get partitioners based on backend config |
| 1216 | + partitioners = _get_xnnpack_partitioners(llm_config) |
| 1217 | + |
| 1218 | + # Lower all methods together using multimethod API |
| 1219 | + edge_config = first_builder._get_edge_config() |
| 1220 | + edge_manager = to_edge_transform_and_lower( |
| 1221 | + method_to_program, |
| 1222 | + partitioner=partitioners, |
| 1223 | + compile_config=edge_config, |
| 1224 | + constant_methods=first_builder.metadata, |
| 1225 | + ) |
| 1226 | + |
| 1227 | + # Convert to executorch and save |
| 1228 | + first_builder.edge_manager = edge_manager |
| 1229 | + first_builder = first_builder.to_executorch(passes=additional_passes) |
| 1230 | + |
| 1231 | + output_file = _get_output_filename( |
| 1232 | + llm_config, |
| 1233 | + first_builder.modelname, |
| 1234 | + first_builder.output_dir, |
| 1235 | + first_builder.dtype, |
| 1236 | + ) |
| 1237 | + first_builder.save_to_pte(output_file) |
| 1238 | + |
| 1239 | + return first_builder |
| 1240 | + |
| 1241 | + |
1110 | 1242 | def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 |
1111 | 1243 | _validate_args(llm_config) |
1112 | 1244 |
|
| 1245 | + # Check for multimethod export |
| 1246 | + if llm_config.multimethod.enabled: |
| 1247 | + return _export_llama_multimethod(llm_config) |
| 1248 | + |
1113 | 1249 | pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( |
1114 | 1250 | llm_config |
1115 | 1251 | ) |
|
0 commit comments