Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
parser.add_argument("--min_disparity", type=float, default=50, help="Minimum disparity to generate a new keyframe")
parser.add_argument("--conf_threshold", type=float, default=25.0, help="Initial percentage of low-confidence points to filter out")
parser.add_argument("--lc_thres", type=float, default=0.95, help="Threshold for image retrieval. Range: [0, 1.0]. Higher = more loop closures")
parser.add_argument("--low_vram", action="store_true", help="Enable low-VRAM mode for GPUs with <=12GB (auto-reduces submap size, enables checkpointing)")
parser.add_argument("--checkpoint_inference", action="store_true", help="Enable gradient checkpointing during inference to reduce VRAM (slower but uses ~40%% less memory)")
parser.add_argument("--sequential_heads", action="store_true", help="Run prediction heads sequentially to reduce peak memory")


def main():
Expand All @@ -39,6 +42,20 @@ def main():
"""
args = parser.parse_args()

# Auto-configure for low-VRAM GPUs
if args.low_vram:
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0
print(f"Low-VRAM mode enabled. Detected {vram_gb:.1f} GB VRAM.")
if vram_gb <= 8:
args.submap_size = min(args.submap_size, 4)
elif vram_gb <= 12:
args.submap_size = min(args.submap_size, 6)
elif vram_gb <= 16:
args.submap_size = min(args.submap_size, 10)
args.checkpoint_inference = True
args.sequential_heads = True
print(f" submap_size={args.submap_size}, checkpointing=True, sequential_heads=True")

use_optical_flow_downsample = True
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Expand Down Expand Up @@ -75,6 +92,17 @@ def main():

model.eval()
model = model.to(torch.bfloat16) # use half precision

# Enable gradient checkpointing during inference to reduce activation memory
if args.checkpoint_inference:
model.aggregator.use_checkpointing_inference = True
print("Gradient checkpointing enabled for inference (saves ~40% VRAM)")

# Enable sequential head execution to reduce peak memory
if args.sequential_heads:
model.sequential_heads = True
print("Sequential head execution enabled (reduces peak memory)")

model = model.to(device)

# Use the provided image folder path
Expand Down
5 changes: 5 additions & 0 deletions vggt_slam/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,11 @@ def sample_pixel_coordinates(self, H, W, n):

def run_predictions(self, image_names, model, max_loops, clip_model, clip_preprocess):
device = "cuda" if torch.cuda.is_available() else "cpu"

# Clear GPU cache before inference to maximize available memory
if device == "cuda":
torch.cuda.empty_cache()

t1 = time.time()
with self.vggt_timer:
images = load_and_preprocess_images(image_names).to(device)
Expand Down