diff --git a/stg.py b/stg.py index 9e54821..fbcd8f7 100644 --- a/stg.py +++ b/stg.py @@ -123,6 +123,7 @@ class STGFlag: class PatchAttention(contextlib.AbstractContextManager): def __init__(self, attn_idx: Optional[Union[int, List[int]]] = None): self.current_idx = -1 + self._guide_offset = 0 if isinstance(attn_idx, int): self.attn_idx = [attn_idx] @@ -151,19 +152,42 @@ def __exit__(self, exc_type, exc_value, traceback): self.original_attention = None self.original_attention_masked = None - def stg_attention(self, q, k, v, heads, *args, **kwargs): - self.current_idx += 1 - if self.current_idx in self.attn_idx: - return v + def _stg_call(self, original, q, k, v, heads, args, kwargs): + # comfy's guide-mask self-attention (_attention_with_guide_mask in + # comfy/ldm/lightricks/model.py) splits one self-attention into several + # optimized_attention calls over contiguous *query slices*, each against + # the full key/value. Those sub-calls are the only ones that pass + # low_precision_attention=False, which lets us recognise them: a plain + # "return v" would be the wrong length (full sequence vs. the query + # slice) and would also miscount the STG attention index (one logical + # self-attention would consume several indices, shifting audio_attn_idx). + # We collapse the split into a single logical attention and, when + # skipping, return the matching slice of v. + guide_split = kwargs.get("low_precision_attention") is False and q.shape[1] < v.shape[1] + continuation = guide_split and self._guide_offset > 0 + + if not continuation: + self.current_idx += 1 + skip = self.current_idx in self.attn_idx + + if not guide_split: + return v if skip else original(q, k, v, heads, *args, **kwargs) + + off = self._guide_offset + q_len = q.shape[1] + if skip: + out = v[:, off:off + q_len] else: - return self.original_attention(q, k, v, heads, *args, **kwargs) + out = original(q, k, v, heads, *args, **kwargs) + off += q_len + self._guide_offset = 0 if off >= v.shape[1] else off + return out + + def stg_attention(self, q, k, v, heads, *args, **kwargs): + return self._stg_call(self.original_attention, q, k, v, heads, args, kwargs) def stg_attention_masked(self, q, k, v, heads, *args, **kwargs): - self.current_idx += 1 - if self.current_idx in self.attn_idx: - return v - else: - return self.original_attention_masked(q, k, v, heads, *args, **kwargs) + return self._stg_call(self.original_attention_masked, q, k, v, heads, args, kwargs) class STGBlockWrapper: