diff --git a/backend/app/main.py b/backend/app/main.py index 57c1067..21b5e22 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -42,6 +42,7 @@ upsert_contributor_stat, ) from .models import ( + DeduplicatedScanResponse, Finding, FixRequest, FixResponse, @@ -60,6 +61,7 @@ from .scanners.gitleaks import run_gitleaks from .scanners.osv import run_osv_scanner from .scanners.semgrep import run_semgrep +from .utils.deduplicator import deduplicate from .utils.fs import ensure_dir, safe_rmtree, unzip_to_dir _MAX_UPLOAD_MB_RAW = os.environ.get("MAX_UPLOAD_MB") @@ -312,17 +314,21 @@ async def download_to_path(url: str, dest_path: Path, max_retries: int = 5) -> N ) bytes_received = 0 chunk_size = 1024 * 1024 - + file_limit_exceeded = False with open(dest_path, "wb") as f: async for chunk in r.aiter_bytes(chunk_size=chunk_size): bytes_received += len(chunk) if bytes_received > MAX_UPLOAD_SIZE: - dest_path.unlink(missing_ok=True) - raise HTTPException( - status_code=413, - detail=f"Remote repository exceeds the maximum limit of {MAX_UPLOAD_MB}MB.", - ) + file_limit_exceeded = True + break f.write(chunk) + + if file_limit_exceeded: + dest_path.unlink(missing_ok=True) + raise HTTPException( + status_code=413, + detail=f"Remote repository exceeds the maximum limit of {MAX_UPLOAD_MB}MB.", + ) return if status_code_for_retry in (403, 429): @@ -399,10 +405,26 @@ def update_progress(phase, status): _scan_repo_dir, scan_root, update_progress ) + disable_dedup = os.environ.get("DISABLE_DEDUP", "false").lower() in { + "1", + "true", + "yes", + } + + try: + epsilon = float(os.environ.get("DEDUP_EPSILON", "0.15")) + except ValueError: + epsilon = 0.15 + + if disable_dedup: + dedup_findings = findings + else: + dedup_findings = deduplicate(findings, epsilon=epsilon) + db = await get_db() try: rows = [] - for f in findings: + for f in dedup_findings: engine = (f.metadata or {}).get("engine") scanner = {"osv-scanner": "osv"}.get(engine, engine) rule_id = ( @@ -444,7 +466,8 @@ def update_progress(phase, status): if job_id in ACTIVE_SCANS: ACTIVE_SCANS[job_id]["status"] = "completed" - ACTIVE_SCANS[job_id]["findings_count"] = len(findings) + ACTIVE_SCANS[job_id]["findings_count"] = len(dedup_findings) + ACTIVE_SCANS[job_id]["raw_finding_count"] = len(findings) except Exception: logger.exception("Failed scan task for %s", job_id) if job_id in ACTIVE_SCANS: diff --git a/backend/app/models.py b/backend/app/models.py index 6aca77b..bd9da4d 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -75,3 +75,10 @@ class OrgJobStatusResponse(BaseModel): org_job_id: str status: str repos: List[RepoStatus] + + +class DeduplicatedScanResponse(BaseModel): + job_id: str + raw_finding_count: int + finding_count: int + findings: List[Finding] diff --git a/backend/app/utils/deduplicator.py b/backend/app/utils/deduplicator.py new file mode 100644 index 0000000..2f0dc59 --- /dev/null +++ b/backend/app/utils/deduplicator.py @@ -0,0 +1,73 @@ +import logging +from typing import List + +from app.models import Finding + +logger = logging.getLogger(__name__) + +_MODEL = None + + +def get_model(): + """Lazily load and cache the SentenceTransformer model.""" + global _MODEL + if _MODEL is None: + from sentence_transformers import SentenceTransformer + + _MODEL = SentenceTransformer("all-MiniLM-L6-v2") + return _MODEL + + +def deduplicate(findings: List[Finding], epsilon: float = 0.15) -> List[Finding]: + """ + Deduplicates finding descriptions/messages using SentenceTransformer embeddings. + Returns the original findings list if sentence-transformers is unavailable or if loading/encoding fails. + """ + if not findings: + return findings + + # Check for sentence_transformers availability + import importlib.util + + if importlib.util.find_spec("sentence_transformers") is None: + logger.warning( + "sentence-transformers is not available. Skipping deduplication." + ) + return findings + + try: + import numpy as np + except ImportError: + logger.warning("numpy is not available. Skipping deduplication.") + return findings + + try: + model = get_model() + texts = [f.description if f.description else f.title for f in findings] + embeddings = model.encode(texts, convert_to_numpy=True) + + if len(embeddings.shape) == 1: + embeddings = np.expand_dims(embeddings, axis=0) + + # Normalize embeddings to compute cosine similarity using dot product + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + norms[norms == 0] = 1.0 + normalized_embeddings = embeddings / norms + + keep = [] + for i in range(len(findings)): + is_dup = False + for j in keep: + sim = np.dot(normalized_embeddings[i], normalized_embeddings[j]) + dist = 1.0 - sim + if dist <= epsilon: + is_dup = True + break + if not is_dup: + keep.append(i) + + return [findings[idx] for idx in keep] + + except Exception as e: + logger.error(f"Error during deduplication: {e}. Skipping deduplication.") + return findings diff --git a/backend/tests/test_scan_dedup.py b/backend/tests/test_scan_dedup.py new file mode 100644 index 0000000..652941b --- /dev/null +++ b/backend/tests/test_scan_dedup.py @@ -0,0 +1,193 @@ +import io +import zipfile +from unittest.mock import patch + +import numpy as np +import pytest +from fastapi.testclient import TestClient + +import app.utils.deduplicator as dedup_mod +from app.main import app as fastapi_app +from app.models import Finding, Location + +client = TestClient(fastapi_app) + + +def make_dummy_zip(): + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: + zip_file.writestr("dummy.py", "print('hello')") + zip_buffer.seek(0) + return zip_buffer + + +class MockSentenceTransformer: + def __init__(self, *args, **kwargs): + """Initialize mock.""" + pass + + def encode(self, texts, **kwargs): + embs = [] + for text in texts: + if "SQL Injection" in text: + embs.append([1.0, 0.0]) + else: + embs.append([0.0, 1.0]) + return np.array(embs) + + +findings_input = [ + Finding( + id="1", + category="sast", + severity="HIGH", + title="SQL Injection", + description="SQL Injection in auth.py", + location=Location(path="auth.py", start_line=10), + ), + Finding( + id="2", + category="sast", + severity="HIGH", + title="SQL Injection", + description="SQL Injection in auth.py", + location=Location(path="auth.py", start_line=15), + ), + Finding( + id="3", + category="secret", + severity="CRITICAL", + title="Hardcoded Password", + description="Hardcoded password in config.py", + location=Location(path="config.py", start_line=5), + ), +] + + +@pytest.fixture(autouse=True) +def reset_dedup_cache(): + dedup_mod._MODEL = None + yield + dedup_mod._MODEL = None + + +# Case 1: Dedup enabled with duplicate findings +@patch("app.main.unzip_to_dir") +@patch("app.main._scan_repo_dir") +@patch("app.utils.deduplicator.get_model") +def test_scan_dedup_enabled(mock_get_model, mock_scan, mock_unzip, monkeypatch): + monkeypatch.delenv("DISABLE_DEDUP", raising=False) + monkeypatch.setenv("DEDUP_EPSILON", "0.15") + mock_get_model.return_value = MockSentenceTransformer() + monkeypatch.setattr("importlib.util.find_spec", lambda name: object()) + mock_scan.return_value = ([], [], [], [], findings_input) + + zip_file = make_dummy_zip() + res = client.post( + "/scan", + files={"project": ("project.zip", zip_file, "application/zip")}, + data={"project_name": "test_project"}, + ) + assert res.status_code == 200 + job_id = res.json()["job_id"] + + from app.main import ACTIVE_SCANS + assert ACTIVE_SCANS[job_id]["raw_finding_count"] == 3 + assert ACTIVE_SCANS[job_id]["findings_count"] == 2 + + res_findings = client.get(f"/jobs/{job_id}/findings") + assert res_findings.status_code == 200 + data = res_findings.json() + assert data["finding_count"] == 2 + assert len(data["findings"]) == 2 + assert {f["rule_id"] for f in data["findings"]} == {"SQL Injection", "Hardcoded Password"} + + +# Case 2: DISABLE_DEDUP=true +@patch("app.main.unzip_to_dir") +@patch("app.main._scan_repo_dir") +@patch("app.utils.deduplicator.get_model") +def test_scan_dedup_disabled(mock_get_model, mock_scan, mock_unzip, monkeypatch): + monkeypatch.setenv("DISABLE_DEDUP", "true") + mock_get_model.return_value = MockSentenceTransformer() + mock_scan.return_value = ([], [], [], [], findings_input) + + zip_file = make_dummy_zip() + res = client.post( + "/scan", + files={"project": ("project.zip", zip_file, "application/zip")}, + data={"project_name": "test_project"}, + ) + assert res.status_code == 200 + job_id = res.json()["job_id"] + + from app.main import ACTIVE_SCANS + assert ACTIVE_SCANS[job_id]["raw_finding_count"] == 3 + assert ACTIVE_SCANS[job_id]["findings_count"] == 3 + + res_findings = client.get(f"/jobs/{job_id}/findings") + assert res_findings.status_code == 200 + data = res_findings.json() + assert data["finding_count"] == 3 + assert len(data["findings"]) == 3 + + +# Case 3: sentence-transformers unavailable +@patch("app.main.unzip_to_dir") +@patch("app.main._scan_repo_dir") +def test_scan_dedup_sentence_transformers_unavailable( + mock_scan, mock_unzip, monkeypatch +): + monkeypatch.delenv("DISABLE_DEDUP", raising=False) + monkeypatch.setattr("importlib.util.find_spec", lambda name: None) + mock_scan.return_value = ([], [], [], [], findings_input) + + zip_file = make_dummy_zip() + res = client.post( + "/scan", + files={"project": ("project.zip", zip_file, "application/zip")}, + data={"project_name": "test_project"}, + ) + assert res.status_code == 200 + job_id = res.json()["job_id"] + + from app.main import ACTIVE_SCANS + assert ACTIVE_SCANS[job_id]["raw_finding_count"] == 3 + assert ACTIVE_SCANS[job_id]["findings_count"] == 3 + + res_findings = client.get(f"/jobs/{job_id}/findings") + assert res_findings.status_code == 200 + data = res_findings.json() + assert data["finding_count"] == 3 + assert len(data["findings"]) == 3 + + +# Case 4: Invalid DEDUP_EPSILON value (fallback to 0.15) +@patch("app.main.unzip_to_dir") +@patch("app.main._scan_repo_dir") +@patch("app.utils.deduplicator.get_model") +def test_scan_dedup_invalid_epsilon(mock_get_model, mock_scan, mock_unzip, monkeypatch): + monkeypatch.delenv("DISABLE_DEDUP", raising=False) + monkeypatch.setenv("DEDUP_EPSILON", "abc") + mock_get_model.return_value = MockSentenceTransformer() + monkeypatch.setattr("importlib.util.find_spec", lambda name: object()) + mock_scan.return_value = ([], [], [], [], findings_input) + + zip_file = make_dummy_zip() + res = client.post( + "/scan", + files={"project": ("project.zip", zip_file, "application/zip")}, + data={"project_name": "test_project"}, + ) + assert res.status_code == 200 + job_id = res.json()["job_id"] + + from app.main import ACTIVE_SCANS + assert ACTIVE_SCANS[job_id]["raw_finding_count"] == 3 + assert ACTIVE_SCANS[job_id]["findings_count"] == 2 + + res_findings = client.get(f"/jobs/{job_id}/findings") + assert res_findings.status_code == 200 + data = res_findings.json() + assert data["finding_count"] == 2 + assert len(data["findings"]) == 2