-
Notifications
You must be signed in to change notification settings - Fork 0
Feature/sam2 features #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
folkien
wants to merge
22
commits into
main
Choose a base branch
from
feature/sam2_features
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
30b6a0b
Fixes.
folkien a04ac8c
Fixes.
folkien 9aa9499
Fixes.
folkien 8250987
Fixes.
folkien c8ea7fa
Fixes.
folkien ace2a4c
Fixes.
folkien 1e1b1e5
Fixes.
folkien 1bc0331
Fixes.
folkien c442158
Fixes.
folkien 331d3f4
Fixes.
folkien 6d7c691
fixes.
folkien 6a52f8b
Fixes.
folkien 213cdb3
Fixes.
folkien b29eb96
Fixes.
folkien e060581
Fixes.
folkien 71f6493
Fixes.
folkien 621981c
Fixes.
folkien ae9d136
Fixes.
folkien bc94f6c
Fixes.
folkien 2392bd4
Fixes.
folkien ccfb302
Updsate.
folkien 2b74dab
Fixes.
folkien File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -172,3 +172,5 @@ cython_debug/ | |
|
|
||
| # PyPI configuration file | ||
| .pypirc | ||
| zoo/* | ||
| *.pickle | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,51 +1,43 @@ | ||
| # Template : How to start and customize? | ||
| # ssya | ||
|
|
||
| - [ ] Create new repository from this template | ||
| - [ ] Inside pyproject.toml rename `package_name` | ||
| - [ ] Rename aisp_template directory to `package_name` | ||
| - [ ] Update `README.md` | ||
| SSYA to graficzne narzędzie do segmentacji i wyszukiwania podobnych regionów w zbiorach obrazów, wykorzystujące Segment Anything v2 (SAM2). | ||
|
|
||
| # Template directory structure | ||
| ## Funkcje | ||
| - Offline’owe indeksowanie i cache’owanie embeddingów z SAM2 | ||
| - Interaktywne GUI z paskami postępu i filtrowaniem według progu podobieństwa | ||
| - Szybkie wyszukiwanie podobnych wykryć | ||
|
|
||
| - package_name/ - Insert package code here | ||
| - tests/ - Insert unit tests here | ||
| - scripts/ - Insert scripts here | ||
| - images/ - If this is CV/AI repository then insert images here | ||
|
|
||
| # Package name | ||
|
|
||
| Write package short description here. | ||
|
|
||
| # Installation : Developer | ||
|
|
||
| Use poetry to install the package in development mode. | ||
| ## Wymagania | ||
| - Python 3.11+ | ||
|
|
||
| ## Instalacja | ||
| ```bash | ||
| git clone {URL} | ||
| uv sync | ||
| uv venv | ||
| git clone <repo-url> | ||
| cd ssya | ||
| pdm install | ||
| ``` | ||
|
|
||
| # Testing | ||
|
|
||
| Run the tests using pytest. | ||
|
|
||
| ## Użycie | ||
| ```bash | ||
| uv run pytest | ||
| ssya -i /ścieżka/do/dataset | ||
| ``` | ||
| Jeśli nie podasz `-i`, pojawi się okno dialogowe do wyboru folderu ze zbiorami. | ||
|
|
||
| # Release | ||
| ## Format danych | ||
| Umieść w katalogu obrazy i pliki anotacji TXT (jedna linia na obiekt: `klasa xc yc szerokość wysokość`, wartości znormalizowane). | ||
|
|
||
| Github workflow is created to automatically release the package to PyPI when a new tag "vX.X.X" (example v1.0.0) is pushed to the main branch. | ||
| ## Testy | ||
| ```bash | ||
| pdm run pytest | ||
| ``` | ||
|
|
||
| ## Wydania | ||
| Nowe tagi `vX.X.X` wrzucane na `main` są automatycznie publikowane na PyPI. | ||
| Możesz też ręcznie: | ||
| ```bash | ||
| git tag vX.X.X | ||
| git push --tags | ||
| ``` | ||
|
|
||
| Or manually build and upload the package to PyPI using the following command. | ||
|
|
||
| ``` | ||
| uv build | ||
| pdm build | ||
| pdm publish | ||
| ``` | ||
|
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| 3 0.493349 0.459251 0.911729 0.442731 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| 2 0.410937 0.560937 0.268750 0.288542 | ||
| 2 0.721484 0.518750 0.228906 0.179167 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| 6 0.415104 0.306019 0.545833 0.350926 | ||
| 2 0.670052 0.399537 0.165104 0.191667 | ||
| 2 0.926042 0.181019 0.020833 0.036111 | ||
| 2 0.685937 0.126389 0.022917 0.025000 | ||
| 2 0.890365 0.175000 0.021354 0.024074 | ||
| 2 0.865104 0.165278 0.019792 0.019444 | ||
| 14 0.608333 0.418981 0.016667 0.021296 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ dynamic = ["version"] | |
|
|
||
| dependencies = [ | ||
| "dotenv>=0.9.9", | ||
| "sam2>=1.1.0", | ||
| "yaya-tools", | ||
| ] | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| import logging | ||
| from pathlib import Path | ||
|
|
||
| import cv2 | ||
| import numpy as np | ||
| from tqdm import tqdm | ||
| from yaya_tools.helpers.dataset import load_directory_images_annotatations | ||
|
|
||
| from ssya.controllers.features_index import FeatureIndex | ||
| from ssya.controllers.sam2_wrapper import Sam2Runner | ||
| from ssya.models.detection import Detection | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class DatasetManager: | ||
| """Loads dataset, detections, builds/loads feature index.""" | ||
|
|
||
| def __init__(self, root: Path): | ||
| self.root = root | ||
| ann_map = load_directory_images_annotatations(str(root)) | ||
| self.images: list[str] = list(ann_map.keys()) | ||
| self.ann_map = ann_map | ||
| self.detections: dict[str, list[Detection]] = {} | ||
| for img_idx, img_path in enumerate(self.images): | ||
| if not ann_map[img_path]: | ||
| self.detections[img_path] = [] | ||
| continue | ||
| with open(root / ann_map[img_path]) as f: | ||
| lines = [l.split() for l in f] | ||
| self.detections[img_path] = [ | ||
| Detection(int(cls), (float(xc), float(yc), float(w), float(h)), img_idx) for cls, xc, yc, w, h in lines | ||
| ] | ||
| logger.info("Dataset: %d images (%d with annotations)", len(self.images), len(self.detections)) | ||
|
|
||
| # Build or load feature index --------------------------------- | ||
| self.index_path = root / "features.pickle" | ||
| if self.index_path.exists(): | ||
| logger.info("Loading cached features …") | ||
| self.fidx = FeatureIndex.load(self.index_path) | ||
| else: | ||
| self.fidx = FeatureIndex() | ||
| self._build_index() | ||
| self.fidx.save(self.index_path) | ||
|
|
||
| # Detections : Update with embeddings from the index | ||
| for img_path, dets in self.detections.items(): | ||
| self.detections[img_path] = self.fidx.get_features(dets) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _build_index(self) -> None: | ||
| sam = Sam2Runner() | ||
| logger.info("Building feature index (SAM2)…") | ||
| for img_idx, img_path in enumerate(tqdm(self.images, desc="Images")): | ||
| img = cv2.imread(str(self.root / img_path)) | ||
| if img is None: | ||
| continue | ||
| for det_idx, det in enumerate(self.detections[img_path]): | ||
| mask, emb = sam.mask_and_embed(img, det.bbox_pixels(img.shape[1], img.shape[0])) | ||
| det.embedding = emb | ||
| self.fidx.add(img_idx, det_idx, emb) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
|
|
||
| # Convenience helpers used by GUI ---------------------------------- | ||
| def image(self, idx: int) -> np.ndarray: | ||
| """Get image at index `idx`.""" | ||
| return cv2.imread(str(self.root / self.images[idx])) | ||
|
|
||
| def image_detections(self, idx: int) -> list[Detection]: | ||
| """Get detections for the image at index `idx`.""" | ||
| return self.detections[self.images[idx]] | ||
|
|
||
| def image_count(self) -> int: | ||
| """Get the number of images in the dataset.""" | ||
| return len(self.images) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| import pickle | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
|
|
||
| from ssya.helpers.metrics import cosine_similarity | ||
| from ssya.models.detection import Detection # type: ignore | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class FeatureIndex: | ||
| """Persistent RAM index: list of (image_idx, det_idx, embedding).""" | ||
|
|
||
| def __init__(self, entries: list[dict[str, Any]] | None = None): | ||
| """Initialize with existing entries or empty.""" | ||
| if entries is not None: | ||
| self.entries = entries | ||
| else: | ||
| self.entries: list[dict[str, Any]] = [] | ||
|
|
||
| def add(self, image_idx: int, det_idx: int, emb: np.ndarray): | ||
| self.entries.append({"image_idx": image_idx, "det_idx": det_idx, "emb": emb}) | ||
|
|
||
| def save(self, path: Path): | ||
| with open(path, "wb") as f: | ||
| pickle.dump(self.entries, f, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
|
||
| def get_features(self, detections: list[Detection]) -> list[Detection]: | ||
| """Update detections list with embeddings from the index.""" | ||
| for det in detections: | ||
| if det.embedding is None: | ||
| for e in self.entries: | ||
| if e["image_idx"] == det.image_idx and e["det_idx"] == det.class_id: | ||
|
||
| det.embedding = e["emb"] | ||
| break | ||
|
|
||
| return detections | ||
|
|
||
| @classmethod | ||
| def load(cls, path: Path) -> FeatureIndex: | ||
| with open(path, "rb") as f: | ||
| entries = pickle.load(f) | ||
|
|
||
| return cls(entries) | ||
|
|
||
| def get_similar_images(self, ref_emb: np.ndarray, thresh: float) -> set[int]: | ||
| """Find images with at least one detection above the threshold.""" | ||
| imgs: set[int] = set() | ||
| for e in self.entries: | ||
| if cosine_similarity(ref_emb, e["emb"]) >= thresh: | ||
| imgs.add(e["image_idx"]) | ||
|
|
||
| return imgs | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| import os | ||
| from pathlib import Path | ||
|
|
||
| import cv2 # type: ignore | ||
| import numpy as np # type: ignore | ||
| import requests | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from sam2.build_sam import build_sam2 | ||
| from sam2.sam2_image_predictor import SAM2ImagePredictor | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def gem_pooling(features: torch.Tensor, mask: torch.Tensor, p: float = 3.0): | ||
| """ | ||
| GeM pooling z maską: features (B, C, H, W), mask (B, 1, H, W) – bool/int. | ||
| Zwraca (B, C) | ||
| """ | ||
| eps = 1e-6 | ||
| masked = features * mask # (B, C, H, W) | ||
| pooled = F.avg_pool2d(masked.clamp(min=eps).pow(p), kernel_size=masked.shape[-2:]) # (B, C, 1, 1) | ||
| pooled = pooled.pow(1.0 / p).squeeze(-1).squeeze(-1) | ||
| # uwzględnij liczbę aktywnych pikseli | ||
| denom = mask.flatten(2).sum(-1).clamp(min=1e-6) # (B,1) | ||
| pooled = pooled / denom | ||
| return F.normalize(pooled, dim=-1) | ||
|
|
||
|
|
||
| class Sam2Runner: | ||
| """Light wrapper that exposes mask + embedding for a bbox.""" | ||
|
|
||
| _instance = None # singleton for reuse | ||
|
|
||
| def __new__(cls): | ||
| if cls._instance is None: | ||
| cls._instance = super().__new__(cls) | ||
| cls._instance._init_model() | ||
| return cls._instance | ||
|
|
||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _init_model(self) -> None: | ||
| """Initialize the SAM2 model.""" | ||
| model_path = "zoo/sam2_tiny.pth" | ||
| if not os.path.exists(model_path): | ||
| url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt" | ||
| Path(model_path).parent.mkdir(parents=True, exist_ok=True) | ||
| logger.info("Downloading SAM2 weights …") | ||
| with requests.get(url, stream=True) as r, open(model_path, "wb") as f: | ||
| for chunk in r.iter_content(1 << 14): | ||
| f.write(chunk) | ||
| cfg = "configs/sam2.1/sam2.1_hiera_t.yaml" | ||
| device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES", "") else "cpu" | ||
| model = build_sam2(cfg, model_path).to(device).eval() | ||
| self._predictor = SAM2ImagePredictor(model) | ||
| self.device = device | ||
|
|
||
| # ------------------------------------------------------------------ | ||
|
|
||
| def mask_and_embed(self, img_bgr: np.ndarray, box_px: tuple[int, int, int, int]) -> tuple[np.ndarray, np.ndarray]: | ||
| img_rgb = img_bgr[:, :, ::-1].copy() | ||
| self._predictor.set_image(img_rgb) # ← tutaj SAM2 wylicza embedding | ||
|
|
||
| # ---------- segmentacja ---------- | ||
| masks, _, _ = self._predictor.predict( | ||
| box=np.array([box_px[0], box_px[1], box_px[0] + box_px[2], box_px[1] + box_px[3]]), | ||
| multimask_output=False, | ||
| return_logits=False, | ||
| ) | ||
| mask_hr = masks[0] # (H, W) bool | ||
|
|
||
| # ---------- mapa cech ---------- | ||
| feat_container = getattr(self._predictor, "_features", None) | ||
| if feat_container is None: | ||
| raise RuntimeError("Brak _features w predictorze — sprawdź wersję biblioteki") | ||
|
|
||
| # słownik → weź 'image_embed' | ||
| feat_map = feat_container.get("image_embed", None) if isinstance(feat_container, dict) else feat_container | ||
|
|
||
| if feat_map is None or not torch.is_tensor(feat_map): | ||
| raise RuntimeError("Nie znalazłem tensora z mapą cech w _features") | ||
|
|
||
| B, C, h, w = feat_map.shape | ||
|
|
||
| mask_lr = cv2.resize(mask_hr.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) | ||
| mask_t = torch.from_numpy(mask_lr).to(feat_map.device).view(1, 1, h, w) | ||
|
|
||
| emb_t = gem_pooling(feat_map, mask_t, p=3.0) # z poprzedniej odpowiedzi | ||
| emb = emb_t.cpu().numpy()[0] # (C,) | ||
|
|
||
| return mask_hr, emb |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| import numpy as np | ||
|
|
||
|
|
||
| def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: | ||
| if a is None or b is None: | ||
| return 0.0 | ||
| if np.linalg.norm(a) == 0 or np.linalg.norm(b) == 0: | ||
| return 0.0 | ||
| return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
len(self.detections) returns the total number of images (keys in the dict), not only those with annotations. To accurately log images with annotations, count entries where the detection list is non-empty.