Skip to content

Fixed MSLoss GPU utilization: completely rewritten for CUDA#33

Open
Fleyderer wants to merge 1 commit into
jixunbo:masterfrom
Fleyderer:master
Open

Fixed MSLoss GPU utilization: completely rewritten for CUDA#33
Fleyderer wants to merge 1 commit into
jixunbo:masterfrom
Fleyderer:master

Conversation

@Fleyderer

@Fleyderer Fleyderer commented Oct 2, 2025

Copy link
Copy Markdown

When I've tried to train LMBN model using existing code, I've got a problem of very low GPU utilization: I could use 95% of GPU memory, but still had 7-10% of GPU utilization, while for half of a second I had 50% (absolutely typical situation when there are CPU - GPU usage problems).

After profiling I had this picture of trace:
image

First row of small magenta bars is CUDA usage. So you can see that when computing loss forward-backward, we spend like 2.5 seconds on CPU. After small research I've found that problem is in for-loops and min-max python functions, which are doing enormous count of CPU operations.

When I've fixed this code, I had 70-99% usage of GPU and 10x speed up of training (In my case it is 10 days -> 1 day).

Additionally, I will provide some code for input-output validation:

import torch
from torch import nn
import torch.nn.functional as F

class MultiSimilarityLoss(nn.Module):
    def __init__(self, margin=0.1):
        super(MultiSimilarityLoss, self).__init__()
        self.thresh = 0.5
        self.margin = margin

        self.scale_pos = 2.0
        self.scale_neg = 40.0

    def forward(self, feats, labels):
        assert feats.size(0) == labels.size(0), \
            f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}"
        batch_size = feats.size(0)
        feats = nn.functional.normalize(feats, p=2, dim=1)

        # Shape: batchsize * batch size
        sim_mat = torch.matmul(feats, torch.t(feats))

        epsilon = 1e-5
        loss = list()

        # for i in range(batch_size):
        #     # print(i,'ccccc')
        #     pos_pair_ = sim_mat[i][labels == labels[i]]
        #     # print(pos_pair_.shape)
        #     pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
        #     neg_pair_ = sim_mat[i][labels != labels[i]]

        #     neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)]
        #     pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)]

        #     if len(neg_pair) < 1 or len(pos_pair) < 1:
        #         continue

        #     # weighting step
        #     pos_loss = 1.0 / self.scale_pos * torch.log(
        #         1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))
        #     neg_loss = 1.0 / self.scale_neg * torch.log(
        #         1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))
        #     loss.append(pos_loss + neg_loss)

        mask = labels.expand(batch_size, batch_size).eq(
            labels.expand(batch_size, batch_size).t())
        for i in range(batch_size):
            pos_pair_ = sim_mat[i][mask[i]]
            pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
            neg_pair_ = sim_mat[i][mask[i] == 0]

            neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)]
            pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)]

            if len(neg_pair) < 1 or len(pos_pair) < 1:
                continue

            # weighting step
            pos_loss = 1.0 / self.scale_pos * torch.log(
                1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))
            neg_loss = 1.0 / self.scale_neg * torch.log(
                1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))
            loss.append(pos_loss + neg_loss)
            # pos_loss = 


        if len(loss) == 0:
            return torch.zeros([], requires_grad=True, device=feats.device)

        loss = sum(loss) / batch_size
        return loss

class MultiSimilarityLossNew(nn.Module):
    def __init__(self, margin: float = 0.1, thresh: float = 0.5,
                 scale_pos: float = 2.0, scale_neg: float = 40.0, eps: float = 1e-5):
        super(MultiSimilarityLossNew, self).__init__()
        self.margin = margin
        self.thresh = thresh
        self.scale_pos = scale_pos
        self.scale_neg = scale_neg
        self.eps = eps

    def forward(self, feats: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Vectorized CUDA-friendly version preserving original algorithm semantics.

        Args:
            feats: (B, D) tensor of embeddings (should be on the same device as labels).
            labels: (B,) long tensor of integer class labels.

        Returns:
            scalar loss tensor on the same device as feats.
        """
        device = feats.device
        B = feats.size(0)
        if B == 0:
            return torch.zeros([], device=device, requires_grad=True)

        # normalize features
        feats = F.normalize(feats, p=2, dim=1)

        # similarity matrix (B, B)
        sim_mat = torch.matmul(feats, feats.t())

        # boolean masks
        labels_row = labels.unsqueeze(1)
        labels_col = labels.unsqueeze(0)
        mask_pos = labels_row.eq(labels_col)               # includes diagonal
        eye = torch.eye(B, dtype=torch.bool, device=device)
        mask_pos = mask_pos & ~eye                         # exclude self-similarity
        mask_neg_all = ~labels_row.eq(labels_col)          # all negatives

        # apply initial pos filter: pos_pair_ < 1 - eps
        pos_mask_init = mask_pos & (sim_mat < (1.0 - self.eps))

        # For rows with no pos or no neg at all, we'll exclude them later.
        # compute per-row min over pos_mask_init (min pos similarity) and per-row max over all negs
        big_pos = float('inf')
        small_neg = float('-inf')

        # min pos per row (if no pos, will be +inf)
        sim_pos_masked = torch.where(pos_mask_init, sim_mat, torch.tensor(big_pos, device=device, dtype=sim_mat.dtype))
        min_pos_per_row = torch.min(sim_pos_masked, dim=1).values  # (B,)

        # max neg per row (over all negatives, as in original code)
        sim_neg_masked = torch.where(mask_neg_all, sim_mat, torch.tensor(small_neg, device=device, dtype=sim_mat.dtype))
        max_neg_per_row = torch.max(sim_neg_masked, dim=1).values  # (B,)

        # Now build the two filtered sets per original algorithm:
        # neg_pair = neg_pair_[neg_pair_ + margin > min(pos_pair_)]
        # pos_pair = pos_pair_[pos_pair_ - margin < max(neg_pair_)]
        # We compute boolean masks for these selections, broadcasting per-row scalars.

        # For rows where min_pos_per_row == +inf (no pos) or max_neg_per_row == -inf (no neg) we'll mark invalid later.
        # Expand scalars for broadcasting
        min_pos_b = min_pos_per_row.unsqueeze(1)  # (B,1)
        max_neg_b = max_neg_per_row.unsqueeze(1)  # (B,1)

        neg_keep_mask = mask_neg_all & ((sim_mat + self.margin) > min_pos_b)
        pos_keep_mask = pos_mask_init & ((sim_mat - self.margin) < max_neg_b)

        # Valid rows are those having at least one kept pos AND at least one kept neg (matches original `if len <1: continue`)
        pos_count = pos_keep_mask.sum(dim=1)      # (B,)
        neg_count = neg_keep_mask.sum(dim=1)      # (B,)
        valid_rows = (pos_count > 0) & (neg_count > 0)

        if valid_rows.sum() == 0:
            # no valid pairs at all -> match original behavior
            return torch.zeros([], device=device, requires_grad=True)

        # Compute weighted sums per-row (but we'll zero out invalid rows later)
        # pos: exp(-scale_pos * (pos_pair - thresh))
        pos_exp = torch.exp(-self.scale_pos * (sim_mat - self.thresh)) * pos_keep_mask.to(sim_mat.dtype)
        pos_sum = pos_exp.sum(dim=1)  # (B,)

        # neg: exp(scale_neg * (neg_pair - thresh))
        neg_exp = torch.exp(self.scale_neg * (sim_mat - self.thresh)) * neg_keep_mask.to(sim_mat.dtype)
        neg_sum = neg_exp.sum(dim=1)  # (B,)

        # Avoid numerical issues: clamp sums to >=0
        pos_sum = pos_sum.clamp_min(0.0)
        neg_sum = neg_sum.clamp_min(0.0)

        # per-row losses (for invalid rows these will be non-sensical, so zero them out later)
        pos_loss_row = (1.0 / self.scale_pos) * torch.log1p(pos_sum)
        neg_loss_row = (1.0 / self.scale_neg) * torch.log1p(neg_sum)

        row_loss = pos_loss_row + neg_loss_row  # (B,)

        # Zero out rows that were invalid (matching `continue` in original)
        row_loss = row_loss * valid_rows.to(row_loss.dtype)

        # Final loss: sum(valid_row_losses) / batch_size  (same as original code)
        total_loss = row_loss.sum() / float(B)

        return total_loss

And then:

ms_loss_old = MultiSimilarityLoss()
ms_loss_new = MultiSimilarityLossNew()

feats = torch.randn(6, 128)
labels = torch.tensor([0, 0, 1, 1, 2, 2], dtype=torch.long)

loss = ms_loss_old(feats, labels)
print("Loss:", loss.item())

loss = ms_loss_new(feats, labels)
print("Loss:", loss.item())

I've found that most of time these values are identical, while sometimes there is difference after 1e-7 or 1e-8 which is an acceptable tolerance for GPU optimizations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant