-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
103 lines (82 loc) · 2.7 KB
/
evaluate.py
File metadata and controls
103 lines (82 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from functools import partial
import numpy as np
import random
import argparse
import torch
import torch.nn.functional as F
import ignite
from ignite.engine import Engine, Events
from ignite.contrib.handlers.tensorboard_logger import (
TensorboardLogger, OutputHandler
)
import logging
import workflow
from workflow import json
from workflow.functional import starcompose
from workflow.torch import set_seeds
from workflow.ignite import worker_init, evaluator
from workflow.ignite.handlers.learning_rate import (
LearningRateScheduler, warmup, cyclical
)
from workflow.ignite.handlers import (
EpochLogger,
MetricsLogger,
ProgressBar
)
from datastream import Datastream
from vae import datastream, architecture, metrics
logging.getLogger('ignite').setLevel(logging.WARNING)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def evaluate(config):
device = torch.device('cuda' if config['use_cuda'] else 'cpu')
model = architecture.Model().to(device)
train_state = dict(model=model)
print('Loading model checkpoint')
workflow.ignite.handlers.ModelCheckpoint.load(
train_state, 'model/checkpoints', device
)
@workflow.ignite.decorators.evaluate(model)
def evaluate_batch(engine, examples):
predictions = model.predicted(tuple(example.image for example in examples))
loss = predictions.loss(tuple(example.class_name for example in examples))
return dict(
examples=examples,
predictions=predictions.cpu().detach(),
loss=loss,
)
evaluate_data_loaders = {
f'evaluate_{name}': datastream.data_loader(
batch_size=config['eval_batch_size'],
num_workers=config['n_workers'],
collate_fn=tuple,
)
for name, datastream in datastream.evaluate_datastreams().items()
}
tensorboard_logger = TensorboardLogger(log_dir='tb')
for desciption, data_loader in evaluate_data_loaders.items():
engine = evaluator(
evaluate_batch, desciption, metrics.evaluate_metrics(), tensorboard_logger
)
engine.run(data=data_loader)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--eval_batch_size', type=int, default=128)
parser.add_argument('--n_workers', default=2, type=int)
try:
__IPYTHON__
args = parser.parse_known_args()[0]
except NameError:
args = parser.parse_args()
config = vars(args)
config.update(
seed=1,
use_cuda=torch.cuda.is_available(),
run_id=os.getenv('RUN_ID'),
)
json.write(config, 'config.json')
evaluate(config)