-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
144 lines (113 loc) · 4.31 KB
/
data.py
File metadata and controls
144 lines (113 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# -*- coding:utf-8 -*-
import os
import os.path
import random
import numpy as np
import cv2
import torch
import torch.utils.data as data
import config as cfg
import glob
from PIL import Image
# import gdal
from osgeo import gdal
from torchvision import transforms
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
def gdal_read(input_file):
dataset = gdal.Open(input_file)
rows = dataset.RasterYSize
cols = dataset.RasterXSize
couts = dataset.RasterCount
data = dataset.ReadAsArray().astype(np.float32) # in order to fit the torch.fromnumpy
if len(data.shape) == 2: # this is one band
return data.reshape(1, data.shape[0], data.shape[1])
else:
return data
def transform(LR, HR, FR, args):
# this implmented fits the format of c,h,w
c, h, w = LR.shape
size = args.img_size
location_x = random.randint(0, h-size)
location_y = random.randint(0, w-size)
FR = FR[:, location_x*args.scale:(location_x+size)*args.scale, location_y*args.scale:(location_y+size)*args.scale]
HR = HR[:, location_x*args.scale:(location_x+size)*args.scale, location_y*args.scale:(location_y+size)*args.scale]
LR = LR[:, location_x:location_x+size, location_y:location_y+size]
if random.random() < 0.5:
LR = LR[:,::-1,:]
HR = HR[:,::-1,:]
FR = FR[:,::-1,:]
if random.random() < 0.5:
LR = LR[:,:,::-1]
HR = HR[:,:,::-1]
FR = FR[:,:,::-1]
if random.random() < 0.5:
LR = LR.swapaxes(-2,-1)[:,:,::-1]
HR = HR.swapaxes(-2,-1)[:,:,::-1]
FR = FR.swapaxes(-2,-1)[:,:,::-1]
if random.random() < 0.5:
LR = LR[:,::-1,:]
HR = HR[:,::-1,:]
FR = FR[:,::-1,:]
if random.random() < 0.5:
LR = LR.swapaxes(-2,-1)
HR = HR.swapaxes(-2,-1)
FR = FR.swapaxes(-2,-1)
# print(LR.shape, HR.shape, FR.shape)
LR, HR, FR = LR.copy(), HR.copy(), FR.copy()
return LR, HR, FR
class PsDataset(data.Dataset):
def __init__(self, args, apath, isAug=True, isUnlabel=False):
self.isAug = args.isAug
self.isUnlabel = isUnlabel
self.scale = args.scale
# apath = cfg.dataDir
self.args = args
print(self.isAug, self.isUnlabel, self.scale)
dirHR = 'HR'
dirLR = 'LR'
dirFR = 'FR'
self.dirIn = os.path.join(apath, dirLR)
self.dirTar = os.path.join(apath, dirHR)
self.dirFR = os.path.join(apath, dirFR)
self.LRList = sorted(glob.glob(os.path.join(self.dirIn, '*tif')))
self.HRList = sorted(glob.glob(os.path.join(self.dirTar, '*.tif')))
self.FRList = sorted(glob.glob(os.path.join(self.dirFR, '*.tif')))
random.seed(args.seed)
random.shuffle(self.LRList)
random.seed(args.seed)
random.shuffle(self.HRList)
random.seed(args.seed)
random.shuffle(self.FRList)
self.len = len(self.LRList)
self.transform = transform
def __getitem__(self, idx):
# print(self.LRList[idx], self.HRList[idx], self.FRList[idx])
LR = gdal_read(self.LRList[idx])
HR = gdal_read(self.HRList[idx])
FR = gdal_read(self.FRList[idx])
lr, hr, fr = LR, HR, FR
if self.isAug:
lr, hr, fr = self.transform(LR, HR, FR, self.args)
lr = torch.Tensor(lr).float() / self.args.data_range
hr = torch.Tensor(hr).float() / self.args.data_range
fr = torch.Tensor(fr).float() / self.args.data_range
# print(lr.shape, hr.shape, fr.shape)
return lr, hr, fr
def __len__(self):
return self.len
def Get_DataSet(dataset, length):
size = len(length)
dataset_size = len(dataset)
indices = list(range(dataset_size))
if size == 1:
flag = int(length[0] * dataset_size)
return data.Subset(dataset, indices[:flag])
elif size == 2:
flag = int(length[0] * dataset_size)
return data.Subset(dataset, indices[:flag]), data.Subset(dataset, indices[flag:])
elif size == 3:
flag1 = int(length[0] * dataset_size)
flag2 = int(length[1] * dataset_size)
return data.Subset(dataset, indices[:flag1]), data.Subset(dataset, indices[flag1:flag2]), data.Subset(dataset, indices[flag2:])