From a9661a90997219730e939b2b0c91fd45ad2ee696 Mon Sep 17 00:00:00 2001 From: ionfwsrijan Date: Sun, 21 Jun 2026 19:35:37 +0530 Subject: [PATCH 1/2] fix(#1137): add transaction context manager for atomic multi-write operations - Add Database.transaction() async context manager (begin/commit/rollback) - Wrap _upsert_findings_and_report in transaction (findings + report + resources) - Wrap _upsert_findings_and_report_from_scanner in transaction - Wrap cancel_task status update + audit log in transaction - Wrap replace_asset_services delete+insert in transaction - Wrap delete_task_records in transaction (replacing manual begin/commit/rollback) - Wrap upsert_vault_secret select-then-upsert in transaction - Keep existing execute() auto-commit behavior for backward compatibility --- backend/secuscan/database.py | 24 +++- backend/secuscan/executor.py | 157 +++++++++++++------------ backend/secuscan/platform_resources.py | 5 +- backend/secuscan/routes.py | 51 +++----- 4 files changed, 121 insertions(+), 116 deletions(-) diff --git a/backend/secuscan/database.py b/backend/secuscan/database.py index 0b2d02798..42bb36b80 100644 --- a/backend/secuscan/database.py +++ b/backend/secuscan/database.py @@ -3,10 +3,11 @@ """ import asyncio +import contextlib import json import sqlite3 from pathlib import Path -from typing import Any, Optional, List, Dict +from typing import Any, Optional, List, Dict, AsyncIterator import aiosqlite from .config import settings @@ -701,6 +702,27 @@ async def _backfill_risk_scores(self): ) print(f"Backfilled risk scores for {len(rows)} existing finding(s).") + @contextlib.asynccontextmanager + async def transaction(self) -> AsyncIterator["Database"]: + """Context manager for atomic transactions. + + Usage:: + + async with db.transaction(): + await db.execute("INSERT INTO ...") + await db.execute("UPDATE ...") + + If any statement raises, the entire transaction is rolled back. + On success the transaction is committed automatically. + """ + await self.begin() + try: + yield self + await self.commit() + except Exception: + await self.rollback() + raise + async def execute(self, query: str, params: tuple = ()): """Execute a write query and return the cursor (so callers can inspect rowcount).""" cursor = await self.connection.execute(query, params) diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index db078a8eb..965413dbf 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -947,20 +947,21 @@ async def cancel_task(self, task_id: str) -> bool: logger.error(f"Failed to kill docker container for {task_id}: {e}") db = await get_db() - await db.execute( - "UPDATE tasks SET status = ?, completed_at = ? WHERE id = ?", - (TaskStatus.CANCELLED.value, datetime.now().isoformat(), task_id) - ) + async with db.transaction(): + await db.execute( + "UPDATE tasks SET status = ?, completed_at = ? WHERE id = ?", + (TaskStatus.CANCELLED.value, datetime.now().isoformat(), task_id) + ) + + await db.log_audit( + "task_cancelled", + "Task cancelled by user", + task_id=task_id + ) await self._broadcast(task_id, "status", TaskStatus.CANCELLED.value) await self._invalidate_cached_views() - await db.log_audit( - "task_cancelled", - "Task cancelled by user", - task_id=task_id - ) - return True async def get_task_status(self, task_id: str) -> Optional[Dict]: @@ -1368,41 +1369,42 @@ async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plu structured_result["asset_summary"] = build_asset_summary(findings_data, asset_services) structured_result["scan_diff"] = build_scan_diff(findings_data, previous_findings) - await db.execute( - "UPDATE tasks SET structured_json = ? WHERE id = ?", - (json.dumps(structured_result), task_id) - ) + async with db.transaction(): + await db.execute( + "UPDATE tasks SET structured_json = ? WHERE id = ?", + (json.dumps(structured_result), task_id) + ) - await db.execute( - """ - INSERT INTO reports ( - id, owner_id, task_id, name, type, generated_at, status, findings, pages - ) VALUES (?, ?, ?, ?, ?, (datetime('now')), ?, ?, ?) - ON CONFLICT (id) DO UPDATE SET - status = EXCLUDED.status, - findings = EXCLUDED.findings, - pages = EXCLUDED.pages - """, - ( - f"report:{task_id}", - owner_id, - task_id, - f"{plugin.name} Report", - "technical", - "ready" if status == TaskStatus.COMPLETED.value else "failed", - len(findings_data), - 1, - ), - ) + await db.execute( + """ + INSERT INTO reports ( + id, owner_id, task_id, name, type, generated_at, status, findings, pages + ) VALUES (?, ?, ?, ?, ?, (datetime('now')), ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + status = EXCLUDED.status, + findings = EXCLUDED.findings, + pages = EXCLUDED.pages + """, + ( + f"report:{task_id}", + owner_id, + task_id, + f"{plugin.name} Report", + "technical", + "ready" if status == TaskStatus.COMPLETED.value else "failed", + len(findings_data), + 1, + ), + ) - await self._persist_result_resources( - db, - owner_id=owner_id, - task_id=task_id, - plugin_id=plugin_id, - target=target, - result=structured_result, - ) + await self._persist_result_resources( + db, + owner_id=owner_id, + task_id=task_id, + plugin_id=plugin_id, + target=target, + result=structured_result, + ) async def _upsert_findings_and_report_from_scanner(self, db, task_id: str, owner_id: str, scanner: Any, plugin_id: str, target: str, status: str, result: Dict[str, Any]): """Persist modular scanner results into findings, and reports.""" @@ -1433,42 +1435,43 @@ async def _upsert_findings_and_report_from_scanner(self, db, task_id: str, owner structured_result["asset_summary"] = build_asset_summary(findings_data, asset_services) structured_result["scan_diff"] = build_scan_diff(findings_data, previous_findings) - await db.execute( - "UPDATE tasks SET structured_json = ? WHERE id = ?", - (json.dumps(structured_result), task_id) - ) + async with db.transaction(): + await db.execute( + "UPDATE tasks SET structured_json = ? WHERE id = ?", + (json.dumps(structured_result), task_id) + ) - # Create/Update report - await db.execute( - """ - INSERT INTO reports ( - id, owner_id, task_id, name, type, generated_at, status, findings, pages - ) VALUES (?, ?, ?, ?, ?, (datetime('now')), ?, ?, ?) - ON CONFLICT (id) DO UPDATE SET - status = EXCLUDED.status, - findings = EXCLUDED.findings, - pages = EXCLUDED.pages - """, - ( - f"report:{task_id}", - owner_id, - task_id, - f"{scanner.name} Report", - "professional" if status == TaskStatus.COMPLETED.value else "failed", - "ready" if status == TaskStatus.COMPLETED.value else "failed", - len(findings_data), - 2, # Professional reports are typically multi-page - ), - ) + # Create/Update report + await db.execute( + """ + INSERT INTO reports ( + id, owner_id, task_id, name, type, generated_at, status, findings, pages + ) VALUES (?, ?, ?, ?, ?, (datetime('now')), ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + status = EXCLUDED.status, + findings = EXCLUDED.findings, + pages = EXCLUDED.pages + """, + ( + f"report:{task_id}", + owner_id, + task_id, + f"{scanner.name} Report", + "professional" if status == TaskStatus.COMPLETED.value else "failed", + "ready" if status == TaskStatus.COMPLETED.value else "failed", + len(findings_data), + 2, # Professional reports are typically multi-page + ), + ) - await self._persist_result_resources( - db, - owner_id=owner_id, - task_id=task_id, - plugin_id=plugin_id, - target=target, - result=structured_result, - ) + await self._persist_result_resources( + db, + owner_id=owner_id, + task_id=task_id, + plugin_id=plugin_id, + target=target, + result=structured_result, + ) async def _persist_result_resources( self, diff --git a/backend/secuscan/platform_resources.py b/backend/secuscan/platform_resources.py index 8b84ac00f..1116ab8b9 100644 --- a/backend/secuscan/platform_resources.py +++ b/backend/secuscan/platform_resources.py @@ -114,8 +114,9 @@ async def replace_asset_services( target: str, services: Iterable[Dict[str, Any]], ) -> None: - await db.execute("DELETE FROM asset_services WHERE task_id = ?", (task_id,)) - for item in services: + async with db.transaction(): + await db.execute("DELETE FROM asset_services WHERE task_id = ?", (task_id,)) + for item in services: metadata = item.get("metadata", {}) if isinstance(item.get("metadata"), dict) else {} host = str(item.get("host") or target) port = item.get("port") diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index b8e7c45ea..6c94aea8e 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -1243,8 +1243,7 @@ async def delete_task_records(task_ids: List[str]): all_task_rows.extend(rows) # Delete associated records in chunks, atomic within a transaction - await db.begin() - try: + async with db.transaction(): # Re-check running status inside the transaction to prevent the # race where a task starts running between the check and the delete. for i in range(0, len(task_ids), SQLITE_CHUNK_SIZE): @@ -1285,13 +1284,6 @@ async def delete_task_records(task_ids: List[str]): await db.execute_no_commit( f"DELETE FROM tasks WHERE id IN ({placeholders})", tuple(chunk) ) - await db.commit() - except HTTPException: - await db.rollback() - raise - except Exception: - await db.rollback() - raise # Cleanup files on disk (outside the transaction — file deletion is not # transactional; a failure here does not leave the DB in an inconsistent @@ -1475,34 +1467,21 @@ async def upsert_vault_secret( encrypted = crypto.encrypt(value) secret_id = str(uuid.uuid4()) - existing = await db.fetchone( - """ - SELECT id - FROM credential_vault - WHERE owner_id = ? AND name = ? - """, - (owner, name), - ) - - if existing: - await db.execute( - """ - UPDATE credential_vault - SET encrypted_value = ?, updated_at = datetime('now') - WHERE owner_id = ? AND name = ? - """, - (encrypted, owner, name), - ) - else: - await db.execute( - """ - INSERT INTO credential_vault - (id, owner_id, name, encrypted_value) - VALUES (?, ?, ?, ?) - """, - (secret_id, owner, name, encrypted), + async with db.transaction(): + existing = await db.fetchone( + "SELECT id FROM credential_vault WHERE owner_id = ? AND name = ?", + (owner, name), ) - + if existing: + await db.execute( + "UPDATE credential_vault SET encrypted_value = ?, updated_at = datetime('now') WHERE owner_id = ? AND name = ?", + (encrypted, owner, name), + ) + else: + await db.execute( + "INSERT INTO credential_vault (id, owner_id, name, encrypted_value) VALUES (?, ?, ?, ?)", + (secret_id, owner, name, encrypted), + ) return {"name": name, "stored": True} @router.get("/vault/{name}", dependencies=[Depends(vault_limiter)]) From 775efc1518bc0b668f59e841b98fd4b9f69e1faf Mon Sep 17 00:00:00 2001 From: ionfwsrijan Date: Sun, 21 Jun 2026 19:56:09 +0530 Subject: [PATCH 2/2] fix(#1137): fix indentation and ruff formatting - Fix indentation in replace_asset_services for loop body after wrapping in async with db.transaction() - Run ruff format on all 4 changed files --- backend/secuscan/database.py | 102 ++- backend/secuscan/executor.py | 557 +++++++++++----- backend/secuscan/platform_resources.py | 139 ++-- backend/secuscan/routes.py | 890 ++++++++++++++++++------- 4 files changed, 1167 insertions(+), 521 deletions(-) diff --git a/backend/secuscan/database.py b/backend/secuscan/database.py index 42bb36b80..d48667348 100644 --- a/backend/secuscan/database.py +++ b/backend/secuscan/database.py @@ -28,7 +28,9 @@ def __init__(self, db_path: str): def connection(self) -> aiosqlite.Connection: """Get the active database connection, raising an error if it's not connected.""" if self._connection is None: - raise RuntimeError("Database not connected. Did you forget to await connect()?") + raise RuntimeError( + "Database not connected. Did you forget to await connect()?" + ) return self._connection async def connect(self): @@ -418,13 +420,15 @@ async def _create_schema(self): "inputs_json": "TEXT NOT NULL DEFAULT '{}'", "execution_context_json": "TEXT NOT NULL DEFAULT '{}'", "preset": "TEXT", - "safe_mode": "BOOLEAN NOT NULL DEFAULT 1" + "safe_mode": "BOOLEAN NOT NULL DEFAULT 1", } for col_name, col_type in needed_cols.items(): if col_name not in existing_cols: try: - await self.execute(f"ALTER TABLE tasks ADD COLUMN {col_name} {col_type}") + await self.execute( + f"ALTER TABLE tasks ADD COLUMN {col_name} {col_type}" + ) print(f"Added missing column {col_name} to tasks table.") except Exception as e: print(f"Failed to add column {col_name}: {e}") @@ -468,7 +472,9 @@ async def _create_schema(self): for col_name, col_type in risk_cols.items(): if col_name not in existing_finding_cols: try: - await self.execute(f"ALTER TABLE findings ADD COLUMN {col_name} {col_type}") + await self.execute( + f"ALTER TABLE findings ADD COLUMN {col_name} {col_type}" + ) print(f"Added missing column {col_name} to findings table.") except Exception as e: print(f"Failed to add column {col_name}: {e}") @@ -488,7 +494,9 @@ async def _create_schema(self): for col_name, col_type in asset_service_needed.items(): if col_name not in existing_asset_service_cols: try: - await self.execute(f"ALTER TABLE asset_services ADD COLUMN {col_name} {col_type}") + await self.execute( + f"ALTER TABLE asset_services ADD COLUMN {col_name} {col_type}" + ) print(f"Added missing column {col_name} to asset_services table.") except Exception as e: print(f"Failed to add column {col_name} to asset_services: {e}") @@ -616,7 +624,9 @@ async def _create_schema(self): ALTER TABLE workflows_new RENAME TO workflows; """) await self.connection.commit() - print("Replaced workflows UNIQUE(name) constraint with UNIQUE(owner_id, name).") + print( + "Replaced workflows UNIQUE(name) constraint with UNIQUE(owner_id, name)." + ) finally: if old_fk: await self.execute("PRAGMA foreign_keys = ON") @@ -651,9 +661,9 @@ async def _run_migrations(self): if not migrations_dir.exists(): raise RuntimeError( - f"Migrations directory not found at {migrations_dir} — " - "ensure the backend package is installed correctly." - ) + f"Migrations directory not found at {migrations_dir} — " + "ensure the backend package is installed correctly." + ) for migration_file in sorted(migrations_dir.glob("*.sql")): sql = migration_file.read_text(encoding="utf-8") @@ -669,6 +679,7 @@ async def _run_migrations(self): async def _backfill_risk_scores(self): """Compute risk scores for existing findings that have none.""" from datetime import datetime, timezone + rows = await self.fetchall( "SELECT id, severity, exploitability, confidence, asset_exposure, discovered_at, risk_score FROM findings WHERE risk_score IS NULL" ) @@ -804,7 +815,6 @@ async def log_audit( ), ) - async def snapshot_workflow_version( self, workflow_id: str, @@ -858,17 +868,21 @@ async def get_workflow_versions(self, workflow_id: str) -> List[Dict]: defn = json.loads(row["definition_json"]) except (json.JSONDecodeError, TypeError): defn = {} - result.append({ - "id": row["id"], - "workflow_id": row["workflow_id"], - "version_number": row["version_number"], - "definition": defn, - "created_at": row["created_at"], - "created_by": row["created_by"], - }) + result.append( + { + "id": row["id"], + "workflow_id": row["workflow_id"], + "version_number": row["version_number"], + "definition": defn, + "created_at": row["created_at"], + "created_by": row["created_by"], + } + ) return result - async def get_workflow_version(self, workflow_id: str, version_number: int) -> Optional[Dict]: + async def get_workflow_version( + self, workflow_id: str, version_number: int + ) -> Optional[Dict]: """Return a specific version record or None if it does not exist.""" row = await self.fetchone( "SELECT id, workflow_id, version_number, definition_json, created_at, created_by " @@ -904,11 +918,20 @@ async def record_workflow_run( "INSERT INTO workflow_runs " "(id, workflow_id, version_id, version_number, triggered_by, status, task_ids_json) " "VALUES (?, ?, ?, ?, ?, 'queued', ?)", - (run_id, workflow_id, version_id, version_number, triggered_by, json.dumps(task_ids)), + ( + run_id, + workflow_id, + version_id, + version_number, + triggered_by, + json.dumps(task_ids), + ), ) return run_id - async def finalize_workflow_run(self, run_id: str, status: str, error_message: Optional[str] = None) -> None: + async def finalize_workflow_run( + self, run_id: str, status: str, error_message: Optional[str] = None + ) -> None: """Mark a workflow run as completed, failed, or cancelled with a timestamp. status must be one of: completed | failed | cancelled. @@ -929,7 +952,9 @@ async def check_workflow_run_tasks(self, run_id: str) -> Optional[str]: 'cancelled' if any task was cancelled and none are still running/queued. None if tasks are still in progress. """ - run_row = await self.fetchone("SELECT task_ids_json FROM workflow_runs WHERE id = ?", (run_id,)) + run_row = await self.fetchone( + "SELECT task_ids_json FROM workflow_runs WHERE id = ?", (run_id,) + ) if run_row is None: return None try: @@ -954,10 +979,13 @@ async def check_workflow_run_tasks(self, run_id: str) -> Optional[str]: return "cancelled" return "failed" - async def get_workflow_runs(self, workflow_id: str, limit: int = 50, offset: int = 0) -> Dict: + async def get_workflow_runs( + self, workflow_id: str, limit: int = 50, offset: int = 0 + ) -> Dict: """Return paginated run history for a workflow.""" count_row = await self.fetchone( - "SELECT COUNT(*) AS total FROM workflow_runs WHERE workflow_id = ?", (workflow_id,) + "SELECT COUNT(*) AS total FROM workflow_runs WHERE workflow_id = ?", + (workflow_id,), ) total = count_row["total"] if count_row else 0 rows = await self.fetchall( @@ -971,18 +999,20 @@ async def get_workflow_runs(self, workflow_id: str, limit: int = 50, offset: int task_ids = json.loads(row["task_ids_json"] or "[]") except (json.JSONDecodeError, TypeError): task_ids = [] - entries.append({ - "id": row["id"], - "workflow_id": row["workflow_id"], - "version_id": row["version_id"], - "version_number": row["version_number"], - "triggered_by": row["triggered_by"], - "status": row["status"], - "task_ids": task_ids, - "started_at": row["started_at"], - "completed_at": row["completed_at"], - "error_message": row["error_message"], - }) + entries.append( + { + "id": row["id"], + "workflow_id": row["workflow_id"], + "version_id": row["version_id"], + "version_number": row["version_number"], + "triggered_by": row["triggered_by"], + "status": row["status"], + "task_ids": task_ids, + "started_at": row["started_at"], + "completed_at": row["completed_at"], + "error_message": row["error_message"], + } + ) return {"total": total, "runs": entries} diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index 965413dbf..02543aafc 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -27,7 +27,11 @@ from .models import NotificationDeliveryStatus, TaskStatus, ScanPhase from .ratelimit import concurrent_limiter from .risk_scoring import compute_risk_score, compute_risk_factors -from .capabilities import CapabilityEnforcer, CapabilityDeniedError, build_enforcer_from_settings +from .capabilities import ( + CapabilityEnforcer, + CapabilityDeniedError, + build_enforcer_from_settings, +) from .parser_sandbox import run_parser_in_sandbox, ParserSandboxError from .network_policy import get_policy_engine from .notification_service import process_task_notifications @@ -48,7 +52,10 @@ ) from .vault import VaultCrypto -async def _terminate_process_group(pid: int, task_id: str, grace_seconds: int = _CANCEL_GRACE_SECONDS) -> None: + +async def _terminate_process_group( + pid: int, task_id: str, grace_seconds: int = _CANCEL_GRACE_SECONDS +) -> None: """Send SIGTERM to the process group of *pid*, wait *grace_seconds*, then SIGKILL. Using a process group (via start_new_session=True on subprocess creation) @@ -81,7 +88,9 @@ async def _terminate_process_group(pid: int, task_id: str, grace_seconds: int = os.killpg(pgid, signal.SIGKILL) logger.warning( "process group %d did not exit within %ds grace — SIGKILL sent (task %s)", - pgid, grace_seconds, task_id, + pgid, + grace_seconds, + task_id, ) except (ProcessLookupError, PermissionError) as exc: logger.debug("SIGKILL to pgid %d failed: %s", pgid, exc) @@ -106,7 +115,9 @@ def _validate_risk_fields(finding: dict) -> None: exp = finding.get("exploitability") if exp is not None: if not isinstance(exp, (int, float)): - raise ValueError(f"exploitability must be numeric, got {type(exp).__name__}") + raise ValueError( + f"exploitability must be numeric, got {type(exp).__name__}" + ) if exp < 0 or exp > 10: raise ValueError(f"exploitability must be in [0, 10], got {exp}") @@ -119,7 +130,10 @@ def _validate_risk_fields(finding: dict) -> None: ae = finding.get("asset_exposure") if ae is not None and ae.lower() not in ("critical", "high", "medium", "low"): - raise ValueError(f"asset_exposure must be one of critical/high/medium/low, got {ae}") + raise ValueError( + f"asset_exposure must be one of critical/high/medium/low, got {ae}" + ) + # Modular Scanners from .scanners.port_scanner import PortScanner @@ -211,7 +225,9 @@ async def _broadcast(self, task_id: str, event_type: str, data: Any): for q in list(self._listeners[task_id]): self._enqueue_listener_event(task_id, q, event) - def _enqueue_listener_event(self, task_id: str, q: asyncio.Queue, event: Dict[str, Any]): + def _enqueue_listener_event( + self, task_id: str, q: asyncio.Queue, event: Dict[str, Any] + ): """Add an event to a bounded listener queue without unbounded memory growth.""" try: q.put_nowait(event) @@ -225,15 +241,16 @@ def _enqueue_listener_event(self, task_id: str, q: asyncio.Queue, event: Dict[st try: q.put_nowait(event) except asyncio.QueueFull: - logger.warning("Dropping stream event for slow listener on task %s", task_id) - + logger.warning( + "Dropping stream event for slow listener on task %s", task_id + ) + async def _broadcast_phase(self, task_id: str, phase: str): """Broadcast a scan phase transition and persist it to the database.""" await self._broadcast(task_id, "phase", phase) db = await get_db() await db.execute( - "UPDATE tasks SET scan_phase = ? WHERE id = ?", - (phase, task_id) + "UPDATE tasks SET scan_phase = ? WHERE id = ?", (phase, task_id) ) async def create_task( @@ -265,16 +282,16 @@ async def create_task( task_id = str(uuid.uuid4()) plugin_manager = get_plugin_manager() plugin = plugin_manager.get_plugin(plugin_id) - + if not plugin: raise ValueError(f"Plugin not found: {plugin_id}") - + # Apply preset if provided if preset and preset in plugin.presets: preset_values = plugin.presets[preset] # Merge preset with user inputs (user inputs take precedence) inputs = {**preset_values, **inputs} - + # Store task in database db = await get_db() await db.execute( @@ -296,10 +313,10 @@ async def create_task( TaskStatus.QUEUED.value, ScanPhase.QUEUED.value, consent_granted, - bool(safe_mode) - ) + bool(safe_mode), + ), ) - + # Log audit event await db.log_audit( "task_created", @@ -311,11 +328,11 @@ async def create_task( "execution_context": normalize_execution_context(execution_context), }, task_id=task_id, - plugin_id=plugin_id + plugin_id=plugin_id, ) - + return task_id - + async def mark_task_failed(self, task_id: str, reason: str) -> None: """ Mark a task as failed without running it. @@ -341,7 +358,7 @@ async def mark_task_failed(self, task_id: str, reason: str) -> None: 0, reason, task_id, - ) + ), ) await db.log_audit( "task_failed", @@ -354,7 +371,7 @@ async def mark_task_failed(self, task_id: str, reason: str) -> None: async def execute_task(self, task_id: str): """ Execute a task asynchronously. - + Args: task_id: Task identifier """ @@ -366,10 +383,17 @@ async def execute_task(self, task_id: str): # if the task was deleted or already running before this point. result = await db.execute( "UPDATE tasks SET status = ?, started_at = ? WHERE id = ? AND status = ?", - (TaskStatus.RUNNING.value, datetime.now().isoformat(), task_id, TaskStatus.QUEUED.value) + ( + TaskStatus.RUNNING.value, + datetime.now().isoformat(), + task_id, + TaskStatus.QUEUED.value, + ), ) if result.rowcount == 0: - logger.warning(f"Task {task_id} was deleted or no longer queued before execution started. Aborting.") + logger.warning( + f"Task {task_id} was deleted or no longer queued before execution started. Aborting." + ) self.running_tasks.pop(task_id, None) return await self._invalidate_cached_views() @@ -377,7 +401,7 @@ async def execute_task(self, task_id: str): # Get task details task_row = await db.fetchone( "SELECT owner_id, plugin_id, inputs_json, execution_context_json, safe_mode FROM tasks WHERE id = ?", - (task_id,) + (task_id,), ) if not task_row: @@ -408,14 +432,15 @@ async def execute_task(self, task_id: str): if plugin and plugin.category == "code": should_validate = False - # Use shared is_filesystem_target from validation to ensure # consistent filesystem detection across route and executor layers. from .validation import is_filesystem_target + is_fs = is_filesystem_target(target) if should_validate and not is_fs: from .validation import validate_target + try: # Enforce safe mode validation of target address in a thread pool is_valid, error_msg = await asyncio.wait_for( @@ -427,14 +452,18 @@ async def execute_task(self, task_id: str): task_id, f"Safe mode target validation failed: {error_msg}", ) - await self._broadcast(task_id, "status", TaskStatus.FAILED.value) + await self._broadcast( + task_id, "status", TaskStatus.FAILED.value + ) return except asyncio.TimeoutError: await self.mark_task_failed( task_id, "Target validation timed out (SecuScan Guardrail)", ) - await self._broadcast(task_id, "status", TaskStatus.FAILED.value) + await self._broadcast( + task_id, "status", TaskStatus.FAILED.value + ) return # Check before launching any scanner or subprocess. check_access() @@ -452,7 +481,10 @@ async def execute_task(self, task_id: str): timeout=float(settings.dns_resolution_timeout_seconds), ) except asyncio.TimeoutError: - allowed, reason = False, "Network policy check timed out (DNS resolution timeout)" + allowed, reason = ( + False, + "Network policy check timed out (DNS resolution timeout)", + ) if not allowed: if settings.network_policy_failure_mode == "log_only": @@ -464,7 +496,9 @@ async def execute_task(self, task_id: str): task_id, f"Network policy denied access to {target}: {reason}", ) - await self._broadcast(task_id, "status", TaskStatus.FAILED.value) + await self._broadcast( + task_id, "status", TaskStatus.FAILED.value + ) return # finally block handles running_tasks cleanup + limiter release # Check if this is a modular scanner or a standard plugin @@ -479,7 +513,9 @@ async def execute_task(self, task_id: str): safety_level=plugin.safety.get("level", "safe"), ) - if plugin.safety.get("level") == "exploit" and not is_offensive_validation(execution_context): + if plugin.safety.get("level") == "exploit" and not is_offensive_validation( + execution_context + ): raise ValueError( "Exploit-level plugins require an execution context with validation_mode set to 'proof' or 'controlled_extract'." ) @@ -487,19 +523,23 @@ async def execute_task(self, task_id: str): if plugin_id in MODULAR_SCANNERS: scanner_class = MODULAR_SCANNERS[plugin_id] scanner = scanner_class(task_id, db, safe_mode=safe_mode) - + logger.info(f"Executing modular scanner {plugin_id} for task {task_id}") await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) - + start_time = time.time() # Run the scanner result = await scanner.run(target, inputs) duration = time.time() - start_time - + # Update task with results - final_status = TaskStatus.COMPLETED.value if result.get("status") != "failed" else TaskStatus.FAILED.value - + final_status = ( + TaskStatus.COMPLETED.value + if result.get("status") != "failed" + else TaskStatus.FAILED.value + ) + await db.execute( """ UPDATE tasks SET @@ -516,8 +556,8 @@ async def execute_task(self, task_id: str): duration, json.dumps(result), result.get("error_message"), - task_id - ) + task_id, + ), ) # Upsert findings and report using the scanner's result @@ -530,7 +570,7 @@ async def execute_task(self, task_id: str): plugin_id=plugin_id, target=target, status=final_status, - result=result + result=result, ) await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) @@ -546,26 +586,42 @@ async def execute_task(self, task_id: str): # Validate the named Docker network exists before using it. # If missing, attempt to create it automatically rather than failing. _net_check = await asyncio.create_subprocess_exec( - "docker", "network", "inspect", settings.docker_network, + "docker", + "network", + "inspect", + settings.docker_network, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL, ) await _net_check.wait() if _net_check.returncode != 0: - logger.info(f"Docker network '{settings.docker_network}' not found. Creating isolated bridge network (ICC disabled)...") + logger.info( + f"Docker network '{settings.docker_network}' not found. Creating isolated bridge network (ICC disabled)..." + ) _net_create = await asyncio.create_subprocess_exec( - "docker", "network", "create", - "--driver", "bridge", - "--opt", "com.docker.network.bridge.enable_icc=false", + "docker", + "network", + "create", + "--driver", + "bridge", + "--opt", + "com.docker.network.bridge.enable_icc=false", settings.docker_network, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL, ) await _net_create.wait() if _net_create.returncode != 0: - logger.warning("Failed to create isolated bridge network with ICC disabled. Falling back to standard bridge...") + logger.warning( + "Failed to create isolated bridge network with ICC disabled. Falling back to standard bridge..." + ) _net_create_fallback = await asyncio.create_subprocess_exec( - "docker", "network", "create", "--driver", "bridge", settings.docker_network, + "docker", + "network", + "create", + "--driver", + "bridge", + settings.docker_network, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL, ) @@ -574,9 +630,13 @@ async def execute_task(self, task_id: str): raise RuntimeError( f"Docker network '{settings.docker_network}' does not exist and could not be created automatically." ) - logger.info(f"Successfully created Docker network '{settings.docker_network}' (fallback)") + logger.info( + f"Successfully created Docker network '{settings.docker_network}' (fallback)" + ) else: - logger.info(f"Successfully created Docker network '{settings.docker_network}' with ICC disabled") + logger.info( + f"Successfully created Docker network '{settings.docker_network}' with ICC disabled" + ) docker_image = plugin.docker_image or "alpine:latest" docker_cmd = [ @@ -589,8 +649,10 @@ async def execute_task(self, task_id: str): f"{settings.sandbox_memory_mb}m", "--cpus", str(settings.sandbox_cpu_quota), - "--cap-drop", "NET_RAW", - "--network", settings.docker_network, + "--cap-drop", + "NET_RAW", + "--network", + settings.docker_network, docker_image, ] command = docker_cmd + command @@ -611,7 +673,7 @@ async def execute_task(self, task_id: str): # Save raw output raw_path = Path(settings.raw_output_dir) / f"{task_id}.txt" output = redact(output) - with open(raw_path, 'w') as f: + with open(raw_path, "w") as f: f.write(output) # Some CLI tools use non-zero exit codes for "no result" states while still @@ -642,8 +704,8 @@ async def execute_task(self, task_id: str): str(raw_path), " ".join(command), error_message, - task_id - ) + task_id, + ), ) # Upsert findings and report @@ -656,7 +718,7 @@ async def execute_task(self, task_id: str): plugin_id=plugin_id, target=target, status=final_status, - output=output + output=output, ) await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) @@ -670,9 +732,9 @@ async def execute_task(self, task_id: str): await db.log_audit( "task_completed", f"Task completed in {duration:.2f}s", - context={"task_id": task_id, "exit_code": locals().get('exit_code', 0)}, + context={"task_id": task_id, "exit_code": locals().get("exit_code", 0)}, task_id=task_id, - plugin_id=plugin_id + plugin_id=plugin_id, ) logger.info(f"Task {task_id} completed in {duration:.2f}s") @@ -683,7 +745,7 @@ async def execute_task(self, task_id: str): # Task.cancelled() returns False while the finally block is still # executing, so this is the only reliable place to write the # cancellation status to the DB. - duration = (time.time() - start_time) if 'start_time' in locals() else 0 + duration = (time.time() - start_time) if "start_time" in locals() else 0 await db.execute( """ UPDATE tasks SET @@ -698,7 +760,7 @@ async def execute_task(self, task_id: str): duration, task_id, TaskStatus.RUNNING.value, - ) + ), ) await self._broadcast(task_id, "status", TaskStatus.CANCELLED.value) await self._invalidate_cached_views() @@ -742,7 +804,7 @@ async def execute_task(self, task_id: str): logger.error(f"Task {task_id} failed: {e}", exc_info=True) # Update task as failed - duration = (time.time() - start_time) if 'start_time' in locals() else 0 + duration = (time.time() - start_time) if "start_time" in locals() else 0 await db.execute( """ UPDATE tasks SET @@ -757,8 +819,8 @@ async def execute_task(self, task_id: str): datetime.now().isoformat(), duration, str(e), - task_id - ) + task_id, + ), ) await self._broadcast(task_id, "status", TaskStatus.FAILED.value) @@ -769,7 +831,7 @@ async def execute_task(self, task_id: str): f"Task failed: {str(e)}", severity="error", context={"task_id": task_id, "error": str(e)}, - task_id=task_id + task_id=task_id, ) finally: # Always clean up: remove from the in-memory registry and @@ -777,12 +839,9 @@ async def execute_task(self, task_id: str): self.running_tasks.pop(task_id, None) self._process_pids.pop(task_id, None) await concurrent_limiter.release(task_id) - + async def _execute_command( - self, - command: list, - task_id: str, - timeout: int = 600 + self, command: list, task_id: str, timeout: int = 600 ) -> tuple: """ Execute command in subprocess and stream output. @@ -821,12 +880,16 @@ async def read_stream(): await asyncio.wait_for(read_stream(), timeout=timeout) await process.wait() self._process_pids.pop(task_id, None) - return "".join(output_lines), process.returncode if process.returncode is not None else -1 + return "".join( + output_lines + ), process.returncode if process.returncode is not None else -1 except asyncio.TimeoutError: logger.warning( "Task %s timed out after %ds — terminating process group (pid=%d)", - task_id, timeout, process.pid, + task_id, + timeout, + process.pid, ) await _terminate_process_group(process.pid, task_id) try: @@ -839,7 +902,8 @@ async def read_stream(): except asyncio.CancelledError: logger.warning( "Task %s cancelled — terminating process group (pid=%d)", - task_id, process.pid, + task_id, + process.pid, ) await _terminate_process_group(process.pid, task_id) try: @@ -874,14 +938,20 @@ def _resolve_execution_timeout(self, inputs: Dict[str, Any]) -> int: return min(timeout, settings.sandbox_timeout) return settings.sandbox_timeout - def _classify_command_result(self, plugin, output: str, exit_code: int) -> tuple[str, Optional[str]]: + def _classify_command_result( + self, plugin, output: str, exit_code: int + ) -> tuple[str, Optional[str]]: """Map raw process exit codes into task status with plugin-specific tolerances.""" normalized_output = output.lower() - if "unknown option:" in normalized_output or "flag provided but not defined:" in normalized_output: + if ( + "unknown option:" in normalized_output + or "flag provided but not defined:" in normalized_output + ): return ( TaskStatus.FAILED.value, - output or "Tool rejected one or more generated CLI options. Check the final command and raw output for details.", + output + or "Tool rejected one or more generated CLI options. Check the final command and raw output for details.", ) if exit_code == 0: @@ -938,9 +1008,11 @@ async def cancel_task(self, task_id: str) -> bool: if settings.docker_enabled: try: killer = await asyncio.create_subprocess_exec( - "docker", "kill", f"secuscan_task_{task_id}", + "docker", + "kill", + f"secuscan_task_{task_id}", stdout=subprocess.PIPE, - stderr=subprocess.PIPE + stderr=subprocess.PIPE, ) await killer.communicate() except Exception as e: @@ -950,20 +1022,18 @@ async def cancel_task(self, task_id: str) -> bool: async with db.transaction(): await db.execute( "UPDATE tasks SET status = ?, completed_at = ? WHERE id = ?", - (TaskStatus.CANCELLED.value, datetime.now().isoformat(), task_id) + (TaskStatus.CANCELLED.value, datetime.now().isoformat(), task_id), ) await db.log_audit( - "task_cancelled", - "Task cancelled by user", - task_id=task_id + "task_cancelled", "Task cancelled by user", task_id=task_id ) await self._broadcast(task_id, "status", TaskStatus.CANCELLED.value) await self._invalidate_cached_views() return True - + async def get_task_status(self, task_id: str) -> Optional[Dict]: """Get task status and progress""" db = await get_db() @@ -973,7 +1043,7 @@ async def get_task_status(self, task_id: str) -> Optional[Dict]: duration_seconds, exit_code, error_message, preset, inputs_json, execution_context_json FROM tasks WHERE id = ? """, - (task_id,) + (task_id,), ) if not task_row: return None @@ -984,7 +1054,7 @@ async def get_task_status(self, task_id: str) -> Optional[Dict]: if task_row["status"] == TaskStatus.QUEUED.value: queued_rows = await db.fetchall( "SELECT id FROM tasks WHERE status = ? ORDER BY created_at ASC", - (TaskStatus.QUEUED.value,) + (TaskStatus.QUEUED.value,), ) ids = [r["id"] for r in queued_rows] pending_count = len(ids) @@ -1039,10 +1109,16 @@ async def _hydrate_inputs_with_execution_context( effective_inputs["__extra_headers"] = { str(key): str(value) for key, value in headers.items() } - username = await self._read_vault_secret(db, credential_profile.get("username_secret_name")) - password = await self._read_vault_secret(db, credential_profile.get("password_secret_name")) + username = await self._read_vault_secret( + db, credential_profile.get("username_secret_name") + ) + password = await self._read_vault_secret( + db, credential_profile.get("password_secret_name") + ) if username is not None and password is not None: - token = base64.b64encode(f"{username}:{password}".encode("utf-8")).decode("ascii") + token = base64.b64encode( + f"{username}:{password}".encode("utf-8") + ).decode("ascii") effective_inputs.setdefault("__extra_headers", {}) effective_inputs["__extra_headers"]["Authorization"] = f"Basic {token}" @@ -1057,7 +1133,9 @@ async def _hydrate_inputs_with_execution_context( effective_inputs.setdefault("__extra_headers", {}) for key, value in extra_headers.items(): effective_inputs["__extra_headers"][str(key)] = str(value) - cookie_secret = await self._read_vault_secret(db, session_profile.get("cookie_secret_name")) + cookie_secret = await self._read_vault_secret( + db, session_profile.get("cookie_secret_name") + ) if cookie_secret: try: parsed = json.loads(cookie_secret) @@ -1143,31 +1221,49 @@ async def _load_previous_task_findings( ) return self._deserialize_finding_rows(rows) - def _normalize_asset_service_record(self, target: str, service: Dict[str, Any]) -> Dict[str, Any]: - metadata = service.get("metadata", {}) if isinstance(service.get("metadata"), dict) else {} + def _normalize_asset_service_record( + self, target: str, service: Dict[str, Any] + ) -> Dict[str, Any]: + metadata = ( + service.get("metadata", {}) + if isinstance(service.get("metadata"), dict) + else {} + ) host = str(service.get("host") or target) port = service.get("port") protocol = service.get("protocol") - cert_san = service.get("cert_san") or service.get("cert_sans") or metadata.get("cert_san") or metadata.get("cert_sans") or [] + cert_san = ( + service.get("cert_san") + or service.get("cert_sans") + or metadata.get("cert_san") + or metadata.get("cert_sans") + or [] + ) if not isinstance(cert_san, list): cert_san = [cert_san] fingerprint = service.get("service_fingerprint") if not fingerprint: - fingerprint = " ".join( - str(part).strip() - for part in ( - service.get("product"), - service.get("version"), - service.get("service"), - service.get("title"), + fingerprint = ( + " ".join( + str(part).strip() + for part in ( + service.get("product"), + service.get("version"), + service.get("service"), + service.get("title"), + ) + if str(part or "").strip() ) - if str(part or "").strip() - ) or None + or None + ) return { **service, "host": host, "target": target, - "asset_id": str(service.get("asset_id") or _stable_asset_id(target, host, port, protocol)), + "asset_id": str( + service.get("asset_id") + or _stable_asset_id(target, host, port, protocol) + ), "cert_san": cert_san, "metadata": metadata, "service_fingerprint": fingerprint, @@ -1195,7 +1291,9 @@ async def _build_result_contract( owner_id=owner_id, plugin_id=plugin_id, target=target, - findings=[item for item in result.get("findings", []) if isinstance(item, dict)], + findings=[ + item for item in result.get("findings", []) if isinstance(item, dict) + ], ) previous_findings = await self._load_previous_task_findings( db, @@ -1214,9 +1312,15 @@ async def _build_result_contract( structured_result["asset_services"] = asset_services structured_result["services"] = asset_services structured_result["finding_groups"] = build_finding_groups(normalized_findings) - structured_result["asset_summary"] = build_asset_summary(normalized_findings, asset_services) - structured_result["scan_diff"] = build_scan_diff(normalized_findings, previous_findings) - structured_result["severity_counts"] = self._build_severity_counts(normalized_findings) + structured_result["asset_summary"] = build_asset_summary( + normalized_findings, asset_services + ) + structured_result["scan_diff"] = build_scan_diff( + normalized_findings, previous_findings + ) + structured_result["severity_counts"] = self._build_severity_counts( + normalized_findings + ) structured_result["count"] = len(normalized_findings) return structured_result, previous_findings, asset_services @@ -1239,11 +1343,31 @@ async def _persist_finding( asset_exposure = finding.get("asset_exposure") discovered = _parse_discovered_at(finding) target_value = str(finding.get("target") or target) - metadata = finding.get("metadata", {}) if isinstance(finding.get("metadata"), dict) else {} - evidence = finding.get("evidence", []) if isinstance(finding.get("evidence"), list) else [] - asset_refs = finding.get("asset_refs", []) if isinstance(finding.get("asset_refs"), list) else [] - references = finding.get("references", []) if isinstance(finding.get("references"), list) else [] - corroborating_sources = finding.get("corroborating_sources", []) if isinstance(finding.get("corroborating_sources"), list) else [] + metadata = ( + finding.get("metadata", {}) + if isinstance(finding.get("metadata"), dict) + else {} + ) + evidence = ( + finding.get("evidence", []) + if isinstance(finding.get("evidence"), list) + else [] + ) + asset_refs = ( + finding.get("asset_refs", []) + if isinstance(finding.get("asset_refs"), list) + else [] + ) + references = ( + finding.get("references", []) + if isinstance(finding.get("references"), list) + else [] + ) + corroborating_sources = ( + finding.get("corroborating_sources", []) + if isinstance(finding.get("corroborating_sources"), list) + else [] + ) first_seen_at = str(finding.get("first_seen_at") or discovered.isoformat()) last_seen_at = str(finding.get("last_seen_at") or discovered.isoformat()) occurrence_count = int(finding.get("occurrence_count") or 1) @@ -1339,10 +1463,24 @@ async def _persist_finding( "risk_factors": risk_factors, } - async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plugin, plugin_id: str, target: str, status: str, output: str = ""): + async def _upsert_findings_and_report( + self, + db, + task_id: str, + owner_id: str, + plugin, + plugin_id: str, + target: str, + status: str, + output: str = "", + ): """Persist derived findings and report records into SQLite.""" parsed = self._parse_results(plugin, output) - structured_result, previous_findings, asset_services = await self._build_result_contract( + ( + structured_result, + previous_findings, + asset_services, + ) = await self._build_result_contract( db, task_id=task_id, owner_id=owner_id, @@ -1364,15 +1502,21 @@ async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plu ) structured_result["findings"] = findings_data - structured_result["severity_counts"] = self._build_severity_counts(findings_data) + structured_result["severity_counts"] = self._build_severity_counts( + findings_data + ) structured_result["finding_groups"] = build_finding_groups(findings_data) - structured_result["asset_summary"] = build_asset_summary(findings_data, asset_services) - structured_result["scan_diff"] = build_scan_diff(findings_data, previous_findings) + structured_result["asset_summary"] = build_asset_summary( + findings_data, asset_services + ) + structured_result["scan_diff"] = build_scan_diff( + findings_data, previous_findings + ) async with db.transaction(): await db.execute( "UPDATE tasks SET structured_json = ? WHERE id = ?", - (json.dumps(structured_result), task_id) + (json.dumps(structured_result), task_id), ) await db.execute( @@ -1406,9 +1550,23 @@ async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plu result=structured_result, ) - async def _upsert_findings_and_report_from_scanner(self, db, task_id: str, owner_id: str, scanner: Any, plugin_id: str, target: str, status: str, result: Dict[str, Any]): + async def _upsert_findings_and_report_from_scanner( + self, + db, + task_id: str, + owner_id: str, + scanner: Any, + plugin_id: str, + target: str, + status: str, + result: Dict[str, Any], + ): """Persist modular scanner results into findings, and reports.""" - structured_result, previous_findings, asset_services = await self._build_result_contract( + ( + structured_result, + previous_findings, + asset_services, + ) = await self._build_result_contract( db, task_id=task_id, owner_id=owner_id, @@ -1430,15 +1588,21 @@ async def _upsert_findings_and_report_from_scanner(self, db, task_id: str, owner ) structured_result["findings"] = findings_data - structured_result["severity_counts"] = self._build_severity_counts(findings_data) + structured_result["severity_counts"] = self._build_severity_counts( + findings_data + ) structured_result["finding_groups"] = build_finding_groups(findings_data) - structured_result["asset_summary"] = build_asset_summary(findings_data, asset_services) - structured_result["scan_diff"] = build_scan_diff(findings_data, previous_findings) + structured_result["asset_summary"] = build_asset_summary( + findings_data, asset_services + ) + structured_result["scan_diff"] = build_scan_diff( + findings_data, previous_findings + ) async with db.transaction(): await db.execute( "UPDATE tasks SET structured_json = ? WHERE id = ?", - (json.dumps(structured_result), task_id) + (json.dumps(structured_result), task_id), ) # Create/Update report @@ -1457,10 +1621,12 @@ async def _upsert_findings_and_report_from_scanner(self, db, task_id: str, owner owner_id, task_id, f"{scanner.name} Report", - "professional" if status == TaskStatus.COMPLETED.value else "failed", + "professional" + if status == TaskStatus.COMPLETED.value + else "failed", "ready" if status == TaskStatus.COMPLETED.value else "failed", len(findings_data), - 2, # Professional reports are typically multi-page + 2, # Professional reports are typically multi-page ), ) @@ -1509,12 +1675,12 @@ def _parse_results(self, plugin, output: str) -> Dict[str, Any]: """Route to appropriate parser based on plugin metadata.""" parser_type = plugin.output.get("parser") parser_input = self._resolve_parser_input(plugin, output) - + # 1. Check for custom parser.py in plugin directory (Recommended) plugin_manager = get_plugin_manager() plugin_dir = plugin_manager.plugins_dir / plugin.id parser_path = plugin_dir / "parser.py" - + if parser_path.exists(): if not plugin_manager.verify_parser_at_exec_time(plugin, plugin_dir): raise ValueError( @@ -1540,18 +1706,28 @@ def _parse_results(self, plugin, output: str) -> Dict[str, Any]: f"Custom parser failed for plugin '{plugin.id}': {exc.reason}" ) from exc except Exception as exc: - logger.error("Unexpected error running parser sandbox for '%s': %s", plugin.id, exc) + logger.error( + "Unexpected error running parser sandbox for '%s': %s", + plugin.id, + exc, + ) raise RuntimeError( f"Custom parser encountered an unexpected error for plugin '{plugin.id}': {exc}" ) from exc # 2. Fallback to legacy built-in parsers (only reached when no parser.py exists) if parser_type == "builtin_nmap": - return self._normalize_parsed_result(plugin, parser_input, self._parse_nmap_output(parser_input)) + return self._normalize_parsed_result( + plugin, parser_input, self._parse_nmap_output(parser_input) + ) elif parser_type == "builtin_http": - return self._normalize_parsed_result(plugin, parser_input, self._parse_http_output(parser_input)) - - return self._normalize_parsed_result(plugin, parser_input, {"findings": [], "raw": parser_input}) + return self._normalize_parsed_result( + plugin, parser_input, self._parse_http_output(parser_input) + ) + + return self._normalize_parsed_result( + plugin, parser_input, {"findings": [], "raw": parser_input} + ) def _resolve_parser_input(self, plugin, output: str) -> str: """Prefer report-file content when configured, fallback to command output.""" @@ -1563,11 +1739,15 @@ def _resolve_parser_input(self, plugin, output: str) -> str: logger.info("Using parser report file for %s: %s", plugin.id, path) return path.read_text(encoding="utf-8", errors="replace") except Exception as exc: - logger.warning("Failed to read parser report file %s: %s", path, exc) + logger.warning( + "Failed to read parser report file %s: %s", path, exc + ) return output - def _normalize_parsed_result(self, plugin, parser_input: str, parsed: Any) -> Dict[str, Any]: + def _normalize_parsed_result( + self, plugin, parser_input: str, parsed: Any + ) -> Dict[str, Any]: """ Normalize parser output shape so downstream report/asset logic always receives: { findings: List[Finding], ... }. @@ -1597,7 +1777,10 @@ def _normalize_parsed_result(self, plugin, parser_input: str, parsed: Any) -> Di ] # Fallback for JSON/JSONL plugin outputs where parser returns empty or unexpected data. - if not findings and str(plugin.output.get("format", "")).lower() in {"json", "jsonl"}: + if not findings and str(plugin.output.get("format", "")).lower() in { + "json", + "jsonl", + }: findings = self._parse_json_fallback_findings(plugin, parser_input) normalized["findings"] = findings @@ -1624,7 +1807,11 @@ def _normalize_finding(self, plugin, finding: Dict[str, Any]) -> Dict[str, Any]: } normalized_severity = severity_map.get(severity, "info") - category = finding.get("category") or finding.get("type") or str(plugin.category).title() + category = ( + finding.get("category") + or finding.get("type") + or str(plugin.category).title() + ) title = finding.get("title") or finding.get("name") or "Security Finding" description = finding.get("description") or finding.get("message") or str(title) @@ -1647,15 +1834,23 @@ def _normalize_finding(self, plugin, finding: Dict[str, Any]) -> Dict[str, Any]: "validated": bool(finding.get("validated", False)), "validation_method": finding.get("validation_method"), "confidence_reason": finding.get("confidence_reason"), - "evidence": finding.get("evidence", []) if isinstance(finding.get("evidence"), list) else [], - "asset_refs": finding.get("asset_refs", []) if isinstance(finding.get("asset_refs"), list) else [], + "evidence": finding.get("evidence", []) + if isinstance(finding.get("evidence"), list) + else [], + "asset_refs": finding.get("asset_refs", []) + if isinstance(finding.get("asset_refs"), list) + else [], "service_fingerprint": finding.get("service_fingerprint"), "cpe": finding.get("cpe"), - "references": finding.get("references", []) if isinstance(finding.get("references"), list) else [], + "references": finding.get("references", []) + if isinstance(finding.get("references"), list) + else [], "asset_exposure": finding.get("asset_exposure"), } - def _parse_json_fallback_findings(self, plugin, parser_input: str) -> List[Dict[str, Any]]: + def _parse_json_fallback_findings( + self, plugin, parser_input: str + ) -> List[Dict[str, Any]]: """Best-effort conversion of JSON payloads into finding entries.""" try: data = json.loads(parser_input) @@ -1667,7 +1862,9 @@ def _parse_json_fallback_findings(self, plugin, parser_input: str) -> List[Dict[ if isinstance(data, list): for idx, item in enumerate(data, start=1): if isinstance(item, dict): - findings.append(self._json_item_to_finding(plugin, item, f"Item {idx}")) + findings.append( + self._json_item_to_finding(plugin, item, f"Item {idx}") + ) else: findings.append( self._normalize_finding( @@ -1688,7 +1885,11 @@ def _parse_json_fallback_findings(self, plugin, parser_input: str) -> List[Dict[ if isinstance(data.get(list_key), list): for idx, item in enumerate(data[list_key], start=1): if isinstance(item, dict): - findings.append(self._json_item_to_finding(plugin, item, f"{list_key} #{idx}")) + findings.append( + self._json_item_to_finding( + plugin, item, f"{list_key} #{idx}" + ) + ) if findings: return findings @@ -1696,7 +1897,9 @@ def _parse_json_fallback_findings(self, plugin, parser_input: str) -> List[Dict[ return findings - def _json_item_to_finding(self, plugin, item: Dict[str, Any], default_title: str) -> Dict[str, Any]: + def _json_item_to_finding( + self, plugin, item: Dict[str, Any], default_title: str + ) -> Dict[str, Any]: title = ( item.get("title") or item.get("name") @@ -1704,7 +1907,12 @@ def _json_item_to_finding(self, plugin, item: Dict[str, Any], default_title: str or item.get("message") or default_title ) - description = item.get("description") or item.get("detail") or item.get("message") or str(item) + description = ( + item.get("description") + or item.get("detail") + or item.get("message") + or str(item) + ) severity = item.get("severity", "info") category = item.get("category", str(plugin.category).title()) return self._normalize_finding( @@ -1723,7 +1931,7 @@ def _parse_nmap_output(self, output: str) -> Dict[str, Any]: findings = [] ports = [] services = [] - + # Regex for open ports: 80/tcp open http port_pattern = re.compile(r"(\d+)/(tcp|udp)\s+open\s+([\w-]+)") for match in port_pattern.finditer(output): @@ -1731,19 +1939,25 @@ def _parse_nmap_output(self, output: str) -> Dict[str, Any]: port_val = int(port_str) ports.append(port_val) services.append(service) - findings.append({ - "title": f"Open Port: {port_str}/{proto} ({service})", - "category": "Network Service", - "severity": "low", - "description": f"Port {port_str} is open and running {service} service.", - "remediation": "Close unnecessary ports and use a firewall to restrict access.", - "metadata": {"port": port_str, "protocol": proto, "service": service} - }) - + findings.append( + { + "title": f"Open Port: {port_str}/{proto} ({service})", + "category": "Network Service", + "severity": "low", + "description": f"Port {port_str} is open and running {service} service.", + "remediation": "Close unnecessary ports and use a firewall to restrict access.", + "metadata": { + "port": port_str, + "protocol": proto, + "service": service, + }, + } + ) + return { "open_ports": sorted(list(set(ports))), "services": sorted(list(set(services))), - "findings": findings + "findings": findings, } def _parse_http_output(self, output: str) -> Dict[str, Any]: @@ -1754,31 +1968,32 @@ def _parse_http_output(self, output: str) -> Dict[str, Any]: if server_match := re.search(r"(?i)Server:\s*(.+)", output): server = server_match[1].strip() techs.append(server) - findings.append({ - "title": f"Web Server Disclosed: {server}", - "category": "Information Disclosure", - "severity": "low", - "description": f"The web server discloses its version: {server}", - "remediation": "Disable the Server header in web server configuration.", - "metadata": {"server": server} - }) + findings.append( + { + "title": f"Web Server Disclosed: {server}", + "category": "Information Disclosure", + "severity": "low", + "description": f"The web server discloses its version: {server}", + "remediation": "Disable the Server header in web server configuration.", + "metadata": {"server": server}, + } + ) if powered_match := re.search(r"(?i)X-Powered-By:\s*(.+)", output): powered = powered_match[1].strip() techs.append(powered) - findings.append({ - "title": f"X-Powered-By Disclosed: {powered}", - "category": "Information Disclosure", - "severity": "low", - "description": f"The application discloses its technology stack: {powered}", - "remediation": "Disable the X-Powered-By header.", - "metadata": {"tech": powered} - }) + findings.append( + { + "title": f"X-Powered-By Disclosed: {powered}", + "category": "Information Disclosure", + "severity": "low", + "description": f"The application discloses its technology stack: {powered}", + "remediation": "Disable the X-Powered-By header.", + "metadata": {"tech": powered}, + } + ) - return { - "technologies": sorted(list(set(techs))), - "findings": findings - } + return {"technologies": sorted(list(set(techs))), "findings": findings} async def _dispatch_task_notifications(self, db, task_id: str) -> None: """Evaluate notification rules for all findings on a completed task.""" diff --git a/backend/secuscan/platform_resources.py b/backend/secuscan/platform_resources.py index 1116ab8b9..2c2c309dd 100644 --- a/backend/secuscan/platform_resources.py +++ b/backend/secuscan/platform_resources.py @@ -30,7 +30,9 @@ def _stable_asset_id(target: str, host: Any, port: Any, protocol: Any) -> str: return f"asset:{digest}" -async def get_target_policy(db: Database, owner_id: str, policy_id: str | None) -> Optional[Dict[str, Any]]: +async def get_target_policy( + db: Database, owner_id: str, policy_id: str | None +) -> Optional[Dict[str, Any]]: if not policy_id: return None row = await db.fetchone( @@ -40,7 +42,9 @@ async def get_target_policy(db: Database, owner_id: str, policy_id: str | None) return _deserialize_resource_row(row) -async def get_credential_profile(db: Database, owner_id: str, profile_id: str | None) -> Optional[Dict[str, Any]]: +async def get_credential_profile( + db: Database, owner_id: str, profile_id: str | None +) -> Optional[Dict[str, Any]]: if not profile_id: return None row = await db.fetchone( @@ -50,7 +54,9 @@ async def get_credential_profile(db: Database, owner_id: str, profile_id: str | return _deserialize_resource_row(row) -async def get_session_profile(db: Database, owner_id: str, profile_id: str | None) -> Optional[Dict[str, Any]]: +async def get_session_profile( + db: Database, owner_id: str, profile_id: str | None +) -> Optional[Dict[str, Any]]: if not profile_id: return None row = await db.fetchone( @@ -117,66 +123,85 @@ async def replace_asset_services( async with db.transaction(): await db.execute("DELETE FROM asset_services WHERE task_id = ?", (task_id,)) for item in services: - metadata = item.get("metadata", {}) if isinstance(item.get("metadata"), dict) else {} - host = str(item.get("host") or target) - port = item.get("port") - protocol = item.get("protocol") - asset_id = str(item.get("asset_id") or _stable_asset_id(target, host, port, protocol)) - cert_sans = item.get("cert_san") or item.get("cert_sans") or metadata.get("cert_san") or metadata.get("cert_sans") or [] - if not isinstance(cert_sans, list): - cert_sans = [cert_sans] - service_fingerprint = item.get("service_fingerprint") - if not service_fingerprint: - service_fingerprint = " ".join( - str(part).strip() - for part in ( + metadata = ( + item.get("metadata", {}) + if isinstance(item.get("metadata"), dict) + else {} + ) + host = str(item.get("host") or target) + port = item.get("port") + protocol = item.get("protocol") + asset_id = str( + item.get("asset_id") or _stable_asset_id(target, host, port, protocol) + ) + cert_sans = ( + item.get("cert_san") + or item.get("cert_sans") + or metadata.get("cert_san") + or metadata.get("cert_sans") + or [] + ) + if not isinstance(cert_sans, list): + cert_sans = [cert_sans] + service_fingerprint = item.get("service_fingerprint") + if not service_fingerprint: + service_fingerprint = ( + " ".join( + str(part).strip() + for part in ( + item.get("product"), + item.get("version"), + item.get("service"), + item.get("title"), + ) + if str(part or "").strip() + ) + or None + ) + await db.execute( + """ + INSERT INTO asset_services ( + id, owner_id, task_id, plugin_id, target, asset_id, host, ip, port, protocol, + service, product, version, cpe, confidence, title, banner, cert_subject, + cert_san_json, cert_expiry, service_fingerprint, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + str(uuid.uuid4()), + owner_id, + task_id, + plugin_id, + target, + asset_id, + host, + item.get("ip"), + item.get("port"), + item.get("protocol"), + item.get("service"), item.get("product"), item.get("version"), - item.get("service"), + item.get("cpe"), + item.get("confidence"), item.get("title"), - ) - if str(part or "").strip() - ) or None - await db.execute( - """ - INSERT INTO asset_services ( - id, owner_id, task_id, plugin_id, target, asset_id, host, ip, port, protocol, - service, product, version, cpe, confidence, title, banner, cert_subject, - cert_san_json, cert_expiry, service_fingerprint, metadata_json - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - str(uuid.uuid4()), - owner_id, - task_id, - plugin_id, - target, - asset_id, - host, - item.get("ip"), - item.get("port"), - item.get("protocol"), - item.get("service"), - item.get("product"), - item.get("version"), - item.get("cpe"), - item.get("confidence"), - item.get("title"), - item.get("banner"), - item.get("cert_subject"), - json.dumps(cert_sans), - item.get("cert_expiry"), - service_fingerprint, - json.dumps(metadata), - ), - ) - - -def serialize_execution_context(context: ExecutionContext | Dict[str, Any] | None) -> str: + item.get("banner"), + item.get("cert_subject"), + json.dumps(cert_sans), + item.get("cert_expiry"), + service_fingerprint, + json.dumps(metadata), + ), + ) + + +def serialize_execution_context( + context: ExecutionContext | Dict[str, Any] | None, +) -> str: return json.dumps(normalize_execution_context(context or {})) -def _deserialize_resource_row(row: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: +def _deserialize_resource_row( + row: Optional[Dict[str, Any]], +) -> Optional[Dict[str, Any]]: if row is None: return None item = dict(row) diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index 6c94aea8e..0bb007415 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -2,7 +2,16 @@ API routes for SecuScan backend """ -from fastapi import APIRouter, HTTPException, BackgroundTasks, Response, Request, Depends, Body, Query +from fastapi import ( + APIRouter, + HTTPException, + BackgroundTasks, + Response, + Request, + Depends, + Body, + Query, +) from fastapi.responses import JSONResponse from typing import Any, Optional, List, Dict, Callable import json @@ -34,6 +43,7 @@ "build_report_filename", ] + def _parse_workflow_steps(raw_steps: Any) -> List[Dict[str, Any]]: if isinstance(raw_steps, list): parsed = raw_steps @@ -60,7 +70,10 @@ def _parse_workflow_steps(raw_steps: Any) -> List[Dict[str, Any]]: normalized.append(model.model_dump()) return normalized -def _serialize_workflow(row: Dict[str, Any], queued_task_ids: Optional[List[str]] = None) -> Dict[str, Any]: + +def _serialize_workflow( + row: Dict[str, Any], queued_task_ids: Optional[List[str]] = None +) -> Dict[str, Any]: """Return the workflow shape consumed by the frontend.""" return { "id": row["id"], @@ -84,11 +97,20 @@ def _json_payload(value: Any, fallback: str) -> str: from .cache import get_cache, invalidate_view_cache from .models import ( - TaskCreateRequest, TaskResponse, TaskResult, - PluginListResponse, ErrorResponse, BulkDeleteRequest, - NotificationRuleCreate, NotificationRuleUpdate, - NotificationChannelType, TaskStatus, - ExecutionContext, WorkflowStep, ValidationMode, EvidenceLevel, + TaskCreateRequest, + TaskResponse, + TaskResult, + PluginListResponse, + ErrorResponse, + BulkDeleteRequest, + NotificationRuleCreate, + NotificationRuleUpdate, + NotificationChannelType, + TaskStatus, + ExecutionContext, + WorkflowStep, + ValidationMode, + EvidenceLevel, NotificationDiagnosticsResponse, ) from .config import settings @@ -98,10 +120,15 @@ def _json_payload(value: Any, fallback: str) -> str: from .executor import executor from .redaction import redact_inputs from .ratelimit import ( - rate_limiter, concurrent_limiter, workflow_rate_limiter, - task_start_limiter, vault_limiter, - report_download_limiter, read_heavy_limiter, - resolve_client_identity, admin_limiter, + rate_limiter, + concurrent_limiter, + workflow_rate_limiter, + task_start_limiter, + vault_limiter, + report_download_limiter, + read_heavy_limiter, + resolve_client_identity, + admin_limiter, scheduler_tick_limiter, ) from .validation import validate_target, validate_task_start_payload, validate_url @@ -127,7 +154,9 @@ def _json_payload(value: Any, fallback: str) -> str: _EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$") -def _validate_notification_target(channel_type: NotificationChannelType, target: str) -> str: +def _validate_notification_target( + channel_type: NotificationChannelType, target: str +) -> str: cleaned = target.strip() if not cleaned: raise HTTPException(status_code=400, detail="Notification target is required") @@ -139,18 +168,19 @@ def _validate_notification_target(channel_type: NotificationChannelType, target: if settings.notification_ssrf_enabled: from .validation import resolve_and_validate_target, validate_webhook_target + ssrf_ok, ssrf_err = resolve_and_validate_target(cleaned) if not ssrf_ok: raise HTTPException( status_code=400, - detail=f"Webhook target blocked by SSRF protection: {ssrf_err}" + detail=f"Webhook target blocked by SSRF protection: {ssrf_err}", ) # Additional independent check against notification_blocked_ip_ranges target_ok, target_err = validate_webhook_target(cleaned) if not target_ok: raise HTTPException( status_code=400, - detail=f"Webhook target blocked by SSRF protection: {target_err}" + detail=f"Webhook target blocked by SSRF protection: {target_err}", ) return cleaned @@ -195,9 +225,9 @@ async def get_or_set_cached(key: str, builder): return value - - -async def require_owned_task(db, task_id: str, owner: str, columns: str = "owner_id") -> Dict[str, Any]: +async def require_owned_task( + db, task_id: str, owner: str, columns: str = "owner_id" +) -> Dict[str, Any]: """Fetch a task and enforce that it belongs to ``owner`` (issue #401). Returns the selected row on success. Raises 404 when the task does not @@ -208,7 +238,9 @@ async def require_owned_task(db, task_id: str, owner: str, columns: str = "owner if row is None: raise HTTPException(status_code=404, detail="Task not found") if row.get("owner_id") != owner: - raise HTTPException(status_code=403, detail="You do not have access to this task") + raise HTTPException( + status_code=403, detail="You do not have access to this task" + ) return row @@ -223,7 +255,9 @@ def iter_raw_output_chunks(path: str, chunk_size: int = SSE_RAW_OUTPUT_CHUNK_SIZ def _report_generation_error_response(task_id: str, report_format: str) -> JSONResponse: - logger.exception("Report generation failed for task_id=%s format=%s", task_id, report_format) + logger.exception( + "Report generation failed for task_id=%s format=%s", task_id, report_format + ) return JSONResponse( status_code=500, content={ @@ -253,10 +287,8 @@ async def list_plugins(): plugin_manager = await get_plugin_manager_for_request() plugins = plugin_manager.list_plugins() - return PluginListResponse( - plugins=plugins, - total=len(plugins) - ) + return PluginListResponse(plugins=plugins, total=len(plugins)) + @router.get("/plugins/summary") async def get_plugins_summary(): @@ -273,9 +305,7 @@ async def get_plugins_summary(): for plugin in plugins: category = plugin.get("category", "unknown") - category_counts[category] = ( - category_counts.get(category, 0) + 1 - ) + category_counts[category] = category_counts.get(category, 0) + 1 availability = plugin.get("availability", {}) runnable = availability.get("runnable", False) @@ -288,9 +318,10 @@ async def get_plugins_summary(): "total_plugins": total_plugins, "runnable_count": runnable_count, "unavailable_count": unavailable_count, - "category_counts": dict(sorted(category_counts.items())) + "category_counts": dict(sorted(category_counts.items())), } + @router.get("/plugin/{plugin_id}/schema") async def get_plugin_schema(plugin_id: str): """Get plugin schema for UI generation""" @@ -324,7 +355,9 @@ async def start_task( # ── Payload size / field-length guard ───────────────────────────────── raw_body = await raw_request.body() execution_context = normalize_execution_context(request.execution_context) - ok, status_code, error_msg = validate_task_start_payload(raw_body, request.inputs, execution_context) + ok, status_code, error_msg = validate_task_start_payload( + raw_body, request.inputs, execution_context + ) if not ok: raise HTTPException(status_code=status_code, detail=error_msg) @@ -333,7 +366,7 @@ async def start_task( logger.warning(f"Task start failed: Consent not granted. Request: {request}") raise HTTPException( status_code=400, - detail="Consent required. You must acknowledge the legal notice." + detail="Consent required. You must acknowledge the legal notice.", ) # Get plugin @@ -342,21 +375,37 @@ async def start_task( if not plugin: logger.warning(f"Task start failed: Plugin not found: {request.plugin_id}") - raise HTTPException(status_code=404, detail=f"Plugin not found: {request.plugin_id}") + raise HTTPException( + status_code=404, detail=f"Plugin not found: {request.plugin_id}" + ) db = await get_db() - target_policy = await get_target_policy(db, owner, execution_context.get("target_policy_id")) - credential_profile = await get_credential_profile(db, owner, execution_context.get("credential_profile_id")) - session_profile = await get_session_profile(db, owner, execution_context.get("session_profile_id")) + target_policy = await get_target_policy( + db, owner, execution_context.get("target_policy_id") + ) + credential_profile = await get_credential_profile( + db, owner, execution_context.get("credential_profile_id") + ) + session_profile = await get_session_profile( + db, owner, execution_context.get("session_profile_id") + ) if execution_context.get("target_policy_id") and not target_policy: - raise HTTPException(status_code=400, detail="Target policy not found for this workspace") + raise HTTPException( + status_code=400, detail="Target policy not found for this workspace" + ) if execution_context.get("credential_profile_id") and not credential_profile: - raise HTTPException(status_code=400, detail="Credential profile not found for this workspace") + raise HTTPException( + status_code=400, detail="Credential profile not found for this workspace" + ) if execution_context.get("session_profile_id") and not session_profile: - raise HTTPException(status_code=400, detail="Session profile not found for this workspace") + raise HTTPException( + status_code=400, detail="Session profile not found for this workspace" + ) - if (credential_profile or session_profile) and not (target_policy and target_policy.get("allow_authenticated_scan")): + if (credential_profile or session_profile) and not ( + target_policy and target_policy.get("allow_authenticated_scan") + ): raise HTTPException( status_code=400, detail="Authenticated scans require a target policy with authenticated scanning enabled.", @@ -364,10 +413,13 @@ async def start_task( requires_exploit_policy = ( plugin.safety.get("level") == "exploit" - or execution_context.get("validation_mode") == ValidationMode.CONTROLLED_EXTRACT.value + or execution_context.get("validation_mode") + == ValidationMode.CONTROLLED_EXTRACT.value ) - if requires_exploit_policy and not (target_policy and target_policy.get("allow_exploit_validation")): + if requires_exploit_policy and not ( + target_policy and target_policy.get("allow_exploit_validation") + ): raise HTTPException( status_code=400, detail="Offensive validation requires a target policy that explicitly allows exploit validation.", @@ -395,13 +447,21 @@ async def start_task( try: tval = int(effective_inputs[tkey]) except (TypeError, ValueError): - raise HTTPException(status_code=400, detail=f"Invalid value for {tkey}: must be an integer") + raise HTTPException( + status_code=400, + detail=f"Invalid value for {tkey}: must be an integer", + ) if tval <= 0 or tval > settings.sandbox_timeout: - raise HTTPException(status_code=400, detail=f"{tkey} must be between 1 and {settings.sandbox_timeout} seconds") + raise HTTPException( + status_code=400, + detail=f"{tkey} must be between 1 and {settings.sandbox_timeout} seconds", + ) if target := effective_inputs.get("target"): target_str = str(target) - should_validate_target = plugin.category != "code" and not is_filesystem_target(target_str) + should_validate_target = plugin.category != "code" and not is_filesystem_target( + target_str + ) if should_validate_target: try: @@ -410,14 +470,19 @@ async def start_task( timeout=float(settings.dns_resolution_timeout_seconds), ) except asyncio.TimeoutError: - logger.warning("Task start failed: Target validation timed out for '%s'", target_str) + logger.warning( + "Task start failed: Target validation timed out for '%s'", + target_str, + ) raise HTTPException( status_code=400, detail="Target validation timed out in safe mode (SecuScan Guardrail)", ) if not is_valid: - logger.warning(f"Task start failed: Target validation failed for '{target}': {error_msg}") + logger.warning( + f"Task start failed: Target validation failed for '{target}': {error_msg}" + ) raise HTTPException(status_code=400, detail=error_msg) # Check rate limits per (client, plugin) so one client cannot exhaust @@ -425,7 +490,9 @@ async def start_task( client_id = resolve_client_identity(raw_request) can_execute, error_msg = await rate_limiter.can_execute( request.plugin_id, - plugin.safety.get("rate_limit", {}).get("max_per_hour", settings.max_tasks_per_hour), + plugin.safety.get("rate_limit", {}).get( + "max_per_hour", settings.max_tasks_per_hour + ), client_id=client_id, ) @@ -452,7 +519,9 @@ async def start_task( can_acquire, error_msg = await concurrent_limiter.acquire(task_id) if not can_acquire: # Roll back: mark the DB row failed so it isn't left orphaned - await executor.mark_task_failed(task_id, reason="Concurrency limit reached; task was not started") + await executor.mark_task_failed( + task_id, reason="Concurrency limit reached; task was not started" + ) raise HTTPException(status_code=503, detail=error_msg) # Slot is held — schedule execution. @@ -468,7 +537,7 @@ async def start_task( "task_id": task_id, "status": "queued", "created_at": "now", - "stream_url": f"/api/v1/task/{task_id}/stream" + "stream_url": f"/api/v1/task/{task_id}/stream", } @router.post("/task/{task_id}/retry", dependencies=[Depends(task_start_limiter)]) @@ -547,6 +616,7 @@ async def get_task_status(task_id: str, owner: str = Depends(get_current_owner)) return status + @router.get("/task/{task_id}/stream") async def stream_task_output(task_id: str, owner: str = Depends(get_current_owner)): """Stream task output via Server-Sent Events (SSE)""" @@ -563,22 +633,25 @@ async def event_generator(): # First, send the initial status and phase yield { "event": "status", - "data": json.dumps({"status": status["status"], "scan_phase": status.get("scan_phase")}) + "data": json.dumps( + {"status": status["status"], "scan_phase": status.get("scan_phase")} + ), } # If it's already completed/failed, we just return the raw output if any and close if status["status"] in ["completed", "failed", "cancelled"]: try: db = await get_db() - task_row = await db.fetchone("SELECT raw_output_path FROM tasks WHERE id = ?", (task_id,)) + task_row = await db.fetchone( + "SELECT raw_output_path FROM tasks WHERE id = ?", (task_id,) + ) if task_row and task_row["raw_output_path"]: for chunk in iter_raw_output_chunks(task_row["raw_output_path"]): - yield { - "event": "output", - "data": json.dumps({"chunk": chunk}) - } + yield {"event": "output", "data": json.dumps({"chunk": chunk})} except Exception as exc: - logger.warning("Failed to replay raw output for task %s: %s", task_id, exc) + logger.warning( + "Failed to replay raw output for task %s: %s", task_id, exc + ) return # Otherwise, subscribe to the live task events @@ -591,19 +664,19 @@ async def event_generator(): if event["type"] == "status": yield { "event": "status", - "data": json.dumps({"status": event["data"]}) + "data": json.dumps({"status": event["data"]}), } if event["data"] in ["completed", "failed", "cancelled"]: break elif event["type"] == "phase": yield { "event": "phase", - "data": json.dumps({"scan_phase": event["data"]}) + "data": json.dumps({"scan_phase": event["data"]}), } elif event["type"] == "output": yield { "event": "output", - "data": json.dumps({"chunk": event["data"]}) + "data": json.dumps({"chunk": event["data"]}), } except asyncio.CancelledError: pass @@ -612,34 +685,49 @@ async def event_generator(): return EventSourceResponse(event_generator()) -@router.get("/task/{task_id}/report/csv", dependencies=[Depends(report_download_limiter)]) + +@router.get( + "/task/{task_id}/report/csv", dependencies=[Depends(report_download_limiter)] +) async def download_csv_report(task_id: str, owner: str = Depends(get_current_owner)): """Download task results as a CSV report.""" db = await get_db() task_row = await db.fetchone( "SELECT id, owner_id, plugin_id, tool_name, target, status, created_at, preset, inputs_json, command_used, structured_json FROM tasks WHERE id = ?", - (task_id,) + (task_id,), ) if not task_row: raise HTTPException(status_code=404, detail="Task not found") if task_row["owner_id"] != owner: - raise HTTPException(status_code=403, detail="You do not have access to this task") + raise HTTPException( + status_code=403, detail="You do not have access to this task" + ) if task_row["status"] not in ["completed", "failed"]: raise HTTPException(status_code=400, detail="Task is not finished yet") try: - structured_data = json.loads(task_row["structured_json"]) if task_row["structured_json"] else {} - csv_data = reporting.generate_csv_report(dict(task_row), {"structured": structured_data}) + structured_data = ( + json.loads(task_row["structured_json"]) + if task_row["structured_json"] + else {} + ) + csv_data = reporting.generate_csv_report( + dict(task_row), {"structured": structured_data} + ) except Exception: return _report_generation_error_response(task_id, "csv") await db.log_audit( "report_downloaded", f"CSV report downloaded for task {task_id}", - context={"format": "csv", "task_id": task_id, "plugin_id": task_row["plugin_id"]}, + context={ + "format": "csv", + "task_id": task_id, + "plugin_id": task_row["plugin_id"], + }, task_id=task_id, plugin_id=task_row["plugin_id"], ) @@ -647,37 +735,54 @@ async def download_csv_report(task_id: str, owner: str = Depends(get_current_own return Response( content=csv_data, media_type="text/csv", - headers={"Content-Disposition": f'attachment; filename="{build_report_filename(dict(task_row), "csv")}"'} + headers={ + "Content-Disposition": f'attachment; filename="{build_report_filename(dict(task_row), "csv")}"' + }, ) -@router.get("/task/{task_id}/report/html", dependencies=[Depends(report_download_limiter)]) + +@router.get( + "/task/{task_id}/report/html", dependencies=[Depends(report_download_limiter)] +) async def download_html_report(task_id: str, owner: str = Depends(get_current_owner)): """Download task results as an HTML report.""" db = await get_db() task_row = await db.fetchone( "SELECT id, owner_id, plugin_id, tool_name, target, status, created_at, preset, inputs_json, command_used, structured_json FROM tasks WHERE id = ?", - (task_id,) + (task_id,), ) if not task_row: raise HTTPException(status_code=404, detail="Task not found") if task_row["owner_id"] != owner: - raise HTTPException(status_code=403, detail="You do not have access to this task") + raise HTTPException( + status_code=403, detail="You do not have access to this task" + ) if task_row["status"] not in ["completed", "failed"]: raise HTTPException(status_code=400, detail="Task is not finished yet") try: - structured_data = json.loads(task_row["structured_json"]) if task_row["structured_json"] else {} - html_content = reporting.generate_html_report(dict(task_row), {"structured": structured_data}) + structured_data = ( + json.loads(task_row["structured_json"]) + if task_row["structured_json"] + else {} + ) + html_content = reporting.generate_html_report( + dict(task_row), {"structured": structured_data} + ) except Exception: return _report_generation_error_response(task_id, "html") await db.log_audit( "report_downloaded", f"HTML report downloaded for task {task_id}", - context={"format": "html", "task_id": task_id, "plugin_id": task_row["plugin_id"]}, + context={ + "format": "html", + "task_id": task_id, + "plugin_id": task_row["plugin_id"], + }, task_id=task_id, plugin_id=task_row["plugin_id"], ) @@ -685,37 +790,56 @@ async def download_html_report(task_id: str, owner: str = Depends(get_current_ow return Response( content=html_content, media_type="text/html", - headers={"Content-Disposition": f'attachment; filename="{build_report_filename(dict(task_row), "html")}"'} + headers={ + "Content-Disposition": f'attachment; filename="{build_report_filename(dict(task_row), "html")}"' + }, ) -@router.get("/task/{task_id}/report/pdf", dependencies=[Depends(report_download_limiter)]) + +@router.get( + "/task/{task_id}/report/pdf", dependencies=[Depends(report_download_limiter)] +) async def download_pdf_report(task_id: str, owner: str = Depends(get_current_owner)): """Download task results as a PDF report.""" db = await get_db() task_row = await db.fetchone( "SELECT id, owner_id, plugin_id, tool_name, target, status, created_at, preset, inputs_json, command_used, structured_json FROM tasks WHERE id = ?", - (task_id,) + (task_id,), ) if not task_row: raise HTTPException(status_code=404, detail="Task not found") if task_row["owner_id"] != owner: - raise HTTPException(status_code=403, detail="You do not have access to this task") + raise HTTPException( + status_code=403, detail="You do not have access to this task" + ) if task_row["status"] not in ["completed", "failed"]: raise HTTPException(status_code=400, detail="Task is not finished yet") try: - structured_data = json.loads(task_row["structured_json"]) if task_row["structured_json"] else {} - pdf_bytes = bytes(reporting.generate_pdf_report(dict(task_row), {"structured": structured_data})) + structured_data = ( + json.loads(task_row["structured_json"]) + if task_row["structured_json"] + else {} + ) + pdf_bytes = bytes( + reporting.generate_pdf_report( + dict(task_row), {"structured": structured_data} + ) + ) except Exception: return _report_generation_error_response(task_id, "pdf") await db.log_audit( "report_downloaded", f"PDF report downloaded for task {task_id}", - context={"format": "pdf", "task_id": task_id, "plugin_id": task_row["plugin_id"]}, + context={ + "format": "pdf", + "task_id": task_id, + "plugin_id": task_row["plugin_id"], + }, task_id=task_id, plugin_id=task_row["plugin_id"], ) @@ -723,38 +847,54 @@ async def download_pdf_report(task_id: str, owner: str = Depends(get_current_own return Response( content=pdf_bytes, media_type="application/pdf", - headers={"Content-Disposition": f'attachment; filename="{build_report_filename(dict(task_row), "pdf")}"'} + headers={ + "Content-Disposition": f'attachment; filename="{build_report_filename(dict(task_row), "pdf")}"' + }, ) -@router.get("/task/{task_id}/report/sarif", dependencies=[Depends(report_download_limiter)]) +@router.get( + "/task/{task_id}/report/sarif", dependencies=[Depends(report_download_limiter)] +) async def download_sarif_report(task_id: str, owner: str = Depends(get_current_owner)): """Download task results as a SARIF report.""" db = await get_db() task_row = await db.fetchone( "SELECT id, owner_id, plugin_id, tool_name, target, status, created_at, preset, inputs_json, command_used, structured_json FROM tasks WHERE id = ?", - (task_id,) + (task_id,), ) if not task_row: raise HTTPException(status_code=404, detail="Task not found") if task_row["owner_id"] != owner: - raise HTTPException(status_code=403, detail="You do not have access to this task") + raise HTTPException( + status_code=403, detail="You do not have access to this task" + ) if task_row["status"] not in ["completed", "failed"]: raise HTTPException(status_code=400, detail="Task is not finished yet") try: - structured_data = json.loads(task_row["structured_json"]) if task_row["structured_json"] else {} - sarif_data = reporting.generate_sarif_report(dict(task_row), {"structured": structured_data}) + structured_data = ( + json.loads(task_row["structured_json"]) + if task_row["structured_json"] + else {} + ) + sarif_data = reporting.generate_sarif_report( + dict(task_row), {"structured": structured_data} + ) except Exception: return _report_generation_error_response(task_id, "sarif") await db.log_audit( "report_downloaded", f"SARIF report downloaded for task {task_id}", - context={"format": "sarif", "task_id": task_id, "plugin_id": task_row["plugin_id"]}, + context={ + "format": "sarif", + "task_id": task_id, + "plugin_id": task_row["plugin_id"], + }, task_id=task_id, plugin_id=task_row["plugin_id"], ) @@ -762,7 +902,9 @@ async def download_sarif_report(task_id: str, owner: str = Depends(get_current_o return Response( content=sarif_data, media_type="application/sarif+json", - headers={"Content-Disposition": f'attachment; filename="{build_report_filename(dict(task_row), "sarif")}"'} + headers={ + "Content-Disposition": f'attachment; filename="{build_report_filename(dict(task_row), "sarif")}"' + }, ) @@ -778,14 +920,16 @@ async def get_task_result(task_id: str, owner: str = Depends(get_current_owner)) raw_output_path, command_used, error_message, exit_code FROM tasks WHERE id = ? """, - (task_id,) + (task_id,), ) if not task_row: raise HTTPException(status_code=404, detail="Task not found") if task_row["owner_id"] != owner: - raise HTTPException(status_code=403, detail="You do not have access to this task") + raise HTTPException( + status_code=403, detail="You do not have access to this task" + ) structured = {} if task_row["structured_json"]: @@ -806,24 +950,35 @@ async def get_task_result(task_id: str, owner: str = Depends(get_current_owner)) asset_services = deserialize_asset_service_rows(asset_rows) if not findings and isinstance(structured, dict): - findings = [item for item in structured.get("findings", []) if isinstance(item, dict)] + findings = [ + item for item in structured.get("findings", []) if isinstance(item, dict) + ] severity_counts: Dict[str, int] = {} for finding in findings: severity = str(finding.get("severity", "info")).lower() severity_counts[severity] = severity_counts.get(severity, 0) + 1 - finding_groups = structured.get("finding_groups") if isinstance(structured, dict) else None + finding_groups = ( + structured.get("finding_groups") if isinstance(structured, dict) else None + ) if not isinstance(finding_groups, list) or not finding_groups: finding_groups = build_finding_groups(findings) - asset_summary = structured.get("asset_summary") if isinstance(structured, dict) else None + asset_summary = ( + structured.get("asset_summary") if isinstance(structured, dict) else None + ) if not isinstance(asset_summary, list) or not asset_summary: asset_summary = build_asset_summary(findings, asset_services) scan_diff = structured.get("scan_diff") if isinstance(structured, dict) else None if not isinstance(scan_diff, dict): - scan_diff = {"new": [], "resolved": [], "changed": [], "summary": {"new_count": 0, "resolved_count": 0, "changed_count": 0}} + scan_diff = { + "new": [], + "resolved": [], + "changed": [], + "summary": {"new_count": 0, "resolved_count": 0, "changed_count": 0}, + } if isinstance(structured, dict): structured["findings"] = findings @@ -833,32 +988,51 @@ async def get_task_result(task_id: str, owner: str = Depends(get_current_owner)) structured["asset_services"] = asset_services structured["severity_counts"] = severity_counts - structured_summary = structured.get("summary") if isinstance(structured, dict) else None - summary: List[str] = [ - str(item) for item in structured_summary - if isinstance(item, (str, int, float)) and str(item).strip() - ] if isinstance(structured_summary, list) else [] + structured_summary = ( + structured.get("summary") if isinstance(structured, dict) else None + ) + summary: List[str] = ( + [ + str(item) + for item in structured_summary + if isinstance(item, (str, int, float)) and str(item).strip() + ] + if isinstance(structured_summary, list) + else [] + ) total_findings = len(findings) if not summary and total_findings > 0: - critical_high = severity_counts.get("critical", 0) + severity_counts.get("high", 0) + critical_high = severity_counts.get("critical", 0) + severity_counts.get( + "high", 0 + ) if critical_high > 0: - summary.append(f"Assessment identified {total_findings} security risks, including {critical_high} high-priority items requiring remediation.") + summary.append( + f"Assessment identified {total_findings} security risks, including {critical_high} high-priority items requiring remediation." + ) else: - summary.append(f"Assessment identified {total_findings} minor observations; no critical or high-severity threats were found.") + summary.append( + f"Assessment identified {total_findings} minor observations; no critical or high-severity threats were found." + ) elif not summary: - summary.append("Security analysis revealed no significant vulnerabilities or exposed risks.") + summary.append( + "Security analysis revealed no significant vulnerabilities or exposed risks." + ) if ports := structured.get("open_ports"): - summary.append(f"Perimeter analysis confirmed {len(ports)} active network entry points.") + summary.append( + f"Perimeter analysis confirmed {len(ports)} active network entry points." + ) if techs := structured.get("technologies"): - summary.append(f"Fingerprinting identified {len(techs)} unique technologies powering the target infrastructure.") + summary.append( + f"Fingerprinting identified {len(techs)} unique technologies powering the target infrastructure." + ) # Read raw output (limit to 100k for performance, but usually enough) raw_output = None if task_row["raw_output_path"]: try: - with open(task_row["raw_output_path"], 'r') as f: + with open(task_row["raw_output_path"], "r") as f: raw_output = f.read(100000) except Exception: pass @@ -873,7 +1047,9 @@ async def get_task_result(task_id: str, owner: str = Depends(get_current_owner)) "status": task_row["status"], "preset": task_row["preset"], "inputs": redact_inputs(json.loads(task_row["inputs_json"] or "{}")), - "execution_context": normalize_execution_context(json.loads(task_row["execution_context_json"] or "{}")), + "execution_context": normalize_execution_context( + json.loads(task_row["execution_context_json"] or "{}") + ), "summary": summary, "severity_counts": severity_counts, "findings": findings, @@ -885,10 +1061,12 @@ async def get_task_result(task_id: str, owner: str = Depends(get_current_owner)) "raw_output_excerpt": raw_output, "raw_output": raw_output, "command_used": task_row["command_used"], - "errors": [{"message": task_row["error_message"]}] if task_row["error_message"] else [], + "errors": [{"message": task_row["error_message"]}] + if task_row["error_message"] + else [], "error_message": task_row["error_message"], "exit_code": task_row["exit_code"], - "metadata": {} + "metadata": {}, } @@ -903,11 +1081,7 @@ async def cancel_task(task_id: str, owner: str = Depends(get_current_owner)): if not cancelled: raise HTTPException(status_code=404, detail="Task not found or not running") - return { - "task_id": task_id, - "status": "cancelled", - "cancelled_at": "now" - } + return {"task_id": task_id, "status": "cancelled", "cancelled_at": "now"} @router.get("/dashboard/summary", dependencies=[Depends(read_heavy_limiter)]) @@ -972,7 +1146,13 @@ async def build(): ) recent_findings: List[Dict] = parse_json_fields( recent_rows, - ["metadata_json", "risk_factors_json", "evidence_json", "asset_refs_json", "references_json"], + [ + "metadata_json", + "risk_factors_json", + "evidence_json", + "asset_refs_json", + "references_json", + ], ) for finding in recent_findings: if "risk_factors_json" in finding: @@ -985,10 +1165,13 @@ async def build(): finding["references"] = finding.pop("references_json") risk_scores = [ - f.get("risk_score") for f in recent_findings + f.get("risk_score") + for f in recent_findings if isinstance(f.get("risk_score"), (int, float)) ] - avg_risk_score = round(sum(risk_scores) / len(risk_scores), 1) if risk_scores else None + avg_risk_score = ( + round(sum(risk_scores) / len(risk_scores), 1) if risk_scores else None + ) return { "total_findings": total_findings, @@ -998,27 +1181,35 @@ async def build(): "low_findings": low_findings, "info_findings": info_findings, "avg_risk_score": avg_risk_score, - "last_scan_time": recent_findings[0].get("discovered_at") if recent_findings else None, + "last_scan_time": recent_findings[0].get("discovered_at") + if recent_findings + else None, "recent_findings": recent_findings, "scan_activity": { - "total": int(task_stats["total"]) if task_stats and task_stats.get("total") is not None else 0, - "completed": int(task_stats["completed"]) if task_stats and task_stats.get("completed") is not None else 0, - "running": int(task_stats["running"]) if task_stats and task_stats.get("running") is not None else 0, + "total": int(task_stats["total"]) + if task_stats and task_stats.get("total") is not None + else 0, + "completed": int(task_stats["completed"]) + if task_stats and task_stats.get("completed") is not None + else 0, + "running": int(task_stats["running"]) + if task_stats and task_stats.get("running") is not None + else 0, }, "running_tasks": parse_json_fields( await db.fetchall( "SELECT id, plugin_id, tool_name, target, status, created_at FROM tasks WHERE owner_id = ? AND status = 'running' ORDER BY created_at DESC LIMIT 5", (owner,), ), - [] + [], ), "recent_tasks": parse_json_fields( await db.fetchall( "SELECT id, plugin_id, tool_name, target, status, created_at, duration_seconds FROM tasks WHERE owner_id = ? ORDER BY created_at DESC LIMIT 5", (owner,), ), - [] - ) + [], + ), } return await get_or_set_cached(f"summary:dashboard:{owner}", build) @@ -1061,7 +1252,9 @@ async def build(): } # Cache key includes pagination params so different pages do not collide. - return await get_or_set_cached(f"findings:list:{owner}:page={page}:per_page={per_page}", build) + return await get_or_set_cached( + f"findings:list:{owner}:page={page}:per_page={per_page}", build + ) @router.get("/finding-groups", dependencies=[Depends(read_heavy_limiter)]) @@ -1086,7 +1279,9 @@ async def build(): "per_page": per_page, } - return await get_or_set_cached(f"findings:groups:{owner}:page={page}:per_page={per_page}", build) + return await get_or_set_cached( + f"findings:groups:{owner}:page={page}:per_page={per_page}", build + ) @router.get("/task/{task_id}/diff", dependencies=[Depends(read_heavy_limiter)]) @@ -1099,7 +1294,9 @@ async def get_task_diff(task_id: str, owner: str = Depends(get_current_owner)): if not task_row: raise HTTPException(status_code=404, detail="Task not found") if task_row["owner_id"] != owner: - raise HTTPException(status_code=403, detail="You do not have access to this task") + raise HTTPException( + status_code=403, detail="You do not have access to this task" + ) structured = {} if task_row["structured_json"]: @@ -1109,7 +1306,12 @@ async def get_task_diff(task_id: str, owner: str = Depends(get_current_owner)): structured = {} diff = structured.get("scan_diff") if isinstance(structured, dict) else None if not isinstance(diff, dict): - diff = {"new": [], "resolved": [], "changed": [], "summary": {"new_count": 0, "resolved_count": 0, "changed_count": 0}} + diff = { + "new": [], + "resolved": [], + "changed": [], + "summary": {"new_count": 0, "resolved_count": 0, "changed_count": 0}, + } return diff @@ -1155,7 +1357,7 @@ async def list_tasks( allowed_values = ", ".join([s.value for s in TaskStatus]) raise HTTPException( status_code=400, - detail=f"Invalid task status '{status}'. Allowed values: {allowed_values}" + detail=f"Invalid task status '{status}'. Allowed values: {allowed_values}", ) where_clauses.append("status = ?") @@ -1173,11 +1375,26 @@ async def list_tasks( if where_clauses: count_query += " WHERE " + " AND ".join(where_clauses) - count_result = await db.fetchone(count_query, tuple(params[:-2]) if where_clauses else ()) - total: int = int(count_result["total"]) if count_result and count_result.get("total") is not None else 0 + count_result = await db.fetchone( + count_query, tuple(params[:-2]) if where_clauses else () + ) + total: int = ( + int(count_result["total"]) + if count_result and count_result.get("total") is not None + else 0 + ) # Parse JSON fields and format for frontend - tasks_list = parse_json_fields(tasks, ["structured_json", "config_json", "metadata_json", "inputs_json", "execution_context_json"]) + tasks_list = parse_json_fields( + tasks, + [ + "structured_json", + "config_json", + "metadata_json", + "inputs_json", + "execution_context_json", + ], + ) for t in tasks_list: if "id" in t: t["task_id"] = t.pop("id") @@ -1202,6 +1419,7 @@ def build_page_url(page_num): if status: query_params["status"] = status return f"/api/v1/tasks?{urlencode(query_params)}" + return { "tasks": tasks_list, "pagination": { @@ -1210,13 +1428,14 @@ def build_page_url(page_num): "total_pages": total_pages, "total_items": total, "next": build_page_url(next_page), - "previous": build_page_url(prev_page) - } + "previous": build_page_url(prev_page), + }, } SQLITE_CHUNK_SIZE = 500 # safely under SQLITE_LIMIT_VARIABLE_NUMBER = 999 + async def delete_task_records(task_ids: List[str]): """Helper to delete database records and files for multiple tasks. @@ -1238,7 +1457,7 @@ async def delete_task_records(task_ids: List[str]): placeholders = ",".join(["?"] * len(chunk)) rows = await db.fetchall( f"SELECT raw_output_path FROM tasks WHERE id IN ({placeholders})", - tuple(chunk) + tuple(chunk), ) all_task_rows.extend(rows) @@ -1251,12 +1470,12 @@ async def delete_task_records(task_ids: List[str]): placeholders = ",".join(["?"] * len(chunk)) running = await db.fetchone( f"SELECT 1 FROM tasks WHERE id IN ({placeholders}) AND status = 'running' LIMIT 1", - tuple(chunk) + tuple(chunk), ) if running: raise HTTPException( status_code=400, - detail="Cannot delete running tasks. Abort them first." + detail="Cannot delete running tasks. Abort them first.", ) for i in range(0, len(task_ids), SQLITE_CHUNK_SIZE): @@ -1264,25 +1483,32 @@ async def delete_task_records(task_ids: List[str]): placeholders = ",".join(["?"] * len(chunk)) # Delete notification_history first (depends on findings via finding_id) await db.execute_no_commit( - f"DELETE FROM notification_history WHERE finding_id IN (SELECT id FROM findings WHERE task_id IN ({placeholders}))", tuple(chunk) + f"DELETE FROM notification_history WHERE finding_id IN (SELECT id FROM findings WHERE task_id IN ({placeholders}))", + tuple(chunk), ) await db.execute_no_commit( - f"DELETE FROM findings WHERE task_id IN ({placeholders})", tuple(chunk) + f"DELETE FROM findings WHERE task_id IN ({placeholders})", + tuple(chunk), ) await db.execute_no_commit( - f"DELETE FROM reports WHERE task_id IN ({placeholders})", tuple(chunk) + f"DELETE FROM reports WHERE task_id IN ({placeholders})", + tuple(chunk), ) await db.execute_no_commit( - f"DELETE FROM audit_log WHERE task_id IN ({placeholders})", tuple(chunk) + f"DELETE FROM audit_log WHERE task_id IN ({placeholders})", + tuple(chunk), ) await db.execute_no_commit( - f"DELETE FROM crawl_runs WHERE task_id IN ({placeholders})", tuple(chunk) + f"DELETE FROM crawl_runs WHERE task_id IN ({placeholders})", + tuple(chunk), ) await db.execute_no_commit( - f"DELETE FROM asset_services WHERE task_id IN ({placeholders})", tuple(chunk) + f"DELETE FROM asset_services WHERE task_id IN ({placeholders})", + tuple(chunk), ) await db.execute_no_commit( - f"DELETE FROM tasks WHERE id IN ({placeholders})", tuple(chunk) + f"DELETE FROM tasks WHERE id IN ({placeholders})", + tuple(chunk), ) # Cleanup files on disk (outside the transaction — file deletion is not @@ -1295,7 +1521,10 @@ async def delete_task_records(task_ids: List[str]): if path.exists(): path.unlink() except Exception as e: - logger.error(f"Failed to delete raw output file {row['raw_output_path']}: {e}") + logger.error( + f"Failed to delete raw output file {row['raw_output_path']}: {e}" + ) + @router.delete("/task/{task_id}") async def delete_task(task_id: str, owner: str = Depends(get_current_owner)): @@ -1307,28 +1536,33 @@ async def delete_task(task_id: str, owner: str = Depends(get_current_owner)): # cannot be deleted across owners (issue #401). existing = await db.fetchone("SELECT owner_id FROM tasks WHERE id = ?", (task_id,)) if existing is not None and existing["owner_id"] != owner: - raise HTTPException(status_code=403, detail="You do not have access to this task") + raise HTTPException( + status_code=403, detail="You do not have access to this task" + ) # Check if task is running status = await executor.get_task_status(task_id) if status and status.get("status") == "running": - raise HTTPException(status_code=400, detail="Cannot delete a running task. Abort it first.") + raise HTTPException( + status_code=400, detail="Cannot delete a running task. Abort it first." + ) # If the task is currently executing but the DB hasn't been updated yet, fail closed. if task_id in executor.running_tasks: - raise HTTPException(status_code=400, detail="Cannot delete a running task. Abort it first.") + raise HTTPException( + status_code=400, detail="Cannot delete a running task. Abort it first." + ) await delete_task_records([task_id]) await invalidate_view_cache() - return { - "task_id": task_id, - "deleted": True - } + return {"task_id": task_id, "deleted": True} @router.delete("/tasks/bulk") -async def bulk_delete_tasks(request: BulkDeleteRequest, owner: str = Depends(get_current_owner)): +async def bulk_delete_tasks( + request: BulkDeleteRequest, owner: str = Depends(get_current_owner) +): """Delete multiple tasks at once (max 500 IDs per request)""" task_ids = request.root # RootModel exposes data via .root db = await get_db() @@ -1353,22 +1587,24 @@ async def bulk_delete_tasks(request: BulkDeleteRequest, owner: str = Depends(get placeholders = ",".join(["?"] * len(owned_ids)) running_tasks = await db.fetchone( f"SELECT id FROM tasks WHERE id IN ({placeholders}) AND status = 'running' LIMIT 1", - tuple(owned_ids) + tuple(owned_ids), ) if running_tasks: - raise HTTPException(status_code=400, detail="Cannot delete running tasks. Abort them first.") + raise HTTPException( + status_code=400, detail="Cannot delete running tasks. Abort them first." + ) # If the task is currently executing but the DB hasn't been updated yet, fail closed. if any(tid in executor.running_tasks for tid in owned_ids): - raise HTTPException(status_code=400, detail="Cannot delete running tasks. Abort them first.") + raise HTTPException( + status_code=400, detail="Cannot delete running tasks. Abort them first." + ) await delete_task_records(owned_ids) await invalidate_view_cache() - return { - "deleted_count": len(owned_ids), - "success": True - } + return {"deleted_count": len(owned_ids), "success": True} + @router.delete("/tasks/clear") async def clear_all_tasks(owner: str = Depends(get_current_owner)): @@ -1385,7 +1621,9 @@ async def clear_all_tasks(owner: str = Depends(get_current_owner)): (owner,), ) if running_tasks: - raise HTTPException(status_code=400, detail="Cannot clear history while tasks are running.") + raise HTTPException( + status_code=400, detail="Cannot clear history while tasks are running." + ) # Get the caller's task IDs to delete records and cleanup files own_tasks = await db.fetchall("SELECT id FROM tasks WHERE owner_id = ?", (owner,)) @@ -1401,7 +1639,7 @@ async def clear_all_tasks(owner: str = Depends(get_current_owner)): return { "cleared": True, - "message": "All scan history and associated data has been purged." + "message": "All scan history and associated data has been purged.", } @@ -1412,26 +1650,26 @@ async def get_settings(): "network": { "bind_address": settings.bind_address, "port": settings.bind_port, - "allow_remote": False + "allow_remote": False, }, "sandbox": { "engine": "docker" if settings.docker_enabled else "subprocess", "default_timeout": settings.sandbox_timeout, "resource_limits": { "cpu_quota": settings.sandbox_cpu_quota, - "memory_mb": settings.sandbox_memory_mb - } + "memory_mb": settings.sandbox_memory_mb, + }, }, "safety": { "require_consent": settings.require_consent, "safe_mode_default": settings.safe_mode_default, - "allowed_networks": settings.allowed_networks + "allowed_networks": settings.allowed_networks, }, "execution_context": { "validation_modes": [mode.value for mode in ValidationMode], "evidence_levels": [level.value for level in EvidenceLevel], "default": ExecutionContext().model_dump(), - } + }, } @@ -1492,14 +1730,9 @@ async def get_vault_secret( db = await get_db() row = await db.fetchone( - """ - SELECT encrypted_value - FROM credential_vault - WHERE owner_id = ? AND name = ? - """, + "SELECT encrypted_value FROM credential_vault WHERE owner_id = ? AND name = ?", (owner, name), ) - if not row: raise HTTPException(status_code=404, detail="Secret not found") @@ -1541,7 +1774,9 @@ async def list_target_policies(owner: str = Depends(get_current_owner)): @router.post("/target-policies") -async def create_target_policy(payload: Dict[str, Any], owner: str = Depends(get_current_owner)): +async def create_target_policy( + payload: Dict[str, Any], owner: str = Depends(get_current_owner) +): name = str(payload.get("name", "")).strip() if not name: raise HTTPException(status_code=400, detail="Target policy name is required") @@ -1573,9 +1808,14 @@ async def create_target_policy(payload: Dict[str, Any], owner: str = Depends(get @router.patch("/target-policies/{policy_id}") -async def update_target_policy(policy_id: str, payload: Dict[str, Any], owner: str = Depends(get_current_owner)): +async def update_target_policy( + policy_id: str, payload: Dict[str, Any], owner: str = Depends(get_current_owner) +): db = await get_db() - row = await db.fetchone("SELECT id FROM target_policies WHERE id = ? AND owner_id = ?", (policy_id, owner)) + row = await db.fetchone( + "SELECT id FROM target_policies WHERE id = ? AND owner_id = ?", + (policy_id, owner), + ) if not row: raise HTTPException(status_code=404, detail="Target policy not found") updates: List[str] = [] @@ -1583,8 +1823,14 @@ async def update_target_policy(policy_id: str, payload: Dict[str, Any], owner: s for key in ("name", "description", "default_validation_mode"): if key in payload: updates.append(f"{key} = ?") - params.append(str(payload[key]).strip() if payload[key] is not None else None) - for key in ("allow_public_targets", "allow_exploit_validation", "allow_authenticated_scan"): + params.append( + str(payload[key]).strip() if payload[key] is not None else None + ) + for key in ( + "allow_public_targets", + "allow_exploit_validation", + "allow_authenticated_scan", + ): if key in payload: updates.append(f"{key} = ?") params.append(1 if payload[key] else 0) @@ -1596,15 +1842,26 @@ async def update_target_policy(policy_id: str, payload: Dict[str, Any], owner: s params.append(_json_payload(payload["metadata"], "{}")) updates.append("updated_at = datetime('now')") params.extend([policy_id, owner]) - await db.execute(f"UPDATE target_policies SET {', '.join(updates)} WHERE id = ? AND owner_id = ?", tuple(params)) - updated = await db.fetchone("SELECT * FROM target_policies WHERE id = ?", (policy_id,)) - return deserialize_resource_rows([updated])[0] if updated else {"id": policy_id, "updated": True} + await db.execute( + f"UPDATE target_policies SET {', '.join(updates)} WHERE id = ? AND owner_id = ?", + tuple(params), + ) + updated = await db.fetchone( + "SELECT * FROM target_policies WHERE id = ?", (policy_id,) + ) + return ( + deserialize_resource_rows([updated])[0] + if updated + else {"id": policy_id, "updated": True} + ) @router.delete("/target-policies/{policy_id}") async def delete_target_policy(policy_id: str, owner: str = Depends(get_current_owner)): db = await get_db() - await db.execute("DELETE FROM target_policies WHERE id = ? AND owner_id = ?", (policy_id, owner)) + await db.execute( + "DELETE FROM target_policies WHERE id = ? AND owner_id = ?", (policy_id, owner) + ) return {"id": policy_id, "deleted": True} @@ -1619,10 +1876,14 @@ async def list_credential_profiles(owner: str = Depends(get_current_owner)): @router.post("/credential-profiles") -async def create_credential_profile(payload: Dict[str, Any], owner: str = Depends(get_current_owner)): +async def create_credential_profile( + payload: Dict[str, Any], owner: str = Depends(get_current_owner) +): name = str(payload.get("name", "")).strip() if not name: - raise HTTPException(status_code=400, detail="Credential profile name is required") + raise HTTPException( + status_code=400, detail="Credential profile name is required" + ) profile_id = str(uuid.uuid4()) db = await get_db() await db.execute( @@ -1642,14 +1903,21 @@ async def create_credential_profile(payload: Dict[str, Any], owner: str = Depend _json_payload(payload.get("login_recipe"), "{}"), ), ) - row = await db.fetchone("SELECT * FROM credential_profiles WHERE id = ?", (profile_id,)) + row = await db.fetchone( + "SELECT * FROM credential_profiles WHERE id = ?", (profile_id,) + ) return deserialize_resource_rows([row])[0] if row else {"id": profile_id} @router.patch("/credential-profiles/{profile_id}") -async def update_credential_profile(profile_id: str, payload: Dict[str, Any], owner: str = Depends(get_current_owner)): +async def update_credential_profile( + profile_id: str, payload: Dict[str, Any], owner: str = Depends(get_current_owner) +): db = await get_db() - row = await db.fetchone("SELECT id FROM credential_profiles WHERE id = ? AND owner_id = ?", (profile_id, owner)) + row = await db.fetchone( + "SELECT id FROM credential_profiles WHERE id = ? AND owner_id = ?", + (profile_id, owner), + ) if not row: raise HTTPException(status_code=404, detail="Credential profile not found") updates: List[str] = [] @@ -1666,15 +1934,29 @@ async def update_credential_profile(profile_id: str, payload: Dict[str, Any], ow params.append(_json_payload(payload["login_recipe"], "{}")) updates.append("updated_at = datetime('now')") params.extend([profile_id, owner]) - await db.execute(f"UPDATE credential_profiles SET {', '.join(updates)} WHERE id = ? AND owner_id = ?", tuple(params)) - updated = await db.fetchone("SELECT * FROM credential_profiles WHERE id = ?", (profile_id,)) - return deserialize_resource_rows([updated])[0] if updated else {"id": profile_id, "updated": True} + await db.execute( + f"UPDATE credential_profiles SET {', '.join(updates)} WHERE id = ? AND owner_id = ?", + tuple(params), + ) + updated = await db.fetchone( + "SELECT * FROM credential_profiles WHERE id = ?", (profile_id,) + ) + return ( + deserialize_resource_rows([updated])[0] + if updated + else {"id": profile_id, "updated": True} + ) @router.delete("/credential-profiles/{profile_id}") -async def delete_credential_profile(profile_id: str, owner: str = Depends(get_current_owner)): +async def delete_credential_profile( + profile_id: str, owner: str = Depends(get_current_owner) +): db = await get_db() - await db.execute("DELETE FROM credential_profiles WHERE id = ? AND owner_id = ?", (profile_id, owner)) + await db.execute( + "DELETE FROM credential_profiles WHERE id = ? AND owner_id = ?", + (profile_id, owner), + ) return {"id": profile_id, "deleted": True} @@ -1689,7 +1971,9 @@ async def list_session_profiles(owner: str = Depends(get_current_owner)): @router.post("/session-profiles") -async def create_session_profile(payload: Dict[str, Any], owner: str = Depends(get_current_owner)): +async def create_session_profile( + payload: Dict[str, Any], owner: str = Depends(get_current_owner) +): name = str(payload.get("name", "")).strip() if not name: raise HTTPException(status_code=400, detail="Session profile name is required") @@ -1710,14 +1994,21 @@ async def create_session_profile(payload: Dict[str, Any], owner: str = Depends(g str(payload.get("notes", "")).strip() or None, ), ) - row = await db.fetchone("SELECT * FROM session_profiles WHERE id = ?", (profile_id,)) + row = await db.fetchone( + "SELECT * FROM session_profiles WHERE id = ?", (profile_id,) + ) return deserialize_resource_rows([row])[0] if row else {"id": profile_id} @router.patch("/session-profiles/{profile_id}") -async def update_session_profile(profile_id: str, payload: Dict[str, Any], owner: str = Depends(get_current_owner)): +async def update_session_profile( + profile_id: str, payload: Dict[str, Any], owner: str = Depends(get_current_owner) +): db = await get_db() - row = await db.fetchone("SELECT id FROM session_profiles WHERE id = ? AND owner_id = ?", (profile_id, owner)) + row = await db.fetchone( + "SELECT id FROM session_profiles WHERE id = ? AND owner_id = ?", + (profile_id, owner), + ) if not row: raise HTTPException(status_code=404, detail="Session profile not found") updates: List[str] = [] @@ -1731,15 +2022,29 @@ async def update_session_profile(profile_id: str, payload: Dict[str, Any], owner params.append(_json_payload(payload["extra_headers"], "{}")) updates.append("updated_at = datetime('now')") params.extend([profile_id, owner]) - await db.execute(f"UPDATE session_profiles SET {', '.join(updates)} WHERE id = ? AND owner_id = ?", tuple(params)) - updated = await db.fetchone("SELECT * FROM session_profiles WHERE id = ?", (profile_id,)) - return deserialize_resource_rows([updated])[0] if updated else {"id": profile_id, "updated": True} + await db.execute( + f"UPDATE session_profiles SET {', '.join(updates)} WHERE id = ? AND owner_id = ?", + tuple(params), + ) + updated = await db.fetchone( + "SELECT * FROM session_profiles WHERE id = ?", (profile_id,) + ) + return ( + deserialize_resource_rows([updated])[0] + if updated + else {"id": profile_id, "updated": True} + ) @router.delete("/session-profiles/{profile_id}") -async def delete_session_profile(profile_id: str, owner: str = Depends(get_current_owner)): +async def delete_session_profile( + profile_id: str, owner: str = Depends(get_current_owner) +): db = await get_db() - await db.execute("DELETE FROM session_profiles WHERE id = ? AND owner_id = ?", (profile_id, owner)) + await db.execute( + "DELETE FROM session_profiles WHERE id = ? AND owner_id = ?", + (profile_id, owner), + ) return {"id": profile_id, "deleted": True} @@ -1780,14 +2085,18 @@ async def list_workflows(owner: str = Depends(get_current_owner)): @router.post("/workflows") -async def create_workflow(payload: Dict[str, Any], owner: str = Depends(get_current_owner)): +async def create_workflow( + payload: Dict[str, Any], owner: str = Depends(get_current_owner) +): name = str(payload.get("name", "")).strip() if not name: raise HTTPException(status_code=400, detail="Workflow name is required") steps = _parse_workflow_steps(payload.get("steps", [])) if not steps: - raise HTTPException(status_code=400, detail="Workflow requires at least one step") + raise HTTPException( + status_code=400, detail="Workflow requires at least one step" + ) workflow_id = str(uuid.uuid4()) schedule_seconds = payload.get("schedule_seconds") @@ -1813,13 +2122,13 @@ async def create_workflow(payload: Dict[str, Any], owner: str = Depends(get_curr async def _verify_workflow_owner(db, workflow_id: str, owner: str): """Check the workflow exists and belongs to the caller. Returns the row or raises 404/403.""" - row = await db.fetchone( - "SELECT * FROM workflows WHERE id = ?", (workflow_id,) - ) + row = await db.fetchone("SELECT * FROM workflows WHERE id = ?", (workflow_id,)) if not row: raise HTTPException(status_code=404, detail="Workflow not found") if row["owner_id"] != owner: - raise HTTPException(status_code=403, detail="You do not have access to this workflow") + raise HTTPException( + status_code=403, detail="You do not have access to this workflow" + ) return row @@ -1851,8 +2160,12 @@ async def run_workflow_once(workflow_id: str, owner: str = Depends(get_current_o version_number = active_version["version_number"] created_task_ids: List[str] = [] for step in steps: - execution_context = normalize_execution_context(step.get("execution_context") or {}) - target_policy = await get_target_policy(db, owner, execution_context.get("target_policy_id")) + execution_context = normalize_execution_context( + step.get("execution_context") or {} + ) + target_policy = await get_target_policy( + db, owner, execution_context.get("target_policy_id") + ) safe_mode = bool( settings.safe_mode_default and not (target_policy and target_policy.get("allow_public_targets")) @@ -1871,7 +2184,10 @@ async def run_workflow_once(workflow_id: str, owner: str = Depends(get_current_o ) asyncio.create_task(executor.execute_task(task_id)) created_task_ids.append(task_id) - await db.execute("UPDATE workflows SET last_run_at = datetime('now') WHERE id = ?", (workflow_id,)) + await db.execute( + "UPDATE workflows SET last_run_at = datetime('now') WHERE id = ?", + (workflow_id,), + ) run_id = await db.record_workflow_run( workflow_id=workflow_id, version_id=version_id, @@ -1890,8 +2206,14 @@ async def run_workflow_once(workflow_id: str, owner: str = Depends(get_current_o + @router.get("/workflows/{workflow_id}/runs") -async def list_workflow_runs(workflow_id: str, owner: str = Depends(get_current_owner), limit: int = 50, offset: int = 0): +async def list_workflow_runs( + workflow_id: str, + owner: str = Depends(get_current_owner), + limit: int = 50, + offset: int = 0, +): """Return paginated run history for a workflow.""" if limit < 1 or limit > 500: raise HTTPException(status_code=400, detail="limit must be between 1 and 500") @@ -1899,11 +2221,15 @@ async def list_workflow_runs(workflow_id: str, owner: str = Depends(get_current_ raise HTTPException(status_code=400, detail="offset must be non-negative") db = await get_db() await _verify_workflow_owner(db, workflow_id, owner) - return await db.get_workflow_runs(workflow_id=workflow_id, limit=limit, offset=offset) + return await db.get_workflow_runs( + workflow_id=workflow_id, limit=limit, offset=offset + ) @router.get("/workflows/{workflow_id}/versions") -async def list_workflow_versions(workflow_id: str, owner: str = Depends(get_current_owner)): +async def list_workflow_versions( + workflow_id: str, owner: str = Depends(get_current_owner) +): """Return all saved version snapshots for a workflow, newest first.""" db = await get_db() await _verify_workflow_owner(db, workflow_id, owner) @@ -1912,7 +2238,9 @@ async def list_workflow_versions(workflow_id: str, owner: str = Depends(get_curr @router.post("/workflows/{workflow_id}/rollback/{version_number}") -async def rollback_workflow(workflow_id: str, version_number: int, owner: str = Depends(get_current_owner)): +async def rollback_workflow( + workflow_id: str, version_number: int, owner: str = Depends(get_current_owner) +): """Restore a workflow to a previously saved version. The target version's full definition replaces the live workflow fields. @@ -1954,7 +2282,9 @@ async def rollback_workflow(workflow_id: str, version_number: int, owner: str = @router.patch("/workflows/{workflow_id}") -async def update_workflow(workflow_id: str, payload: Dict[str, Any], owner: str = Depends(get_current_owner)): +async def update_workflow( + workflow_id: str, payload: Dict[str, Any], owner: str = Depends(get_current_owner) +): db = await get_db() row = await _verify_workflow_owner(db, workflow_id, owner) @@ -1978,7 +2308,9 @@ async def update_workflow(workflow_id: str, payload: Dict[str, Any], owner: str raise HTTPException(status_code=400, detail="No update fields provided") params.append(workflow_id) - await db.execute(f"UPDATE workflows SET {', '.join(updates)} WHERE id = ?", tuple(params)) + await db.execute( + f"UPDATE workflows SET {', '.join(updates)} WHERE id = ?", tuple(params) + ) updated = await db.fetchone("SELECT * FROM workflows WHERE id = ?", (workflow_id,)) if updated is None: return {"workflow_id": workflow_id, "updated": True} @@ -2001,7 +2333,9 @@ async def delete_workflow(workflow_id: str, owner: str = Depends(get_current_own return {"workflow_id": workflow_id, "deleted": True} -@router.post("/workflows/scheduler/tick", dependencies=[Depends(scheduler_tick_limiter)]) +@router.post( + "/workflows/scheduler/tick", dependencies=[Depends(scheduler_tick_limiter)] +) async def trigger_workflow_tick(): await scheduler.tick() return {"tick": "ok"} @@ -2019,12 +2353,16 @@ async def list_notification_rules(owner: str = Depends(get_current_owner)): @router.post("/notifications/rules") -async def create_notification_rule(payload: NotificationRuleCreate, owner: str = Depends(get_current_owner)): +async def create_notification_rule( + payload: NotificationRuleCreate, owner: str = Depends(get_current_owner) +): name = payload.name.strip() if not name: raise HTTPException(status_code=400, detail="Rule name is required") - target = _validate_notification_target(payload.channel_type, payload.target_url_or_email) + target = _validate_notification_target( + payload.channel_type, payload.target_url_or_email + ) rule_id = str(uuid.uuid4()) db = await get_db() await db.execute( @@ -2048,7 +2386,9 @@ async def create_notification_rule(payload: NotificationRuleCreate, owner: str = (rule_id,), ) if not row: - raise HTTPException(status_code=500, detail="Failed to create notification rule") + raise HTTPException( + status_code=500, detail="Failed to create notification rule" + ) return _serialize_notification_rule(row) @@ -2061,7 +2401,9 @@ async def _verify_notification_rule_owner(db, rule_id: str, owner: str): if not row: raise HTTPException(status_code=404, detail="Notification rule not found") if row["owner_id"] != owner: - raise HTTPException(status_code=403, detail="You do not have access to this notification rule") + raise HTTPException( + status_code=403, detail="You do not have access to this notification rule" + ) return row @@ -2073,7 +2415,11 @@ async def get_notification_rule(rule_id: str, owner: str = Depends(get_current_o @router.patch("/notifications/rules/{rule_id}") -async def update_notification_rule(rule_id: str, payload: NotificationRuleUpdate, owner: str = Depends(get_current_owner)): +async def update_notification_rule( + rule_id: str, + payload: NotificationRuleUpdate, + owner: str = Depends(get_current_owner), +): db = await get_db() row = await _verify_notification_rule_owner(db, rule_id, owner) @@ -2138,7 +2484,9 @@ async def update_notification_rule(rule_id: str, payload: NotificationRuleUpdate @router.delete("/notifications/rules/{rule_id}") -async def delete_notification_rule(rule_id: str, owner: str = Depends(get_current_owner)): +async def delete_notification_rule( + rule_id: str, owner: str = Depends(get_current_owner) +): db = await get_db() await _verify_notification_rule_owner(db, rule_id, owner) await db.execute("DELETE FROM notification_rules WHERE id = ?", (rule_id,)) @@ -2191,14 +2539,16 @@ async def get_finding_details(finding_id: str, owner: str = Depends(get_current_ JOIN tasks t ON f.task_id = t.id WHERE f.id = ? """, - (finding_id,) + (finding_id,), ) if not finding_row: raise HTTPException(status_code=404, detail="Finding not found") if finding_row["owner_id"] != owner: - raise HTTPException(status_code=403, detail="You do not have access to this finding") + raise HTTPException( + status_code=403, detail="You do not have access to this finding" + ) metadata = {} if finding_row["metadata_json"]: @@ -2260,30 +2610,34 @@ async def get_attack_surface(owner: str = Depends(get_current_owner)): for f in findings: target = f["target"] if target not in seen_targets: - entries.append({ - "id": str(uuid.uuid4()), - "category": f["category"], - "item": target, - "details": f"Active exposure identified in {f['category']}", - "risk": f["severity"], - "source": "Audit Scan", - "last_seen": f["discovered_at"] - }) + entries.append( + { + "id": str(uuid.uuid4()), + "category": f["category"], + "item": target, + "details": f"Active exposure identified in {f['category']}", + "risk": f["severity"], + "source": "Audit Scan", + "last_seen": f["discovered_at"], + } + ) seen_targets.add(target) # Add other scanned targets for t in tasks: target = t["target"] if target not in seen_targets: - entries.append({ - "id": str(uuid.uuid4()), - "category": "Infrastructure", - "item": target, - "details": f"Monitored via {t['tool_name']}", - "risk": "info", - "source": "Recon", - "last_seen": t["created_at"] - }) + entries.append( + { + "id": str(uuid.uuid4()), + "category": "Infrastructure", + "item": target, + "details": f"Monitored via {t['tool_name']}", + "risk": "info", + "source": "Recon", + "last_seen": t["created_at"], + } + ) seen_targets.add(target) return {"entries": entries} @@ -2305,6 +2659,7 @@ async def get_assets(owner: str = Depends(get_current_owner)): assets = [{"id": str(uuid.uuid4()), "name": row["target"]} for row in rows] return {"assets": assets} + # ── Network Policy Management Endpoints ───────────────────────────────────── from fastapi.security import APIKeyHeader @@ -2314,6 +2669,7 @@ async def get_assets(owner: str = Depends(get_current_owner)): api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + def verify_admin_access( api_key: Optional[str] = Security(api_key_header), request: Request = None, @@ -2325,14 +2681,14 @@ def verify_admin_access( if not settings.admin_api_key: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Admin API Key is not configured on the server. Please set SECUSCAN_ADMIN_API_KEY." + detail="Admin API Key is not configured on the server. Please set SECUSCAN_ADMIN_API_KEY.", ) # Entropy check: enforce a strong API key if len(settings.admin_api_key) < 16: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Admin API Key is too weak. It must be at least 16 characters long." + detail="Admin API Key is too weak. It must be at least 16 characters long.", ) candidate = api_key @@ -2353,20 +2709,25 @@ def verify_admin_access( if not candidate or not hmac.compare_digest(candidate, settings.admin_api_key): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or missing Admin API Key" + detail="Invalid or missing Admin API Key", ) return candidate + @router.get( "/admin/diagnostics/notifications", response_model=NotificationDiagnosticsResponse, - dependencies=[Depends(verify_admin_access), Depends(admin_limiter)] + dependencies=[Depends(verify_admin_access), Depends(admin_limiter)], ) async def get_notification_diagnostics(): """Get active notification delivery configuration and retry policy""" return notification_service.get_delivery_configuration() -@router.get("/admin/network-policy", dependencies=[Depends(verify_admin_access), Depends(admin_limiter)]) + +@router.get( + "/admin/network-policy", + dependencies=[Depends(verify_admin_access), Depends(admin_limiter)], +) async def get_network_policy(): """Get current network policy configuration""" engine = get_policy_engine() @@ -2377,7 +2738,11 @@ async def get_network_policy(): "audit_entries_count": len(engine.audit_entries), } -@router.post("/admin/network-policy/allow", dependencies=[Depends(verify_admin_access), Depends(admin_limiter)]) + +@router.post( + "/admin/network-policy/allow", + dependencies=[Depends(verify_admin_access), Depends(admin_limiter)], +) async def add_allow_rule(request: dict): """Add network to allowlist""" engine = get_policy_engine() @@ -2391,7 +2756,11 @@ async def add_allow_rule(request: dict): except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) -@router.post("/admin/network-policy/deny", dependencies=[Depends(verify_admin_access), Depends(admin_limiter)]) + +@router.post( + "/admin/network-policy/deny", + dependencies=[Depends(verify_admin_access), Depends(admin_limiter)], +) async def add_deny_rule(request: dict): """Add network to denylist""" engine = get_policy_engine() @@ -2405,11 +2774,13 @@ async def add_deny_rule(request: dict): except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) -@router.get("/admin/network-audit-log", dependencies=[Depends(verify_admin_access), Depends(admin_limiter)]) + +@router.get( + "/admin/network-audit-log", + dependencies=[Depends(verify_admin_access), Depends(admin_limiter)], +) async def get_audit_log( - plugin_id: Optional[str] = None, - action: Optional[str] = None, - limit: int = 100 + plugin_id: Optional[str] = None, action: Optional[str] = None, limit: int = 100 ): """Retrieve network audit log entries""" engine = get_policy_engine() @@ -2419,9 +2790,7 @@ async def get_audit_log( policy_action = PolicyAction[action.upper()] entries = engine.get_audit_entries( - plugin_id=plugin_id, - action=policy_action, - limit=limit + plugin_id=plugin_id, action=policy_action, limit=limit ) return { @@ -2429,7 +2798,11 @@ async def get_audit_log( "total": len(entries), } -@router.get("/admin/network-audit-log/export", dependencies=[Depends(verify_admin_access), Depends(admin_limiter)]) + +@router.get( + "/admin/network-audit-log/export", + dependencies=[Depends(verify_admin_access), Depends(admin_limiter)], +) async def export_audit_log(format: str = "json"): """Export audit log in specified format""" engine = get_policy_engine() @@ -2443,11 +2816,14 @@ async def export_audit_log(format: str = "json"): return Response( content=content, media_type=mime_type, - headers={"Content-Disposition": f"attachment; filename=network-audit.{format}"} + headers={"Content-Disposition": f"attachment; filename=network-audit.{format}"}, ) -@router.get("/admin/vault/diagnostics", dependencies=[Depends(verify_admin_access), Depends(admin_limiter)]) +@router.get( + "/admin/vault/diagnostics", + dependencies=[Depends(verify_admin_access), Depends(admin_limiter)], +) async def get_vault_diagnostics(): """Report non-secret diagnostics for the credential vault key. Surfaces a one-way fingerprint of the active vault key so operators can confirm key-rotation state without the key material ever leaving the server.