From f74e91186422ff9a7a5e0cc6704e516a5fdbd045 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Fri, 13 Feb 2026 12:19:05 -0800 Subject: [PATCH 1/3] Support multimethod in runner Pull Request resolved: https://github.com/pytorch/executorch/pull/17228 Specify which method to call in the runner. ghstack-source-id: 340390608 @exported-using-ghexport Differential Revision: [D92225533](https://our.internmc.facebook.com/intern/diff/D92225533/) --- examples/models/llama/main.cpp | 11 ++++-- examples/models/llama/runner/runner.cpp | 15 +++++--- examples/models/llama/runner/runner.h | 6 ++-- extension/llm/runner/llm_runner_helper.cpp | 25 ++++++++----- extension/llm/runner/llm_runner_helper.h | 9 +++-- .../runner/test/test_text_decoder_runner.cpp | 35 +++++++++++++++++++ extension/llm/runner/text_decoder_runner.cpp | 32 +++++++++++------ extension/llm/runner/text_decoder_runner.h | 25 +++++++++++-- 8 files changed, 125 insertions(+), 33 deletions(-) diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 3de47598426..80ece46a1bb 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -77,6 +77,11 @@ DEFINE_string( "etdump.in", "If an etdump path is provided, generate an ETDump file at the specified path for profiling purposes."); +DEFINE_string( + method_name, + "forward", + "Method name to execute in the model (e.g., 'forward', 'lora_forward')."); + // Helper function to parse comma-separated string lists std::vector parseStringList(const std::string& input) { std::vector result; @@ -145,11 +150,11 @@ int32_t main(int32_t argc, char** argv) { data_paths, temperature, #ifdef ET_EVENT_TRACER_ENABLED - std::move(etdump_gen_ptr) + std::move(etdump_gen_ptr), #else - nullptr + nullptr, #endif - ); + FLAGS_method_name); if (runner == nullptr) { ET_LOG(Error, "Failed to create llama runner"); diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index d2db805405e..3e26e5334e3 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -37,7 +37,8 @@ std::unique_ptr create_llama_runner( const std::string& tokenizer_path, std::optional data_path, float temperature, - std::unique_ptr<::executorch::runtime::EventTracer> event_tracer) { + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer, + const std::string& method_name) { if (data_path.has_value()) { std::vector data_files; data_files.push_back(data_path.value()); @@ -46,14 +47,16 @@ std::unique_ptr create_llama_runner( tokenizer_path, std::move(data_files), temperature, - std::move(event_tracer)); + std::move(event_tracer), + method_name); } return create_llama_runner( model_path, tokenizer_path, std::vector(), temperature, - std::move(event_tracer)); + std::move(event_tracer), + method_name); } std::unique_ptr create_llama_runner( @@ -61,7 +64,8 @@ std::unique_ptr create_llama_runner( const std::string& tokenizer_path, std::vector data_files, float temperature, - std::unique_ptr<::executorch::runtime::EventTracer> event_tracer) { + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer, + const std::string& method_name) { ET_LOG( Info, "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", @@ -84,7 +88,8 @@ std::unique_ptr create_llama_runner( std::move(tokenizer), data_files, temperature, - std::move(event_tracer)); + std::move(event_tracer), + method_name); } } // namespace example diff --git a/examples/models/llama/runner/runner.h b/examples/models/llama/runner/runner.h index 10225fcb81d..00d0832908b 100644 --- a/examples/models/llama/runner/runner.h +++ b/examples/models/llama/runner/runner.h @@ -29,14 +29,16 @@ std::unique_ptr create_llama_runner( const std::string& tokenizer_path, std::optional data_path, float temperature = -1.0f, - std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr); + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr, + const std::string& method_name = "forward"); std::unique_ptr create_llama_runner( const std::string& model_path, const std::string& tokenizer_path, std::vector data_files = {}, float temperature = -1.0f, - std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr); + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr, + const std::string& method_name = "forward"); std::unique_ptr load_llama_tokenizer( const std::string& tokenizer_path, diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 13f8d7a9db5..25846a2c5bc 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -182,18 +182,26 @@ std::unique_ptr create_text_llm_runner( const std::string& model_path, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::optional data_path, - float temperature) { + float temperature, + const std::string& method_name) { if (data_path.has_value()) { std::vector data_files; data_files.push_back(data_path.value()); return create_text_llm_runner( - model_path, std::move(tokenizer), std::move(data_files), temperature); + model_path, + std::move(tokenizer), + std::move(data_files), + temperature, + nullptr, + method_name); } return create_text_llm_runner( model_path, std::move(tokenizer), std::vector(), - temperature); + temperature, + nullptr, + method_name); } std::unique_ptr create_text_llm_runner( @@ -201,7 +209,8 @@ std::unique_ptr create_text_llm_runner( std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::vector data_files, float temperature, - std::unique_ptr<::executorch::runtime::EventTracer> event_tracer) { + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer, + const std::string& method_name) { // Sanity check tokenizer if (!tokenizer || !tokenizer->is_loaded()) { ET_LOG(Error, "Tokenizer is null or not loaded"); @@ -236,10 +245,10 @@ std::unique_ptr create_text_llm_runner( // Create IOManager std::unique_ptr io_manager = std::make_unique(*module); - // Create text_decoder_runner. Use a shared_ptr so that it can be shared with - // TextPrefiller and TextTokenGenerator - auto text_decoder_runner = - std::make_unique(module.get(), io_manager.get()); + // Create text_decoder_runner + ET_LOG(Info, "Using method: %s", method_name.c_str()); + auto text_decoder_runner = std::make_unique( + module.get(), io_manager.get(), method_name); // Create text_prefiller auto text_prefiller = std::make_unique( diff --git a/extension/llm/runner/llm_runner_helper.h b/extension/llm/runner/llm_runner_helper.h index 424567b7c2b..373124d8560 100644 --- a/extension/llm/runner/llm_runner_helper.h +++ b/extension/llm/runner/llm_runner_helper.h @@ -95,6 +95,7 @@ ET_EXPERIMENTAL std::unordered_set get_eos_ids( * @param data_path Optional path to additional data required by the model * @param temperature Optional temperature parameter for controlling randomness * (deprecated) + * @param method_name Name of the method to execute in the model * @return std::unique_ptr Initialized TextLLMRunner instance, or * nullptr on failure */ @@ -102,7 +103,8 @@ ET_EXPERIMENTAL std::unique_ptr create_text_llm_runner( const std::string& model_path, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::optional data_path, - float temperature = -1.0f); + float temperature = -1.0f, + const std::string& method_name = "forward"); /** * @brief Creates a TextLLMRunner instance with dependency injection @@ -116,6 +118,8 @@ ET_EXPERIMENTAL std::unique_ptr create_text_llm_runner( * @param data_files Vector of paths to additional data required by the model * @param temperature Optional temperature parameter for controlling randomness * (deprecated) + * @param event_tracer Optional event tracer for profiling + * @param method_name Name of the method to execute in the model * @return std::unique_ptr Initialized TextLLMRunner instance, or * nullptr on failure */ @@ -124,7 +128,8 @@ ET_EXPERIMENTAL std::unique_ptr create_text_llm_runner( std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::vector data_files = {}, float temperature = -1.0f, - std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr); + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr, + const std::string& method_name = "forward"); /** * @brief Creates a MultimodalRunner instance with dependency injection diff --git a/extension/llm/runner/test/test_text_decoder_runner.cpp b/extension/llm/runner/test/test_text_decoder_runner.cpp index 0001509ec55..917467e31fd 100644 --- a/extension/llm/runner/test/test_text_decoder_runner.cpp +++ b/extension/llm/runner/test/test_text_decoder_runner.cpp @@ -47,6 +47,41 @@ class TextDecoderRunnerTest : public Test { std::unique_ptr io_manager_; }; +// Test that method_name defaults to "forward" +TEST_F(TextDecoderRunnerTest, MethodNameDefaultsToForward) { + EXPECT_EQ(runner_->method_name(), "forward"); +} + +// Test that method_name can be set to a custom value +TEST_F(TextDecoderRunnerTest, MethodNameCustomValue) { + auto custom_runner = std::make_unique( + mock_module_.get(), io_manager_.get(), "encode"); + EXPECT_EQ(custom_runner->method_name(), "encode"); +} + +// Test that load() uses method_name (not hardcoded "forward") +TEST_F(TextDecoderRunnerTest, LoadUsesMethodName) { + // Get an available model + const char* model_path = std::getenv("KVCACHE_CACHE_POS"); + if (!model_path) { + GTEST_SKIP() << "No PTE model environment variable set"; + } + auto module = std::make_unique(model_path); + auto load_result = module->load(); + if (load_result != Error::Ok) { + GTEST_SKIP() << "Failed to load model"; + } + + auto io_mgr = std::make_unique(*module); + + // Create runner with a method name that doesn't exist + TextDecoderRunner runner(module.get(), io_mgr.get(), "nonexistent_method"); + + // load() should fail because "nonexistent_method" doesn't exist + auto result = runner.load(); + EXPECT_NE(result, Error::Ok); +} + // Test logits_to_token() method with Float tensor TEST_F(TextDecoderRunnerTest, LogitsToTokenFloat) { TensorFactory tf_float; diff --git a/extension/llm/runner/text_decoder_runner.cpp b/extension/llm/runner/text_decoder_runner.cpp index 8d51736ace5..3eb4e346e05 100644 --- a/extension/llm/runner/text_decoder_runner.cpp +++ b/extension/llm/runner/text_decoder_runner.cpp @@ -22,8 +22,13 @@ namespace llm { // NOTE: we observed ~2x loading performance increase on iPhone 15 // and a ~5% improvement on Galaxy S22 by switching to // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors. -TextDecoderRunner::TextDecoderRunner(Module* module, IOManager* io_manager) - : module_(module), io_manager_(io_manager) {} +TextDecoderRunner::TextDecoderRunner( + Module* module, + IOManager* io_manager, + std::string method_name) + : module_(module), + io_manager_(io_manager), + method_name_(std::move(method_name)) {} // This function is functional, meaning it shouldn't modify any state of the // input. It should be safe to call multiple times with the same inputs. The @@ -32,7 +37,7 @@ ::executorch::runtime::Result TextDecoderRunner::step( TensorPtr& tokens, int64_t start_pos) { // ET_LOG(Info, "Input token %" PRIu64, input_token); - auto method_meta_result = module_->method_meta("forward"); + auto method_meta_result = module_->method_meta(method_name_); if (!method_meta_result.ok()) { return method_meta_result.error(); } @@ -44,25 +49,31 @@ ::executorch::runtime::Result TextDecoderRunner::step( if (use_kv_cache) { auto start_pos_tensor_result = populate_start_pos_or_cache_position( - module_, start_pos, cache_positions, tokens->numel(), "forward"); + module_, + start_pos, + cache_positions, + tokens->numel(), + method_name_.c_str()); if (!start_pos_tensor_result.ok()) { return start_pos_tensor_result.error(); } auto start_pos_tensor = std::move(*start_pos_tensor_result); std::vector inputs; - auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor); + auto inputs_res = + io_manager_->prepare_decode(tokens, start_pos_tensor, method_name_); ET_CHECK_OK_OR_RETURN_ERROR(inputs_res.error()); inputs = inputs_res.get(); - auto outputs_res = module_->forward(inputs); + auto outputs_res = module_->execute(method_name_, inputs); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); - auto update_err = io_manager_->update_decode(outputs_res.get()); + auto update_err = + io_manager_->update_decode(outputs_res.get(), method_name_); ET_CHECK_OK_OR_RETURN_ERROR(update_err); ET_CHECK_MSG( outputs_res.get().size() == 1, - "More then one output returned from executing LLM."); + "More than one output returned from executing LLM."); ET_CHECK_MSG( outputs_res.get()[0].isTensor(), "Non Tensor Output returned from executing LLM"); @@ -72,11 +83,12 @@ ::executorch::runtime::Result TextDecoderRunner::step( } else { // no kv cache (void)start_pos; // unused - auto outputs_res = module_->forward(tokens); + std::vector inputs{tokens}; + auto outputs_res = module_->execute(method_name_, inputs); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_CHECK_MSG( outputs_res.get().size() == 1, - "More then one output returned from executing LLM."); + "More than one output returned from executing LLM."); ET_CHECK_MSG( outputs_res.get()[0].isTensor(), "Non Tensor Output returned from executing LLM"); diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index 720000185c9..8b855e2924f 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -20,7 +20,10 @@ namespace llm { class ET_EXPERIMENTAL TextDecoderRunner { public: - explicit TextDecoderRunner(Module* module, IOManager* io_manager); + explicit TextDecoderRunner( + Module* module, + IOManager* io_manager, + std::string method_name = "forward"); virtual ~TextDecoderRunner() = default; @@ -40,7 +43,14 @@ class ET_EXPERIMENTAL TextDecoderRunner { * @return The error code. */ virtual ::executorch::runtime::Error load() { - return module_->load_method("forward"); + auto err = module_->load_method(method_name_); + if (err != ::executorch::runtime::Error::Ok) { + ET_LOG( + Error, + "Failed to load method '%s'. Check available methods in the model.", + method_name_.c_str()); + } + return err; } /** @@ -48,7 +58,15 @@ class ET_EXPERIMENTAL TextDecoderRunner { * @return True if the Module is loaded, false otherwise. */ virtual bool is_method_loaded() { - return module_->is_method_loaded("forward"); + return module_->is_method_loaded(method_name_); + } + + /** + * Get the method name used by this runner. + * @return The method name. + */ + const std::string& method_name() const { + return method_name_; } inline void stop() { @@ -79,6 +97,7 @@ class ET_EXPERIMENTAL TextDecoderRunner { */ Module* module_; IOManager* io_manager_; + std::string method_name_; bool should_stop_{false}; }; From be0d894cf45d45429f98ae95aa870f0ab9766030 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Fri, 13 Feb 2026 12:19:06 -0800 Subject: [PATCH 2/3] Introduce LoraConfig Pull Request resolved: https://github.com/pytorch/executorch/pull/17229 Introduce LoraConfig to hold lora parameters such as: - checkpoint - rank - target_modules (e.g. q_proj, k_proj, v_proj, up_proj, down_proj, gate_proj, o_proj) - lora_alpha LoraConfig validation done post-init. LoraConfig can be created with config.json file. Update cases of export_llama_lib to use LoraConfig instead of adapter_checkpoint and adapter_config. NOTE: we may need to extend this to support more customizable features like lora config per layer etc. cc @hakanb ghstack-source-id: 340930093 @exported-using-ghexport Differential Revision: [D92304723](https://our.internmc.facebook.com/intern/diff/D92304723/) --- .ci/scripts/test_lora.sh | 23 +++++---- examples/models/llama/model.py | 49 ++++++++---------- ...3_xnnpack.yaml => qwen3_xnnpack_lora.yaml} | 3 ++ extension/llm/export/config/llm_config.py | 50 +++++++++++++++---- 4 files changed, 74 insertions(+), 51 deletions(-) rename examples/models/qwen3/config/{qwen3_xnnpack.yaml => qwen3_xnnpack_lora.yaml} (74%) diff --git a/.ci/scripts/test_lora.sh b/.ci/scripts/test_lora.sh index 71307ca086e..17e42988c4d 100644 --- a/.ci/scripts/test_lora.sh +++ b/.ci/scripts/test_lora.sh @@ -41,12 +41,14 @@ HF_ADAPTER_PATH=$( --files "adapter_config.json" "adapter_model.safetensors" ) +# Set environment variables for OmegaConf interpolation in yaml. +export LORA_ADAPTER_CHECKPOINT="${HF_ADAPTER_PATH}/adapter_model.safetensors" +export LORA_ADAPTER_CONFIG="${HF_ADAPTER_PATH}/adapter_config.json" + ### SINGLE LORA PTE ### # Export LoRA PTE file. $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ - +base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \ - +base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \ + --config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \ +export.output_name="qwen_lora_math_full.pte" # Capture the path of the downloaded qwen artifacts @@ -93,9 +95,7 @@ fi ### PROGRAM DATA SEPARATION ### # Export LoRA PTE, LoRA PTD, foundation PTD file. $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ - +base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \ - +base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \ + --config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \ +export.output_name="qwen_lora_math.pte" \ +export.foundation_weights_file="qwen_foundation.ptd" \ +export.lora_weights_file="qwen_lora_math.ptd" @@ -108,7 +108,7 @@ cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math.pte --dat NOW=$(date +"%H:%M:%S") echo "Finished at ${NOW}" -RESULT=$(cat result.txt) +RESULT=$(cat result2.txt) if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then echo "Expected result prefix: ${EXPECTED_PREFIX}" echo "Actual result: ${RESULT}" @@ -143,8 +143,11 @@ So, 15% of 80 is equal to (80 * 15) / 100 = 1200 / 100 = 12. The answer is: 12<|im_end|>" # Export Quantized PTE, PTD file, no LoRA. +# override base.lora_config=null to avoid creating a lora model +# and loading lora weights. $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ + --config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \ + base.lora_config=null \ +export.output_name="qwen_q.pte" \ +export.foundation_weights_file="qwen_foundation_q.ptd" \ +quantization.qmode="8da4w" \ @@ -152,9 +155,7 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ # Export Quantized LoRA PTE, LoRA PTD, foundation PTD file. $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ - +base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \ - +base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \ + --config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \ +export.output_name="qwen_lora_math_q.pte" \ +export.foundation_weights_file="qwen_foundation_lora_q.ptd" \ +export.lora_weights_file="qwen_lora_math_q.ptd" \ diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 1ec85936f7a..8b35d7d3155 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -31,12 +31,8 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): checkpoint_path = self.llm_config.base.checkpoint params_path = self.llm_config.base.params - # Adapter checkpoint and config. - adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint - adapter_config_path = self.llm_config.base.adapter_config - assert (adapter_checkpoint_path is None and adapter_config_path is None) or ( - adapter_checkpoint_path is not None and adapter_config_path is not None - ), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified." + # LoRA adapter configuration. + lora_config = self.llm_config.base.lora_config self.use_kv_cache = self.llm_config.model.use_kv_cache self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache @@ -69,10 +65,18 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): with open(params_path, "r") as f: params = json.loads(f.read()) - # Get adapter checkpoint and config. + # Get adapter checkpoint. adapter_checkpoint = {} - adapter_config = {} - if adapter_checkpoint_path: + if lora_config: + # Resolve LoRA params from adapter_config JSON if not already set. + if lora_config.adapter_config and lora_config.lora_rank == 0: + with open(lora_config.adapter_config, "r") as f: + cfg = json.load(f) + lora_config.lora_rank = cfg["r"] + lora_config.lora_alpha = cfg["lora_alpha"] + lora_config.target_modules = cfg["target_modules"] + + adapter_checkpoint_path = lora_config.adapter_checkpoint if adapter_checkpoint_path.endswith(".pt"): adapter_checkpoint = torch.load( adapter_checkpoint_path, map_location=device, mmap=True @@ -92,22 +96,6 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): raise ValueError( f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}" ) - - with open(adapter_config_path, "r") as f: - adapter_config_full = json.loads(f.read()) - if ( - "r" not in adapter_config_full - or "lora_alpha" not in adapter_config_full - or "target_modules" not in adapter_config_full - ): - raise ValueError( - "Adapter config must contain r, lora_alpha, and target_modules." - ) - adapter_config = { - "r": adapter_config_full["r"], - "lora_alpha": adapter_config_full["lora_alpha"], - "target_modules": adapter_config_full["target_modules"], - } checkpoint.update(adapter_checkpoint) output_prune_map = None @@ -133,8 +121,10 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): input_prune_map=input_prune_map, output_prune_map=output_prune_map, enable_dynamic_shape=self.enable_dynamic_shape, + r=lora_config.lora_rank if lora_config else None, + lora_alpha=lora_config.lora_alpha if lora_config else None, + target_modules=lora_config.target_modules if lora_config else None, **params, - **adapter_config, ) if model_args.use_scaled_rope: @@ -356,9 +346,10 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): embedding_bit_width, embedding_group_size = None, None if self.llm_config.base.preq_embedding_quantize: - embedding_bit_width, embedding_group_size = ( - self.llm_config.base.preq_embedding_quantize.split(",") - ) + ( + embedding_bit_width, + embedding_group_size, + ) = self.llm_config.base.preq_embedding_quantize.split(",") from .source_transformation.pre_quantization import ( transform_embedding_for_pre_quantization, ) diff --git a/examples/models/qwen3/config/qwen3_xnnpack.yaml b/examples/models/qwen3/config/qwen3_xnnpack_lora.yaml similarity index 74% rename from examples/models/qwen3/config/qwen3_xnnpack.yaml rename to examples/models/qwen3/config/qwen3_xnnpack_lora.yaml index 1c4801bf5ef..3836b7793fb 100644 --- a/examples/models/qwen3/config/qwen3_xnnpack.yaml +++ b/examples/models/qwen3/config/qwen3_xnnpack_lora.yaml @@ -2,6 +2,9 @@ base: model_class: "qwen3_0_6b" params: "examples/models/qwen3/config/0_6b_config.json" metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}' + lora_config: + adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT} + adapter_config: ${oc.env:LORA_ADAPTER_CONFIG} model: use_kv_cache: True diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index a7453fd09c1..db56b686ba5 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -62,6 +62,36 @@ class PreqMode(str, Enum): preq_8da4w_out_8da8w = "8da4w_output_8da8w" +@dataclass +class LoraConfig: + """LoRA adapter configuration. + + Can be created in two ways: + + 1. From an adapter_config JSON file: + LoraConfig( + adapter_checkpoint="/path/to/adapter.safetensors", + adapter_config="/path/to/adapter_config.json", + ) + Note: user is responsible for parsing the config and + ensure it doesn't conflict with any explicit values. + + 2. With explicit values: + LoraConfig( + adapter_checkpoint="/path/to/adapter.safetensors", + lora_rank=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + ) + """ + + adapter_checkpoint: str + adapter_config: Optional[str] = None + lora_rank: int = 0 + lora_alpha: int = 0 + target_modules: List[str] = field(default_factory=list) + + @dataclass class BaseConfig: """ @@ -77,11 +107,7 @@ class BaseConfig: If left empty, the model will either be initialized with random weights if it is a Llama model or the weights will be downloaded from HuggingFace if it is a non-Llama model. - adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if - the model has trained LoRA adapters. Must provide - adapter_config.json. - adapter_config: Path to the adapter_config.json file from torchtune. - Used if the model has trained LoRA adapters. Must provide adapter.pt. + lora_config: LoRA adapter configuration. tokenizer_path: Path to the tokenizer file. metadata: Json string containing metadata information. e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' @@ -98,8 +124,7 @@ class BaseConfig: model_class: ModelType = ModelType.llama3 params: Optional[str] = None checkpoint: Optional[str] = None - adapter_checkpoint: Optional[str] = None - adapter_config: Optional[str] = None + lora_config: Optional[LoraConfig] = None tokenizer_path: Optional[str] = None metadata: Optional[str] = None use_lora: int = 0 @@ -536,10 +561,13 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 llm_config.base.params = args.params if hasattr(args, "checkpoint"): llm_config.base.checkpoint = args.checkpoint - if hasattr(args, "adapter_checkpoint"): - llm_config.base.adapter_checkpoint = args.adapter_checkpoint - if hasattr(args, "adapter_config"): - llm_config.base.adapter_config = args.adapter_config + if hasattr(args, "adapter_checkpoint") and args.adapter_checkpoint: + if not hasattr(args, "adapter_config") or not args.adapter_config: + raise ValueError("--adapter_checkpoint requires --adapter_config") + llm_config.base.lora_config = LoraConfig( + adapter_checkpoint=args.adapter_checkpoint, + adapter_config=args.adapter_config, + ) if hasattr(args, "tokenizer_path"): llm_config.base.tokenizer_path = args.tokenizer_path if hasattr(args, "metadata"): From 2ed0f5e77e0abec2f53874c041fba2ef4b13f862 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Fri, 13 Feb 2026 12:19:08 -0800 Subject: [PATCH 3/3] Introduce MultimethodLoraConfig Pull Request resolved: https://github.com/pytorch/executorch/pull/17230 MultimethodLoraConfig; collects together method name and lora config. ghstack-source-id: 341143477 @exported-using-ghexport Differential Revision: [D92315627](https://our.internmc.facebook.com/intern/diff/D92315627/) --- extension/llm/export/config/llm_config.py | 36 ++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index db56b686ba5..a65bbe248e2 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -22,7 +22,7 @@ import re from dataclasses import dataclass, field from enum import Enum -from typing import ClassVar, List, Optional +from typing import ClassVar, Dict, List, Optional ################################################################################ @@ -287,6 +287,37 @@ class DebugConfig: verbose: bool = False +################################################################################ +############################## MultimethodLoraConfig ########################### +################################################################################ + + +@dataclass +class MultimethodLoraConfig: + """Configuration for exporting multiple methods to a single .pte file. + + Maps method names to optional LoRA configurations. A None value means + the method uses base model weights. + + Attributes: + methods: Dict mapping method names to optional LoRA configs. + Empty dict disables multimethod export. + + Example: + MultimethodLoraConfig(methods={ + "forward": None, # base model + "lora_forward": lora_config, # LoRA variant + }) + """ + + methods: Dict[str, Optional[LoraConfig]] = field(default_factory=dict) + + @property + def enabled(self) -> bool: + """Returns True if multimethod_lora export is configured.""" + return len(self.methods) > 0 + + ################################################################################ ############################# QuantizationConfig ############################### ################################################################################ @@ -543,6 +574,9 @@ class LlmConfig: model: ModelConfig = field(default_factory=ModelConfig) export: ExportConfig = field(default_factory=ExportConfig) debug: DebugConfig = field(default_factory=DebugConfig) + multimethod_lora: MultimethodLoraConfig = field( + default_factory=MultimethodLoraConfig + ) quantization: QuantizationConfig = field(default_factory=QuantizationConfig) backend: BackendConfig = field(default_factory=BackendConfig)