-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathtrain.py
More file actions
108 lines (103 loc) · 4.42 KB
/
train.py
File metadata and controls
108 lines (103 loc) · 4.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import time
import copy
import torch
from torchnet import meter
from torch.autograd import Variable
from utils import plot_training
data_cat = ['train', 'valid'] # data categories
def train_model(model, criterion, optimizer, dataloaders, scheduler,
dataset_sizes, num_epochs):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
costs = {x:[] for x in data_cat} # for storing costs per epoch
accs = {x:[] for x in data_cat} # for storing accuracies per epoch
print('Train batches:', len(dataloaders['train']))
print('Valid batches:', len(dataloaders['valid']), '\n')
for epoch in range(num_epochs):
confusion_matrix = {x: meter.ConfusionMeter(2, normalized=True)
for x in data_cat}
print('Epoch {}/{}'.format(epoch+1, num_epochs))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in data_cat:
model.train(phase=='train')
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for i, data in enumerate(dataloaders[phase]):
# get the inputs
print(i, end='\r')
inputs = data['images'][0]
labels = data['label'].type(torch.FloatTensor)
# wrap them in Variable
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
outputs = torch.mean(outputs)
loss = criterion(outputs, labels, phase)
running_loss += loss.data[0]
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
preds = (outputs.data > 0.5).type(torch.cuda.FloatTensor)
running_corrects += torch.sum(preds == labels.data)
confusion_matrix[phase].add(preds, labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
costs[phase].append(epoch_loss)
accs[phase].append(epoch_acc)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
print('Confusion Meter:\n', confusion_matrix[phase].value())
# deep copy the model
if phase == 'valid':
scheduler.step(epoch_loss)
if epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
time_elapsed = time.time() - since
print('Time elapsed: {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best valid Acc: {:4f}'.format(best_acc))
plot_training(costs, accs)
# load best model weights
model.load_state_dict(best_model_wts)
return model
def get_metrics(model, criterion, dataloaders, dataset_sizes, phase='valid'):
'''
Loops over phase (train or valid) set to determine acc, loss and
confusion meter of the model.
'''
confusion_matrix = meter.ConfusionMeter(2, normalized=True)
running_loss = 0.0
running_corrects = 0
for i, data in enumerate(dataloaders[phase]):
print(i, end='\r')
labels = data['label'].type(torch.FloatTensor)
inputs = data['images'][0]
# wrap them in Variable
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
# forward
outputs = model(inputs)
outputs = torch.mean(outputs)
loss = criterion(outputs, labels, phase)
# statistics
running_loss += loss.data[0] * inputs.size(0)
preds = (outputs.data > 0.5).type(torch.cuda.FloatTensor)
running_corrects += torch.sum(preds == labels.data)
confusion_matrix.add(preds, labels.data)
loss = running_loss / dataset_sizes[phase]
acc = running_corrects / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, loss, acc))
print('Confusion Meter:\n', confusion_matrix.value())