diff --git a/libreyolo/models/rfdetr/loss.py b/libreyolo/models/rfdetr/loss.py index 2d41b65e..8c4a70b2 100644 --- a/libreyolo/models/rfdetr/loss.py +++ b/libreyolo/models/rfdetr/loss.py @@ -394,10 +394,14 @@ def loss_masks(self, outputs, targets, indices, num_boxes): spatial_features = outputs["pred_masks"]["spatial_features"] query_features = outputs["pred_masks"]["query_features"] bias = outputs["pred_masks"]["bias"] - # If there are no matches, return an empty tensor like the Tensor branch does. if idx[0].numel() == 0: - device = spatial_features.device - src_masks = torch.tensor([], device=device) + # Return zero losses that ARE connected to the mask head tensors so + # every mask-head parameter still receives a (zero) gradient. + # torch.tensor([]) has no grad_fn and silently drops those params from + # the backward graph, which violates DDP static_graph=True and causes + # a "finished reduction" crash whenever a rank sees an all-unlabeled batch. + zero = spatial_features.sum() * 0.0 + query_features.sum() * 0.0 + bias.sum() * 0.0 + return {"loss_mask_ce": zero, "loss_mask_dice": zero} else: batched_selected_masks = [] per_batch_counts = idx[0].unique(return_counts=True)[1] diff --git a/tests/unit/test_rfdetr_seg_ddp_static_graph.py b/tests/unit/test_rfdetr_seg_ddp_static_graph.py index 9d937900..1dc488ab 100644 --- a/tests/unit/test_rfdetr_seg_ddp_static_graph.py +++ b/tests/unit/test_rfdetr_seg_ddp_static_graph.py @@ -14,9 +14,11 @@ from __future__ import annotations import pytest +import torch pytestmark = pytest.mark.unit +rfdetr_loss = pytest.importorskip("libreyolo.models.rfdetr.loss") rfdetr_trainer = pytest.importorskip("libreyolo.models.rfdetr.trainer") @@ -44,3 +46,49 @@ def test_det_trainer_ddp_uses_find_unused_not_static_graph(): kwargs = trainer._ddp_kwargs() assert kwargs["find_unused_parameters"] is True assert kwargs["static_graph"] is False + + +def test_mask_loss_no_match_zero_stays_connected_to_mask_head_tensors(): + criterion = rfdetr_loss.SetCriterion( + num_classes=1, + matcher=None, + weight_dict={}, + focal_alpha=0.25, + losses=["masks"], + ) + + spatial_features = torch.randn(2, 4, 8, 8, requires_grad=True) + query_features = torch.randn(2, 5, 4, requires_grad=True) + bias = torch.randn(1, requires_grad=True) + outputs = { + "pred_masks": { + "spatial_features": spatial_features, + "query_features": query_features, + "bias": bias, + } + } + targets = [ + { + "labels": torch.zeros(0, dtype=torch.long), + "boxes": torch.zeros(0, 4), + "masks": torch.zeros(0, 8, 8, dtype=torch.bool), + } + for _ in range(2) + ] + indices = [ + (torch.empty(0, dtype=torch.long), torch.empty(0, dtype=torch.long)) + for _ in targets + ] + + losses = criterion.loss_masks(outputs, targets, indices, num_boxes=1.0) + loss = losses["loss_mask_ce"] + losses["loss_mask_dice"] + loss.backward() + + assert loss.ndim == 0 + assert loss.item() == 0.0 + assert spatial_features.grad is not None + assert query_features.grad is not None + assert bias.grad is not None + assert spatial_features.grad.abs().sum().item() == 0.0 + assert query_features.grad.abs().sum().item() == 0.0 + assert bias.grad.abs().sum().item() == 0.0