Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7c48d7f
Move class ModelSoruce into module misc
Dec 11, 2024
fea6c45
Merge branch 'main' of https://github.com/paspf/nanosam2 into main
Dec 11, 2024
7d8ff28
Add tools to export model parts as onnx.
Dec 13, 2024
29329c8
Add new script for exporting onnx files.
Dec 17, 2024
2a47159
Add support for exporting memory-encoder to onnx.
Dec 17, 2024
2406feb
Improve onnx export.
Dec 18, 2024
a006365
Add onnx export for mask-decoder
Jan 7, 2025
f350cbc
Add support for in-model callbacks.
Jan 9, 2025
c31b839
Refactor ModelSource class.
Jan 10, 2025
c718df3
Update start message.
Jan 10, 2025
83cc831
Update imports.
Jan 10, 2025
78e62e1
Add option to pad mask decoders transformer.
Jan 14, 2025
4f1fa6c
Onnx export can handle dynamic shapes.
Jan 14, 2025
848e669
Only use single dynamic axis for transformer
Jan 15, 2025
29f1fa5
Add functions to pad tensors
Jan 15, 2025
0a1ad2e
Set defalt objects limit to 1
Jan 15, 2025
c4d0ff8
Make fixed transformer shapes more flexible
Jan 16, 2025
efe9fdb
Reduce default image size to 512px
Jan 16, 2025
73dc704
Add feature maps callback for image encoder neck.
Jan 28, 2025
c94e7b0
Disable fixed transformer shapes by default.
Jan 28, 2025
2894dc4
Add support for exporting image encoders input feature maps.
Jan 28, 2025
c72da8a
Set opset using cli.
Feb 28, 2025
60aa7f2
Fix opset version passing when using cli.
Mar 6, 2025
f39be0b
Add option to enable onnxsim
Dec 18, 2025
00d5dbc
Update to latest SAM2 version
Dec 22, 2025
d967c12
Print current step when compiling model
Jan 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
model_exports/*
model_exports2/*
videos/*
21 changes: 14 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,30 @@ This repository is inspired by https://github.com/NVIDIA-AI-IOT/nanosam and adap
Although the inference speed of the SAM2.1 Hiera backbones is already quite fast on GPUs, it is still difficult to deploy on edge devices.
This repository aims to provide a more efficient alternative for SAM2.1 inference, with a focus on backbones that are smaller and faster to deploy.

## Installation

## Dependencies and Prerequirements
### Dependencies and Prerequirements

- Create a new Python 3.10+ environment and clone the repository.
- Create a new Python 3.12+ environment and clone the repository.
- Install the dependencies listed below:

```
pip install matplotlib torchvision tqdm hydra-core pycocotools requests iopath
pip install matplotlib torchvision tqdm hydra-core pycocotools requests iopath opencv-python
```

- Install the repository as editable package `pip install -e .`

## Inference
### Download checkpoints

You can find and download pretrained nanosam2 checkpoints [here](https://drive.google.com/drive/folders/15wApVHwqJGunjDP_cx5YZDCTEKliOMCQ?usp=sharing). Each backbone was trained for 10 epochs on 14 SA1 datasets, i.e. ~175k images.

## Inference Demos

All inference demos are executed on the [Bedroom](https://github.com/facebookresearch/sam2/blob/2b90b9f5ceec907a1c18123530e92e794ad901a4/notebooks/videos/bedroom.mp4) video file, shared in the original [sam2](https://github.com/facebookresearch/sam2) repository.

### Video

Load all frames of a video at once into Nanosam2 and perform tracking of objects from any frame. To use the script you have to obtain all frames of the video as `.jpg` file. Place all `.jpg` files
Load all frames of a video at once into Nanosam2 and perform tracking of objects from any frame. To run the video demo you have to obtain all frames of the video as `.jpg` file. Place all `.jpg` files
in the same directory and pass the directory to `video_frames_demo.py`.

Extracting all frames of a video uns FFmpeg:
Expand All @@ -36,7 +43,7 @@ python demos/video_frames_demo.py --config nanosam2.1_resnet18.yaml --checkpoint

### Camera Live Stream

Stream a video (of a camera or a video file) frame by frame into Nanosam2. Perform tracking of objects from any frame.
Stream a video (from a camera or video file source) frame by frame into Nanosam2. Start object tracking from any frame in the stream.


For ResNet18 backend:
Expand Down Expand Up @@ -122,9 +129,9 @@ python nanosam2/tools/compute_eval_coco_metric.py results/sam2.1_hiera_s_resnet1


## Results FP32
You can find pretrained nanosam2 checkpoints [here](https://drive.google.com/drive/folders/15wApVHwqJGunjDP_cx5YZDCTEKliOMCQ?usp=sharing).

Each backbone was trained for 10 epochs on 14 SA1 datasets, i.e. ~175k images.

| Backbone | num_epochs | mIoU All | mIoU Small | mIoU Medium | mIoU Large |
| -------- | -------- | -------- | -------- | -------- | -------- |
| resnet18 | 10 | 0.69 | 0.62 | 0.73 | 0.76 |
Expand Down
56 changes: 42 additions & 14 deletions demos/live_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Live Demo
# Nanosam2 Live Demo
#
# Run inferences on a video stram from a camera or a video file.
#
# To run this script, create a new python environment (3.12) install all packages listed in the README.md file
# and add the "nanosam2" directory to your pythonpath (or install the package).
#
# Based on "https://github.com/Gy920/segment-anything-2-real-time/blob/main/demo/demo.py".


Expand All @@ -14,10 +20,11 @@
parser.add_argument("--config", type=str, default="sam2_hiera_s", help="The path to a sam2 config.")
parser.add_argument("--checkpoint", type=str, default="sam2_checkpoints/sam2.1_hiera_small.pt")
parser.add_argument('--video', default=0, help='Path to a video or a camera id, default: 0')
parser.add_argument('--device', default="cpu", help='Device to run the model on, default: cpu, also supports cuda')
args = parser.parse_args()

# Configure Device.
device = "cuda"
device = args.device
if device == "cuda":
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
Expand All @@ -32,22 +39,30 @@
frametimes = []

def _compile_model_blocks(model, model_settings:list, compile_backend):
if (all(model_settings) == False):
print("Skipping Model Compilation...")
return model
print("Compiling Model...")
if model_settings[0]: # image_encoder
print(" - Compiling Image Encoder...")
model.image_encoder = torch.compile(model.image_encoder, backend=compile_backend, dynamic=False)
if model_settings[1]: # memory_attention
print(" - Compiling Memory Attention...")
model.memory_attention = torch.compile(model.memory_attention, backend=compile_backend)
if model_settings[2]: # sam_mask_decoder
print(" - Compiling SAM Mask Decoder...")
model.sam_mask_decoder = torch.compile(model.sam_mask_decoder, backend=compile_backend)
if model_settings[3]: # sam_prompt_encoder
print(" - Compiling SAM Prompt Encoder...")
model.sam_prompt_encoder = torch.compile(model.sam_prompt_encoder, backend=compile_backend)
if model_settings[4]: # memory_encoder
print(" - Compiling Memory Encoder...")
model.memory_encoder = torch.compile(model.memory_encoder, backend=compile_backend)
print("Compile finished.")
return model

# Compile Model if Required.
#predictor = _compile_model_blocks(predictor, [True, False, False, False, False], "inductor")
predictor = _compile_model_blocks(predictor, [False, False, False, False, False], "inductor")

# Open Video Stream.
cap = cv2.VideoCapture(args.video)
Expand All @@ -67,26 +82,39 @@ def _compile_model_blocks(model, model_settings:list, compile_backend):
predictor.load_first_frame(frame)
if_init = True

ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
# Let's add a positive click at (x, y) = (210, 350) to get started
# ---------------------------------------------------------------------------
# for demo video: https://github.com/facebookresearch/sam2/blob/2b90b9f5ceec907a1c18123530e92e794ad901a4/notebooks/videos/bedroom.mp4
# ---------------------------------------------------------------------------

# Add bbox - boy
ann_frame_idx = 0 # frame index to annotate
ann_obj_id = 1 # unique object id to annotate
bbox = np.array([[230, 134], [294, 219]], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox
)

# Add bbox - girl
ann_frame_idx = 0 # frame index to annotate
ann_obj_id = 2 # unique object id to annotate
bbox = np.array([[353, 11], [451, 122]], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox
)

##! add points, `1` means positive click and `0` means negative click
# ---------------------------------------------------------------------------
# other bounding box, mask and point examples
# ---------------------------------------------------------------------------

# Add points, `1` means positive click and `0` means negative click
# points = np.array([[660, 267]], dtype=np.float32)
# labels = np.array([1], dtype=np.int32)

# _, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
# frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels
# )

## ! add bbox
bbox = np.array([[600, 214], [765, 286]], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox
)

##! add mask
# Add mask
# mask_img_path="../notebooks/masks/aquarium/aquarium_mask.png"
# mask = cv2.imread(mask_img_path, cv2.IMREAD_GRAYSCALE)
# mask = mask / 255
Expand Down
15 changes: 11 additions & 4 deletions demos/video_frames_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Run inferences on the frames exported form a video.
#
# To run this script, create a new python environment (3.12) install all packages listed in the README.md file
# and add the "nanosam2" directory to your pythonpath.
# and add the "nanosam2" directory to your pythonpath (or install the package).
# If you are using bash add the following line to your .bashrc
# export PYTHONPATH="<path-to>/nanosam2"

Expand Down Expand Up @@ -102,13 +102,20 @@ def show_box(box, ax):

predictor.reset_state(inference_state)

# ---------------------------------------------------------------------------
# for demo video: https://github.com/facebookresearch/sam2/blob/2b90b9f5ceec907a1c18123530e92e794ad901a4/notebooks/videos/bedroom.mp4
# ---------------------------------------------------------------------------

print("\n#1: Set two points and predict a mask...")
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)

# Add a positive click at (x, y).
# For labels, `1` means positive click and `0` means negative click
points = np.array([[770, 420], [750, 380]], dtype=np.float32)


# Add two points on the shirt of the girl.
points = np.array([[388, 139], [414, 165]], dtype=np.float32)
labels = np.array([1,1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
Expand Down Expand Up @@ -155,10 +162,10 @@ def show_box(box, ax):
ann_obj_id = 4 # give a unique id to each object we interact with (it can be any integers)

# Add a positive click at (x, y).
points = np.array([[560, 350], [770, 420], [750, 380]], dtype=np.float32)
points = np.array([[379, 149], [102, 138], [100, 172]], dtype=np.float32)
labels = np.array([1,1,1], np.int32)
# Box coordinates.
box = np.array([400, 320, 1100, 650], dtype=np.float32)
box = np.array([370, 111, 427, 185], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
Expand Down
9 changes: 9 additions & 0 deletions nanosam2/datasets/containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Containers for different use cases for Nanosam2.
#


class ModelSource:
def __init__(self, name:str, checkpoint:str, cfg:str):
self.name = name
self.checkpoint = checkpoint
self.cfg = cfg
4 changes: 2 additions & 2 deletions nanosam2/sam2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hydra import initialize_config_module
from hydra.core.global_hydra import GlobalHydra

if not GlobalHydra().is_initialized():
initialize_config_module("sam2_configs", version_base="1.2")
if not GlobalHydra.instance().is_initialized():
initialize_config_module("sam2_configs", version_base="1.2")
26 changes: 24 additions & 2 deletions nanosam2/sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,21 @@ def build_sam2(
mode="eval",
load_image_encoder=True,
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):

if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
]
# Read config and init model
config_name = f'{config_dir}/{config_file}' if config_dir is not None else config_file

cfg = compose(config_name=config_name)
cfg = compose(config_name=config_name, overrides=hydra_overrides_extra)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path, load_image_encoder=load_image_encoder)
Expand All @@ -98,11 +106,18 @@ def build_sam2_video_predictor(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
vos_optimized=False,
**kwargs,
):
hydra_overrides = [
"++model._target_=nanosam2.sam2.sam2_video_predictor.SAM2VideoPredictor",
]
if vos_optimized:
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
"++model.compile_image_encoder=True", # Let sam2_base handle this
]

if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
Expand Down Expand Up @@ -191,6 +206,13 @@ def _load_checkpoint(model, ckpt_path, load_image_encoder=True):
for k in list(sd.keys()):
if "image_encoder" in k:
del sd[k]

missing_keys, unexpected_keys = model.load_state_dict(sd, strict=load_image_encoder)

if missing_keys:
logging.error(missing_keys)
raise RuntimeError()
if unexpected_keys:
logging.error(unexpected_keys)
raise RuntimeError()
logging.info("Loaded checkpoint sucessfully")
19 changes: 17 additions & 2 deletions nanosam2/sam2/modeling/backbones/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,25 @@ def __init__(
self.trunk = trunk
self.neck = neck
self.scalp = scalp
self.feature_maps_callback=None
# assert (
# self.trunk.channel_list == self.neck.backbone_channel_list
# ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"

def forward(self, sample: torch.Tensor):
# Forward through backbone
features, pos = self.neck(self.trunk(sample))
# Feature map callback.
if self.feature_maps_callback is not None:
self.feature_maps_callback("image-encoder:trunk-input", {"0":sample.cpu()})

# Forward through backbone (trunk)
trunk = self.trunk(sample)

# Feature map callback.
if self.feature_maps_callback is not None:
self.feature_maps_callback("image-encoder:trunk-output", {"0":trunk[0].cpu(), "1":trunk[1].cpu(), "2":trunk[2].cpu(), "3":trunk[3].cpu()})

# Forward through backbone (neck)
features, pos = self.neck(trunk)
if self.scalp > 0:
# Discard the lowest resolution features
features, pos = features[: -self.scalp], pos[: -self.scalp]
Expand All @@ -41,6 +53,9 @@ def forward(self, sample: torch.Tensor):
"backbone_fpn": features,
}
return output

def set_feature_maps_callback(self, fun):
self.feature_maps_callback = fun


class FpnNeck(nn.Module):
Expand Down
10 changes: 4 additions & 6 deletions nanosam2/sam2/modeling/backbones/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def window_partition(x, window_size):
Hp, Wp = H + pad_h, W + pad_w

x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
return windows, (Hp, Wp)


Expand All @@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(
x = windows.reshape(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)

if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
x = x[:, :H, :W, :]
return x


Expand Down
Loading