Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Lint

on:
push:
branches: [main]
pull_request:

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- run: uv python install 3.10
- run: uv pip install ruff
- run: uv run ruff check .
- run: uv run ruff format --check .
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
7 changes: 4 additions & 3 deletions benchmark_capacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
Uses exponential probing (1, 2, 4, 8, ...) to quickly find the max power-of-2
number of images that fits in GPU memory. No slow binary search.
"""
import torch
import argparse
import gc
import time
import json
import argparse
import time
from datetime import datetime

import torch


def clear_gpu():
gc.collect()
Expand Down
27 changes: 13 additions & 14 deletions demo_gradio.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import gc
import glob
import os
import cv2
import torch
import numpy as np
import gradio as gr
import sys
import shutil
from datetime import datetime
import glob
import gc
import time
# import spaces # only for web demo

from pi3.utils.geometry import se3_inverse, homogenize_points, depth_edge
from pi3.models.pi3 import Pi3
from pi3.utils.basic import load_images_as_tensor
from datetime import datetime

import trimesh
import cv2
import gradio as gr
import matplotlib
import numpy as np
import torch
import trimesh
from scipy.spatial.transform import Rotation

from pi3.models.pi3 import Pi3
from pi3.utils.basic import load_images_as_tensor

# import spaces # only for web demo
from pi3.utils.geometry import depth_edge

"""
Gradio utils
Expand Down
8 changes: 5 additions & 3 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
import argparse

import torch

from pi3.models.pi3 import Pi3
from pi3.utils.basic import load_images_as_tensor, write_ply
from pi3.utils.geometry import depth_edge
from pi3.models.pi3 import Pi3

if __name__ == '__main__':
# --- Argument Parsing ---
Expand All @@ -25,7 +27,7 @@
print(f'Sampling interval: {args.interval}')

# 1. Prepare model
print(f"Loading model...")
print("Loading model...")
device = torch.device(args.device)
if args.ckpt is not None:
model = Pi3().to(device).eval()
Expand Down
10 changes: 6 additions & 4 deletions example_mm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import argparse
import numpy as np
import os

import numpy as np
import torch

from pi3.models.pi3x import Pi3X
from pi3.utils.basic import load_multimodal_data, write_ply
from pi3.utils.geometry import depth_edge, recover_intrinsic_from_rays_d
from pi3.models.pi3x import Pi3X

if __name__ == '__main__':
# --- Argument Parsing ---
Expand Down Expand Up @@ -60,7 +62,7 @@
print("No multimodal conditions found. Disable multimodal branch to reduce memory usage.")

# 2. Prepare model
print(f"Loading model...")
print("Loading model...")
if args.ckpt is not None:
model = Pi3X(use_multimodal=use_multimodal).eval()
if args.ckpt.endswith('.safetensors'):
Expand Down
9 changes: 5 additions & 4 deletions example_vo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
import argparse
import numpy as np
import os
from pi3.utils.basic import load_multimodal_data, write_ply

import torch

from pi3.models.pi3x import Pi3X
from pi3.pipe.pi3x_vo import Pi3XVO
from pi3.utils.basic import load_multimodal_data, write_ply

if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Run inference with the Pi3 model.")
Expand All @@ -27,7 +28,7 @@
print(f'Sampling interval: {args.interval}')

# 1. Prepare model
print(f"Loading model...")
print("Loading model...")
device = torch.device(args.device)
if args.ckpt is not None:
model = Pi3X().to(device).eval()
Expand Down
1 change: 0 additions & 1 deletion pi3/models/dinov2/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch.nn as nn
import torch.nn.functional as F


_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"


Expand Down
4 changes: 2 additions & 2 deletions pi3/models/dinov2/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from .attention import MemEffAttention
from .block import NestedTensorBlock
from .dino_head import DINOHead
from .mlp import Mlp
from .patch_embed import PatchEmbed
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
from .block import NestedTensorBlock
from .attention import MemEffAttention
5 changes: 1 addition & 4 deletions pi3/models/dinov2/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@

import logging
import os
import warnings

from torch import Tensor
from torch import nn

from torch import Tensor, nn

logger = logging.getLogger("dinov2")

Expand Down
8 changes: 3 additions & 5 deletions pi3/models/dinov2/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,23 @@

import logging
import os
from typing import Callable, List, Any, Tuple, Dict
import warnings
from typing import Any, Callable, Dict, List, Tuple

import torch
from torch import nn, Tensor
from torch import Tensor, nn

from .attention import Attention, MemEffAttention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp


logger = logging.getLogger("dinov2")


XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import fmha, scaled_index_add, index_select_cat
from xformers.ops import fmha, index_select_cat, scaled_index_add

XFORMERS_AVAILABLE = True
# warnings.warn("xFormers is available (Block)")
Expand Down
3 changes: 1 addition & 2 deletions pi3/models/dinov2/layers/layer_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from typing import Union

import torch
from torch import Tensor
from torch import nn
from torch import Tensor, nn


class LayerScale(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion pi3/models/dinov2/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from typing import Callable, Optional, Tuple, Union

from torch import Tensor
import torch.nn as nn
from torch import Tensor


def make_2tuple(x):
Expand Down
3 changes: 1 addition & 2 deletions pi3/models/dinov2/layers/swiglu_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import os
from typing import Callable, Optional
import warnings

from torch import Tensor, nn
import torch.nn.functional as F
from torch import Tensor, nn


class SwiGLUFFN(nn.Module):
Expand Down
1 change: 0 additions & 1 deletion pi3/models/dinov2/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from . import vision_transformer as vits


logger = logging.getLogger("dinov2")


Expand Down
11 changes: 5 additions & 6 deletions pi3/models/dinov2/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py

from functools import partial
import math
import logging
from typing import Sequence, Tuple, Union, Callable
from functools import partial
from typing import Callable, Sequence, Tuple, Union

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch.nn.init import trunc_normal_
from torch.utils.checkpoint import checkpoint

from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
from ...layers.attention import FlashAttention

from ..layers import MemEffAttention, Mlp, PatchEmbed, SwiGLUFFNFused
from ..layers import NestedTensorBlock as Block

# logger = logging.getLogger("dinov2")

Expand Down
2 changes: 1 addition & 1 deletion pi3/models/dinov2/utils/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from enum import Enum
import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional

Expand Down
5 changes: 2 additions & 3 deletions pi3/models/dinov2/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import math
import logging
import math
import os

from omegaconf import OmegaConf

import dinov2.distributed as distributed
from dinov2.configs import dinov2_default_config
from dinov2.logging import setup_logging
from dinov2.utils import utils
from dinov2.configs import dinov2_default_config


logger = logging.getLogger("dinov2")

Expand Down
1 change: 0 additions & 1 deletion pi3/models/dinov2/utils/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
import torch


TypeSpec = Union[str, np.dtype, torch.dtype]


Expand Down
3 changes: 1 addition & 2 deletions pi3/models/dinov2/utils/param_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from collections import defaultdict
import logging

from collections import defaultdict

logger = logging.getLogger("dinov2")

Expand Down
2 changes: 0 additions & 2 deletions pi3/models/dinov2/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import logging
import os
import random
import subprocess
Expand All @@ -13,7 +12,6 @@
import torch
from torch import nn


# logger = logging.getLogger("dinov2")


Expand Down
14 changes: 6 additions & 8 deletions pi3/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py

import logging
import os
import warnings

from torch import Tensor
from torch import nn
import torch

from torch.nn.functional import scaled_dot_product_attention
from torch import Tensor, nn
from torch.nn.attention import SDPBackend
from torch.nn.functional import scaled_dot_product_attention

XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import memory_efficient_attention, unbind
from xformers.ops import memory_efficient_attention, unbind # noqa: F401

XFORMERS_AVAILABLE = True
# warnings.warn("xFormers is available (Attention)")
Expand Down Expand Up @@ -370,7 +366,9 @@ def get_attn_score(blk_class, x, frame_num, token_length, xpos=None):
return score


from .prope import _prepare_apply_fns, _prepare_apply_fns_query
from .prope import _prepare_apply_fns


class PRopeFlashAttention(AttentionRope):
def forward(self, x: Tensor, extrinsics, H, W, patch_h, patch_w, K=None, attn_mask=None) -> Tensor:
B, N, C = x.shape
Expand Down
Loading