Skip to content

Token-merging path diverges from the paper (Algorithm 1) and from ToMe: local energy + no-op proportional attention #3

@YangDongHeon

Description

@YangDongHeon

While comparing the token-merging path against the paper (Algorithm 1, Globally Informed Token Merging) and the original ToMe, I ran into two things that look off. They're independent, so feel free to split them.


Finding 1 — weighted_patch energy is computed per-patch (local), not over the global bipartite graph

Algorithm 1 defines each point's energy as its (negative) average cosine similarity to all patch centroids:

E(x_i) = -(1/|N(x_i)|) * Σ_j cos(x_i, P̄_j)     # j over every patch centroid

i.e. a global point→centroid graph, then patch energy = mean of its points' energies, then the merge rate is chosen by thresholding that energy.

In code, weighted_patch calls cal_score(v, v) (in point_transformer_v3m2_sonata.py), with v shaped [B=#patches, H, K=patch_size, D]:

def cal_score(self, q, k, threshold=0.9):
    q = F.avg_pool1d(q.mean(1).transpose(-1,-2), 32, 32) \
          .transpose(-1,-2).unsqueeze(1).expand(...)        # [B, H, K/32, D]
    q = F.normalize(q, dim=-1); k = F.normalize(k, dim=-1)  # k: [B, H, K, D]
    sim   = q @ k.mean(-2, keepdim=True).transpose(-1,-2)   # [B,H,K/32,D] @ [B,H,D,1] = [B,H,K/32,1]
    score = sim.squeeze().mean(-1).mean(-1)                 # [B]

k.mean(-2, keepdim=True) is each patch's own centroid [B,H,1,D], and the matmul is batched over [B,H], so patch b's tokens are only ever compared against patch b's own centroid. There's no term summing cos(x_i, P̄_j) over the other patches, so score[b] is intra-patch coherence, not the global energy. The high/low split in process_weighted_merging is therefore driven by a purely local signal.

The global version exists — cal_density_score (token_merging_algos.py:675) does build the full point×centroid graph:

sim   = q.reshape(-1, C) @ k.mean(1).transpose(-1,-2)   # [B*T, C] @ [C, B] = [B*T, B]
score = sim.max(-1)[0]

but its only reference is commented out (token_merging_algos.py:832), so it never runs. bipartite_soft_matching / patch_based_matching are batched over the patch dimension too, so they don't add any cross-patch term either.

Question: is the local cal_score intentional (speed/ablation), or should the global energy be wired into weighted_patch to match Algorithm 1? As-is, the "globally informed" part of the method doesn't seem active.


Finding 2 — proportional attention (size.log()) is added on the wrong axis and cancels out

In self_attn (`point_transformer_v3m2_sonata.py):

if size is not None:
    attn = attn + size.log()

with:

  • attn = (q*scale) @ k.transpose(-2,-1)[B, H, T, T] = [B, H, T_query, T_key]
  • size = merge(ones, mode="sum")[B, H, T, 1]

size.log() is [B,H,T,1], so it broadcasts onto the query axis (the trailing 1 expands over the key axis):

attn[b,h,i,j] += log(size[i])      # constant across keys j, varies over query i

Two problems:

  1. Wrong axis. ToMe's proportional attention boosts the key that represents more merged tokens, added along the softmax axis:
    # facebookresearch/ToMe
    attn = attn + size.log()[:, None, None, :, 0]   # size [B,N,1] -> [B,1,1,N], on the key axis
    attn = attn.softmax(dim=-1)
  2. It's a silent no-op. softmax is over the key axis (dim=-1), and adding a per-query constant across all keys is invariant under softmax (softmax(x + c) = softmax(x)). So + log(size[i]) is normalized away entirely — proportional attention is effectively disabled. It doesn't raise an error because [B,H,T,1] broadcasts cleanly against [B,H,T,T], which is why it's easy to miss.

I might be misreading parts of this, so I'd just like to hear the authors' take — are the local cal_score (Finding 1) and the current size.log() form (Finding 2) intentional, or am I missing something? Mainly trying to understand the intended behavior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions