From dcfea58856b74d22300dd10c2a6172c8e4593ea6 Mon Sep 17 00:00:00 2001 From: Di Xu Date: Wed, 17 Jun 2026 08:25:57 -0700 Subject: [PATCH] Extract `_lora_call` into a shared free function in `lora.py` (#20306) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/20306 Both `MultimodalTransformer` (transformer.py) and `StaticAttention` (static_attention.py) had identical `_lora_call` methods. Extracted the logic into a module-level `lora_call()` function in `lora.py` and updated both consumers to import and call it directly. Reviewed By: billmguo Differential Revision: D108757232 Signed-off-by: Di Xu --- examples/models/llama/lora.py | 9 +++++++++ examples/models/llama/static_attention.py | 20 ++++++-------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/examples/models/llama/lora.py b/examples/models/llama/lora.py index 1f6cca6403a..fd732bea780 100644 --- a/examples/models/llama/lora.py +++ b/examples/models/llama/lora.py @@ -69,3 +69,12 @@ def forward( z = self.lora_a(self.dropout(x)) z = (self.alpha / self.rank) * self.lora_b(z) return out + z + + +def lora_call(linear, x_in, lora_blob): + if lora_blob is not None: + key = getattr(linear, "_lora_key", None) + if key is not None and key in lora_blob: + a, b = lora_blob[key] + return linear(x_in, a, b) + return linear(x_in) diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index fddd451e3ac..8e985239651 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -13,7 +13,7 @@ ForwardOptions, register_attention, ) -from executorch.examples.models.llama.lora import LoRALinear +from executorch.examples.models.llama.lora import lora_call, LoRALinear from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import ScalelessRMSNorm from executorch.examples.models.llama.rope import Rope @@ -1014,14 +1014,6 @@ def from_attention_mha( return instance - def _lora_call(self, linear, x_in, lora_blob): - if lora_blob is not None: - key = getattr(linear, "_lora_key", None) - if key is not None and key in lora_blob: - a, b = lora_blob[key] - return linear(x_in, a, b) - return linear(x_in) - def forward( self, x: torch.Tensor, @@ -1044,7 +1036,7 @@ def forward( # Default behavior (no blob, or no `_lora_key`) is unchanged. _lora_blob = kwargs.get("__lora_io_blob__") - new_qs = [self._lora_call(wq, x, _lora_blob) for wq in self.wqs] + new_qs = [lora_call(wq, x, _lora_blob) for wq in self.wqs] shared_kv = kwargs.get("shared_kv") if shared_kv is not None: @@ -1054,8 +1046,8 @@ def forward( new_ks = [] new_vs = [] else: - new_ks = [self._lora_call(wk, x, _lora_blob) for wk in self.wks] - new_vs = [self._lora_call(wv, x, _lora_blob) for wv in self.wvs] + new_ks = [lora_call(wk, x, _lora_blob) for wk in self.wks] + new_vs = [lora_call(wv, x, _lora_blob) for wv in self.wvs] if self.use_conv2d: @@ -1092,7 +1084,7 @@ def from_conv2ds(ts): if self.use_conv2d: y = ( - self._lora_call( + lora_call( self.wo, y.reshape(bsz, -1, 1, self.n_heads * self.head_dim).transpose(1, 3), _lora_blob, @@ -1101,7 +1093,7 @@ def from_conv2ds(ts): .reshape(bsz, -1, self.dim) ) else: - y = self._lora_call(self.wo, y, _lora_blob) + y = lora_call(self.wo, y, _lora_blob) update = {"out_cache_state": out_cache_state} if kv_to_share is not None: