From 618532f41c73f37ad24cce21816f3f8f4323aa6f Mon Sep 17 00:00:00 2001 From: "Jean J. de Jong" Date: Tue, 2 Jun 2026 09:45:47 +0200 Subject: [PATCH] Fix STG crash/index miscount with guide-mask self-attention When cond-image/keyframe guides with strength != 1.0 are combined with STG perturbation, the perturbed denoise step crashed with: RuntimeError: The expanded size of the tensor (N) must match the existing size (M) ... in _attention_with_guide_mask Root cause: comfy core (CORE-166, "Reduce LTX2.3 peak VRAM when guide_mask is in use") splits one video self-attention into up to three optimized_attention calls over sliced queries against the full key/value. STG's PatchAttention skipped a layer with `return v` (the full sequence) and counted each sub-call as a separate attention index. So the returned value was the wrong length (crash on the query-slice assignment), and the extra sub-calls shifted audio_attn_idx in calc_stg_indexes. Fix: detect a guide-split sub-call via the low_precision_attention=False kwarg (the only signal core's guide path passes; this avoids false-positiving the v2a cross-attention, which also has q_len < v_len), collapse the split into a single logical STG index, and return the matching v[:, off:off+q_len] slice when skipping. No core changes. Not AV-specific: any video-only workflow combining cond_images (strength != 1.0) with the STG/multimodal guider triggers it. Co-Authored-By: Claude Opus 4.8 --- stg.py | 44 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/stg.py b/stg.py index 9e54821..fbcd8f7 100644 --- a/stg.py +++ b/stg.py @@ -123,6 +123,7 @@ class STGFlag: class PatchAttention(contextlib.AbstractContextManager): def __init__(self, attn_idx: Optional[Union[int, List[int]]] = None): self.current_idx = -1 + self._guide_offset = 0 if isinstance(attn_idx, int): self.attn_idx = [attn_idx] @@ -151,19 +152,42 @@ def __exit__(self, exc_type, exc_value, traceback): self.original_attention = None self.original_attention_masked = None - def stg_attention(self, q, k, v, heads, *args, **kwargs): - self.current_idx += 1 - if self.current_idx in self.attn_idx: - return v + def _stg_call(self, original, q, k, v, heads, args, kwargs): + # comfy's guide-mask self-attention (_attention_with_guide_mask in + # comfy/ldm/lightricks/model.py) splits one self-attention into several + # optimized_attention calls over contiguous *query slices*, each against + # the full key/value. Those sub-calls are the only ones that pass + # low_precision_attention=False, which lets us recognise them: a plain + # "return v" would be the wrong length (full sequence vs. the query + # slice) and would also miscount the STG attention index (one logical + # self-attention would consume several indices, shifting audio_attn_idx). + # We collapse the split into a single logical attention and, when + # skipping, return the matching slice of v. + guide_split = kwargs.get("low_precision_attention") is False and q.shape[1] < v.shape[1] + continuation = guide_split and self._guide_offset > 0 + + if not continuation: + self.current_idx += 1 + skip = self.current_idx in self.attn_idx + + if not guide_split: + return v if skip else original(q, k, v, heads, *args, **kwargs) + + off = self._guide_offset + q_len = q.shape[1] + if skip: + out = v[:, off:off + q_len] else: - return self.original_attention(q, k, v, heads, *args, **kwargs) + out = original(q, k, v, heads, *args, **kwargs) + off += q_len + self._guide_offset = 0 if off >= v.shape[1] else off + return out + + def stg_attention(self, q, k, v, heads, *args, **kwargs): + return self._stg_call(self.original_attention, q, k, v, heads, args, kwargs) def stg_attention_masked(self, q, k, v, heads, *args, **kwargs): - self.current_idx += 1 - if self.current_idx in self.attn_idx: - return v - else: - return self.original_attention_masked(q, k, v, heads, *args, **kwargs) + return self._stg_call(self.original_attention_masked, q, k, v, heads, args, kwargs) class STGBlockWrapper: