-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
75 lines (57 loc) · 2.34 KB
/
train.py
File metadata and controls
75 lines (57 loc) · 2.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import torch
from torchvision import datasets, transforms
from options.train_options import TrainOptions
from models.networks import CustomNet
def load_MNIST_train_data(args, kwargs={}):
transform_list = [transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]
if not args.no_flip:
transform_list = [transforms.RandomHorizontalFlip()] + transform_list
mnist_root = os.path.join(args.data_root, 'mnist')
if not os.path.exists(mnist_root):
os.makedirs(mnist_root)
mnist = datasets.MNIST(mnist_root, train=True, download=True,
transform=transforms.Compose(transform_list))
return torch.utils.data.DataLoader(mnist,
batch_size=args.batch_size, shuffle=True, **kwargs)
def train(args, model, dataloader, optimizer, loss_fn, device):
# Load model if resuming training
if args.continue_train and args.which_epoch > 0:
model.load_network(args.checkpoints_dir, args.name, args.which_epoch)
# Cycle epochs and iters
for epoch in range(args.which_epoch + 1, args.num_epochs + 1):
for data, target in dataloader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
# Log info periodically
if epoch % args.print_freq == 0:
print("Train Epoch: {}\tLoss: {:.6f}".format(epoch, loss.item()))
# Save model checkpoint
if epoch % args.save_freq == 0:
model.save_network(args.checkpoints_dir, args.name, epoch)
# Train settings
args = TrainOptions().parse()
# Configure GPU/CPU
if not args.gpu_id < 0 and torch.cuda.is_available():
torch.cuda.set_device(args.gpu_id)
device = torch.device("cuda")
kwargs = {'num_workers': 1, 'pin_memory': True}
else:
device = torch.device("cpu")
kwargs = {'num_workers': args.n_threads}
# Data Loading
train_loader = load_MNIST_train_data(args, kwargs)
# Create network
model = CustomNet(args).to(device)
model.train()
# Optimization modules
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
loss_fn = torch.nn.functional.nll_loss
# Begin training
train(args, model, train_loader, optimizer, loss_fn, device)
print("Training Complete")