diff --git a/chrome-extension/background.js b/chrome-extension/background.js index 664d0118..1dbf8056 100644 --- a/chrome-extension/background.js +++ b/chrome-extension/background.js @@ -174,10 +174,21 @@ async function handlePIICheck(text) { // Pull the latest backendUrl every time — the service worker may have been // torn down since startup, which resets in-memory `backendUrl` to the // default even if the user configured a different one in options. - const { backendUrl: storedUrl } = await chrome.storage.sync.get({ + const stored = await chrome.storage.sync.get({ backendUrl: DEFAULT_API_BASE, + disabledLabels: [], + allLabels: [], }); - backendUrl = storedUrl || DEFAULT_API_BASE; + backendUrl = stored.backendUrl || DEFAULT_API_BASE; + + // Build enabled_labels: all known labels minus the ones the user disabled. + // An empty array is sent when all labels are enabled (backend treats it as "all"). + const disabledSet = new Set(stored.disabledLabels || []); + const allLabels = stored.allLabels || []; + const enabledLabels = + allLabels.length > 0 && disabledSet.size > 0 + ? allLabels.filter((l) => !disabledSet.has(l)) + : []; const url = `${backendUrl}/api/pii/check`; @@ -186,7 +197,7 @@ async function handlePIICheck(text) { response = await fetch(url, { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ message: text }), + body: JSON.stringify({ message: text, enabled_labels: enabledLabels }), signal: AbortSignal.timeout(10000), }); } catch (e) { diff --git a/chrome-extension/options.css b/chrome-extension/options.css index 881b9dcb..a8b350f2 100644 --- a/chrome-extension/options.css +++ b/chrome-extension/options.css @@ -79,7 +79,9 @@ body { min-height: 100vh; padding: 48px 24px; display: flex; - justify-content: center; + flex-direction: column; + align-items: center; + gap: 24px; } .card { @@ -307,3 +309,167 @@ input[type="url"]:focus, .save-error { color: var(--err); } + +/* ---------- Multiple cards ---------- */ + +/* ---------- Label grid (PII types) ---------- */ + +.label-grid { + display: grid; + grid-template-columns: repeat(auto-fill, minmax(160px, 1fr)); + gap: 8px; + margin-bottom: 20px; +} + +.label-item { + display: flex; + align-items: center; + gap: 8px; + font-size: 12px; + font-family: var(--mono); + color: var(--text); + cursor: pointer; + padding: 6px 8px; + border-radius: var(--radius-sm); + background: var(--bg-subtle); + user-select: none; +} + +.label-item input[type="checkbox"] { + accent-color: var(--brand); + width: 14px; + height: 14px; + flex-shrink: 0; + cursor: pointer; +} + +.label-loading { + font-size: 12px; + color: var(--text-muted); +} + +/* ---------- Custom patterns ---------- */ + +.pattern-form { + margin-bottom: 16px; +} + +.pattern-form-row { + display: flex; + gap: 8px; + align-items: center; + flex-wrap: wrap; + margin-bottom: 8px; +} + +.pattern-form-row .input-mono { + flex: 1; + min-width: 120px; +} + +.pattern-form-row .pattern-replacement { + flex: 1.2; +} + +.pattern-preview-row { + display: flex; + gap: 8px; + align-items: center; +} + +.pattern-preview-row .input-mono { + flex: 1; +} + +.pattern-preview-result { + font-size: 12px; + font-family: var(--mono); + color: var(--text-muted); + white-space: nowrap; +} + +.pattern-error { + color: var(--err); + min-height: 1.2em; +} + +.pattern-list { + display: flex; + flex-direction: column; + gap: 6px; + border-top: 0.5px solid var(--border); + padding-top: 16px; +} + +.pattern-row { + display: flex; + align-items: center; + gap: 10px; + padding: 8px 10px; + background: var(--bg-subtle); + border-radius: var(--radius-sm); + font-size: 12px; + flex-wrap: wrap; +} + +.pattern-label { + font-family: var(--mono); + font-weight: 600; + color: var(--brand); + flex-shrink: 0; +} + +.pattern-regex-val { + font-family: var(--mono); + color: var(--code-text); + flex: 1; + word-break: break-all; +} + +.pattern-replacement-val { + font-family: var(--mono); + color: var(--ok); + font-size: 12px; + flex-shrink: 0; + white-space: nowrap; +} + +.btn-danger { + padding: 5px 10px; + background: transparent; + color: var(--err); + border: 0.5px solid var(--err); + border-radius: var(--radius-sm); + font-size: 12px; + cursor: pointer; + flex-shrink: 0; + transition: background 0.15s ease; +} + +.btn-danger:hover { + background: rgba(226, 75, 74, 0.1); +} + +.btn-secondary { + padding: 5px 10px; + background: transparent; + color: var(--text-muted); + border: 0.5px solid var(--border-strong); + border-radius: var(--radius-sm); + font-size: 12px; + cursor: pointer; + flex-shrink: 0; + transition: background 0.15s ease; +} + +.btn-secondary:hover { + background: var(--bg-subtle); +} + + +.pattern-row .input-mono { + padding: 5px 8px; + font-size: 12px; + flex: 1; + min-width: 80px; +} diff --git a/chrome-extension/options.html b/chrome-extension/options.html index 48ddcda3..fa9b1b53 100644 --- a/chrome-extension/options.html +++ b/chrome-extension/options.html @@ -73,6 +73,82 @@ + +
No custom patterns yet.
'; + return; + } + + patternList.innerHTML = ""; + for (const p of patterns) { + const row = document.createElement("div"); + row.className = "pattern-row"; + row.dataset.id = p.id; + + row.innerHTML = ` + ${escHtml(p.label)} +${escHtml(p.regex)}
+ ${escHtml(p.replacement || "")}
+
+
+ `;
+ patternList.appendChild(row);
+ }
+
+ patternList.querySelectorAll("[data-action=edit]").forEach((btn) => {
+ btn.addEventListener("click", () => {
+ const id = Number(btn.dataset.id);
+ openEditRow(id);
+ });
+ });
+
+ patternList.querySelectorAll("[data-action=delete]").forEach((btn) => {
+ btn.addEventListener("click", async () => {
+ const id = Number(btn.dataset.id);
+ await deletePattern(id);
+ });
+ });
+ }
+
+ function openEditRow(id) {
+ const p = patterns.find((x) => x.id === id);
+ if (!p) return;
+ const row = patternList.querySelector(`[data-id="${id}"]`);
+ if (!row) return;
+
+ row.innerHTML = `
+
+
+
+
+
+
+ `;
+
+ row.querySelector("[data-action=cancel-edit]").addEventListener("click", () => {
+ renderPatterns();
+ });
+
+ row.querySelector("[data-action=save-edit]").addEventListener("click", async () => {
+ const label = row.querySelector(".edit-label").value.trim().toUpperCase();
+ const regex = row.querySelector(".edit-regex").value.trim();
+ const replacement = row.querySelector(".edit-replacement").value.trim();
+ const errEl = row.querySelector(".edit-error");
+ const regexErr = validateRegex(regex);
+ if (regexErr) {
+ errEl.textContent = regexErr;
+ return;
+ }
+ await updatePattern(id, label, regex, replacement, p.enabled, errEl);
+ });
+ }
+
+ async function updatePattern(id, label, regex, replacement, enabled, errEl) {
+ const base = await getBackendUrl();
+ try {
+ const resp = await fetch(`${base}/api/pii/patterns/${id}`, {
+ method: "PUT",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({ label, regex, replacement, enabled }),
+ });
+ if (!resp.ok) throw new Error(await resp.text());
+ const updated = await resp.json();
+ patterns = patterns.map((p) => (p.id === id ? updated : p));
+ renderPatterns();
+ } catch (e) {
+ if (errEl) errEl.textContent = `Failed to save: ${e.message}`;
+ }
+ }
+
+ async function deletePattern(id) {
+ const base = await getBackendUrl();
+ try {
+ const resp = await fetch(`${base}/api/pii/patterns/${id}`, {
+ method: "DELETE",
+ });
+ if (!resp.ok) throw new Error(await resp.text());
+ patterns = patterns.filter((p) => p.id !== id);
+ renderPatterns();
+ loadLabels();
+ } catch (e) {
+ alert(`Failed to delete pattern: ${e.message}`);
+ }
+ }
+
+ patternForm.addEventListener("submit", async (e) => {
+ e.preventDefault();
+ const label = patternLabel.value.trim().toUpperCase();
+ const regex = patternRegex.value.trim();
+ const replacement = patternReplacement.value.trim();
+
+ const error = validateRegex(regex);
+ if (error) {
+ patternRegexError.textContent = error;
+ patternRegexError.style.color = "var(--err)";
+ return;
+ }
+
+ const base = await getBackendUrl();
+ try {
+ const resp = await fetch(`${base}/api/pii/patterns`, {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({ label, regex, replacement }),
+ });
+ if (!resp.ok) throw new Error(await resp.text());
+ const created = await resp.json();
+ patterns.push(created);
+ renderPatterns();
+ loadLabels();
+ patternForm.reset();
+ patternPreviewResult.textContent = "";
+ } catch (err) {
+ patternRegexError.textContent = `Failed to save: ${err.message}`;
+ patternRegexError.style.color = "var(--err)";
+ }
+ });
+
+ function escHtml(str) {
+ return str
+ .replace(/&/g, "&")
+ .replace(//g, ">")
+ .replace(/"/g, """);
+ }
+
+ loadPatterns();
});
diff --git a/src/backend/pii/database.go b/src/backend/pii/database.go
index feb991ce..7ee5afa6 100644
--- a/src/backend/pii/database.go
+++ b/src/backend/pii/database.go
@@ -69,6 +69,24 @@ const (
roleUser = "user"
)
+// CustomPattern holds a user-defined regex-based PII detection rule.
+type CustomPattern struct {
+ ID int64 `json:"id"`
+ Label string `json:"label"`
+ Regex string `json:"regex"`
+ Replacement string `json:"replacement"`
+ Enabled bool `json:"enabled"`
+ CreatedAt string `json:"created_at"`
+}
+
+// CustomPatternDB is the interface for managing user-defined regex patterns.
+type CustomPatternDB interface {
+ ListPatterns(ctx context.Context) ([]CustomPattern, error)
+ CreatePattern(ctx context.Context, label, regex, replacement string) (CustomPattern, error)
+ UpdatePattern(ctx context.Context, id int64, label, regex, replacement string, enabled bool) (CustomPattern, error)
+ DeletePattern(ctx context.Context, id int64) error
+}
+
// LoggingDB defines the interface for logging operations
type LoggingDB interface {
// InsertLog inserts a log entry (automatically parses OpenAI messages if applicable)
@@ -170,6 +188,15 @@ func createSQLiteTables(ctx context.Context, db *sql.DB) error {
`CREATE INDEX IF NOT EXISTS idx_logs_blocked ON logs(blocked)`,
`CREATE INDEX IF NOT EXISTS idx_logs_direction ON logs(direction)`,
`CREATE INDEX IF NOT EXISTS idx_logs_model ON logs(model)`,
+
+ `CREATE TABLE IF NOT EXISTS custom_patterns (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ label TEXT NOT NULL,
+ regex TEXT NOT NULL,
+ replacement TEXT DEFAULT '',
+ enabled INTEGER NOT NULL DEFAULT 1,
+ created_at TEXT DEFAULT (datetime('now'))
+ )`,
}
for _, query := range queries {
@@ -178,6 +205,10 @@ func createSQLiteTables(ctx context.Context, db *sql.DB) error {
}
}
+ // Migrate existing DBs: add columns if not present (errors mean column already exists).
+ _, _ = db.ExecContext(ctx, `ALTER TABLE custom_patterns ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1`)
+ _, _ = db.ExecContext(ctx, `ALTER TABLE custom_patterns ADD COLUMN replacement TEXT DEFAULT ''`)
+
return nil
}
@@ -719,3 +750,90 @@ func (s *SQLitePIIMappingDB) ClearLogs(ctx context.Context) error {
log.Println("[SQLiteDB] All logs cleared")
return nil
}
+
+// ListPatterns returns all custom regex patterns.
+func (s *SQLitePIIMappingDB) ListPatterns(ctx context.Context) ([]CustomPattern, error) {
+ rows, err := s.db.QueryContext(ctx, `SELECT id, label, regex, replacement, enabled, created_at FROM custom_patterns ORDER BY id`)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query custom patterns: %w", err)
+ }
+ defer rows.Close()
+
+ var patterns []CustomPattern
+ for rows.Next() {
+ var p CustomPattern
+ var enabledInt int
+ if err := rows.Scan(&p.ID, &p.Label, &p.Regex, &p.Replacement, &enabledInt, &p.CreatedAt); err != nil {
+ return nil, fmt.Errorf("failed to scan custom pattern: %w", err)
+ }
+ p.Enabled = enabledInt != 0
+ patterns = append(patterns, p)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("error iterating custom patterns: %w", err)
+ }
+ if patterns == nil {
+ patterns = []CustomPattern{}
+ }
+ return patterns, nil
+}
+
+// CreatePattern inserts a new custom regex pattern (enabled by default).
+func (s *SQLitePIIMappingDB) CreatePattern(ctx context.Context, label, regex, replacement string) (CustomPattern, error) {
+ result, err := s.db.ExecContext(ctx,
+ `INSERT INTO custom_patterns (label, regex, replacement, enabled) VALUES (?, ?, ?, 1)`,
+ label, regex, replacement,
+ )
+ if err != nil {
+ return CustomPattern{}, fmt.Errorf("failed to create custom pattern: %w", err)
+ }
+ id, err := result.LastInsertId()
+ if err != nil {
+ return CustomPattern{}, fmt.Errorf("failed to get inserted id: %w", err)
+ }
+ var p CustomPattern
+ var enabledInt int
+ err = s.db.QueryRowContext(ctx,
+ `SELECT id, label, regex, replacement, enabled, created_at FROM custom_patterns WHERE id = ?`, id,
+ ).Scan(&p.ID, &p.Label, &p.Regex, &p.Replacement, &enabledInt, &p.CreatedAt)
+ p.Enabled = enabledInt != 0
+ return p, err
+}
+
+// UpdatePattern updates label, regex, replacement, and enabled state of a pattern.
+func (s *SQLitePIIMappingDB) UpdatePattern(ctx context.Context, id int64, label, regex, replacement string, enabled bool) (CustomPattern, error) {
+ enabledInt := 0
+ if enabled {
+ enabledInt = 1
+ }
+ _, err := s.db.ExecContext(ctx,
+ `UPDATE custom_patterns SET label = ?, regex = ?, replacement = ?, enabled = ? WHERE id = ?`,
+ label, regex, replacement, enabledInt, id,
+ )
+ if err != nil {
+ return CustomPattern{}, fmt.Errorf("failed to update custom pattern: %w", err)
+ }
+ var p CustomPattern
+ var enabledResult int
+ err = s.db.QueryRowContext(ctx,
+ `SELECT id, label, regex, replacement, enabled, created_at FROM custom_patterns WHERE id = ?`, id,
+ ).Scan(&p.ID, &p.Label, &p.Regex, &p.Replacement, &enabledResult, &p.CreatedAt)
+ if err == sql.ErrNoRows {
+ return CustomPattern{}, fmt.Errorf("pattern %d not found", id)
+ }
+ p.Enabled = enabledResult != 0
+ return p, err
+}
+
+// DeletePattern removes a custom regex pattern.
+func (s *SQLitePIIMappingDB) DeletePattern(ctx context.Context, id int64) error {
+ result, err := s.db.ExecContext(ctx, `DELETE FROM custom_patterns WHERE id = ?`, id)
+ if err != nil {
+ return fmt.Errorf("failed to delete custom pattern: %w", err)
+ }
+ n, _ := result.RowsAffected()
+ if n == 0 {
+ return fmt.Errorf("pattern %d not found", id)
+ }
+ return nil
+}
diff --git a/src/backend/pii/database_test.go b/src/backend/pii/database_test.go
index df3553cf..410c706c 100644
--- a/src/backend/pii/database_test.go
+++ b/src/backend/pii/database_test.go
@@ -880,3 +880,80 @@ func TestInsertLog_EntitiesRoundTrip(t *testing.T) {
t.Errorf("expected confidence 0.88, got %f", entries[0].Confidence)
}
}
+
+// --- CustomPattern CRUD tests ---
+
+func TestCustomPattern_CreateAndList(t *testing.T) {
+ db := newTestDB(t)
+ ctx := context.Background()
+
+ p, err := db.CreatePattern(ctx, "EMPLOYEE_ID", `EMP-\d{4}`, "[EMPLOYEE_ID]")
+ if err != nil {
+ t.Fatalf("CreatePattern: %v", err)
+ }
+ if p.ID == 0 {
+ t.Fatal("expected non-zero ID")
+ }
+ if p.Label != "EMPLOYEE_ID" || p.Regex != `EMP-\d{4}` {
+ t.Errorf("unexpected pattern fields: %+v", p)
+ }
+ if p.Replacement != "[EMPLOYEE_ID]" {
+ t.Errorf("expected replacement [EMPLOYEE_ID], got %q", p.Replacement)
+ }
+ if !p.Enabled {
+ t.Error("new pattern should be enabled by default")
+ }
+
+ list, err := db.ListPatterns(ctx)
+ if err != nil {
+ t.Fatalf("ListPatterns: %v", err)
+ }
+ if len(list) != 1 || list[0].ID != p.ID {
+ t.Errorf("expected 1 pattern with id %d, got %v", p.ID, list)
+ }
+}
+
+func TestCustomPattern_Update(t *testing.T) {
+ db := newTestDB(t)
+ ctx := context.Background()
+
+ p, _ := db.CreatePattern(ctx, "OLD", `old`, "")
+ updated, err := db.UpdatePattern(ctx, p.ID, "NEW", `new`, "[NEW]", false)
+ if err != nil {
+ t.Fatalf("UpdatePattern: %v", err)
+ }
+ if updated.Label != "NEW" || updated.Regex != "new" {
+ t.Errorf("unexpected updated fields: %+v", updated)
+ }
+ if updated.Replacement != "[NEW]" {
+ t.Errorf("expected replacement [NEW], got %q", updated.Replacement)
+ }
+ if updated.Enabled {
+ t.Error("expected enabled=false after update")
+ }
+}
+
+func TestCustomPattern_Delete(t *testing.T) {
+ db := newTestDB(t)
+ ctx := context.Background()
+
+ p, _ := db.CreatePattern(ctx, "X", `x`, "")
+ if err := db.DeletePattern(ctx, p.ID); err != nil {
+ t.Fatalf("DeletePattern: %v", err)
+ }
+ list, _ := db.ListPatterns(ctx)
+ if len(list) != 0 {
+ t.Errorf("expected 0 patterns after delete, got %d", len(list))
+ }
+}
+
+func TestCustomPattern_EmptyListIsNotNil(t *testing.T) {
+ db := newTestDB(t)
+ list, err := db.ListPatterns(context.Background())
+ if err != nil {
+ t.Fatalf("ListPatterns: %v", err)
+ }
+ if list == nil {
+ t.Error("ListPatterns should return empty slice, not nil")
+ }
+}
diff --git a/src/backend/pii/detectors/types.go b/src/backend/pii/detectors/types.go
index 8219beef..a19d0b15 100644
--- a/src/backend/pii/detectors/types.go
+++ b/src/backend/pii/detectors/types.go
@@ -13,9 +13,10 @@ type DetectorOutput struct {
// Entity represents a detected PII entity
type Entity struct {
- Text string `json:"text"`
- Label string `json:"label"`
- StartPos int `json:"start_pos"`
- EndPos int `json:"end_pos"`
- Confidence float64 `json:"confidence"`
+ Text string `json:"text"`
+ Label string `json:"label"`
+ StartPos int `json:"start_pos"`
+ EndPos int `json:"end_pos"`
+ Confidence float64 `json:"confidence"`
+ Replacement string `json:"replacement,omitempty"` // custom override; empty means use generator
}
diff --git a/src/backend/pii/generator_service.go b/src/backend/pii/generator_service.go
index 527bb03b..37bc144a 100644
--- a/src/backend/pii/generator_service.go
+++ b/src/backend/pii/generator_service.go
@@ -37,6 +37,16 @@ const (
labelUsername = "USERNAME"
)
+// AllLabels is the full set of PII entity type labels supported by the built-in model.
+var AllLabels = []string{
+ labelSurname, labelFirstName, labelBuildingNum, labelDateOfBirth,
+ labelEmail, labelPhoneNumber, labelCity, labelURL, labelCompanyName,
+ labelState, labelZip, labelStreet, labelCountry, labelSSN,
+ labelDriverLicenseNum, labelPassportID, labelNationalID, labelIDCardNum,
+ labelTaxNum, labelLicensePlateNum, labelPassword, labelIBAN, labelAge,
+ labelSecurityToken, labelCreditCardNumber, labelUsername,
+}
+
// GeneratorService handles PII replacement generation
type GeneratorService struct {
rng *rand.Rand
@@ -91,6 +101,7 @@ func (s *GeneratorService) getGeneratorForLabel(label string) func(string) strin
return generator
}
- // Return generic generator for unknown labels
- return func(original string) string { return piiGenerators.GenericGenerator(s.rng, original) }
+ // Unknown label (e.g. custom regex pattern) — use label name as placeholder
+ // so the AI retains semantic context about what was redacted.
+ return func(original string) string { return "[" + label + "]" }
}
diff --git a/src/backend/pii/masking_service.go b/src/backend/pii/masking_service.go
index 4977fcfc..a7c559b3 100644
--- a/src/backend/pii/masking_service.go
+++ b/src/backend/pii/masking_service.go
@@ -31,6 +31,7 @@ type DetectorProvider interface {
type MaskingService struct {
detectorProvider DetectorProvider
generator *GeneratorService
+ patternDB CustomPatternDB
mapping *PIIMapping // optional persistent original<->dummy store; nil disables reuse
}
@@ -45,8 +46,15 @@ func NewMaskingService(detectorProvider DetectorProvider, generator *GeneratorSe
}
}
-// MaskText detects PII in text and returns masked text with mappings
-func (s *MaskingService) MaskText(text string, logPrefix string) MaskedResult {
+// SetPatternDB wires in the custom-regex pattern store used during masking.
+func (s *MaskingService) SetPatternDB(db CustomPatternDB) {
+ s.patternDB = db
+}
+
+// MaskText detects PII in text and returns masked text with mappings.
+// enabledLabels restricts which model-detected label types are masked; nil means all labels.
+// Custom regex patterns (if any) are always applied regardless of enabledLabels.
+func (s *MaskingService) MaskText(text string, logPrefix string, enabledLabels []string) MaskedResult {
detector, err := s.detectorProvider.GetDetector()
if err != nil {
log.Printf("%s ❌ Failed to get detector: %v", logPrefix, err)
@@ -67,7 +75,48 @@ func (s *MaskingService) MaskText(text string, logPrefix string) MaskedResult {
}
}
- if len(piiFound.Entities) == 0 {
+ entities := piiFound.Entities
+
+ // Filter model entities to the caller-specified label set.
+ if len(enabledLabels) > 0 {
+ enabled := make(map[string]bool, len(enabledLabels))
+ for _, l := range enabledLabels {
+ enabled[l] = true
+ }
+ filtered := entities[:0]
+ for _, e := range entities {
+ if enabled[e.Label] {
+ filtered = append(filtered, e)
+ }
+ }
+ entities = filtered
+ }
+
+ // Append custom regex matches, respecting the same enabledLabels filter.
+ // When enabledLabels is empty (proxy pipeline), all enabled patterns run.
+ // When enabledLabels is non-empty (extension flow), only patterns whose label
+ // is in the enabled set are applied — so the PII types checkbox controls them too.
+ if s.patternDB != nil {
+ if patterns, err := s.patternDB.ListPatterns(context.Background()); err == nil && len(patterns) > 0 {
+ if len(enabledLabels) > 0 {
+ enabled := make(map[string]bool, len(enabledLabels))
+ for _, l := range enabledLabels {
+ enabled[l] = true
+ }
+ filtered := patterns[:0]
+ for _, p := range patterns {
+ if enabled[p.Label] {
+ filtered = append(filtered, p)
+ }
+ }
+ patterns = filtered
+ }
+ rd := newRegexDetector(patterns)
+ entities = append(entities, rd.detect(text)...)
+ }
+ }
+
+ if len(entities) == 0 {
log.Printf("%s No PII detected", logPrefix)
return MaskedResult{
MaskedText: text,
@@ -76,14 +125,18 @@ func (s *MaskingService) MaskText(text string, logPrefix string) MaskedResult {
}
}
- log.Printf("%s ⚠️ PII detected: %d entities", logPrefix, len(piiFound.Entities))
+ // Deduplicate: when entities overlap, keep the longer span; ties go to
+ // the later-appended entity (custom regex takes precedence over ML model
+ // because custom patterns are appended after ML entities).
+ entities = deduplicateEntities(entities)
+
+ log.Printf("%s ⚠️ PII detected: %d entities", logPrefix, len(entities))
// Create mapping of original text to masked text
maskedToOriginal := make(map[string]string)
maskedText := text
// Sort entities by start position in descending order to avoid position shifts
- entities := piiFound.Entities
for i := 0; i < len(entities)-1; i++ {
for j := 0; j < len(entities)-i-1; j++ {
if entities[j].StartPos < entities[j+1].StartPos {
@@ -157,6 +210,42 @@ func (s *MaskingService) MaskText(text string, logPrefix string) MaskedResult {
}
}
+// deduplicateEntities removes overlapping entities, keeping the longest span.
+// When two spans are identical, the last one in the slice wins (custom regex
+// entities are appended after ML model entities, so they take precedence).
+func deduplicateEntities(entities []detectors.Entity) []detectors.Entity {
+ // Sort ascending by start, then descending by length for stable processing.
+ n := len(entities)
+ for i := 0; i < n-1; i++ {
+ for j := 0; j < n-i-1; j++ {
+ a, b := entities[j], entities[j+1]
+ if a.StartPos > b.StartPos || (a.StartPos == b.StartPos && (a.EndPos-a.StartPos) < (b.EndPos-b.StartPos)) {
+ entities[j], entities[j+1] = b, a
+ }
+ }
+ }
+
+ result := entities[:0]
+ for _, e := range entities {
+ if len(result) == 0 {
+ result = append(result, e)
+ continue
+ }
+ prev := &result[len(result)-1]
+ if e.StartPos < prev.EndPos {
+ // Overlapping: keep the longer span; if equal length, keep e (later = custom regex).
+ eLen := e.EndPos - e.StartPos
+ prevLen := prev.EndPos - prev.StartPos
+ if eLen >= prevLen {
+ *prev = e
+ }
+ continue
+ }
+ result = append(result, e)
+ }
+ return result
+}
+
// RestorePII restores masked PII text back to original text using the stored mapping
func (s *MaskingService) RestorePII(text string, maskedToOriginal map[string]string) string {
// Replace all occurrences of masked text with original text
diff --git a/src/backend/pii/masking_service_test.go b/src/backend/pii/masking_service_test.go
new file mode 100644
index 00000000..7a73809f
--- /dev/null
+++ b/src/backend/pii/masking_service_test.go
@@ -0,0 +1,62 @@
+package pii
+
+import (
+ "testing"
+
+ detectors "github.com/hannes/kiji-private/src/backend/pii/detectors"
+)
+
+func ent(label string, start, end int) detectors.Entity {
+ return detectors.Entity{Label: label, StartPos: start, EndPos: end, Text: "x"}
+}
+
+func TestDeduplicateEntities_NoOverlap(t *testing.T) {
+ input := []detectors.Entity{ent("A", 0, 5), ent("B", 6, 10)}
+ got := deduplicateEntities(input)
+ if len(got) != 2 {
+ t.Fatalf("expected 2 entities, got %d", len(got))
+ }
+}
+
+func TestDeduplicateEntities_ExactOverlapKeepsLast(t *testing.T) {
+ // Same span: last entry (custom regex) should win over ML model entity.
+ input := []detectors.Entity{ent("ML", 0, 8), ent("CUSTOM", 0, 8)}
+ got := deduplicateEntities(input)
+ if len(got) != 1 {
+ t.Fatalf("expected 1 entity, got %d", len(got))
+ }
+ if got[0].Label != "CUSTOM" {
+ t.Errorf("expected CUSTOM to win, got %s", got[0].Label)
+ }
+}
+
+func TestDeduplicateEntities_LongerSpanWins(t *testing.T) {
+ // Partial overlap: the longer span should be kept.
+ input := []detectors.Entity{ent("SHORT", 2, 6), ent("LONG", 0, 10)}
+ got := deduplicateEntities(input)
+ if len(got) != 1 {
+ t.Fatalf("expected 1 entity, got %d", len(got))
+ }
+ if got[0].Label != "LONG" {
+ t.Errorf("expected LONG to win, got %s", got[0].Label)
+ }
+}
+
+func TestDeduplicateEntities_NonOverlappingPreserveOrder(t *testing.T) {
+ input := []detectors.Entity{ent("C", 10, 15), ent("A", 0, 3), ent("B", 5, 8)}
+ got := deduplicateEntities(input)
+ if len(got) != 3 {
+ t.Fatalf("expected 3 entities, got %d", len(got))
+ }
+ // Should come out sorted ascending by start position.
+ if got[0].Label != "A" || got[1].Label != "B" || got[2].Label != "C" {
+ t.Errorf("unexpected order: %v", got)
+ }
+}
+
+func TestDeduplicateEntities_Empty(t *testing.T) {
+ got := deduplicateEntities(nil)
+ if len(got) != 0 {
+ t.Fatalf("expected empty, got %d", len(got))
+ }
+}
diff --git a/src/backend/pii/regex_detector.go b/src/backend/pii/regex_detector.go
new file mode 100644
index 00000000..edddda4a
--- /dev/null
+++ b/src/backend/pii/regex_detector.go
@@ -0,0 +1,49 @@
+package pii
+
+import (
+ "regexp"
+
+ detectors "github.com/hannes/kiji-private/src/backend/pii/detectors"
+)
+
+type regexDetector struct {
+ patterns []compiledPattern
+}
+
+type compiledPattern struct {
+ label string
+ replacement string
+ re *regexp.Regexp
+}
+
+func newRegexDetector(patterns []CustomPattern) *regexDetector {
+ compiled := make([]compiledPattern, 0, len(patterns))
+ for _, p := range patterns {
+ if !p.Enabled {
+ continue
+ }
+ re, err := regexp.Compile(p.Regex)
+ if err != nil {
+ continue
+ }
+ compiled = append(compiled, compiledPattern{label: p.Label, replacement: p.Replacement, re: re})
+ }
+ return ®exDetector{patterns: compiled}
+}
+
+func (d *regexDetector) detect(text string) []detectors.Entity {
+ var entities []detectors.Entity
+ for _, p := range d.patterns {
+ for _, m := range p.re.FindAllStringIndex(text, -1) {
+ entities = append(entities, detectors.Entity{
+ Text: text[m[0]:m[1]],
+ Label: p.label,
+ StartPos: m[0],
+ EndPos: m[1],
+ Confidence: 1.0,
+ Replacement: p.replacement,
+ })
+ }
+ }
+ return entities
+}
diff --git a/src/backend/proxy/handler.go b/src/backend/proxy/handler.go
index ffc9544d..8f5eb114 100644
--- a/src/backend/proxy/handler.go
+++ b/src/backend/proxy/handler.go
@@ -33,8 +33,14 @@ type Handler struct {
detector *pii.Detector
responseProcessor *processor.ResponseProcessor
maskingService *piiServices.MaskingService
- loggingDB piiServices.LoggingDB // Database or in-memory storage for logging
- mappingDB piiServices.PIIMappingDB // Same instance as loggingDB, for mapping operations
+ loggingDB piiServices.LoggingDB // Database or in-memory storage for logging
+ mappingDB piiServices.PIIMappingDB // Same instance as loggingDB, for mapping operations
+ patternDB piiServices.CustomPatternDB // Same instance as loggingDB, for custom pattern CRUD
+}
+
+// PatternDB returns the custom regex pattern store, for use by API handlers.
+func (h *Handler) PatternDB() piiServices.CustomPatternDB {
+ return h.patternDB
}
// ReloadModel reloads the PII model from the specified directory
@@ -297,7 +303,7 @@ func (h *Handler) maskPIIInText(text string, logPrefix string) (string, map[stri
// Model is unhealthy - return text unchanged
return text, make(map[string]string), []pii.Entity{}
}
- result := h.maskingService.MaskText(text, logPrefix)
+ result := h.maskingService.MaskText(text, logPrefix, nil)
return result.MaskedText, result.MaskedToOriginal, result.Entities
}
@@ -306,6 +312,16 @@ func (h *Handler) MaskPIIInText(text string) (string, map[string]string, []pii.E
return h.maskPIIInText(text, "[PIICheck]")
}
+// MaskPIIInTextFiltered masks PII in text, limiting model detection to the given label set.
+// If enabledLabels is nil or empty, all labels are active.
+func (h *Handler) MaskPIIInTextFiltered(text string, enabledLabels []string) (string, map[string]string, []pii.Entity) {
+ if h.maskingService == nil {
+ return text, make(map[string]string), []pii.Entity{}
+ }
+ result := h.maskingService.MaskText(text, "[PIICheck]", enabledLabels)
+ return result.MaskedText, result.MaskedToOriginal, result.Entities
+}
+
// ProcessedRequest contains the result of processing a request through the PII pipeline
type ProcessedRequest struct {
RedactedBody []byte
@@ -578,6 +594,22 @@ func NewHandler(cfg *config.Config) (*Handler, error) {
CustomProvider: customProvider,
}
+ // Create services
+ // MaskingService now uses ModelManager as a DetectorProvider, so it always gets
+ // the current detector after hot reloads
+ generatorService := piiServices.NewGeneratorService()
+ maskingService := piiServices.NewMaskingService(modelManager, generatorService)
+ // patternDB is wired in below, after the SQLite DB is initialised
+
+ var responseProcessor *processor.ResponseProcessor
+ if detector != nil {
+ responseProcessor = processor.NewResponseProcessor(&detector, cfg.Logging)
+ } else {
+ // Model is unhealthy at startup - log warning but allow server to start
+ log.Printf("[Handler] Creating handler with unhealthy model - PII detection disabled until model is fixed")
+ responseProcessor = nil
+ }
+
// Initialize SQLite database
ctx := context.Background()
dbConfig := piiServices.DatabaseConfig{
@@ -592,6 +624,7 @@ func NewHandler(cfg *config.Config) (*Handler, error) {
// Set debug mode based on config
loggingDB.SetDebugMode(cfg.Logging.DebugMode)
+ maskingService.SetPatternDB(db)
// Create services
// MaskingService now uses ModelManager as a DetectorProvider, so it always gets
@@ -630,6 +663,7 @@ func NewHandler(cfg *config.Config) (*Handler, error) {
maskingService: maskingService,
loggingDB: loggingDB,
mappingDB: loggingDB.(piiServices.PIIMappingDB), // Same instance, different interface
+ patternDB: db,
}, nil
}
diff --git a/src/backend/server/server.go b/src/backend/server/server.go
index ba1cec04..860355b4 100644
--- a/src/backend/server/server.go
+++ b/src/backend/server/server.go
@@ -1,6 +1,7 @@
package server
import (
+ "context"
"encoding/json"
"fmt"
"io/fs"
@@ -8,11 +9,14 @@ import (
"net/http"
"os"
"path/filepath"
+ "regexp"
+ "strconv"
"sync"
"time"
"github.com/hannes/kiji-private/src/backend/config"
"github.com/hannes/kiji-private/src/backend/paths"
+ pii "github.com/hannes/kiji-private/src/backend/pii"
"github.com/hannes/kiji-private/src/backend/providers"
"github.com/hannes/kiji-private/src/backend/proxy"
"golang.org/x/time/rate"
@@ -236,6 +240,9 @@ func (s *Server) Start() error {
mux.HandleFunc("/api/proxy/transparent/toggle", s.handleTransparentProxyToggle)
mux.HandleFunc("/api/pii/check", s.handlePIICheck)
mux.HandleFunc("/api/pii/confidence", s.handlePIIConfidence)
+ mux.HandleFunc("/api/pii/labels", s.handlePIILabels)
+ mux.HandleFunc("/api/pii/patterns", s.handlePIIPatterns)
+ mux.HandleFunc("/api/pii/patterns/{id}", s.handlePIIPattern)
// Add provider endpoints
mux.Handle(providers.ProviderSubpathOpenAI, s.handler) // same as Mistral
@@ -342,6 +349,8 @@ func (s *Server) startTransparentProxy() {
s.handlePIICheck(w, r)
case "/api/pii/confidence":
s.handlePIIConfidence(w, r)
+ case "/api/pii/labels":
+ s.handlePIILabels(w, r)
default:
// All other HTTP/HTTPS requests go to transparent proxy
s.transparentProxy.ServeHTTP(w, r)
@@ -555,7 +564,8 @@ func (s *Server) handleModelSecurity(w http.ResponseWriter, r *http.Request) {
// PIICheckRequest represents the request body for PII checking
type PIICheckRequest struct {
- Message string `json:"message"`
+ Message string `json:"message"`
+ EnabledLabels []string `json:"enabled_labels,omitempty"`
}
// DetectedEntity represents a single detected PII entity with its label and
@@ -616,7 +626,7 @@ func (s *Server) handlePIICheck(w http.ResponseWriter, r *http.Request) {
}
// Use the handler's masking service to check for PII
- maskedText, maskedToOriginal, entities := s.handler.MaskPIIInText(req.Message)
+ maskedText, maskedToOriginal, entities := s.handler.MaskPIIInTextFiltered(req.Message, req.EnabledLabels)
// masked -> original map (consumed by UI)
entityDetails := make(map[string]string)
@@ -901,3 +911,228 @@ func (s *Server) Close() error {
}
return nil
}
+
+// handlePIILabels handles GET /api/pii/labels.
+// It reads label_mappings.json from the model directory to build the list dynamically,
+// then appends any user-defined custom pattern labels so they appear in the entity-type UI.
+func (s *Server) handlePIILabels(w http.ResponseWriter, r *http.Request) {
+ s.corsHandler(w, r)
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ labels := s.modelLabels()
+
+ // Append custom pattern labels that aren't already in the model set.
+ db := s.handler.PatternDB()
+ if db != nil {
+ ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ defer cancel()
+ if patterns, err := db.ListPatterns(ctx); err == nil {
+ seen := make(map[string]bool, len(labels))
+ for _, l := range labels {
+ seen[l] = true
+ }
+ for _, p := range patterns {
+ if !seen[p.Label] {
+ labels = append(labels, p.Label)
+ seen[p.Label] = true
+ }
+ }
+ }
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(map[string]interface{}{"labels": labels}); err != nil {
+ log.Printf("Failed to encode labels response: %v", err)
+ }
+}
+
+// modelLabels reads entity type labels from label_mappings.json in the model directory,
+// stripping BIO prefixes and deduplicating. Falls back to AllLabels if the file can't be read.
+func (s *Server) modelLabels() []string {
+ mappingPath := filepath.Join(s.config.ResolveModelDirectory(), "label_mappings.json")
+ data, err := os.ReadFile(mappingPath) // #nosec G304 — path derived from validated config
+ if err != nil {
+ return pii.AllLabels
+ }
+
+ var mapping struct {
+ PII struct {
+ Label2ID map[string]int `json:"label2id"`
+ } `json:"pii"`
+ }
+ if err := json.Unmarshal(data, &mapping); err != nil {
+ return pii.AllLabels
+ }
+
+ seen := make(map[string]bool)
+ var labels []string
+ for raw := range mapping.PII.Label2ID {
+ name := raw
+ if len(raw) > 2 && (raw[:2] == "B-" || raw[:2] == "I-") {
+ name = raw[2:]
+ }
+ if name == "O" || name == "IGNORE" || seen[name] {
+ continue
+ }
+ seen[name] = true
+ labels = append(labels, name)
+ }
+ return labels
+}
+
+const maxPatternLength = 500
+
+func regexpCompile(pattern string) (*regexp.Regexp, error) {
+ return regexp.Compile(pattern)
+}
+
+// handlePIIPatterns handles GET and POST /api/pii/patterns.
+func (s *Server) handlePIIPatterns(w http.ResponseWriter, r *http.Request) {
+ s.corsHandler(w, r)
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ db := s.handler.PatternDB()
+ if db == nil {
+ http.Error(w, "Pattern store unavailable", http.StatusServiceUnavailable)
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
+ defer cancel()
+
+ switch r.Method {
+ case http.MethodGet:
+ patterns, err := db.ListPatterns(ctx)
+ if err != nil {
+ http.Error(w, "Failed to list patterns", http.StatusInternalServerError)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(map[string]interface{}{"patterns": patterns}); err != nil {
+ log.Printf("Failed to encode patterns response: %v", err)
+ }
+
+ case http.MethodPost:
+ var req struct {
+ Label string `json:"label"`
+ Regex string `json:"regex"`
+ Replacement string `json:"replacement"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
+ return
+ }
+ if req.Label == "" || req.Regex == "" {
+ http.Error(w, "label and regex are required", http.StatusBadRequest)
+ return
+ }
+ if len(req.Regex) > maxPatternLength {
+ http.Error(w, fmt.Sprintf("regex exceeds maximum length of %d characters", maxPatternLength), http.StatusBadRequest)
+ return
+ }
+ if _, err := regexpCompile(req.Regex); err != nil {
+ http.Error(w, "invalid regex: "+err.Error(), http.StatusBadRequest)
+ return
+ }
+ pattern, err := db.CreatePattern(ctx, req.Label, req.Regex, req.Replacement)
+ if err != nil {
+ http.Error(w, "Failed to create pattern: "+err.Error(), http.StatusInternalServerError)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusCreated)
+ if err := json.NewEncoder(w).Encode(pattern); err != nil {
+ log.Printf("Failed to encode create pattern response: %v", err)
+ }
+
+ default:
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+}
+
+// handlePIIPattern handles PUT and DELETE /api/pii/patterns/{id}.
+func (s *Server) handlePIIPattern(w http.ResponseWriter, r *http.Request) {
+ s.corsHandler(w, r)
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ idStr := r.PathValue("id")
+ id, err := strconv.ParseInt(idStr, 10, 64)
+ if err != nil {
+ http.Error(w, "Invalid pattern id", http.StatusBadRequest)
+ return
+ }
+
+ db := s.handler.PatternDB()
+ if db == nil {
+ http.Error(w, "Pattern store unavailable", http.StatusServiceUnavailable)
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
+ defer cancel()
+
+ switch r.Method {
+ case http.MethodPut:
+ var req struct {
+ Label string `json:"label"`
+ Regex string `json:"regex"`
+ Replacement string `json:"replacement"`
+ Enabled *bool `json:"enabled"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
+ return
+ }
+ if req.Label == "" || req.Regex == "" {
+ http.Error(w, "label and regex are required", http.StatusBadRequest)
+ return
+ }
+ if len(req.Regex) > maxPatternLength {
+ http.Error(w, fmt.Sprintf("regex exceeds maximum length of %d characters", maxPatternLength), http.StatusBadRequest)
+ return
+ }
+ if _, err := regexpCompile(req.Regex); err != nil {
+ http.Error(w, "invalid regex: "+err.Error(), http.StatusBadRequest)
+ return
+ }
+ enabled := true
+ if req.Enabled != nil {
+ enabled = *req.Enabled
+ }
+ pattern, err := db.UpdatePattern(ctx, id, req.Label, req.Regex, req.Replacement, enabled)
+ if err != nil {
+ http.Error(w, "Failed to update pattern: "+err.Error(), http.StatusInternalServerError)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(pattern); err != nil {
+ log.Printf("Failed to encode update pattern response: %v", err)
+ }
+
+ case http.MethodDelete:
+ if err := db.DeletePattern(ctx, id); err != nil {
+ http.Error(w, "Failed to delete pattern: "+err.Error(), http.StatusInternalServerError)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(map[string]interface{}{responseFieldSuccess: true}); err != nil {
+ log.Printf("Failed to encode delete pattern response: %v", err)
+ }
+
+ default:
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+}