Showing ${compression.compressed_count} of ${compression.original_count} matched paths (UI compression).
`
+ : "";
+ list.innerHTML = compressionNote + paths
.map((path, idx) => {
const template = path.template_id ? `How tools could be chained together for a multi-step attack (read → exfiltrate, etc.).
View related issues →
-
Each arrow shows a possible step between tools.
+
Each arrow shows a possible step between tools. Policy edges are dashed; inferred edges are muted.
+
diff --git a/src/mcts/scoring/attack_graph.py b/src/mcts/scoring/attack_graph.py
index 1342400..fa05982 100644
--- a/src/mcts/scoring/attack_graph.py
+++ b/src/mcts/scoring/attack_graph.py
@@ -197,12 +197,17 @@ def filter_layers(self, layers: Iterable[GraphLayer | str]) -> AttackGraph:
filtered.total_risk_score = self.total_risk_score
return filtered
- def to_report_dict(self) -> dict[str, Any]:
+ def to_report_dict(self, *, compress_for_ui: bool = False) -> dict[str, Any]:
node_layers = {n.layer.value for n in self._nodes.values()}
edge_layers = {e.layer.value for e in self._edges.values()}
layers_present = sorted(node_layers | edge_layers)
paths = [self._path_to_dict(chain) for chain in self.matched_chains]
- return {
+ compression_stats: dict[str, Any] | None = None
+ if compress_for_ui:
+ from mcts.scoring.graph_compress import compress_paths
+
+ paths, compression_stats = compress_paths(paths)
+ payload = {
"version": 3,
"nodes": [n.model_dump(mode="json") for n in self._nodes.values()],
"edges": [e.model_dump(mode="json") for e in self._edges.values()],
@@ -211,6 +216,9 @@ def to_report_dict(self) -> dict[str, Any]:
"total_risk_score": round(self.total_risk_score, 2),
"layers_present": layers_present,
}
+ if compression_stats:
+ payload["compression_stats"] = compression_stats
+ return payload
def _path_to_dict(self, chain: MatchedChain) -> dict[str, Any]:
path = chain.path
@@ -229,6 +237,8 @@ def _path_to_dict(self, chain: MatchedChain) -> dict[str, Any]:
"finding_ids": [
chain.legacy_finding_id or f"chain-{chain.template_id.lower().replace('_', '-')}"
],
+ "recommended_fixes": chain.recommended_fixes,
+ "counterfactual_remediation": chain.counterfactual_remediation,
}
def to_findings(self) -> list[Finding]:
diff --git a/src/mcts/scoring/attack_graph_builder.py b/src/mcts/scoring/attack_graph_builder.py
index bf8d183..ba3d09a 100644
--- a/src/mcts/scoring/attack_graph_builder.py
+++ b/src/mcts/scoring/attack_graph_builder.py
@@ -2,11 +2,14 @@
from __future__ import annotations
+from typing import Any
+
from mcts.core.config import ScanConfig
+from mcts.inventory.models import InventoryEntry
from mcts.mcp.models import MCPServerInfo
from mcts.reporting.models import Finding
from mcts.scoring.attack_graph import AttackGraph
-from mcts.scoring.attack_graph_models import EdgeKind, canonical_node_id
+from mcts.scoring.attack_graph_models import EdgeKind, MatchedChain, canonical_node_id
from mcts.scoring.attack_graph_policy import apply_policy_edges, seed_server_surfaces
from mcts.scoring.attack_graph_producers import export_all_edges
from mcts.scoring.graph_matcher import match_all_templates
@@ -19,7 +22,13 @@ class GraphBuilder:
def __init__(self, config: ScanConfig | None = None) -> None:
self.config = config or ScanConfig(target=".")
- def build(self, server: MCPServerInfo, findings: list[Finding]) -> AttackGraph:
+ def build(
+ self,
+ server: MCPServerInfo,
+ findings: list[Finding],
+ *,
+ inventory: list[InventoryEntry] | None = None,
+ ) -> AttackGraph:
graph = AttackGraph()
graph.seed_sources_and_sinks()
seed_server_surfaces(graph, server)
@@ -28,6 +37,10 @@ def build(self, server: MCPServerInfo, findings: list[Finding]) -> AttackGraph:
if self.config.attack_graph_include_overlap_chains:
self._add_corroborated_invokes_edges(graph, server)
apply_policy_edges(graph, server, findings)
+ if inventory and len(inventory) >= 2:
+ from mcts.scoring.graph_inventory import attach_inventory_layer
+
+ attach_inventory_layer(graph, inventory)
templates = load_chain_templates()
max_depth = self.config.attack_graph_max_depth
if max_depth > 0:
@@ -46,10 +59,42 @@ def build(self, server: MCPServerInfo, findings: list[Finding]) -> AttackGraph:
templates,
top_per_template=3,
)
+ matched = self._attach_graph_polish(
+ graph,
+ matched,
+ counterfactuals=self.config.attack_graph_enable_counterfactuals,
+ )
graph.matched_chains = matched
graph.total_risk_score = sum(chain.chain_risk_score for chain in matched)
return graph
+ def _attach_graph_polish(
+ self,
+ graph: AttackGraph,
+ chains: list[MatchedChain],
+ *,
+ counterfactuals: bool,
+ ) -> list[MatchedChain]:
+ from mcts.scoring.graph_counterfactual import counterfactual_for_chain
+ from mcts.scoring.graph_fixes import describe_fixes
+ from mcts.scoring.graph_templates import load_chain_templates
+
+ templates = {template.id: template for template in load_chain_templates()}
+ enriched: list[MatchedChain] = []
+ for chain in chains:
+ template = templates.get(chain.template_id)
+ fixes = describe_fixes(list(template.recommended_fixes)) if template else []
+ update: dict[str, Any] = {"recommended_fixes": fixes}
+ if counterfactuals:
+ update["counterfactual_remediation"] = counterfactual_for_chain(
+ chain.template_id,
+ chain.path.tool_names_on_path(),
+ graph=graph,
+ fix_kinds=list(template.recommended_fixes) if template else [],
+ )
+ enriched.append(chain.model_copy(update=update))
+ return enriched
+
def _add_corroborated_invokes_edges(self, graph: AttackGraph, server: MCPServerInfo) -> None:
"""Optional overlap chains — disabled by default (spec forever default False)."""
tools = server.tools
diff --git a/src/mcts/scoring/attack_graph_models.py b/src/mcts/scoring/attack_graph_models.py
index 27895bf..e163235 100644
--- a/src/mcts/scoring/attack_graph_models.py
+++ b/src/mcts/scoring/attack_graph_models.py
@@ -236,3 +236,5 @@ class MatchedChain(BaseModel):
trust_boundary_crossings: int = 0
explanation: list[ExplanationStep] = Field(default_factory=list)
legacy_finding_id: str | None = None
+ counterfactual_remediation: dict[str, Any] | None = None
+ recommended_fixes: list[dict[str, Any]] = Field(default_factory=list)
diff --git a/src/mcts/scoring/fixes/registry.yaml b/src/mcts/scoring/fixes/registry.yaml
index d5b96a1..c8a965e 100644
--- a/src/mcts/scoring/fixes/registry.yaml
+++ b/src/mcts/scoring/fixes/registry.yaml
@@ -20,3 +20,66 @@ remove_env_tool:
mutates:
- remove_node:
pattern: tool:get-env
+
+bind_localhost_only:
+ description: Bind HTTP/SSE listeners to 127.0.0.1 only
+
+disable_network_egress:
+ description: Remove or block outbound network tools and httpx/fetch sinks
+
+add_url_policy:
+ description: Enforce URL allowlist and block RFC1918/metadata hosts
+
+disable_prompt_egress:
+ description: Prevent prompt handlers from performing network egress
+
+disable_elicitation:
+ description: Disable client URL elicitation tools and CAP-05 surfaces
+
+require_user_confirmation:
+ description: Require explicit user confirmation before elicitation or sampling
+
+disable_sampling:
+ description: Disable client sampling capability exposure
+
+gate_sampling_behind_auth:
+ description: Require authenticated transport before sampling is enabled
+
+scope_git_tools:
+ description: Require repository allowlist on all git read tools
+
+add_repo_allowlist:
+ description: Add mandatory --repository / repo scope CLI flag
+
+atomic_open:
+ description: Open files with O_NOFOLLOW and atomic read-after-validate
+
+disable_symlink_follow:
+ description: Use lstat per directory entry; never follow symlinks on read
+
+isolate_read_and_exec_tools:
+ description: Split read and exec tools across trust boundaries or separate servers
+
+validate_inputs_before_exec:
+ description: Validate and sandbox shell/exec inputs before execution
+
+remove_egress_tool:
+ description: Remove webhook/fetch/egress tools from production surface
+
+scope_read_tools:
+ description: Restrict read tools to scoped roots and path allowlists
+
+block_credential_tools:
+ description: Remove get-env and credential-access tools from production
+
+use_resource_link_not_inline_blob:
+ description: Return resource links instead of inline session resource blobs
+
+cap_session_resources:
+ description: Cap session resource count and TTL; deny gzip-to-resource staging
+
+cap_memory_writes:
+ description: Limit memory graph write tools and observation size
+
+trust_model_documentation:
+ description: Document memory trust model and cross-session read risks
diff --git a/src/mcts/scoring/graph_compress.py b/src/mcts/scoring/graph_compress.py
new file mode 100644
index 0000000..7b68523
--- /dev/null
+++ b/src/mcts/scoring/graph_compress.py
@@ -0,0 +1,54 @@
+"""Path compression for attack graph dashboard export (Phase 3c)."""
+
+from __future__ import annotations
+
+from typing import Any
+
+
+def _path_key(path: dict[str, Any]) -> tuple[str, tuple[str, ...]]:
+ template_id = str(path.get("template_id") or "")
+ tools = tuple(sorted(path.get("tools_on_path") or path.get("nodes") or []))
+ return template_id, tools
+
+
+def compress_paths(
+ paths: list[dict[str, Any]],
+ *,
+ max_total: int = 12,
+ max_per_template: int = 2,
+) -> tuple[list[dict[str, Any]], dict[str, Any]]:
+ """Dedupe by template + tool set; keep highest chain_risk_score paths."""
+ if not paths or len(paths) <= max_total:
+ return paths, {"original_count": len(paths), "compressed_count": len(paths), "dropped": 0}
+
+ by_template: dict[str, list[dict[str, Any]]] = {}
+ for path in paths:
+ template_id = str(path.get("template_id") or "unknown")
+ by_template.setdefault(template_id, []).append(path)
+
+ ranked: list[dict[str, Any]] = []
+ seen_keys: set[tuple[str, tuple[str, ...]]] = set()
+ for _template_id, group in sorted(by_template.items()):
+ group_sorted = sorted(
+ group,
+ key=lambda row: float(row.get("chain_risk_score") or 0),
+ reverse=True,
+ )
+ kept = 0
+ for path in group_sorted:
+ key = _path_key(path)
+ if key in seen_keys:
+ continue
+ seen_keys.add(key)
+ ranked.append(path)
+ kept += 1
+ if kept >= max_per_template:
+ break
+
+ ranked.sort(key=lambda row: float(row.get("chain_risk_score") or 0), reverse=True)
+ compressed = ranked[:max_total]
+ return compressed, {
+ "original_count": len(paths),
+ "compressed_count": len(compressed),
+ "dropped": max(0, len(paths) - len(compressed)),
+ }
diff --git a/src/mcts/scoring/graph_counterfactual.py b/src/mcts/scoring/graph_counterfactual.py
new file mode 100644
index 0000000..83ed2a9
--- /dev/null
+++ b/src/mcts/scoring/graph_counterfactual.py
@@ -0,0 +1,55 @@
+"""Counterfactual remediation for matched attack graph chains (Phase 3c)."""
+
+from __future__ import annotations
+
+from typing import Any
+
+from mcts.scoring.attack_graph import AttackGraph
+from mcts.scoring.graph_fixes import describe_fixes
+from mcts.scoring.graph_mutate import simulate_fixes_for_template
+from mcts.scoring.graph_templates import load_chain_templates
+
+
+def counterfactual_for_chain(
+ template_id: str,
+ tools_on_path: list[str],
+ *,
+ graph: AttackGraph | None = None,
+ fix_kinds: list[str] | None = None,
+) -> dict[str, Any]:
+ """Build counterfactual payload aligned with trust-layer evidence shape."""
+ templates = {template.id: template for template in load_chain_templates()}
+ template = templates.get(template_id)
+ kinds = fix_kinds if fix_kinds is not None else (list(template.recommended_fixes) if template else [])
+ fixes = describe_fixes(kinds)
+
+ triggered: list[str] = []
+ actions: list[dict[str, str]] = []
+ tool_label = ", ".join(tools_on_path) if tools_on_path else "matched path"
+ for fix in fixes:
+ kind = str(fix.get("kind", ""))
+ description = str(fix.get("description") or kind.replace("_", " "))
+ triggered.append(f"{template_id}: {description} ({tool_label})")
+ actions.append(
+ {
+ "action": description,
+ "removes": kind or template_id,
+ }
+ )
+
+ simulation: list[dict[str, Any]] = []
+ if graph is not None and kinds:
+ simulation = simulate_fixes_for_template(graph, template_id, kinds)
+
+ eliminating = [row for row in simulation if row.get("eliminates_template")]
+ payload: dict[str, Any] = {
+ "triggered_by": triggered,
+ "removing_any_one_eliminates_finding": len(eliminating) > 0 or len(fixes) > 1,
+ "actions": actions,
+ "recommended_fixes": fixes,
+ "template_id": template_id,
+ }
+ if simulation:
+ payload["fix_simulation"] = simulation
+ payload["effective_fixes"] = [row["fix_kind"] for row in eliminating]
+ return payload
diff --git a/src/mcts/scoring/graph_explain.py b/src/mcts/scoring/graph_explain.py
index 870011f..feb5d03 100644
--- a/src/mcts/scoring/graph_explain.py
+++ b/src/mcts/scoring/graph_explain.py
@@ -86,6 +86,21 @@ def matched_chain_to_finding(template_id: str, chains: list[MatchedChain]) -> Fi
"explanation": [step.model_dump(mode="json") for step in chain.explanation],
}
)
+ evidence_kwargs: dict[str, Any] = {
+ "template_id": template_id,
+ "path_proven": top.path.hop_count >= 2,
+ "chain_confidence": round(top.path_confidence, 3),
+ "path_reachability": round(top.path_reachability, 3),
+ "chain_risk_score": round(top.chain_risk_score, 3),
+ "trust_boundary_crossings": top.trust_boundary_crossings,
+ "exploit_cost": template.exploit_cost,
+ "paths": paths_payload,
+ "finding_class": template.finding_class,
+ }
+ if top.counterfactual_remediation:
+ evidence_kwargs["counterfactual_remediation"] = top.counterfactual_remediation
+ if top.recommended_fixes:
+ evidence_kwargs["recommended_fixes"] = top.recommended_fixes
builder = (
FindingBuilder(
finding_id=finding_id,
@@ -97,17 +112,7 @@ def matched_chain_to_finding(template_id: str, chains: list[MatchedChain]) -> Fi
)
.confidence(top.path_confidence)
.technique("MCTS-T-attack-graph")
- .evidence(
- template_id=template_id,
- path_proven=top.path.hop_count >= 2,
- chain_confidence=round(top.path_confidence, 3),
- path_reachability=round(top.path_reachability, 3),
- chain_risk_score=round(top.chain_risk_score, 3),
- trust_boundary_crossings=top.trust_boundary_crossings,
- exploit_cost=template.exploit_cost,
- paths=paths_payload,
- finding_class=template.finding_class,
- )
+ .evidence(**evidence_kwargs)
.fact(
rule_id=template_id,
match=template.title,
diff --git a/src/mcts/scoring/graph_fixes.py b/src/mcts/scoring/graph_fixes.py
new file mode 100644
index 0000000..db53524
--- /dev/null
+++ b/src/mcts/scoring/graph_fixes.py
@@ -0,0 +1,43 @@
+"""FixKind registry for attack graph template remediations (Phase 3c)."""
+
+from __future__ import annotations
+
+from functools import lru_cache
+from pathlib import Path
+from typing import Any
+
+import yaml
+
+FIXES_REGISTRY_PATH = Path(__file__).resolve().parent / "fixes" / "registry.yaml"
+
+
+@lru_cache(maxsize=1)
+def load_fixes_registry() -> dict[str, dict[str, Any]]:
+ if not FIXES_REGISTRY_PATH.exists():
+ return {}
+ raw = yaml.safe_load(FIXES_REGISTRY_PATH.read_text(encoding="utf-8")) or {}
+ return {str(key): value for key, value in raw.items() if isinstance(value, dict)}
+
+
+def resolve_fix(kind: str) -> dict[str, Any] | None:
+ entry = load_fixes_registry().get(kind)
+ if not entry:
+ return None
+ return {"kind": kind, **entry}
+
+
+def describe_fixes(fix_kinds: list[str]) -> list[dict[str, Any]]:
+ """Map template recommended_fixes keys to registry descriptions."""
+ described: list[dict[str, Any]] = []
+ for kind in fix_kinds:
+ entry = resolve_fix(kind)
+ if entry:
+ described.append(entry)
+ else:
+ described.append(
+ {
+ "kind": kind,
+ "description": kind.replace("_", " "),
+ }
+ )
+ return described
diff --git a/src/mcts/scoring/graph_inventory.py b/src/mcts/scoring/graph_inventory.py
new file mode 100644
index 0000000..c9efd69
--- /dev/null
+++ b/src/mcts/scoring/graph_inventory.py
@@ -0,0 +1,148 @@
+"""Inventory-layer edges for multi-server attack graphs (Phase 3c)."""
+
+from __future__ import annotations
+
+import re
+
+from mcts.inventory.models import InventoryEntry
+from mcts.scoring.attack_graph import AttackGraph
+from mcts.scoring.attack_graph_models import EdgeKind, GraphLayer, NodeKind, canonical_node_id
+
+_READ_TOOLS = frozenset({"read_file", "get_env", "read_env", "fetch", "http_request"})
+_WRITE_TOOLS = frozenset({"write_file", "delete_file", "run_shell", "execute_command", "deploy"})
+_SENSITIVE_TOOLS = frozenset(
+ {
+ "read_file",
+ "write_file",
+ "delete_file",
+ "run_shell",
+ "execute_command",
+ "http_request",
+ "fetch",
+ "post_webhook",
+ "get_env",
+ "read_env",
+ }
+)
+
+
+def _server_slug(server_key: str) -> str:
+ return re.sub(r"[^a-zA-Z0-9]+", "_", server_key).strip("_")
+
+
+def server_hub_id(server_key: str) -> str:
+ return canonical_node_id(NodeKind.CAPABILITY, f"inventory-{_server_slug(server_key)}")
+
+
+def inventory_tool_id(server_key: str, tool_name: str) -> str:
+ return canonical_node_id(NodeKind.TOOL, f"{_server_slug(server_key)}::{tool_name}")
+
+
+def attach_inventory_layer(graph: AttackGraph, inventory: list[InventoryEntry]) -> None:
+ """Merge cross-server inventory nodes and edges when fleet size >= 2."""
+ if len(inventory) < 2:
+ return
+
+ server_tools: dict[str, set[str]] = {}
+ server_meta: dict[str, InventoryEntry] = {}
+
+ for entry in inventory:
+ server_key = f"{entry.client}/{entry.server_name}"
+ server_meta[server_key] = entry
+ tools = {tool.lower() for tool in entry.tools}
+ server_tools[server_key] = tools
+
+ hub = server_hub_id(server_key)
+ graph.add_node(
+ NodeKind.CAPABILITY,
+ hub.split(":", 1)[1],
+ label=server_key,
+ layer=GraphLayer.INVENTORY,
+ metadata={"server_key": server_key, "client": entry.client, "server_name": entry.server_name},
+ )
+ for tool in sorted(entry.tools):
+ local = f"{_server_slug(server_key)}::{tool}"
+ graph.add_node(
+ NodeKind.TOOL,
+ local,
+ label=f"{tool} ({server_key})",
+ layer=GraphLayer.INVENTORY,
+ metadata={"server_key": server_key, "tool": tool},
+ )
+ graph.add_edge(
+ EdgeKind.EXPOSES,
+ hub,
+ inventory_tool_id(server_key, tool),
+ layer=GraphLayer.INVENTORY,
+ confidence=0.9,
+ reachability=1.0,
+ label="inventory_tool_surface",
+ policy=True,
+ )
+ unscoped = canonical_node_id(NodeKind.TOOL, tool)
+ if unscoped in graph.nodes:
+ graph.add_edge(
+ EdgeKind.INVOKES,
+ unscoped,
+ inventory_tool_id(server_key, tool),
+ layer=GraphLayer.INVENTORY,
+ confidence=0.95,
+ reachability=1.0,
+ label="focal_server_bridge",
+ policy=True,
+ )
+
+ readers = [key for key, tools in server_tools.items() if tools & _READ_TOOLS]
+ writers = [key for key, tools in server_tools.items() if tools & _WRITE_TOOLS]
+ for reader in readers:
+ read_tools = sorted(server_tools[reader] & _READ_TOOLS)
+ for writer in writers:
+ if reader == writer:
+ continue
+ write_tools = sorted(server_tools[writer] & _WRITE_TOOLS)
+ if not read_tools or not write_tools:
+ continue
+ from_tool = inventory_tool_id(reader, read_tools[0])
+ to_tool = inventory_tool_id(writer, write_tools[0])
+ graph.add_edge(
+ EdgeKind.INVOKES,
+ from_tool,
+ to_tool,
+ layer=GraphLayer.INVENTORY,
+ confidence=0.75,
+ reachability=0.8,
+ label="cross_server_read_write",
+ policy=True,
+ metadata={"reader": reader, "writer": writer, "issue": "W015"},
+ )
+ graph.add_edge(
+ EdgeKind.INVOKES,
+ server_hub_id(reader),
+ server_hub_id(writer),
+ layer=GraphLayer.INVENTORY,
+ confidence=0.7,
+ reachability=0.75,
+ label="cross_server_toxic_flow",
+ policy=True,
+ metadata={"reader": reader, "writer": writer, "issue": "W015"},
+ )
+
+ for tool in _SENSITIVE_TOOLS:
+ holders = sorted(key for key, tools in server_tools.items() if tool in tools)
+ if len(holders) < 2:
+ continue
+ for left in holders:
+ for right in holders:
+ if left >= right:
+ continue
+ graph.add_edge(
+ EdgeKind.INVOKES,
+ inventory_tool_id(left, tool),
+ inventory_tool_id(right, tool),
+ layer=GraphLayer.INVENTORY,
+ confidence=0.65,
+ reachability=0.7,
+ label="sensitive_tool_shadow",
+ policy=True,
+ metadata={"tool": tool, "servers": holders, "issue": "W016"},
+ )
diff --git a/src/mcts/scoring/graph_mutate.py b/src/mcts/scoring/graph_mutate.py
new file mode 100644
index 0000000..29d5f6a
--- /dev/null
+++ b/src/mcts/scoring/graph_mutate.py
@@ -0,0 +1,164 @@
+"""Apply FixKind registry mutates to attack graphs (Phase 3c runtime engine)."""
+
+from __future__ import annotations
+
+import fnmatch
+from typing import Any
+
+from mcts.scoring.attack_graph import AttackGraph
+from mcts.scoring.attack_graph_models import EdgeKind, GraphLayer, parse_node_id
+from mcts.scoring.graph_fixes import resolve_fix
+from mcts.scoring.graph_matcher import match_template
+from mcts.scoring.graph_templates import ChainTemplate, load_chain_templates
+
+
+def clone_graph(graph: AttackGraph) -> AttackGraph:
+ """Deep-copy nodes and edges without matched chains."""
+ copy = AttackGraph()
+ copy._nodes = {node_id: node.model_copy(deep=True) for node_id, node in graph.nodes.items()}
+ for edge in graph.edges.values():
+ copy.merge_edge(edge.model_copy(deep=True))
+ return copy
+
+
+def _remove_node(graph: AttackGraph, node_id: str) -> None:
+ if node_id not in graph.nodes:
+ return
+ edge_ids = [
+ edge.id for edge in graph.edges.values() if edge.from_node == node_id or edge.to_node == node_id
+ ]
+ for edge_id in edge_ids:
+ edge = graph.edges.get(edge_id)
+ if not edge:
+ continue
+ graph._edges.pop(edge_id, None)
+ if edge.kind in graph._edges_by_kind:
+ graph._edges_by_kind[edge.kind] = [
+ eid for eid in graph._edges_by_kind[edge.kind] if eid != edge_id
+ ]
+ if edge.from_node in graph._outgoing:
+ graph._outgoing[edge.from_node] = [
+ eid for eid in graph._outgoing[edge.from_node] if eid != edge_id
+ ]
+ graph._nodes_with_outgoing_kind.setdefault(edge.kind, set()).discard(edge.from_node)
+ graph._nodes.pop(node_id, None)
+
+
+def _node_matches_pattern(node_id: str, pattern: str) -> bool:
+ if fnmatch.fnmatch(node_id, pattern):
+ return True
+ _, local = parse_node_id(node_id)
+ return fnmatch.fnmatch(local, pattern)
+
+
+def apply_mutate_spec(graph: AttackGraph, spec: dict[str, Any]) -> None:
+ """Apply one registry mutate block to *graph* in place."""
+ if "add_edge" in spec:
+ payload = spec["add_edge"]
+ kind = EdgeKind(str(payload["kind"]))
+ graph.add_edge(
+ kind,
+ str(payload["from"]),
+ str(payload["to"]),
+ layer=GraphLayer(payload["layer"]) if payload.get("layer") else None,
+ confidence=float(payload.get("confidence", 0.85)),
+ reachability=float(payload.get("reachability", 1.0)),
+ label=str(payload.get("label", "fix_mutate")),
+ policy=bool(payload.get("policy", True)),
+ )
+ return
+
+ if "set_reachability" in spec:
+ payload = spec["set_reachability"]
+ edge_kind = EdgeKind(str(payload["edge_kind"]))
+ value = float(payload["value"])
+ for edge in graph.edges_of_kind(edge_kind):
+ edge.reachability = value
+ return
+
+ if "remove_nodes" in spec:
+ payload = spec["remove_nodes"]
+ kind_value = str(payload["kind"])
+ targets = [
+ node_id
+ for node_id, node in graph.nodes.items()
+ if node.kind.value == kind_value or parse_node_id(node_id)[0] == kind_value
+ ]
+ for node_id in targets:
+ _remove_node(graph, node_id)
+ return
+
+ if "remove_node" in spec:
+ pattern = str(spec["remove_node"]["pattern"])
+ targets = [node_id for node_id in graph.nodes if _node_matches_pattern(node_id, pattern)]
+ for node_id in targets:
+ _remove_node(graph, node_id)
+
+
+def apply_fix_kind(graph: AttackGraph, fix_kind: str) -> AttackGraph:
+ """Return a mutated graph copy after applying all mutates for *fix_kind*."""
+ entry = resolve_fix(fix_kind)
+ mutated = clone_graph(graph)
+ if not entry:
+ return mutated
+ for spec in entry.get("mutates") or []:
+ if isinstance(spec, dict):
+ apply_mutate_spec(mutated, spec)
+ return mutated
+
+
+def simulate_fix_eliminates_template(
+ graph: AttackGraph,
+ template: ChainTemplate,
+ fix_kind: str,
+) -> bool:
+ """True when applying *fix_kind* removes all matches for *template*."""
+ mutated = apply_fix_kind(graph, fix_kind)
+ return len(match_template(template, mutated)) == 0
+
+
+def simulate_fixes_for_template(
+ graph: AttackGraph,
+ template_id: str,
+ fix_kinds: list[str],
+) -> list[dict[str, Any]]:
+ """Evaluate each fix kind against the live graph for counterfactual simulation."""
+ templates = {template.id: template for template in load_chain_templates()}
+ template = templates.get(template_id)
+ if template is None:
+ return []
+ results: list[dict[str, Any]] = []
+ for kind in fix_kinds:
+ entry = resolve_fix(kind) or {"kind": kind}
+ eliminates = simulate_fix_eliminates_template(graph, template, kind)
+ results.append(
+ {
+ "fix_kind": kind,
+ "description": str(entry.get("description") or kind.replace("_", " ")),
+ "eliminates_template": eliminates,
+ "mutates_applied": len(entry.get("mutates") or []),
+ }
+ )
+ return results
+
+
+def any_fix_eliminates_template(
+ graph: AttackGraph,
+ template_id: str,
+ fix_kinds: list[str],
+) -> bool:
+ rows = simulate_fixes_for_template(graph, template_id, fix_kinds)
+ return any(row["eliminates_template"] for row in rows)
+
+
+def minimal_fix_set(
+ graph: AttackGraph,
+ template_id: str,
+ fix_kinds: list[str],
+) -> list[str]:
+ """Greedy single-fix hits that eliminate the template (for doctor suggestions)."""
+ return [
+ row["fix_kind"]
+ for row in simulate_fixes_for_template(graph, template_id, fix_kinds)
+ if row["eliminates_template"]
+ ]
diff --git a/src/mcts/scoring/graph_suggest.py b/src/mcts/scoring/graph_suggest.py
new file mode 100644
index 0000000..2b6acb1
--- /dev/null
+++ b/src/mcts/scoring/graph_suggest.py
@@ -0,0 +1,34 @@
+"""Suggest attack-graph remediations from a scan report (Phase 3c)."""
+
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Any
+
+from mcts.scoring.graph_fixes import describe_fixes
+from mcts.scoring.graph_templates import load_chain_templates
+
+
+def suggest_fixes_from_report(report_path: Path) -> list[dict[str, Any]]:
+ payload = json.loads(report_path.read_text(encoding="utf-8"))
+ attack_graph = payload.get("attack_graph") or {}
+ templates = {template.id: template for template in load_chain_templates()}
+ suggestions: list[dict[str, Any]] = []
+ seen: set[str] = set()
+
+ for template_id in attack_graph.get("templates_matched") or []:
+ if template_id in seen:
+ continue
+ seen.add(template_id)
+ template = templates.get(str(template_id))
+ if not template:
+ continue
+ suggestions.append(
+ {
+ "template_id": template_id,
+ "title": template.title,
+ "recommended_fixes": describe_fixes(list(template.recommended_fixes)),
+ }
+ )
+ return suggestions
diff --git a/src/mcts/scoring/graph_ui.py b/src/mcts/scoring/graph_ui.py
index 5e7c43f..b2f45b9 100644
--- a/src/mcts/scoring/graph_ui.py
+++ b/src/mcts/scoring/graph_ui.py
@@ -22,6 +22,9 @@ def normalize_attack_graph_for_ui(graph: dict[str, Any]) -> dict[str, Any]:
"label": node.get("label") or _short_label(node.get("id", "")),
"type": node.get("kind", "tool"),
"kind": node.get("kind"),
+ "layer": node.get("layer"),
+ "trust": node.get("trust"),
+ "sensitivity": node.get("sensitivity"),
}
)
@@ -31,13 +34,18 @@ def normalize_attack_graph_for_ui(graph: dict[str, Any]) -> dict[str, Any]:
continue
src = edge.get("from_node") or edge.get("from")
dst = edge.get("to_node") or edge.get("to")
+ policy = bool(edge.get("policy", False))
+ evidence_strength = edge.get("evidence_strength") or "static"
edges.append(
{
"from": src,
"to": dst,
"label": edge.get("label") or edge.get("kind", ""),
"kind": edge.get("kind"),
- "policy": edge.get("policy", False),
+ "layer": edge.get("layer"),
+ "policy": policy,
+ "evidence_strength": evidence_strength,
+ "edge_class": "policy" if policy else _edge_class(evidence_strength),
}
)
@@ -56,9 +64,16 @@ def normalize_attack_graph_for_ui(graph: dict[str, Any]) -> dict[str, Any]:
"chain_risk_score": path.get("chain_risk_score"),
"explanation": path.get("explanation") or [],
"finding_ids": path.get("finding_ids") or [],
+ "recommended_fixes": path.get("recommended_fixes") or [],
+ "counterfactual_remediation": path.get("counterfactual_remediation"),
}
)
+ layers_present = graph.get("layers_present") or sorted(
+ {layer for node in nodes if (layer := node.get("layer"))}
+ | {layer for edge in edges if (layer := edge.get("layer"))}
+ )
+
return {
"version": version,
"nodes": nodes,
@@ -66,7 +81,8 @@ def normalize_attack_graph_for_ui(graph: dict[str, Any]) -> dict[str, Any]:
"paths": paths,
"templates_matched": graph.get("templates_matched") or [],
"total_risk_score": graph.get("total_risk_score"),
- "layers_present": graph.get("layers_present") or [],
+ "layers_present": layers_present,
+ "compression_stats": graph.get("compression_stats"),
}
@@ -94,9 +110,23 @@ def format_attack_path_explanation(finding: Any) -> str:
lines.append(f"{idx}. {step.get('message', step)}")
else:
lines.append(f"{idx}. {step}")
+ counterfactual = evidence.get("counterfactual_remediation")
+ if isinstance(counterfactual, dict) and counterfactual.get("actions"):
+ lines.append("Counterfactual fixes:")
+ for action in counterfactual.get("actions") or []:
+ if isinstance(action, dict):
+ lines.append(f"- {action.get('action', action)}")
return "\n".join(lines) if lines else (finding.description or "")
+def _edge_class(evidence_strength: str) -> str:
+ if evidence_strength == "runtime":
+ return "runtime"
+ if evidence_strength == "heuristic":
+ return "inferred"
+ return "proven"
+
+
def _short_label(node_id: str) -> str:
if ":" in node_id:
return node_id.split(":", 1)[1]
diff --git a/tests/scoring/test_graph_phase_3c.py b/tests/scoring/test_graph_phase_3c.py
new file mode 100644
index 0000000..99fb4d3
--- /dev/null
+++ b/tests/scoring/test_graph_phase_3c.py
@@ -0,0 +1,245 @@
+"""Phase 3c graph polish: fixes registry, counterfactuals, compression, UI."""
+
+from __future__ import annotations
+
+from mcts.core.config import ScanConfig
+from mcts.core.scanner import Scanner
+from mcts.scoring.attack_graph import AttackGraph
+from mcts.scoring.attack_graph_builder import GraphBuilder
+from mcts.scoring.attack_graph_models import EdgeKind, GraphEdge, GraphLayer, NodeKind
+from mcts.scoring.graph_compress import compress_paths
+from mcts.scoring.graph_counterfactual import counterfactual_for_chain
+from mcts.scoring.graph_fixes import describe_fixes, load_fixes_registry
+from mcts.scoring.graph_ui import normalize_attack_graph_for_ui
+
+
+def test_fixes_registry_loads() -> None:
+ registry = load_fixes_registry()
+ assert "add_http_auth" in registry
+ assert "remove_env_tool" in registry
+
+
+def test_describe_fixes_uses_registry() -> None:
+ fixes = describe_fixes(["add_http_auth", "unknown_fix_kind"])
+ assert fixes[0]["description"]
+ assert fixes[1]["kind"] == "unknown_fix_kind"
+
+
+def test_counterfactual_for_chain() -> None:
+ payload = counterfactual_for_chain("HTTP_TAKEOVER", ["get-env"])
+ assert payload["template_id"] == "HTTP_TAKEOVER"
+ assert payload["actions"]
+ assert payload["recommended_fixes"]
+
+
+def test_apply_fix_kind_remove_env_tool() -> None:
+ from mcts.scoring.attack_graph_models import EdgeKind, NodeKind
+ from mcts.scoring.graph_mutate import apply_fix_kind
+
+ graph = AttackGraph()
+ graph.seed_sources_and_sinks()
+ graph.add_node(NodeKind.TOOL, "get-env", label="get-env")
+ graph.add_edge(EdgeKind.READS, "tool:get-env", "sink:env")
+ mutated = apply_fix_kind(graph, "remove_env_tool")
+ assert "tool:get-env" not in mutated.nodes
+ assert not any(edge.from_node == "tool:get-env" for edge in mutated.edges.values())
+
+
+def test_simulate_remove_env_tool_eliminates_http_takeover() -> None:
+ from mcts.scoring.attack_graph_models import EdgeKind, GraphLayer, NodeKind
+ from mcts.scoring.graph_mutate import simulate_fix_eliminates_template
+ from mcts.scoring.graph_templates import load_chain_templates
+
+ graph = AttackGraph()
+ graph.seed_sources_and_sinks()
+ graph.add_node(NodeKind.TRANSPORT, "http", label="http", layer=GraphLayer.TRANSPORT)
+ graph.add_node(NodeKind.TOOL, "get-env", label="get-env")
+ graph.add_edge(
+ EdgeKind.EXPOSES,
+ "transport:http",
+ "tool:get-env",
+ layer=GraphLayer.TRANSPORT,
+ reachability=1.0,
+ )
+ graph.add_edge(EdgeKind.READS, "tool:get-env", "sink:env", reachability=1.0)
+ graph.add_edge(
+ EdgeKind.DELIVERS_TO_CONTEXT,
+ "tool:get-env",
+ "sink:model_context",
+ reachability=1.0,
+ )
+ template = next(t for t in load_chain_templates() if t.id == "HTTP_TAKEOVER")
+ assert simulate_fix_eliminates_template(graph, template, "remove_env_tool")
+ assert simulate_fix_eliminates_template(graph, template, "disable_http_transport")
+
+
+def test_inventory_layer_adds_cross_server_edges() -> None:
+ from mcts.inventory.models import InventoryEntry
+ from mcts.scoring.attack_graph_models import GraphLayer
+ from mcts.scoring.graph_inventory import attach_inventory_layer
+
+ graph = AttackGraph()
+ graph.seed_sources_and_sinks()
+ inventory = [
+ InventoryEntry(
+ client="cursor",
+ config_path="/a",
+ server_name="reader",
+ tools=["read_file"],
+ ),
+ InventoryEntry(
+ client="cursor",
+ config_path="/b",
+ server_name="writer",
+ tools=["write_file"],
+ ),
+ ]
+ attach_inventory_layer(graph, inventory)
+ inventory_edges = [edge for edge in graph.edges.values() if edge.layer == GraphLayer.INVENTORY]
+ assert inventory_edges
+ assert any(edge.label == "cross_server_read_write" for edge in inventory_edges)
+
+
+def test_counterfactual_includes_fix_simulation() -> None:
+ from mcts.scoring.attack_graph_models import EdgeKind, GraphLayer, NodeKind
+ from mcts.scoring.graph_mutate import clone_graph
+
+ graph = AttackGraph()
+ graph.seed_sources_and_sinks()
+ graph.add_node(NodeKind.TRANSPORT, "http", label="http", layer=GraphLayer.TRANSPORT)
+ graph.add_node(NodeKind.TOOL, "get-env", label="get-env")
+ graph.add_edge(
+ EdgeKind.EXPOSES,
+ "transport:http",
+ "tool:get-env",
+ layer=GraphLayer.TRANSPORT,
+ reachability=1.0,
+ )
+ graph.add_edge(EdgeKind.READS, "tool:get-env", "sink:env", reachability=1.0)
+ graph.add_edge(
+ EdgeKind.DELIVERS_TO_CONTEXT,
+ "tool:get-env",
+ "sink:model_context",
+ reachability=1.0,
+ )
+ payload = counterfactual_for_chain("HTTP_TAKEOVER", ["get-env"], graph=clone_graph(graph))
+ assert payload.get("fix_simulation")
+ assert "remove_env_tool" in payload.get("effective_fixes", [])
+
+
+def test_config_counterfactuals_and_compress_default_on() -> None:
+ config = ScanConfig(target=".")
+ assert config.attack_graph_enable_counterfactuals is True
+ assert config.attack_graph_compress_for_ui is True
+
+
+def test_compress_paths_dedupes() -> None:
+ paths = [
+ {"template_id": "A", "tools_on_path": ["t1"], "chain_risk_score": 1},
+ {"template_id": "A", "tools_on_path": ["t1"], "chain_risk_score": 2},
+ {"template_id": "B", "tools_on_path": ["t2"], "chain_risk_score": 5},
+ ]
+ compressed, stats = compress_paths(paths, max_total=2, max_per_template=1)
+ assert len(compressed) == 2
+ assert stats["dropped"] == 1
+
+
+def test_graph_builder_attaches_recommended_fixes() -> None:
+ graph = AttackGraph()
+ graph.seed_sources_and_sinks()
+ graph.add_node(NodeKind.TOOL, "trigger-url-elicitation", label="trigger-url-elicitation")
+ graph.add_node(NodeKind.CAPABILITY, "elicitation", label="elicitation")
+ graph.merge_edge(
+ GraphEdge(
+ id="edge-elicit",
+ kind=EdgeKind.TRIGGERS,
+ from_node="tool:trigger-url-elicitation",
+ to_node="capability:elicitation",
+ layer=GraphLayer.TRUST_BOUNDARY,
+ confidence=0.75,
+ reachability=0.8,
+ )
+ )
+ from mcts.mcp.models import MCPServerInfo, MCPTool
+
+ server = MCPServerInfo(
+ name="demo",
+ tools=[MCPTool(name="trigger-url-elicitation", description="elicitation tool")],
+ )
+ builder = GraphBuilder(config=ScanConfig(target=".", attack_graph_enable_counterfactuals=True))
+ built = builder.build(server, [])
+ elicit = next((c for c in built.matched_chains if c.template_id == "ELICIT_PHISH"), None)
+ assert elicit is not None
+ assert elicit.recommended_fixes
+ assert elicit.counterfactual_remediation
+
+
+def test_to_report_dict_compresses_when_requested() -> None:
+ graph = AttackGraph()
+ graph.matched_chains = []
+ report = graph.to_report_dict(compress_for_ui=True)
+ assert report.get("compression_stats") is not None
+
+
+def test_normalize_ui_includes_layers_and_edge_class() -> None:
+ raw = {
+ "version": 3,
+ "nodes": [
+ {
+ "id": "tool:fetch",
+ "kind": "tool",
+ "label": "fetch",
+ "layer": "mcp_surface",
+ "trust": "semi_trusted",
+ "sensitivity": "medium",
+ }
+ ],
+ "edges": [
+ {
+ "from_node": "tool:fetch",
+ "to_node": "sink:external_network",
+ "kind": "EGRESS",
+ "layer": "dataflow",
+ "policy": False,
+ "evidence_strength": "static",
+ }
+ ],
+ "paths": [],
+ "layers_present": ["dataflow", "mcp_surface"],
+ }
+ ui = normalize_attack_graph_for_ui(raw)
+ assert ui["nodes"][0]["trust"] == "semi_trusted"
+ assert ui["edges"][0]["edge_class"] == "proven"
+ assert "dataflow" in ui["layers_present"]
+
+
+def test_suggest_fixes_from_report(tmp_path) -> None:
+ from mcts.scoring.graph_suggest import suggest_fixes_from_report
+
+ report = {
+ "attack_graph": {
+ "templates_matched": ["ELICIT_PHISH", "ELICIT_PHISH"],
+ }
+ }
+ path = tmp_path / "scan.json"
+ path.write_text(__import__("json").dumps(report), encoding="utf-8")
+ rows = suggest_fixes_from_report(path)
+ assert len(rows) == 1
+ assert rows[0]["template_id"] == "ELICIT_PHISH"
+
+
+def test_scanner_counterfactual_flag() -> None:
+ target = "tests/fixtures/monorepo-mini/src/everything/tools/trigger-url-elicitation.ts"
+ report = Scanner(
+ ScanConfig(
+ target=target,
+ surface_depth="full",
+ attack_graph_enable_counterfactuals=True,
+ )
+ ).run()
+ chain = next(
+ f
+ for f in report.findings
+ if f.analyzer == "attack_graph" and f.evidence.get("template_id") == "ELICIT_PHISH"
+ )
+ assert chain.evidence.get("counterfactual_remediation")