From 287e1ddd452fc941a5f324739141f9a06f667975 Mon Sep 17 00:00:00 2001 From: lucylq Date: Wed, 21 May 2025 10:53:47 -0700 Subject: [PATCH] Add LoRA linear definition ^ Add lora linear definition. Pull out linears from attention, and allow custom linear (eg. lora linear) to be passed in. If none, construct linear (current behaviour). Differential Revision: [D73953776](https://our.internmc.facebook.com/intern/diff/D73953776/) [ghstack-poisoned] --- examples/models/llama/TARGETS | 1 + examples/models/llama/attention.py | 54 ++++++++++++--- examples/models/llama/llama_transformer.py | 79 +++++++++++++++++++++- examples/models/llama/lora.py | 48 +++++++++++++ examples/models/llama/model_args.py | 10 +++ 5 files changed, 182 insertions(+), 10 deletions(-) create mode 100644 examples/models/llama/lora.py diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index f2aa396f7a1..45b7be10b7e 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -13,6 +13,7 @@ runtime.python_library( name = "llama_transformer", srcs = [ "llama_transformer.py", + "lora.py", "rope.py", "attention.py", "model_args.py", diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 63d783c3332..4a71d679a53 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -324,7 +324,28 @@ def update( @register_attention("mha") class AttentionMHA(Attention): - def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + wq: Optional[nn.Module] = None, + wk: Optional[nn.Module] = None, + wv: Optional[nn.Module] = None, + wo: Optional[nn.Module] = None, + ): + """ + Multi-head attention layer. + + Args: + args (ModelArgs): Model configuration parameters. + layer_id (int): Layer index. + rope (Rope): Rotary position embedding module. + wq (Optional[nn.Module]): Query projection module. If None, use regular nn.Linear. + wk (Optional[nn.Module]): Key projection module. If None, use regular nn.Linear. + wv (Optional[nn.Module]): Value projection module. If None, use regular nn.Linear. + wo (Optional[nn.Module]): Output projection module. If None, use regular nn.Linear. + """ super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads @@ -349,19 +370,34 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) - self.wq = nn.Linear( - self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + self.wq = ( + wq + if wq is not None + else nn.Linear( + self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + ) + ) + self.wk = ( + wk + if wk is not None + else nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) ) - self.wk = nn.Linear( - self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + self.wv = ( + wv + if wv is not None + else nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) ) - self.wv = nn.Linear( - self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + self.wo = ( + wo + if wo is not None + else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) ) - self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.layer_id = layer_id - self.rope = rope causal_mask = torch.tril( diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 1fdcdcd91fc..f991116619a 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -18,6 +18,7 @@ ForwardOptions, ) +from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import Rope @@ -254,7 +255,83 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: layers = torch.nn.ModuleList() cls = ATTENTION_REGISTRY[model_args.attention_type] for layer_id in range(model_args.n_layers): - attention = cls(model_args, layer_id, rope) + wq = ( + LoRALinear( + in_dim=model_args.dim, + out_dim=model_args.n_heads * model_args.head_dim, + rank=model_args.r, + alpha=model_args.lora_alpha, + dropout=0.0, + use_bias=model_args.attention_qkv_bias, + ) + if model_args.target_modules is not None + and "q_proj" in model_args.target_modules + else ( + torch.nn.Linear( + model_args.dim, + model_args.n_heads * model_args.head_dim, + bias=model_args.attention_qkv_bias, + ) + ) + ) + + wk = ( + LoRALinear( + in_dim=model_args.dim, + out_dim=model_args.n_kv_heads * model_args.head_dim, + rank=model_args.r, + alpha=model_args.lora_alpha, + dropout=0.0, + use_bias=model_args.attention_qkv_bias, + ) + if model_args.target_modules is not None + and "k_proj" in model_args.target_modules + else ( + torch.nn.Linear( + model_args.dim, + model_args.n_kv_heads * model_args.head_dim, + bias=model_args.attention_qkv_bias, + ) + ) + ) + wv = ( + LoRALinear( + in_dim=model_args.dim, + out_dim=model_args.n_kv_heads * model_args.head_dim, + rank=model_args.r, + alpha=model_args.lora_alpha, + dropout=0.0, + use_bias=model_args.attention_qkv_bias, + ) + if model_args.target_modules is not None + and "v_proj" in model_args.target_modules + else ( + torch.nn.Linear( + model_args.dim, + model_args.n_kv_heads * model_args.head_dim, + bias=model_args.attention_qkv_bias, + ) + ) + ) + + wo = ( + LoRALinear( + in_dim=model_args.n_kv_heads * model_args.head_dim, + out_dim=model_args.dim, + rank=model_args.r, + alpha=model_args.lora_alpha, + dropout=0.0, + use_bias=model_args.attention_qkv_bias, + ) + if model_args.target_modules is not None + and "output_proj" in model_args.target_modules + else ( + torch.nn.Linear( + model_args.n_heads * model_args.head_dim, model_args.dim, bias=False + ) + ) + ) + attention = cls(model_args, layer_id, rope, wq, wk, wv, wo) transformer_block = TransformerBlock(model_args, attention) layers.append(transformer_block) diff --git a/examples/models/llama/lora.py b/examples/models/llama/lora.py new file mode 100644 index 00000000000..12c1c4e5d68 --- /dev/null +++ b/examples/models/llama/lora.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + + +class LoRALinear(nn.Module): + """LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models `.""" + + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + use_bias: bool = False, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.rank = rank + self.alpha = alpha + self.use_bias = use_bias + self.dropout = dropout + + linear = nn.Linear(in_dim, out_dim, bias=use_bias) + weight = linear.weight + bias = linear.bias if self.use_bias else None + self.register_parameter("weight", nn.Parameter(weight)) + self.register_parameter( + "bias", nn.Parameter(bias) if bias is not None else None + ) + + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() + self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) + self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.nn.functional.linear(x, self.weight, self.bias) + lora_out = self.lora_a(self.dropout(x)) + lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) + + return out + lora_out diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 94dbb5a0651..b131d992138 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -55,8 +55,18 @@ class ModelArgs: eos_count: int = 2 quantization_args: Optional[dict] = None + # LoRA for QAT. lora_args: Optional[dict] = None + # LoRA arguments to set up a LoRA inference model. + # These arguments come directly from a torchtune LoRA config. + r: Optional[int] = None # Rank. + lora_alpha: Optional[int] = None # Alpha. + # Eg. q_proj, k_proj, v_proj, output_proj + target_modules: Optional[list] = None + peft_type: Optional[str] = None # PEFT type. + base_model_name_or_path: Optional[str] = None # Base model name or path. + def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads