Skip to content
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import { createPromptDefense } from '@stackone/defender';

// Create defense with Tier 1 (patterns) + Tier 2 (ML classifier)
// blockHighRisk: true enables the allowed/blocked decision
// Tier 1 (patterns) + Tier 2 (ML classifier) are both on by default.
// blockHighRisk: true enables the allowed/blocked decision.
const defense = createPromptDefense({
enableTier2: true,
blockHighRisk: true,
useDefaultToolRules: true, // Enable built-in per-tool base risk and field-handling rules (risky-field overrides always apply)
});
Expand Down Expand Up @@ -105,9 +106,10 @@ Create a defense instance.
```typescript
const defense = createPromptDefense({
enableTier1: true, // Pattern detection (default: true)
enableTier2: true, // ML classification (default: false)
enableTier2: true, // ML classification (default: true) — set false to disable
blockHighRisk: true, // Block high/critical content (default: false)
useDefaultToolRules: true, // Enable built-in per-tool base risk and field-handling rules (default: false)
tier2Fields: ['subject', 'body', 'snippet'], // Scope Tier 2 to specific fields (default: all fields)
defaultRiskLevel: 'medium',
});
```
Expand Down Expand Up @@ -164,14 +166,13 @@ console.log(result.matches); // [{ pattern: '...', severity: 'high', ... }
ONNX mode auto-loads the bundled model on first `defendToolResult()` call. Use `warmupTier2()` at startup to avoid first-call latency:

```typescript
// ONNX mode (default) — optional warmup to pre-load at startup
const defense = createPromptDefense({ enableTier2: true });
// ONNX mode (default) — Tier 2 is on by default, warmup is optional
const defense = createPromptDefense();
await defense.warmupTier2(); // optional, avoids ~1-2s first-call latency

// MLP mode (legacy) — requires loading weights explicitly
import { createPromptDefense, MLP_WEIGHTS } from '@stackone/defender';
const mlpDefense = createPromptDefense({
enableTier2: true,
tier2Config: { mode: 'mlp' },
});
mlpDefense.loadTier2Weights(MLP_WEIGHTS);
Expand All @@ -187,7 +188,6 @@ import { generateText, tool } from 'ai';
import { createPromptDefense } from '@stackone/defender';

const defense = createPromptDefense({
enableTier2: true,
blockHighRisk: true,
useDefaultToolRules: true,
});
Expand Down
1 change: 0 additions & 1 deletion src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ export const DEFAULT_CUMULATIVE_RISK_THRESHOLDS = {
* Default Tier 2 configuration
*/
export const DEFAULT_TIER2_CONFIG = {
enabled: false, // Disabled until implemented
mode: "onnx" as const,
highRiskThreshold: 0.8,
mediumRiskThreshold: 0.5,
Expand Down
63 changes: 52 additions & 11 deletions src/core/prompt-defense.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ export interface DefenseResult {
patternsByField: Record<string, string[]>;
/** Tier 2 ML score (0.0 = safe, 1.0 = injection), undefined if Tier 2 not enabled */
tier2Score?: number;
/** Reason Tier 2 was skipped (e.g. "No strings extracted") when tier2Score is undefined */
tier2SkipReason?: string;
/** The sentence with the highest Tier 2 score */
maxSentence?: string;
/** Total processing time in milliseconds */
Expand All @@ -45,22 +47,44 @@ export interface DefenseResult {

/**
* Recursively extract all string values from an object.
* Used to collect text content from tool results for Tier 2 classification.
* When `fields` is provided, only strings under matching field keys are collected;
* the traversal still descends into non-matching keys to find matching ones deeper.
*/
function extractStrings(obj: unknown): string[] {
function extractStrings(obj: unknown, fields?: string[]): string[] {
const strings: string[] = [];

function traverse(value: unknown): void {
function collectAll(value: unknown): void {
if (typeof value === "string") {
strings.push(value);
} else if (Array.isArray(value)) {
for (const item of value) {
traverse(item);
}
for (const item of value) collectAll(item);
} else if (value && typeof value === "object") {
for (const v of Object.values(value)) {
traverse(v);
for (const v of Object.values(value)) collectAll(v);
}
}

if (!fields || fields.length === 0) {
collectAll(obj);
return strings;
}

// Use a Set for O(1) key lookups during traversal
const fieldSet = new Set(fields);

function traverse(value: unknown): void {
if (Array.isArray(value)) {
for (const item of value) traverse(item);
} else if (value && typeof value === "object") {
for (const [k, v] of Object.entries(value as Record<string, unknown>)) {
if (fieldSet.has(k)) {
collectAll(v);
} else {
traverse(v);
Comment thread
hiskudin marked this conversation as resolved.
}
Comment thread
hiskudin marked this conversation as resolved.
}
} else if (typeof value === "string") {
// Plain string — no field keys to filter on, fall back to collecting it
strings.push(value);
}
}

Expand All @@ -76,7 +100,7 @@ export interface PromptDefenseOptions {
config?: Partial<PromptDefenseConfig>;
/** Enable Tier 1 classification */
enableTier1?: boolean;
/** Enable Tier 2 ML classification */
/** Enable Tier 2 ML classification (default: true — set false to disable) */
enableTier2?: boolean;
/** Tier 2 classifier configuration */
tier2Config?: Partial<Tier2ClassifierConfig>;
Expand All @@ -91,6 +115,12 @@ export interface PromptDefenseOptions {
* Defaults to false — tool rules are opt-in to avoid unexpected risk level inflation.
*/
useDefaultToolRules?: boolean;
/**
* Only run Tier 2 on strings extracted from these field names.
* Strings under any other field key are skipped.
* If omitted, Tier 2 runs on all strings in the tool result.
*/
tier2Fields?: string[];
}

/**
Expand All @@ -114,6 +144,7 @@ export class PromptDefense {
private toolResultSanitizer: ToolResultSanitizer;
private patternDetector: PatternDetector;
private tier2Classifier: Tier2Classifier | null = null;
private tier2Fields: string[] | undefined;

constructor(options: PromptDefenseOptions = {}) {
// Build configuration
Expand All @@ -124,6 +155,8 @@ export class PromptDefense {
this.config.blockHighRisk = options.blockHighRisk;
Comment thread
hiskudin marked this conversation as resolved.
}

this.tier2Fields = options.tier2Fields ?? this.config.tier2?.tier2Fields;

// Initialize components
this.toolResultSanitizer = createToolResultSanitizer({
riskyFields: this.config.riskyFields,
Expand All @@ -141,7 +174,7 @@ export class PromptDefense {
this.patternDetector = createPatternDetector();

// Initialize Tier 2 classifier if enabled
if (options.enableTier2) {
if (options.enableTier2 ?? true) {
Comment thread
hiskudin marked this conversation as resolved.
this.tier2Classifier = createTier2Classifier(options.tier2Config);
if (options.tier2Weights) {
this.tier2Classifier.loadWeights(options.tier2Weights);
Expand Down Expand Up @@ -212,11 +245,12 @@ export class PromptDefense {

// Tier 2: sentence-level ML classification on raw (unsanitized) value
let tier2Score: number | undefined;
let tier2SkipReason: string | undefined;
let maxSentence: string | undefined;
let tier2Risk: RiskLevel = "low";

if (this.tier2Classifier) {
const strings = extractStrings(value);
const strings = extractStrings(value, this.tier2Fields);
const combinedText = strings.join("\n\n");
Comment thread
hiskudin marked this conversation as resolved.

if (combinedText.length > 0) {
Expand All @@ -225,7 +259,13 @@ export class PromptDefense {
tier2Score = tier2Result.score;
tier2Risk = this.tier2Classifier.getRiskLevel(tier2Result.score);
maxSentence = tier2Result.maxSentence;
} else {
tier2SkipReason = tier2Result.skipReason;
}
} else {
tier2SkipReason = this.tier2Fields?.length
? "No strings found in tier2Fields"
: "No strings extracted from tool result";
}
}

Expand Down Expand Up @@ -257,6 +297,7 @@ export class PromptDefense {
fieldsSanitized,
patternsByField: patternsRemovedByField,
tier2Score,
tier2SkipReason,
maxSentence,
latencyMs: performance.now() - startTime,
};
Expand Down
2 changes: 1 addition & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ export {
type PromptDefenseOptions,
} from "./core/prompt-defense";
// Types
export type { RiskLevel, Tier1Result } from "./types";
export type { RiskLevel, Tier1Result, ToolSanitizationRule } from "./types";
8 changes: 6 additions & 2 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,6 @@ export interface PromptDefenseConfig {
};
/** Tier 2 configuration */
tier2: {
/** Whether Tier 2 is enabled */
enabled: boolean;
/** Inference mode: 'onnx' for fine-tuned MiniLM, 'mlp' for frozen embeddings + MLP head */
mode?: "mlp" | "onnx";
/** Score threshold for high risk */
Expand All @@ -317,6 +315,12 @@ export interface PromptDefenseConfig {
mediumRiskThreshold: number;
/** Size threshold to skip Tier 2 (bytes) */
skipBelowSize: number;
/**
* Only run Tier 2 on strings extracted from these field names.
* Strings under any other field key are skipped.
* If omitted, Tier 2 runs on all strings in the tool result.
*/
tier2Fields?: string[];
};
/** Whether to block high/critical risk by default */
blockHighRisk: boolean;
Expand Down
Loading