-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathImageFilesDataset.py
More file actions
84 lines (69 loc) · 2.78 KB
/
ImageFilesDataset.py
File metadata and controls
84 lines (69 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
'''
This code is adopted from https://github.com/marcojira/fld
'''
from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
from collections import defaultdict
import random
class ImageFilesDataset(Dataset):
"""
Creates torch Dataset from directory of images.
Must be structured as dir/<class>/<img_name>.<extension> for `conditional=True`
For `conditional=False`, will search recursively for all files that match the extension
"""
def __init__(
self, path, name=None, extension="png", transform=None, conditional=False, n=None
):
self.path = path
self.name = name
self.extension = extension
self.conditional = conditional # If conditional, will get the class from the parent folder's name
self.transform = transform
self.files = []
self.files_loaded = False # For lazy loading of files
self.n = n # Maximum number of images per class
# def load_files(self):
# for curr_path in Path(self.path).rglob(f"*.{self.extension}"):
# if self.conditional:
# self.files.append((curr_path, curr_path.parent.name))
# else:
# self.files.append((curr_path, 0))
# self.files_loaded = True
def load_files(self):
class_count = defaultdict(int)
# 获取所有匹配的文件路径
all_paths = list(Path(self.path).rglob(f"*.{self.extension}"))
# 打乱路径列表
random.shuffle(all_paths)
for curr_path in all_paths:
if self.conditional:
class_name = curr_path.parent.name
if self.n is None or class_count[class_name] < self.n:
self.files.append((curr_path, class_name))
class_count[class_name] += 1
else:
if self.n is None or class_count[0] < self.n:
self.files.append((curr_path, 0))
class_count[0] += 1
self.files_loaded = True
def __len__(self):
if not self.files_loaded:
self.load_files()
return len(self.files)
def __getitem__(self, idx):
if not self.files_loaded:
self.load_files()
img_path, class_id = self.files[idx]
if 'mnist' in self.name.lower() and 'color' not in self.name.lower():
with Image.open(img_path).convert("L") as img:
if self.transform:
img = self.transform(img)
return img, class_id
else:
with Image.open(img_path).convert("RGB") as img:
if self.transform:
img = self.transform(img)
return img, class_id
def get_class(self, idx):
return self.files[idx][1]