From bfb9189023d727e1db599096f072c79e1e60a11c Mon Sep 17 00:00:00 2001 From: Abhin Date: Wed, 22 Apr 2026 19:04:11 +0530 Subject: [PATCH] Optimize temporal dropout using vectorized masked_fill --- tribev2/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tribev2/model.py b/tribev2/model.py index d08d170..93f324e 100644 --- a/tribev2/model.py +++ b/tribev2/model.py @@ -219,9 +219,8 @@ def aggregate_features(self, batch): elif self.config.extractor_aggregation == "sum": out = sum(tensors) if self.config.temporal_dropout > 0 and self.training: - for batch_idx in range(out.shape[0]): - mask = torch.rand(out.shape[1]) < self.config.temporal_dropout - out[batch_idx, mask, :] = torch.zeros_like(out[batch_idx, mask, :]) + mask = torch.rand(out.shape[:2], device=out.device) < self.config.temporal_dropout + out = out.masked_fill_(mask.unsqueeze(-1), 0.0) return out def transformer_forward(self, x, subject_id=None):