Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions examples/models/llama/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,16 @@ def __init__(self, dim: int, hidden_dim: int, args: ModelArgs):
else nn.Linear(dim, hidden_dim, bias=False)
)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def forward(self, x, lora_blob=None):
# CoreML LoRA-as-IO Path-2: when `lora_blob` is provided, route per-
# projection slices to LoRALinear instances tagged with `_lora_key`.
# Default behavior (lora_blob=None) is unchanged.
def _call(linear, x_in):
if lora_blob is not None:
key = getattr(linear, "_lora_key", None)
if key is not None and key in lora_blob:
a, b = lora_blob[key]
return linear(x_in, a, b)
return linear(x_in)

return _call(self.w2, F.silu(_call(self.w1, x)) * _call(self.w3, x))
8 changes: 7 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,13 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
else:
out = h + ffn_out
else:
ffn_out = self.feed_forward(self.ffn_norm(h))
if isinstance(self.feed_forward, LoRAFeedForward):
ffn_out = self.feed_forward(
self.ffn_norm(h),
lora_blob=attn_options.get("__lora_io_blob__"),
)
else:
ffn_out = self.feed_forward(self.ffn_norm(h))
if hasattr(self, "post_ffn_norm"):
ffn_out = self.post_ffn_norm(ffn_out)
if self.use_residual_gate:
Expand Down
24 changes: 19 additions & 5 deletions examples/models/llama/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn


Expand Down Expand Up @@ -49,9 +52,20 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
# Optional forward-arg LoRA tensors (CoreML LoRA-as-IO Path 2). When
# both are provided, they override the stored lora_a/lora_b for this
# call. Default behavior (None, None) is unchanged.
lora_a: Optional[torch.Tensor] = None,
lora_b: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = self.linear(x)
lora_out = self.lora_a(self.dropout(x))
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)

return out + lora_out
if lora_a is not None and lora_b is not None:
z = F.linear(self.dropout(x), lora_a)
z = (self.alpha / self.rank) * F.linear(z, lora_b)
else:
z = self.lora_a(self.dropout(x))
z = (self.alpha / self.rank) * self.lora_b(z)
return out + z
28 changes: 22 additions & 6 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,14 @@ def from_attention_mha(

return instance

def _lora_call(self, linear, x_in, lora_blob):
if lora_blob is not None:
key = getattr(linear, "_lora_key", None)
if key is not None and key in lora_blob:
a, b = lora_blob[key]
return linear(x_in, a, b)
return linear(x_in)

def forward(
self,
x: torch.Tensor,
Expand All @@ -1030,7 +1038,13 @@ def forward(
if self.use_conv2d:
x = x.reshape(bsz, -1, 1, dim).transpose(1, 3)

new_qs = [wq(x) for wq in self.wqs]
# CoreML LoRA-as-IO Path-2: when an upstream wrapper has stashed
# a per-key LoRA blob in attn_options, route per-projection slices
# to LoRALinear instances that have been tagged with `_lora_key`.
# Default behavior (no blob, or no `_lora_key`) is unchanged.
_lora_blob = kwargs.get("__lora_io_blob__")

new_qs = [self._lora_call(wq, x, _lora_blob) for wq in self.wqs]

shared_kv = kwargs.get("shared_kv")
if shared_kv is not None:
Expand All @@ -1040,8 +1054,8 @@ def forward(
new_ks = []
new_vs = []
else:
new_ks = [wk(x) for wk in self.wks]
new_vs = [wv(x) for wv in self.wvs]
new_ks = [self._lora_call(wk, x, _lora_blob) for wk in self.wks]
new_vs = [self._lora_call(wv, x, _lora_blob) for wv in self.wvs]

if self.use_conv2d:

Expand Down Expand Up @@ -1078,14 +1092,16 @@ def from_conv2ds(ts):

if self.use_conv2d:
y = (
self.wo(
y.reshape(bsz, -1, 1, self.n_heads * self.head_dim).transpose(1, 3)
self._lora_call(
self.wo,
y.reshape(bsz, -1, 1, self.n_heads * self.head_dim).transpose(1, 3),
_lora_blob,
)
.transpose(1, 3)
.reshape(bsz, -1, self.dim)
)
else:
y = self.wo(y)
y = self._lora_call(self.wo, y, _lora_blob)

update = {"out_cache_state": out_cache_state}
if kv_to_share is not None:
Expand Down
Loading