Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vggt/models/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def __init__(
self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)

self.use_reentrant = False # hardcoded to False
self.use_checkpointing_inference = False # set True to checkpoint during eval (saves VRAM)

def __build_patch_embed__(
self,
Expand Down Expand Up @@ -612,7 +613,7 @@ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):

# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.training:
if self.training or self.use_checkpointing_inference:
tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
else:
tokens = self.frame_blocks[frame_idx](tokens, pos=pos, viz_attention=False)
Expand All @@ -635,7 +636,7 @@ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None, co

# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.training:
if self.training or (self.use_checkpointing_inference and not (viz_attention or compute_similarity)):
tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
else:
if viz_attention or compute_similarity:
Expand Down
9 changes: 8 additions & 1 deletion vggt/models/vggt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, img_size=518, patch_size=14, embed_dim=1024,
self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") if enable_depth else None
self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) if enable_track else None

self.sequential_heads = False # when True, run heads one at a time to reduce peak memory

def forward(self, images: torch.Tensor, query_points: torch.Tensor = None, compute_similarity=False):
"""
Forward pass of the VGGT model.
Expand Down Expand Up @@ -67,7 +69,12 @@ def forward(self, images: torch.Tensor, query_points: torch.Tensor = None, compu
pose_enc_list = self.camera_head(aggregated_tokens_list)
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
predictions["pose_enc_list"] = pose_enc_list


if self.sequential_heads and self.depth_head is not None:
# Free camera head intermediates before running depth head
del pose_enc_list
torch.cuda.empty_cache()

if self.depth_head is not None:
depth, depth_conf = self.depth_head(
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
Expand Down