diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index 772dfa41..ed9b268d 100644 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -314,7 +314,11 @@ def validation_loop( loss_dict = {} time_logs = {} count = 0 - denom = len(dataloader) if full else self.short_validation_length + denom = ( + len(dataloader) + if full + else min(self.short_validation_length, len(dataloader)) + ) with torch.autocast( self.device.type, enabled=self.enable_amp, dtype=self.amp_type ):