diff --git a/README.md b/README.md index 4d0db25..15f3069 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,9 @@ import { createPromptDefense } from '@stackone/defender'; // Create defense with Tier 1 (patterns) + Tier 2 (ML classifier) // blockHighRisk: true enables the allowed/blocked decision +// Tier 1 (patterns) + Tier 2 (ML classifier) are both on by default. +// blockHighRisk: true enables the allowed/blocked decision. const defense = createPromptDefense({ - enableTier2: true, blockHighRisk: true, useDefaultToolRules: true, // Enable built-in per-tool base risk and field-handling rules (risky-field overrides always apply) }); @@ -105,9 +106,10 @@ Create a defense instance. ```typescript const defense = createPromptDefense({ enableTier1: true, // Pattern detection (default: true) - enableTier2: true, // ML classification (default: false) + enableTier2: true, // ML classification (default: true) — set false to disable blockHighRisk: true, // Block high/critical content (default: false) useDefaultToolRules: true, // Enable built-in per-tool base risk and field-handling rules (default: false) + tier2Fields: ['subject', 'body', 'snippet'], // Scope Tier 2 to specific fields (default: all fields) defaultRiskLevel: 'medium', }); ``` @@ -164,14 +166,13 @@ console.log(result.matches); // [{ pattern: '...', severity: 'high', ... } ONNX mode auto-loads the bundled model on first `defendToolResult()` call. Use `warmupTier2()` at startup to avoid first-call latency: ```typescript -// ONNX mode (default) — optional warmup to pre-load at startup -const defense = createPromptDefense({ enableTier2: true }); +// ONNX mode (default) — Tier 2 is on by default, warmup is optional +const defense = createPromptDefense(); await defense.warmupTier2(); // optional, avoids ~1-2s first-call latency // MLP mode (legacy) — requires loading weights explicitly import { createPromptDefense, MLP_WEIGHTS } from '@stackone/defender'; const mlpDefense = createPromptDefense({ - enableTier2: true, tier2Config: { mode: 'mlp' }, }); mlpDefense.loadTier2Weights(MLP_WEIGHTS); @@ -187,7 +188,6 @@ import { generateText, tool } from 'ai'; import { createPromptDefense } from '@stackone/defender'; const defense = createPromptDefense({ - enableTier2: true, blockHighRisk: true, useDefaultToolRules: true, }); diff --git a/src/config.ts b/src/config.ts index 9a24efc..043cd8f 100644 --- a/src/config.ts +++ b/src/config.ts @@ -169,7 +169,6 @@ export const DEFAULT_CUMULATIVE_RISK_THRESHOLDS = { * Default Tier 2 configuration */ export const DEFAULT_TIER2_CONFIG = { - enabled: false, // Disabled until implemented mode: "onnx" as const, highRiskThreshold: 0.8, mediumRiskThreshold: 0.5, diff --git a/src/core/prompt-defense.ts b/src/core/prompt-defense.ts index 4210c4a..a3205f3 100644 --- a/src/core/prompt-defense.ts +++ b/src/core/prompt-defense.ts @@ -37,6 +37,8 @@ export interface DefenseResult { patternsByField: Record; /** Tier 2 ML score (0.0 = safe, 1.0 = injection), undefined if Tier 2 not enabled */ tier2Score?: number; + /** Reason Tier 2 was skipped (e.g. "No strings extracted") when tier2Score is undefined */ + tier2SkipReason?: string; /** The sentence with the highest Tier 2 score */ maxSentence?: string; /** Total processing time in milliseconds */ @@ -45,22 +47,44 @@ export interface DefenseResult { /** * Recursively extract all string values from an object. - * Used to collect text content from tool results for Tier 2 classification. + * When `fields` is provided, only strings under matching field keys are collected; + * the traversal still descends into non-matching keys to find matching ones deeper. */ -function extractStrings(obj: unknown): string[] { +function extractStrings(obj: unknown, fields?: string[]): string[] { const strings: string[] = []; - function traverse(value: unknown): void { + function collectAll(value: unknown): void { if (typeof value === "string") { strings.push(value); } else if (Array.isArray(value)) { - for (const item of value) { - traverse(item); - } + for (const item of value) collectAll(item); } else if (value && typeof value === "object") { - for (const v of Object.values(value)) { - traverse(v); + for (const v of Object.values(value)) collectAll(v); + } + } + + if (!fields || fields.length === 0) { + collectAll(obj); + return strings; + } + + // Use a Set for O(1) key lookups during traversal + const fieldSet = new Set(fields); + + function traverse(value: unknown): void { + if (Array.isArray(value)) { + for (const item of value) traverse(item); + } else if (value && typeof value === "object") { + for (const [k, v] of Object.entries(value as Record)) { + if (fieldSet.has(k)) { + collectAll(v); + } else { + traverse(v); + } } + } else if (typeof value === "string") { + // Plain string — no field keys to filter on, fall back to collecting it + strings.push(value); } } @@ -76,7 +100,7 @@ export interface PromptDefenseOptions { config?: Partial; /** Enable Tier 1 classification */ enableTier1?: boolean; - /** Enable Tier 2 ML classification */ + /** Enable Tier 2 ML classification (default: true — set false to disable) */ enableTier2?: boolean; /** Tier 2 classifier configuration */ tier2Config?: Partial; @@ -91,6 +115,12 @@ export interface PromptDefenseOptions { * Defaults to false — tool rules are opt-in to avoid unexpected risk level inflation. */ useDefaultToolRules?: boolean; + /** + * Only run Tier 2 on strings extracted from these field names. + * Strings under any other field key are skipped. + * If omitted, Tier 2 runs on all strings in the tool result. + */ + tier2Fields?: string[]; } /** @@ -114,6 +144,7 @@ export class PromptDefense { private toolResultSanitizer: ToolResultSanitizer; private patternDetector: PatternDetector; private tier2Classifier: Tier2Classifier | null = null; + private tier2Fields: string[] | undefined; constructor(options: PromptDefenseOptions = {}) { // Build configuration @@ -124,6 +155,8 @@ export class PromptDefense { this.config.blockHighRisk = options.blockHighRisk; } + this.tier2Fields = options.tier2Fields ?? this.config.tier2?.tier2Fields; + // Initialize components this.toolResultSanitizer = createToolResultSanitizer({ riskyFields: this.config.riskyFields, @@ -141,7 +174,7 @@ export class PromptDefense { this.patternDetector = createPatternDetector(); // Initialize Tier 2 classifier if enabled - if (options.enableTier2) { + if (options.enableTier2 ?? true) { this.tier2Classifier = createTier2Classifier(options.tier2Config); if (options.tier2Weights) { this.tier2Classifier.loadWeights(options.tier2Weights); @@ -212,11 +245,12 @@ export class PromptDefense { // Tier 2: sentence-level ML classification on raw (unsanitized) value let tier2Score: number | undefined; + let tier2SkipReason: string | undefined; let maxSentence: string | undefined; let tier2Risk: RiskLevel = "low"; if (this.tier2Classifier) { - const strings = extractStrings(value); + const strings = extractStrings(value, this.tier2Fields); const combinedText = strings.join("\n\n"); if (combinedText.length > 0) { @@ -225,7 +259,13 @@ export class PromptDefense { tier2Score = tier2Result.score; tier2Risk = this.tier2Classifier.getRiskLevel(tier2Result.score); maxSentence = tier2Result.maxSentence; + } else { + tier2SkipReason = tier2Result.skipReason; } + } else { + tier2SkipReason = this.tier2Fields?.length + ? "No strings found in tier2Fields" + : "No strings extracted from tool result"; } } @@ -257,6 +297,7 @@ export class PromptDefense { fieldsSanitized, patternsByField: patternsRemovedByField, tier2Score, + tier2SkipReason, maxSentence, latencyMs: performance.now() - startTime, }; diff --git a/src/index.ts b/src/index.ts index 1fbad19..1d39910 100644 --- a/src/index.ts +++ b/src/index.ts @@ -27,4 +27,4 @@ export { type PromptDefenseOptions, } from "./core/prompt-defense"; // Types -export type { RiskLevel, Tier1Result } from "./types"; +export type { RiskLevel, Tier1Result, ToolSanitizationRule } from "./types"; diff --git a/src/types.ts b/src/types.ts index 6323c85..3b7f665 100644 --- a/src/types.ts +++ b/src/types.ts @@ -307,8 +307,6 @@ export interface PromptDefenseConfig { }; /** Tier 2 configuration */ tier2: { - /** Whether Tier 2 is enabled */ - enabled: boolean; /** Inference mode: 'onnx' for fine-tuned MiniLM, 'mlp' for frozen embeddings + MLP head */ mode?: "mlp" | "onnx"; /** Score threshold for high risk */ @@ -317,6 +315,12 @@ export interface PromptDefenseConfig { mediumRiskThreshold: number; /** Size threshold to skip Tier 2 (bytes) */ skipBelowSize: number; + /** + * Only run Tier 2 on strings extracted from these field names. + * Strings under any other field key are skipped. + * If omitted, Tier 2 runs on all strings in the tool result. + */ + tier2Fields?: string[]; }; /** Whether to block high/critical risk by default */ blockHighRisk: boolean;