diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e660f74 --- /dev/null +++ b/.gitignore @@ -0,0 +1,141 @@ +# Created by .ignore support plugin (hsz.mobi) +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + diff --git a/0_prepare.sh b/0_prepare.sh new file mode 100644 index 0000000..56bfa6c --- /dev/null +++ b/0_prepare.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# Load tensorflow models (we need DeepLabV3+ for coarse people segmentation) +git clone https://github.com/tensorflow/models.git + +# Create a docker image with all dependencies +docker build -t backmatting -f dockerfile . + +# Download pretrained models +wget https://gist.githubusercontent.com/andreyryabtsev/458f7450c630952d1e75e195f94845a0/raw/0b4336ac2a2140ac2313f9966316467e8cd3002a/download.sh +chmod +x download.sh +./download.sh diff --git a/dockerfile b/dockerfile new file mode 100644 index 0000000..6884395 --- /dev/null +++ b/dockerfile @@ -0,0 +1,24 @@ +# Let image base on ubuntu 16.04 +FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 +ENV TZ=Europe/Berlin +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone +# Install python 3.6 +RUN apt-get update && apt-get install -y \ + python3.6-dev\ + python3-pip \ + python3-tk \ + git libgtk2.0-dev +# Install OpenCV requirements +RUN apt-get update && apt-get install -y \ + libopencv-dev \ + python-opencv +# Install required python libraries +ADD requirements.txt . +RUN pip3 install --upgrade pip +RUN pip3 install --upgrade setuptools + +RUN pip3 install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html +RUN pip3 install -r requirements.txt +# Configure environment variables +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64 +ENV CUDA_HOME=/usr/local/cuda diff --git a/functions.py b/functions.py index 3c3fbcf..17ef1ec 100644 --- a/functions.py +++ b/functions.py @@ -5,106 +5,106 @@ def composite4(fg, bg, a): - fg = np.array(fg, np.float32) - alpha= np.expand_dims(a / 255,axis=2) - im = alpha * fg + (1 - alpha) * bg - im = im.astype(np.uint8) - return im + fg = np.array(fg, np.float32) + alpha = np.expand_dims(a / 255, axis=2) + im = alpha * fg + (1 - alpha) * bg + im = im.astype(np.uint8) + return im -def compose_image_withshift(alpha_pred,fg_pred,bg,seg): - image_sh=torch.zeros(fg_pred.shape).cuda() +def compose_image_withshift(alpha_pred, fg_pred, bg, seg): + image_sh = torch.zeros(fg_pred.shape).cuda() - for t in range(0,fg_pred.shape[0]): - al_tmp=to_image(seg[t,...]).squeeze(2) - where = np.array(np.where((al_tmp>0.1).astype(np.float32))) + for t in range(0, fg_pred.shape[0]): + al_tmp = to_image(seg[t, ...]).squeeze(2) + where = np.array(np.where((al_tmp > 0.1).astype(np.float32))) x1, y1 = np.amin(where, axis=1) x2, y2 = np.amax(where, axis=1) - #select shift - n=np.random.randint(-(y1-10),al_tmp.shape[1]-y2-10) - #n positive indicates shift to right - alpha_pred_sh=torch.cat((alpha_pred[t,:,:,-n:],alpha_pred[t,:,:,:-n]),dim=2) - fg_pred_sh=torch.cat((fg_pred[t,:,:,-n:],fg_pred[t,:,:,:-n]),dim=2) + # select shift + n = np.random.randint(-(y1 - 10), al_tmp.shape[1] - y2 - 10) + # n positive indicates shift to right + alpha_pred_sh = torch.cat((alpha_pred[t, :, :, -n:], alpha_pred[t, :, :, :-n]), dim=2) + fg_pred_sh = torch.cat((fg_pred[t, :, :, -n:], fg_pred[t, :, :, :-n]), dim=2) - alpha_pred_sh=(alpha_pred_sh+1)/2 + alpha_pred_sh = (alpha_pred_sh + 1) / 2 - image_sh[t,...]=fg_pred_sh*alpha_pred_sh + (1-alpha_pred_sh)*bg[t,...] + image_sh[t, ...] = fg_pred_sh * alpha_pred_sh + (1 - alpha_pred_sh) * bg[t, ...] return torch.autograd.Variable(image_sh.cuda()) -def get_bbox(mask,R,C): + +def get_bbox(mask, R, C): where = np.array(np.where(mask)) x1, y1 = np.amin(where, axis=1) x2, y2 = np.amax(where, axis=1) - bbox_init=[x1,y1,np.maximum(x2-x1,y2-y1),np.maximum(x2-x1,y2-y1)] - + bbox_init = [x1, y1, np.maximum(x2 - x1, y2 - y1), np.maximum(x2 - x1, y2 - y1)] - bbox=create_bbox(bbox_init,(R,C)) + bbox = create_bbox(bbox_init, (R, C)) return bbox -def crop_images(crop_list,reso,bbox): - for i in range(0,len(crop_list)): - img=crop_list[i] - if img.ndim>=3: - img_crop=img[bbox[0]:bbox[0]+bbox[2],bbox[1]:bbox[1]+bbox[3],...]; img_crop=cv2.resize(img_crop,reso) +def crop_images(crop_list, reso, bbox): + for i in range(len(crop_list)): + img = crop_list[i] + if img.ndim >= 3: + img_crop = img[bbox[0]:bbox[0] + bbox[2], bbox[1]:bbox[1] + bbox[3], ...] + img_crop = cv2.resize(img_crop, reso) else: - img_crop=img[bbox[0]:bbox[0]+bbox[2],bbox[1]:bbox[1]+bbox[3]]; img_crop=cv2.resize(img_crop,reso) - crop_list[i]=img_crop + img_crop = img[bbox[0]:bbox[0] + bbox[2], bbox[1]:bbox[1] + bbox[3]] + img_crop = cv2.resize(img_crop, reso) + crop_list[i] = img_crop return crop_list -def create_bbox(bbox_init,sh): - w=np.maximum(bbox_init[2],bbox_init[3]) +def create_bbox(bbox_init, sh): + w = np.maximum(bbox_init[2], bbox_init[3]) - x1=bbox_init[0]-0.1*w - y1=bbox_init[1]-0.1*w + x1 = bbox_init[0] - 0.1 * w + y1 = bbox_init[1] - 0.1 * w - x2=bbox_init[0]+1.1*w - y2=bbox_init[1]+1.1*w + x2 = bbox_init[0] + 1.1 * w + y2 = bbox_init[1] + 1.1 * w - if x1<0: x1=0 - if y1<0: y1=0 - if x2>=sh[0]: x2=sh[0]-1 - if y2>=sh[1]: y2=sh[1]-1 + if x1 < 0: x1 = 0 + if y1 < 0: y1 = 0 + if x2 >= sh[0]: x2 = sh[0] - 1 + if y2 >= sh[1]: y2 = sh[1] - 1 - bbox=np.around([x1,y1,x2-x1,y2-y1]).astype('int') + bbox = np.around([x1, y1, x2 - x1, y2 - y1]).astype('int') return bbox -def uncrop(alpha,bbox,R=720,C=1280): - - alpha=cv2.resize(alpha,(bbox[3],bbox[2])) +def uncrop(alpha, bbox, R=720, C=1280): + alpha = cv2.resize(alpha, (bbox[3], bbox[2])) - if alpha.ndim==2: - alpha_uncrop=np.zeros((R,C)) - alpha_uncrop[bbox[0]:bbox[0]+bbox[2],bbox[1]:bbox[1]+bbox[3]]=alpha + if alpha.ndim == 2: + alpha_uncrop = np.zeros((R, C)) + alpha_uncrop[bbox[0]:bbox[0] + bbox[2], bbox[1]:bbox[1] + bbox[3]] = alpha else: - alpha_uncrop=np.zeros((R,C,3)) - alpha_uncrop[bbox[0]:bbox[0]+bbox[2],bbox[1]:bbox[1]+bbox[3],:]=alpha - + alpha_uncrop = np.zeros((R, C, 3)) + alpha_uncrop[bbox[0]:bbox[0] + bbox[2], bbox[1]:bbox[1] + bbox[3], :] = alpha return alpha_uncrop.astype(np.uint8) def to_image(rec0): - rec0=((rec0.data).cpu()).numpy() - rec0=(rec0+1)/2 - rec0=rec0.transpose((1,2,0)) - rec0[rec0>1]=1 - rec0[rec0<0]=0 + rec0 = ((rec0.data).cpu()).numpy() + rec0 = (rec0 + 1) / 2 + rec0 = rec0.transpose((1, 2, 0)) + rec0[rec0 > 1] = 1 + rec0[rec0 < 0] = 0 return rec0 -def write_tb_log(image,tag,log_writer,i): + +def write_tb_log(image, tag, log_writer, i): # image1 - output_to_show = image.cpu().data[0:4,...] - output_to_show = (output_to_show + 1)/2.0 - grid = torchvision.utils.make_grid(output_to_show,nrow=4) + output_to_show = image.cpu().data[0:4, ...] + output_to_show = (output_to_show + 1) / 2.0 + grid = torchvision.utils.make_grid(output_to_show, nrow=4) log_writer.add_image(tag, grid, i + 1) - diff --git a/requirements.txt b/requirements.txt index ae1764b..2e58970 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +tensorflow-gpu==1.14.0 numpy==1.17.0 opencv-python==3.4.5.20 pandas diff --git a/test_background-matting_image.py b/test_background-matting_image.py index a8d785b..182b3b5 100644 --- a/test_background-matting_image.py +++ b/test_background-matting_image.py @@ -1,198 +1,208 @@ from __future__ import print_function +import argparse +import glob +import os -import os, glob, time, argparse, pdb, cv2 -#import matplotlib.pyplot as plt -import numpy as np -from skimage.measure import label - - -import torch +import torch.backends.cudnn as cudnn import torch.nn as nn +import tqdm +from skimage.measure import label from torch.autograd import Variable -import torch.backends.cudnn as cudnn from functions import * from networks import ResnetConditionHR torch.set_num_threads(1) -#os.environ["CUDA_VISIBLE_DEVICES"]="4" -print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"]) - -"""Parses arguments.""" parser = argparse.ArgumentParser(description='Background Matting.') -parser.add_argument('-m', '--trained_model', type=str, default='real-fixed-cam',choices=['real-fixed-cam', 'real-hand-held', 'syn-comp-adobe'],help='Trained background matting model') -parser.add_argument('-o', '--output_dir', type=str, required=True,help='Directory to save the output results. (required)') -parser.add_argument('-i', '--input_dir', type=str, required=True,help='Directory to load input images. (required)') -parser.add_argument('-tb', '--target_back', type=str,help='Directory to load the target background.') -parser.add_argument('-b', '--back', type=str,default=None,help='Captured background image. (only use for inference on videos with fixed camera') - - -args=parser.parse_args() - -#input model -model_main_dir='Models/' + args.trained_model + '/'; -#input data path -data_path=args.input_dir +parser.add_argument('-m', '--trained_model', type=str, default='real-fixed-cam', + choices=['real-fixed-cam', 'real-hand-held', 'syn-comp-adobe'], + help='Trained background matting model') +parser.add_argument('-o', '--output_dir', type=str, required=True, + help='Directory to save the output results. (required)') +parser.add_argument('-i', '--input_dir', type=str, required=True, help='Directory to load input images. (required)') +parser.add_argument('-tb', '--target_back', type=str, + help='Target background to put foreground on. Either path to an image or an directory.') +parser.add_argument('-b', '--back', type=str, default=None, + help='Captured background image for fixed-cam mode. In case of hand-held mode, leave empty') + +args = parser.parse_args() +# input data path +data_path = args.input_dir if os.path.isdir(args.target_back): - args.video=True - print('Using video mode') + args.video = True + print('Using video mode') else: - args.video=False - print('Using image mode') - #target background path - back_img10=cv2.imread(args.target_back); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB); - #Green-screen background - back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155; - - - -#initialize network -fo=glob.glob(model_main_dir + 'netG_epoch_*') -model_name1=fo[0] -netM=ResnetConditionHR(input_nc=(3,3,1,4),output_nc=4,n_blocks1=7,n_blocks2=3) -netM=nn.DataParallel(netM) -netM.load_state_dict(torch.load(model_name1)) -netM.cuda(); netM.eval() -cudnn.benchmark=True -reso=(512,512) #input reoslution to the network - -#load captured background for video mode, fixed camera + args.video = False + print('Using image mode') + # target background path + target_img = cv2.imread(args.target_back) + target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB) + # Green-screen background + target_green_img = np.zeros(target_img.shape, dtype=np.uint8) + target_green_img[..., 0] = 0 + target_green_img[..., 1] = 255 + target_green_img[..., 2] = 0 + +# initialize network +model_main_dir = 'Models/' + args.trained_model + '/' +model_filepath = glob.glob(model_main_dir + 'netG_epoch_*')[0] +print("Loading model", model_filepath) +net = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3) +net = nn.DataParallel(net) +net.load_state_dict(torch.load(model_filepath)) +net.cuda() +net.eval() +cudnn.benchmark = True +reso = (512, 512) # input resolution to the network + +# load captured background for video mode, fixed camera if args.back is not None: - bg_im0=cv2.imread(args.back); bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB); - + back_img = cv2.imread(args.back) + back_img = cv2.cvtColor(back_img, cv2.COLOR_BGR2RGB) -#Create a list of test images +# Create a list of test images test_imgs = [f for f in os.listdir(data_path) if - os.path.isfile(os.path.join(data_path, f)) and f.endswith('_img.png')] + os.path.isfile(os.path.join(data_path, f)) and f.endswith('_img.png')] test_imgs.sort() -#output directory -result_path=args.output_dir +# output directory +result_path = args.output_dir if not os.path.exists(result_path): - os.makedirs(result_path) - -for i in range(0,len(test_imgs)): - filename = test_imgs[i] - #original image - bgr_img = cv2.imread(os.path.join(data_path, filename)); bgr_img=cv2.cvtColor(bgr_img,cv2.COLOR_BGR2RGB); - - if args.back is None: - #captured background image - bg_im0=cv2.imread(os.path.join(data_path, filename.replace('_img','_back'))); bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB); - - #segmentation mask - rcnn = cv2.imread(os.path.join(data_path, filename.replace('_img','_masksDL')),0); - - if args.video: #if video mode, load target background frames - #target background path - back_img10=cv2.imread(os.path.join(args.target_back,filename.replace('_img.png','.png'))); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB); - #Green-screen background - back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155; - - #create multiple frames with adjoining frames - gap=20 - multi_fr_w=np.zeros((bgr_img.shape[0],bgr_img.shape[1],4)) - idx=[i-2*gap,i-gap,i+gap,i+2*gap] - for t in range(0,4): - if idx[t]<0: - idx[t]=len(test_imgs)+idx[t] - elif idx[t]>=len(test_imgs): - idx[t]=idx[t]-len(test_imgs) - - file_tmp=test_imgs[idx[t]] - bgr_img_mul = cv2.imread(os.path.join(data_path, file_tmp)); - multi_fr_w[...,t]=cv2.cvtColor(bgr_img_mul,cv2.COLOR_BGR2GRAY); - - else: - ## create the multi-frame - multi_fr_w=np.zeros((bgr_img.shape[0],bgr_img.shape[1],4)) - multi_fr_w[...,0] = cv2.cvtColor(bgr_img,cv2.COLOR_BGR2GRAY); - multi_fr_w[...,1] = multi_fr_w[...,0] - multi_fr_w[...,2] = multi_fr_w[...,0] - multi_fr_w[...,3] = multi_fr_w[...,0] - - - #crop tightly - bgr_img0=bgr_img; - bbox=get_bbox(rcnn,R=bgr_img0.shape[0],C=bgr_img0.shape[1]) - - crop_list=[bgr_img,bg_im0,rcnn,back_img10,back_img20,multi_fr_w] - crop_list=crop_images(crop_list,reso,bbox) - bgr_img=crop_list[0]; bg_im=crop_list[1]; rcnn=crop_list[2]; back_img1=crop_list[3]; back_img2=crop_list[4]; multi_fr=crop_list[5] - - #process segmentation mask - kernel_er = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) - kernel_dil = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) - rcnn=rcnn.astype(np.float32)/255; rcnn[rcnn>0.2]=1; - K=25 - - zero_id=np.nonzero(np.sum(rcnn,axis=1)==0) - del_id=zero_id[0][zero_id[0]>250] - if len(del_id)>0: - del_id=[del_id[0]-2,del_id[0]-1,*del_id] - rcnn=np.delete(rcnn,del_id,0) - rcnn = cv2.copyMakeBorder( rcnn, 0, K + len(del_id), 0, 0, cv2.BORDER_REPLICATE) - - - rcnn = cv2.erode(rcnn, kernel_er, iterations=10) - rcnn = cv2.dilate(rcnn, kernel_dil, iterations=5) - rcnn=cv2.GaussianBlur(rcnn.astype(np.float32),(31,31),0) - rcnn=(255*rcnn).astype(np.uint8) - rcnn=np.delete(rcnn, range(reso[0],reso[0]+K), 0) - - - #convert to torch - img=torch.from_numpy(bgr_img.transpose((2, 0, 1))).unsqueeze(0); img=2*img.float().div(255)-1 - bg=torch.from_numpy(bg_im.transpose((2, 0, 1))).unsqueeze(0); bg=2*bg.float().div(255)-1 - rcnn_al=torch.from_numpy(rcnn).unsqueeze(0).unsqueeze(0); rcnn_al=2*rcnn_al.float().div(255)-1 - multi_fr=torch.from_numpy(multi_fr.transpose((2, 0, 1))).unsqueeze(0); multi_fr=2*multi_fr.float().div(255)-1 - - - with torch.no_grad(): - img,bg,rcnn_al, multi_fr =Variable(img.cuda()), Variable(bg.cuda()), Variable(rcnn_al.cuda()), Variable(multi_fr.cuda()) - input_im=torch.cat([img,bg,rcnn_al,multi_fr],dim=1) - - alpha_pred,fg_pred_tmp=netM(img,bg,rcnn_al,multi_fr) - - al_mask=(alpha_pred>0.95).type(torch.cuda.FloatTensor) - - # for regions with alpha>0.95, simply use the image as fg - fg_pred=img*al_mask + fg_pred_tmp*(1-al_mask) - - alpha_out=to_image(alpha_pred[0,...]); - - #refine alpha with connected component - labels=label((alpha_out>0.05).astype(int)) - try: - assert( labels.max() != 0 ) - except: - continue - largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 - alpha_out=alpha_out*largestCC - - alpha_out=(255*alpha_out[...,0]).astype(np.uint8) - - fg_out=to_image(fg_pred[0,...]); fg_out=fg_out*np.expand_dims((alpha_out.astype(float)/255>0.01).astype(float),axis=2); fg_out=(255*fg_out).astype(np.uint8) - - #Uncrop - R0=bgr_img0.shape[0];C0=bgr_img0.shape[1] - alpha_out0=uncrop(alpha_out,bbox,R0,C0) - fg_out0=uncrop(fg_out,bbox,R0,C0) - - #compose - back_img10=cv2.resize(back_img10,(C0,R0)); back_img20=cv2.resize(back_img20,(C0,R0)) - comp_im_tr1=composite4(fg_out0,back_img10,alpha_out0) - comp_im_tr2=composite4(fg_out0,back_img20,alpha_out0) - - cv2.imwrite(result_path+'/'+filename.replace('_img','_out'), alpha_out0) - cv2.imwrite(result_path+'/'+filename.replace('_img','_fg'), cv2.cvtColor(fg_out0,cv2.COLOR_BGR2RGB)) - cv2.imwrite(result_path+'/'+filename.replace('_img','_compose'), cv2.cvtColor(comp_im_tr1,cv2.COLOR_BGR2RGB)) - cv2.imwrite(result_path+'/'+filename.replace('_img','_matte').format(i), cv2.cvtColor(comp_im_tr2,cv2.COLOR_BGR2RGB)) - - - print('Done: ' + str(i+1) + '/' + str(len(test_imgs))) - + os.makedirs(result_path) + +for i in tqdm.trange(len(test_imgs)): + filename = test_imgs[i] + + # original image + bgr_img = cv2.imread(os.path.join(data_path, filename)) + bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + output_height = bgr_img.shape[0] + output_width = bgr_img.shape[1] + + if args.back is None: + # captured background image + back_img = cv2.imread(os.path.join(data_path, filename.replace('_img', '_back'))) + back_img = cv2.cvtColor(back_img, cv2.COLOR_BGR2RGB) + + # segmentation mask + seg_mask = cv2.imread(os.path.join(data_path, filename.replace('_img', '_masksDL')), 0) + + if args.video: # if video mode, load target background frames + # target background path + target_img = cv2.imread(os.path.join(args.target_back, filename.replace('_img.png', '.png'))) + target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB) + # Green-screen background + target_green_img = np.zeros(target_img.shape) + target_green_img[..., 0] = 0 + target_green_img[..., 1] = 255 + target_green_img[..., 2] = 0 + + # create multiple frames with adjoining frames + gap = 20 + multi_fr_w = np.zeros((output_height, output_width, 4)) + idx = [i - 2 * gap, i - gap, i + gap, i + 2 * gap] + for t in range(0, 4): + if idx[t] < 0: + idx[t] = len(test_imgs) + idx[t] + elif idx[t] >= len(test_imgs): + idx[t] = idx[t] - len(test_imgs) + + file_tmp = test_imgs[idx[t]] + bgr_img_mul = cv2.imread(os.path.join(data_path, file_tmp)) + multi_fr_w[..., t] = cv2.cvtColor(bgr_img_mul, cv2.COLOR_BGR2GRAY) + else: + ## create the multi-frame + multi_fr_w = np.zeros((output_height, output_width, 4)) + multi_fr_w[..., 0] = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2GRAY) + multi_fr_w[..., 1] = multi_fr_w[..., 0] + multi_fr_w[..., 2] = multi_fr_w[..., 0] + multi_fr_w[..., 3] = multi_fr_w[..., 0] + + # Crop all images by the bbox of the rough segmentation mask + bbox = get_bbox(seg_mask, R=output_height, C=output_width) + + crop_list = [bgr_img, back_img, seg_mask, multi_fr_w] + crop_list = crop_images(crop_list, reso, bbox) + bgr_img = crop_list[0] + bg_im = crop_list[1] + seg_mask = crop_list[2] + multi_fr = crop_list[3] + + # Preprocess the rough segmentation mask + kernel_er = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + kernel_dil = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + seg_mask = seg_mask.astype(np.float32) / 255 + seg_mask[seg_mask > 0.2] = 1 + K = 25 + + zero_id = np.nonzero(np.sum(seg_mask, axis=1) == 0) + del_id = zero_id[0][zero_id[0] > 250] + if len(del_id) > 0: + del_id = [del_id[0] - 2, del_id[0] - 1, *del_id] + seg_mask = np.delete(seg_mask, del_id, 0) + seg_mask = cv2.copyMakeBorder(seg_mask, 0, K + len(del_id), 0, 0, cv2.BORDER_REPLICATE) + + seg_mask = cv2.erode(seg_mask, kernel_er, iterations=10) + seg_mask = cv2.dilate(seg_mask, kernel_dil, iterations=5) + seg_mask = cv2.GaussianBlur(seg_mask.astype(np.float32), (31, 31), 0) + seg_mask = (255 * seg_mask).astype(np.uint8) + seg_mask = np.delete(seg_mask, range(reso[0], reso[0] + K), 0) + + # Convert images to torch and normalize to range [-1, 1] + img = torch.from_numpy(bgr_img.transpose((2, 0, 1))).unsqueeze(0) + img = 2 * img.float().div(255) - 1 + bg = torch.from_numpy(bg_im.transpose((2, 0, 1))).unsqueeze(0) + bg = 2 * bg.float().div(255) - 1 + rcnn_al = torch.from_numpy(seg_mask).unsqueeze(0).unsqueeze(0) + rcnn_al = 2 * rcnn_al.float().div(255) - 1 + multi_fr = torch.from_numpy(multi_fr.transpose((2, 0, 1))).unsqueeze(0) + multi_fr = 2 * multi_fr.float().div(255) - 1 + + with torch.no_grad(): + img, bg, rcnn_al, multi_fr = Variable(img.cuda()), Variable(bg.cuda()), Variable(rcnn_al.cuda()), Variable( + multi_fr.cuda()) + input_im = torch.cat([img, bg, rcnn_al, multi_fr], dim=1) + + alpha_pred, fg_pred_tmp = net(img, bg, rcnn_al, multi_fr) + + # for regions with alpha>0.95, simply use the image as fg + al_mask = (alpha_pred > 0.95).type(torch.cuda.FloatTensor) + fg_pred = img * al_mask + fg_pred_tmp * (1 - al_mask) + + alpha_out = to_image(alpha_pred[0, ...]) + + # Filter alpha image by largest connected area + labels = label((alpha_out > 0.05).astype(int)) + try: + assert (labels.max() != 0) + except: + continue + largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 + alpha_out = alpha_out * largestCC + alpha_out = (255 * alpha_out[..., 0]).astype(np.uint8) + + fg_out = to_image(fg_pred[0, ...]) + fg_out = fg_out * np.expand_dims((alpha_out.astype(float) / 255 > 0.01).astype(float), axis=2) + fg_out = (255 * fg_out).astype(np.uint8) + + # Uncrop + alpha_out = uncrop(alpha_out, bbox, output_height, output_width) + fg_out = uncrop(fg_out, bbox, output_height, output_width) + + # Resize target backgrounds to foreground size + target_img = cv2.resize(target_img, (output_width, output_height)) + target_green_img = cv2.resize(target_green_img, (output_width, output_height)) + # Compose cutout foreground on Background + compose_target_img = composite4(fg_out, target_img, alpha_out) + compose_target_green_img = composite4(fg_out, target_green_img, alpha_out) + + # Write images to file + cv2.imwrite(result_path + '/' + filename.replace('_img', '_out'), alpha_out) + cv2.imwrite(result_path + '/' + filename.replace('_img', '_fg'), cv2.cvtColor(fg_out, cv2.COLOR_BGR2RGB)) + cv2.imwrite(result_path + '/' + filename.replace('_img', '_compose'), cv2.cvtColor(compose_target_img, cv2.COLOR_BGR2RGB)) + cv2.imwrite(result_path + '/' + filename.replace('_img', '_matte').format(i), + cv2.cvtColor(compose_target_green_img, cv2.COLOR_BGR2RGB)) diff --git a/test_background-matting_video.py b/test_background-matting_video.py new file mode 100644 index 0000000..84e4f19 --- /dev/null +++ b/test_background-matting_video.py @@ -0,0 +1,220 @@ +from __future__ import print_function + +import argparse +import glob +import os +import sys + +import torch.backends.cudnn as cudnn +import torch.nn as nn +import tqdm +from skimage.measure import label +from torch.autograd import Variable + +from functions import * +from networks import ResnetConditionHR + +torch.set_num_threads(1) + +parser = argparse.ArgumentParser(description='Background Matting.') +parser.add_argument('-m', '--trained_model', type=str, default='real-fixed-cam', + choices=['real-fixed-cam', 'real-hand-held', 'syn-comp-adobe'], + help='Trained background matting model') +parser.add_argument('-o', '--output_dir', type=str, required=True, + help='Directory to save the output results. (required)') +parser.add_argument('-i', '--input_dir', type=str, required=True, help='Directory to load input images. (required)') +parser.add_argument('-tb', '--target_back', type=str, + help='Target background to put foreground on. Either path to an image or an directory.') +parser.add_argument('-b', '--back', type=str, default=None, + help='Captured background image for fixed-cam mode. In case of hand-held mode, leave empty') + +args = parser.parse_args() +# input data path +data_path = args.input_dir + +if os.path.isdir(args.target_back): + args.video = True + print('Using video mode') +else: + args.video = False + print('Using image mode') + # target background path + target_img = cv2.imread(args.target_back) + target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB) + # Green-screen background + target_green_img = np.zeros(target_img.shape, dtype=np.uint8) + target_green_img[..., 0] = 0 + target_green_img[..., 1] = 255 + target_green_img[..., 2] = 0 + +# initialize network +model_main_dir = 'Models/' + args.trained_model + '/' +model_filepath = glob.glob(model_main_dir + 'netG_epoch_*')[0] +print("Loading model", model_filepath) +net = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3) +net = nn.DataParallel(net) +net.load_state_dict(torch.load(model_filepath)) +net.cuda() +net.eval() +cudnn.benchmark = True +reso = (512, 512) # input resolution to the network + +# load captured background for video mode, fixed camera +if args.back is not None: + back_img = cv2.imread(args.back) + back_img = cv2.cvtColor(back_img, cv2.COLOR_BGR2RGB) + +# output directory +result_path = args.output_dir + +if not os.path.exists(result_path): + os.makedirs(result_path) + +video = cv2.VideoCapture(data_path) +num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) +fps = int(video.get(cv2.CAP_PROP_FPS)) +width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) +height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + +masks_video = cv2.VideoCapture(data_path.replace("_raw.mp4", "_masksDL.avi")) + +if not "raw" in data_path: + print("Wrong input file specified!") + sys.exit(0) + +# Create a video writer for cutout object placed on target +output_path = data_path.replace("raw", "out") +video_writer = cv2.VideoWriter(output_path, + cv2.VideoWriter_fourcc(*'mp4v'), + fps, + (width, height)) +print("Writing video to", output_path) + +# Create video writer for matte. Use lossless png compression +output_path = data_path.replace("raw", "masks").replace("mp4", "avi") +masks_video_writer = cv2.VideoWriter(output_path, + cv2.VideoWriter_fourcc(*'png '), + fps, + (width, height)) +for i in tqdm.trange(num_frames): + + # original image + ret, bgr_img = video.read() + bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + output_height = bgr_img.shape[0] + output_width = bgr_img.shape[1] + + # segmentation mask + ret, seg_mask = masks_video.read() + seg_mask = seg_mask[:, :, 0] + + ## create the multi-frame + multi_fr_w = np.zeros((output_height, output_width, 4)) + multi_fr_w[..., 0] = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2GRAY) + multi_fr_w[..., 1] = multi_fr_w[..., 0] + multi_fr_w[..., 2] = multi_fr_w[..., 0] + multi_fr_w[..., 3] = multi_fr_w[..., 0] + + # Crop all images by the bbox of the rough segmentation mask + try: + bbox = get_bbox(seg_mask, R=output_height, C=output_width) + except: # Catch error occurring if no object is currently visible + masks_video_writer.write(np.zeros((output_height, output_width, 3), dtype=np.uint8)) + target_green_img = cv2.resize(target_green_img, (output_width, output_height)) + video_writer.write(cv2.cvtColor(target_green_img, cv2.COLOR_BGR2RGB)) + print("Skipping frame", i) + continue + + try: + crop_list = [bgr_img, back_img, seg_mask, multi_fr_w] + crop_list = crop_images(crop_list, reso, bbox) + bgr_img = crop_list[0] + bg_im = crop_list[1] + seg_mask = crop_list[2] + multi_fr = crop_list[3] + except: # Catch error occurring if no (big enough) object is currently visible + masks_video_writer.write(np.zeros((output_height, output_width, 3), dtype=np.uint8)) + target_green_img = cv2.resize(target_green_img, (output_width, output_height)) + video_writer.write(cv2.cvtColor(target_green_img, cv2.COLOR_BGR2RGB)) + print("Error cropping", i) + continue + + # Preprocess the rough segmentation mask + kernel_er = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + kernel_dil = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + seg_mask = seg_mask.astype(np.float32) / 255 + seg_mask[seg_mask > 0.2] = 1 + K = 25 + + zero_id = np.nonzero(np.sum(seg_mask, axis=1) == 0) + del_id = zero_id[0][zero_id[0] > 250] + if len(del_id) > 0: + del_id = [del_id[0] - 2, del_id[0] - 1, *del_id] + seg_mask = np.delete(seg_mask, del_id, 0) + seg_mask = cv2.copyMakeBorder(seg_mask, 0, K + len(del_id), 0, 0, cv2.BORDER_REPLICATE) + + seg_mask = cv2.erode(seg_mask, kernel_er, iterations=10) + seg_mask = cv2.dilate(seg_mask, kernel_dil, iterations=5) + seg_mask = cv2.GaussianBlur(seg_mask.astype(np.float32), (31, 31), 0) + seg_mask = (255 * seg_mask).astype(np.uint8) + seg_mask = np.delete(seg_mask, range(reso[0], reso[0] + K), 0) + + # Convert images to torch and normalize to range [-1, 1] + img = torch.from_numpy(bgr_img.transpose((2, 0, 1))).unsqueeze(0) + img = 2 * img.float().div(255) - 1 + bg = torch.from_numpy(bg_im.transpose((2, 0, 1))).unsqueeze(0) + bg = 2 * bg.float().div(255) - 1 + rcnn_al = torch.from_numpy(seg_mask).unsqueeze(0).unsqueeze(0) + rcnn_al = 2 * rcnn_al.float().div(255) - 1 + multi_fr = torch.from_numpy(multi_fr.transpose((2, 0, 1))).unsqueeze(0) + multi_fr = 2 * multi_fr.float().div(255) - 1 + + with torch.no_grad(): + img, bg, rcnn_al, multi_fr = Variable(img.cuda()), Variable(bg.cuda()), Variable(rcnn_al.cuda()), Variable( + multi_fr.cuda()) + input_im = torch.cat([img, bg, rcnn_al, multi_fr], dim=1) + + alpha_pred, fg_pred_tmp = net(img, bg, rcnn_al, multi_fr) + + # for regions with alpha>0.95, simply use the image as fg + al_mask = (alpha_pred > 0.95).type(torch.cuda.FloatTensor) + fg_pred = img * al_mask + fg_pred_tmp * (1 - al_mask) + + alpha_out = to_image(alpha_pred[0, ...]) + + # Filter alpha image by largest connected area + labels = label((alpha_out > 0.05).astype(int)) + try: + assert (labels.max() != 0) + except: + masks_video_writer.write(np.zeros((output_height, output_width, 3), dtype=np.uint8)) + target_green_img = cv2.resize(target_green_img, (output_width, output_height)) + video_writer.write(cv2.cvtColor(target_green_img, cv2.COLOR_BGR2RGB)) + print("Skipping frame", i) + continue + largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 + alpha_out = alpha_out * largestCC + alpha_out = (255 * alpha_out[..., 0]).astype(np.uint8) + + fg_out = to_image(fg_pred[0, ...]) + fg_out = fg_out * np.expand_dims((alpha_out.astype(float) / 255 > 0.01).astype(float), axis=2) + fg_out = (255 * fg_out).astype(np.uint8) + + # Uncrop + alpha_out = uncrop(alpha_out, bbox, output_height, output_width) + fg_out = uncrop(fg_out, bbox, output_height, output_width) + + # Resize target backgrounds to foreground size + target_img = cv2.resize(target_img, (output_width, output_height)) + target_green_img = cv2.resize(target_green_img, (output_width, output_height)) + # Compose cutout foreground on Background + compose_target_img = composite4(fg_out, target_img, alpha_out) + compose_target_green_img = composite4(fg_out, target_green_img, alpha_out) + + # Write results to videos + masks_video_writer.write(np.repeat(np.expand_dims(alpha_out, -1), 3, axis=2)) + video_writer.write(cv2.cvtColor(compose_target_green_img, cv2.COLOR_BGR2RGB)) + +video_writer.release() +video.release() +print("Finished") diff --git a/test_segmentation_deeplab.py b/test_segmentation_deeplab.py index 3d9ca37..48c0981 100644 --- a/test_segmentation_deeplab.py +++ b/test_segmentation_deeplab.py @@ -1,179 +1,231 @@ import os -from io import BytesIO import tarfile -import tempfile -from six.moves import urllib +from pathlib import Path +import argparse +import cv2 +import glob import numpy as np -from PIL import Image -import cv2, pdb, glob, argparse - import tensorflow as tf - +import tqdm +from PIL import Image +from six.moves import urllib class DeepLabModel(object): - """Class to load deeplab model and run inference.""" - - INPUT_TENSOR_NAME = 'ImageTensor:0' - OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' - INPUT_SIZE = 513 - FROZEN_GRAPH_NAME = 'frozen_inference_graph' - - def __init__(self, tarball_path): - #"""Creates and loads pretrained deeplab model.""" - self.graph = tf.Graph() - graph_def = None - # Extract frozen graph from tar archive. - tar_file = tarfile.open(tarball_path) - for tar_info in tar_file.getmembers(): - if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name): - file_handle = tar_file.extractfile(tar_info) - graph_def = tf.GraphDef.FromString(file_handle.read()) - break - - tar_file.close() - - if graph_def is None: - raise RuntimeError('Cannot find inference graph in tar archive.') - - with self.graph.as_default(): - tf.import_graph_def(graph_def, name='') - - self.sess = tf.Session(graph=self.graph) - - def run(self, image): - """Runs inference on a single image. - - Args: - image: A PIL.Image object, raw input image. - - Returns: - resized_image: RGB image resized from original input image. - seg_map: Segmentation map of `resized_image`. - """ - width, height = image.size - resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) - target_size = (int(resize_ratio * width), int(resize_ratio * height)) - resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS) - batch_seg_map = self.sess.run( - self.OUTPUT_TENSOR_NAME, - feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]}) - seg_map = batch_seg_map[0] - return resized_image, seg_map + """Class to load deeplab model and run inference.""" + + INPUT_TENSOR_NAME = 'ImageTensor:0' + OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' + INPUT_SIZE = 513 + FROZEN_GRAPH_NAME = 'frozen_inference_graph' + + def __init__(self, tarball_path): + # """Creates and loads pretrained deeplab model.""" + self.graph = tf.Graph() + graph_def = None + # Extract frozen graph from tar archive. + tar_file = tarfile.open(tarball_path) + for tar_info in tar_file.getmembers(): + if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name): + file_handle = tar_file.extractfile(tar_info) + graph_def = tf.GraphDef.FromString(file_handle.read()) + break + + tar_file.close() + + if graph_def is None: + raise RuntimeError('Cannot find inference graph in tar archive.') + + with self.graph.as_default(): + tf.import_graph_def(graph_def, name='') + + self.sess = tf.Session(graph=self.graph) + + def run(self, image): + """Runs inference on a single image. + + Args: + image: A PIL.Image object, raw input image. + + Returns: + resized_image: RGB image resized from original input image. + seg_map: Segmentation map of `resized_image`. + """ + width, height = image.size + resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) + target_size = (int(resize_ratio * width), int(resize_ratio * height)) + resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS) + batch_seg_map = self.sess.run( + self.OUTPUT_TENSOR_NAME, + feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]}) + seg_map = batch_seg_map[0] + return resized_image, seg_map + def create_pascal_label_colormap(): - """Creates a label colormap used in PASCAL VOC segmentation benchmark. + """Creates a label colormap used in PASCAL VOC segmentation benchmark. + + Returns: + A Colormap for visualizing segmentation results. + """ + colormap = np.zeros((256, 3), dtype=int) + ind = np.arange(256, dtype=int) - Returns: - A Colormap for visualizing segmentation results. - """ - colormap = np.zeros((256, 3), dtype=int) - ind = np.arange(256, dtype=int) + for shift in reversed(range(8)): + for channel in range(3): + colormap[:, channel] |= ((ind >> channel) & 1) << shift + ind >>= 3 - for shift in reversed(range(8)): - for channel in range(3): - colormap[:, channel] |= ((ind >> channel) & 1) << shift - ind >>= 3 + return colormap - return colormap def label_to_color_image(label): - """Adds color defined by the dataset colormap to the label. + """Adds color defined by the dataset colormap to the label. - Args: - label: A 2D array with integer type, storing the segmentation label. + Args: + label: A 2D array with integer type, storing the segmentation label. - Returns: - result: A 2D array with floating type. The element of the array - is the color indexed by the corresponding element in the input label - to the PASCAL color map. + Returns: + result: A 2D array with floating type. The element of the array + is the color indexed by the corresponding element in the input label + to the PASCAL color map. - Raises: - ValueError: If label is not of rank 2 or its value is larger than color - map maximum entry. - """ - if label.ndim != 2: - raise ValueError('Expect 2-D input label') + Raises: + ValueError: If label is not of rank 2 or its value is larger than color + map maximum entry. + """ + if label.ndim != 2: + raise ValueError('Expect 2-D input label') - colormap = create_pascal_label_colormap() + colormap = create_pascal_label_colormap() - if np.max(label) >= len(colormap): - raise ValueError('label value too large.') + if np.max(label) >= len(colormap): + raise ValueError('label value too large.') - return colormap[label] + return colormap[label] +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Deeplab Segmentation') + parser.add_argument('-i', '--image_dir', default=None, type=str, + help='Directory to search for images (*_img.png)') + parser.add_argument('-v', '--video_dir', default=None, type=str, + help='Directory to search for videos (*.mp4)') + args = parser.parse_args() -parser = argparse.ArgumentParser(description='Deeplab Segmentation') -parser.add_argument('-i', '--input_dir', type=str, required=True,help='Directory to save the output results. (required)') -args=parser.parse_args() +## setup #################### -dir_name=args.input_dir; + LABEL_NAMES = np.asarray([ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', + 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv' + ]) + FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) + FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) -## setup #################### + MODEL_NAME = 'xception_coco_voctrainval' # @param ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval'] + + _DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/' + _MODEL_URLS = { + 'mobilenetv2_coco_voctrainaug': + 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz', + 'mobilenetv2_coco_voctrainval': + 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz', + 'xception_coco_voctrainaug': + 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz', + 'xception_coco_voctrainval': + 'deeplabv3_pascal_trainval_2018_01_04.tar.gz', + } + _TARBALL_NAME = _MODEL_URLS[MODEL_NAME] + + model_dir = 'deeplab_model' + if not os.path.exists(model_dir): + tf.io.gfile.makedirs(model_dir) + + download_path = os.path.join(model_dir, _TARBALL_NAME) + if not os.path.exists(download_path): + print('downloading model to %s, this might take a while...' % download_path) + urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], + download_path) + print('download completed! loading DeepLab model...') + + MODEL = DeepLabModel(download_path) + print('model loaded successfully!') + + ####################################################################################### + # Images -LABEL_NAMES = np.asarray([ - 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', - 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', - 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv' -]) + image_dir = args.image_dir + if image_dir: + image_paths = glob.glob(image_dir + '/*_img.png') + image_paths.sort() + print("Found {} images in {}".format(len(image_paths), image_dir)) -FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) -FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) + for image_path in tqdm.tqdm(image_paths): + image = Image.open(image_path) + res_im, seg = MODEL.run(image) -MODEL_NAME = 'xception_coco_voctrainval' # @param ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval'] + seg = cv2.resize(seg.astype(np.uint8), image.size) -_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/' -_MODEL_URLS = { - 'mobilenetv2_coco_voctrainaug': - 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz', - 'mobilenetv2_coco_voctrainval': - 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz', - 'xception_coco_voctrainaug': - 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz', - 'xception_coco_voctrainval': - 'deeplabv3_pascal_trainval_2018_01_04.tar.gz', -} -_TARBALL_NAME = _MODEL_URLS[MODEL_NAME] + mask_sel = (seg == 15).astype(np.float32) -model_dir = 'deeplab_model' -if not os.path.exists(model_dir): - tf.gfile.MakeDirs(model_dir) + name = image_path.replace('img', 'masksDL') + cv2.imwrite(name, (255 * mask_sel).astype(np.uint8)) -download_path = os.path.join(model_dir, _TARBALL_NAME) -if not os.path.exists(download_path): - print('downloading model to %s, this might take a while...' % download_path) - urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], - download_path) - print('download completed! loading DeepLab model...') + print('\nDone: ' + image_dir) + else: + print("No image dir specified!") -MODEL = DeepLabModel(download_path) -print('model loaded successfully!') + ####################################################################################### + # Videos + video_dir = args.video_dir -####################################################################################### + if video_dir: + video_paths = glob.glob(video_dir + '/*_raw.mp4') + video_paths.sort() + print("Found {} videos in {}".format(len(video_paths), video_dir)) + for video_path in tqdm.tqdm(video_paths): + print("Processing video", video_path) + video = cv2.VideoCapture(video_path) + num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = int(video.get(cv2.CAP_PROP_FPS)) + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) -list_im=glob.glob(dir_name + '/*_img.png'); list_im.sort() + output_path = video_path.replace("raw", "masksDL").replace("mp4", "avi") + video_writer = cv2.VideoWriter(output_path, + cv2.VideoWriter_fourcc(*'png '), + fps, + (width, height)) -for i in range(0,len(list_im)): + for i_frame in tqdm.trange(num_frames): + ret, frame = video.read() - image = Image.open(list_im[i]) + if not ret: + print("Could not read video frame {}!".format(i_frame)) - res_im,seg=MODEL.run(image) + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + image = Image.fromarray(frame) - seg=cv2.resize(seg.astype(np.uint8),image.size) + res_im, seg = MODEL.run(image) - mask_sel=(seg==15).astype(np.float32) + seg = cv2.resize(seg.astype(np.uint8), image.size) + mask_sel = (seg == 15).astype(np.float32) - name=list_im[i].replace('img','masksDL') - cv2.imwrite(name,(255*mask_sel).astype(np.uint8)) + # Make 3 channel image + mask_sel = np.repeat(np.expand_dims(mask_sel, -1), 3, axis=2) -str_msg='\nDone: ' + dir_name -print(str_msg) + video_writer.write((255 * mask_sel).astype(np.uint8)) + video.release() + video_writer.release() + else: + print("No video dir specified!") + print("Finished processing")