From 172b3f551825e14fd889d3c7b84f9829d5a41329 Mon Sep 17 00:00:00 2001 From: lucylq Date: Thu, 24 Jul 2025 14:27:55 -0700 Subject: [PATCH 1/2] Add LoRA linear definition Pull Request resolved: https://github.com/pytorch/executorch/pull/11044 ^ Add lora linear definition. Pull out linears from attention, and allow custom linear (eg. lora linear) to be passed in. If none, construct linear (current behaviour). ghstack-source-id: 298411928 @exported-using-ghexport Differential Revision: [D73953776](https://our.internmc.facebook.com/intern/diff/D73953776/) --- examples/models/llama/TARGETS | 1 + examples/models/llama/attention.py | 74 +++++++++++++++++++++++++---- examples/models/llama/lora.py | 48 +++++++++++++++++++ examples/models/llama/model_args.py | 10 ++++ 4 files changed, 125 insertions(+), 8 deletions(-) create mode 100644 examples/models/llama/lora.py diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 95d57e12f5a..9ea683e4174 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -13,6 +13,7 @@ runtime.python_library( name = "llama_transformer", srcs = [ "llama_transformer.py", + "lora.py", "rope.py", "attention.py", "model_args.py", diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index aa53b330837..6f23456eaaa 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import Rope @@ -325,7 +326,20 @@ def update( @register_attention("mha") class AttentionMHA(Attention): - def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + ): + """ + Multi-head attention layer. + + Args: + args (ModelArgs): Model configuration parameters. + layer_id (int): Layer index. + rope (Rope): Rotary position embedding module. + """ super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads @@ -350,16 +364,60 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) - self.wq = nn.Linear( - self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + self.wq = ( + LoRALinear( + in_dim=args.dim, + out_dim=args.n_heads * args.head_dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=args.attention_qkv_bias, + ) + if args.target_modules is not None and "q_proj" in args.target_modules + else nn.Linear( + self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + ) ) - self.wk = nn.Linear( - self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + self.wk = ( + LoRALinear( + in_dim=args.dim, + out_dim=args.n_kv_heads * args.head_dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=args.attention_qkv_bias, + ) + if args.target_modules is not None and "k_proj" in args.target_modules + else nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) ) - self.wv = nn.Linear( - self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + self.wv = ( + LoRALinear( + in_dim=args.dim, + out_dim=args.n_kv_heads * args.head_dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=args.attention_qkv_bias, + ) + if args.target_modules is not None and "v_proj" in args.target_modules + else nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + ) + self.wo = ( + LoRALinear( + in_dim=args.n_kv_heads * args.head_dim, + out_dim=args.dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=args.attention_qkv_bias, + ) + if args.target_modules is not None and "output_proj" in args.target_modules + else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) ) - self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.layer_id = layer_id diff --git a/examples/models/llama/lora.py b/examples/models/llama/lora.py new file mode 100644 index 00000000000..12c1c4e5d68 --- /dev/null +++ b/examples/models/llama/lora.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + + +class LoRALinear(nn.Module): + """LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models `.""" + + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + use_bias: bool = False, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.rank = rank + self.alpha = alpha + self.use_bias = use_bias + self.dropout = dropout + + linear = nn.Linear(in_dim, out_dim, bias=use_bias) + weight = linear.weight + bias = linear.bias if self.use_bias else None + self.register_parameter("weight", nn.Parameter(weight)) + self.register_parameter( + "bias", nn.Parameter(bias) if bias is not None else None + ) + + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() + self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) + self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.nn.functional.linear(x, self.weight, self.bias) + lora_out = self.lora_a(self.dropout(x)) + lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) + + return out + lora_out diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 5734cd66ef7..18acda9fe93 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -55,8 +55,18 @@ class ModelArgs: eos_count: int = 2 quantization_args: Optional[dict] = None + # LoRA for QAT. lora_args: Optional[dict] = None + # LoRA arguments to set up a LoRA inference model. + # These arguments come directly from a torchtune LoRA config. + r: Optional[int] = None # Rank. + lora_alpha: Optional[int] = None # Alpha. + # Eg. q_proj, k_proj, v_proj, output_proj + target_modules: Optional[list] = None + peft_type: Optional[str] = None # PEFT type. + base_model_name_or_path: Optional[str] = None # Base model name or path. + def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads From c77a4bdea7ea4ec6547fb5870c823bbf10979c12 Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 25 Jul 2025 11:41:23 -0700 Subject: [PATCH 2/2] Export a lora model Pull Request resolved: https://github.com/pytorch/executorch/pull/11045 ^ Program+data combined currently, using the lora linear definition. ghstack-source-id: 298641176 Differential Revision: [D75153377](https://our.internmc.facebook.com/intern/diff/D75153377/) --- backends/xnnpack/operators/node_visitor.py | 7 ++++--- examples/models/llama/export_llama_lib.py | 12 ++++++++++++ examples/models/llama/model.py | 22 ++++++++++++++++++++++ examples/models/llama/model_args.py | 2 +- extension/llm/export/config/llm_config.py | 10 +++++++++- 5 files changed, 48 insertions(+), 5 deletions(-) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index b7d16b18bd1..90a9a3063e3 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -622,9 +622,10 @@ def get_serialized_buffer_index( ) external_tag = tensor.meta.get("delegate_constant_tag", None) - logging.info( - f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store" - ) + if external_tag is not None: + logging.info( + f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store" + ) self._named_data_store.add_named_data( named_key, bytes(array), diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 39f5f2ec0cd..a0cb7dab0ea 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -239,6 +239,18 @@ def build_args_parser() -> argparse.ArgumentParser: help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.", ) + parser.add_argument( + "--adapter_checkpoint", + required=False, + help="Path to the adapter.pt file from torchtune. Used if the model has trained LoRA adapters. Must provide adapter_config.json", + ) + + parser.add_argument( + "--adapter_config", + required=False, + help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.", + ) + parser.add_argument( "--use_qnn_sha", action="store_true", diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 27d41ac90cd..ac2905ea4c4 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -46,6 +46,13 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): checkpoint_dir = self.llm_config.base.checkpoint_dir params_path = self.llm_config.base.params + # Adapter checkpoint and config. + adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint + adapter_config_path = self.llm_config.base.adapter_config + assert (adapter_checkpoint_path is None and adapter_config_path is None) or ( + adapter_checkpoint_path is not None and adapter_config_path is not None + ), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified." + self.use_kv_cache = self.llm_config.model.use_kv_cache self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache self.generate_full_logits = self.llm_config.debug.generate_full_logits @@ -129,6 +136,20 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): with open(params_path, "r") as f: params = json.loads(f.read()) + # Get adapter checkpoint and config. + adapter_checkpoint = {} + adapter_config = {} + if adapter_checkpoint_path: + adapter_checkpoint = torch.load( + adapter_checkpoint_path, map_location=device, mmap=True + ) + from torchtune.models import convert_weights + + adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint) + with open(adapter_config_path, "r") as f: + adapter_config = json.loads(f.read()) + checkpoint.update(adapter_checkpoint) + output_prune_map = None if self.output_prune_map_path is not None: with open(self.output_prune_map_path, "r") as f: @@ -153,6 +174,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): output_prune_map=output_prune_map, enable_dynamic_shape=self.enable_dynamic_shape, **params, + **adapter_config, ) if model_args.use_scaled_rope: diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 18acda9fe93..1335aaf609e 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -59,7 +59,7 @@ class ModelArgs: lora_args: Optional[dict] = None # LoRA arguments to set up a LoRA inference model. - # These arguments come directly from a torchtune LoRA config. + # These arguments come directly from a torchtune adapter_config.json file. r: Optional[int] = None # Rank. lora_alpha: Optional[int] = None # Alpha. # Eg. q_proj, k_proj, v_proj, output_proj diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 94bbb2d8b2e..3a67bf83dfd 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -73,10 +73,16 @@ class BaseConfig: if it is a Llama model or the weights will be downloaded from HuggingFace if it is a non-Llama model. checkpoint_dir: Path to directory containing sharded checkpoint files. + adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if + the model has trained LoRA adapters. Must provide + adapter_config.json. + adapter_config: Path to the adapter_config.json file from torchtune. + Used if the model has trained LoRA adapters. Must provide adapter.pt. tokenizer_path: Path to the tokenizer file. metadata: Json string containing metadata information. e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' - use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT. + use_lora: Only for use with QAT. Rank of the LoRA adapter, disabled + if set to 0. fairseq2: For legacy internal use cases, this is safe to ignore. preq_mode: Legacy option to specify how prequantized weights are loaded. Going forward, ExecuTorch supports loading weights prequantized through @@ -90,6 +96,8 @@ class BaseConfig: params: Optional[str] = None checkpoint: Optional[str] = None checkpoint_dir: Optional[str] = None + adapter_checkpoint: Optional[str] = None + adapter_config: Optional[str] = None tokenizer_path: Optional[str] = None metadata: Optional[str] = None use_lora: int = 0