Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
upsert_contributor_stat,
)
from .models import (
DeduplicatedScanResponse,
Finding,
FixRequest,
FixResponse,
Expand All @@ -60,6 +61,7 @@
from .scanners.gitleaks import run_gitleaks
from .scanners.osv import run_osv_scanner
from .scanners.semgrep import run_semgrep
from .utils.deduplicator import deduplicate
from .utils.fs import ensure_dir, safe_rmtree, unzip_to_dir

_MAX_UPLOAD_MB_RAW = os.environ.get("MAX_UPLOAD_MB")
Expand Down Expand Up @@ -312,17 +314,21 @@ async def download_to_path(url: str, dest_path: Path, max_retries: int = 5) -> N
)
bytes_received = 0
chunk_size = 1024 * 1024

file_limit_exceeded = False
with open(dest_path, "wb") as f:
async for chunk in r.aiter_bytes(chunk_size=chunk_size):
bytes_received += len(chunk)
if bytes_received > MAX_UPLOAD_SIZE:
dest_path.unlink(missing_ok=True)
raise HTTPException(
status_code=413,
detail=f"Remote repository exceeds the maximum limit of {MAX_UPLOAD_MB}MB.",
)
file_limit_exceeded = True
break
f.write(chunk)

if file_limit_exceeded:
dest_path.unlink(missing_ok=True)
raise HTTPException(
status_code=413,
detail=f"Remote repository exceeds the maximum limit of {MAX_UPLOAD_MB}MB.",
)
return

if status_code_for_retry in (403, 429):
Expand Down Expand Up @@ -399,10 +405,26 @@ def update_progress(phase, status):
_scan_repo_dir, scan_root, update_progress
)

disable_dedup = os.environ.get("DISABLE_DEDUP", "false").lower() in {
"1",
"true",
"yes",
}

try:
epsilon = float(os.environ.get("DEDUP_EPSILON", "0.15"))
except ValueError:
epsilon = 0.15

if disable_dedup:
dedup_findings = findings
else:
dedup_findings = deduplicate(findings, epsilon=epsilon)

db = await get_db()
try:
rows = []
for f in findings:
for f in dedup_findings:
engine = (f.metadata or {}).get("engine")
scanner = {"osv-scanner": "osv"}.get(engine, engine)
rule_id = (
Expand Down Expand Up @@ -444,7 +466,8 @@ def update_progress(phase, status):

if job_id in ACTIVE_SCANS:
ACTIVE_SCANS[job_id]["status"] = "completed"
ACTIVE_SCANS[job_id]["findings_count"] = len(findings)
ACTIVE_SCANS[job_id]["findings_count"] = len(dedup_findings)
ACTIVE_SCANS[job_id]["raw_finding_count"] = len(findings)
except Exception:
logger.exception("Failed scan task for %s", job_id)
if job_id in ACTIVE_SCANS:
Expand Down
7 changes: 7 additions & 0 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,10 @@ class OrgJobStatusResponse(BaseModel):
org_job_id: str
status: str
repos: List[RepoStatus]


class DeduplicatedScanResponse(BaseModel):
job_id: str
raw_finding_count: int
finding_count: int
findings: List[Finding]
73 changes: 73 additions & 0 deletions backend/app/utils/deduplicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
from typing import List

from app.models import Finding

logger = logging.getLogger(__name__)

_MODEL = None


def get_model():
"""Lazily load and cache the SentenceTransformer model."""
global _MODEL
if _MODEL is None:
from sentence_transformers import SentenceTransformer

_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
return _MODEL


def deduplicate(findings: List[Finding], epsilon: float = 0.15) -> List[Finding]:
"""
Deduplicates finding descriptions/messages using SentenceTransformer embeddings.
Returns the original findings list if sentence-transformers is unavailable or if loading/encoding fails.
"""
if not findings:
return findings

# Check for sentence_transformers availability
import importlib.util

if importlib.util.find_spec("sentence_transformers") is None:
logger.warning(
"sentence-transformers is not available. Skipping deduplication."
)
return findings

try:
import numpy as np
except ImportError:
logger.warning("numpy is not available. Skipping deduplication.")
return findings

try:
model = get_model()
texts = [f.description if f.description else f.title for f in findings]
embeddings = model.encode(texts, convert_to_numpy=True)

if len(embeddings.shape) == 1:
embeddings = np.expand_dims(embeddings, axis=0)

# Normalize embeddings to compute cosine similarity using dot product
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1.0
normalized_embeddings = embeddings / norms

keep = []
for i in range(len(findings)):
is_dup = False
for j in keep:
sim = np.dot(normalized_embeddings[i], normalized_embeddings[j])
dist = 1.0 - sim
if dist <= epsilon:
is_dup = True
break
if not is_dup:
keep.append(i)

return [findings[idx] for idx in keep]

except Exception as e:
logger.error(f"Error during deduplication: {e}. Skipping deduplication.")
return findings
193 changes: 193 additions & 0 deletions backend/tests/test_scan_dedup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import io
import zipfile
from unittest.mock import patch

import numpy as np
import pytest
from fastapi.testclient import TestClient

import app.utils.deduplicator as dedup_mod
from app.main import app as fastapi_app
from app.models import Finding, Location

client = TestClient(fastapi_app)


def make_dummy_zip():
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
zip_file.writestr("dummy.py", "print('hello')")
zip_buffer.seek(0)
return zip_buffer


class MockSentenceTransformer:
def __init__(self, *args, **kwargs):
"""Initialize mock."""
pass

def encode(self, texts, **kwargs):
embs = []
for text in texts:
if "SQL Injection" in text:
embs.append([1.0, 0.0])
else:
embs.append([0.0, 1.0])
return np.array(embs)


findings_input = [
Finding(
id="1",
category="sast",
severity="HIGH",
title="SQL Injection",
description="SQL Injection in auth.py",
location=Location(path="auth.py", start_line=10),
),
Finding(
id="2",
category="sast",
severity="HIGH",
title="SQL Injection",
description="SQL Injection in auth.py",
location=Location(path="auth.py", start_line=15),
),
Finding(
id="3",
category="secret",
severity="CRITICAL",
title="Hardcoded Password",
description="Hardcoded password in config.py",
location=Location(path="config.py", start_line=5),
),
]


@pytest.fixture(autouse=True)
def reset_dedup_cache():
dedup_mod._MODEL = None
yield
dedup_mod._MODEL = None


# Case 1: Dedup enabled with duplicate findings
@patch("app.main.unzip_to_dir")
@patch("app.main._scan_repo_dir")
@patch("app.utils.deduplicator.get_model")
def test_scan_dedup_enabled(mock_get_model, mock_scan, mock_unzip, monkeypatch):
monkeypatch.delenv("DISABLE_DEDUP", raising=False)
monkeypatch.setenv("DEDUP_EPSILON", "0.15")
mock_get_model.return_value = MockSentenceTransformer()
monkeypatch.setattr("importlib.util.find_spec", lambda name: object())
mock_scan.return_value = ([], [], [], [], findings_input)

zip_file = make_dummy_zip()
res = client.post(
"/scan",
files={"project": ("project.zip", zip_file, "application/zip")},
data={"project_name": "test_project"},
)
assert res.status_code == 200
job_id = res.json()["job_id"]

from app.main import ACTIVE_SCANS
assert ACTIVE_SCANS[job_id]["raw_finding_count"] == 3
assert ACTIVE_SCANS[job_id]["findings_count"] == 2

res_findings = client.get(f"/jobs/{job_id}/findings")
assert res_findings.status_code == 200
data = res_findings.json()
assert data["finding_count"] == 2
assert len(data["findings"]) == 2
assert {f["rule_id"] for f in data["findings"]} == {"SQL Injection", "Hardcoded Password"}


# Case 2: DISABLE_DEDUP=true
@patch("app.main.unzip_to_dir")
@patch("app.main._scan_repo_dir")
@patch("app.utils.deduplicator.get_model")
def test_scan_dedup_disabled(mock_get_model, mock_scan, mock_unzip, monkeypatch):
monkeypatch.setenv("DISABLE_DEDUP", "true")
mock_get_model.return_value = MockSentenceTransformer()
mock_scan.return_value = ([], [], [], [], findings_input)

zip_file = make_dummy_zip()
res = client.post(
"/scan",
files={"project": ("project.zip", zip_file, "application/zip")},
data={"project_name": "test_project"},
)
assert res.status_code == 200
job_id = res.json()["job_id"]

from app.main import ACTIVE_SCANS
assert ACTIVE_SCANS[job_id]["raw_finding_count"] == 3
assert ACTIVE_SCANS[job_id]["findings_count"] == 3

res_findings = client.get(f"/jobs/{job_id}/findings")
assert res_findings.status_code == 200
data = res_findings.json()
assert data["finding_count"] == 3
assert len(data["findings"]) == 3


# Case 3: sentence-transformers unavailable
@patch("app.main.unzip_to_dir")
@patch("app.main._scan_repo_dir")
def test_scan_dedup_sentence_transformers_unavailable(
mock_scan, mock_unzip, monkeypatch
):
monkeypatch.delenv("DISABLE_DEDUP", raising=False)
monkeypatch.setattr("importlib.util.find_spec", lambda name: None)
mock_scan.return_value = ([], [], [], [], findings_input)

zip_file = make_dummy_zip()
res = client.post(
"/scan",
files={"project": ("project.zip", zip_file, "application/zip")},
data={"project_name": "test_project"},
)
assert res.status_code == 200
job_id = res.json()["job_id"]

from app.main import ACTIVE_SCANS
assert ACTIVE_SCANS[job_id]["raw_finding_count"] == 3
assert ACTIVE_SCANS[job_id]["findings_count"] == 3

res_findings = client.get(f"/jobs/{job_id}/findings")
assert res_findings.status_code == 200
data = res_findings.json()
assert data["finding_count"] == 3
assert len(data["findings"]) == 3


# Case 4: Invalid DEDUP_EPSILON value (fallback to 0.15)
@patch("app.main.unzip_to_dir")
@patch("app.main._scan_repo_dir")
@patch("app.utils.deduplicator.get_model")
def test_scan_dedup_invalid_epsilon(mock_get_model, mock_scan, mock_unzip, monkeypatch):
monkeypatch.delenv("DISABLE_DEDUP", raising=False)
monkeypatch.setenv("DEDUP_EPSILON", "abc")
mock_get_model.return_value = MockSentenceTransformer()
monkeypatch.setattr("importlib.util.find_spec", lambda name: object())
mock_scan.return_value = ([], [], [], [], findings_input)

zip_file = make_dummy_zip()
res = client.post(
"/scan",
files={"project": ("project.zip", zip_file, "application/zip")},
data={"project_name": "test_project"},
)
assert res.status_code == 200
job_id = res.json()["job_id"]

from app.main import ACTIVE_SCANS
assert ACTIVE_SCANS[job_id]["raw_finding_count"] == 3
assert ACTIVE_SCANS[job_id]["findings_count"] == 2

res_findings = client.get(f"/jobs/{job_id}/findings")
assert res_findings.status_code == 200
data = res_findings.json()
assert data["finding_count"] == 2
assert len(data["findings"]) == 2
Loading