diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index db56b686ba5..a65bbe248e2 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -22,7 +22,7 @@ import re from dataclasses import dataclass, field from enum import Enum -from typing import ClassVar, List, Optional +from typing import ClassVar, Dict, List, Optional ################################################################################ @@ -287,6 +287,37 @@ class DebugConfig: verbose: bool = False +################################################################################ +############################## MultimethodLoraConfig ########################### +################################################################################ + + +@dataclass +class MultimethodLoraConfig: + """Configuration for exporting multiple methods to a single .pte file. + + Maps method names to optional LoRA configurations. A None value means + the method uses base model weights. + + Attributes: + methods: Dict mapping method names to optional LoRA configs. + Empty dict disables multimethod export. + + Example: + MultimethodLoraConfig(methods={ + "forward": None, # base model + "lora_forward": lora_config, # LoRA variant + }) + """ + + methods: Dict[str, Optional[LoraConfig]] = field(default_factory=dict) + + @property + def enabled(self) -> bool: + """Returns True if multimethod_lora export is configured.""" + return len(self.methods) > 0 + + ################################################################################ ############################# QuantizationConfig ############################### ################################################################################ @@ -543,6 +574,9 @@ class LlmConfig: model: ModelConfig = field(default_factory=ModelConfig) export: ExportConfig = field(default_factory=ExportConfig) debug: DebugConfig = field(default_factory=DebugConfig) + multimethod_lora: MultimethodLoraConfig = field( + default_factory=MultimethodLoraConfig + ) quantization: QuantizationConfig = field(default_factory=QuantizationConfig) backend: BackendConfig = field(default_factory=BackendConfig)