Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 48 additions & 22 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
try:
from inference import load_models, predict_stream_a, predict_stream_b
from fusion import process_and_fuse
from species import load_species_model, predict_species

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don’t import species metadata inside history conversion.

Line 304 bypasses the existing PyTorch gating: if species.py failed to import because Torch/TorchVision is unavailable, history serialization can still crash here. Import SPECIES_METADATA in the guarded top-level import path and provide a lightweight fallback in the except branch, then use that module-level value in _row_to_payload().

Proposed direction
-    from species import load_species_model, predict_species
+    from species import SPECIES_METADATA, load_species_model, predict_species
-    from species import SPECIES_METADATA
     metadata = SPECIES_METADATA.get(species_name, {"scientific_name": "Labeo rohita", "habitat": "Freshwater"})

Also define a fallback SPECIES_METADATA in the existing import failure branch so _row_to_payload() remains usable without PyTorch.

Also applies to: 302-305

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@backend/main.py` at line 35, The import of species module inside the
_row_to_payload() function bypasses the top-level PyTorch gating mechanism,
which means history serialization can crash if PyTorch/TorchVision is
unavailable. Move the species module import and SPECIES_METADATA extraction to
the top-level guarded import block (where the try-except handles PyTorch
availability), define a lightweight fallback SPECIES_METADATA in the except
branch for when PyTorch is unavailable, then replace the dynamic import inside
_row_to_payload() with a reference to the module-level SPECIES_METADATA variable
to ensure it uses the pre-initialized value.


_torch_available = True
except ModuleNotFoundError:
Expand All @@ -53,6 +54,7 @@
MODEL_DIR = Path(os.environ.get("MODEL_DIR", str(_repo_root / "Models")))
STREAM_A_PATH = os.environ.get("STREAM_A_MODEL", str(MODEL_DIR / "freshscan_stream_a_body.pth"))
STREAM_B_PATH = os.environ.get("STREAM_B_MODEL", str(MODEL_DIR / "stream_b_checkpoint.pth"))
SPECIES_MODEL_PATH = os.environ.get("SPECIES_MODEL", str(MODEL_DIR / "species_mobilenetv3.pth"))


# ── Supabase clients ──────────────────────────────────────────────────────────
Expand Down Expand Up @@ -81,18 +83,22 @@ async def lifespan(app: FastAPI):
global _models_loaded
a = Path(STREAM_A_PATH)
b = Path(STREAM_B_PATH)
sp = Path(SPECIES_MODEL_PATH)
if not _torch_available:
print("WARNING: PyTorch not installed. Scan endpoints will return 503.")
elif a.exists() and b.exists():
print(f"Loading models from {MODEL_DIR} ...")
load_models(str(a), str(b))
_models_loaded = True
print("Models loaded successfully.")
else:
print(
f"WARNING: Model files not found at {MODEL_DIR}. "
"Scan endpoints will return 503 until models are present."
)
if a.exists() and b.exists():
print(f"Loading models from {MODEL_DIR} ...")
load_models(str(a), str(b))
_models_loaded = True
print("Models loaded successfully.")
else:
print(
f"WARNING: Model files not found at {MODEL_DIR}. "
"Scan endpoints will return 503 until models are present."
)
# Load species classifier (optional — falls back to default if missing)
load_species_model(str(sp))
yield


Expand Down Expand Up @@ -224,6 +230,7 @@ def _build_scan_payload(
scan_id: str,
display_id: str,
photo_url: Optional[str] = None,
species_info: Optional[dict] = None,
) -> dict:
score = fusion["final_score_percent"]
reg = fusion["regional_breakdown"]
Expand All @@ -246,6 +253,14 @@ def _build_scan_payload(

consume_hours = max(0, int((freshness - 40) * 0.6)) if is_fresh else 0

# Use detected species or fallback to default
if species_info is None:
species_info = {}

species_name = species_info.get("common_name", "Rohu Carp")
scientific_name = species_info.get("scientific_name", "Labeo rohita")
habitat = species_info.get("habitat", "Freshwater")

return {
"scan_id": scan_id,
"scan_display_id": display_id,
Expand All @@ -256,10 +271,10 @@ def _build_scan_payload(
"is_fresh": is_fresh,
"uncertain_flag": fusion["uncertain_prediction_flag"],
"species": {
"common_name": "Rohu Carp",
"scientific_name": "Labeo rohita",
"habitat": "Freshwater",
"tags": ["ROHU CARP", "LABEO ROHITA", "FRESHWATER"],
"common_name": species_name,
"scientific_name": scientific_name,
"habitat": habitat,
"tags": [species_name.upper(), scientific_name.upper(), habitat.upper()],
"weight_estimate_kg": 1.2,
"catch_age_hours": 6,
},
Expand All @@ -284,6 +299,11 @@ def _row_to_payload(row: dict) -> dict:
if not bm:
bm = _build_biomarkers(freshness, freshness, freshness)

# Use species_detected from DB or fallback
species_name = row.get("species_detected") or "Rohu Carp"
from species import SPECIES_METADATA
metadata = SPECIES_METADATA.get(species_name, {"scientific_name": "Labeo rohita", "habitat": "Freshwater"})

return {
"scan_id": row["id"],
"scan_display_id": row.get("scan_display_id") or row["id"][:8].upper(),
Expand All @@ -294,10 +314,10 @@ def _row_to_payload(row: dict) -> dict:
"is_fresh": is_fresh,
"uncertain_flag": False,
"species": {
"common_name": "Rohu Carp",
"scientific_name": "Labeo rohita",
"habitat": "Freshwater",
"tags": ["ROHU CARP", "LABEO ROHITA", "FRESHWATER"],
"common_name": species_name,
"scientific_name": metadata["scientific_name"],
"habitat": metadata["habitat"],
"tags": [species_name.upper(), metadata["scientific_name"].upper(), metadata["habitat"].upper()],
"weight_estimate_kg": 1.2,
"catch_age_hours": 6,
},
Expand Down Expand Up @@ -490,7 +510,10 @@ async def process_scan(
predict_stream_b(img_gill),
temperature=1.5,
)
payload = _build_scan_payload(fusion, scan_id, display_id)

# Classify species from the body image
species_info = predict_species(img_body)
payload = _build_scan_payload(fusion, scan_id, display_id, species_info=species_info)

try:
_db().table("scans").insert(
Expand All @@ -503,7 +526,7 @@ async def process_scan(
"image_type": "full_scan",
"freshness_index": payload["freshness_index"],
"scan_display_id": display_id,
"species_detected": "Rohu Carp",
"species_detected": species_info["common_name"],
"biomarker_json": payload["biomarkers"],
"storage_hours": payload["recommendations"]["consume_within_hours"],
"alert_flags": payload["recommendations"]["alert_flags"],
Expand Down Expand Up @@ -546,7 +569,7 @@ async def scan_auto(
},
}
photo_url = await _upload_image(image_bytes, str(current_user.id), scan_id)
payload = _build_scan_payload(demo_fusion, scan_id, display_id, photo_url)
payload = _build_scan_payload(demo_fusion, scan_id, display_id, photo_url, species_info={"common_name": "Rohu Carp", "scientific_name": "Labeo rohita", "habitat": "Freshwater"})

try:
_db().table("scans").insert(
Expand Down Expand Up @@ -608,7 +631,10 @@ async def scan_auto(

fusion = process_and_fuse(body_logits, eye_logits, gill_logits, temperature=1.5)
photo_url = await _upload_image(image_bytes, str(current_user.id), scan_id)
payload = _build_scan_payload(fusion, scan_id, display_id, photo_url)

# Classify species from the uploaded image
species_info = predict_species(img)
payload = _build_scan_payload(fusion, scan_id, display_id, photo_url, species_info=species_info)

try:
_db().table("scans").insert(
Expand All @@ -620,7 +646,7 @@ async def scan_auto(
"image_type": image_type.value,
"freshness_index": payload["freshness_index"],
"scan_display_id": display_id,
"species_detected": "Rohu Carp",
"species_detected": species_info["common_name"],
"biomarker_json": payload["biomarkers"],
"storage_hours": payload["recommendations"]["consume_within_hours"],
"alert_flags": payload["recommendations"]["alert_flags"],
Expand Down
147 changes: 147 additions & 0 deletions backend/species.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
Fish Species Classification Module

Uses MobileNetV3-Small for lightweight, fast species identification.
Designed for edge deployment via ONNX Runtime.
"""

import torch
import torch.nn as nn
from torchvision import models, transforms
Comment on lines +8 to +10

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Locate dependency declarations and show torch/torchvision pins with context.

fd -i '^(requirements.*|pyproject\.toml|setup\.py|poetry\.lock|uv\.lock|Pipfile|Pipfile\.lock|Dockerfile.*)$' \
  --exec sh -c 'echo "### $1"; rg -n -C2 "torch|torchvision" "$1" || true' sh {}

Repository: jpdevhub/FreshScanAi

Length of output: 1064


Fix incompatible torch/torchvision version constraint in requirements.txt.

The Dockerfile correctly pins torch==2.2.2 and torchvision==0.17.2 (a compatible pair), but requirements.txt pins torchvision>=0.27.0, which requires torch>=2.4.0. This mismatch will cause import failures when installing from requirements.txt without the Dockerfile context.

Update requirements.txt to pin torchvision>=0.17.0,<0.18.0 (or use torchvision==0.17.2 to match Dockerfile) to align with torch>=2.2.0.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@backend/species.py` around lines 8 - 10, The torchvision version constraint
in requirements.txt is incompatible with the torch version specified. Update the
torchvision constraint in requirements.txt from the current incompatible version
specification to match the compatible version pair pinned in the Dockerfile
(torchvision should be pinned to version 0.17.2 or use a range like
>=0.17.0,<0.18.0 to align with torch>=2.2.0). This ensures that both the
Dockerfile and requirements.txt will install compatible versions of torch and
torchvision, preventing import failures when installing from requirements.txt.

from PIL import Image
import numpy as np
from pathlib import Path

# ── Species Labels ──────────────────────────────────────────────────────────
# Common fish species found in Indian/local markets
SPECIES_LABELS = [
"Rohu Carp",
"Catla Carp",
"Mrigal Carp",
"Pangas",
"Basa",
"Tilapia",
"Pomfret",
"Kingfish",
"Mackerel",
"Sardine",
]

SPECIES_METADATA = {
"Rohu Carp": {"scientific_name": "Labeo rohita", "habitat": "Freshwater"},
"Catla Carp": {"scientific_name": "Catla catla", "habitat": "Freshwater"},
"Mrigal Carp": {"scientific_name": "Cirrhinus cirrhosus", "habitat": "Freshwater"},
"Pangas": {"scientific_name": "Pangasius hypophthalmus", "habitat": "Freshwater"},
"Basa": {"scientific_name": "Pangasius bocourti", "habitat": "Freshwater"},
"Tilapia": {"scientific_name": "Oreochromis niloticus", "habitat": "Freshwater"},
"Pomfret": {"scientific_name": "Pampus argenteus", "habitat": "Marine"},
"Kingfish": {"scientific_name": "Scomberomorus commerson", "habitat": "Marine"},
"Mackerel": {"scientific_name": "Rastrelliger kanagurta", "habitat": "Marine"},
"Sardine": {"scientific_name": "Sardinella longiceps", "habitat": "Marine"},
}

NUM_SPECIES = len(SPECIES_LABELS)

# ── Device ──────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ── Model Architecture ──────────────────────────────────────────────────────
def get_species_model(num_classes: int = NUM_SPECIES) -> nn.Module:
"""
MobileNetV3-Small backbone with a custom classification head.
Lightweight (~2.5M params) — ideal for edge/FastAPI deployment.
"""
model = models.mobilenet_v3_small(weights=None)
# Replace classifier head for our species count
in_features = model.classifier[0].in_features
model.classifier = nn.Sequential(
nn.Linear(in_features, 256),
nn.Hardswish(inplace=True),
nn.Dropout(p=0.2),
nn.Linear(256, num_classes),
)
return model


# ── Preprocessing ───────────────────────────────────────────────────────────
species_transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)


# ── Global Model Reference ─────────────────────────────────────────────────
_species_model = None
_species_loaded = False


def load_species_model(weights_path: str):
"""
Load pre-trained species classification weights.
Run once on server startup.
"""
global _species_model, _species_loaded

path = Path(weights_path)
if not path.exists():
print(f"WARNING: Species model not found at {path}. Using default species.")
return

_species_model = get_species_model()
checkpoint = torch.load(path, map_location=device, weights_only=True)

if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
_species_model.load_state_dict(checkpoint["model_state_dict"])
else:
_species_model.load_state_dict(checkpoint)
Comment on lines +95 to +100

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Persist and validate the class-index mapping with the checkpoint.

Inference maps top_idx directly into SPECIES_LABELS, but the training snippet saves only a bare state_dict, so no label order is available to validate. If the dataset class order differs from SPECIES_LABELS, every prediction can be assigned the wrong species. Save species_labels or class_to_idx with the checkpoint and reject/remap mismatches during load.

Also applies to: 137-146

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@backend/species.py` around lines 95 - 100, The checkpoint loading code
(around line 95-100 in the conditional branches checking for "model_state_dict")
does not validate that the class-to-index mapping used during training matches
the current SPECIES_LABELS ordering. Modify the checkpoint saving code
(referenced in lines 137-146) to include the species_labels or class_to_idx
mapping alongside the model_state_dict. Then in the checkpoint loading code,
extract and validate this saved mapping against the current SPECIES_LABELS, and
either reject the checkpoint if there is a critical mismatch or apply a
remapping transformation to correct the model output indices before mapping them
to species names during inference.


_species_model.to(device)
_species_model.eval()
_species_loaded = True
Comment on lines +89 to +104

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Keep species loading optional when the checkpoint is present but invalid.

A bad or incompatible weights file currently raises from torch.load() / load_state_dict() and aborts startup, even though species classification is intended to fall back when unavailable. Reset the globals and return the fallback path on load failure.

Proposed defensive load handling
     path = Path(weights_path)
     if not path.exists():
         print(f"WARNING: Species model not found at {path}. Using default species.")
+        _species_model = None
+        _species_loaded = False
         return
 
-    _species_model = get_species_model()
-    checkpoint = torch.load(path, map_location=device, weights_only=True)
-
-    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
-        _species_model.load_state_dict(checkpoint["model_state_dict"])
-    else:
-        _species_model.load_state_dict(checkpoint)
-
-    _species_model.to(device)
-    _species_model.eval()
-    _species_loaded = True
+    try:
+        _species_model = get_species_model()
+        checkpoint = torch.load(path, map_location=device, weights_only=True)
+
+        if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
+            _species_model.load_state_dict(checkpoint["model_state_dict"])
+        else:
+            _species_model.load_state_dict(checkpoint)
+
+        _species_model.to(device)
+        _species_model.eval()
+        _species_loaded = True
+    except Exception as exc:
+        print(f"WARNING: Failed to load species model from {path}: {exc}. Using default species.")
+        _species_model = None
+        _species_loaded = False
+        return
     print(f"Species model loaded from {path}")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@backend/species.py` around lines 89 - 104, The torch.load() and
load_state_dict() calls in the species model loading section can raise
exceptions from invalid or incompatible checkpoint files, which will abort the
entire startup process. Wrap the checkpoint loading logic (from torch.load
through the load_state_dict call and model preparation) in a try-except block.
When any exception occurs during loading, catch it, reset the global variables
_species_model and _species_loaded to their initial state, log a warning message
about the load failure, and return early. This ensures species classification
remains optional and the application can continue with the fallback behavior
when weights are invalid or incompatible.

print(f"Species model loaded from {path}")


# ── Inference ───────────────────────────────────────────────────────────────
@torch.no_grad()
def predict_species(image: Image.Image) -> dict:
"""
Classify fish species from an image.

Returns:
{
"common_name": str,
"scientific_name": str,
"habitat": str,
"confidence": float,
"all_probs": dict[str, float],
}
"""
if not _species_loaded or _species_model is None:
# Fallback: return default species with low confidence
return {
"common_name": "Rohu Carp",
"scientific_name": "Labeo rohita",
"habitat": "Freshwater",
"confidence": 0.0,
"all_probs": {label: 0.0 for label in SPECIES_LABELS},
}

tensor = species_transform(image).unsqueeze(0).to(device)
logits = _species_model(tensor)
probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()

top_idx = int(np.argmax(probs))
top_label = SPECIES_LABELS[top_idx]
metadata = SPECIES_METADATA[top_label]

return {
"common_name": top_label,
"scientific_name": metadata["scientific_name"],
"habitat": metadata["habitat"],
"confidence": float(probs[top_idx]),
"all_probs": {label: float(probs[i]) for i, label in enumerate(SPECIES_LABELS)},
}
Loading
Loading