-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpreprocess_map.py
More file actions
109 lines (92 loc) · 5.6 KB
/
Copy pathpreprocess_map.py
File metadata and controls
109 lines (92 loc) · 5.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from utils.utils_correspondence import resize
def set_seed(seed=42):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
def process_and_save_features(file_paths, sd_size, dino_size, layer, facet, model, aug, extractor_vit, num_ensemble, flip=False, do_sd=False, do_dino=False,add_str='',sd_post_fix='_sd', dino_post_fix='_dino'):
ii = 0
for file_path in tqdm(file_paths, desc="Processing images (Flip: {})".format(flip)):
subdir_name = 'features' if num_ensemble == 1 else f'features_ensemble{num_ensemble}'
output_subdir = file_path.replace('JPEGImages', subdir_name).rsplit('/', 1)[0]
os.makedirs(output_subdir, exist_ok=True)
suffix = f'{add_str}_flip' if flip else f'{add_str}'
output_path_dino = os.path.join(output_subdir, os.path.splitext(os.path.basename(file_path))[0] + f'{dino_post_fix}{suffix}.pt')
output_path = os.path.join(output_subdir, os.path.splitext(os.path.basename(file_path))[0] + f'{sd_post_fix}{suffix}.pt')
ii += 1
img1 = Image.open(file_path).convert('RGB')
if flip:
img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
img1_input = resize(img1, sd_size, resize=True, to_pil=True)
img1 = resize(img1, dino_size, resize=True, to_pil=True)
if do_sd:
if os.path.exists(output_path):
print(f"File {output_path} already exists. Skipping computing SD features.")
continue
accumulated_features = {}
for _ in range(num_ensemble):
features1 = process_features_and_mask(model, aug, img1_input, mask=False, raw=True)
del features1['s2']
for k in features1:
accumulated_features[k] = accumulated_features.get(k, 0) + features1[k]
for k in accumulated_features:
accumulated_features[k] /= num_ensemble
output_path = os.path.join(output_subdir, os.path.splitext(os.path.basename(file_path))[0] + f'_sd{suffix}.pt')
os.makedirs(os.path.dirname(output_path), exist_ok=True)
try:
torch.save(accumulated_features, output_path)
except:
print(f"Error saving SD features of {file_path}")
if do_dino:
if os.path.exists(output_path_dino):
print(f"File {output_path_dino} already exists. Skipping computing DINO features.")
continue
img1_batch = extractor_vit.preprocess_pil(img1)
with torch.no_grad():
img1_desc_dino = extractor_vit.extract_descriptors(img1_batch.cuda(), layer, facet).permute(0, 1, 3, 2).reshape(1, -1, 60, 60)
os.makedirs(os.path.dirname(output_path_dino), exist_ok=True)
try:
torch.save(img1_desc_dino, output_path_dino)
except:
print(f"Error saving DINO features of {file_path}")
if __name__ == '__main__':
# Argument parser
parser = argparse.ArgumentParser(description="Process and save features from images.")
parser.add_argument('--base_dir', type=str, default='data/SPair-71k/JPEGImages', help='Base directory containing images.')
parser.add_argument('--dino', action='store_true', help='Whether to compute DINO features.')
parser.add_argument('--sd', action='store_true', help='Whether to compute SD features.')
parser.add_argument('--do_flip', action='store_true', help='Whether to flip images vertically.')
parser.add_argument('--sd_size', type=int, default=960, help='Image size for SD.')
parser.add_argument('--dino_size', type=int, default=840, help='Image size for DINOv2.')
parser.add_argument('--layer', type=int, default=11, help='DINOv2 layer for feature extraction.')
parser.add_argument('--facet', type=str, default='token', help='Facet for feature extraction.')
parser.add_argument('--num_ensemble', type=int, default=1, help='Number of ensembles for SD processing.')
parser.add_argument('--dino_model', type=str, default='dinov2_vitb14', help='DINO model.')
parser.add_argument('--sd_path_suffix', type=str, default='_sd', help='Str for SD feature paths.')
parser.add_argument('--dino_path_suffix', type=str, default='_dino', help='Str for DINO feature paths.')
args = parser.parse_args()
set_seed()
all_files = sorted([os.path.join(subdir, file) for subdir, dirs, files in os.walk(args.base_dir) for file in files if file.endswith('.jpg') or file.endswith('.JPEG') or file.endswith('.png')])
print('Number of images', len(all_files))
# Load models
model, aug = None, None
extractor_vit = None
if args.sd:
from model_utils.extractor_sd import load_model, process_features_and_mask
model, aug = load_model(diffusion_ver='v1-5', image_size=args.sd_size, num_timesteps=50, block_indices=[2, 5, 8, 11])
if args.dino:
from model_utils.extractor_dino import ViTExtractor
extractor_vit = ViTExtractor(args.dino_model, 14, device='cuda')
try:
process_and_save_features(all_files, args.sd_size, args.dino_size, args.layer, args.facet, model, aug, extractor_vit, args.num_ensemble, flip=args.do_flip, do_dino=args.dino, do_sd=args.sd,add_str='', sd_post_fix=args.sd_path_suffix, dino_post_fix=args.dino_path_suffix)
except KeyboardInterrupt:
print("\nProcessing interrupted by user.")
print("Feature processing completed.")