From 83ac07c34866e596b5186e808671df3e7bbfaf5e Mon Sep 17 00:00:00 2001 From: hello-args Date: Wed, 17 Jun 2026 09:58:55 +0530 Subject: [PATCH 1/2] feat(scoring): Phase 3c attack graph polish (fixes, counterfactuals, UI) Add FixKind registry resolution, optional counterfactual remediation and path compression, dashboard layer filters with policy/inferred edge styling, and doctor --suggest-fixes for template remediations from scan reports. --- CHANGELOG.md | 2 + src/mcts/cli/doctor.py | 40 +++++++ src/mcts/cli/main.py | 39 +++++- src/mcts/core/scanner.py | 4 +- src/mcts/report/assets/dashboard.js | 81 +++++++++++-- src/mcts/report/assets/styles.css | 47 ++++++++ src/mcts/report/templates/dashboard.html | 3 +- src/mcts/scoring/attack_graph.py | 14 ++- src/mcts/scoring/attack_graph_builder.py | 29 ++++- src/mcts/scoring/attack_graph_models.py | 2 + src/mcts/scoring/fixes/registry.yaml | 63 ++++++++++ src/mcts/scoring/graph_compress.py | 54 +++++++++ src/mcts/scoring/graph_counterfactual.py | 41 +++++++ src/mcts/scoring/graph_explain.py | 27 +++-- src/mcts/scoring/graph_fixes.py | 43 +++++++ src/mcts/scoring/graph_suggest.py | 34 ++++++ src/mcts/scoring/graph_ui.py | 34 +++++- tests/scoring/test_graph_phase_3c.py | 144 +++++++++++++++++++++++ 18 files changed, 674 insertions(+), 27 deletions(-) create mode 100644 src/mcts/scoring/graph_compress.py create mode 100644 src/mcts/scoring/graph_counterfactual.py create mode 100644 src/mcts/scoring/graph_fixes.py create mode 100644 src/mcts/scoring/graph_suggest.py create mode 100644 tests/scoring/test_graph_phase_3c.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fe42a4c..8fc5168 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **Attack graph Phase 3c polish** — FixKind registry expansion, template `recommended_fixes` on paths, optional counterfactual remediation (`--attack-graph-counterfactuals`), UI path compression (`--attack-graph-compress-ui`), dashboard layer filter + policy/inferred edge styling, `mcts doctor --suggest-fixes --report` + - **Attack graph v3 rollout (Phase 3a/3b)** — default `attack_graph_version=3`; YAML template matcher replaces `AttackChainAnalyzer`; 12 chain templates including `SSRF_RESOURCE`, `ENV_SAMPLING`, `GIT_UNSCOPED`, `PROMPT_BYPASS`, `ELICIT_PHISH`, `TOCTOU_READ`, `READ_EXEC`, `CRED_THEFT`; capability overlap fallbacks; dashboard v3 paths + SARIF `mcts/attackPathExplanation`; R-23–R-25 regression fixtures + `tests/scoring/test_phase_3b_templates.py` - **Fact provenance metrics** — `fact_coverage()` reports `native_pct` / `silver_pct`; dashboard exposes `fact_provenance`; CI gates via `check_ttu_baseline.py` + corpus `--check-only` - **Scoring corpus** — `single_tool_overlap` fixture under enforce; Spearman calibration validates without mutating fixtures in CI diff --git a/src/mcts/cli/doctor.py b/src/mcts/cli/doctor.py index eef2cd9..cb12177 100644 --- a/src/mcts/cli/doctor.py +++ b/src/mcts/cli/doctor.py @@ -36,6 +36,8 @@ def run_doctor( deep: bool = False, json_output: bool = False, output: Path | None = None, + suggest_fixes: bool = False, + report: Path | None = None, ) -> int: """Run read-only preflight checks. Returns exit code (0 ok, 1 failures, 2 user error).""" root = path.expanduser().resolve() @@ -112,12 +114,44 @@ def run_doctor( if deep: warnings += _check_optional_toolchain(checks) + fix_suggestions: list[dict] = [] + if suggest_fixes: + if report is None: + checks.append( + ( + "warn", + "Suggest fixes", + "pass --report scan.json to list attack-graph remediations", + ) + ) + warnings += 1 + elif not report.exists(): + checks.append(("fail", "Suggest fixes", f"report not found: {report}")) + failures += 1 + else: + from mcts.scoring.graph_suggest import suggest_fixes_from_report + + fix_suggestions = suggest_fixes_from_report(report) + if fix_suggestions: + checks.append( + ( + "pass", + "Attack graph fixes", + f"{len(fix_suggestions)} template(s) with remediations", + ) + ) + else: + checks.append(("warn", "Attack graph fixes", "no matched templates in report")) + warnings += 1 + payload = { "path": str(root), "checks": [{"status": s, "label": label, "detail": d} for s, label, d in checks], "failures": failures, "warnings": warnings, } + if fix_suggestions: + payload["attack_graph_fix_suggestions"] = fix_suggestions if json_output or output is not None: import json @@ -136,6 +170,12 @@ def run_doctor( for status, label, detail in checks: icon = {"pass": "[green]✓[/green]", "warn": "[yellow]⚠[/yellow]", "fail": "[red]✗[/red]"}[status] console.print(f"{icon} {escape(label)}: {escape(detail)}") + if fix_suggestions: + console.print("\n[bold]Attack graph suggested fixes[/bold]") + for row in fix_suggestions: + console.print(f" [cyan]{escape(row['template_id'])}[/cyan] — {escape(row['title'])}") + for fix in row.get("recommended_fixes") or []: + console.print(f" • {escape(str(fix.get('description') or fix.get('kind', '')))}") if root.is_dir(): hints = format_discovery_hints(root) if hints: diff --git a/src/mcts/cli/main.py b/src/mcts/cli/main.py index 7384f11..440a907 100644 --- a/src/mcts/cli/main.py +++ b/src/mcts/cli/main.py @@ -807,6 +807,20 @@ def scan( help="Disable chain multiplier (chain_factor=1.0); under v2/both the analyzer still runs", ), ] = False, + attack_graph_counterfactuals: Annotated[ + bool, + typer.Option( + "--attack-graph-counterfactuals", + help="Attach counterfactual remediation to attack graph template findings", + ), + ] = False, + attack_graph_compress_ui: Annotated[ + bool, + typer.Option( + "--attack-graph-compress-ui", + help="Compress matched attack paths in report export for dashboard readability", + ), + ] = False, min_security_score: Annotated[ int | None, typer.Option( @@ -1077,6 +1091,8 @@ def scan( surface_scoped_analyzers=surface_scoped, scoring_mode=scoring.lower(), enable_attack_chains=not no_attack_chains, + attack_graph_enable_counterfactuals=attack_graph_counterfactuals, + attack_graph_compress_for_ui=attack_graph_compress_ui, min_security_score=min_security_score, max_absolute_risk=max_absolute_risk, max_risk_level=max_risk_level.lower() if max_risk_level else None, @@ -2089,11 +2105,32 @@ def doctor( bool, typer.Option("--json", help="Emit machine-readable JSON"), ] = False, + suggest_fixes: Annotated[ + bool, + typer.Option( + "--suggest-fixes", + help="List attack-graph template remediations from a prior scan report", + ), + ] = False, + report: Annotated[ + Path | None, + typer.Option( + "--report", + help="Scan JSON report for --suggest-fixes (e.g. mcts_analysis/scan-report.json)", + ), + ] = None, ) -> None: """Preflight checks before your first scan (no live probes).""" from mcts.cli.doctor import run_doctor - code = run_doctor(path, deep=deep, json_output=json_output, output=output) + code = run_doctor( + path, + deep=deep, + json_output=json_output, + output=output, + suggest_fixes=suggest_fixes, + report=report, + ) if code: raise typer.Exit(code=code) diff --git a/src/mcts/core/scanner.py b/src/mcts/core/scanner.py index 35f1ea2..8c07910 100644 --- a/src/mcts/core/scanner.py +++ b/src/mcts/core/scanner.py @@ -257,7 +257,9 @@ def analyze_server(self, server_info: MCPServerInfo) -> ScanReport: attack_graph_model = GraphBuilder(config=self.config).build(server_info, findings) chain_findings = attack_graph_model.to_findings() findings.extend(chain_findings) - raw_graph = attack_graph_model.to_report_dict() + raw_graph = attack_graph_model.to_report_dict( + compress_for_ui=self.config.attack_graph_compress_for_ui, + ) proven_legacy = { chain.legacy_finding_id for chain in attack_graph_model.matched_chains diff --git a/src/mcts/report/assets/dashboard.js b/src/mcts/report/assets/dashboard.js index 4b1e461..e02a4aa 100644 --- a/src/mcts/report/assets/dashboard.js +++ b/src/mcts/report/assets/dashboard.js @@ -2068,13 +2068,26 @@ if (!svg) return; const graph = DATA.attack_graph || {}; const nodes = graph.nodes || []; - const edges = graph.edges || []; + const allEdges = graph.edges || []; + const activeLayer = window.__attackGraphLayer || "all"; + const edges = + activeLayer === "all" + ? allEdges + : allEdges.filter((e) => (e.layer || "dataflow") === activeLayer); + renderAttackGraphLayers(graph.layers_present || []); if (!nodes.length) { svg.innerHTML = 'No attack chain data'; renderAttackPaths(graph); return; } + const visibleNodeIds = new Set(); + edges.forEach((e) => { + visibleNodeIds.add(e.from || e.from_node); + visibleNodeIds.add(e.to || e.to_node); + }); + const visibleNodes = nodes.filter((n) => visibleNodeIds.has(n.id)); + const width = svg.clientWidth || 800; const height = 400; const cx = width / 2; @@ -2082,8 +2095,8 @@ const radius = Math.min(width, height) * 0.32; const positions = {}; - nodes.forEach((n, i) => { - const angle = (i / nodes.length) * Math.PI * 2 - Math.PI / 2; + visibleNodes.forEach((n, i) => { + const angle = (i / Math.max(visibleNodes.length, 1)) * Math.PI * 2 - Math.PI / 2; positions[n.id] = { x: cx + radius * Math.cos(angle), y: cy + radius * Math.sin(angle), @@ -2098,15 +2111,17 @@ const from = positions[fromId]; const to = positions[toId]; if (!from || !to) return; - markup += ``; + const edgeClass = e.edge_class ? ` graph-edge ${e.edge_class}` : " graph-edge"; + markup += ``; }); - nodes.forEach((n) => { + visibleNodes.forEach((n) => { const p = positions[n.id]; if (!p) return; const label = (n.label || n.id || "").length > 14 ? (n.label || n.id || "").slice(0, 12) + "…" : (n.label || n.id || ""); + const trust = n.trust ? ` data-trust="${escapeHtml(n.trust)}"` : ""; markup += ` - + ${escapeHtml(label)} `; @@ -2117,6 +2132,33 @@ renderAttackPaths(graph); } + function renderAttackGraphLayers(layers) { + const toolbar = document.getElementById("attack-graph-layers"); + if (!toolbar) return; + const unique = [...new Set(layers.filter(Boolean))]; + if (!unique.length) { + toolbar.hidden = true; + toolbar.innerHTML = ""; + return; + } + toolbar.hidden = false; + const active = window.__attackGraphLayer || "all"; + const buttons = [ + ``, + ...unique.map( + (layer) => + ``, + ), + ]; + toolbar.innerHTML = buttons.join(""); + toolbar.querySelectorAll(".graph-layer-btn").forEach((btn) => { + btn.addEventListener("click", () => { + window.__attackGraphLayer = btn.dataset.layer || "all"; + renderAttackGraph(); + }); + }); + } + function renderAttackPaths(graph) { const panel = document.getElementById("attack-paths-panel"); const list = document.getElementById("attack-paths-list"); @@ -2128,7 +2170,12 @@ return; } panel.hidden = false; - list.innerHTML = paths + const compression = graph.compression_stats; + const compressionNote = + compression && compression.dropped > 0 + ? `

Showing ${compression.compressed_count} of ${compression.original_count} matched paths (UI compression).

` + : ""; + list.innerHTML = compressionNote + paths .map((path, idx) => { const template = path.template_id ? `${escapeHtml(path.template_id)}` : `Path ${idx + 1}`; const meta = [ @@ -2144,7 +2191,25 @@ return `
  • ${stepIdx + 1}. ${escapeHtml(msg)}
  • `; }) .join(""); - return `
    ${template}${meta ? ` (${escapeHtml(meta)})` : ""}
    ${steps ? `
      ${steps}
    ` : ""}
    `; + const fixes = (path.recommended_fixes || []) + .map((fix) => { + const label = fix.description || fix.kind || ""; + return `
  • ${escapeHtml(label)}
  • `; + }) + .join(""); + const fixesBlock = fixes + ? `
    Suggested fixes
      ${fixes}
    ` + : ""; + const counterfactual = path.counterfactual_remediation; + const cfActions = counterfactual && counterfactual.actions + ? counterfactual.actions + .map((action) => `
  • ${escapeHtml(action.action || action)}
  • `) + .join("") + : ""; + const cfBlock = cfActions + ? `
    Counterfactual
      ${cfActions}
    ` + : ""; + return `
    ${template}${meta ? ` (${escapeHtml(meta)})` : ""}
    ${steps ? `
      ${steps}
    ` : ""}${fixesBlock}${cfBlock}
    `; }) .join(""); } diff --git a/src/mcts/report/assets/styles.css b/src/mcts/report/assets/styles.css index 4df0310..2b34841 100644 --- a/src/mcts/report/assets/styles.css +++ b/src/mcts/report/assets/styles.css @@ -2909,6 +2909,53 @@ body.modal-open { marker-end: url(#arrowhead); } +.graph-edge.policy { + stroke: rgba(148, 163, 184, 0.75); + stroke-dasharray: 5 4; +} + +.graph-edge.inferred { + stroke: rgba(148, 163, 184, 0.45); +} + +.graph-edge.runtime { + stroke: rgba(34, 197, 94, 0.7); +} + +.graph-layer-toolbar { + display: flex; + flex-wrap: wrap; + gap: 8px; + margin-bottom: 12px; +} + +.graph-layer-btn { + font-size: 12px; + padding: 4px 10px; + border-radius: 999px; + border: 1px solid var(--border); + background: transparent; + color: var(--muted); + cursor: pointer; +} + +.graph-layer-btn.active { + border-color: var(--accent); + color: var(--text); + background: rgba(59, 130, 246, 0.12); +} + +.attack-path-fixes { + margin-top: 8px; + font-size: 12px; + color: var(--muted); +} + +.attack-path-fixes ul { + margin: 4px 0 0; + padding-left: 18px; +} + /* Analyzers */ .analyzer-grid { display: grid; diff --git a/src/mcts/report/templates/dashboard.html b/src/mcts/report/templates/dashboard.html index 91ba324..c452347 100644 --- a/src/mcts/report/templates/dashboard.html +++ b/src/mcts/report/templates/dashboard.html @@ -417,7 +417,8 @@

    Attack Paths

    How tools could be chained together for a multi-step attack (read → exfiltrate, etc.).

    View related issues → -

    Each arrow shows a possible step between tools.

    +

    Each arrow shows a possible step between tools. Policy edges are dashed; inferred edges are muted.

    +
    diff --git a/src/mcts/scoring/attack_graph.py b/src/mcts/scoring/attack_graph.py index 1342400..fa05982 100644 --- a/src/mcts/scoring/attack_graph.py +++ b/src/mcts/scoring/attack_graph.py @@ -197,12 +197,17 @@ def filter_layers(self, layers: Iterable[GraphLayer | str]) -> AttackGraph: filtered.total_risk_score = self.total_risk_score return filtered - def to_report_dict(self) -> dict[str, Any]: + def to_report_dict(self, *, compress_for_ui: bool = False) -> dict[str, Any]: node_layers = {n.layer.value for n in self._nodes.values()} edge_layers = {e.layer.value for e in self._edges.values()} layers_present = sorted(node_layers | edge_layers) paths = [self._path_to_dict(chain) for chain in self.matched_chains] - return { + compression_stats: dict[str, Any] | None = None + if compress_for_ui: + from mcts.scoring.graph_compress import compress_paths + + paths, compression_stats = compress_paths(paths) + payload = { "version": 3, "nodes": [n.model_dump(mode="json") for n in self._nodes.values()], "edges": [e.model_dump(mode="json") for e in self._edges.values()], @@ -211,6 +216,9 @@ def to_report_dict(self) -> dict[str, Any]: "total_risk_score": round(self.total_risk_score, 2), "layers_present": layers_present, } + if compression_stats: + payload["compression_stats"] = compression_stats + return payload def _path_to_dict(self, chain: MatchedChain) -> dict[str, Any]: path = chain.path @@ -229,6 +237,8 @@ def _path_to_dict(self, chain: MatchedChain) -> dict[str, Any]: "finding_ids": [ chain.legacy_finding_id or f"chain-{chain.template_id.lower().replace('_', '-')}" ], + "recommended_fixes": chain.recommended_fixes, + "counterfactual_remediation": chain.counterfactual_remediation, } def to_findings(self) -> list[Finding]: diff --git a/src/mcts/scoring/attack_graph_builder.py b/src/mcts/scoring/attack_graph_builder.py index bf8d183..4f2255a 100644 --- a/src/mcts/scoring/attack_graph_builder.py +++ b/src/mcts/scoring/attack_graph_builder.py @@ -2,11 +2,13 @@ from __future__ import annotations +from typing import Any + from mcts.core.config import ScanConfig from mcts.mcp.models import MCPServerInfo from mcts.reporting.models import Finding from mcts.scoring.attack_graph import AttackGraph -from mcts.scoring.attack_graph_models import EdgeKind, canonical_node_id +from mcts.scoring.attack_graph_models import EdgeKind, MatchedChain, canonical_node_id from mcts.scoring.attack_graph_policy import apply_policy_edges, seed_server_surfaces from mcts.scoring.attack_graph_producers import export_all_edges from mcts.scoring.graph_matcher import match_all_templates @@ -46,10 +48,35 @@ def build(self, server: MCPServerInfo, findings: list[Finding]) -> AttackGraph: templates, top_per_template=3, ) + matched = self._attach_graph_polish(matched, counterfactuals=self.config.attack_graph_enable_counterfactuals) graph.matched_chains = matched graph.total_risk_score = sum(chain.chain_risk_score for chain in matched) return graph + def _attach_graph_polish( + self, + chains: list[MatchedChain], + *, + counterfactuals: bool, + ) -> list[MatchedChain]: + from mcts.scoring.graph_counterfactual import counterfactual_for_chain + from mcts.scoring.graph_fixes import describe_fixes + from mcts.scoring.graph_templates import load_chain_templates + + templates = {template.id: template for template in load_chain_templates()} + enriched: list[MatchedChain] = [] + for chain in chains: + template = templates.get(chain.template_id) + fixes = describe_fixes(list(template.recommended_fixes)) if template else [] + update: dict[str, Any] = {"recommended_fixes": fixes} + if counterfactuals: + update["counterfactual_remediation"] = counterfactual_for_chain( + chain.template_id, + chain.path.tool_names_on_path(), + ) + enriched.append(chain.model_copy(update=update)) + return enriched + def _add_corroborated_invokes_edges(self, graph: AttackGraph, server: MCPServerInfo) -> None: """Optional overlap chains — disabled by default (spec forever default False).""" tools = server.tools diff --git a/src/mcts/scoring/attack_graph_models.py b/src/mcts/scoring/attack_graph_models.py index 27895bf..e163235 100644 --- a/src/mcts/scoring/attack_graph_models.py +++ b/src/mcts/scoring/attack_graph_models.py @@ -236,3 +236,5 @@ class MatchedChain(BaseModel): trust_boundary_crossings: int = 0 explanation: list[ExplanationStep] = Field(default_factory=list) legacy_finding_id: str | None = None + counterfactual_remediation: dict[str, Any] | None = None + recommended_fixes: list[dict[str, Any]] = Field(default_factory=list) diff --git a/src/mcts/scoring/fixes/registry.yaml b/src/mcts/scoring/fixes/registry.yaml index d5b96a1..c8a965e 100644 --- a/src/mcts/scoring/fixes/registry.yaml +++ b/src/mcts/scoring/fixes/registry.yaml @@ -20,3 +20,66 @@ remove_env_tool: mutates: - remove_node: pattern: tool:get-env + +bind_localhost_only: + description: Bind HTTP/SSE listeners to 127.0.0.1 only + +disable_network_egress: + description: Remove or block outbound network tools and httpx/fetch sinks + +add_url_policy: + description: Enforce URL allowlist and block RFC1918/metadata hosts + +disable_prompt_egress: + description: Prevent prompt handlers from performing network egress + +disable_elicitation: + description: Disable client URL elicitation tools and CAP-05 surfaces + +require_user_confirmation: + description: Require explicit user confirmation before elicitation or sampling + +disable_sampling: + description: Disable client sampling capability exposure + +gate_sampling_behind_auth: + description: Require authenticated transport before sampling is enabled + +scope_git_tools: + description: Require repository allowlist on all git read tools + +add_repo_allowlist: + description: Add mandatory --repository / repo scope CLI flag + +atomic_open: + description: Open files with O_NOFOLLOW and atomic read-after-validate + +disable_symlink_follow: + description: Use lstat per directory entry; never follow symlinks on read + +isolate_read_and_exec_tools: + description: Split read and exec tools across trust boundaries or separate servers + +validate_inputs_before_exec: + description: Validate and sandbox shell/exec inputs before execution + +remove_egress_tool: + description: Remove webhook/fetch/egress tools from production surface + +scope_read_tools: + description: Restrict read tools to scoped roots and path allowlists + +block_credential_tools: + description: Remove get-env and credential-access tools from production + +use_resource_link_not_inline_blob: + description: Return resource links instead of inline session resource blobs + +cap_session_resources: + description: Cap session resource count and TTL; deny gzip-to-resource staging + +cap_memory_writes: + description: Limit memory graph write tools and observation size + +trust_model_documentation: + description: Document memory trust model and cross-session read risks diff --git a/src/mcts/scoring/graph_compress.py b/src/mcts/scoring/graph_compress.py new file mode 100644 index 0000000..7b68523 --- /dev/null +++ b/src/mcts/scoring/graph_compress.py @@ -0,0 +1,54 @@ +"""Path compression for attack graph dashboard export (Phase 3c).""" + +from __future__ import annotations + +from typing import Any + + +def _path_key(path: dict[str, Any]) -> tuple[str, tuple[str, ...]]: + template_id = str(path.get("template_id") or "") + tools = tuple(sorted(path.get("tools_on_path") or path.get("nodes") or [])) + return template_id, tools + + +def compress_paths( + paths: list[dict[str, Any]], + *, + max_total: int = 12, + max_per_template: int = 2, +) -> tuple[list[dict[str, Any]], dict[str, Any]]: + """Dedupe by template + tool set; keep highest chain_risk_score paths.""" + if not paths or len(paths) <= max_total: + return paths, {"original_count": len(paths), "compressed_count": len(paths), "dropped": 0} + + by_template: dict[str, list[dict[str, Any]]] = {} + for path in paths: + template_id = str(path.get("template_id") or "unknown") + by_template.setdefault(template_id, []).append(path) + + ranked: list[dict[str, Any]] = [] + seen_keys: set[tuple[str, tuple[str, ...]]] = set() + for _template_id, group in sorted(by_template.items()): + group_sorted = sorted( + group, + key=lambda row: float(row.get("chain_risk_score") or 0), + reverse=True, + ) + kept = 0 + for path in group_sorted: + key = _path_key(path) + if key in seen_keys: + continue + seen_keys.add(key) + ranked.append(path) + kept += 1 + if kept >= max_per_template: + break + + ranked.sort(key=lambda row: float(row.get("chain_risk_score") or 0), reverse=True) + compressed = ranked[:max_total] + return compressed, { + "original_count": len(paths), + "compressed_count": len(compressed), + "dropped": max(0, len(paths) - len(compressed)), + } diff --git a/src/mcts/scoring/graph_counterfactual.py b/src/mcts/scoring/graph_counterfactual.py new file mode 100644 index 0000000..883c2a2 --- /dev/null +++ b/src/mcts/scoring/graph_counterfactual.py @@ -0,0 +1,41 @@ +"""Counterfactual remediation for matched attack graph chains (Phase 3c).""" + +from __future__ import annotations + +from typing import Any + +from mcts.scoring.graph_fixes import describe_fixes +from mcts.scoring.graph_templates import load_chain_templates + + +def counterfactual_for_chain( + template_id: str, + tools_on_path: list[str], +) -> dict[str, Any]: + """Build counterfactual payload aligned with trust-layer evidence shape.""" + templates = {template.id: template for template in load_chain_templates()} + template = templates.get(template_id) + fix_kinds = list(template.recommended_fixes) if template else [] + fixes = describe_fixes(fix_kinds) + + triggered: list[str] = [] + actions: list[dict[str, str]] = [] + tool_label = ", ".join(tools_on_path) if tools_on_path else "matched path" + for fix in fixes: + kind = str(fix.get("kind", "")) + description = str(fix.get("description") or kind.replace("_", " ")) + triggered.append(f"{template_id}: {description} ({tool_label})") + actions.append( + { + "action": description, + "removes": kind or template_id, + } + ) + + return { + "triggered_by": triggered, + "removing_any_one_eliminates_finding": len(fixes) > 1, + "actions": actions, + "recommended_fixes": fixes, + "template_id": template_id, + } diff --git a/src/mcts/scoring/graph_explain.py b/src/mcts/scoring/graph_explain.py index 870011f..feb5d03 100644 --- a/src/mcts/scoring/graph_explain.py +++ b/src/mcts/scoring/graph_explain.py @@ -86,6 +86,21 @@ def matched_chain_to_finding(template_id: str, chains: list[MatchedChain]) -> Fi "explanation": [step.model_dump(mode="json") for step in chain.explanation], } ) + evidence_kwargs: dict[str, Any] = { + "template_id": template_id, + "path_proven": top.path.hop_count >= 2, + "chain_confidence": round(top.path_confidence, 3), + "path_reachability": round(top.path_reachability, 3), + "chain_risk_score": round(top.chain_risk_score, 3), + "trust_boundary_crossings": top.trust_boundary_crossings, + "exploit_cost": template.exploit_cost, + "paths": paths_payload, + "finding_class": template.finding_class, + } + if top.counterfactual_remediation: + evidence_kwargs["counterfactual_remediation"] = top.counterfactual_remediation + if top.recommended_fixes: + evidence_kwargs["recommended_fixes"] = top.recommended_fixes builder = ( FindingBuilder( finding_id=finding_id, @@ -97,17 +112,7 @@ def matched_chain_to_finding(template_id: str, chains: list[MatchedChain]) -> Fi ) .confidence(top.path_confidence) .technique("MCTS-T-attack-graph") - .evidence( - template_id=template_id, - path_proven=top.path.hop_count >= 2, - chain_confidence=round(top.path_confidence, 3), - path_reachability=round(top.path_reachability, 3), - chain_risk_score=round(top.chain_risk_score, 3), - trust_boundary_crossings=top.trust_boundary_crossings, - exploit_cost=template.exploit_cost, - paths=paths_payload, - finding_class=template.finding_class, - ) + .evidence(**evidence_kwargs) .fact( rule_id=template_id, match=template.title, diff --git a/src/mcts/scoring/graph_fixes.py b/src/mcts/scoring/graph_fixes.py new file mode 100644 index 0000000..db53524 --- /dev/null +++ b/src/mcts/scoring/graph_fixes.py @@ -0,0 +1,43 @@ +"""FixKind registry for attack graph template remediations (Phase 3c).""" + +from __future__ import annotations + +from functools import lru_cache +from pathlib import Path +from typing import Any + +import yaml + +FIXES_REGISTRY_PATH = Path(__file__).resolve().parent / "fixes" / "registry.yaml" + + +@lru_cache(maxsize=1) +def load_fixes_registry() -> dict[str, dict[str, Any]]: + if not FIXES_REGISTRY_PATH.exists(): + return {} + raw = yaml.safe_load(FIXES_REGISTRY_PATH.read_text(encoding="utf-8")) or {} + return {str(key): value for key, value in raw.items() if isinstance(value, dict)} + + +def resolve_fix(kind: str) -> dict[str, Any] | None: + entry = load_fixes_registry().get(kind) + if not entry: + return None + return {"kind": kind, **entry} + + +def describe_fixes(fix_kinds: list[str]) -> list[dict[str, Any]]: + """Map template recommended_fixes keys to registry descriptions.""" + described: list[dict[str, Any]] = [] + for kind in fix_kinds: + entry = resolve_fix(kind) + if entry: + described.append(entry) + else: + described.append( + { + "kind": kind, + "description": kind.replace("_", " "), + } + ) + return described diff --git a/src/mcts/scoring/graph_suggest.py b/src/mcts/scoring/graph_suggest.py new file mode 100644 index 0000000..2b6acb1 --- /dev/null +++ b/src/mcts/scoring/graph_suggest.py @@ -0,0 +1,34 @@ +"""Suggest attack-graph remediations from a scan report (Phase 3c).""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from mcts.scoring.graph_fixes import describe_fixes +from mcts.scoring.graph_templates import load_chain_templates + + +def suggest_fixes_from_report(report_path: Path) -> list[dict[str, Any]]: + payload = json.loads(report_path.read_text(encoding="utf-8")) + attack_graph = payload.get("attack_graph") or {} + templates = {template.id: template for template in load_chain_templates()} + suggestions: list[dict[str, Any]] = [] + seen: set[str] = set() + + for template_id in attack_graph.get("templates_matched") or []: + if template_id in seen: + continue + seen.add(template_id) + template = templates.get(str(template_id)) + if not template: + continue + suggestions.append( + { + "template_id": template_id, + "title": template.title, + "recommended_fixes": describe_fixes(list(template.recommended_fixes)), + } + ) + return suggestions diff --git a/src/mcts/scoring/graph_ui.py b/src/mcts/scoring/graph_ui.py index 5e7c43f..b2f45b9 100644 --- a/src/mcts/scoring/graph_ui.py +++ b/src/mcts/scoring/graph_ui.py @@ -22,6 +22,9 @@ def normalize_attack_graph_for_ui(graph: dict[str, Any]) -> dict[str, Any]: "label": node.get("label") or _short_label(node.get("id", "")), "type": node.get("kind", "tool"), "kind": node.get("kind"), + "layer": node.get("layer"), + "trust": node.get("trust"), + "sensitivity": node.get("sensitivity"), } ) @@ -31,13 +34,18 @@ def normalize_attack_graph_for_ui(graph: dict[str, Any]) -> dict[str, Any]: continue src = edge.get("from_node") or edge.get("from") dst = edge.get("to_node") or edge.get("to") + policy = bool(edge.get("policy", False)) + evidence_strength = edge.get("evidence_strength") or "static" edges.append( { "from": src, "to": dst, "label": edge.get("label") or edge.get("kind", ""), "kind": edge.get("kind"), - "policy": edge.get("policy", False), + "layer": edge.get("layer"), + "policy": policy, + "evidence_strength": evidence_strength, + "edge_class": "policy" if policy else _edge_class(evidence_strength), } ) @@ -56,9 +64,16 @@ def normalize_attack_graph_for_ui(graph: dict[str, Any]) -> dict[str, Any]: "chain_risk_score": path.get("chain_risk_score"), "explanation": path.get("explanation") or [], "finding_ids": path.get("finding_ids") or [], + "recommended_fixes": path.get("recommended_fixes") or [], + "counterfactual_remediation": path.get("counterfactual_remediation"), } ) + layers_present = graph.get("layers_present") or sorted( + {layer for node in nodes if (layer := node.get("layer"))} + | {layer for edge in edges if (layer := edge.get("layer"))} + ) + return { "version": version, "nodes": nodes, @@ -66,7 +81,8 @@ def normalize_attack_graph_for_ui(graph: dict[str, Any]) -> dict[str, Any]: "paths": paths, "templates_matched": graph.get("templates_matched") or [], "total_risk_score": graph.get("total_risk_score"), - "layers_present": graph.get("layers_present") or [], + "layers_present": layers_present, + "compression_stats": graph.get("compression_stats"), } @@ -94,9 +110,23 @@ def format_attack_path_explanation(finding: Any) -> str: lines.append(f"{idx}. {step.get('message', step)}") else: lines.append(f"{idx}. {step}") + counterfactual = evidence.get("counterfactual_remediation") + if isinstance(counterfactual, dict) and counterfactual.get("actions"): + lines.append("Counterfactual fixes:") + for action in counterfactual.get("actions") or []: + if isinstance(action, dict): + lines.append(f"- {action.get('action', action)}") return "\n".join(lines) if lines else (finding.description or "") +def _edge_class(evidence_strength: str) -> str: + if evidence_strength == "runtime": + return "runtime" + if evidence_strength == "heuristic": + return "inferred" + return "proven" + + def _short_label(node_id: str) -> str: if ":" in node_id: return node_id.split(":", 1)[1] diff --git a/tests/scoring/test_graph_phase_3c.py b/tests/scoring/test_graph_phase_3c.py new file mode 100644 index 0000000..38b885e --- /dev/null +++ b/tests/scoring/test_graph_phase_3c.py @@ -0,0 +1,144 @@ +"""Phase 3c graph polish: fixes registry, counterfactuals, compression, UI.""" + +from __future__ import annotations + +from mcts.core.config import ScanConfig +from mcts.core.scanner import Scanner +from mcts.scoring.attack_graph import AttackGraph +from mcts.scoring.attack_graph_builder import GraphBuilder +from mcts.scoring.attack_graph_models import EdgeKind, GraphEdge, GraphLayer, NodeKind +from mcts.scoring.graph_compress import compress_paths +from mcts.scoring.graph_counterfactual import counterfactual_for_chain +from mcts.scoring.graph_fixes import describe_fixes, load_fixes_registry +from mcts.scoring.graph_ui import normalize_attack_graph_for_ui + + +def test_fixes_registry_loads() -> None: + registry = load_fixes_registry() + assert "add_http_auth" in registry + assert "remove_env_tool" in registry + + +def test_describe_fixes_uses_registry() -> None: + fixes = describe_fixes(["add_http_auth", "unknown_fix_kind"]) + assert fixes[0]["description"] + assert fixes[1]["kind"] == "unknown_fix_kind" + + +def test_counterfactual_for_chain() -> None: + payload = counterfactual_for_chain("HTTP_TAKEOVER", ["get-env"]) + assert payload["template_id"] == "HTTP_TAKEOVER" + assert payload["actions"] + assert payload["recommended_fixes"] + + +def test_compress_paths_dedupes() -> None: + paths = [ + {"template_id": "A", "tools_on_path": ["t1"], "chain_risk_score": 1}, + {"template_id": "A", "tools_on_path": ["t1"], "chain_risk_score": 2}, + {"template_id": "B", "tools_on_path": ["t2"], "chain_risk_score": 5}, + ] + compressed, stats = compress_paths(paths, max_total=2, max_per_template=1) + assert len(compressed) == 2 + assert stats["dropped"] == 1 + + +def test_graph_builder_attaches_recommended_fixes() -> None: + graph = AttackGraph() + graph.seed_sources_and_sinks() + graph.add_node(NodeKind.TOOL, "trigger-url-elicitation", label="trigger-url-elicitation") + graph.add_node(NodeKind.CAPABILITY, "elicitation", label="elicitation") + graph.merge_edge( + GraphEdge( + id="edge-elicit", + kind=EdgeKind.TRIGGERS, + from_node="tool:trigger-url-elicitation", + to_node="capability:elicitation", + layer=GraphLayer.TRUST_BOUNDARY, + confidence=0.75, + reachability=0.8, + ) + ) + from mcts.mcp.models import MCPServerInfo, MCPTool + + server = MCPServerInfo( + name="demo", + tools=[MCPTool(name="trigger-url-elicitation", description="elicitation tool")], + ) + builder = GraphBuilder(config=ScanConfig(target=".", attack_graph_enable_counterfactuals=True)) + built = builder.build(server, []) + elicit = next((c for c in built.matched_chains if c.template_id == "ELICIT_PHISH"), None) + assert elicit is not None + assert elicit.recommended_fixes + assert elicit.counterfactual_remediation + + +def test_to_report_dict_compresses_when_requested() -> None: + graph = AttackGraph() + graph.matched_chains = [] + report = graph.to_report_dict(compress_for_ui=False) + assert "compression_stats" not in report + + +def test_normalize_ui_includes_layers_and_edge_class() -> None: + raw = { + "version": 3, + "nodes": [ + { + "id": "tool:fetch", + "kind": "tool", + "label": "fetch", + "layer": "mcp_surface", + "trust": "semi_trusted", + "sensitivity": "medium", + } + ], + "edges": [ + { + "from_node": "tool:fetch", + "to_node": "sink:external_network", + "kind": "EGRESS", + "layer": "dataflow", + "policy": False, + "evidence_strength": "static", + } + ], + "paths": [], + "layers_present": ["dataflow", "mcp_surface"], + } + ui = normalize_attack_graph_for_ui(raw) + assert ui["nodes"][0]["trust"] == "semi_trusted" + assert ui["edges"][0]["edge_class"] == "proven" + assert "dataflow" in ui["layers_present"] + + +def test_suggest_fixes_from_report(tmp_path) -> None: + from mcts.scoring.graph_suggest import suggest_fixes_from_report + + report = { + "attack_graph": { + "templates_matched": ["ELICIT_PHISH", "ELICIT_PHISH"], + } + } + path = tmp_path / "scan.json" + path.write_text(__import__("json").dumps(report), encoding="utf-8") + rows = suggest_fixes_from_report(path) + assert len(rows) == 1 + assert rows[0]["template_id"] == "ELICIT_PHISH" + + +def test_scanner_counterfactual_flag() -> None: + target = "tests/fixtures/monorepo-mini/src/everything/tools/trigger-url-elicitation.ts" + report = Scanner( + ScanConfig( + target=target, + surface_depth="full", + attack_graph_enable_counterfactuals=True, + ) + ).run() + chain = next( + f + for f in report.findings + if f.analyzer == "attack_graph" and f.evidence.get("template_id") == "ELICIT_PHISH" + ) + assert chain.evidence.get("counterfactual_remediation") From a9ed11860007dd6d286c387893acc0a53a52224c Mon Sep 17 00:00:00 2001 From: hello-args Date: Wed, 17 Jun 2026 23:02:32 +0530 Subject: [PATCH 2/2] feat(scoring): finish Phase 3c graph polish (mutates, inventory, defaults) Apply FixKind registry mutates at runtime with template re-match simulation, attach inventory-layer cross-server edges when fleet inventory is present, and enable counterfactuals plus UI path compression by default with CLI opt-out. --- CHANGELOG.md | 2 +- src/mcts/cli/main.py | 12 +- src/mcts/core/config.py | 4 +- src/mcts/core/scanner.py | 6 +- src/mcts/scoring/attack_graph_builder.py | 22 ++- src/mcts/scoring/graph_counterfactual.py | 22 ++- src/mcts/scoring/graph_inventory.py | 148 ++++++++++++++++++++ src/mcts/scoring/graph_mutate.py | 164 +++++++++++++++++++++++ tests/scoring/test_graph_phase_3c.py | 105 ++++++++++++++- 9 files changed, 467 insertions(+), 18 deletions(-) create mode 100644 src/mcts/scoring/graph_inventory.py create mode 100644 src/mcts/scoring/graph_mutate.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fc5168..f4ef67b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- **Attack graph Phase 3c polish** — FixKind registry expansion, template `recommended_fixes` on paths, optional counterfactual remediation (`--attack-graph-counterfactuals`), UI path compression (`--attack-graph-compress-ui`), dashboard layer filter + policy/inferred edge styling, `mcts doctor --suggest-fixes --report` +- **Attack graph Phase 3c polish** — FixKind registry with runtime `graph_mutate` engine (apply registry `mutates` and simulate template elimination), inventory multi-server graph layer (`graph_inventory`), counterfactual fix simulation on paths, default-on counterfactuals and UI compression (`--no-attack-graph-counterfactuals`, `--no-attack-graph-compress-ui`), dashboard layer filter + policy/inferred edge styling, `mcts doctor --suggest-fixes --report` - **Attack graph v3 rollout (Phase 3a/3b)** — default `attack_graph_version=3`; YAML template matcher replaces `AttackChainAnalyzer`; 12 chain templates including `SSRF_RESOURCE`, `ENV_SAMPLING`, `GIT_UNSCOPED`, `PROMPT_BYPASS`, `ELICIT_PHISH`, `TOCTOU_READ`, `READ_EXEC`, `CRED_THEFT`; capability overlap fallbacks; dashboard v3 paths + SARIF `mcts/attackPathExplanation`; R-23–R-25 regression fixtures + `tests/scoring/test_phase_3b_templates.py` - **Fact provenance metrics** — `fact_coverage()` reports `native_pct` / `silver_pct`; dashboard exposes `fact_provenance`; CI gates via `check_ttu_baseline.py` + corpus `--check-only` diff --git a/src/mcts/cli/main.py b/src/mcts/cli/main.py index 440a907..fba762f 100644 --- a/src/mcts/cli/main.py +++ b/src/mcts/cli/main.py @@ -810,17 +810,17 @@ def scan( attack_graph_counterfactuals: Annotated[ bool, typer.Option( - "--attack-graph-counterfactuals", - help="Attach counterfactual remediation to attack graph template findings", + "--attack-graph-counterfactuals/--no-attack-graph-counterfactuals", + help="Attach counterfactual remediation to attack graph template findings (default on)", ), - ] = False, + ] = True, attack_graph_compress_ui: Annotated[ bool, typer.Option( - "--attack-graph-compress-ui", - help="Compress matched attack paths in report export for dashboard readability", + "--attack-graph-compress-ui/--no-attack-graph-compress-ui", + help="Compress matched attack paths in report export for dashboard readability (default on)", ), - ] = False, + ] = True, min_security_score: Annotated[ int | None, typer.Option( diff --git a/src/mcts/core/config.py b/src/mcts/core/config.py index e2dc133..5ad2799 100644 --- a/src/mcts/core/config.py +++ b/src/mcts/core/config.py @@ -158,8 +158,8 @@ class ScanConfig(BaseModel): attack_graph_min_confidence: float = Field(default=0.0, ge=0.0, le=1.0) attack_graph_confidence_mode: str = "geometric_mean" attack_graph_include_overlap_chains: bool = False - attack_graph_enable_counterfactuals: bool = False - attack_graph_compress_for_ui: bool = False + attack_graph_enable_counterfactuals: bool = True + attack_graph_compress_for_ui: bool = True @classmethod def _validate_min_evidence_strength(cls, value: str | None) -> str | None: diff --git a/src/mcts/core/scanner.py b/src/mcts/core/scanner.py index 8c07910..8defd76 100644 --- a/src/mcts/core/scanner.py +++ b/src/mcts/core/scanner.py @@ -254,7 +254,11 @@ def analyze_server(self, server_info: MCPServerInfo) -> ScanReport: from mcts.scoring.attack_graph_builder import GraphBuilder from mcts.scoring.capability_overlap import emit_capability_overlap_findings - attack_graph_model = GraphBuilder(config=self.config).build(server_info, findings) + attack_graph_model = GraphBuilder(config=self.config).build( + server_info, + findings, + inventory=self.inventory, + ) chain_findings = attack_graph_model.to_findings() findings.extend(chain_findings) raw_graph = attack_graph_model.to_report_dict( diff --git a/src/mcts/scoring/attack_graph_builder.py b/src/mcts/scoring/attack_graph_builder.py index 4f2255a..ba3d09a 100644 --- a/src/mcts/scoring/attack_graph_builder.py +++ b/src/mcts/scoring/attack_graph_builder.py @@ -5,6 +5,7 @@ from typing import Any from mcts.core.config import ScanConfig +from mcts.inventory.models import InventoryEntry from mcts.mcp.models import MCPServerInfo from mcts.reporting.models import Finding from mcts.scoring.attack_graph import AttackGraph @@ -21,7 +22,13 @@ class GraphBuilder: def __init__(self, config: ScanConfig | None = None) -> None: self.config = config or ScanConfig(target=".") - def build(self, server: MCPServerInfo, findings: list[Finding]) -> AttackGraph: + def build( + self, + server: MCPServerInfo, + findings: list[Finding], + *, + inventory: list[InventoryEntry] | None = None, + ) -> AttackGraph: graph = AttackGraph() graph.seed_sources_and_sinks() seed_server_surfaces(graph, server) @@ -30,6 +37,10 @@ def build(self, server: MCPServerInfo, findings: list[Finding]) -> AttackGraph: if self.config.attack_graph_include_overlap_chains: self._add_corroborated_invokes_edges(graph, server) apply_policy_edges(graph, server, findings) + if inventory and len(inventory) >= 2: + from mcts.scoring.graph_inventory import attach_inventory_layer + + attach_inventory_layer(graph, inventory) templates = load_chain_templates() max_depth = self.config.attack_graph_max_depth if max_depth > 0: @@ -48,13 +59,18 @@ def build(self, server: MCPServerInfo, findings: list[Finding]) -> AttackGraph: templates, top_per_template=3, ) - matched = self._attach_graph_polish(matched, counterfactuals=self.config.attack_graph_enable_counterfactuals) + matched = self._attach_graph_polish( + graph, + matched, + counterfactuals=self.config.attack_graph_enable_counterfactuals, + ) graph.matched_chains = matched graph.total_risk_score = sum(chain.chain_risk_score for chain in matched) return graph def _attach_graph_polish( self, + graph: AttackGraph, chains: list[MatchedChain], *, counterfactuals: bool, @@ -73,6 +89,8 @@ def _attach_graph_polish( update["counterfactual_remediation"] = counterfactual_for_chain( chain.template_id, chain.path.tool_names_on_path(), + graph=graph, + fix_kinds=list(template.recommended_fixes) if template else [], ) enriched.append(chain.model_copy(update=update)) return enriched diff --git a/src/mcts/scoring/graph_counterfactual.py b/src/mcts/scoring/graph_counterfactual.py index 883c2a2..83ed2a9 100644 --- a/src/mcts/scoring/graph_counterfactual.py +++ b/src/mcts/scoring/graph_counterfactual.py @@ -4,19 +4,24 @@ from typing import Any +from mcts.scoring.attack_graph import AttackGraph from mcts.scoring.graph_fixes import describe_fixes +from mcts.scoring.graph_mutate import simulate_fixes_for_template from mcts.scoring.graph_templates import load_chain_templates def counterfactual_for_chain( template_id: str, tools_on_path: list[str], + *, + graph: AttackGraph | None = None, + fix_kinds: list[str] | None = None, ) -> dict[str, Any]: """Build counterfactual payload aligned with trust-layer evidence shape.""" templates = {template.id: template for template in load_chain_templates()} template = templates.get(template_id) - fix_kinds = list(template.recommended_fixes) if template else [] - fixes = describe_fixes(fix_kinds) + kinds = fix_kinds if fix_kinds is not None else (list(template.recommended_fixes) if template else []) + fixes = describe_fixes(kinds) triggered: list[str] = [] actions: list[dict[str, str]] = [] @@ -32,10 +37,19 @@ def counterfactual_for_chain( } ) - return { + simulation: list[dict[str, Any]] = [] + if graph is not None and kinds: + simulation = simulate_fixes_for_template(graph, template_id, kinds) + + eliminating = [row for row in simulation if row.get("eliminates_template")] + payload: dict[str, Any] = { "triggered_by": triggered, - "removing_any_one_eliminates_finding": len(fixes) > 1, + "removing_any_one_eliminates_finding": len(eliminating) > 0 or len(fixes) > 1, "actions": actions, "recommended_fixes": fixes, "template_id": template_id, } + if simulation: + payload["fix_simulation"] = simulation + payload["effective_fixes"] = [row["fix_kind"] for row in eliminating] + return payload diff --git a/src/mcts/scoring/graph_inventory.py b/src/mcts/scoring/graph_inventory.py new file mode 100644 index 0000000..c9efd69 --- /dev/null +++ b/src/mcts/scoring/graph_inventory.py @@ -0,0 +1,148 @@ +"""Inventory-layer edges for multi-server attack graphs (Phase 3c).""" + +from __future__ import annotations + +import re + +from mcts.inventory.models import InventoryEntry +from mcts.scoring.attack_graph import AttackGraph +from mcts.scoring.attack_graph_models import EdgeKind, GraphLayer, NodeKind, canonical_node_id + +_READ_TOOLS = frozenset({"read_file", "get_env", "read_env", "fetch", "http_request"}) +_WRITE_TOOLS = frozenset({"write_file", "delete_file", "run_shell", "execute_command", "deploy"}) +_SENSITIVE_TOOLS = frozenset( + { + "read_file", + "write_file", + "delete_file", + "run_shell", + "execute_command", + "http_request", + "fetch", + "post_webhook", + "get_env", + "read_env", + } +) + + +def _server_slug(server_key: str) -> str: + return re.sub(r"[^a-zA-Z0-9]+", "_", server_key).strip("_") + + +def server_hub_id(server_key: str) -> str: + return canonical_node_id(NodeKind.CAPABILITY, f"inventory-{_server_slug(server_key)}") + + +def inventory_tool_id(server_key: str, tool_name: str) -> str: + return canonical_node_id(NodeKind.TOOL, f"{_server_slug(server_key)}::{tool_name}") + + +def attach_inventory_layer(graph: AttackGraph, inventory: list[InventoryEntry]) -> None: + """Merge cross-server inventory nodes and edges when fleet size >= 2.""" + if len(inventory) < 2: + return + + server_tools: dict[str, set[str]] = {} + server_meta: dict[str, InventoryEntry] = {} + + for entry in inventory: + server_key = f"{entry.client}/{entry.server_name}" + server_meta[server_key] = entry + tools = {tool.lower() for tool in entry.tools} + server_tools[server_key] = tools + + hub = server_hub_id(server_key) + graph.add_node( + NodeKind.CAPABILITY, + hub.split(":", 1)[1], + label=server_key, + layer=GraphLayer.INVENTORY, + metadata={"server_key": server_key, "client": entry.client, "server_name": entry.server_name}, + ) + for tool in sorted(entry.tools): + local = f"{_server_slug(server_key)}::{tool}" + graph.add_node( + NodeKind.TOOL, + local, + label=f"{tool} ({server_key})", + layer=GraphLayer.INVENTORY, + metadata={"server_key": server_key, "tool": tool}, + ) + graph.add_edge( + EdgeKind.EXPOSES, + hub, + inventory_tool_id(server_key, tool), + layer=GraphLayer.INVENTORY, + confidence=0.9, + reachability=1.0, + label="inventory_tool_surface", + policy=True, + ) + unscoped = canonical_node_id(NodeKind.TOOL, tool) + if unscoped in graph.nodes: + graph.add_edge( + EdgeKind.INVOKES, + unscoped, + inventory_tool_id(server_key, tool), + layer=GraphLayer.INVENTORY, + confidence=0.95, + reachability=1.0, + label="focal_server_bridge", + policy=True, + ) + + readers = [key for key, tools in server_tools.items() if tools & _READ_TOOLS] + writers = [key for key, tools in server_tools.items() if tools & _WRITE_TOOLS] + for reader in readers: + read_tools = sorted(server_tools[reader] & _READ_TOOLS) + for writer in writers: + if reader == writer: + continue + write_tools = sorted(server_tools[writer] & _WRITE_TOOLS) + if not read_tools or not write_tools: + continue + from_tool = inventory_tool_id(reader, read_tools[0]) + to_tool = inventory_tool_id(writer, write_tools[0]) + graph.add_edge( + EdgeKind.INVOKES, + from_tool, + to_tool, + layer=GraphLayer.INVENTORY, + confidence=0.75, + reachability=0.8, + label="cross_server_read_write", + policy=True, + metadata={"reader": reader, "writer": writer, "issue": "W015"}, + ) + graph.add_edge( + EdgeKind.INVOKES, + server_hub_id(reader), + server_hub_id(writer), + layer=GraphLayer.INVENTORY, + confidence=0.7, + reachability=0.75, + label="cross_server_toxic_flow", + policy=True, + metadata={"reader": reader, "writer": writer, "issue": "W015"}, + ) + + for tool in _SENSITIVE_TOOLS: + holders = sorted(key for key, tools in server_tools.items() if tool in tools) + if len(holders) < 2: + continue + for left in holders: + for right in holders: + if left >= right: + continue + graph.add_edge( + EdgeKind.INVOKES, + inventory_tool_id(left, tool), + inventory_tool_id(right, tool), + layer=GraphLayer.INVENTORY, + confidence=0.65, + reachability=0.7, + label="sensitive_tool_shadow", + policy=True, + metadata={"tool": tool, "servers": holders, "issue": "W016"}, + ) diff --git a/src/mcts/scoring/graph_mutate.py b/src/mcts/scoring/graph_mutate.py new file mode 100644 index 0000000..29d5f6a --- /dev/null +++ b/src/mcts/scoring/graph_mutate.py @@ -0,0 +1,164 @@ +"""Apply FixKind registry mutates to attack graphs (Phase 3c runtime engine).""" + +from __future__ import annotations + +import fnmatch +from typing import Any + +from mcts.scoring.attack_graph import AttackGraph +from mcts.scoring.attack_graph_models import EdgeKind, GraphLayer, parse_node_id +from mcts.scoring.graph_fixes import resolve_fix +from mcts.scoring.graph_matcher import match_template +from mcts.scoring.graph_templates import ChainTemplate, load_chain_templates + + +def clone_graph(graph: AttackGraph) -> AttackGraph: + """Deep-copy nodes and edges without matched chains.""" + copy = AttackGraph() + copy._nodes = {node_id: node.model_copy(deep=True) for node_id, node in graph.nodes.items()} + for edge in graph.edges.values(): + copy.merge_edge(edge.model_copy(deep=True)) + return copy + + +def _remove_node(graph: AttackGraph, node_id: str) -> None: + if node_id not in graph.nodes: + return + edge_ids = [ + edge.id for edge in graph.edges.values() if edge.from_node == node_id or edge.to_node == node_id + ] + for edge_id in edge_ids: + edge = graph.edges.get(edge_id) + if not edge: + continue + graph._edges.pop(edge_id, None) + if edge.kind in graph._edges_by_kind: + graph._edges_by_kind[edge.kind] = [ + eid for eid in graph._edges_by_kind[edge.kind] if eid != edge_id + ] + if edge.from_node in graph._outgoing: + graph._outgoing[edge.from_node] = [ + eid for eid in graph._outgoing[edge.from_node] if eid != edge_id + ] + graph._nodes_with_outgoing_kind.setdefault(edge.kind, set()).discard(edge.from_node) + graph._nodes.pop(node_id, None) + + +def _node_matches_pattern(node_id: str, pattern: str) -> bool: + if fnmatch.fnmatch(node_id, pattern): + return True + _, local = parse_node_id(node_id) + return fnmatch.fnmatch(local, pattern) + + +def apply_mutate_spec(graph: AttackGraph, spec: dict[str, Any]) -> None: + """Apply one registry mutate block to *graph* in place.""" + if "add_edge" in spec: + payload = spec["add_edge"] + kind = EdgeKind(str(payload["kind"])) + graph.add_edge( + kind, + str(payload["from"]), + str(payload["to"]), + layer=GraphLayer(payload["layer"]) if payload.get("layer") else None, + confidence=float(payload.get("confidence", 0.85)), + reachability=float(payload.get("reachability", 1.0)), + label=str(payload.get("label", "fix_mutate")), + policy=bool(payload.get("policy", True)), + ) + return + + if "set_reachability" in spec: + payload = spec["set_reachability"] + edge_kind = EdgeKind(str(payload["edge_kind"])) + value = float(payload["value"]) + for edge in graph.edges_of_kind(edge_kind): + edge.reachability = value + return + + if "remove_nodes" in spec: + payload = spec["remove_nodes"] + kind_value = str(payload["kind"]) + targets = [ + node_id + for node_id, node in graph.nodes.items() + if node.kind.value == kind_value or parse_node_id(node_id)[0] == kind_value + ] + for node_id in targets: + _remove_node(graph, node_id) + return + + if "remove_node" in spec: + pattern = str(spec["remove_node"]["pattern"]) + targets = [node_id for node_id in graph.nodes if _node_matches_pattern(node_id, pattern)] + for node_id in targets: + _remove_node(graph, node_id) + + +def apply_fix_kind(graph: AttackGraph, fix_kind: str) -> AttackGraph: + """Return a mutated graph copy after applying all mutates for *fix_kind*.""" + entry = resolve_fix(fix_kind) + mutated = clone_graph(graph) + if not entry: + return mutated + for spec in entry.get("mutates") or []: + if isinstance(spec, dict): + apply_mutate_spec(mutated, spec) + return mutated + + +def simulate_fix_eliminates_template( + graph: AttackGraph, + template: ChainTemplate, + fix_kind: str, +) -> bool: + """True when applying *fix_kind* removes all matches for *template*.""" + mutated = apply_fix_kind(graph, fix_kind) + return len(match_template(template, mutated)) == 0 + + +def simulate_fixes_for_template( + graph: AttackGraph, + template_id: str, + fix_kinds: list[str], +) -> list[dict[str, Any]]: + """Evaluate each fix kind against the live graph for counterfactual simulation.""" + templates = {template.id: template for template in load_chain_templates()} + template = templates.get(template_id) + if template is None: + return [] + results: list[dict[str, Any]] = [] + for kind in fix_kinds: + entry = resolve_fix(kind) or {"kind": kind} + eliminates = simulate_fix_eliminates_template(graph, template, kind) + results.append( + { + "fix_kind": kind, + "description": str(entry.get("description") or kind.replace("_", " ")), + "eliminates_template": eliminates, + "mutates_applied": len(entry.get("mutates") or []), + } + ) + return results + + +def any_fix_eliminates_template( + graph: AttackGraph, + template_id: str, + fix_kinds: list[str], +) -> bool: + rows = simulate_fixes_for_template(graph, template_id, fix_kinds) + return any(row["eliminates_template"] for row in rows) + + +def minimal_fix_set( + graph: AttackGraph, + template_id: str, + fix_kinds: list[str], +) -> list[str]: + """Greedy single-fix hits that eliminate the template (for doctor suggestions).""" + return [ + row["fix_kind"] + for row in simulate_fixes_for_template(graph, template_id, fix_kinds) + if row["eliminates_template"] + ] diff --git a/tests/scoring/test_graph_phase_3c.py b/tests/scoring/test_graph_phase_3c.py index 38b885e..99fb4d3 100644 --- a/tests/scoring/test_graph_phase_3c.py +++ b/tests/scoring/test_graph_phase_3c.py @@ -32,6 +32,107 @@ def test_counterfactual_for_chain() -> None: assert payload["recommended_fixes"] +def test_apply_fix_kind_remove_env_tool() -> None: + from mcts.scoring.attack_graph_models import EdgeKind, NodeKind + from mcts.scoring.graph_mutate import apply_fix_kind + + graph = AttackGraph() + graph.seed_sources_and_sinks() + graph.add_node(NodeKind.TOOL, "get-env", label="get-env") + graph.add_edge(EdgeKind.READS, "tool:get-env", "sink:env") + mutated = apply_fix_kind(graph, "remove_env_tool") + assert "tool:get-env" not in mutated.nodes + assert not any(edge.from_node == "tool:get-env" for edge in mutated.edges.values()) + + +def test_simulate_remove_env_tool_eliminates_http_takeover() -> None: + from mcts.scoring.attack_graph_models import EdgeKind, GraphLayer, NodeKind + from mcts.scoring.graph_mutate import simulate_fix_eliminates_template + from mcts.scoring.graph_templates import load_chain_templates + + graph = AttackGraph() + graph.seed_sources_and_sinks() + graph.add_node(NodeKind.TRANSPORT, "http", label="http", layer=GraphLayer.TRANSPORT) + graph.add_node(NodeKind.TOOL, "get-env", label="get-env") + graph.add_edge( + EdgeKind.EXPOSES, + "transport:http", + "tool:get-env", + layer=GraphLayer.TRANSPORT, + reachability=1.0, + ) + graph.add_edge(EdgeKind.READS, "tool:get-env", "sink:env", reachability=1.0) + graph.add_edge( + EdgeKind.DELIVERS_TO_CONTEXT, + "tool:get-env", + "sink:model_context", + reachability=1.0, + ) + template = next(t for t in load_chain_templates() if t.id == "HTTP_TAKEOVER") + assert simulate_fix_eliminates_template(graph, template, "remove_env_tool") + assert simulate_fix_eliminates_template(graph, template, "disable_http_transport") + + +def test_inventory_layer_adds_cross_server_edges() -> None: + from mcts.inventory.models import InventoryEntry + from mcts.scoring.attack_graph_models import GraphLayer + from mcts.scoring.graph_inventory import attach_inventory_layer + + graph = AttackGraph() + graph.seed_sources_and_sinks() + inventory = [ + InventoryEntry( + client="cursor", + config_path="/a", + server_name="reader", + tools=["read_file"], + ), + InventoryEntry( + client="cursor", + config_path="/b", + server_name="writer", + tools=["write_file"], + ), + ] + attach_inventory_layer(graph, inventory) + inventory_edges = [edge for edge in graph.edges.values() if edge.layer == GraphLayer.INVENTORY] + assert inventory_edges + assert any(edge.label == "cross_server_read_write" for edge in inventory_edges) + + +def test_counterfactual_includes_fix_simulation() -> None: + from mcts.scoring.attack_graph_models import EdgeKind, GraphLayer, NodeKind + from mcts.scoring.graph_mutate import clone_graph + + graph = AttackGraph() + graph.seed_sources_and_sinks() + graph.add_node(NodeKind.TRANSPORT, "http", label="http", layer=GraphLayer.TRANSPORT) + graph.add_node(NodeKind.TOOL, "get-env", label="get-env") + graph.add_edge( + EdgeKind.EXPOSES, + "transport:http", + "tool:get-env", + layer=GraphLayer.TRANSPORT, + reachability=1.0, + ) + graph.add_edge(EdgeKind.READS, "tool:get-env", "sink:env", reachability=1.0) + graph.add_edge( + EdgeKind.DELIVERS_TO_CONTEXT, + "tool:get-env", + "sink:model_context", + reachability=1.0, + ) + payload = counterfactual_for_chain("HTTP_TAKEOVER", ["get-env"], graph=clone_graph(graph)) + assert payload.get("fix_simulation") + assert "remove_env_tool" in payload.get("effective_fixes", []) + + +def test_config_counterfactuals_and_compress_default_on() -> None: + config = ScanConfig(target=".") + assert config.attack_graph_enable_counterfactuals is True + assert config.attack_graph_compress_for_ui is True + + def test_compress_paths_dedupes() -> None: paths = [ {"template_id": "A", "tools_on_path": ["t1"], "chain_risk_score": 1}, @@ -76,8 +177,8 @@ def test_graph_builder_attaches_recommended_fixes() -> None: def test_to_report_dict_compresses_when_requested() -> None: graph = AttackGraph() graph.matched_chains = [] - report = graph.to_report_dict(compress_for_ui=False) - assert "compression_stats" not in report + report = graph.to_report_dict(compress_for_ui=True) + assert report.get("compression_stats") is not None def test_normalize_ui_includes_layers_and_edge_class() -> None: