Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import os, argparse, time, datetime, stat, shutil
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.utils as vutils
from util.RGB_Depth_Fusion_dataset import RGB_Depth_Fusion_dataset
from util.augmentation import RandomFlip, RandomCrop, RandomCropOut, RandomBrightness, RandomNoise
from util.util import compute_results
from sklearn.metrics import confusion_matrix
from torch.utils.tensorboard import SummaryWriter
from model import AMFNet

#############################################################################################
parser = argparse.ArgumentParser(description='Train with pytorch')
#############################################################################################
parser.add_argument('--model_name', '-m', type=str, default='AMFNet')
parser.add_argument('--batch_size', '-b', type=int, default=2)
parser.add_argument('--lr_start', '-ls', type=float, default=0.01)
parser.add_argument('--gpu', '-g', type=int, default=0)
#############################################################################################
parser.add_argument('--lr_decay', '-ld', type=float, default=0.95)
parser.add_argument('--epoch_max', '-em', type=int, default=200) # please stop training mannully
parser.add_argument('--epoch_from', '-ef', type=int, default=0)
parser.add_argument('--num_workers', '-j', type=int, default=8)
parser.add_argument('--n_class', '-nc', type=int, default=3)
parser.add_argument('--data_dir', '-dr', type=str, default='./dataset/')
args = parser.parse_args()
#############################################################################################

augmentation_methods = [
RandomFlip(prob=0.5),
RandomCrop(crop_rate=0.1, prob=1.0),
]

def train(epo, model, train_loader, optimizer):
model.train()
for it, (images, labels, names) in enumerate(train_loader):
images = Variable(images).cuda(args.gpu)
labels = Variable(labels).cuda(args.gpu)
start_t = time.time() # time.time() returns the current time
optimizer.zero_grad()
logits = model(images)
loss = F.cross_entropy(logits, labels) # Note that the cross_entropy function has already include the softmax function
loss.backward()
optimizer.step()
lr_this_epo=0
for param_group in optimizer.param_groups:
lr_this_epo = param_group['lr']
print('Train: %s, epo %s/%s, iter %s/%s, lr %.8f, %.2f img/sec, loss %.4f, time %s' \
% (args.model_name, epo, args.epoch_max, it+1, len(train_loader), lr_this_epo, len(names)/(time.time()-start_t), float(loss),
datetime.datetime.now().replace(microsecond=0)-start_datetime))
if accIter['train'] % 1 == 0:
writer.add_scalar('Train/loss', loss, accIter['train'])
view_figure = True # note that I have not colorized the GT and predictions here
if accIter['train'] % 50 == 0:
if view_figure:
input_rgb_images = vutils.make_grid(images[:,:3], nrow=8, padding=10) # can only display 3-channel images, so images[:,:3]
writer.add_image('Train/input_rgb_images', input_rgb_images, accIter['train'])

scale = max(1, 255//args.n_class) # label (0,1,2..) is invisable, multiply a constant for visualization
groundtruth_tensor = labels.unsqueeze(1) * scale # mini_batch*288*512 -> mini_batch*1*288*512
groundtruth_tensor = torch.cat((groundtruth_tensor, groundtruth_tensor, groundtruth_tensor), 1) # change to 3-channel for visualization
groudtruth_images = vutils.make_grid(groundtruth_tensor, nrow=8, padding=10)
writer.add_image('Train/groudtruth_images', groudtruth_images, accIter['train'])

depth_tensor = images[:,3:4] # mini_batch*288*512 -> mini_batch*1*288*512
depth_tensor = torch.cat((depth_tensor, depth_tensor, depth_tensor), 1) # change to 3-channel for visualization
depth_images = vutils.make_grid(depth_tensor, nrow=8, padding=10)
writer.add_image('Train/depth_images', depth_images, accIter['train'])

mask_tensor = images[:,4:] # mini_batch*288*512 -> mini_batch*1*288*512
mask_tensor = torch.cat((mask_tensor, mask_tensor, mask_tensor), 1) # change to 3-channel for visualization
mask_images = vutils.make_grid(mask_tensor, nrow=8, padding=10)
writer.add_image('Train/mask_images', mask_images, accIter['train'])

predicted_tensor = logits.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*288*512 -> mini_batch*288*512 -> mini_batch*1*288*512
predicted_tensor = torch.cat((predicted_tensor, predicted_tensor, predicted_tensor),1) # change to 3-channel for visualization, mini_batch*1*288*512
predicted_images = vutils.make_grid(predicted_tensor, nrow=8, padding=10)
writer.add_image('Train/predicted_images', predicted_images, accIter['train'])
accIter['train'] = accIter['train'] + 1

def validation(epo, model, val_loader):
model.eval()
with torch.no_grad():
for it, (images, labels, names) in enumerate(val_loader):
images = Variable(images).cuda(args.gpu)
labels = Variable(labels).cuda(args.gpu)
start_t = time.time() # time.time() returns the current time
logits = model(images)
loss = F.cross_entropy(logits, labels) # Note that the cross_entropy function has already include the softmax function
print('Val: %s, epo %s/%s, iter %s/%s, %.2f img/sec, loss %.4f, time %s' \
% (args.model_name, epo, args.epoch_max, it + 1, len(val_loader), len(names)/(time.time()-start_t), float(loss),
datetime.datetime.now().replace(microsecond=0)-start_datetime))
if accIter['val'] % 1 == 0:
writer.add_scalar('Validation/loss', loss, accIter['val'])
view_figure = False # note that I have not colorized the GT and predictions here
if accIter['val'] % 100 == 0:
if view_figure:
input_rgb_images = vutils.make_grid(images[:, :3], nrow=8, padding=10) # can only display 3-channel images, so images[:,:3]
writer.add_image('Validation/input_rgb_images', input_rgb_images, accIter['val'])
scale = max(1, 255 // args.n_class) # label (0,1,2..) is invisable, multiply a constant for visualization
groundtruth_tensor = labels.unsqueeze(1) * scale # mini_batch*288*512 -> mini_batch*1*288*512
groundtruth_tensor = torch.cat((groundtruth_tensor, groundtruth_tensor, groundtruth_tensor), 1) # change to 3-channel for visualization
groudtruth_images = vutils.make_grid(groundtruth_tensor, nrow=8, padding=10)
writer.add_image('Validation/groudtruth_images', groudtruth_images, accIter['val'])
predicted_tensor = logits.argmax(1).unsqueeze(1)*scale # mini_batch*args.n_class*288*512 -> mini_batch*288*512 -> mini_batch*1*288*512
predicted_tensor = torch.cat((predicted_tensor, predicted_tensor, predicted_tensor), 1) # change to 3-channel for visualization, mini_batch*1*288*512
predicted_images = vutils.make_grid(predicted_tensor, nrow=8, padding=10)
writer.add_image('Validation/predicted_images', predicted_images, accIter['val'])
accIter['val'] += 1

def testing(epo, model, test_loader):
model.eval()
conf_total = np.zeros((args.n_class, args.n_class))
label_list = ["unlabeled", "road", "negative"]
testing_results_file = os.path.join(weight_dir, 'testing_results_file.txt')
with torch.no_grad():
for it, (images, labels, names) in enumerate(test_loader):
images = Variable(images).cuda(args.gpu)
#images = images[:,:3]
labels = Variable(labels).cuda(args.gpu)
#images = torch.cat((images,labels),dim = 1)
logits = model(images)
label = labels.cpu().numpy().squeeze().flatten()
prediction = logits.argmax(1).cpu().numpy().squeeze().flatten() # prediction and label are both 1-d array, size: minibatch*640*480
conf = confusion_matrix(y_true=label, y_pred=prediction, labels=[0,1,2]) # conf is args.n_class*args.n_class matrix, vertical axis: groundtruth, horizontal axis: prediction
conf_total += conf
print('Test: %s, epo %s/%s, iter %s/%s, time %s' % (args.model_name, epo, args.epoch_max, it+1, len(test_loader),
datetime.datetime.now().replace(microsecond=0)-start_datetime))
precision, recall, IoU, F1 = compute_results(conf_total)
writer.add_scalar('Test/average_recall', recall.mean(), epo)
writer.add_scalar('Test/average_IoU', IoU.mean(), epo)
writer.add_scalar('Test/average_precision',precision.mean(), epo)
writer.add_scalar('Test/average_F1', F1.mean(), epo)
for i in range(len(precision)):
writer.add_scalar("Test(class)/precision_class_%s" % label_list[i], precision[i], epo)
writer.add_scalar("Test(class)/recall_class_%s"% label_list[i], recall[i],epo)
writer.add_scalar('Test(class)/Iou_%s'% label_list[i], IoU[i], epo)
writer.add_scalar('Test(class)/F1_%s'% label_list[i], F1[i], epo)
if epo==0:
with open(testing_results_file, 'w') as f:
f.write("# %s, initial lr: %s, batch size: %s, date: %s \n" %(args.model_name, args.lr_start, args.batch_size, datetime.date.today()))
f.write("# epoch: unlabeled, road, negative, average(nan_to_num). (Pre %, Acc %, IoU %, F1 %)\n")
with open(testing_results_file, 'a') as f:
f.write(str(epo)+': ')
for i in range(len(precision)):
f.write('%0.4f, %0.4f, %0.4f, %0.4f ' % (100*precision[i], 100*recall[i], 100*IoU[i], 100*F1[i]))
f.write('%0.4f, %0.4f, %0.4f, %0.4f \n' % (100*np.mean(np.nan_to_num(precision)), 100*np.mean(np.nan_to_num(recall)), 100*np.mean(np.nan_to_num(IoU)), 100*np.mean(np.nan_to_num(F1)) ))
print('saving testing results.')
with open(testing_results_file, "r") as file:
writer.add_text('testing_results', file.read().replace('\n', ' \n'), epo)


if __name__ == '__main__':

torch.cuda.set_device(args.gpu)
print("\nthe pytorch version:", torch.__version__)
print("the gpu count:", torch.cuda.device_count())
print("the current used gpu:", torch.cuda.current_device(), '\n')

model = eval(args.model_name)(n_class=args.n_class)

if args.gpu >= 0: model.cuda(args.gpu)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_start, momentum=0.9, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay, last_epoch=-1)

# preparing folders
if os.path.exists("./runs"):
shutil.rmtree("./runs")
weight_dir = os.path.join("./runs", args.model_name)
os.makedirs(weight_dir)
os.chmod(weight_dir, stat.S_IRWXO) # allow the folder created by docker read, written, and execuated by local machine

writer = SummaryWriter("./runs/tensorboard_log")
os.chmod("./runs/tensorboard_log", stat.S_IRWXO) # allow the folder created by docker read, written, and execuated by local machine
os.chmod("./runs", stat.S_IRWXO)

print('training %s on GPU #%d with pytorch' % (args.model_name, args.gpu))
print('from epoch %d / %s' % (args.epoch_from, args.epoch_max))
print('weight will be saved in: %s' % weight_dir)

train_dataset = RGB_Depth_Fusion_dataset(data_dir=args.data_dir, split='train', transform=augmentation_methods,threshold=0)
val_dataset = RGB_Depth_Fusion_dataset(data_dir=args.data_dir, split='val',threshold=0)
test_dataset = RGB_Depth_Fusion_dataset(data_dir=args.data_dir, split='test',threshold=0)

train_loader = DataLoader(
dataset = train_dataset,
batch_size = args.batch_size,
shuffle = False,
num_workers = args.num_workers,
pin_memory = True,
drop_last = False
)
val_loader = DataLoader(
dataset = val_dataset,
batch_size = args.batch_size,
shuffle = False,
num_workers = args.num_workers,
pin_memory = True,
drop_last = False
)
test_loader = DataLoader(
dataset = test_dataset,
batch_size = args.batch_size,
shuffle = False,
num_workers = args.num_workers,
pin_memory = True,
drop_last = False
)
start_datetime = datetime.datetime.now().replace(microsecond=0)
accIter = {'train': 0, 'val': 0}
for epo in range(args.epoch_from, args.epoch_max):
print('\ntrain %s, epo #%s begin...' % (args.model_name, epo))
#scheduler.step() # if using pytorch 0.4.1, please put this statement here
train(epo, model, train_loader, optimizer)
validation(epo, model, val_loader)

checkpoint_model_file = os.path.join(weight_dir, str(epo) + '.pth')
print('saving check point %s: ' % checkpoint_model_file)
torch.save(model.state_dict(), checkpoint_model_file)

#if epo>15:
testing(epo, model, test_loader) # testing is just for your reference, you can comment this line during training

scheduler.step() # if using pytorch 1.1 or above, please put this statement here