diff --git a/.gitignore b/.gitignore index a18c1a6..a160ad2 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +model_exports/* +model_exports2/* +videos/* \ No newline at end of file diff --git a/README.md b/README.md index 0733c71..d452f62 100644 --- a/README.md +++ b/README.md @@ -3,23 +3,30 @@ This repository is inspired by https://github.com/NVIDIA-AI-IOT/nanosam and adap Although the inference speed of the SAM2.1 Hiera backbones is already quite fast on GPUs, it is still difficult to deploy on edge devices. This repository aims to provide a more efficient alternative for SAM2.1 inference, with a focus on backbones that are smaller and faster to deploy. +## Installation -## Dependencies and Prerequirements +### Dependencies and Prerequirements - - Create a new Python 3.10+ environment and clone the repository. + - Create a new Python 3.12+ environment and clone the repository. - Install the dependencies listed below: ``` -pip install matplotlib torchvision tqdm hydra-core pycocotools requests iopath +pip install matplotlib torchvision tqdm hydra-core pycocotools requests iopath opencv-python ``` - Install the repository as editable package `pip install -e .` -## Inference +### Download checkpoints + +You can find and download pretrained nanosam2 checkpoints [here](https://drive.google.com/drive/folders/15wApVHwqJGunjDP_cx5YZDCTEKliOMCQ?usp=sharing). Each backbone was trained for 10 epochs on 14 SA1 datasets, i.e. ~175k images. + +## Inference Demos + +All inference demos are executed on the [Bedroom](https://github.com/facebookresearch/sam2/blob/2b90b9f5ceec907a1c18123530e92e794ad901a4/notebooks/videos/bedroom.mp4) video file, shared in the original [sam2](https://github.com/facebookresearch/sam2) repository. ### Video -Load all frames of a video at once into Nanosam2 and perform tracking of objects from any frame. To use the script you have to obtain all frames of the video as `.jpg` file. Place all `.jpg` files +Load all frames of a video at once into Nanosam2 and perform tracking of objects from any frame. To run the video demo you have to obtain all frames of the video as `.jpg` file. Place all `.jpg` files in the same directory and pass the directory to `video_frames_demo.py`. Extracting all frames of a video uns FFmpeg: @@ -36,7 +43,7 @@ python demos/video_frames_demo.py --config nanosam2.1_resnet18.yaml --checkpoint ### Camera Live Stream -Stream a video (of a camera or a video file) frame by frame into Nanosam2. Perform tracking of objects from any frame. +Stream a video (from a camera or video file source) frame by frame into Nanosam2. Start object tracking from any frame in the stream. For ResNet18 backend: @@ -122,9 +129,9 @@ python nanosam2/tools/compute_eval_coco_metric.py results/sam2.1_hiera_s_resnet1 ## Results FP32 -You can find pretrained nanosam2 checkpoints [here](https://drive.google.com/drive/folders/15wApVHwqJGunjDP_cx5YZDCTEKliOMCQ?usp=sharing). Each backbone was trained for 10 epochs on 14 SA1 datasets, i.e. ~175k images. + | Backbone | num_epochs | mIoU All | mIoU Small | mIoU Medium | mIoU Large | | -------- | -------- | -------- | -------- | -------- | -------- | | resnet18 | 10 | 0.69 | 0.62 | 0.73 | 0.76 | diff --git a/demos/live_demo.py b/demos/live_demo.py index 767ee3d..84f2ae7 100644 --- a/demos/live_demo.py +++ b/demos/live_demo.py @@ -1,4 +1,10 @@ -# Live Demo +# Nanosam2 Live Demo +# +# Run inferences on a video stram from a camera or a video file. +# +# To run this script, create a new python environment (3.12) install all packages listed in the README.md file +# and add the "nanosam2" directory to your pythonpath (or install the package). +# # Based on "https://github.com/Gy920/segment-anything-2-real-time/blob/main/demo/demo.py". @@ -14,10 +20,11 @@ parser.add_argument("--config", type=str, default="sam2_hiera_s", help="The path to a sam2 config.") parser.add_argument("--checkpoint", type=str, default="sam2_checkpoints/sam2.1_hiera_small.pt") parser.add_argument('--video', default=0, help='Path to a video or a camera id, default: 0') +parser.add_argument('--device', default="cpu", help='Device to run the model on, default: cpu, also supports cuda') args = parser.parse_args() # Configure Device. -device = "cuda" +device = args.device if device == "cuda": # use bfloat16 for the entire notebook torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() @@ -32,22 +39,30 @@ frametimes = [] def _compile_model_blocks(model, model_settings:list, compile_backend): + if (all(model_settings) == False): + print("Skipping Model Compilation...") + return model print("Compiling Model...") if model_settings[0]: # image_encoder + print(" - Compiling Image Encoder...") model.image_encoder = torch.compile(model.image_encoder, backend=compile_backend, dynamic=False) if model_settings[1]: # memory_attention + print(" - Compiling Memory Attention...") model.memory_attention = torch.compile(model.memory_attention, backend=compile_backend) if model_settings[2]: # sam_mask_decoder + print(" - Compiling SAM Mask Decoder...") model.sam_mask_decoder = torch.compile(model.sam_mask_decoder, backend=compile_backend) if model_settings[3]: # sam_prompt_encoder + print(" - Compiling SAM Prompt Encoder...") model.sam_prompt_encoder = torch.compile(model.sam_prompt_encoder, backend=compile_backend) if model_settings[4]: # memory_encoder + print(" - Compiling Memory Encoder...") model.memory_encoder = torch.compile(model.memory_encoder, backend=compile_backend) print("Compile finished.") return model # Compile Model if Required. -#predictor = _compile_model_blocks(predictor, [True, False, False, False, False], "inductor") +predictor = _compile_model_blocks(predictor, [False, False, False, False, False], "inductor") # Open Video Stream. cap = cv2.VideoCapture(args.video) @@ -67,12 +82,31 @@ def _compile_model_blocks(model, model_settings:list, compile_backend): predictor.load_first_frame(frame) if_init = True - ann_frame_idx = 0 # the frame index we interact with - ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) - # Let's add a positive click at (x, y) = (210, 350) to get started + # --------------------------------------------------------------------------- + # for demo video: https://github.com/facebookresearch/sam2/blob/2b90b9f5ceec907a1c18123530e92e794ad901a4/notebooks/videos/bedroom.mp4 + # --------------------------------------------------------------------------- + + # Add bbox - boy + ann_frame_idx = 0 # frame index to annotate + ann_obj_id = 1 # unique object id to annotate + bbox = np.array([[230, 134], [294, 219]], dtype=np.float32) + _, out_obj_ids, out_mask_logits = predictor.add_new_prompt( + frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox + ) + # Add bbox - girl + ann_frame_idx = 0 # frame index to annotate + ann_obj_id = 2 # unique object id to annotate + bbox = np.array([[353, 11], [451, 122]], dtype=np.float32) + _, out_obj_ids, out_mask_logits = predictor.add_new_prompt( + frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox + ) - ##! add points, `1` means positive click and `0` means negative click + # --------------------------------------------------------------------------- + # other bounding box, mask and point examples + # --------------------------------------------------------------------------- + + # Add points, `1` means positive click and `0` means negative click # points = np.array([[660, 267]], dtype=np.float32) # labels = np.array([1], dtype=np.int32) @@ -80,13 +114,7 @@ def _compile_model_blocks(model, model_settings:list, compile_backend): # frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels # ) - ## ! add bbox - bbox = np.array([[600, 214], [765, 286]], dtype=np.float32) - _, out_obj_ids, out_mask_logits = predictor.add_new_prompt( - frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox - ) - - ##! add mask + # Add mask # mask_img_path="../notebooks/masks/aquarium/aquarium_mask.png" # mask = cv2.imread(mask_img_path, cv2.IMREAD_GRAYSCALE) # mask = mask / 255 diff --git a/demos/video_frames_demo.py b/demos/video_frames_demo.py index 6a6322b..89e3688 100644 --- a/demos/video_frames_demo.py +++ b/demos/video_frames_demo.py @@ -3,7 +3,7 @@ # Run inferences on the frames exported form a video. # # To run this script, create a new python environment (3.12) install all packages listed in the README.md file -# and add the "nanosam2" directory to your pythonpath. +# and add the "nanosam2" directory to your pythonpath (or install the package). # If you are using bash add the following line to your .bashrc # export PYTHONPATH="/nanosam2" @@ -102,13 +102,20 @@ def show_box(box, ax): predictor.reset_state(inference_state) +# --------------------------------------------------------------------------- +# for demo video: https://github.com/facebookresearch/sam2/blob/2b90b9f5ceec907a1c18123530e92e794ad901a4/notebooks/videos/bedroom.mp4 +# --------------------------------------------------------------------------- + print("\n#1: Set two points and predict a mask...") ann_frame_idx = 0 # the frame index we interact with ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) # Add a positive click at (x, y). # For labels, `1` means positive click and `0` means negative click -points = np.array([[770, 420], [750, 380]], dtype=np.float32) + + +# Add two points on the shirt of the girl. +points = np.array([[388, 139], [414, 165]], dtype=np.float32) labels = np.array([1,1], np.int32) _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( inference_state=inference_state, @@ -155,10 +162,10 @@ def show_box(box, ax): ann_obj_id = 4 # give a unique id to each object we interact with (it can be any integers) # Add a positive click at (x, y). -points = np.array([[560, 350], [770, 420], [750, 380]], dtype=np.float32) +points = np.array([[379, 149], [102, 138], [100, 172]], dtype=np.float32) labels = np.array([1,1,1], np.int32) # Box coordinates. -box = np.array([400, 320, 1100, 650], dtype=np.float32) +box = np.array([370, 111, 427, 185], dtype=np.float32) _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=ann_frame_idx, diff --git a/nanosam2/datasets/containers.py b/nanosam2/datasets/containers.py new file mode 100644 index 0000000..2829981 --- /dev/null +++ b/nanosam2/datasets/containers.py @@ -0,0 +1,9 @@ +# Containers for different use cases for Nanosam2. +# + + +class ModelSource: + def __init__(self, name:str, checkpoint:str, cfg:str): + self.name = name + self.checkpoint = checkpoint + self.cfg = cfg diff --git a/nanosam2/sam2/__init__.py b/nanosam2/sam2/__init__.py index 744c794..0d76214 100644 --- a/nanosam2/sam2/__init__.py +++ b/nanosam2/sam2/__init__.py @@ -1,5 +1,5 @@ from hydra import initialize_config_module from hydra.core.global_hydra import GlobalHydra -if not GlobalHydra().is_initialized(): - initialize_config_module("sam2_configs", version_base="1.2") \ No newline at end of file +if not GlobalHydra.instance().is_initialized(): + initialize_config_module("sam2_configs", version_base="1.2") diff --git a/nanosam2/sam2/build_sam.py b/nanosam2/sam2/build_sam.py index 3579f5e..5e50320 100644 --- a/nanosam2/sam2/build_sam.py +++ b/nanosam2/sam2/build_sam.py @@ -75,13 +75,21 @@ def build_sam2( mode="eval", load_image_encoder=True, hydra_overrides_extra=[], + apply_postprocessing=True, **kwargs, ): + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] # Read config and init model config_name = f'{config_dir}/{config_file}' if config_dir is not None else config_file - - cfg = compose(config_name=config_name) + cfg = compose(config_name=config_name, overrides=hydra_overrides_extra) OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path, load_image_encoder=load_image_encoder) @@ -98,11 +106,18 @@ def build_sam2_video_predictor( mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, + vos_optimized=False, **kwargs, ): hydra_overrides = [ "++model._target_=nanosam2.sam2.sam2_video_predictor.SAM2VideoPredictor", ] + if vos_optimized: + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS", + "++model.compile_image_encoder=True", # Let sam2_base handle this + ] + if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ @@ -191,6 +206,13 @@ def _load_checkpoint(model, ckpt_path, load_image_encoder=True): for k in list(sd.keys()): if "image_encoder" in k: del sd[k] + missing_keys, unexpected_keys = model.load_state_dict(sd, strict=load_image_encoder) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() logging.info("Loaded checkpoint sucessfully") diff --git a/nanosam2/sam2/modeling/backbones/image_encoder.py b/nanosam2/sam2/modeling/backbones/image_encoder.py index 005f47b..2508442 100644 --- a/nanosam2/sam2/modeling/backbones/image_encoder.py +++ b/nanosam2/sam2/modeling/backbones/image_encoder.py @@ -22,13 +22,25 @@ def __init__( self.trunk = trunk self.neck = neck self.scalp = scalp + self.feature_maps_callback=None # assert ( # self.trunk.channel_list == self.neck.backbone_channel_list # ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" def forward(self, sample: torch.Tensor): - # Forward through backbone - features, pos = self.neck(self.trunk(sample)) + # Feature map callback. + if self.feature_maps_callback is not None: + self.feature_maps_callback("image-encoder:trunk-input", {"0":sample.cpu()}) + + # Forward through backbone (trunk) + trunk = self.trunk(sample) + + # Feature map callback. + if self.feature_maps_callback is not None: + self.feature_maps_callback("image-encoder:trunk-output", {"0":trunk[0].cpu(), "1":trunk[1].cpu(), "2":trunk[2].cpu(), "3":trunk[3].cpu()}) + + # Forward through backbone (neck) + features, pos = self.neck(trunk) if self.scalp > 0: # Discard the lowest resolution features features, pos = features[: -self.scalp], pos[: -self.scalp] @@ -41,6 +53,9 @@ def forward(self, sample: torch.Tensor): "backbone_fpn": features, } return output + + def set_feature_maps_callback(self, fun): + self.feature_maps_callback = fun class FpnNeck(nn.Module): diff --git a/nanosam2/sam2/modeling/backbones/utils.py b/nanosam2/sam2/modeling/backbones/utils.py index 32d55c7..930b1b7 100644 --- a/nanosam2/sam2/modeling/backbones/utils.py +++ b/nanosam2/sam2/modeling/backbones/utils.py @@ -32,9 +32,7 @@ def window_partition(x, window_size): Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) - windows = ( - x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - ) + windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) return windows, (Hp, Wp) @@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw): Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) - x = windows.view( + x = windows.reshape( B, Hp // window_size, Wp // window_size, window_size, window_size, -1 ) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) if Hp > H or Wp > W: - x = x[:, :H, :W, :].contiguous() + x = x[:, :H, :W, :] return x diff --git a/nanosam2/sam2/modeling/position_encoding.py b/nanosam2/sam2/modeling/position_encoding.py index 52ac226..2241d4c 100644 --- a/nanosam2/sam2/modeling/position_encoding.py +++ b/nanosam2/sam2/modeling/position_encoding.py @@ -25,6 +25,11 @@ def __init__( temperature: int = 10000, normalize: bool = True, scale: Optional[float] = None, + # Following settings only relevant + # for warmping up cache for compilation + warmup_cache: bool = True, + image_size: int = 1024, + strides: Tuple[int] = (4, 8, 16, 32), ): super().__init__() assert num_pos_feats % 2 == 0, "Expecting even model width" @@ -38,6 +43,12 @@ def __init__( self.scale = scale self.cache = {} + if warmup_cache and torch.cuda.is_available(): + # Warmup cache for cuda, to help with compilation + device = torch.device("cuda") + for stride in strides: + cache_key = (image_size // stride, image_size // stride) + self._pe(1, device, *cache_key) def _encode_xy(self, x, y): # The positions are expected to be normalized @@ -76,19 +87,20 @@ def encode_points(self, x, y, labels): return pos @torch.no_grad() - def forward(self, x: torch.Tensor): - cache_key = (x.shape[-2], x.shape[-1]) + def _pe(self, B, device, *cache_key): + H, W = cache_key if cache_key in self.cache: - return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) + y_embed = ( - torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + torch.arange(1, H + 1, dtype=torch.float32, device=device) .view(1, -1, 1) - .repeat(x.shape[0], 1, x.shape[-1]) + .repeat(B, 1, W) ) x_embed = ( - torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + torch.arange(1, W + 1, dtype=torch.float32, device=device) .view(1, 1, -1) - .repeat(x.shape[0], x.shape[-2], 1) + .repeat(B, H, 1) ) if self.normalize: @@ -96,7 +108,7 @@ def forward(self, x: torch.Tensor): y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t @@ -111,6 +123,12 @@ def forward(self, x: torch.Tensor): self.cache[cache_key] = pos[0] return pos + @torch.no_grad() + def forward(self, x: torch.Tensor): + B = x.shape[0] + cache_key = (x.shape[-2], x.shape[-1]) + return self._pe(B, x.device, *cache_key) + class PositionEmbeddingRandom(nn.Module): """ diff --git a/nanosam2/sam2/modeling/sam/mask_decoder.py b/nanosam2/sam2/modeling/sam/mask_decoder.py index cf04ba7..2c72ac8 100644 --- a/nanosam2/sam2/modeling/sam/mask_decoder.py +++ b/nanosam2/sam2/modeling/sam/mask_decoder.py @@ -10,7 +10,7 @@ from torch import nn from nanosam2.sam2.modeling.sam2_utils import LayerNorm2d, MLP - +from nanosam2.sam2.utils.misc import pad_tensor, copy_to_smaller_tensor class MaskDecoder(nn.Module): def __init__( @@ -18,6 +18,7 @@ def __init__( *, transformer_dim: int, transformer: nn.Module, + feature_maps_callback=None, num_multimask_outputs: int = 3, activation: Type[nn.Module] = nn.GELU, iou_head_depth: int = 3, @@ -48,8 +49,10 @@ def __init__( used to predict mask quality """ super().__init__() + self.fixed_transformer_shapes=False self.transformer_dim = transformer_dim self.transformer = transformer + self.feature_maps_callback=feature_maps_callback self.num_multimask_outputs = num_multimask_outputs @@ -107,6 +110,9 @@ def __init__( self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + def set_feature_maps_callback(self, fun): + self.feature_maps_callback = fun + def forward( self, image_embeddings: torch.Tensor, @@ -164,6 +170,9 @@ def forward( # Prepare output return masks, iou_pred, sam_tokens_out, object_score_logits + + def set_fixed_transformer_shapes(self, fixed:bool): + self.fixed_transformer_shapes = fixed def predict_masks( self, @@ -209,8 +218,27 @@ def predict_masks( pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape + if self.fixed_transformer_shapes: + tokens_limit = 16 + objects_limit = 1 + objects_incoming = src.shape[0] + tokens_incoming = tokens.shape[1] + src = pad_tensor(src, (objects_limit,src.shape[1],src.shape[2],src.shape[3])) + pos_src = pad_tensor(pos_src, (objects_limit,pos_src.shape[1],pos_src.shape[2],pos_src.shape[3])) + tokens = pad_tensor(tokens, (objects_limit,tokens_limit,256)) + + # Feature map callback. + if self.feature_maps_callback is not None: + self.feature_maps_callback("mask-decoder-transformer:inputs", {"src":src.cpu(), "pos_src":pos_src.cpu(), "tokens":tokens.cpu()}) + # Run the transformer hs, src = self.transformer(src, pos_src, tokens) + + # Restore original shapes. + if self.fixed_transformer_shapes: + src = copy_to_smaller_tensor(src, (objects_incoming, 1024, 256)) + hs = copy_to_smaller_tensor(hs, (objects_incoming, tokens_incoming, 256)) + iou_token_out = hs[:, s, :] mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] diff --git a/nanosam2/sam2/modeling/sam/prompt_encoder.py b/nanosam2/sam2/modeling/sam/prompt_encoder.py index 9073e62..cd36afa 100644 --- a/nanosam2/sam2/modeling/sam/prompt_encoder.py +++ b/nanosam2/sam2/modeling/sam/prompt_encoder.py @@ -92,12 +92,32 @@ def _embed_points( point_embedding = self.pe_layer.forward_with_coords( points, self.input_image_size ) - point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight - point_embedding[labels == 0] += self.point_embeddings[0].weight - point_embedding[labels == 1] += self.point_embeddings[1].weight - point_embedding[labels == 2] += self.point_embeddings[2].weight - point_embedding[labels == 3] += self.point_embeddings[3].weight + + point_embedding = torch.where( + (labels == -1).unsqueeze(-1), + torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 0).unsqueeze(-1), + point_embedding + self.point_embeddings[0].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 1).unsqueeze(-1), + point_embedding + self.point_embeddings[1].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 2).unsqueeze(-1), + point_embedding + self.point_embeddings[2].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 3).unsqueeze(-1), + point_embedding + self.point_embeddings[3].weight, + point_embedding, + ) return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: diff --git a/nanosam2/sam2/modeling/sam/transformer.py b/nanosam2/sam2/modeling/sam/transformer.py index b68925f..7793998 100644 --- a/nanosam2/sam2/modeling/sam/transformer.py +++ b/nanosam2/sam2/modeling/sam/transformer.py @@ -4,9 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import contextlib import math -import warnings from functools import partial from typing import Tuple, Type @@ -16,29 +14,6 @@ from nanosam2.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from nanosam2.sam2.modeling.sam2_utils import MLP -from nanosam2.sam2.utils.misc import get_sdpa_settings - -warnings.simplefilter(action="ignore", category=FutureWarning) -# Check whether Flash Attention is available (and use it by default) -OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() -# A fallback setting to allow all available kernels if Flash Attention fails -ALLOW_ALL_KERNELS = False - - -def sdp_kernel_context(dropout_p): - """ - Get the context for the attention scaled dot-product kernel. We use Flash Attention - by default, but fall back to all available kernels if Flash Attention fails. - """ - if ALLOW_ALL_KERNELS: - return contextlib.nullcontext() - - return torch.backends.cuda.sdp_kernel( - enable_flash=USE_FLASH_ATTN, - # if Flash attention kernel is off, then math kernel needs to be enabled - enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, - enable_mem_efficient=OLD_GPU, - ) class TwoWayTransformer(nn.Module): @@ -265,20 +240,7 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: dropout_p = self.dropout_p if self.training else 0.0 # Attention - try: - with sdp_kernel_context(dropout_p): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) - except Exception as e: - # Fall back to all kernels if the Flash attention kernel fails - warnings.warn( - f"Flash Attention kernel failed due to: {e}\nFalling back to all available " - f"kernels for scaled_dot_product_attention (which may have a slower speed).", - category=UserWarning, - stacklevel=2, - ) - global ALLOW_ALL_KERNELS - ALLOW_ALL_KERNELS = True - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) @@ -305,7 +267,9 @@ def __init__( compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta ) freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) - self.freqs_cis = freqs_cis + self.freqs_cis = ( + freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis + ) self.rope_k_repeat = rope_k_repeat def forward( @@ -339,20 +303,7 @@ def forward( dropout_p = self.dropout_p if self.training else 0.0 # Attention - try: - with sdp_kernel_context(dropout_p): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) - except Exception as e: - # Fall back to all kernels if the Flash attention kernel fails - warnings.warn( - f"Flash Attention kernel failed due to: {e}\nFalling back to all available " - f"kernels for scaled_dot_product_attention (which may have a slower speed).", - category=UserWarning, - stacklevel=2, - ) - global ALLOW_ALL_KERNELS - ALLOW_ALL_KERNELS = True - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) diff --git a/nanosam2/sam2/modeling/sam2_base.py b/nanosam2/sam2/modeling/sam2_base.py index 94b603b..82b7841 100644 --- a/nanosam2/sam2/modeling/sam2_base.py +++ b/nanosam2/sam2/modeling/sam2_base.py @@ -96,6 +96,8 @@ def __init__( compile_image_encoder: bool = False, ): super().__init__() + # Feature Maps callback, default: None + self.feature_maps_callback = None # Part 1: the image backbone self.image_encoder = image_encoder @@ -198,6 +200,13 @@ def __init__( @property def device(self): return next(self.parameters()).device + + def set_feature_maps_callback(self, fun): + """Set the feature maps callback for all blocks supporting feature maps callback""" + self.feature_maps_callback = fun + self.sam_mask_decoder.set_feature_maps_callback(self.feature_maps_callback) + self.image_encoder.set_feature_maps_callback(self.feature_maps_callback) + def forward(self, *args, **kwargs): raise NotImplementedError( @@ -631,7 +640,9 @@ def _prepare_memory_conditioned_features( if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim - obj_pos = torch.tensor(pos_list, device=device) + obj_pos = torch.tensor(pos_list).to( + device=device, non_blocking=True + ) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) diff --git a/nanosam2/sam2/sam2_image_predictor.py b/nanosam2/sam2/sam2_image_predictor.py index 9125873..49a5977 100644 --- a/nanosam2/sam2/sam2_image_predictor.py +++ b/nanosam2/sam2/sam2_image_predictor.py @@ -59,9 +59,11 @@ def __init__( self.mask_threshold = mask_threshold # Spatial dim for backbone feature maps - # tanks for the comment here https://github.com/facebookresearch/sam2/issues/138#issuecomment-2269907504 - hires_size = self.model.image_size // 4 - self._bb_feat_sizes = [[hires_size // (2**k)]*2 for k in range(3)] + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] @classmethod def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": diff --git a/nanosam2/sam2/sam2_video_predictor.py b/nanosam2/sam2/sam2_video_predictor.py index 47b5dc9..c0fd1ea 100644 --- a/nanosam2/sam2/sam2_video_predictor.py +++ b/nanosam2/sam2/sam2_video_predictor.py @@ -8,6 +8,7 @@ from collections import OrderedDict import torch +import torch.nn.functional as F from tqdm import tqdm @@ -26,8 +27,6 @@ def __init__( # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) clear_non_cond_mem_around_input=False, - # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). - clear_non_cond_mem_for_multi_obj=False, # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames add_all_frames_to_correct_as_cond=False, @@ -37,7 +36,6 @@ def __init__( self.fill_hole_area = fill_hole_area self.non_overlap_masks = non_overlap_masks self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input - self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond @torch.inference_mode() @@ -47,7 +45,6 @@ def init_state( offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, - disable_prints=False ): """Initialize an inference state.""" compute_device = self.device # device of the model @@ -57,7 +54,6 @@ def init_state( offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, compute_device=compute_device, - disable_prints=disable_prints ) inference_state = {} inference_state["images"] = images @@ -89,11 +85,6 @@ def init_state( inference_state["obj_id_to_idx"] = OrderedDict() inference_state["obj_idx_to_id"] = OrderedDict() inference_state["obj_ids"] = [] - # A storage to hold the model's tracking results and states on each frame - inference_state["output_dict"] = { - "cond_frame_outputs": {}, # dict containing {frame_idx: } - "non_cond_frame_outputs": {}, # dict containing {frame_idx: } - } # Slice (view) of each object tracking results, sharing the same memory with "output_dict" inference_state["output_dict_per_obj"] = {} # A temporary storage to hold new outputs when user interact with a frame @@ -101,13 +92,8 @@ def init_state( inference_state["temp_output_dict_per_obj"] = {} # Frames that already holds consolidated outputs from click or mask inputs # (we directly use their consolidated outputs during tracking) - inference_state["consolidated_frame_inds"] = { - "cond_frame_outputs": set(), # set containing frame indices - "non_cond_frame_outputs": set(), # set containing frame indices - } # metadata for each tracking frame (e.g. which direction it's tracked) - inference_state["tracking_has_started"] = False - inference_state["frames_already_tracked"] = {} + inference_state["frames_tracked_per_obj"] = {} # Warm up the visual backbone and cache the image feature on frame 0 self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state @@ -135,9 +121,8 @@ def _obj_id_to_idx(self, inference_state, obj_id): if obj_idx is not None: return obj_idx - # This is a new object id not sent to the server before. We only allow adding - # new objects *before* the tracking starts. - allow_new_object = not inference_state["tracking_has_started"] + # We always allow adding new objects (including after tracking starts). + allow_new_object = True if allow_new_object: # get the next object slot obj_idx = len(inference_state["obj_id_to_idx"]) @@ -155,6 +140,7 @@ def _obj_id_to_idx(self, inference_state, obj_id): "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } + inference_state["frames_tracked_per_obj"][obj_idx] = {} return obj_idx else: raise RuntimeError( @@ -215,15 +201,6 @@ def add_new_points_or_box( "box prompt must be provided before any point prompt " "(please use clear_old_points=True instead)" ) - if inference_state["tracking_has_started"]: - warnings.warn( - "You are adding a box after tracking starts. SAM 2 may not always be " - "able to incorporate a box prompt for *refinement*. If you intend to " - "use box prompt as an *initial* input before tracking, please call " - "'reset_state' on the inference state to restart from scratch.", - category=UserWarning, - stacklevel=2, - ) if not isinstance(box, torch.Tensor): box = torch.tensor(box, dtype=torch.float32, device=points.device) box_coords = box.reshape(1, 2, 2) @@ -253,12 +230,13 @@ def add_new_points_or_box( # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. - is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] + is_init_cond_frame = frame_idx not in obj_frames_tracked # whether to track in reverse time order if is_init_cond_frame: reverse = False else: - reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + reverse = obj_frames_tracked[frame_idx]["reverse"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or @@ -307,7 +285,6 @@ def add_new_points_or_box( inference_state, frame_idx, is_cond=is_cond, - run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( @@ -358,12 +335,13 @@ def add_new_mask( # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. - is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] + is_init_cond_frame = frame_idx not in obj_frames_tracked # whether to track in reverse time order if is_init_cond_frame: reverse = False else: - reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + reverse = obj_frames_tracked[frame_idx]["reverse"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or @@ -395,7 +373,6 @@ def add_new_mask( inference_state, frame_idx, is_cond=is_cond, - run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( @@ -430,7 +407,6 @@ def _consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond, - run_mem_encoder, consolidate_at_video_res=False, ): """ @@ -447,7 +423,6 @@ def _consolidate_temp_output_across_obj( # Optionally, we allow consolidating the temporary outputs at the original # video resolution (to provide a better editing experience for mask prompts). if consolidate_at_video_res: - assert not run_mem_encoder, "memory encoder cannot run at video resolution" consolidated_H = inference_state["video_height"] consolidated_W = inference_state["video_width"] consolidated_mask_key = "pred_masks_video_res" @@ -460,30 +435,13 @@ def _consolidate_temp_output_across_obj( # constraints to object scores. Its "pred_masks" are prefilled with a large # negative value (NO_OBJ_SCORE) to represent missing objects. consolidated_out = { - "maskmem_features": None, - "maskmem_pos_enc": None, consolidated_mask_key: torch.full( size=(batch_size, 1, consolidated_H, consolidated_W), fill_value=NO_OBJ_SCORE, dtype=torch.float32, device=inference_state["storage_device"], ), - "obj_ptr": torch.full( - size=(batch_size, self.hidden_dim), - fill_value=NO_OBJ_SCORE, - dtype=torch.float32, - device=inference_state["device"], - ), - "object_score_logits": torch.full( - size=(batch_size, 1), - # default to 10.0 for object_score_logits, i.e. assuming the object is - # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` - fill_value=10.0, - dtype=torch.float32, - device=inference_state["device"], - ), } - empty_mask_ptr = None for obj_idx in range(batch_size): obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] @@ -500,16 +458,6 @@ def _consolidate_temp_output_across_obj( # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE # placeholder above) and set its object pointer to be a dummy pointer. if out is None: - # Fill in dummy object pointers for those objects without any inputs or - # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, - # i.e. when we need to build the memory for tracking). - if run_mem_encoder: - if empty_mask_ptr is None: - empty_mask_ptr = self._get_empty_mask_ptr( - inference_state, frame_idx - ) - # fill object pointer with a dummy pointer (based on an empty mask) - consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr continue # Add the temporary object output mask to consolidated output mask obj_mask = out["pred_masks"] @@ -525,141 +473,74 @@ def _consolidate_temp_output_across_obj( align_corners=False, ) consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask - consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] - consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[ - "object_score_logits" - ] - - # Optionally, apply non-overlapping constraints on the consolidated scores - # and rerun the memory encoder - if run_mem_encoder: - device = inference_state["device"] - high_res_masks = torch.nn.functional.interpolate( - consolidated_out["pred_masks"].to(device, non_blocking=True), - size=(self.image_size, self.image_size), - mode="bilinear", - align_corners=False, - ) - if self.non_overlap_masks_for_mem_enc: - high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) - maskmem_features, maskmem_pos_enc = self._run_memory_encoder( - inference_state=inference_state, - frame_idx=frame_idx, - batch_size=batch_size, - high_res_masks=high_res_masks, - object_score_logits=consolidated_out["object_score_logits"], - is_mask_from_pts=True, # these frames are what the user interacted with - ) - consolidated_out["maskmem_features"] = maskmem_features - consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc return consolidated_out - def _get_empty_mask_ptr(self, inference_state, frame_idx): - """Get a dummy object pointer based on an empty mask on the current frame.""" - # A dummy (empty) mask with a single object - batch_size = 1 - mask_inputs = torch.zeros( - (batch_size, 1, self.image_size, self.image_size), - dtype=torch.float32, - device=inference_state["device"], - ) - - # Retrieve correct image features - ( - _, - _, - current_vision_feats, - current_vision_pos_embeds, - feat_sizes, - ) = self._get_image_feature(inference_state, frame_idx, batch_size) - - # Feed the empty mask and image feature above to get a dummy object pointer - current_out = self.track_step( - frame_idx=frame_idx, - is_init_cond_frame=True, - current_vision_feats=current_vision_feats, - current_vision_pos_embeds=current_vision_pos_embeds, - feat_sizes=feat_sizes, - point_inputs=None, - mask_inputs=mask_inputs, - output_dict={}, - num_frames=inference_state["num_frames"], - track_in_reverse=False, - run_mem_encoder=False, - prev_sam_mask_logits=None, - ) - return current_out["obj_ptr"] - @torch.inference_mode() def propagate_in_video_preflight(self, inference_state): """Prepare inference_state and consolidate temporary outputs before tracking.""" - # Tracking has started and we don't allow adding new objects until session is reset. - inference_state["tracking_has_started"] = True + # Check and make sure that every object has received input points or masks. batch_size = self._get_obj_num(inference_state) + if batch_size == 0: + raise RuntimeError( + "No input points or masks are provided for any object; please add inputs first." + ) # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and # add them into "output_dict". - temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] - output_dict = inference_state["output_dict"] - # "consolidated_frame_inds" contains indices of those frames where consolidated - # temporary outputs have been added (either in this call or any previous calls - # to `propagate_in_video_preflight`). - consolidated_frame_inds = inference_state["consolidated_frame_inds"] - for is_cond in [False, True]: - # Separately consolidate conditioning and non-conditioning temp outputs - storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" - # Find all the frames that contain temporary outputs for any objects - # (these should be the frames that have just received clicks for mask inputs - # via `add_new_points_or_box` or `add_new_mask`) - temp_frame_inds = set() - for obj_temp_output_dict in temp_output_dict_per_obj.values(): - temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) - consolidated_frame_inds[storage_key].update(temp_frame_inds) - # consolidate the temporary output across all objects on this frame - for frame_idx in temp_frame_inds: - consolidated_out = self._consolidate_temp_output_across_obj( - inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True - ) - # merge them into "output_dict" and also create per-object slices - output_dict[storage_key][frame_idx] = consolidated_out - self._add_output_per_object( - inference_state, frame_idx, consolidated_out, storage_key - ) - clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( - self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + for obj_idx in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = ( + "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" ) - if clear_non_cond_mem: - # clear non-conditioning memory of the surrounding frames - self._clear_non_cond_mem_around_input(inference_state, frame_idx) + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + for frame_idx, out in obj_temp_output_dict[storage_key].items(): + # Run memory encoder on the temporary outputs (if the memory feature is missing) + if out["maskmem_features"] is None: + high_res_masks = torch.nn.functional.interpolate( + out["pred_masks"].to(inference_state["device"]), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + high_res_masks=high_res_masks, + object_score_logits=out["object_score_logits"], + # these frames are what the user interacted with + is_mask_from_pts=True, + ) + out["maskmem_features"] = maskmem_features + out["maskmem_pos_enc"] = maskmem_pos_enc + + obj_output_dict[storage_key][frame_idx] = out + if self.clear_non_cond_mem_around_input: + # clear non-conditioning memory of the surrounding frames + self._clear_obj_non_cond_mem_around_input( + inference_state, frame_idx, obj_idx + ) - # clear temporary outputs in `temp_output_dict_per_obj` - for obj_temp_output_dict in temp_output_dict_per_obj.values(): + # clear temporary outputs in `temp_output_dict_per_obj` obj_temp_output_dict[storage_key].clear() - # edge case: if an output is added to "cond_frame_outputs", we remove any prior - # output on the same frame in "non_cond_frame_outputs" - for frame_idx in output_dict["cond_frame_outputs"]: - output_dict["non_cond_frame_outputs"].pop(frame_idx, None) - for obj_output_dict in inference_state["output_dict_per_obj"].values(): + # check and make sure that every object has received input points or masks + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + if len(obj_output_dict["cond_frame_outputs"]) == 0: + obj_id = self._obj_idx_to_id(inference_state, obj_idx) + raise RuntimeError( + f"No input points or masks are provided for object id {obj_id}; please add inputs first." + ) + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" for frame_idx in obj_output_dict["cond_frame_outputs"]: obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) - for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: - assert frame_idx in output_dict["cond_frame_outputs"] - consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) - - # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames - # with either points or mask inputs (which should be true under a correct workflow). - all_consolidated_frame_inds = ( - consolidated_frame_inds["cond_frame_outputs"] - | consolidated_frame_inds["non_cond_frame_outputs"] - ) - input_frames_inds = set() - for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): - input_frames_inds.update(point_inputs_per_frame.keys()) - for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): - input_frames_inds.update(mask_inputs_per_frame.keys()) - assert all_consolidated_frame_inds == input_frames_inds @torch.inference_mode() def propagate_in_video( @@ -668,26 +549,22 @@ def propagate_in_video( start_frame_idx=None, max_frame_num_to_track=None, reverse=False, - disable_prints=False, ): """Propagate the input points across frames to track in the entire video.""" self.propagate_in_video_preflight(inference_state) - output_dict = inference_state["output_dict"] - consolidated_frame_inds = inference_state["consolidated_frame_inds"] obj_ids = inference_state["obj_ids"] num_frames = inference_state["num_frames"] batch_size = self._get_obj_num(inference_state) - if len(output_dict["cond_frame_outputs"]) == 0: - raise RuntimeError("No points are provided; please add points first") - clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( - self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 - ) # set start index, end index, and processing order if start_frame_idx is None: # default: start from the earliest frame with input points - start_frame_idx = min(output_dict["cond_frame_outputs"]) + start_frame_idx = min( + t + for obj_output_dict in inference_state["output_dict_per_obj"].values() + for t in obj_output_dict["cond_frame_outputs"] + ) if max_frame_num_to_track is None: # default: track all the frames in the video max_frame_num_to_track = num_frames @@ -703,79 +580,55 @@ def propagate_in_video( ) processing_order = range(start_frame_idx, end_frame_idx + 1) - for frame_idx in tqdm(processing_order, desc="propagate in video", disable=disable_prints): - # We skip those frames already in consolidated outputs (these are frames - # that received input clicks or mask). Note that we cannot directly run - # batched forward on them via `_run_single_frame_inference` because the - # number of clicks on each object might be different. - if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: - storage_key = "cond_frame_outputs" - current_out = output_dict[storage_key][frame_idx] - pred_masks = current_out["pred_masks"] - if clear_non_cond_mem: - # clear non-conditioning memory of the surrounding frames - self._clear_non_cond_mem_around_input(inference_state, frame_idx) - elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: - storage_key = "non_cond_frame_outputs" - current_out = output_dict[storage_key][frame_idx] - pred_masks = current_out["pred_masks"] - else: - storage_key = "non_cond_frame_outputs" - current_out, pred_masks = self._run_single_frame_inference( - inference_state=inference_state, - output_dict=output_dict, - frame_idx=frame_idx, - batch_size=batch_size, - is_init_cond_frame=False, - point_inputs=None, - mask_inputs=None, - reverse=reverse, - run_mem_encoder=True, - ) - output_dict[storage_key][frame_idx] = current_out - # Create slices of per-object outputs for subsequent interaction with each - # individual object after tracking. - self._add_output_per_object( - inference_state, frame_idx, current_out, storage_key - ) - inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + for frame_idx in tqdm(processing_order, desc="propagate in video"): + pred_masks_per_obj = [None] * batch_size + for obj_idx in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in obj_output_dict["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = obj_output_dict[storage_key][frame_idx] + device = inference_state["device"] + pred_masks = current_out["pred_masks"].to(device, non_blocking=True) + if self.clear_non_cond_mem_around_input: + # clear non-conditioning memory of the surrounding frames + self._clear_obj_non_cond_mem_around_input( + inference_state, frame_idx, obj_idx + ) + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + obj_output_dict[storage_key][frame_idx] = current_out + + inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { + "reverse": reverse + } + pred_masks_per_obj[obj_idx] = pred_masks # Resize the output mask to the original video resolution (we directly use # the mask scores on GPU for output to avoid any CPU conversion in between) + if len(pred_masks_per_obj) > 1: + all_pred_masks = torch.cat(pred_masks_per_obj, dim=0) + else: + all_pred_masks = pred_masks_per_obj[0] _, video_res_masks = self._get_orig_video_res_output( - inference_state, pred_masks + inference_state, all_pred_masks ) yield frame_idx, obj_ids, video_res_masks - def _add_output_per_object( - self, inference_state, frame_idx, current_out, storage_key - ): - """ - Split a multi-object output into per-object output slices and add them into - `output_dict_per_obj`. The resulting slices share the same tensor storage. - """ - maskmem_features = current_out["maskmem_features"] - assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) - - maskmem_pos_enc = current_out["maskmem_pos_enc"] - assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) - - output_dict_per_obj = inference_state["output_dict_per_obj"] - for obj_idx, obj_output_dict in output_dict_per_obj.items(): - obj_slice = slice(obj_idx, obj_idx + 1) - obj_out = { - "maskmem_features": None, - "maskmem_pos_enc": None, - "pred_masks": current_out["pred_masks"][obj_slice], - "obj_ptr": current_out["obj_ptr"][obj_slice], - "object_score_logits": current_out["object_score_logits"][obj_slice], - } - if maskmem_features is not None: - obj_out["maskmem_features"] = maskmem_features[obj_slice] - if maskmem_pos_enc is not None: - obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] - obj_output_dict[storage_key][frame_idx] = obj_out - @torch.inference_mode() def clear_all_prompts_in_frame( self, inference_state, frame_idx, obj_id, need_output=True @@ -791,41 +644,14 @@ def clear_all_prompts_in_frame( temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) - # Check and see if there are still any inputs left on this frame - batch_size = self._get_obj_num(inference_state) - frame_has_input = False - for obj_idx2 in range(batch_size): - if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: - frame_has_input = True - break - if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: - frame_has_input = True - break - - # If this frame has no remaining inputs for any objects, we further clear its - # conditioning frame status - if not frame_has_input: - output_dict = inference_state["output_dict"] - consolidated_frame_inds = inference_state["consolidated_frame_inds"] - consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) - consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) - # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) - out = output_dict["cond_frame_outputs"].pop(frame_idx, None) - if out is not None: - # The frame is not a conditioning frame anymore since it's not receiving inputs, - # so we "downgrade" its output (if exists) to a non-conditioning frame output. - output_dict["non_cond_frame_outputs"][frame_idx] = out - inference_state["frames_already_tracked"].pop(frame_idx, None) - # Similarly, do it for the sliced output on each object. - for obj_idx2 in range(batch_size): - obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] - obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) - if obj_out is not None: - obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out - - # If all the conditioning frames have been removed, we also clear the tracking outputs - if len(output_dict["cond_frame_outputs"]) == 0: - self._reset_tracking_results(inference_state) + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + obj_output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None) if not need_output: return @@ -839,7 +665,6 @@ def clear_all_prompts_in_frame( inference_state, frame_idx, is_cond=is_cond, - run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( @@ -859,6 +684,7 @@ def reset_state(self, inference_state): inference_state["mask_inputs_per_obj"].clear() inference_state["output_dict_per_obj"].clear() inference_state["temp_output_dict_per_obj"].clear() + inference_state["frames_tracked_per_obj"].clear() def _reset_tracking_results(self, inference_state): """Reset all tracking inputs and results across the videos.""" @@ -872,12 +698,8 @@ def _reset_tracking_results(self, inference_state): for v in inference_state["temp_output_dict_per_obj"].values(): v["cond_frame_outputs"].clear() v["non_cond_frame_outputs"].clear() - inference_state["output_dict"]["cond_frame_outputs"].clear() - inference_state["output_dict"]["non_cond_frame_outputs"].clear() - inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() - inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() - inference_state["tracking_has_started"] = False - inference_state["frames_already_tracked"].clear() + for v in inference_state["frames_tracked_per_obj"].values(): + v.clear() def _get_image_feature(self, inference_state, frame_idx, batch_size): """Compute the image features on a given frame.""" @@ -1095,8 +917,6 @@ def remove_object(self, inference_state, obj_id, strict=False, need_output=True) inference_state["obj_ids"] = new_obj_ids # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. - # (note that "consolidated_frame_inds" doesn't need to be updated in this step as - # it's already handled in Step 0) def _map_keys(container): new_kvs = [] for k in old_obj_inds: @@ -1109,30 +929,9 @@ def _map_keys(container): _map_keys(inference_state["mask_inputs_per_obj"]) _map_keys(inference_state["output_dict_per_obj"]) _map_keys(inference_state["temp_output_dict_per_obj"]) + _map_keys(inference_state["frames_tracked_per_obj"]) - # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. - def _slice_state(output_dict, storage_key): - for frame_idx, out in output_dict[storage_key].items(): - out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds] - out["maskmem_pos_enc"] = [ - x[remain_old_obj_inds] for x in out["maskmem_pos_enc"] - ] - # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it - out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) - out["pred_masks"] = out["pred_masks"][remain_old_obj_inds] - out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds] - out["object_score_logits"] = out["object_score_logits"][ - remain_old_obj_inds - ] - # also update the per-object slices - self._add_output_per_object( - inference_state, frame_idx, out, storage_key - ) - - _slice_state(inference_state["output_dict"], "cond_frame_outputs") - _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") - - # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which + # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which # could show an updated mask for objects previously occluded by the object being removed if need_output: temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] @@ -1145,7 +944,6 @@ def _slice_state(output_dict, storage_key): inference_state, frame_idx, is_cond=is_cond, - run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( @@ -1167,9 +965,259 @@ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): r = self.memory_temporal_stride_for_eval frame_idx_begin = frame_idx - r * self.num_maskmem frame_idx_end = frame_idx + r * self.num_maskmem - output_dict = inference_state["output_dict"] - non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] - for t in range(frame_idx_begin, frame_idx_end + 1): - non_cond_frame_outputs.pop(t, None) - for obj_output_dict in inference_state["output_dict_per_obj"].values(): - obj_output_dict["non_cond_frame_outputs"].pop(t, None) + batch_size = self._get_obj_num(inference_state) + for obj_idx in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + + +class SAM2VideoPredictorVOS(SAM2VideoPredictor): + """Optimized for the VOS setting""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._compile_all_components() + + def _compile_all_components(self): + print("Compiling all components for VOS setting. First time may be very slow.") + self.memory_encoder.forward = torch.compile( + self.memory_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + self.memory_attention.forward = torch.compile( + self.memory_attention.forward, + mode="max-autotune", + fullgraph=True, + dynamic=True, # Num. of memories varies + ) + + self.sam_prompt_encoder.forward = torch.compile( + self.sam_prompt_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, # Accuracy regression on True + ) + + self.sam_mask_decoder.forward = torch.compile( + self.sam_mask_decoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, # Accuracy regression on True + ) + + def forward_image(self, img_batch: torch.Tensor): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the backbone features and pos encoding to enable compilation. + """ + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + # Clone to help torch.compile + for i in range(len(backbone_out["backbone_fpn"])): + backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone() + backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][ + i + ].clone() + return backbone_out + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the outputs of prompt_encoder and mask_decoder to enable compilation. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + # Clone image_pe and the outputs of sam_prompt_encoder + # to enable compilation + sparse_embeddings = sparse_embeddings.clone() + dense_embeddings = dense_embeddings.clone() + image_pe = self.sam_prompt_encoder.get_dense_pe().clone() + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + # Clone the output of sam_mask_decoder + # to enable compilation + low_res_multimasks = low_res_multimasks.clone() + ious = ious.clone() + sam_output_tokens = sam_output_tokens.clone() + object_score_logits = object_score_logits.clone() + + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the memories and their pos enc to enable compilation. + """ + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + # Clone the feats and pos_enc to enable compilation + maskmem_features = maskmem_out["vision_features"].clone() + maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + + return maskmem_features, maskmem_pos_enc diff --git a/nanosam2/sam2/utils/misc.py b/nanosam2/sam2/utils/misc.py index b7c11f0..f61a87f 100644 --- a/nanosam2/sam2/utils/misc.py +++ b/nanosam2/sam2/utils/misc.py @@ -42,6 +42,79 @@ def __repr__(self): def __len__(self): # Return the number of elements in the dictionary return len(self.data) + +def pad_tensor(original_tensor:torch.Tensor, larger_shape:torch.Tensor, pad="front"): + """ + Pads the original tensor into a larger tensor filled with zeros. + + Parameters: + original_tensor (torch.Tensor): The tensor to be padded. + larger_shape (tuple): The desired shape for the larger tensor. + + Returns: + torch.Tensor: A larger tensor with the original tensor copied into it, padded with zeros. + """ + # Create a larger tensor filled with zeros, ensuring the same device and dtype + larger_tensor = torch.zeros(larger_shape, dtype=original_tensor.dtype, device=original_tensor.device) + + # Copy the original tensor into the larger tensor + if pad=="front": + match original_tensor.dim(): + case 4: + larger_tensor[-original_tensor.shape[0]:, -original_tensor.shape[1]:, -original_tensor.shape[2]:, -original_tensor.shape[3]:] = original_tensor + case 3: + larger_tensor[-original_tensor.shape[0]:, -original_tensor.shape[1]:, -original_tensor.shape[2]:] = original_tensor + case 2: + larger_tensor[-original_tensor.shape[0]:, -original_tensor.shape[1]:] = original_tensor + else: + #pad=="back" + print(original_tensor.shape) + print(larger_tensor.shape) + match original_tensor.dim(): + case 4: + larger_tensor[:original_tensor.shape[0], :original_tensor.shape[1], :original_tensor.shape[2], :original_tensor.shape[3]] = original_tensor + case 3: + larger_tensor[:original_tensor.shape[0], :original_tensor.shape[1], :original_tensor.shape[2]] = original_tensor + case 2: + larger_tensor[:original_tensor.shape[0], :original_tensor.shape[1]] = original_tensor + + return larger_tensor + +def copy_to_smaller_tensor(original_tensor, smaller_shape, pad="front"): + """ + Copies elements from the original tensor to a smaller tensor, ignoring excess elements. + + Parameters: + original_tensor (torch.Tensor): The tensor to copy from. + smaller_shape (tuple): The desired shape for the smaller tensor. + + Returns: + torch.Tensor: A smaller tensor with copied values from the original tensor. + """ + + if pad=="front": + match original_tensor.dim(): + case 4: + extracted_tensor = original_tensor[-smaller_shape[0]:, -smaller_shape[1]:, -smaller_shape[2]:, -smaller_shape[3]:] + case 3: + extracted_tensor = original_tensor[-smaller_shape[0]:, -smaller_shape[1]:, -smaller_shape[2]:] + case 2: + extracted_tensor = original_tensor[-smaller_shape[0]:, -smaller_shape[1]:] + case _: + extracted_tensor = original_tensor[-smaller_shape[0]:, -smaller_shape[1]:] + else: + match original_tensor.dim(): + case 4: + extracted_tensor = original_tensor[:smaller_shape[0], :smaller_shape[1], :smaller_shape[2], :smaller_shape[3]] + case 3: + extracted_tensor = original_tensor[:smaller_shape[0], :smaller_shape[1]:, :smaller_shape[2]] + case 2: + extracted_tensor = original_tensor[:smaller_shape[0], :smaller_shape[1]] + case _: + extracted_tensor = original_tensor[:smaller_shape[0], :smaller_shape[1]] + + return extracted_tensor + def get_sdpa_settings(): if torch.cuda.is_available(): diff --git a/nanosam2/tools/benchmark_video_performance.py b/nanosam2/tools/benchmark_video_performance.py index 1dcc1de..2842671 100644 --- a/nanosam2/tools/benchmark_video_performance.py +++ b/nanosam2/tools/benchmark_video_performance.py @@ -13,13 +13,7 @@ import itertools import time from nanosam2.sam2.build_sam import build_sam2_video_predictor - - -class ModelSource: - def __init__(self, name:str, checkpoint:str, cfg:str): - self.name = name - self.checkpoint = checkpoint - self.cfg = cfg +from nanosam2.datasets.containers import ModelSource class BenchmarkVideoPerformance: class BenchmarkIterationMetadata: diff --git a/sam2_configs/nanosam2.1_resnet18.yaml b/sam2_configs/nanosam2.1_resnet18.yaml index 145b938..4a0d31e 100644 --- a/sam2_configs/nanosam2.1_resnet18.yaml +++ b/sam2_configs/nanosam2.1_resnet18.yaml @@ -82,7 +82,7 @@ model: num_layers: 2 num_maskmem: 7 - image_size: 1024 + image_size: 512 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/onnx_export.py b/tools/onnx_export.py new file mode 100644 index 0000000..82c3a8c --- /dev/null +++ b/tools/onnx_export.py @@ -0,0 +1,300 @@ +# Export model blocks to ONNX. +# Autohor: paspf +# +# Usage: +# Call this script from cli to explort the desired model block. + +import onnx +from pathlib import Path +from nanosam2.datasets.containers import ModelSource +import torch +from nanosam2.sam2.build_sam import build_sam2_video_predictor +from pathlib import Path +import numpy as np +import onnxruntime as ort +import onnxsim + +def modify_filename(file_path, str_extension="_modified") -> Path: + """ + Add the string str_extension before the file extension to the name of a file. + """ + # Check if the input is already a Path object + if not isinstance(file_path, Path): + # Create a Path object from the file path string + path = Path(file_path) + else: + path = file_path + + # Create the new file name by adding "modified" before the extension + new_file_name = path.with_name(f"{path.stem}{str_extension}{path.suffix}") + + return new_file_name + +def remove_duplicate_outputs(model_path:str, out_file:str=None) -> onnx.ModelProto: + """ + Remove duplicate output nodes from an onnx model. + Nodes are considered as duplicated if they have the same name and output shape. + """ + model = onnx.load(model_path) + + # Get the current output nodes + original_outputs = model.graph.output + + # Dictionary to track unique outputs + unique_outputs = {} + new_outputs = [] + to_remove = [] + + # Iterate through the output nodes + for output in original_outputs: + # Create a key based on the output name and shape + shape = tuple(dim.dim_value for dim in output.type.tensor_type.shape.dim) + key = (output.name, shape) + + # Check if this key already exists + if key not in unique_outputs: + unique_outputs[key] = output + new_outputs.append(output) + else: + print(f"Duplicate output found: {output.name} with shape: {shape}. Removing it.") + to_remove.append(output) + + # Update the model's output nodes + for o in to_remove: + model.graph.output.remove(o) + + if out_file is not None: + onnx.save(model, out_file) + print(f"Modified model saved to {out_file}") + return {"onnx_model":model, "out_file": out_file} + + +def analyze_onnx_model(model:onnx.ModelProto): + """ + Identify the in- and output nodes of a onnx model. + """ + graph = model.graph + + for input in graph.input: + output_name = input.name + output_shape = [dim.dim_value for dim in input.type.tensor_type.shape.dim] + print(f'Input Name: {output_name}, Shape: {output_shape}') + + for output in graph.output: + output_name = output.name + output_shape = [dim.dim_value for dim in output.type.tensor_type.shape.dim] + print(f'Output Name: {output_name}, Shape: {output_shape}') + +def test_onnx_model(model_path:Path, input): + session = ort.InferenceSession(str(model_path)) + input_name = session.get_inputs()[0].name + outputs = session.run(None, {input_name: input}) + + print("ONNX model output shapes:") + for i, output in enumerate(outputs): + print(f"{i} | shape: {output.shape}") + return output + +def determine_torch_output_shapes(y, id=0): + """ + Determine the output shapes of the feature map produced by a torch inference. + """ + if isinstance(y, dict): + for k,v in y.items(): + if isinstance(v, list) or isinstance(v, tuple): + determine_torch_output_shapes(v) + else: + print(f"{id} | {k} shape: {v.shape}") + id += 1 + elif isinstance(y, list) or isinstance(y, tuple): + for k,v in enumerate(y): + determine_torch_output_shapes(v) + else: + print(f"out | - {y.shape}") + +def test_torch_model(torch_model:torch.nn, input, silent=False, use_unpack_operator=False) -> bool: + if use_unpack_operator: y = torch_model(*input) + else: y = torch_model(input) + if not silent: + determine_torch_output_shapes(y) + return True + + +def get_block_and_inputs(predictor:torch.nn, block:str, img_shape:list=[3,512,512]): + input_names = ["input"] + d_axes = None + use_unpack_operator = True + match block: + case "nanosam2": + print("Hint: Full model export not supported.") + torch_model = predictor + torch_input = (torch.randn(1,img_shape[0],img_shape[1],img_shape[2]),) + case "image-encoder": + torch_model = predictor.image_encoder + torch_input = (torch.randn(1,img_shape[0],img_shape[1],img_shape[2]),) + case "image-encoder-trunk": + torch_model = predictor.image_encoder.trunk + torch_input = (torch.randn(1,img_shape[0],img_shape[1],img_shape[2]),) + case "image-encoder-neck": + torch_model = predictor.image_encoder.neck + torch_input = [torch.randn(1,64,128,128), + torch.randn(1,128,64,64), + torch.randn(1,256,32,32), + torch.randn(1,512,16,16)] + use_unpack_operator = False + case "mask-decoder": + torch_model = predictor.sam_mask_decoder + torch_input = (torch.randn(1, 256, 32, 32), + torch.randn(1, 256, 32, 32), + torch.randn(1, 8, 256)) + case "mask-decoder-transformer": + torch_model = predictor.sam_mask_decoder.transformer + #tokens_limit = 16 + #objects_limit = 10 + tokens_limit = 8 + objects_limit = 1 + torch_input = (torch.randn(objects_limit, 256, 32, 32), + torch.randn(objects_limit, 256, 32, 32), + torch.randn(objects_limit, tokens_limit, 256)) + input_names = ["src", "pos_src", "tokens"] + d_axes = { + #"src": {0:"tokens_dyn_input_num_objects"}, + #"pos_src": {0:"tokens_dyn_input_num_objects"}, + 'tokens': {1: "tokens_dyn_input_num_points"}} + case "memory-encoder": + torch_model = predictor.memory_encoder + torch_input = (torch.randn(1, 256, 32, 32), + torch.randn(1, 1, 512, 512), + True) + case "not supported: prompt-encoder-mask-downscaling": + # mask_downscaling is only used in SAM2 when prompting with masks instead of boxes or points. + torch_model = predictor.sam_prompt_encoder.mask_downscaling + case "memory-attention": + print("Hint: Fails due to implementation of memory_attention. Value p is for the number of prompts entered.") + p = 2 + torch_model = predictor.memory_attention + torch_input = ([torch.randn(1024, p, 256)], + torch.randn(7200, p, 64), + [torch.randn(1024, p, 256)], + torch.randn(7200, p, 64), + 32) + case _: + print(f"Unknown model block: {block}") + exit() + return (torch_model, torch_input, use_unpack_operator, input_names, d_axes) + +def export_model_block(m:ModelSource, block:str, out_dir:Path, img_shape:list, use_simplify:bool=False, opset_version:int=13): + """ + Export a building block of nanosam2 to onnx. Not all blocks can be converted. + """ + print(f"Exporting Model: {m.name} block: {block}...") + out_dir.mkdir(parents=True, exist_ok=True) + predictor = build_sam2_video_predictor(m.cfg, m.checkpoint, torch.device("cpu")) + torch_model, torch_input, use_unpack_operator, input_names, d_axes = get_block_and_inputs(predictor=predictor, block=block, img_shape=img_shape) + + print(" - Testing torch model...", end="") + test_torch_model(torch_model, torch_input, silent=True, use_unpack_operator=use_unpack_operator) + print("OK") + + print(" - Exporting to ONNX...", end="") + export_path = out_dir / Path(f"{m.name}-{block}-sa1-v01-op{opset_version}.onnx") + torch.onnx.export(torch_model, torch_input, export_path, + export_params=True, + opset_version=opset_version, + simplify=True, + input_names=input_names, + dynamic_axes=d_axes + ) + print("OK") + print(f" - Model stored in {export_path}") + + if use_simplify: + print(" - simplify model...", end="") + model = onnx.load(export_path) + model, check = onnxsim.simplify(model) + if check: + export_path = modify_filename(export_path, str_extension="-simply") + onnx.save(model, export_path) + print("OK") + print(f" - Simplify model stored in {export_path}") + else: + print("Model simplification failed!") + + print(" - Analyze model...") + analyze_onnx_model(onnx.load(export_path)) + print(" - Model block successfully converted to ONNX.") + # test_onnx_model(export_path, torch_input) + +if __name__ == "__main__": + import argparse + print("---\n" + "Welcome to Nanosam2 ONNX exporter!\n" + "Nanosam2 ONNX exporter is used to export parts of the Nanosam2 model to ONNX files.\n" + "These ONNX files can be executed using onnxruntime or further processed and deployed on the desired hardware.\n" + "\n" + " - Use parameter img_shape to set the input shape. Default is [3,512,512].\n" + "---" + ) + models = [ + ModelSource("sam2.1_small", "results/sam2.1_hiera_s/sam2.1_hiera_small.pt", "../sam2_configs/sam2.1_hiera_s.yaml"), + ModelSource("nanosam2-resnet18", "results/sam2.1_hiera_s_resnet18/checkpoint.pth", "../sam2_configs/nanosam2.1_resnet18.yaml"), + ModelSource("nanosam2-mobilenetV3", "results/sam2.1_hiera_s_mobilenetV3_large/checkpoint.pth", "../sam2_configs/nanosam2.1_mobilenet_v3_large.yaml") + ] + + valid_exports = [ + "nanosam2", + "all", + "image-encoder", + "image-encoder-trunk", + "image-encoder-neck", + "mask-decoder", + "mask-decoder-transformer", + "memory-encoder", + "memory-attention" + ] + + valid_encoders = [ + "hiera_small", + "resnet18", + "mobilenetV3_large", + "casvit_s" + ] + + parser = argparse.ArgumentParser("Nanosam2 ONNX exporter") + parser.add_argument("--export", type=str, default="image-encoder", choices=valid_exports, help='Model block to export, use all to export all supported blocks as individuals, use nanosam2 to export the whole model.') + parser.add_argument("--output_path", type=str, default="model_exports2", help="Export directory.") + parser.add_argument("--img_shape", nargs='+', type=int, default=[3,512,512], help="Image shape to use.") + parser.add_argument("--encoder_type", type=str, default="resnet18", choices=valid_encoders) + parser.add_argument("--opset", type=int, default=13) + parser.add_argument("--simplify", action='store_true', help='Simplify model') + + out_dir = Path("model_exports2") + + args = parser.parse_args() + + match args.encoder_type: + case "hiera_small": + model = models[0] + case "resnet18": + model = models[1] + case "mobilenetV3_large": + model = models[2] + case "casvit_s": + model = models[3] + case _: + print("Enocer type not supported.") + exit(1) + + if not Path(model.checkpoint).is_file(): + print(f"Path to model {model.checkpoint} not found.") + exit(1) + + if args.export != "all": + # Export a single block or the whole nanosam2 model. + export_model_block(model, args.export, out_dir, args.img_shape, use_simplify=args.simplify, opset_version=args.opset) + else: + # Export all blocks as individuals. + for b in valid_exports: + if b == "all": continue + export_model_block(model, b, out_dir, args.img_shape, use_simplify=args.simplify, opset_version=args.opset) + print("done.")