diff --git a/tribev2/model.py b/tribev2/model.py index d08d170..93f324e 100644 --- a/tribev2/model.py +++ b/tribev2/model.py @@ -219,9 +219,8 @@ def aggregate_features(self, batch): elif self.config.extractor_aggregation == "sum": out = sum(tensors) if self.config.temporal_dropout > 0 and self.training: - for batch_idx in range(out.shape[0]): - mask = torch.rand(out.shape[1]) < self.config.temporal_dropout - out[batch_idx, mask, :] = torch.zeros_like(out[batch_idx, mask, :]) + mask = torch.rand(out.shape[:2], device=out.device) < self.config.temporal_dropout + out = out.masked_fill_(mask.unsqueeze(-1), 0.0) return out def transformer_forward(self, x, subject_id=None):