From 9ce101115f4b656dd6ff9987b01adf130f1489a0 Mon Sep 17 00:00:00 2001 From: Nekodificador Date: Tue, 26 May 2026 11:54:30 +0200 Subject: [PATCH] fix(stg): skip self-attention by module name instead of call index STG's PatchAttention monkey-patches the global optimized_attention function and counts calls to match a target index. Any extension that changes the number of attention calls per transformer block offsets the index, causing STG to skip the wrong layer or crash with shape mismatches. Affected by this fragility: - NAG (Normalized Attention Guidance) adds an extra attention call per cross-attn for the negative context. - Memory-efficient cross-attention chunking splits a single call into 2-3 calls. - Prompt routing / attention masks wrapping attention calls. Replace PatchAttention with PatchSelfAttn, which selects self-attention modules by name (attn1, audio_attn1) and replaces their forward with a stub applying V (and gated attention if present) - the semantic equivalent of the previous skip. Respects run_vx / run_ax flags from MultimodalGuider for the audio-video model. This is semantically what STG was always trying to do (skip self-attention, not "the N-th attention call"), and it is robust to any patch that changes the call structure within the block. Patch authored by Kijai while debugging NAG compatibility with LTXVAddGuide. Co-Authored-By: Kijai --- stg.py | 100 ++++++++++++++++++++++++++++++++------------------------- 1 file changed, 56 insertions(+), 44 deletions(-) diff --git a/stg.py b/stg.py index 9e54821..3b124a8 100644 --- a/stg.py +++ b/stg.py @@ -3,9 +3,8 @@ import math import os from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List -import comfy.ldm.modules.attention import comfy.samplers import torch from comfy.model_patcher import ModelPatcher @@ -119,55 +118,48 @@ class STGFlag: skip_layers: List[int] = None -# context manager that replaces the attention function in a transformer block -class PatchAttention(contextlib.AbstractContextManager): - def __init__(self, attn_idx: Optional[Union[int, List[int]]] = None): - self.current_idx = -1 - - if isinstance(attn_idx, int): - self.attn_idx = [attn_idx] - elif attn_idx is None: - self.attn_idx = [0] - else: - self.attn_idx = list(attn_idx) +# context manager that replaces specific self-attention modules' forward with a "skip" stub. +class PatchSelfAttn(contextlib.AbstractContextManager): + def __init__(self, attn_modules): + self.attn_modules = list(attn_modules) + self._originals = [] def __enter__(self): - self.original_attention = comfy.ldm.modules.attention.optimized_attention - self.original_attention_masked = ( - comfy.ldm.modules.attention.optimized_attention_masked - ) - - comfy.ldm.modules.attention.optimized_attention = self.stg_attention - comfy.ldm.modules.attention.optimized_attention_masked = ( - self.stg_attention_masked - ) + for attn in self.attn_modules: + self._originals.append((attn, attn.forward)) + attn.forward = self._make_stub(attn) + return self def __exit__(self, exc_type, exc_value, traceback): - comfy.ldm.modules.attention.optimized_attention = self.original_attention - comfy.ldm.modules.attention.optimized_attention_masked = ( - self.original_attention_masked - ) + for attn, orig in self._originals: + attn.forward = orig + self._originals.clear() - self.original_attention = None - self.original_attention_masked = None + @staticmethod + def _make_stub(attn): + def stub(x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}): + ctx = x if context is None else context + out = attn.to_v(ctx) + if getattr(attn, "to_gate_logits", None) is not None: + gate_logits = attn.to_gate_logits(x) + b, t, _ = out.shape + out = out.view(b, t, attn.heads, attn.dim_head) + gates = 2.0 * torch.sigmoid(gate_logits) + out = out * gates.unsqueeze(-1) + out = out.view(b, t, attn.heads * attn.dim_head) + return attn.to_out(out) + return stub - def stg_attention(self, q, k, v, heads, *args, **kwargs): - self.current_idx += 1 - if self.current_idx in self.attn_idx: - return v - else: - return self.original_attention(q, k, v, heads, *args, **kwargs) - def stg_attention_masked(self, q, k, v, heads, *args, **kwargs): - self.current_idx += 1 - if self.current_idx in self.attn_idx: - return v - else: - return self.original_attention_masked(q, k, v, heads, *args, **kwargs) +class STGBlockWrapper: + """Wraps transformer blocks to skip self-attention layers for STG. + Selects which self-attentions to skip by module name (attn1 / audio_attn1) + rather than by counting optimized_attention call indices, so it isn't + perturbed by changes in how many internal attention calls a layer makes. + """ -class STGBlockWrapper: - """Wraps transformer blocks to be able to skip attention layers.""" + SELF_ATTN_NAMES = ("attn1", "audio_attn1") def __init__(self, block, stg_flag: STGFlag, idx: int): self.flag = stg_flag @@ -177,14 +169,34 @@ def __init__(self, block, stg_flag: STGFlag, idx: int): def __call__(self, args, extra_args): context_manager = contextlib.nullcontext() - stg_indexes = args["transformer_options"].get("stg_indexes", [0]) if self.flag.do_skip and self.idx in self.flag.skip_layers: - context_manager = PatchAttention(stg_indexes) + attns = self._select_self_attns(args.get("transformer_options", {})) + if attns: + context_manager = PatchSelfAttn(attns) with context_manager: hidden_state = extra_args["original_block"](args) return hidden_state + def _select_self_attns(self, transformer_options): + has_modality_flags = ( + "run_vx" in transformer_options or "run_ax" in transformer_options + ) + run_vx = transformer_options.get("run_vx", True) + run_ax = transformer_options.get("run_ax", True) + + attns = [] + for name in self.SELF_ATTN_NAMES: + if not hasattr(self.block, name): + continue + if has_modality_flags: + if name == "attn1" and not run_vx: + continue + if name == "audio_attn1" and not run_ax: + continue + attns.append(getattr(self.block, name)) + return attns + class STGGuider(comfy.samplers.CFGGuider): def __init__(