diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..fe58f73 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,17 @@ +name: Lint + +on: + push: + branches: [main] + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v5 + - run: uv python install 3.10 + - run: uv pip install ruff + - run: uv run ruff check . + - run: uv run ruff format --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..700cbc4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/benchmark_capacity.py b/benchmark_capacity.py index dbd76b3..9b14699 100644 --- a/benchmark_capacity.py +++ b/benchmark_capacity.py @@ -3,13 +3,14 @@ Uses exponential probing (1, 2, 4, 8, ...) to quickly find the max power-of-2 number of images that fits in GPU memory. No slow binary search. """ -import torch +import argparse import gc -import time import json -import argparse +import time from datetime import datetime +import torch + def clear_gpu(): gc.collect() diff --git a/demo_gradio.py b/demo_gradio.py index 71dfd87..53ab679 100644 --- a/demo_gradio.py +++ b/demo_gradio.py @@ -1,24 +1,23 @@ +import gc +import glob import os -import cv2 -import torch -import numpy as np -import gradio as gr -import sys import shutil -from datetime import datetime -import glob -import gc import time -# import spaces # only for web demo - -from pi3.utils.geometry import se3_inverse, homogenize_points, depth_edge -from pi3.models.pi3 import Pi3 -from pi3.utils.basic import load_images_as_tensor +from datetime import datetime -import trimesh +import cv2 +import gradio as gr import matplotlib +import numpy as np +import torch +import trimesh from scipy.spatial.transform import Rotation +from pi3.models.pi3 import Pi3 +from pi3.utils.basic import load_images_as_tensor + +# import spaces # only for web demo +from pi3.utils.geometry import depth_edge """ Gradio utils diff --git a/example.py b/example.py index 513f709..237d085 100644 --- a/example.py +++ b/example.py @@ -1,8 +1,10 @@ -import torch import argparse + +import torch + +from pi3.models.pi3 import Pi3 from pi3.utils.basic import load_images_as_tensor, write_ply from pi3.utils.geometry import depth_edge -from pi3.models.pi3 import Pi3 if __name__ == '__main__': # --- Argument Parsing --- @@ -25,7 +27,7 @@ print(f'Sampling interval: {args.interval}') # 1. Prepare model - print(f"Loading model...") + print("Loading model...") device = torch.device(args.device) if args.ckpt is not None: model = Pi3().to(device).eval() diff --git a/example_mm.py b/example_mm.py index 71492ae..5a39ea0 100644 --- a/example_mm.py +++ b/example_mm.py @@ -1,10 +1,12 @@ -import torch import argparse -import numpy as np import os + +import numpy as np +import torch + +from pi3.models.pi3x import Pi3X from pi3.utils.basic import load_multimodal_data, write_ply from pi3.utils.geometry import depth_edge, recover_intrinsic_from_rays_d -from pi3.models.pi3x import Pi3X if __name__ == '__main__': # --- Argument Parsing --- @@ -60,7 +62,7 @@ print("No multimodal conditions found. Disable multimodal branch to reduce memory usage.") # 2. Prepare model - print(f"Loading model...") + print("Loading model...") if args.ckpt is not None: model = Pi3X(use_multimodal=use_multimodal).eval() if args.ckpt.endswith('.safetensors'): diff --git a/example_vo.py b/example_vo.py index 2e9f9c2..632789b 100644 --- a/example_vo.py +++ b/example_vo.py @@ -1,10 +1,11 @@ -import torch import argparse -import numpy as np import os -from pi3.utils.basic import load_multimodal_data, write_ply + +import torch + from pi3.models.pi3x import Pi3X from pi3.pipe.pi3x_vo import Pi3XVO +from pi3.utils.basic import load_multimodal_data, write_ply if __name__ == '__main__': parser = argparse.ArgumentParser(description="Run inference with the Pi3 model.") @@ -27,7 +28,7 @@ print(f'Sampling interval: {args.interval}') # 1. Prepare model - print(f"Loading model...") + print("Loading model...") device = torch.device(args.device) if args.ckpt is not None: model = Pi3X().to(device).eval() diff --git a/pi3/models/dinov2/hub/utils.py b/pi3/models/dinov2/hub/utils.py index 9c66414..a7482f9 100644 --- a/pi3/models/dinov2/hub/utils.py +++ b/pi3/models/dinov2/hub/utils.py @@ -10,7 +10,6 @@ import torch.nn as nn import torch.nn.functional as F - _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" diff --git a/pi3/models/dinov2/layers/__init__.py b/pi3/models/dinov2/layers/__init__.py index 05a0b61..26f3f38 100644 --- a/pi3/models/dinov2/layers/__init__.py +++ b/pi3/models/dinov2/layers/__init__.py @@ -3,9 +3,9 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. +from .attention import MemEffAttention +from .block import NestedTensorBlock from .dino_head import DINOHead from .mlp import Mlp from .patch_embed import PatchEmbed from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused -from .block import NestedTensorBlock -from .attention import MemEffAttention diff --git a/pi3/models/dinov2/layers/attention.py b/pi3/models/dinov2/layers/attention.py index 3fed573..6d60673 100644 --- a/pi3/models/dinov2/layers/attention.py +++ b/pi3/models/dinov2/layers/attention.py @@ -9,11 +9,8 @@ import logging import os -import warnings - -from torch import Tensor -from torch import nn +from torch import Tensor, nn logger = logging.getLogger("dinov2") diff --git a/pi3/models/dinov2/layers/block.py b/pi3/models/dinov2/layers/block.py index fd5b8a7..8b852d8 100644 --- a/pi3/models/dinov2/layers/block.py +++ b/pi3/models/dinov2/layers/block.py @@ -9,25 +9,23 @@ import logging import os -from typing import Callable, List, Any, Tuple, Dict -import warnings +from typing import Any, Callable, Dict, List, Tuple import torch -from torch import nn, Tensor +from torch import Tensor, nn from .attention import Attention, MemEffAttention from .drop_path import DropPath from .layer_scale import LayerScale from .mlp import Mlp - logger = logging.getLogger("dinov2") XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: - from xformers.ops import fmha, scaled_index_add, index_select_cat + from xformers.ops import fmha, index_select_cat, scaled_index_add XFORMERS_AVAILABLE = True # warnings.warn("xFormers is available (Block)") diff --git a/pi3/models/dinov2/layers/layer_scale.py b/pi3/models/dinov2/layers/layer_scale.py index 51df0d7..405dd84 100644 --- a/pi3/models/dinov2/layers/layer_scale.py +++ b/pi3/models/dinov2/layers/layer_scale.py @@ -8,8 +8,7 @@ from typing import Union import torch -from torch import Tensor -from torch import nn +from torch import Tensor, nn class LayerScale(nn.Module): diff --git a/pi3/models/dinov2/layers/patch_embed.py b/pi3/models/dinov2/layers/patch_embed.py index 8b7c080..d170d10 100644 --- a/pi3/models/dinov2/layers/patch_embed.py +++ b/pi3/models/dinov2/layers/patch_embed.py @@ -9,8 +9,8 @@ from typing import Callable, Optional, Tuple, Union -from torch import Tensor import torch.nn as nn +from torch import Tensor def make_2tuple(x): diff --git a/pi3/models/dinov2/layers/swiglu_ffn.py b/pi3/models/dinov2/layers/swiglu_ffn.py index 5ce2115..22b8326 100644 --- a/pi3/models/dinov2/layers/swiglu_ffn.py +++ b/pi3/models/dinov2/layers/swiglu_ffn.py @@ -5,10 +5,9 @@ import os from typing import Callable, Optional -import warnings -from torch import Tensor, nn import torch.nn.functional as F +from torch import Tensor, nn class SwiGLUFFN(nn.Module): diff --git a/pi3/models/dinov2/models/__init__.py b/pi3/models/dinov2/models/__init__.py index 3fdff20..50f937a 100644 --- a/pi3/models/dinov2/models/__init__.py +++ b/pi3/models/dinov2/models/__init__.py @@ -7,7 +7,6 @@ from . import vision_transformer as vits - logger = logging.getLogger("dinov2") diff --git a/pi3/models/dinov2/models/vision_transformer.py b/pi3/models/dinov2/models/vision_transformer.py index 73f15cf..598c54c 100644 --- a/pi3/models/dinov2/models/vision_transformer.py +++ b/pi3/models/dinov2/models/vision_transformer.py @@ -7,19 +7,18 @@ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py -from functools import partial import math -import logging -from typing import Sequence, Tuple, Union, Callable +from functools import partial +from typing import Callable, Sequence, Tuple, Union import torch import torch.nn as nn -from torch.utils.checkpoint import checkpoint from torch.nn.init import trunc_normal_ +from torch.utils.checkpoint import checkpoint -from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block from ...layers.attention import FlashAttention - +from ..layers import MemEffAttention, Mlp, PatchEmbed, SwiGLUFFNFused +from ..layers import NestedTensorBlock as Block # logger = logging.getLogger("dinov2") diff --git a/pi3/models/dinov2/utils/cluster.py b/pi3/models/dinov2/utils/cluster.py index 3df87dc..36f38ba 100644 --- a/pi3/models/dinov2/utils/cluster.py +++ b/pi3/models/dinov2/utils/cluster.py @@ -3,8 +3,8 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -from enum import Enum import os +from enum import Enum from pathlib import Path from typing import Any, Dict, Optional diff --git a/pi3/models/dinov2/utils/config.py b/pi3/models/dinov2/utils/config.py index c9de578..951e1a4 100644 --- a/pi3/models/dinov2/utils/config.py +++ b/pi3/models/dinov2/utils/config.py @@ -3,17 +3,16 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -import math import logging +import math import os from omegaconf import OmegaConf import dinov2.distributed as distributed +from dinov2.configs import dinov2_default_config from dinov2.logging import setup_logging from dinov2.utils import utils -from dinov2.configs import dinov2_default_config - logger = logging.getLogger("dinov2") diff --git a/pi3/models/dinov2/utils/dtype.py b/pi3/models/dinov2/utils/dtype.py index 80f4cd7..7ebdedf 100644 --- a/pi3/models/dinov2/utils/dtype.py +++ b/pi3/models/dinov2/utils/dtype.py @@ -9,7 +9,6 @@ import numpy as np import torch - TypeSpec = Union[str, np.dtype, torch.dtype] diff --git a/pi3/models/dinov2/utils/param_groups.py b/pi3/models/dinov2/utils/param_groups.py index 9a5d2ff..996775c 100644 --- a/pi3/models/dinov2/utils/param_groups.py +++ b/pi3/models/dinov2/utils/param_groups.py @@ -3,9 +3,8 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -from collections import defaultdict import logging - +from collections import defaultdict logger = logging.getLogger("dinov2") diff --git a/pi3/models/dinov2/utils/utils.py b/pi3/models/dinov2/utils/utils.py index e8842e4..e32352c 100644 --- a/pi3/models/dinov2/utils/utils.py +++ b/pi3/models/dinov2/utils/utils.py @@ -3,7 +3,6 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -import logging import os import random import subprocess @@ -13,7 +12,6 @@ import torch from torch import nn - # logger = logging.getLogger("dinov2") diff --git a/pi3/models/layers/attention.py b/pi3/models/layers/attention.py index b3407eb..7ce5367 100644 --- a/pi3/models/layers/attention.py +++ b/pi3/models/layers/attention.py @@ -7,21 +7,17 @@ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py -import logging import os -import warnings -from torch import Tensor -from torch import nn import torch - -from torch.nn.functional import scaled_dot_product_attention +from torch import Tensor, nn from torch.nn.attention import SDPBackend +from torch.nn.functional import scaled_dot_product_attention XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: - from xformers.ops import memory_efficient_attention, unbind + from xformers.ops import memory_efficient_attention, unbind # noqa: F401 XFORMERS_AVAILABLE = True # warnings.warn("xFormers is available (Attention)") @@ -370,7 +366,9 @@ def get_attn_score(blk_class, x, frame_num, token_length, xpos=None): return score -from .prope import _prepare_apply_fns, _prepare_apply_fns_query +from .prope import _prepare_apply_fns + + class PRopeFlashAttention(AttentionRope): def forward(self, x: Tensor, extrinsics, H, W, patch_h, patch_w, K=None, attn_mask=None) -> Tensor: B, N, C = x.shape diff --git a/pi3/models/layers/block.py b/pi3/models/layers/block.py index d8fe50f..6d075cd 100644 --- a/pi3/models/layers/block.py +++ b/pi3/models/layers/block.py @@ -7,24 +7,21 @@ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py -import logging import os -from typing import Callable, List, Any, Tuple, Dict -import warnings +from typing import Any, Callable, Dict, List, Tuple import torch -from torch import nn, Tensor +from torch import Tensor, nn -from .attention import Attention, MemEffAttention, CrossAttentionRope, MemEffCrossAttentionRope, FlashAttentionRope from ..dinov2.layers.drop_path import DropPath from ..dinov2.layers.layer_scale import LayerScale from ..dinov2.layers.mlp import Mlp - +from .attention import Attention, CrossAttentionRope, MemEffAttention XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: - from xformers.ops import fmha, scaled_index_add, index_select_cat + from xformers.ops import fmha, index_select_cat, scaled_index_add XFORMERS_AVAILABLE = True # warnings.warn("xFormers is available (Block)") @@ -407,8 +404,10 @@ def ffn_residual_func(x: Tensor) -> Tensor: -from .attention import PRopeFlashAttention from ...utils.geometry import se3_inverse +from .attention import PRopeFlashAttention + + class PoseInjectBlock(nn.Module): def __init__( self, diff --git a/pi3/models/layers/camera_head.py b/pi3/models/layers/camera_head.py index 7d844f7..43d1b6a 100644 --- a/pi3/models/layers/camera_head.py +++ b/pi3/models/layers/camera_head.py @@ -1,8 +1,10 @@ +from copy import deepcopy + import torch import torch.nn as nn -from copy import deepcopy import torch.nn.functional as F + # code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172' class ResConvBlock(nn.Module): """ diff --git a/pi3/models/layers/conv_head.py b/pi3/models/layers/conv_head.py index f6b63a5..39e59ca 100644 --- a/pi3/models/layers/conv_head.py +++ b/pi3/models/layers/conv_head.py @@ -2,9 +2,10 @@ Conv head is from MoGe (https://github.com/microsoft/moge) """ +from typing import * + import torch import torch.nn as nn -from typing import * import torch.nn.functional as F diff --git a/pi3/models/layers/pos_embed.py b/pi3/models/layers/pos_embed.py index e27ea0f..b0167ca 100644 --- a/pi3/models/layers/pos_embed.py +++ b/pi3/models/layers/pos_embed.py @@ -9,9 +9,9 @@ import numpy as np - import torch + # -------------------------------------------------------- # 2D sine-cosine position embedding # References: @@ -166,7 +166,7 @@ def __init__(self): self.cache_positions = {} def __call__(self, b, h, w, device): - if not (h,w) in self.cache_positions: + if (h,w) not in self.cache_positions: x = torch.arange(w, device=device) y = torch.arange(h, device=device) self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) diff --git a/pi3/models/layers/prope.py b/pi3/models/layers/prope.py index 34263a6..cb9ad0d 100644 --- a/pi3/models/layers/prope.py +++ b/pi3/models/layers/prope.py @@ -52,7 +52,7 @@ # o_src = attn_src._apply_to_o(o_src) from functools import partial -from typing import Callable, Optional, Tuple, List +from typing import Callable, List, Optional, Tuple import torch import torch.nn.functional as F diff --git a/pi3/models/layers/transformer_head.py b/pi3/models/layers/transformer_head.py index 813aef3..35d1517 100644 --- a/pi3/models/layers/transformer_head.py +++ b/pi3/models/layers/transformer_head.py @@ -1,11 +1,14 @@ -from .attention import FlashAttentionRope, FlashCrossAttentionRope -from .block import BlockRope, CrossOnlyBlockRope -from ..dinov2.layers import Mlp -import torch.nn as nn from functools import partial -from torch.utils.checkpoint import checkpoint + +import torch.nn as nn import torch.nn.functional as F - +from torch.utils.checkpoint import checkpoint + +from ..dinov2.layers import Mlp +from .attention import FlashAttentionRope, FlashCrossAttentionRope +from .block import BlockRope, CrossOnlyBlockRope + + class TransformerDecoder(nn.Module): def __init__( self, diff --git a/pi3/models/pi3.py b/pi3/models/pi3.py index 917c6cc..cf666bd 100644 --- a/pi3/models/pi3.py +++ b/pi3/models/pi3.py @@ -1,17 +1,19 @@ +from copy import deepcopy +from functools import partial + import torch import torch.nn as nn -from functools import partial -from copy import deepcopy +from huggingface_hub import PyTorchModelHubMixin -from .dinov2.layers import Mlp from ..utils.geometry import homogenize_points -from .layers.pos_embed import RoPE2D, PositionGetter -from .layers.block import BlockRope +from .dinov2.hub.backbones import dinov2_vitl14_reg +from .dinov2.layers import Mlp from .layers.attention import FlashAttentionRope -from .layers.transformer_head import TransformerDecoder, LinearPts3d +from .layers.block import BlockRope from .layers.camera_head import CameraHead -from .dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg -from huggingface_hub import PyTorchModelHubMixin +from .layers.pos_embed import PositionGetter, RoPE2D +from .layers.transformer_head import LinearPts3d, TransformerDecoder + class Pi3(nn.Module, PyTorchModelHubMixin): def __init__( diff --git a/pi3/models/pi3x.py b/pi3/models/pi3x.py index 2d20ca7..225fe6a 100644 --- a/pi3/models/pi3x.py +++ b/pi3/models/pi3x.py @@ -1,19 +1,20 @@ +from copy import deepcopy +from functools import partial + import torch import torch.nn as nn import torch.nn.functional as F -from functools import partial -from copy import deepcopy from huggingface_hub import PyTorchModelHubMixin -from .layers.conv_head import ConvHead -from .layers.camera_head import CameraHead +from ..utils.geometry import get_pixel, homogenize_points, se3_inverse +from .dinov2.hub.backbones import dinov2_vitl14_reg from .dinov2.layers import Mlp, PatchEmbed from .layers.attention import FlashAttentionRope from .layers.block import BlockRope, PoseInjectBlock -from .layers.pos_embed import RoPE2D, PositionGetter -from .dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg -from ..utils.geometry import se3_inverse, get_pixel, homogenize_points -from .layers.transformer_head import TransformerDecoder, ContextOnlyTransformerDecoder +from .layers.camera_head import CameraHead +from .layers.conv_head import ConvHead +from .layers.pos_embed import PositionGetter, RoPE2D +from .layers.transformer_head import ContextOnlyTransformerDecoder, TransformerDecoder class Pi3X(nn.Module, PyTorchModelHubMixin): diff --git a/pi3/pipe/pi3x_vo.py b/pi3/pipe/pi3x_vo.py index 983fe00..56adedf 100644 --- a/pi3/pipe/pi3x_vo.py +++ b/pi3/pipe/pi3x_vo.py @@ -1,6 +1,6 @@ -from ..utils.geometry import homogenize_points, depth_edge import torch -import torch.nn.functional as F + +from ..utils.geometry import depth_edge class Pi3XVO: diff --git a/pi3/utils/basic.py b/pi3/utils/basic.py index a11546f..9146f00 100644 --- a/pi3/utils/basic.py +++ b/pi3/utils/basic.py @@ -1,12 +1,14 @@ +import math import os import os.path as osp -import math + import cv2 -from PIL import Image +import numpy as np import torch -from torchvision import transforms +from PIL import Image from plyfile import PlyData, PlyElement -import numpy as np +from torchvision import transforms + def load_images_as_tensor(path="data/truck", interval=1, PIXEL_LIMIT=255000, verbose=True): """ diff --git a/pi3/utils/debug.py b/pi3/utils/debug.py index f3da8f3..ead4c63 100644 --- a/pi3/utils/debug.py +++ b/pi3/utils/debug.py @@ -1,8 +1,10 @@ -import os import json -import debugpy -import socket +import os import random +import socket + +import debugpy + def update_vscode_launch_file(host: str, port: int): """Update the .vscode/launch.json file with the new host and port.""" diff --git a/pi3/utils/geometry.py b/pi3/utils/geometry.py index 481d47c..bc05d25 100644 --- a/pi3/utils/geometry.py +++ b/pi3/utils/geometry.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F + def se3_inverse(T): """ Computes the inverse of a batch of SE(3) matrices. diff --git a/pyproject.toml b/pyproject.toml index ebd834c..096c46a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,3 +20,14 @@ packages = ["pi3"] # Include package data [tool.setuptools.package-data] "pi3" = ["**/*"] + +[tool.ruff] +line-length = 120 +extend-exclude = ["*.ipynb"] + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = ["E501", "E402", "E731", "E722", "E741", "E721", "E701", "F841", "F403", "F405", "F821"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"]