-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_transforms.py
More file actions
29 lines (27 loc) · 949 Bytes
/
test_transforms.py
File metadata and controls
29 lines (27 loc) · 949 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import wandb
import hydra
import csv
from tqdm import tqdm
from models.multimodalAttention import *
import os
import matplotlib.pyplot as plt
import torchvision
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import matplotlib
matplotlib.use('Agg')
@hydra.main(config_path="configs", config_name="train")
def main (cfg):
datamodule = hydra.utils.instantiate(cfg.datamodule)
train_loader = datamodule.train_dataloader()
print("Train dataloader created with batch size:", datamodule.batch_size)
for i,batch in enumerate(train_loader):
images = batch['image'] if isinstance(batch, dict) and 'image' in batch else batch[0]
grid = torchvision.utils.make_grid(images)
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.savefig(f'train_batch_{i}.png', bbox_inches='tight')
plt.show()
plt.close()
if __name__ == "__main__":
main()