While comparing the token-merging path against the paper (Algorithm 1, Globally Informed Token Merging) and the original ToMe, I ran into two things that look off. They're independent, so feel free to split them.
Finding 1 — weighted_patch energy is computed per-patch (local), not over the global bipartite graph
Algorithm 1 defines each point's energy as its (negative) average cosine similarity to all patch centroids:
E(x_i) = -(1/|N(x_i)|) * Σ_j cos(x_i, P̄_j) # j over every patch centroid
i.e. a global point→centroid graph, then patch energy = mean of its points' energies, then the merge rate is chosen by thresholding that energy.
In code, weighted_patch calls cal_score(v, v) (in point_transformer_v3m2_sonata.py), with v shaped [B=#patches, H, K=patch_size, D]:
def cal_score(self, q, k, threshold=0.9):
q = F.avg_pool1d(q.mean(1).transpose(-1,-2), 32, 32) \
.transpose(-1,-2).unsqueeze(1).expand(...) # [B, H, K/32, D]
q = F.normalize(q, dim=-1); k = F.normalize(k, dim=-1) # k: [B, H, K, D]
sim = q @ k.mean(-2, keepdim=True).transpose(-1,-2) # [B,H,K/32,D] @ [B,H,D,1] = [B,H,K/32,1]
score = sim.squeeze().mean(-1).mean(-1) # [B]
k.mean(-2, keepdim=True) is each patch's own centroid [B,H,1,D], and the matmul is batched over [B,H], so patch b's tokens are only ever compared against patch b's own centroid. There's no term summing cos(x_i, P̄_j) over the other patches, so score[b] is intra-patch coherence, not the global energy. The high/low split in process_weighted_merging is therefore driven by a purely local signal.
The global version exists — cal_density_score (token_merging_algos.py:675) does build the full point×centroid graph:
sim = q.reshape(-1, C) @ k.mean(1).transpose(-1,-2) # [B*T, C] @ [C, B] = [B*T, B]
score = sim.max(-1)[0]
but its only reference is commented out (token_merging_algos.py:832), so it never runs. bipartite_soft_matching / patch_based_matching are batched over the patch dimension too, so they don't add any cross-patch term either.
Question: is the local cal_score intentional (speed/ablation), or should the global energy be wired into weighted_patch to match Algorithm 1? As-is, the "globally informed" part of the method doesn't seem active.
Finding 2 — proportional attention (size.log()) is added on the wrong axis and cancels out
In self_attn (`point_transformer_v3m2_sonata.py):
if size is not None:
attn = attn + size.log()
with:
attn = (q*scale) @ k.transpose(-2,-1) → [B, H, T, T] = [B, H, T_query, T_key]
size = merge(ones, mode="sum") → [B, H, T, 1]
size.log() is [B,H,T,1], so it broadcasts onto the query axis (the trailing 1 expands over the key axis):
attn[b,h,i,j] += log(size[i]) # constant across keys j, varies over query i
Two problems:
- Wrong axis. ToMe's proportional attention boosts the key that represents more merged tokens, added along the softmax axis:
# facebookresearch/ToMe
attn = attn + size.log()[:, None, None, :, 0] # size [B,N,1] -> [B,1,1,N], on the key axis
attn = attn.softmax(dim=-1)
- It's a silent no-op. softmax is over the key axis (
dim=-1), and adding a per-query constant across all keys is invariant under softmax (softmax(x + c) = softmax(x)). So + log(size[i]) is normalized away entirely — proportional attention is effectively disabled. It doesn't raise an error because [B,H,T,1] broadcasts cleanly against [B,H,T,T], which is why it's easy to miss.
I might be misreading parts of this, so I'd just like to hear the authors' take — are the local cal_score (Finding 1) and the current size.log() form (Finding 2) intentional, or am I missing something? Mainly trying to understand the intended behavior.
While comparing the token-merging path against the paper (Algorithm 1, Globally Informed Token Merging) and the original ToMe, I ran into two things that look off. They're independent, so feel free to split them.
Finding 1 —
weighted_patchenergy is computed per-patch (local), not over the global bipartite graphAlgorithm 1 defines each point's energy as its (negative) average cosine similarity to all patch centroids:
i.e. a global point→centroid graph, then patch energy = mean of its points' energies, then the merge rate is chosen by thresholding that energy.
In code,
weighted_patchcallscal_score(v, v)(inpoint_transformer_v3m2_sonata.py), withvshaped[B=#patches, H, K=patch_size, D]:k.mean(-2, keepdim=True)is each patch's own centroid[B,H,1,D], and the matmul is batched over[B,H], so patchb's tokens are only ever compared against patchb's own centroid. There's no term summingcos(x_i, P̄_j)over the other patches, soscore[b]is intra-patch coherence, not the global energy. The high/low split inprocess_weighted_mergingis therefore driven by a purely local signal.The global version exists —
cal_density_score(token_merging_algos.py:675) does build the full point×centroid graph:but its only reference is commented out (
token_merging_algos.py:832), so it never runs.bipartite_soft_matching/patch_based_matchingare batched over the patch dimension too, so they don't add any cross-patch term either.Question: is the local
cal_scoreintentional (speed/ablation), or should the global energy be wired intoweighted_patchto match Algorithm 1? As-is, the "globally informed" part of the method doesn't seem active.Finding 2 — proportional attention (
size.log()) is added on the wrong axis and cancels outIn
self_attn(`point_transformer_v3m2_sonata.py):with:
attn = (q*scale) @ k.transpose(-2,-1)→[B, H, T, T]=[B, H, T_query, T_key]size = merge(ones, mode="sum")→[B, H, T, 1]size.log()is[B,H,T,1], so it broadcasts onto the query axis (the trailing1expands over the key axis):Two problems:
dim=-1), and adding a per-query constant across all keys is invariant under softmax (softmax(x + c) = softmax(x)). So+ log(size[i])is normalized away entirely — proportional attention is effectively disabled. It doesn't raise an error because[B,H,T,1]broadcasts cleanly against[B,H,T,T], which is why it's easy to miss.I might be misreading parts of this, so I'd just like to hear the authors' take — are the local
cal_score(Finding 1) and the currentsize.log()form (Finding 2) intentional, or am I missing something? Mainly trying to understand the intended behavior.