Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,5 @@ cython_debug/

# PyPI configuration file
.pypirc
zoo/*
*.pickle
9 changes: 7 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"name": "SSYA debug",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"program": "${workspaceFolder}/ssya/main.py",
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"args": [
"-i",
"${workspaceFolder}/images",
],
,
"justMyCode": false,
},
],
Expand Down
60 changes: 26 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,51 +1,43 @@
# Template : How to start and customize?
# ssya

- [ ] Create new repository from this template
- [ ] Inside pyproject.toml rename `package_name`
- [ ] Rename aisp_template directory to `package_name`
- [ ] Update `README.md`
SSYA to graficzne narzędzie do segmentacji i wyszukiwania podobnych regionów w zbiorach obrazów, wykorzystujące Segment Anything v2 (SAM2).

# Template directory structure
## Funkcje
- Offline’owe indeksowanie i cache’owanie embeddingów z SAM2
- Interaktywne GUI z paskami postępu i filtrowaniem według progu podobieństwa
- Szybkie wyszukiwanie podobnych wykryć

- package_name/ - Insert package code here
- tests/ - Insert unit tests here
- scripts/ - Insert scripts here
- images/ - If this is CV/AI repository then insert images here

# Package name

Write package short description here.

# Installation : Developer

Use poetry to install the package in development mode.
## Wymagania
- Python 3.11+

## Instalacja
```bash
git clone {URL}
uv sync
uv venv
git clone <repo-url>
cd ssya
pdm install
```

# Testing

Run the tests using pytest.

## Użycie
```bash
uv run pytest
ssya -i /ścieżka/do/dataset
```
Jeśli nie podasz `-i`, pojawi się okno dialogowe do wyboru folderu ze zbiorami.

# Release
## Format danych
Umieść w katalogu obrazy i pliki anotacji TXT (jedna linia na obiekt: `klasa xc yc szerokość wysokość`, wartości znormalizowane).

Github workflow is created to automatically release the package to PyPI when a new tag "vX.X.X" (example v1.0.0) is pushed to the main branch.
## Testy
```bash
pdm run pytest
```

## Wydania
Nowe tagi `vX.X.X` wrzucane na `main` są automatycznie publikowane na PyPI.
Możesz też ręcznie:
```bash
git tag vX.X.X
git push --tags
```

Or manually build and upload the package to PyPI using the following command.

```
uv build
pdm build
pdm publish
```

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions images/414b2c60d4399fd180320d42e7c35b2514b2e0fb.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3 0.493349 0.459251 0.911729 0.442731
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions images/7e954dc0a66a0659163da55410280574388f137d.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
2 0.410937 0.560937 0.268750 0.288542
2 0.721484 0.518750 0.228906 0.179167
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions images/8f4de79e2f402d169ca9d902d8bbd45be34a6361.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
6 0.415104 0.306019 0.545833 0.350926
2 0.670052 0.399537 0.165104 0.191667
2 0.926042 0.181019 0.020833 0.036111
2 0.685937 0.126389 0.022917 0.025000
2 0.890365 0.175000 0.021354 0.024074
2 0.865104 0.165278 0.019792 0.019444
14 0.608333 0.418981 0.016667 0.021296
Binary file added images/image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added images/image.txt
Empty file.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dynamic = ["version"]

dependencies = [
"dotenv>=0.9.9",
"sam2>=1.1.0",
"yaya-tools",
]

Expand Down
77 changes: 77 additions & 0 deletions ssya/controllers/dataset_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import logging
from pathlib import Path

import cv2
import numpy as np
from tqdm import tqdm
from yaya_tools.helpers.dataset import load_directory_images_annotatations

from ssya.controllers.features_index import FeatureIndex
from ssya.controllers.sam2_wrapper import Sam2Runner
from ssya.models.detection import Detection

logger = logging.getLogger(__name__)


class DatasetManager:
"""Loads dataset, detections, builds/loads feature index."""

def __init__(self, root: Path):
self.root = root
ann_map = load_directory_images_annotatations(str(root))
self.images: list[str] = list(ann_map.keys())
self.ann_map = ann_map
self.detections: dict[str, list[Detection]] = {}
for img_idx, img_path in enumerate(self.images):
if not ann_map[img_path]:
self.detections[img_path] = []
continue
with open(root / ann_map[img_path]) as f:
lines = [l.split() for l in f]
self.detections[img_path] = [
Detection(int(cls), (float(xc), float(yc), float(w), float(h)), img_idx) for cls, xc, yc, w, h in lines
]
logger.info("Dataset: %d images (%d with annotations)", len(self.images), len(self.detections))

Copilot AI Jul 12, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(self.detections) returns the total number of images (keys in the dict), not only those with annotations. To accurately log images with annotations, count entries where the detection list is non-empty.

Suggested change
logger.info("Dataset: %d images (%d with annotations)", len(self.images), len(self.detections))
logger.info("Dataset: %d images (%d with annotations)", len(self.images), sum(1 for dets in self.detections.values() if dets))

Copilot uses AI. Check for mistakes.

# Build or load feature index ---------------------------------
self.index_path = root / "features.pickle"
if self.index_path.exists():
logger.info("Loading cached features …")
self.fidx = FeatureIndex.load(self.index_path)
else:
self.fidx = FeatureIndex()
self._build_index()
self.fidx.save(self.index_path)

# Detections : Update with embeddings from the index
for img_path, dets in self.detections.items():
self.detections[img_path] = self.fidx.get_features(dets)

# ------------------------------------------------------------------

def _build_index(self) -> None:
sam = Sam2Runner()
logger.info("Building feature index (SAM2)…")
for img_idx, img_path in enumerate(tqdm(self.images, desc="Images")):
img = cv2.imread(str(self.root / img_path))
if img is None:
continue
for det_idx, det in enumerate(self.detections[img_path]):
mask, emb = sam.mask_and_embed(img, det.bbox_pixels(img.shape[1], img.shape[0]))
det.embedding = emb
self.fidx.add(img_idx, det_idx, emb)

# ------------------------------------------------------------------

# Convenience helpers used by GUI ----------------------------------
def image(self, idx: int) -> np.ndarray:
"""Get image at index `idx`."""
return cv2.imread(str(self.root / self.images[idx]))

def image_detections(self, idx: int) -> list[Detection]:
"""Get detections for the image at index `idx`."""
return self.detections[self.images[idx]]

def image_count(self) -> int:
"""Get the number of images in the dataset."""
return len(self.images)
58 changes: 58 additions & 0 deletions ssya/controllers/features_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import logging
import pickle
from pathlib import Path
from typing import Any

import numpy as np

from ssya.helpers.metrics import cosine_similarity
from ssya.models.detection import Detection # type: ignore

logger = logging.getLogger(__name__)


class FeatureIndex:
"""Persistent RAM index: list of (image_idx, det_idx, embedding)."""

def __init__(self, entries: list[dict[str, Any]] | None = None):
"""Initialize with existing entries or empty."""
if entries is not None:
self.entries = entries
else:
self.entries: list[dict[str, Any]] = []

def add(self, image_idx: int, det_idx: int, emb: np.ndarray):
self.entries.append({"image_idx": image_idx, "det_idx": det_idx, "emb": emb})

def save(self, path: Path):
with open(path, "wb") as f:
pickle.dump(self.entries, f, protocol=pickle.HIGHEST_PROTOCOL)

def get_features(self, detections: list[Detection]) -> list[Detection]:
"""Update detections list with embeddings from the index."""
for det in detections:
if det.embedding is None:
for e in self.entries:
if e["image_idx"] == det.image_idx and e["det_idx"] == det.class_id:

Copilot AI Jul 12, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition compares the stored det_idx against det.class_id, but det_idx represents the detection's position, not its class. Consider comparing e['det_idx'] to the detection's index or storing class_id instead to ensure correct embedding lookup.

Copilot uses AI. Check for mistakes.
det.embedding = e["emb"]
break

return detections

@classmethod
def load(cls, path: Path) -> FeatureIndex:
with open(path, "rb") as f:
entries = pickle.load(f)

return cls(entries)

def get_similar_images(self, ref_emb: np.ndarray, thresh: float) -> set[int]:
"""Find images with at least one detection above the threshold."""
imgs: set[int] = set()
for e in self.entries:
if cosine_similarity(ref_emb, e["emb"]) >= thresh:
imgs.add(e["image_idx"])

return imgs
95 changes: 95 additions & 0 deletions ssya/controllers/sam2_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

import logging
import os
from pathlib import Path

import cv2 # type: ignore
import numpy as np # type: ignore
import requests
import torch
import torch.nn.functional as F
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

logger = logging.getLogger(__name__)


def gem_pooling(features: torch.Tensor, mask: torch.Tensor, p: float = 3.0):
"""
GeM pooling z maską: features (B, C, H, W), mask (B, 1, H, W) – bool/int.
Zwraca (B, C)
"""
eps = 1e-6
masked = features * mask # (B, C, H, W)
pooled = F.avg_pool2d(masked.clamp(min=eps).pow(p), kernel_size=masked.shape[-2:]) # (B, C, 1, 1)
pooled = pooled.pow(1.0 / p).squeeze(-1).squeeze(-1)
# uwzględnij liczbę aktywnych pikseli
denom = mask.flatten(2).sum(-1).clamp(min=1e-6) # (B,1)
pooled = pooled / denom
return F.normalize(pooled, dim=-1)


class Sam2Runner:
"""Light wrapper that exposes mask + embedding for a bbox."""

_instance = None # singleton for reuse

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._init_model()
return cls._instance

# ------------------------------------------------------------------

def _init_model(self) -> None:
"""Initialize the SAM2 model."""
model_path = "zoo/sam2_tiny.pth"
if not os.path.exists(model_path):
url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt"
Path(model_path).parent.mkdir(parents=True, exist_ok=True)
logger.info("Downloading SAM2 weights …")
with requests.get(url, stream=True) as r, open(model_path, "wb") as f:
for chunk in r.iter_content(1 << 14):
f.write(chunk)
cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES", "") else "cpu"
model = build_sam2(cfg, model_path).to(device).eval()
self._predictor = SAM2ImagePredictor(model)
self.device = device

# ------------------------------------------------------------------

def mask_and_embed(self, img_bgr: np.ndarray, box_px: tuple[int, int, int, int]) -> tuple[np.ndarray, np.ndarray]:
img_rgb = img_bgr[:, :, ::-1].copy()
self._predictor.set_image(img_rgb) # ← tutaj SAM2 wylicza embedding

# ---------- segmentacja ----------
masks, _, _ = self._predictor.predict(
box=np.array([box_px[0], box_px[1], box_px[0] + box_px[2], box_px[1] + box_px[3]]),
multimask_output=False,
return_logits=False,
)
mask_hr = masks[0] # (H, W) bool

# ---------- mapa cech ----------
feat_container = getattr(self._predictor, "_features", None)
if feat_container is None:
raise RuntimeError("Brak _features w predictorze — sprawdź wersję biblioteki")

# słownik → weź 'image_embed'
feat_map = feat_container.get("image_embed", None) if isinstance(feat_container, dict) else feat_container

if feat_map is None or not torch.is_tensor(feat_map):
raise RuntimeError("Nie znalazłem tensora z mapą cech w _features")

B, C, h, w = feat_map.shape

mask_lr = cv2.resize(mask_hr.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
mask_t = torch.from_numpy(mask_lr).to(feat_map.device).view(1, 1, h, w)

emb_t = gem_pooling(feat_map, mask_t, p=3.0) # z poprzedniej odpowiedzi
emb = emb_t.cpu().numpy()[0] # (C,)

return mask_hr, emb
9 changes: 9 additions & 0 deletions ssya/helpers/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import numpy as np


def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
if a is None or b is None:
return 0.0
if np.linalg.norm(a) == 0 or np.linalg.norm(b) == 0:
return 0.0
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
Loading