Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 89 additions & 37 deletions backend/secuscan/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
"""

import asyncio
import contextlib
import json
import sqlite3
from pathlib import Path
from typing import Any, Optional, List, Dict
from typing import Any, Optional, List, Dict, AsyncIterator

import aiosqlite
from .config import settings
Expand All @@ -27,7 +28,9 @@ def __init__(self, db_path: str):
def connection(self) -> aiosqlite.Connection:
"""Get the active database connection, raising an error if it's not connected."""
if self._connection is None:
raise RuntimeError("Database not connected. Did you forget to await connect()?")
raise RuntimeError(
"Database not connected. Did you forget to await connect()?"
)
return self._connection

async def connect(self):
Expand Down Expand Up @@ -417,13 +420,15 @@ async def _create_schema(self):
"inputs_json": "TEXT NOT NULL DEFAULT '{}'",
"execution_context_json": "TEXT NOT NULL DEFAULT '{}'",
"preset": "TEXT",
"safe_mode": "BOOLEAN NOT NULL DEFAULT 1"
"safe_mode": "BOOLEAN NOT NULL DEFAULT 1",
}

for col_name, col_type in needed_cols.items():
if col_name not in existing_cols:
try:
await self.execute(f"ALTER TABLE tasks ADD COLUMN {col_name} {col_type}")
await self.execute(
f"ALTER TABLE tasks ADD COLUMN {col_name} {col_type}"
)
print(f"Added missing column {col_name} to tasks table.")
except Exception as e:
print(f"Failed to add column {col_name}: {e}")
Expand Down Expand Up @@ -467,7 +472,9 @@ async def _create_schema(self):
for col_name, col_type in risk_cols.items():
if col_name not in existing_finding_cols:
try:
await self.execute(f"ALTER TABLE findings ADD COLUMN {col_name} {col_type}")
await self.execute(
f"ALTER TABLE findings ADD COLUMN {col_name} {col_type}"
)
print(f"Added missing column {col_name} to findings table.")
except Exception as e:
print(f"Failed to add column {col_name}: {e}")
Expand All @@ -487,7 +494,9 @@ async def _create_schema(self):
for col_name, col_type in asset_service_needed.items():
if col_name not in existing_asset_service_cols:
try:
await self.execute(f"ALTER TABLE asset_services ADD COLUMN {col_name} {col_type}")
await self.execute(
f"ALTER TABLE asset_services ADD COLUMN {col_name} {col_type}"
)
print(f"Added missing column {col_name} to asset_services table.")
except Exception as e:
print(f"Failed to add column {col_name} to asset_services: {e}")
Expand Down Expand Up @@ -615,7 +624,9 @@ async def _create_schema(self):
ALTER TABLE workflows_new RENAME TO workflows;
""")
await self.connection.commit()
print("Replaced workflows UNIQUE(name) constraint with UNIQUE(owner_id, name).")
print(
"Replaced workflows UNIQUE(name) constraint with UNIQUE(owner_id, name)."
)
finally:
if old_fk:
await self.execute("PRAGMA foreign_keys = ON")
Expand Down Expand Up @@ -650,9 +661,9 @@ async def _run_migrations(self):

if not migrations_dir.exists():
raise RuntimeError(
f"Migrations directory not found at {migrations_dir} — "
"ensure the backend package is installed correctly."
)
f"Migrations directory not found at {migrations_dir} — "
"ensure the backend package is installed correctly."
)

for migration_file in sorted(migrations_dir.glob("*.sql")):
sql = migration_file.read_text(encoding="utf-8")
Expand All @@ -668,6 +679,7 @@ async def _run_migrations(self):
async def _backfill_risk_scores(self):
"""Compute risk scores for existing findings that have none."""
from datetime import datetime, timezone

rows = await self.fetchall(
"SELECT id, severity, exploitability, confidence, asset_exposure, discovered_at, risk_score FROM findings WHERE risk_score IS NULL"
)
Expand Down Expand Up @@ -701,6 +713,27 @@ async def _backfill_risk_scores(self):
)
print(f"Backfilled risk scores for {len(rows)} existing finding(s).")

@contextlib.asynccontextmanager
async def transaction(self) -> AsyncIterator["Database"]:
"""Context manager for atomic transactions.

Usage::

async with db.transaction():
await db.execute("INSERT INTO ...")
await db.execute("UPDATE ...")

If any statement raises, the entire transaction is rolled back.
On success the transaction is committed automatically.
"""
await self.begin()
try:
yield self
await self.commit()
except Exception:
await self.rollback()
raise

async def execute(self, query: str, params: tuple = ()):
"""Execute a write query and return the cursor (so callers can inspect rowcount)."""
cursor = await self.connection.execute(query, params)
Expand Down Expand Up @@ -782,7 +815,6 @@ async def log_audit(
),
)


async def snapshot_workflow_version(
self,
workflow_id: str,
Expand Down Expand Up @@ -836,17 +868,21 @@ async def get_workflow_versions(self, workflow_id: str) -> List[Dict]:
defn = json.loads(row["definition_json"])
except (json.JSONDecodeError, TypeError):
defn = {}
result.append({
"id": row["id"],
"workflow_id": row["workflow_id"],
"version_number": row["version_number"],
"definition": defn,
"created_at": row["created_at"],
"created_by": row["created_by"],
})
result.append(
{
"id": row["id"],
"workflow_id": row["workflow_id"],
"version_number": row["version_number"],
"definition": defn,
"created_at": row["created_at"],
"created_by": row["created_by"],
}
)
return result

async def get_workflow_version(self, workflow_id: str, version_number: int) -> Optional[Dict]:
async def get_workflow_version(
self, workflow_id: str, version_number: int
) -> Optional[Dict]:
"""Return a specific version record or None if it does not exist."""
row = await self.fetchone(
"SELECT id, workflow_id, version_number, definition_json, created_at, created_by "
Expand Down Expand Up @@ -882,11 +918,20 @@ async def record_workflow_run(
"INSERT INTO workflow_runs "
"(id, workflow_id, version_id, version_number, triggered_by, status, task_ids_json) "
"VALUES (?, ?, ?, ?, ?, 'queued', ?)",
(run_id, workflow_id, version_id, version_number, triggered_by, json.dumps(task_ids)),
(
run_id,
workflow_id,
version_id,
version_number,
triggered_by,
json.dumps(task_ids),
),
)
return run_id

async def finalize_workflow_run(self, run_id: str, status: str, error_message: Optional[str] = None) -> None:
async def finalize_workflow_run(
self, run_id: str, status: str, error_message: Optional[str] = None
) -> None:
"""Mark a workflow run as completed, failed, or cancelled with a timestamp.

status must be one of: completed | failed | cancelled.
Expand All @@ -907,7 +952,9 @@ async def check_workflow_run_tasks(self, run_id: str) -> Optional[str]:
'cancelled' if any task was cancelled and none are still running/queued.
None if tasks are still in progress.
"""
run_row = await self.fetchone("SELECT task_ids_json FROM workflow_runs WHERE id = ?", (run_id,))
run_row = await self.fetchone(
"SELECT task_ids_json FROM workflow_runs WHERE id = ?", (run_id,)
)
if run_row is None:
return None
try:
Expand All @@ -932,10 +979,13 @@ async def check_workflow_run_tasks(self, run_id: str) -> Optional[str]:
return "cancelled"
return "failed"

async def get_workflow_runs(self, workflow_id: str, limit: int = 50, offset: int = 0) -> Dict:
async def get_workflow_runs(
self, workflow_id: str, limit: int = 50, offset: int = 0
) -> Dict:
"""Return paginated run history for a workflow."""
count_row = await self.fetchone(
"SELECT COUNT(*) AS total FROM workflow_runs WHERE workflow_id = ?", (workflow_id,)
"SELECT COUNT(*) AS total FROM workflow_runs WHERE workflow_id = ?",
(workflow_id,),
)
total = count_row["total"] if count_row else 0
rows = await self.fetchall(
Expand All @@ -949,18 +999,20 @@ async def get_workflow_runs(self, workflow_id: str, limit: int = 50, offset: int
task_ids = json.loads(row["task_ids_json"] or "[]")
except (json.JSONDecodeError, TypeError):
task_ids = []
entries.append({
"id": row["id"],
"workflow_id": row["workflow_id"],
"version_id": row["version_id"],
"version_number": row["version_number"],
"triggered_by": row["triggered_by"],
"status": row["status"],
"task_ids": task_ids,
"started_at": row["started_at"],
"completed_at": row["completed_at"],
"error_message": row["error_message"],
})
entries.append(
{
"id": row["id"],
"workflow_id": row["workflow_id"],
"version_id": row["version_id"],
"version_number": row["version_number"],
"triggered_by": row["triggered_by"],
"status": row["status"],
"task_ids": task_ids,
"started_at": row["started_at"],
"completed_at": row["completed_at"],
"error_message": row["error_message"],
}
)
return {"total": total, "runs": entries}


Expand Down
Loading
Loading