diff --git a/models/SSITA_adapter.py b/models/SSITA_adapter.py index 5b30dfb..c5bbe89 100644 --- a/models/SSITA_adapter.py +++ b/models/SSITA_adapter.py @@ -70,7 +70,7 @@ def incremental_train(self, data_manager): print("The number of training dataset:", len(self.train_dataset)) self.data_manager = data_manager - self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=n8) + self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8) test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test") self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8) train_dataset_for_protonet = data_manager.get_dataset(np.arange(0, self._total_classes), source="train",