diff --git a/dataloader_pth.py b/dataloader_pth.py index 008ddd6..cf44267 100644 --- a/dataloader_pth.py +++ b/dataloader_pth.py @@ -96,10 +96,10 @@ def __init__(self, for aug in augmentation_list: self.augmentation_list.append(self.augmentations[aug]) trainform, testform = self.transform() + self.target_transform = self.to_binary self.build_train_dataset(trainform) self.build_val_dataset(trainform) self.build_test_dataset(testform) - self.target_transform = self.to_binary def list_dataset_variants(self): print(self.list_dataset_variant) @@ -196,7 +196,7 @@ class T50(Dataset): def __init__(self, img_dir, label_file, transform=None, target_transform=None): label_data = json.load(open(label_file, "rb")) self.label_data = label_data["annotations"] - self.frames = self.label_data.keys() + self.frames = list(self.label_data.keys()) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform @@ -229,7 +229,7 @@ def get_binary_labels(self, labels): return (triplet_label, tool_label, verb_label, target_label, phase_label) def __getitem__(self, index): - labels = self.label_data["annotations"][self.frames[index]] + labels = self.label_data[self.frames[index]] basename = "{}.png".format(str(self.frames[index]).zfill(6)) img_path = os.path.join(self.img_dir, basename) image = Image.open(img_path)