Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,15 +897,15 @@ def _validate_args(llm_config):
"Shared embedding is only supported with torchao quantization."
)

if llm_config.multimethod_lora.enabled:
if llm_config.multimethod.enabled:
if llm_config.base.lora_config is not None:
raise ValueError(
"Cannot use both base.lora_config and multimethod_lora.methods. "
"Use multimethod_lora.methods for all LoRA variants."
"Cannot use both base.lora_config and multimethod.methods. "
"Use multimethod.methods for all LoRA variants."
)
if llm_config.quantization.pt2e_quantize is not None:
raise ValueError(
"PT2E quantization is not supported with multimethod_lora export."
"PT2E quantization is not supported with multimethod export."
)
if (
llm_config.backend.coreml.enabled
Expand All @@ -915,7 +915,7 @@ def _validate_args(llm_config):
or llm_config.backend.openvino.enabled
):
raise ValueError(
"multimethod_lora export only supports XNNPACK backend or portable ops"
"multimethod export only supports XNNPACK backend or portable ops. "
"Please disable other backends (coreml, vulkan, qnn, mps, openvino)."
)

Expand Down Expand Up @@ -1230,7 +1230,7 @@ def _to_edge_and_lower_llama( # noqa: C901


def _get_xnnpack_partitioners(llm_config: LlmConfig) -> Optional[List[Partitioner]]:
"""Get XNNPACK partitioners for multimethod_lora export."""
"""Get XNNPACK partitioners for multimethod export."""
partitioners = []

# Order matters here, dynamic quantization should be applied first when
Expand Down Expand Up @@ -1268,20 +1268,20 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
"""
Export multiple methods (base + LoRA variants) to a single .pte file.

For each method in llm_config.multimethod_lora.methods:
For each method in llm_config.multimethod.methods:
- If LoraConfig is None: use base model
- If LoraConfig is provided: create model with LoRA weights

Limitations:
- Only XNNPACK backend is supported for multimethod_lora export.
- Only XNNPACK backend is supported for multimethod export.
- PT2E quantization is not supported.
- Each method is exported separately; export time scales linearly
with the number of methods.
- The final .pte file deduplicates shared weights automatically.
"""
num_methods = len(llm_config.multimethod_lora.methods)
num_methods = len(llm_config.multimethod.methods)
logging.info(
f"multimethod_lora export: exporting {num_methods} method(s). "
f"multimethod export: exporting {num_methods} method(s). "
"Each method requires separate model instantiation and export."
)

Expand All @@ -1293,14 +1293,14 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
method_to_program: Dict[str, ExportedProgram] = {}
first_builder = None

for method_name, lora_config in llm_config.multimethod_lora.methods.items():
logging.info(f"Exporting method: {method_name}")
for method in llm_config.multimethod.methods:
logging.info(f"Exporting method: {method.method_name}")

# Create a copy of config with this method's LoRA setting
method_config = copy.deepcopy(llm_config)
method_config.base.lora_config = lora_config
# Disable multimethod_lora to avoid infinite recursion
method_config.multimethod_lora.methods = {}
method_config.base.lora_config = method.lora_config
# Disable multimethod to avoid infinite recursion
method_config.multimethod.methods = []

# Load and prepare model for this method
builder = _prepare_for_llama_export(method_config)
Expand All @@ -1309,7 +1309,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:

# Get the exported program
exported_program = builder._export(builder.pre_autograd_graph_module)
method_to_program[method_name] = exported_program
method_to_program[method.method_name] = exported_program

if first_builder is None:
first_builder = builder
Expand All @@ -1319,7 +1319,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
# Get partitioners based on backend config
partitioners = _get_xnnpack_partitioners(llm_config)

# Lower all methods together using multimethod_lora API
# Lower all methods together using multimethod API
edge_config = first_builder._get_edge_config()
edge_manager = to_edge_transform_and_lower(
method_to_program,
Expand All @@ -1333,7 +1333,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
first_builder.edge_manager = edge_manager
first_builder = first_builder.to_executorch(
passes=additional_passes,
share_mutable_buffers=llm_config.multimethod_lora.share_mutable_buffers,
share_mutable_buffers=llm_config.multimethod.share_mutable_buffers,
)

output_file = _get_output_filename(
Expand All @@ -1350,8 +1350,8 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
_validate_args(llm_config)

# Check for multimethod_lora export
if llm_config.multimethod_lora.enabled:
# Check for multimethod export
if llm_config.multimethod.enabled:
return _export_llama_multimethod(llm_config)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
Expand Down
11 changes: 6 additions & 5 deletions examples/models/qwen3/config/qwen3_multimethod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ quantization:
qmode: "8da4w"
group_size: 32

multimethod_lora:
multimethod:
methods:
# LoRA method - adapter paths from environment variables
lora_forward:
adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT}
adapter_config: ${oc.env:LORA_ADAPTER_CONFIG}
- method_name: lora_forward
lora_config:
adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT}
adapter_config: ${oc.env:LORA_ADAPTER_CONFIG}
# Base method - no LoRA
base_forward: null
- method_name: base_forward
share_mutable_buffers: True
45 changes: 29 additions & 16 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import ClassVar, Dict, List, Optional
from typing import ClassVar, List, Optional


################################################################################
Expand Down Expand Up @@ -293,37 +293,52 @@ class DebugConfig:


################################################################################
############################## MultimethodLoraConfig ###########################
############################## MultimethodConfig ###########################
################################################################################


@dataclass
class MultimethodLoraConfig:
class MethodConfig:
"""Configuration for exporting a single method to a .pte file.
By default, all other fields fall back to the default configs in
the yaml file.

Attributes:
method_name: Name of the method to export.
lora_config: Optional LoRA configuration.
"""

method_name: str
lora_config: Optional[LoraConfig] = None


@dataclass
class MultimethodConfig:
"""Configuration for exporting multiple methods to a single .pte file.

Maps method names to optional LoRA configurations. A None value means
the method uses base model weights.
Holds a list of method configs, as well as global options that apply
across all methods.

Attributes:
methods: Dict mapping method names to optional LoRA configs.
Empty dict disables multimethod_lora export.
methods: List of MethodConfig objects with method name and config
for each method.
share_mutable_buffers: Whether to share mutable buffers across methods.
If True, sets all mutable buffers to mem_id=2. Mutable buffers with
the same FQN (fully qualified name) will have the same offset.

Example:
MultimethodLoraConfig(methods={
"forward": None, # base model
"lora_forward": lora_config, # LoRA variant
})
MultimethodConfig(methods=[
MethodConfig("forward", lora_config=None), # base model
MethodConfig("lora_forward", lora_config=lora_config), # LoRA variant
])
"""

methods: Dict[str, Optional[LoraConfig]] = field(default_factory=dict)
methods: List[MethodConfig] = field(default_factory=list)
share_mutable_buffers: bool = False

@property
def enabled(self) -> bool:
"""Returns True if multimethod_lora export is configured."""
"""Returns True if multimethod export is configured."""
return len(self.methods) > 0


Expand Down Expand Up @@ -611,9 +626,7 @@ class LlmConfig:
model: ModelConfig = field(default_factory=ModelConfig)
export: ExportConfig = field(default_factory=ExportConfig)
debug: DebugConfig = field(default_factory=DebugConfig)
multimethod_lora: MultimethodLoraConfig = field(
default_factory=MultimethodLoraConfig
)
multimethod: MultimethodConfig = field(default_factory=MultimethodConfig)
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
backend: BackendConfig = field(default_factory=BackendConfig)

Expand Down
Loading