Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: pip install -e ".[dev]"

- name: Lint with ruff
run: |
ruff check src/ tests/
ruff format --check src/ tests/

- name: Run tests
run: pytest tests/ -v --tb=short

- name: Verify build command
run: |
prompt-db build --data-dir . --output /tmp/ci_test.db --force
prompt-db --db /tmp/ci_test.db stats
79 changes: 55 additions & 24 deletions src/prompt_database/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def _resolve_db(ctx: click.Context) -> Path:

@click.group()
@click.version_option(__version__, prog_name="prompt-db")
@click.option("--db", default=str(DEFAULT_DB), envvar="PROMPT_DB_PATH", help="Path to SQLite database")
@click.option(
"--db", default=str(DEFAULT_DB), envvar="PROMPT_DB_PATH", help="Path to SQLite database"
)
@click.pass_context
def main(ctx: click.Context, db: str) -> None:
"""Prompt injection attack database for defensive AI security research."""
Expand Down Expand Up @@ -61,7 +63,9 @@ def build(ctx: click.Context, data_dir: str, output: str, force: bool) -> None:

total_added = sum(r["added"] for r in results.values())
total_skipped = sum(r["skipped"] for r in results.values())
console.print(f"\n[green]Done![/green] {total_added} prompts added, {total_skipped} duplicates skipped.")
console.print(
f"\n[green]Done![/green] {total_added} prompts added, {total_skipped} duplicates skipped."
)
console.print(f"Database: {out_path} ({out_path.stat().st_size / 1024 / 1024:.1f} MB)")


Expand All @@ -82,31 +86,31 @@ def stats(ctx: click.Context) -> None:
with PromptDatabase(db_path) as db:
s = db.stats()

console.print(f"\n[bold]Prompt Database Statistics[/bold]")
console.print("\n[bold]Prompt Database Statistics[/bold]")
console.print(f" Total prompts: {s['total_prompts']:,}")
console.print(f" Verified: {s['verified']:,}")
console.print(f" Curated: {s['curated']:,}")
console.print(f" Test results: {s['test_results']:,}")
console.print(f" Variations: {s['variations']:,}")
console.print(f" Avg sophistication: {s['avg_sophistication']}")

console.print(f"\n[bold]By Technique[/bold]")
console.print("\n[bold]By Technique[/bold]")
table = Table(show_header=True)
table.add_column("Technique", style="cyan")
table.add_column("Count", justify="right")
for tech, count in sorted(s["by_technique"].items(), key=lambda x: -x[1]):
table.add_row(tech, str(count))
console.print(table)

console.print(f"\n[bold]By Complexity[/bold]")
console.print("\n[bold]By Complexity[/bold]")
table = Table(show_header=True)
table.add_column("Complexity", style="yellow")
table.add_column("Count", justify="right")
for comp, count in sorted(s["by_complexity"].items(), key=lambda x: -x[1]):
table.add_row(comp, str(count))
console.print(table)

console.print(f"\n[bold]By Source[/bold]")
console.print("\n[bold]By Source[/bold]")
table = Table(show_header=True)
table.add_column("Source", style="green")
table.add_column("Count", justify="right")
Expand Down Expand Up @@ -166,7 +170,10 @@ def search(

for r in results:
content_preview = r["content"][:120].replace("\n", " ") if not full else r["content"]
console.print(f" [cyan]#{r['id']}[/cyan] [{r['technique']}] [{r['complexity']}] score={r['sophistication_score']}")
score = r["sophistication_score"]
console.print(
f" [cyan]#{r['id']}[/cyan] [{r['technique']}] [{r['complexity']}] score={score}"
)
console.print(f" {content_preview}")
if not full:
console.print(f" [dim]source={r['source']}[/dim]")
Expand Down Expand Up @@ -219,7 +226,16 @@ def export(

buf = io.StringIO()
if prompts:
writer = csv.DictWriter(buf, fieldnames=["id", "content", "technique", "complexity", "sophistication_score", "source", "success_rate"])
fields = [
"id",
"content",
"technique",
"complexity",
"sophistication_score",
"source",
"success_rate",
]
writer = csv.DictWriter(buf, fieldnames=fields)
writer.writeheader()
for p in prompts:
writer.writerow({k: p.get(k) for k in writer.fieldnames})
Expand Down Expand Up @@ -260,7 +276,8 @@ def export_garak_cmd(

with PromptDatabase(db_path) as db:
count = export_garak(
db, Path(output),
db,
Path(output),
technique=technique,
min_sophistication=min_score,
limit=limit,
Expand Down Expand Up @@ -297,7 +314,8 @@ def export_ps_fuzz_cmd(

with PromptDatabase(db_path) as db:
count = export_ps_fuzz(
db, Path(output),
db,
Path(output),
technique=technique,
min_sophistication=min_score,
limit=limit,
Expand Down Expand Up @@ -334,13 +352,16 @@ def export_dataset_cmd(

with PromptDatabase(db_path) as db:
count = export_dataset(
db, Path(output),
db,
Path(output),
technique=technique,
min_sophistication=min_score,
limit=limit,
)

console.print(f"[green]Exported {count} prompts to {output} (HuggingFace dataset format)[/green]")
console.print(
f"[green]Exported {count} prompts to {output} (HuggingFace dataset format)[/green]"
)


# =============================================================================
Expand Down Expand Up @@ -383,7 +404,7 @@ def info(ctx: click.Context, prompt_id: int) -> None:
cats = [f"{c['code']} ({c['name']})" for c in prompt["categories"]]
console.print(f" Categories: {', '.join(cats)}")

console.print(f"\n[bold]Content:[/bold]")
console.print("\n[bold]Content:[/bold]")
console.print(prompt["content"])
console.print()

Expand All @@ -397,7 +418,9 @@ def info(ctx: click.Context, prompt_id: int) -> None:
@click.option("--model", "-m", required=True, help="Target model (e.g., gpt-4, claude-sonnet-4-5)")
@click.option("--api-base", default="https://api.openai.com/v1", help="API base URL")
@click.option("--api-key", envvar="OPENAI_API_KEY", help="API key (or set OPENAI_API_KEY)")
@click.option("--system-prompt", default="You are a helpful assistant.", help="System prompt for target")
@click.option(
"--system-prompt", default="You are a helpful assistant.", help="System prompt for target"
)
@click.option("--technique", "-t", help="Filter prompts by technique")
@click.option("--min-score", type=int, help="Minimum sophistication score")
@click.option("--limit", "-n", default=10, type=int, help="Number of prompts to test")
Expand Down Expand Up @@ -459,7 +482,7 @@ def test_prompt_cmd(
for p in prompts_to_test:
preview = p["content"][:80].replace("\n", " ")
console.print(f" #{p['id']} [{p['technique']}] {preview}...")
console.print(f"\n[dim]Dry run — no API calls made.[/dim]")
console.print("\n[dim]Dry run — no API calls made.[/dim]")
return

results_summary = {"SUCCESS": 0, "FAIL": 0, "PARTIAL": 0, "ERROR": 0}
Expand Down Expand Up @@ -489,15 +512,17 @@ def test_prompt_cmd(
results_summary[result.result] += 1

color = {
"SUCCESS": "red", "FAIL": "green",
"PARTIAL": "yellow", "ERROR": "dim",
"SUCCESS": "red",
"FAIL": "green",
"PARTIAL": "yellow",
"ERROR": "dim",
}.get(result.result, "white")
console.print(
f"[{color}]{result.result}[/{color}] "
f"(conf={result.confidence_score:.2f}, {result.response_time_ms:.0f}ms)"
)

console.print(f"\n[bold]Results Summary[/bold]")
console.print("\n[bold]Results Summary[/bold]")
console.print(f" [red]SUCCESS (attack worked):[/red] {results_summary['SUCCESS']}")
console.print(f" [green]FAIL (model defended):[/green] {results_summary['FAIL']}")
console.print(f" [yellow]PARTIAL (ambiguous):[/yellow] {results_summary['PARTIAL']}")
Expand Down Expand Up @@ -574,19 +599,23 @@ def audit(ctx: click.Context, source: str | None, show_remove: bool) -> None:
preview = row["content"][:100].replace("\n", " ")
console.print(f" [red]REMOVE[/red] #{row['id']} [{src}] {preview}")

console.print(f"\n[bold]Data Quality Audit[/bold]")
console.print("\n[bold]Data Quality Audit[/bold]")
console.print(f" Total prompts: {len(rows):,}")
console.print(f" [green]Keep: {keep_count:,}[/green]")
console.print(f" [yellow]Review: {review_count:,}[/yellow]")
console.print(f" [red]Remove: {remove_count:,}[/red]")

console.print(f"\n[bold]By Source[/bold]")
console.print("\n[bold]By Source[/bold]")
table = Table(show_header=True)
table.add_column("Source", style="cyan")
table.add_column("Keep", justify="right", style="green")
table.add_column("Review", justify="right", style="yellow")
table.add_column("Remove", justify="right", style="red")
for src in sorted(by_source, key=lambda s: -(by_source[s]["keep"] + by_source[s]["review"] + by_source[s]["remove"])):

def _total(s: str) -> int:
return -(by_source[s]["keep"] + by_source[s]["review"] + by_source[s]["remove"])

for src in sorted(by_source, key=_total):
counts = by_source[src]
table.add_row(src, str(counts["keep"]), str(counts["review"]), str(counts["remove"]))
console.print(table)
Expand Down Expand Up @@ -638,14 +667,16 @@ def curate(ctx: click.Context, dry_run: bool, min_quality: int) -> None:
if assessment["quality_score"] < min_quality:
if not dry_run:
db.conn.execute(
"UPDATE prompts SET is_active = 0, updated_at = datetime('now') WHERE id = ?",
"UPDATE prompts SET is_active = 0, updated_at = datetime('now') "
"WHERE id = ?",
(row["id"],),
)
deactivated += 1
elif assessment["quality_score"] >= 50:
if not dry_run:
db.conn.execute(
"UPDATE prompts SET is_curated = 1, updated_at = datetime('now') WHERE id = ?",
"UPDATE prompts SET is_curated = 1, updated_at = datetime('now') "
"WHERE id = ?",
(row["id"],),
)
curated += 1
Expand All @@ -654,7 +685,7 @@ def curate(ctx: click.Context, dry_run: bool, min_quality: int) -> None:
db.conn.commit()

action = "Would deactivate" if dry_run else "Deactivated"
console.print(f"\n[bold]Curation Results[/bold]")
console.print("\n[bold]Curation Results[/bold]")
console.print(f" Total prompts: {len(rows):,}")
console.print(f" [red]{action}: {deactivated:,} (quality < {min_quality})[/red]")
console.print(f" [green]Curated: {curated:,} (quality >= 50)[/green]")
Expand Down
41 changes: 22 additions & 19 deletions src/prompt_database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def init_schema(self) -> None:
def schema_version(self) -> int:
"""Return current schema version, or 0 if not initialized."""
try:
row = self.conn.execute(
"SELECT MAX(version) FROM schema_version"
).fetchone()
row = self.conn.execute("SELECT MAX(version) FROM schema_version").fetchone()
return row[0] or 0
except sqlite3.OperationalError:
return 0
Expand Down Expand Up @@ -139,22 +137,19 @@ def add_prompt(
def _apply_tags(self, prompt_id: int, tag_names: list[str]) -> None:
for name in tag_names:
self.conn.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (name,))
tag_id = self.conn.execute(
"SELECT id FROM tags WHERE name = ?", (name,)
).fetchone()[0]
tag_id = self.conn.execute("SELECT id FROM tags WHERE name = ?", (name,)).fetchone()[0]
self.conn.execute(
"INSERT OR IGNORE INTO prompt_tags (prompt_id, tag_id) VALUES (?, ?)",
(prompt_id, tag_id),
)

def _apply_categories(self, prompt_id: int, category_codes: list[str]) -> None:
for code in category_codes:
row = self.conn.execute(
"SELECT id FROM categories WHERE code = ?", (code,)
).fetchone()
row = self.conn.execute("SELECT id FROM categories WHERE code = ?", (code,)).fetchone()
if row:
self.conn.execute(
"INSERT OR IGNORE INTO prompt_categories (prompt_id, category_id) VALUES (?, ?)",
"INSERT OR IGNORE INTO prompt_categories "
"(prompt_id, category_id) VALUES (?, ?)",
(prompt_id, row[0]),
)

Expand Down Expand Up @@ -292,7 +287,9 @@ def filter_prompts(
if curated_only:
conditions.append("p.is_curated = 1")
if tag:
joins.append("JOIN prompt_tags pt ON pt.prompt_id = p.id JOIN tags t ON t.id = pt.tag_id")
joins.append(
"JOIN prompt_tags pt ON pt.prompt_id = p.id JOIN tags t ON t.id = pt.tag_id"
)
conditions.append("t.name = ?")
params.append(tag)
if category_code:
Expand Down Expand Up @@ -327,13 +324,11 @@ def stats(self) -> dict[str, Any]:
verified = self.conn.execute(
"SELECT COUNT(*) FROM prompts WHERE is_verified = 1"
).fetchone()[0]
curated = self.conn.execute(
"SELECT COUNT(*) FROM prompts WHERE is_curated = 1"
).fetchone()[0]
curated = self.conn.execute("SELECT COUNT(*) FROM prompts WHERE is_curated = 1").fetchone()[
0
]
test_count = self.conn.execute("SELECT COUNT(*) FROM test_results").fetchone()[0]
variation_count = self.conn.execute(
"SELECT COUNT(*) FROM prompt_variations"
).fetchone()[0]
variation_count = self.conn.execute("SELECT COUNT(*) FROM prompt_variations").fetchone()[0]

by_technique = dict(
self.conn.execute(
Expand Down Expand Up @@ -486,7 +481,8 @@ def export_prompts(
limit_clause = f"LIMIT {limit}" if limit else ""

rows = self.conn.execute(
f"SELECT * FROM prompts WHERE {where} ORDER BY sophistication_score DESC {limit_clause}",
f"SELECT * FROM prompts WHERE {where} "
f"ORDER BY sophistication_score DESC {limit_clause}",
params,
).fetchall()

Expand All @@ -496,7 +492,14 @@ def export_prompts(
d["tags"] = self._get_prompt_tags(d["id"])
d["categories"] = self._get_prompt_categories(d["id"])
# Parse JSON fields
for field in ("matched_patterns", "target_models", "paper_ids", "cve_ids", "reference_urls"):
json_fields = (
"matched_patterns",
"target_models",
"paper_ids",
"cve_ids",
"reference_urls",
)
for field in json_fields:
if d.get(field) and isinstance(d[field], str):
try:
d[field] = json.loads(d[field])
Expand Down
17 changes: 8 additions & 9 deletions src/prompt_database/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import json
from pathlib import Path
from typing import Any

from prompt_database.db import PromptDatabase

Expand Down Expand Up @@ -80,19 +79,19 @@ def export_ps_fuzz(
content = p["content"].replace("\\", "\\\\").replace('"', '\\"')
# Multi-line content uses YAML literal block
if "\n" in content:
yaml_lines.append(f" - name: \"{name}\"")
yaml_lines.append(f" category: \"{p['technique']}\"")
yaml_lines.append(f" complexity: \"{p['complexity']}\"")
yaml_lines.append(f' - name: "{name}"')
yaml_lines.append(f' category: "{p["technique"]}"')
yaml_lines.append(f' complexity: "{p["complexity"]}"')
yaml_lines.append(f" sophistication: {p['sophistication_score']}")
yaml_lines.append(f" prompt: |")
yaml_lines.append(" prompt: |")
for line in content.split("\n"):
yaml_lines.append(f" {line}")
else:
yaml_lines.append(f" - name: \"{name}\"")
yaml_lines.append(f" category: \"{p['technique']}\"")
yaml_lines.append(f" complexity: \"{p['complexity']}\"")
yaml_lines.append(f' - name: "{name}"')
yaml_lines.append(f' category: "{p["technique"]}"')
yaml_lines.append(f' complexity: "{p["complexity"]}"')
yaml_lines.append(f" sophistication: {p['sophistication_score']}")
yaml_lines.append(f" prompt: \"{content}\"")
yaml_lines.append(f' prompt: "{content}"')

output_path.write_text("\n".join(yaml_lines) + "\n", encoding="utf-8")
return len(prompts)
Expand Down
Loading
Loading