Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions backend/app/attack_paths/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

import uuid
from typing import List
from pydantic import BaseModel

from ..models import Finding
from .models import NormalizedFinding, AttackPath, AttackStep
from .graph_builder import build_graph, extract_paths
from .scorer import calculate_risk
from ..utils.fs import ensure_dir
from ..db import get_db
import json

async def generate_attack_paths(job_id: str) -> List[AttackPath]:
"""Generate attack paths for a given scan job.

Steps:
1. Load raw findings from the database.
2. Normalize them to a common schema.
3. Build a directed graph linking related findings.
4. Extract all possible paths.
5. Score each path.
"""
# --- 1. Load findings ---
db = await get_db()
try:
cur = await db.execute(
"""
SELECT id, rule_id, severity, category, file_path, line_number, message, metadata
FROM findings
WHERE job_id = ?
""",
(job_id,)
)
rows = await cur.fetchall()
finally:
await db.close()

raw_findings: List[Finding] = []
for row in rows:
fid, rule_id, severity, category, file_path, line_number, message, metadata_json = row
metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else {}
location = None
if file_path:
from ..models import Location
location = Location(path=file_path, start_line=line_number)
finding = Finding(
id=fid,
category=category,
severity=severity,
title=rule_id or "",
description=message or "",
location=location,
metadata=metadata,
)
raw_findings.append(finding)

# --- 2. Normalize ---
normalized: List[NormalizedFinding] = []
for f in raw_findings:
norm = NormalizedFinding(
id=f.id,
category=f.category.lower(),
severity=f.severity,
title=f.title,
description=f.description,
metadata=f.metadata,
)
normalized.append(norm)

# --- 3. Build graph ---
graph = build_graph(normalized)
# --- 4. Extract paths ---
paths = extract_paths(graph)

# --- 5. Score paths ---
scored_paths: List[AttackPath] = []
for p in paths:
risk = calculate_risk(p)
scored = AttackPath(id=str(uuid.uuid4()), steps=p.steps, risk_score=risk)
scored_paths.append(scored)

return scored_paths
65 changes: 65 additions & 0 deletions backend/app/attack_paths/graph_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import networkx as nx
from typing import List, Dict
from .models import NormalizedFinding, AttackStep, AttackPath
import uuid

# Deterministic correlation rules mapping categories to next step labels
_CORRELATION_MAP: Dict[str, str] = {
"secret": "Cloud Access",
"dependency": "Remote Code Execution",
"privilege_escalation": "Data Exposure",
}

def build_graph(findings: List[NormalizedFinding]) -> nx.DiGraph:
"""Build a directed graph linking findings according to correlation rules.

Each finding becomes a node. For a finding whose ``category`` matches a key in
``_CORRELATION_MAP`` an edge is added to an abstract intermediate node that
represents the correlated step.
"""
graph = nx.DiGraph()

# Add finding nodes
for f in findings:
node_id = f.id
label = f.title if f.title else f.category
graph.add_node(node_id, step=AttackStep(label=label, finding_id=f.id))

# Add correlation edges using abstract intermediate nodes
for f in findings:
next_label = _CORRELATION_MAP.get(f.category)
if not next_label:
continue
# Create a unique intermediate node for this correlation type if not exists
inter_id = f"{f.category}_intermediate"
if not graph.has_node(inter_id):
graph.add_node(inter_id, step=AttackStep(label=next_label))
graph.add_edge(f.id, inter_id)

return graph

def extract_paths(graph: nx.DiGraph) -> List[AttackPath]:
"""Extract all linear paths from source finding nodes to leaf nodes.

The function walks each source node (nodes without incoming edges) to every
reachable leaf (nodes without outgoing edges) and builds an ``AttackPath``
consisting of the ordered ``AttackStep`` objects.
"""
paths: List[AttackPath] = []
sources = [n for n in graph.nodes if graph.in_degree(n) == 0]
leaves = [n for n in graph.nodes if graph.out_degree(n) == 0]

for src in sources:
for leaf in leaves:
if src == leaf:
continue
try:
for node_path in nx.all_simple_paths(graph, source=src, target=leaf):
steps = [graph.nodes[n]["step"] for n in node_path]
path_id = str(uuid.uuid4())
paths.append(AttackPath(id=path_id, steps=steps, risk_score=0.0))
except nx.NetworkXNoPath:
continue
return paths
45 changes: 45 additions & 0 deletions backend/app/attack_paths/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from typing import List, Dict, Any
from pydantic import BaseModel, Field

class NormalizedFinding(BaseModel):
"""A normalized representation of a finding from any scanner.

Attributes
----------
id: str
Unique identifier of the finding.
category: str
Normalized category (e.g., "secret", "dependency", "sast").
severity: str
Original severity string.
title: str
Short title or rule identifier.
description: str
Detailed description.
metadata: Dict[str, Any]
Raw metadata from the original finding.
"""

id: str
category: str
severity: str
title: str
description: str = ""
metadata: Dict[str, Any] = Field(default_factory=dict)

class AttackStep(BaseModel):
"""A single step in an attack path.

label: str – human readable label for the step (e.g., "AWS Secret").
finding_id: str | None – optional reference to the underlying finding.
"""
label: str
finding_id: str | None = None

class AttackPath(BaseModel):
"""A complete attack path consisting of ordered steps and a risk score."""
id: str
steps: List[AttackStep]
risk_score: float
65 changes: 65 additions & 0 deletions backend/app/attack_paths/scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

from typing import Dict

from .models import AttackPath, AttackStep

# Severity to numeric score (same as earlier prioritization)
_SEVERITY_SCORE: Dict[str, int] = {
"CRITICAL": 100,
"HIGH": 80,
"MEDIUM": 50,
"LOW": 20,
"INFO": 5,
}

# Category weight – higher weight for more exploitable categories
_CATEGORY_WEIGHT: Dict[str, int] = {
"secret": 35,
"dependency": 25,
"privilege_escalation": 30,
"sast": 20,
}

def _step_score(step: AttackStep) -> int:
"""Calculate a base score for a single step.

If the step is linked to a finding (has ``finding_id``) we look at its
``category`` and ``severity`` via the underlying ``AttackStep`` label – the
label is typically the finding title, but we also store the original
``category`` in the step's metadata when available. For intermediate nodes
created by the correlation engine we fall back to the category weight only.
"""
# For intermediate nodes the label comes from ``_CORRELATION_MAP`` – we can
# infer a pseudo‑category based on the label.
label = step.label.lower()
# Attempt to map label back to a known category; this is heuristic but works
# for the deterministic rules used.
if "secret" in label:
category = "secret"
elif "dependency" in label or "cve" in label:
category = "dependency"
elif "privilege" in label:
category = "privilege_escalation"
else:
category = "sast"

cat_weight = _CATEGORY_WEIGHT.get(category, 10)
# No severity for intermediate nodes – use a default medium value.
sev_score = 50 if step.finding_id is None else _SEVERITY_SCORE.get(step.label.upper(), 30)
return cat_weight + sev_score

def calculate_risk(path: AttackPath) -> float:
"""Calculate a risk score for an attack path.

The risk is a weighted sum of step scores, adjusted by chain length. The
final value is capped to the 0‑100 range.
"""
if not path.steps:
return 0.0
base = sum(_step_score(step) for step in path.steps)
length_factor = len(path.steps) * 5 # each step adds up to 5 points
raw_score = base + length_factor
# Normalise to 0‑100 – the maximum plausible raw_score is roughly 250.
normalized = min(100.0, (raw_score / 250.0) * 100.0)
return round(normalized, 2)
17 changes: 17 additions & 0 deletions backend/app/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ async def init_db():
last_updated TEXT DEFAULT (datetime('now'))
)
""")
await db.execute("""
CREATE TABLE IF NOT EXISTS root_cause_groups (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT,
created_at TEXT DEFAULT (datetime('now'))
)
""")
await db.execute("""
CREATE TABLE IF NOT EXISTS root_cause_group_finding (
group_id TEXT NOT NULL,
finding_id TEXT NOT NULL,
FOREIGN KEY(group_id) REFERENCES root_cause_groups(id),
FOREIGN KEY(finding_id) REFERENCES findings(id),
PRIMARY KEY(group_id, finding_id)
)
""")
await db.execute("""
CREATE TABLE IF NOT EXISTS dependency_links (
id TEXT PRIMARY KEY,
Expand Down
34 changes: 34 additions & 0 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@
logger = logging.getLogger(__name__)
app = FastAPI(title="PatchPilot API", version="0.1.0")

from app.ml.root_cause.api import router as root_cause_router
app.include_router(root_cause_router)

ALLOWED_ORIGINS = [
"http://localhost:5173",
"http://127.0.0.1:5173",
Expand Down Expand Up @@ -799,6 +802,37 @@ async def get_verify(job_id: str):
status_code=404, detail=f"No verify outcome recorded yet for job '{job_id}'"
)

# ==== Attack Path Correlation Endpoint ====
from .attack_paths.engine import generate_attack_paths

@app.get("/attack-paths/{job_id}")
async def get_attack_paths(job_id: str):
"""Return attack path analysis for a scan job.

Generates attack paths from stored findings, scores them, and returns the
highest‑risk path along with all paths.
"""
paths = await generate_attack_paths(job_id)
if not paths:
raise HTTPException(status_code=404, detail="No attack paths found")
# Sort by risk_score descending
paths.sort(key=lambda p: p.risk_score, reverse=True)
top = paths[0]
return {
"attack_path_id": top.id,
"risk_score": top.risk_score,
"steps": [step.label for step in top.steps],
"all_paths": [
{
"id": p.id,
"risk_score": p.risk_score,
"steps": [step.label for step in p.steps],
}
for p in paths
],
}


return dict(zip(columns, row))


Expand Down
3 changes: 3 additions & 0 deletions backend/app/ml/root_cause/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .engine import analyze_root_cause
from .clustering import cluster_findings
from .models import RootCauseFinding, RootCauseGroup, RootCauseResponse
19 changes: 19 additions & 0 deletions backend/app/ml/root_cause/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from fastapi import APIRouter, HTTPException
from ..models import Finding
from .engine import analyze_root_cause

router = APIRouter()

@router.get("/jobs/{job_id}/root-cause-groups", response_model=dict)
async def get_root_cause_groups(job_id: str):
"""Return root cause grouping for a given job.

The response matches the structure of ``RootCauseResponse`` defined in
``backend/app/ml/root_cause/models.py`` but is returned as a plain dict for
simplicity in FastAPI serialization.
"""
try:
result = await analyze_root_cause(job_id)
return result
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
Loading
Loading