From aac7fd68e63a853bc31fb5e6de6bf084fec215b4 Mon Sep 17 00:00:00 2001 From: lucylq Date: Wed, 21 May 2025 10:53:51 -0700 Subject: [PATCH] Export a lora model ^ Program+data combined currently, using the lora linear definition. Differential Revision: [D75153377](https://our.internmc.facebook.com/intern/diff/D75153377/) [ghstack-poisoned] --- backends/xnnpack/operators/node_visitor.py | 7 +++--- examples/models/llama/export_llama_lib.py | 29 ++++++++++++++++++++++ examples/models/llama/model.py | 21 +++++++++++++++- 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 8470184d808..b13fd094d31 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -595,9 +595,10 @@ def get_serialized_buffer_index( ) external_tag = tensor.meta.get("delegate_constant_tag", None) - logging.info( - f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store" - ) + if external_tag is not None: + logging.info( + f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store" + ) self._named_data_store.add_named_data( named_key, bytes(array), diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 3a3102886f8..dcd74a45124 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -235,6 +235,18 @@ def build_args_parser() -> argparse.ArgumentParser: help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.", ) + parser.add_argument( + "--adapter_checkpoint", + required=False, + help="Path to the adapter.pt file from torchtune. Used if the model has trained LoRA adapters. Must provide adapter_config.json", + ) + + parser.add_argument( + "--adapter_config", + required=False, + help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.", + ) + parser.add_argument( "--use_qnn_sha", action="store_true", @@ -631,6 +643,17 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None ) params_path = canonical_path(args.params) if args.params else None + + assert (args.adapter_checkpoint is None and args.adapter_config is None) or ( + args.adapter_checkpoint is not None and args.adapter_config is not None + ), "Must provide both adapter_checkpoint and adapter_config, or neither" + adapter_checkpoint_path = ( + canonical_path(args.adapter_checkpoint) if args.adapter_checkpoint else None + ) + adapter_config_path = ( + canonical_path(args.adapter_config) if args.adapter_config else None + ) + output_dir_path = canonical_path(args.output_dir, dir=True) weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA @@ -642,6 +665,8 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, params_path=params_path, + adapter_checkpoint=adapter_checkpoint_path, + adapter_config=adapter_config_path, use_kv_cache=args.use_kv_cache, use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, generate_full_logits=args.generate_full_logits, @@ -1141,6 +1166,8 @@ def _load_llama_model( checkpoint: Optional[str] = None, checkpoint_dir: Optional[str] = None, params_path: Optional[str] = None, + adapter_checkpoint: Optional[str] = None, + adapter_config: Optional[str] = None, use_kv_cache: bool = False, use_sdpa_with_kv_cache: bool = False, generate_full_logits: bool = False, @@ -1188,6 +1215,8 @@ def _load_llama_model( checkpoint=checkpoint, checkpoint_dir=checkpoint_dir, params=params_path, + adapter_checkpoint=adapter_checkpoint, + adapter_config=adapter_config, use_kv_cache=use_kv_cache, use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, generate_full_logits=generate_full_logits, diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index d6400c29db8..9bbb5be3363 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -8,7 +8,6 @@ import json import os -from typing import Dict, Tuple import torch from executorch.examples.models.checkpoint import ( @@ -47,6 +46,10 @@ def __init__(self, **kwargs): # Params file. params_path = kwargs.get("params", None) + # Adapter + adapter_checkpoint = kwargs.get("adapter_checkpoint", None) + adapter_config = kwargs.get("adapter_config", None) + self.use_kv_cache = kwargs.get("use_kv_cache", False) self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) self.generate_full_logits = kwargs.get("generate_full_logits", False) @@ -130,6 +133,21 @@ def __init__(self, **kwargs): with open(params_path, "r") as f: params = json.loads(f.read()) + # Get adapter checkpoint and config. + adapter_checkpoint = {} + adapter_config = {} + adapter_checkpoint_path = kwargs.get("adapter_checkpoint", None) + if adapter_checkpoint_path: + adapter_checkpoint = torch.load( + adapter_checkpoint_path, map_location=device, mmap=True + ) + from torchtune.models import convert_weights + adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint) + adapter_config = kwargs.get("adapter_config", None) + with open(adapter_config, "r") as f: + adapter_config = json.loads(f.read()) + checkpoint.update(adapter_checkpoint) + output_prune_map = None if self.output_prune_map_path is not None: with open(self.output_prune_map_path, "r") as f: @@ -154,6 +172,7 @@ def __init__(self, **kwargs): output_prune_map=output_prune_map, enable_dynamic_shape=self.enable_dynamic_shape, **params, + **adapter_config, ) if model_args.use_scaled_rope: