diff --git a/.gitignore b/.gitignore
index e43b0f9..777695f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,5 @@
+outputs
+cache
.DS_Store
+*__pycache*
+.idea
diff --git a/README.md b/README.md
index 7edb9a4..5f14711 100644
--- a/README.md
+++ b/README.md
@@ -1,35 +1,8 @@
## GET3D: A Generative Model of High Quality 3D Textured Shapes Learned from Images (NeurIPS 2022)
Official PyTorch implementation
-
+This is GET3D implementation with multi-nodes training support.
-**GET3D: A Generative Model of High Quality 3D Textured Shapes Learned from Images**
-[Jun Gao](http://www.cs.toronto.edu/~jungao/)
-, [Tianchang Shen](http://www.cs.toronto.edu/~shenti11/)
-, [Zian Wang](http://www.cs.toronto.edu/~zianwang/),
-[Wenzheng Chen](http://www.cs.toronto.edu/~wenzheng/), [Kangxue Yin](https://kangxue.org/)
-, [Daiqing Li](https://scholar.google.ca/citations?user=8q2ISMIAAAAJ&hl=en),
-[Or Litany](https://orlitany.github.io/), [Zan Gojcic](https://zgojcic.github.io/),
-[Sanja Fidler](https://www.cs.toronto.edu/~fidler/)
-**[Paper](https://nv-tlabs.github.io/GET3D/assets/paper.pdf)
-, [Project Page](https://nv-tlabs.github.io/GET3D/)**
-
-Abstract: *As several industries are moving towards modeling massive 3D virtual worlds,
-the need for content creation tools that can scale in terms of the quantity, quality, and
-diversity of 3D content is becoming evident. In our work, we aim to train performant 3D
-generative models that synthesize textured meshes which can be directly consumed by 3D
-rendering engines, thus immediately usable in downstream applications. Prior works on 3D
-generative modeling either lack geometric details, are limited in the mesh topology they
-can produce, typically do not support textures, or utilize neural renderers in the
-synthesis process, which makes their use in common 3D software non-trivial. In this work,
-we introduce GET3D, a Generative model that directly generates Explicit Textured 3D meshes
-with complex topology, rich geometric details, and high fidelity textures. We bridge
-recent success in the differentiable surface modeling, differentiable rendering as well as
-2D Generative Adversarial Networks to train our model from 2D image collections. GET3D is
-able to generate high-quality 3D textured meshes, ranging from cars, chairs, animals,
-motorbikes and human characters to buildings, achieving significant improvements over
-previous methods.*
-
-
+
For business inquiries, please visit our website and submit the
form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)
@@ -177,4 +150,4 @@ and Daiqing Li and Or Litany and Zan Gojcic and Sanja Fidler},
booktitle={Advances In Neural Information Processing Systems},
year={2022}
}
-```
\ No newline at end of file
+```
diff --git a/evaluation_scripts/sample_surface.py b/evaluation_scripts/sample_surface.py
index c2e34f3..744e239 100644
--- a/evaluation_scripts/sample_surface.py
+++ b/evaluation_scripts/sample_surface.py
@@ -64,7 +64,7 @@
[0, options.focal_length_y, options.principal_point_y],
[0, 0, 1]
])
-glctx = dr.RasterizeGLContext()
+glctx = dr.RasterizeCudaContext()
def CalcLinearZ(depth):
diff --git a/metrics/metric_utils.py b/metrics/metric_utils.py
index 06806f1..9ce4ab2 100644
--- a/metrics/metric_utils.py
+++ b/metrics/metric_utils.py
@@ -46,15 +46,16 @@ def save_image_grid(img, fname, drange, grid_size):
# ----------------------------------------------------------------------------
class MetricOptions:
- def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True, batch_size=None):
assert 0 <= rank < num_gpus
self.G = G
self.G_kwargs = dnnlib.EasyDict(G_kwargs)
self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
self.num_gpus = num_gpus
self.rank = rank
- self.device = device if device is not None else torch.device('cuda', rank)
+ self.device = device if device is not None else torch.device('cuda')
self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
+ self.batch_size = batch_size
self.cache = cache
@@ -227,7 +228,9 @@ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1)
# ----------------------------------------------------------------------------
-def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
+def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=None, data_loader_kwargs=None, max_items=None, **stats_kwargs):
+ if not batch_size:
+ batch_size = opts.batch_size
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
if data_loader_kwargs is None:
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
diff --git a/train_3d.py b/train_3d.py
index f2b87f8..8757441 100644
--- a/train_3d.py
+++ b/train_3d.py
@@ -21,23 +21,32 @@
# ----------------------------------------------------------------------------
-def subprocess_fn(rank, c, temp_dir):
+def subprocess_fn(local_rank, c, temp_dir):
dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)
+ rank = 0
# Init torch.distributed.
if c.num_gpus > 1:
+ rank = int(os.environ['RANK']) # 当前机器编号
+ gpus = torch.cuda.device_count() # 每台机器的GPU个数
+ rank = rank * gpus + local_rank
+ hosts = int(os.environ['WORLD_SIZE']) # 机器个数
+ c.num_gpus = hosts * gpus
+
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
if os.name == 'nt':
init_method = 'file:///' + init_file.replace('\\', '/')
torch.distributed.init_process_group(
backend='gloo', icfgnit_method=init_method, rank=rank, world_size=c.num_gpus)
else:
- init_method = f'file://{init_file}'
+ ip = os.environ.get('MASTER_ADDR', 'localhost')
+ port = os.environ['MASTER_PORT']
+ init_method = f'tcp://{ip}:{port}'
torch.distributed.init_process_group(
backend='nccl', init_method=init_method, rank=rank, world_size=c.num_gpus)
-
+ torch.cuda.set_device(local_rank)
# Init torch_utils.
- sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None
+ sync_device = torch.device('cuda') if c.num_gpus > 1 else None
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
if rank != 0:
custom_ops.verbosity = 'none'
@@ -100,9 +109,10 @@ def launch_training(c, desc, outdir, dry_run):
torch.multiprocessing.set_start_method('spawn', force=True)
with tempfile.TemporaryDirectory() as temp_dir:
if c.num_gpus == 1:
- subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
+ subprocess_fn(local_rank=0, c=c, temp_dir=temp_dir)
else:
- torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=c.num_gpus)
+ ngpus = torch.cuda.device_count()
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=ngpus)
# ----------------------------------------------------------------------------
@@ -289,7 +299,8 @@ def main(**kwargs):
c.image_snapshot_ticks = c.network_snapshot_ticks = opts.snap
c.random_seed = c.training_set_kwargs.random_seed = opts.seed
c.data_loader_kwargs.num_workers = opts.workers
- c.network_snapshot_ticks = 200
+ if opts.gpus <= 8:
+ c.network_snapshot_ticks = 200
# Sanity checks.
if c.batch_size % c.num_gpus != 0:
raise click.ClickException('--batch must be a multiple of --gpus')
diff --git a/training/dataset.py b/training/dataset.py
index cfef640..3b2a1d8 100644
--- a/training/dataset.py
+++ b/training/dataset.py
@@ -289,6 +289,7 @@ def __getitem__(self, idx):
or self.data_camera_mode == 'shapenet_motorbike' or self.data_camera_mode == 'ts_house' or self.data_camera_mode == 'ts_animal' \
:
ori_img = cv2.imread(fname, cv2.IMREAD_UNCHANGED)
+ assert ori_img is not None, f'{fname}. No such img.'
img = ori_img[:, :, :3][..., ::-1]
mask = ori_img[:, :, 3:4]
condinfo = np.zeros(2)
diff --git a/training/inference_3d.py b/training/inference_3d.py
index befba57..7a70a15 100644
--- a/training/inference_3d.py
+++ b/training/inference_3d.py
@@ -51,7 +51,7 @@ def inference(
bias_act._init()
filtered_lrelu._init()
- device = torch.device('cuda', rank)
+ device = torch.device('cuda')
np.random.seed(random_seed * num_gpus + rank)
torch.manual_seed(random_seed * num_gpus + rank)
torch.backends.cudnn.enabled = True
diff --git a/training/networks_get3d.py b/training/networks_get3d.py
index 67175ce..ba42ffd 100644
--- a/training/networks_get3d.py
+++ b/training/networks_get3d.py
@@ -382,7 +382,7 @@ def extract_3d_shape(
all_gb_pose = []
all_uv_mask = []
if self.dmtet_geometry.renderer.ctx is None:
- self.dmtet_geometry.renderer.ctx = dr.RasterizeGLContext(device=self.device)
+ self.dmtet_geometry.renderer.ctx = dr.RasterizeCudaContext(device=self.device)
for v, f in zip(mesh_v, mesh_f):
uvs, mesh_tex_idx, gb_pos, mask = xatlas_uvmap(
self.dmtet_geometry.renderer.ctx, v, f, resolution=texture_resolution)
diff --git a/training/training_loop_3d.py b/training/training_loop_3d.py
index 343416c..88e3953 100644
--- a/training/training_loop_3d.py
+++ b/training/training_loop_3d.py
@@ -116,7 +116,7 @@ def training_loop(
if num_gpus > 1:
torch.distributed.barrier()
start_time = time.time()
- device = torch.device('cuda', rank)
+ device = torch.device('cuda')
np.random.seed(random_seed * num_gpus + rank)
torch.manual_seed(random_seed * num_gpus + rank)
torch.backends.cudnn.enabled = True
@@ -380,7 +380,7 @@ def training_loop(
snapshot_data = dict(
G=G, D=D, G_ema=G_ema)
for key, value in snapshot_data.items():
- if isinstance(value, torch.nn.Module) and not isinstance(value, dr.ops.RasterizeGLContext):
+ if isinstance(value, torch.nn.Module) and not isinstance(value, dr.ops.RasterizeCudaContext):
if num_gpus > 1:
misc.check_ddp_consistency(value, ignore_regex=r'.*\.[^.]+_(avg|ema|ctx)')
for param in misc.params_and_buffers(value):
@@ -406,7 +406,9 @@ def training_loop(
with torch.no_grad():
result_dict = metric_main.calc_metric(
metric=metric, G=snapshot_data['G_ema'],
- dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
+ dataset_kwargs=training_set_kwargs, num_gpus=num_gpus,
+ batch_size=batch_size,
+ rank=rank, device='cuda')
if rank == 0:
metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
stats_metrics.update(result_dict.results)
diff --git a/uni_rep/render/neural_render.py b/uni_rep/render/neural_render.py
index 5fc9440..864cdc9 100644
--- a/uni_rep/render/neural_render.py
+++ b/uni_rep/render/neural_render.py
@@ -56,7 +56,7 @@ def render_mesh(
):
assert not hierarchical_mask
if self.ctx is None:
- self.ctx = dr.RasterizeGLContext(device=self.device)
+ self.ctx = dr.RasterizeCudaContext(device=self.device)
mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates