diff --git a/stg.py b/stg.py index 9e54821..3b124a8 100644 --- a/stg.py +++ b/stg.py @@ -3,9 +3,8 @@ import math import os from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List -import comfy.ldm.modules.attention import comfy.samplers import torch from comfy.model_patcher import ModelPatcher @@ -119,55 +118,48 @@ class STGFlag: skip_layers: List[int] = None -# context manager that replaces the attention function in a transformer block -class PatchAttention(contextlib.AbstractContextManager): - def __init__(self, attn_idx: Optional[Union[int, List[int]]] = None): - self.current_idx = -1 - - if isinstance(attn_idx, int): - self.attn_idx = [attn_idx] - elif attn_idx is None: - self.attn_idx = [0] - else: - self.attn_idx = list(attn_idx) +# context manager that replaces specific self-attention modules' forward with a "skip" stub. +class PatchSelfAttn(contextlib.AbstractContextManager): + def __init__(self, attn_modules): + self.attn_modules = list(attn_modules) + self._originals = [] def __enter__(self): - self.original_attention = comfy.ldm.modules.attention.optimized_attention - self.original_attention_masked = ( - comfy.ldm.modules.attention.optimized_attention_masked - ) - - comfy.ldm.modules.attention.optimized_attention = self.stg_attention - comfy.ldm.modules.attention.optimized_attention_masked = ( - self.stg_attention_masked - ) + for attn in self.attn_modules: + self._originals.append((attn, attn.forward)) + attn.forward = self._make_stub(attn) + return self def __exit__(self, exc_type, exc_value, traceback): - comfy.ldm.modules.attention.optimized_attention = self.original_attention - comfy.ldm.modules.attention.optimized_attention_masked = ( - self.original_attention_masked - ) + for attn, orig in self._originals: + attn.forward = orig + self._originals.clear() - self.original_attention = None - self.original_attention_masked = None + @staticmethod + def _make_stub(attn): + def stub(x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}): + ctx = x if context is None else context + out = attn.to_v(ctx) + if getattr(attn, "to_gate_logits", None) is not None: + gate_logits = attn.to_gate_logits(x) + b, t, _ = out.shape + out = out.view(b, t, attn.heads, attn.dim_head) + gates = 2.0 * torch.sigmoid(gate_logits) + out = out * gates.unsqueeze(-1) + out = out.view(b, t, attn.heads * attn.dim_head) + return attn.to_out(out) + return stub - def stg_attention(self, q, k, v, heads, *args, **kwargs): - self.current_idx += 1 - if self.current_idx in self.attn_idx: - return v - else: - return self.original_attention(q, k, v, heads, *args, **kwargs) - def stg_attention_masked(self, q, k, v, heads, *args, **kwargs): - self.current_idx += 1 - if self.current_idx in self.attn_idx: - return v - else: - return self.original_attention_masked(q, k, v, heads, *args, **kwargs) +class STGBlockWrapper: + """Wraps transformer blocks to skip self-attention layers for STG. + Selects which self-attentions to skip by module name (attn1 / audio_attn1) + rather than by counting optimized_attention call indices, so it isn't + perturbed by changes in how many internal attention calls a layer makes. + """ -class STGBlockWrapper: - """Wraps transformer blocks to be able to skip attention layers.""" + SELF_ATTN_NAMES = ("attn1", "audio_attn1") def __init__(self, block, stg_flag: STGFlag, idx: int): self.flag = stg_flag @@ -177,14 +169,34 @@ def __init__(self, block, stg_flag: STGFlag, idx: int): def __call__(self, args, extra_args): context_manager = contextlib.nullcontext() - stg_indexes = args["transformer_options"].get("stg_indexes", [0]) if self.flag.do_skip and self.idx in self.flag.skip_layers: - context_manager = PatchAttention(stg_indexes) + attns = self._select_self_attns(args.get("transformer_options", {})) + if attns: + context_manager = PatchSelfAttn(attns) with context_manager: hidden_state = extra_args["original_block"](args) return hidden_state + def _select_self_attns(self, transformer_options): + has_modality_flags = ( + "run_vx" in transformer_options or "run_ax" in transformer_options + ) + run_vx = transformer_options.get("run_vx", True) + run_ax = transformer_options.get("run_ax", True) + + attns = [] + for name in self.SELF_ATTN_NAMES: + if not hasattr(self.block, name): + continue + if has_modality_flags: + if name == "attn1" and not run_vx: + continue + if name == "audio_attn1" and not run_ax: + continue + attns.append(getattr(self.block, name)) + return attns + class STGGuider(comfy.samplers.CFGGuider): def __init__(