-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexplainability.py
More file actions
135 lines (111 loc) · 5.45 KB
/
Copy pathexplainability.py
File metadata and controls
135 lines (111 loc) · 5.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
explainability.py — SHAP-based explanation for artifact suspicion scores.
Surfaces top-N features driving each classification result.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import shap
from config import cfg
from classifier import ClassificationResult, ForensicClassifier
@dataclass
class Explanation:
artifact_path: str
suspicion_score: float
label: str
top_features: List[Tuple[str, float]] # [(feature_name, shap_value), ...]
narrative: str # human-readable "because: ..." string
# Feature names must align with ForensicClassifier._featurize() numeric block
_NUMERIC_FEATURE_NAMES = [
"file_size_mb",
"is_deleted",
"network_anomaly",
"financial_site",
"timestomp_detected",
"visit_count_norm",
]
class ExplainabilityEngine:
"""
Wraps a fitted ForensicClassifier with a SHAP KernelExplainer.
KernelExplainer is model-agnostic — works whether the head is fitted or heuristic.
"""
def __init__(self, classifier: ForensicClassifier, background_results: List[ClassificationResult]):
self._clf = classifier
embed_dim = background_results[0].feature_vector.shape[0] - len(_NUMERIC_FEATURE_NAMES)
self._embed_dim = embed_dim
self._feature_names = [f"bert_{i}" for i in range(embed_dim)] + _NUMERIC_FEATURE_NAMES
background = np.vstack([r.feature_vector for r in background_results])
# Deduplicate rows before k-means to avoid k > n_distinct error
unique_bg = np.unique(background, axis=0)
k = min(cfg.shap_background_k, len(unique_bg))
bg_summary = shap.kmeans(unique_bg, k)
self._explainer = shap.KernelExplainer(
self._predict_fn, bg_summary
)
def explain(self, result: ClassificationResult, top_n: int = 5) -> Explanation:
fv = result.feature_vector.reshape(1, -1)
shap_values = self._explainer.shap_values(fv, nsamples=cfg.shap_nsamples, silent=True)
# shap_values shape: (1, n_features) for binary class 1
sv = shap_values[0] if isinstance(shap_values, list) else shap_values[0]
# Pair names with SHAP values, sort by absolute impact
pairs = sorted(zip(self._feature_names, sv), key=lambda x: abs(x[1]), reverse=True)
top = [(name, float(val)) for name, val in pairs[:top_n] if not name.startswith("bert_")]
# If top features are all BERT dims, fall back to numeric block
if not top:
top = [(name, float(val)) for name, val in pairs[:top_n]]
narrative = self._build_narrative(result, top)
return Explanation(
artifact_path=result.artifact.source_path,
suspicion_score=result.suspicion_score,
label=result.label,
top_features=top,
narrative=narrative,
)
def explain_batch(self, results: List[ClassificationResult], top_n: int = 5) -> List[Explanation]:
return [self.explain(r, top_n) for r in results]
# ── Internal ──────────────────────────────────────────────────────────────
def _predict_fn(self, X: np.ndarray) -> np.ndarray:
if self._clf._head_fitted:
return self._clf._head.predict_proba(X)[:, 1]
# Fallback: reconstruct Artifact-like objects from numeric block
scores = []
for row in X:
numeric = row[-len(_NUMERIC_FEATURE_NAMES):]
score = 0.1
score += numeric[1] * 0.2 # is_deleted
score += numeric[2] * 0.35 # anomaly
score += numeric[4] * 0.3 # timestomp
score += numeric[3] * 0.15 # financial
scores.append(min(score, 1.0))
return np.array(scores)
@staticmethod
def _build_narrative(result: ClassificationResult, top: List[Tuple[str, float]]) -> str:
f = result.artifact.features
reasons = []
ts = result.artifact.timestamp
if ts:
import datetime
hour = datetime.datetime.fromtimestamp(ts, tz=datetime.timezone.utc).hour
if hour < 5 or hour > 22:
reasons.append(f"created/modified at {hour:02d}:00 UTC (off-hours)")
if f.get("is_deleted"):
reasons.append("file was deleted")
if f.get("possible_timestomp"):
reasons.append("$SI/$FN timestamp mismatch (possible timestomping)")
if f.get("anomaly"):
reasons.append(f"network anomaly: {f.get('bytes', 0)} bytes to {f.get('dst', '?')}")
if f.get("type") == "persistence":
reasons.append(f"persistence key: {f.get('value', '')}")
if f.get("type") == "usb_mount":
reasons.append(f"USB device mounted: {f.get('device', '')}")
if f.get("is_financial"):
reasons.append("accessed financial website")
if f.get("visit_count", 0) > 30:
reasons.append(f"visited {f['visit_count']} times")
# Append top SHAP features not already covered
for name, val in top:
if val > 0.05 and name not in ("is_deleted", "network_anomaly", "timestomp_detected"):
reasons.append(f"{name.replace('_', ' ')} (SHAP +{val:.2f})")
body = "; ".join(reasons) if reasons else "no dominant features"
return f"[{result.label.upper()}] score={result.suspicion_score:.2f} — because: {body}"