From 8fc52907ece988586189d0310395a5d6dbaf7345 Mon Sep 17 00:00:00 2001 From: Srijan Jaiswal Date: Tue, 9 Jun 2026 22:18:45 +0530 Subject: [PATCH 1/2] fix: workflow scheduler now applies route-level security controls - Extracted _execute_scan_safe() in routes.py as shared security entry point used by both start_task() and run_workflow_once() - Scheduler _run_workflow() now validates targets, enforces rate limits, checks network policy, acquires concurrency slots, and logs source - Added WorkflowRateLimiter to ratelimit.py with per-workflow and per-user limits - Added source parameter (api|workflow|scheduler) to executor audit logging - Added workflow config settings (min interval, max per user, consent refresh) - Added focused tests for workflow scheduler security path Closes #655 --- backend/secuscan/config.py | 5 + backend/secuscan/executor.py | 29 +- backend/secuscan/main.py | 13 +- backend/secuscan/plugins.py | 1 + backend/secuscan/ratelimit.py | 37 +++ backend/secuscan/routes.py | 280 +++++++++-------- backend/secuscan/workflows.py | 84 +++++- .../unit/test_workflow_scheduler_security.py | 283 ++++++++++++++++++ 8 files changed, 578 insertions(+), 154 deletions(-) create mode 100644 testing/backend/unit/test_workflow_scheduler_security.py diff --git a/backend/secuscan/config.py b/backend/secuscan/config.py index 5685895a9..d6380095d 100644 --- a/backend/secuscan/config.py +++ b/backend/secuscan/config.py @@ -123,6 +123,11 @@ class Settings(BaseSettings): parser_sandbox_timeout_seconds: int = 30 parser_sandbox_max_output_bytes: int = 8 * 1024 * 1024 # 8 MB + # Workflow Configuration + max_workflows_per_user: int = 50 + workflow_min_interval_seconds: int = 60 + workflow_consent_refresh_days: int = 30 + # Logging log_level: str = "INFO" log_file: str = str(PROJECT_ROOT / "logs" / "secuscan.log") diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index 4317b476f..ee4f8b9a7 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -245,6 +245,7 @@ async def create_task( execution_context: Optional[Dict[str, Any]] = None, consent_granted: bool = False, owner_id: str = DEFAULT_OWNER_ID, + source: str = "api", ) -> str: """ Create a new scan task. @@ -258,6 +259,7 @@ async def create_task( access (issue #401). Defaults to the shared default owner for internal callers (workflows, scheduler, CLI) that are not tied to a request. + source: Origin of this task (api|workflow|scheduler) Returns: Task ID @@ -300,7 +302,7 @@ async def create_task( ) ) - # Log audit event + # Log audit event with source tracking await db.log_audit( "task_created", f"Task created for {plugin.name}", @@ -309,6 +311,7 @@ async def create_task( "plugin_id": plugin_id, "target": inputs.get("target"), "execution_context": normalize_execution_context(execution_context), + "source": source, }, task_id=task_id, plugin_id=plugin_id @@ -660,11 +663,19 @@ async def execute_task(self, task_id: str): await self._broadcast(task_id, "status", final_status) await self._invalidate_cached_views() - # Log completion + task_source_row = await db.fetchone("SELECT inputs_json FROM tasks WHERE id = ?", (task_id,)) + source = "api" + if task_source_row: + try: + tj = json.loads(task_source_row["inputs_json"]) + source = tj.get("_source", "api") + except (json.JSONDecodeError, TypeError): + pass + await db.log_audit( "task_completed", f"Task completed in {duration:.2f}s", - context={"task_id": task_id, "exit_code": locals().get('exit_code', 0)}, + context={"task_id": task_id, "exit_code": locals().get('exit_code', 0), "source": source}, task_id=task_id, plugin_id=plugin_id ) @@ -672,11 +683,8 @@ async def execute_task(self, task_id: str): logger.info(f"Task {task_id} completed in {duration:.2f}s") except asyncio.CancelledError: - # CancelledError inherits from BaseException, not Exception — - # it bypasses the broad except below, so we handle it explicitly. - # Task.cancelled() returns False while the finally block is still - # executing, so this is the only reliable place to write the - # cancellation status to the DB. + self.running_tasks.pop(task_id, None) + self._process_pids.pop(task_id, None) duration = (time.time() - start_time) if 'start_time' in locals() else 0 await db.execute( """ @@ -699,6 +707,8 @@ async def execute_task(self, task_id: str): raise # let asyncio complete the cancellation except CapabilityDeniedError as e: + self.running_tasks.pop(task_id, None) + self._process_pids.pop(task_id, None) logger.warning("Task %s blocked by capability policy: %s", task_id, e) duration = (time.time() - start_time) if "start_time" in locals() else 0 await db.execute( @@ -733,9 +743,10 @@ async def execute_task(self, task_id: str): ) except Exception as e: + self.running_tasks.pop(task_id, None) + self._process_pids.pop(task_id, None) logger.error(f"Task {task_id} failed: {e}", exc_info=True) - # Update task as failed duration = (time.time() - start_time) if 'start_time' in locals() else 0 await db.execute( """ diff --git a/backend/secuscan/main.py b/backend/secuscan/main.py index 8e06d6638..99f393d38 100644 --- a/backend/secuscan/main.py +++ b/backend/secuscan/main.py @@ -26,7 +26,7 @@ from .cache import init_cache, cache as global_cache from .database import init_db, db as global_db from .plugins import init_plugins -from .routes import router +from .routes import router, cancel_pending_workflow_tasks from .saved_views import saved_views_router from .workflows import scheduler @@ -122,11 +122,14 @@ async def lifespan(app: FastAPI): # Shutdown logger.info("🛑 Shutting down SecuScan backend...") - if global_db: - await global_db.disconnect() - if global_cache: - await global_cache.disconnect() + from . import database as database_module + from . import cache as cache_module + if database_module.db: + await database_module.db.disconnect() + if cache_module.cache: + await cache_module.cache.disconnect() await scheduler.stop() + await cancel_pending_workflow_tasks() logger.info("✓ Shutdown complete") # Create FastAPI application diff --git a/backend/secuscan/plugins.py b/backend/secuscan/plugins.py index 436844cea..84e9ee693 100644 --- a/backend/secuscan/plugins.py +++ b/backend/secuscan/plugins.py @@ -30,6 +30,7 @@ "consent_granted", "dry_run", "debug_mode", + "_source", }) logger = logging.getLogger(__name__) diff --git a/backend/secuscan/ratelimit.py b/backend/secuscan/ratelimit.py index e1f4064aa..e7f2aa213 100644 --- a/backend/secuscan/ratelimit.py +++ b/backend/secuscan/ratelimit.py @@ -237,9 +237,46 @@ async def reset(self): self.last_cleanup = None +class WorkflowRateLimiter: + """Rate limiter for workflow-triggered scans.""" + + def __init__(self): + self._last_run: Dict[str, datetime] = {} + self._user_workflow_count: Dict[str, int] = {} + self.lock = asyncio.Lock() + + async def check_workflow_rate_limit(self, workflow_id: str, min_interval_seconds: int) -> Tuple[bool, str]: + async with self.lock: + now = datetime.now() + last = self._last_run.get(workflow_id) + if last and (now - last).total_seconds() < min_interval_seconds: + remaining = min_interval_seconds - (now - last).total_seconds() + return False, f"Workflow rate limited: wait {remaining:.0f}s between runs" + self._last_run[workflow_id] = now + return True, "" + + async def check_user_workflow_limit(self, user_id: str, max_workflows: int) -> Tuple[bool, str]: + async with self.lock: + count = self._user_workflow_count.get(user_id, 0) + if count >= max_workflows: + return False, f"User workflow limit reached ({max_workflows})" + return True, "" + + async def register_user_workflow(self, user_id: str): + async with self.lock: + self._user_workflow_count[user_id] = self._user_workflow_count.get(user_id, 0) + 1 + + async def unregister_user_workflow(self, user_id: str): + async with self.lock: + current = self._user_workflow_count.get(user_id, 0) + if current > 0: + self._user_workflow_count[user_id] = current - 1 + + # Global instances rate_limiter = RateLimiter() concurrent_limiter = ConcurrentTaskLimiter() +workflow_rate_limiter = WorkflowRateLimiter() # Route-specific limiters task_start_limiter = EndpointRateLimiter( diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index d3e2f595f..cead68e92 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, BackgroundTasks, Response, Request, Depends, Body, Query from fastapi.responses import JSONResponse -from typing import Any, Optional, List, Dict, Callable +from typing import Any, Optional, List, Dict, Callable, Set import json import logging import re @@ -14,6 +14,22 @@ from pathlib import Path from urllib.parse import urlencode, urlparse +_pending_workflow_tasks: Set[asyncio.Task] = set() + + +def _track_task(task: asyncio.Task) -> None: + _pending_workflow_tasks.add(task) + task.add_done_callback(_pending_workflow_tasks.discard) + + +async def cancel_pending_workflow_tasks() -> None: + for task in list(_pending_workflow_tasks): + task.cancel() + if _pending_workflow_tasks: + await asyncio.gather(*_pending_workflow_tasks, return_exceptions=True) + _pending_workflow_tasks.clear() + + def parse_json_fields(rows: List[Dict], fields: List[str]) -> List[Dict]: """Helper to parse stringified JSON fields from SQLite.""" parsed = [] @@ -174,7 +190,7 @@ def build_report_filename(task: Dict[str, Any], extension: str) -> str: from .executor import executor from .redaction import redact_inputs from .ratelimit import ( - rate_limiter, concurrent_limiter, + rate_limiter, concurrent_limiter, workflow_rate_limiter, task_start_limiter, vault_limiter, report_download_limiter, read_heavy_limiter, resolve_client_identity, admin_limiter, @@ -184,7 +200,7 @@ def build_report_filename(task: Dict[str, Any], extension: str) -> str: from .reporting import reporting from .vault import VaultCrypto from .workflows import scheduler -from .auth import require_api_key, get_current_owner +from .auth import require_api_key, get_current_owner, DEFAULT_OWNER_ID from .execution_context import is_offensive_validation, normalize_execution_context from .finding_intelligence import build_asset_summary, build_finding_groups from .knowledgebase import KnowledgeBase @@ -301,6 +317,99 @@ def _report_generation_error_response(task_id: str, report_format: str) -> JSONR ) +async def _execute_scan_safe( + plugin_id: str, + inputs: Dict[str, Any], + consent_granted: bool, + preset: Optional[str] = None, + owner: str = DEFAULT_OWNER_ID, + source: str = "api", + client_id: Optional[str] = None, + target_policy: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Shared scan execution that applies consent, safe mode, target validation, + rate limiting, and concurrency limits. Used by both the API route and the + workflow runner.""" + if settings.require_consent and not consent_granted: + raise HTTPException( + status_code=400, + detail="Consent required. You must acknowledge the legal notice." + ) + + plugin_manager = await get_plugin_manager_for_request() + plugin = plugin_manager.get_plugin(plugin_id) + + if not plugin: + raise HTTPException(status_code=404, detail=f"Plugin not found: {plugin_id}") + + safe_mode = bool( + settings.safe_mode_default + and not (target_policy and target_policy.get("allow_public_targets")) + ) + effective_inputs = dict(inputs) + if "safe_mode" in effective_inputs: + effective_inputs.pop("safe_mode", None) + effective_inputs["safe_mode"] = safe_mode + effective_inputs["_source"] = source + + for tkey in ("timeout", "max_scan_time"): + declared = any(getattr(f, "id", None) == tkey for f in (plugin.fields or [])) + if not declared: + continue + if tkey in effective_inputs and effective_inputs[tkey] not in (None, ""): + try: + tval = int(effective_inputs[tkey]) + except (TypeError, ValueError): + raise HTTPException(status_code=400, detail=f"Invalid value for {tkey}: must be an integer") + if tval <= 0 or tval > settings.sandbox_timeout: + raise HTTPException(status_code=400, detail=f"{tkey} must be between 1 and {settings.sandbox_timeout} seconds") + + if target := effective_inputs.get("target"): + target_str = str(target) + should_validate = plugin.category != "code" and not is_filesystem_target(target_str) + if should_validate: + try: + is_valid, error_msg = await asyncio.wait_for( + asyncio.to_thread(validate_target, target_str, safe_mode), + timeout=float(settings.dns_resolution_timeout_seconds), + ) + except asyncio.TimeoutError: + raise HTTPException(status_code=400, detail="Target validation timed out in safe mode (SecuScan Guardrail)") + if not is_valid: + raise HTTPException(status_code=400, detail=error_msg) + + client = client_id or f"user:{owner}" + can_exec, err = await rate_limiter.can_execute( + plugin_id, + plugin.safety.get("rate_limit", {}).get("max_per_hour", settings.max_tasks_per_hour), + client_id=client, + ) + if not can_exec: + raise HTTPException(status_code=429, detail=err) + + task_id = await executor.create_task( + plugin_id, + effective_inputs, + safe_mode=safe_mode, + preset=preset, + consent_granted=consent_granted, + owner_id=owner, + source=source, + ) + + can_acquire, concurrency_err = await concurrent_limiter.acquire(task_id) + if not can_acquire: + await executor.mark_task_failed(task_id, reason="Concurrency limit reached") + raise HTTPException(status_code=503, detail=concurrency_err) + + return { + "task_id": task_id, + "status": "queued", + "created_at": "now", + "stream_url": f"/api/v1/task/{task_id}/stream" + } + + async def get_plugin_manager_for_request(): """ In debug mode, refresh plugin metadata from disk on demand so frontend catalog @@ -392,22 +501,6 @@ async def start_task( if not ok: raise HTTPException(status_code=status_code, detail=error_msg) - # Validate consent - if settings.require_consent and not request.consent_granted: - logger.warning(f"Task start failed: Consent not granted. Request: {request}") - raise HTTPException( - status_code=400, - detail="Consent required. You must acknowledge the legal notice." - ) - - # Get plugin - plugin_manager = await get_plugin_manager_for_request() - plugin = plugin_manager.get_plugin(request.plugin_id) - - if not plugin: - logger.warning(f"Task start failed: Plugin not found: {request.plugin_id}") - raise HTTPException(status_code=404, detail=f"Plugin not found: {request.plugin_id}") - db = await get_db() target_policy = await get_target_policy(db, owner, execution_context.get("target_policy_id")) credential_profile = await get_credential_profile(db, owner, execution_context.get("credential_profile_id")) @@ -426,6 +519,11 @@ async def start_task( detail="Authenticated scans require a target policy with authenticated scanning enabled.", ) + plugin_manager = await get_plugin_manager_for_request() + plugin = plugin_manager.get_plugin(request.plugin_id) + if not plugin: + raise HTTPException(status_code=404, detail=f"Plugin not found: {request.plugin_id}") + requires_exploit_policy = ( plugin.safety.get("level") == "exploit" or execution_context.get("validation_mode") == ValidationMode.CONTROLLED_EXTRACT.value @@ -437,103 +535,22 @@ async def start_task( detail="Offensive validation requires a target policy that explicitly allows exploit validation.", ) - # Server-controlled safe mode: public-target scans are opt-in via target policy. - safe_mode = bool( - settings.safe_mode_default - and not (target_policy and target_policy.get("allow_public_targets")) - ) - - # Ensure downstream scanners/plugins see the effective safe-mode, but prevent client override. - effective_inputs = dict(request.inputs or {}) - if "safe_mode" in effective_inputs: - effective_inputs.pop("safe_mode", None) - effective_inputs["safe_mode"] = safe_mode - - # Validate numeric timeout inputs at request time to prevent unsafe values - for tkey in ("timeout", "max_scan_time"): - # Only enforce bounds if the plugin declares the field in its schema - declared = any(getattr(f, "id", None) == tkey for f in (plugin.fields or [])) - if not declared: - continue - if tkey in effective_inputs and effective_inputs[tkey] not in (None, ""): - try: - tval = int(effective_inputs[tkey]) - except (TypeError, ValueError): - raise HTTPException(status_code=400, detail=f"Invalid value for {tkey}: must be an integer") - if tval <= 0 or tval > settings.sandbox_timeout: - raise HTTPException(status_code=400, detail=f"{tkey} must be between 1 and {settings.sandbox_timeout} seconds") - - if target := effective_inputs.get("target"): - target_str = str(target) - should_validate_target = plugin.category != "code" and not is_filesystem_target(target_str) - - if should_validate_target: - try: - is_valid, error_msg = await asyncio.wait_for( - asyncio.to_thread(validate_target, target_str, safe_mode), - timeout=float(settings.dns_resolution_timeout_seconds), - ) - except asyncio.TimeoutError: - logger.warning("Task start failed: Target validation timed out for '%s'", target_str) - raise HTTPException( - status_code=400, - detail="Target validation timed out in safe mode (SecuScan Guardrail)", - ) - - if not is_valid: - logger.warning(f"Task start failed: Target validation failed for '{target}': {error_msg}") - raise HTTPException(status_code=400, detail=error_msg) - - # Check rate limits per (client, plugin) so one client cannot exhaust - # the quota for all other users of the same plugin. client_id = resolve_client_identity(raw_request) - can_execute, error_msg = await rate_limiter.can_execute( - request.plugin_id, - plugin.safety.get("rate_limit", {}).get("max_per_hour", settings.max_tasks_per_hour), + result = await _execute_scan_safe( + plugin_id=request.plugin_id, + inputs=request.inputs or {}, + consent_granted=request.consent_granted, + preset=request.preset, + owner=owner, + source="api", client_id=client_id, + target_policy=target_policy, ) - if not can_execute: - raise HTTPException(status_code=429, detail=error_msg) - - # Create task record first so we have a real task_id for the limiter - try: - task_id = await executor.create_task( - request.plugin_id, - effective_inputs, - safe_mode=safe_mode, - preset=request.preset, - execution_context=execution_context, - consent_granted=request.consent_granted, - owner_id=owner, - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - - # Atomically acquire a concurrency slot using the real task_id. - # acquire() is lock-protected internally, so the check and register - # happen in a single operation — no TOCTOU window between requests. - can_acquire, error_msg = await concurrent_limiter.acquire(task_id) - if not can_acquire: - # Roll back: mark the DB row failed so it isn't left orphaned - await executor.mark_task_failed(task_id, reason="Concurrency limit reached; task was not started") - raise HTTPException(status_code=503, detail=error_msg) - - # Slot is held — schedule execution. - # execute_task releases the slot in its finally block on every exit path. - # - # Use BackgroundTasks so the response can be sent without waiting in real - # ASGI servers, while tests using TestClient still execute the task to keep - # contract tests deterministic. - background_tasks.add_task(executor.execute_task, task_id) + background_tasks.add_task(executor.execute_task, result["task_id"]) await invalidate_view_cache() - return { - "task_id": task_id, - "status": "queued", - "created_at": "now", - "stream_url": f"/api/v1/task/{task_id}/stream" - } + return result @router.get("/task/{task_id}/status") async def get_task_status(task_id: str, owner: str = Depends(get_current_owner)): @@ -1726,11 +1743,20 @@ async def create_workflow(payload: Dict[str, Any]): @router.post("/workflows/{workflow_id}/run") -async def run_workflow_once(workflow_id: str, owner: str = Depends(get_current_owner)): +async def run_workflow_once( + workflow_id: str, + owner: str = Depends(get_current_owner), +): db = await get_db() row = await db.fetchone("SELECT * FROM workflows WHERE id = ?", (workflow_id,)) if not row: raise HTTPException(status_code=404, detail="Workflow not found") + wf_rate_ok, wf_rate_msg = await workflow_rate_limiter.check_workflow_rate_limit( + workflow_id, settings.workflow_min_interval_seconds + ) + if not wf_rate_ok: + raise HTTPException(status_code=429, detail=wf_rate_msg) + steps = _parse_workflow_steps(row["steps_json"] or "[]") active_version = await db.fetchone( "SELECT id, version_number FROM workflow_versions " @@ -1741,26 +1767,20 @@ async def run_workflow_once(workflow_id: str, owner: str = Depends(get_current_o version_number = active_version["version_number"] if active_version else None created_task_ids: List[str] = [] for step in steps: - execution_context = normalize_execution_context(step.get("execution_context") or {}) - target_policy = await get_target_policy(db, owner, execution_context.get("target_policy_id")) - safe_mode = bool( - settings.safe_mode_default - and not (target_policy and target_policy.get("allow_public_targets")) - ) - effective_inputs = dict(step.get("inputs", {}) or {}) - effective_inputs.pop("safe_mode", None) - effective_inputs["safe_mode"] = safe_mode - task_id = await executor.create_task( - step.get("plugin_id"), - effective_inputs, - safe_mode=safe_mode, - preset=step.get("preset"), - execution_context=execution_context, + plugin_id = step.get("plugin_id") + if not plugin_id: + continue + + result = await _execute_scan_safe( + plugin_id=plugin_id, + inputs=step.get("inputs", {}), consent_granted=True, - owner_id=owner, + preset=step.get("preset"), + owner=DEFAULT_OWNER_ID, + source="workflow", ) - asyncio.create_task(executor.execute_task(task_id)) - created_task_ids.append(task_id) + created_task_ids.append(result["task_id"]) + asyncio.create_task(executor.execute_task(result["task_id"])) await db.execute("UPDATE workflows SET last_run_at = datetime('now') WHERE id = ?", (workflow_id,)) run_id = await db.record_workflow_run( workflow_id=workflow_id, @@ -1769,7 +1789,7 @@ async def run_workflow_once(workflow_id: str, owner: str = Depends(get_current_o task_ids=created_task_ids, triggered_by="manual", ) - asyncio.create_task(_finalize_workflow_run(run_id)) + _track_task(asyncio.create_task(_finalize_workflow_run(run_id))) return { "workflow_id": workflow_id, "run_id": run_id, diff --git a/backend/secuscan/workflows.py b/backend/secuscan/workflows.py index c7ba88dc7..003d95eb1 100644 --- a/backend/secuscan/workflows.py +++ b/backend/secuscan/workflows.py @@ -2,13 +2,16 @@ from __future__ import annotations from .request_context import get_request_id, set_request_id import asyncio +from typing import Set import json import logging from datetime import datetime, timezone -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from .database import get_db from .config import settings +from .ratelimit import workflow_rate_limiter, rate_limiter, concurrent_limiter from .executor import executor +from .auth import DEFAULT_OWNER_ID from .execution_context import normalize_execution_context from .platform_resources import get_target_policy logger = logging.getLogger(__name__) @@ -16,6 +19,11 @@ class WorkflowScheduler: def __init__(self): self._task: asyncio.Task | None = None self._running = False + self._child_tasks: Set[asyncio.Task] = set() + + def _track_child(self, task: asyncio.Task) -> None: + self._child_tasks.add(task) + task.add_done_callback(self._child_tasks.discard) async def start(self): if self._task and not self._task.done(): @@ -32,6 +40,11 @@ async def stop(self): except asyncio.CancelledError: pass self._task = None + for task in list(self._child_tasks): + task.cancel() + if self._child_tasks: + await asyncio.gather(*self._child_tasks, return_exceptions=True) + self._child_tasks.clear() logger.info("Workflow scheduler stopped") async def _run_loop(self): while self._running: @@ -53,6 +66,14 @@ async def tick(self): for row in rows: if not self._should_run(now, row.get("last_run_at"), int(row["schedule_seconds"])): continue + + wf_rate_ok, wf_rate_msg = await workflow_rate_limiter.check_workflow_rate_limit( + row["id"], settings.workflow_min_interval_seconds + ) + if not wf_rate_ok: + logger.warning("Workflow %s skipped by rate limiter: %s", row["id"], wf_rate_msg) + continue + await self._run_workflow(row["id"], json.loads(row.get("steps_json") or "[]")) await db.execute( "UPDATE workflows SET last_run_at = datetime('now') WHERE id = ?", @@ -62,10 +83,6 @@ def _should_run(self, now: datetime, last_run_at: str | None, schedule_seconds: if not last_run_at: return True last = datetime.fromisoformat(last_run_at.replace("Z", "+00:00")) - # SQLite's datetime('now') produces "2026-05-25 08:02:28" — no Z and - # no +00:00 suffix — so fromisoformat() returns a naive datetime. - # Subtracting a naive datetime from an aware one raises TypeError. - # Treat any naive timestamp from the DB as UTC. if last.tzinfo is None: last = last.replace(tzinfo=timezone.utc) elapsed = (now - last).total_seconds() @@ -85,9 +102,52 @@ async def _run_workflow(self, workflow_id: str, steps: List[Dict[str, Any]]): settings.safe_mode_default and not (target_policy and target_policy.get("allow_public_targets")) ) + + from .plugins import get_plugin_manager + from .validation import validate_target + from .network_policy import get_policy_engine + + plugin_manager = get_plugin_manager() + plugin = plugin_manager.get_plugin(plugin_id) + if not plugin: + logger.warning("Workflow %s: plugin %s not found, skipping step", workflow_id, plugin_id) + continue effective_inputs = dict(inputs) effective_inputs.pop("safe_mode", None) effective_inputs["safe_mode"] = safe_mode + effective_inputs["_source"] = "scheduler" + + if target := effective_inputs.get("target"): + target_str = str(target) + if plugin.category != "code": + try: + is_valid, error_msg = await asyncio.wait_for( + asyncio.to_thread(validate_target, target_str, safe_mode), + timeout=float(settings.dns_resolution_timeout_seconds), + ) + if not is_valid: + logger.warning("Workflow %s: target validation failed for step %s: %s", workflow_id, plugin_id, error_msg) + continue + except asyncio.TimeoutError: + logger.warning("Workflow %s: target validation timed out for step %s", workflow_id, plugin_id) + continue + + if settings.enforce_network_policy and target_str: + engine = get_policy_engine() + allowed, reason, _ = await asyncio.wait_for( + asyncio.to_thread(engine.check_access, dest_ip=target_str, plugin_id=plugin_id, task_id=""), + timeout=float(settings.dns_resolution_timeout_seconds), + ) + if not allowed: + logger.warning("Workflow %s: network policy denied %s: %s", workflow_id, target_str, reason) + continue + + client = f"user:{DEFAULT_OWNER_ID}" + max_per_hour = plugin.safety.get("rate_limit", {}).get("max_per_hour", settings.max_tasks_per_hour) if plugin else settings.max_tasks_per_hour + can_exec, rate_err = await rate_limiter.can_execute(plugin_id, max_per_hour, client_id=client) + if not can_exec: + logger.warning("Workflow %s: rate limit exceeded for %s: %s", workflow_id, plugin_id, rate_err) + continue task_id = await executor.create_task( plugin_id, @@ -96,13 +156,17 @@ async def _run_workflow(self, workflow_id: str, steps: List[Dict[str, Any]]): preset=step.get("preset"), execution_context=execution_context, consent_granted=True, + owner_id=DEFAULT_OWNER_ID, + source="scheduler", ) - async def run_task(task_id: str) -> None: - set_request_id(request_id) - await executor.execute_task(task_id) + can_acquire, concurrency_err = await concurrent_limiter.acquire(task_id) + if not can_acquire: + await executor.mark_task_failed(task_id, reason="Concurrency limit reached") + logger.warning("Workflow %s: concurrency limit reached for %s", workflow_id, plugin_id) + continue - asyncio.create_task(run_task(task_id)) + self._track_child(asyncio.create_task(executor.execute_task(task_id))) -scheduler = WorkflowScheduler() +scheduler = WorkflowScheduler() \ No newline at end of file diff --git a/testing/backend/unit/test_workflow_scheduler_security.py b/testing/backend/unit/test_workflow_scheduler_security.py new file mode 100644 index 000000000..d25408c63 --- /dev/null +++ b/testing/backend/unit/test_workflow_scheduler_security.py @@ -0,0 +1,283 @@ +""" +Tests for workflow scheduler route-level security controls. + +Verifies that the scheduler path applies the same consent, target validation, +rate limiting, and concurrency controls as the API path. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from backend.secuscan.workflows import WorkflowScheduler +from backend.secuscan.ratelimit import WorkflowRateLimiter + + +@pytest.fixture +def scheduler(): + return WorkflowScheduler() + + +@pytest.fixture +def rate_limiter(): + return WorkflowRateLimiter() + + +# --------------------------------------------------------------------------- +# WorkflowRateLimiter unit tests +# --------------------------------------------------------------------------- + +class TestWorkflowRateLimiter: + @pytest.mark.asyncio + async def test_allows_first_run(self, rate_limiter): + ok, msg = await rate_limiter.check_workflow_rate_limit("wf-1", 60) + assert ok is True + assert msg == "" + + @pytest.mark.asyncio + async def test_blocks_second_run_within_interval(self, rate_limiter): + await rate_limiter.check_workflow_rate_limit("wf-1", 60) + ok, msg = await rate_limiter.check_workflow_rate_limit("wf-1", 60) + assert ok is False + assert "rate limited" in msg.lower() + + @pytest.mark.asyncio + async def test_allows_different_workflows_independently(self, rate_limiter): + await rate_limiter.check_workflow_rate_limit("wf-1", 60) + ok, msg = await rate_limiter.check_workflow_rate_limit("wf-2", 60) + assert ok is True + + @pytest.mark.asyncio + async def test_user_workflow_limit(self, rate_limiter): + ok, msg = await rate_limiter.check_user_workflow_limit("user-1", 5) + assert ok is True + await rate_limiter.register_user_workflow("user-1") + await rate_limiter.register_user_workflow("user-1") + await rate_limiter.register_user_workflow("user-1") + await rate_limiter.register_user_workflow("user-1") + await rate_limiter.register_user_workflow("user-1") + ok, msg = await rate_limiter.check_user_workflow_limit("user-1", 5) + assert ok is False + assert "limit reached" in msg.lower() + + @pytest.mark.asyncio + async def test_unregister_user_workflow(self, rate_limiter): + await rate_limiter.register_user_workflow("user-1") + await rate_limiter.unregister_user_workflow("user-1") + ok, msg = await rate_limiter.check_user_workflow_limit("user-1", 1) + assert ok is True + + +# --------------------------------------------------------------------------- +# WorkflowScheduler._run_workflow security control tests +# --------------------------------------------------------------------------- +# Note: _run_workflow() uses local imports inside the function body +# (e.g., "from .plugins import get_plugin_manager"), so we patch the +# original module paths rather than the local names. + +class TestSchedulerSecurityControls: + @pytest.mark.asyncio + async def test_skips_step_when_plugin_not_found(self, scheduler): + steps = [{"plugin_id": "nonexistent-plugin", "inputs": {}}] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock) as mock_get_db, \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm: + + mock_db = AsyncMock() + mock_get_db.return_value = mock_db + mock_pm = MagicMock() + mock_pm.get_plugin.return_value = None + mock_get_pm.return_value = mock_pm + + await scheduler._run_workflow("wf-1", steps) + mock_db.record_workflow_run.assert_not_called() + + @pytest.mark.asyncio + async def test_skips_step_when_target_validation_fails(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "invalid-target"}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(False, "Target not allowed")) as mock_val: + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + + await scheduler._run_workflow("wf-1", steps) + mock_val.assert_called_once() + + @pytest.mark.asyncio + async def test_skips_step_when_rate_limit_exceeded(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "example.com"}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.rate_limiter.can_execute", new_callable=AsyncMock) as mock_rate: + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + mock_rate.return_value = (False, "Rate limit exceeded") + + await scheduler._run_workflow("wf-1", steps) + mock_rate.assert_called_once() + + @pytest.mark.asyncio + async def test_applies_source_tag_in_inputs(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "example.com"}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.rate_limiter.can_execute", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.concurrent_limiter.acquire", return_value=(True, "")), \ + patch("backend.secuscan.executor.executor.create_task", new_callable=AsyncMock, return_value="task-1") as mock_create: + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + + await scheduler._run_workflow("wf-1", steps) + args, kwargs = mock_create.call_args + assert kwargs.get("source") == "scheduler" + inputs = args[1] if len(args) > 1 else kwargs.get("inputs", {}) + assert inputs.get("_source") == "scheduler" + + @pytest.mark.asyncio + async def test_applies_safe_mode_consistently(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "example.com", "safe_mode": False}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.rate_limiter.can_execute", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.concurrent_limiter.acquire", return_value=(True, "")), \ + patch("backend.secuscan.executor.executor.create_task", new_callable=AsyncMock, return_value="task-1") as mock_create: + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + + await scheduler._run_workflow("wf-1", steps) + args, kwargs = mock_create.call_args + inputs = args[1] if len(args) > 1 else kwargs.get("inputs", {}) + assert "safe_mode" in inputs + assert inputs["safe_mode"] is True + + @pytest.mark.asyncio + async def test_acquires_concurrency_slot(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "example.com"}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.rate_limiter.can_execute", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.concurrent_limiter.acquire", new_callable=AsyncMock) as mock_acquire, \ + patch("backend.secuscan.executor.executor.create_task", new_callable=AsyncMock, return_value="task-1"): + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + mock_acquire.return_value = (True, "") + + await scheduler._run_workflow("wf-1", steps) + mock_acquire.assert_called_once_with("task-1") + + @pytest.mark.asyncio + async def test_skips_step_when_concurrency_limit_reached(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "example.com"}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.rate_limiter.can_execute", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.concurrent_limiter.acquire", return_value=(False, "Concurrency limit reached")), \ + patch("backend.secuscan.executor.executor.create_task", new_callable=AsyncMock, return_value="task-1"), \ + patch("backend.secuscan.executor.executor.mark_task_failed", new_callable=AsyncMock) as mock_fail: + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + + await scheduler._run_workflow("wf-1", steps) + mock_fail.assert_called_once() + + +# --------------------------------------------------------------------------- +# WorkflowScheduler.tick rate limit integration +# --------------------------------------------------------------------------- + +class TestTickRateLimiting: + @pytest.mark.asyncio + async def test_tick_applies_workflow_rate_limiter(self, scheduler): + db_mock = AsyncMock() + db_mock.fetchall.return_value = [{ + "id": "wf-1", + "name": "test", + "schedule_seconds": 60, + "last_run_at": None, + "steps_json": "[]", + }] + with patch("backend.secuscan.workflows.get_db", return_value=db_mock), \ + patch.object(scheduler, "_run_workflow", new_callable=AsyncMock) as mock_run, \ + patch("backend.secuscan.workflows.workflow_rate_limiter.check_workflow_rate_limit", new_callable=AsyncMock) as mock_rate: + + mock_rate.return_value = (True, "") + await scheduler.tick() + mock_rate.assert_called_once_with("wf-1", 60) + mock_run.assert_called_once() + + @pytest.mark.asyncio + async def test_tick_skips_rate_limited_workflow(self, scheduler): + db_mock = AsyncMock() + db_mock.fetchall.return_value = [{ + "id": "wf-1", + "name": "test", + "schedule_seconds": 60, + "last_run_at": None, + "steps_json": "[]", + }] + with patch("backend.secuscan.workflows.get_db", return_value=db_mock), \ + patch.object(scheduler, "_run_workflow", new_callable=AsyncMock) as mock_run, \ + patch("backend.secuscan.workflows.workflow_rate_limiter.check_workflow_rate_limit", new_callable=AsyncMock) as mock_rate: + + mock_rate.return_value = (False, "Workflow rate limited: wait 30s between runs") + await scheduler.tick() + mock_rate.assert_called_once() + mock_run.assert_not_called() From e779be5dac65057a6970b7231fb7c970a8ccf162 Mon Sep 17 00:00:00 2001 From: Srijan Jaiswal Date: Tue, 9 Jun 2026 23:31:00 +0530 Subject: [PATCH 2/2] fix: apply route-level security controls to scheduler-triggered workflow scans - _run_workflow() now validates plugin existence, targets in safe mode, enforces network policy, rate limits per (client, plugin), and acquires concurrency slots before executing each step - tick() enforces workflow_min_interval_seconds via WorkflowRateLimiter - run_workflow_once() applies the same workflow rate limit - Added WorkflowRateLimiter with per-workflow rate limiting - Added workflow_min_interval_seconds config setting - Each check failure gracefully logs and skips the step Executor lifecycle, shutdown handling, route refactoring, and plugin input handling split into separate PRs. Closes #655 --- backend/secuscan/config.py | 2 - backend/secuscan/executor.py | 29 +- backend/secuscan/main.py | 13 +- backend/secuscan/plugins.py | 1 - backend/secuscan/ratelimit.py | 20 +- backend/secuscan/routes.py | 273 +++++++++--------- backend/secuscan/workflows.py | 27 +- .../unit/test_workflow_scheduler_security.py | 51 +--- 8 files changed, 157 insertions(+), 259 deletions(-) diff --git a/backend/secuscan/config.py b/backend/secuscan/config.py index d6380095d..c84fa5cf5 100644 --- a/backend/secuscan/config.py +++ b/backend/secuscan/config.py @@ -124,9 +124,7 @@ class Settings(BaseSettings): parser_sandbox_max_output_bytes: int = 8 * 1024 * 1024 # 8 MB # Workflow Configuration - max_workflows_per_user: int = 50 workflow_min_interval_seconds: int = 60 - workflow_consent_refresh_days: int = 30 # Logging log_level: str = "INFO" diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index ee4f8b9a7..4317b476f 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -245,7 +245,6 @@ async def create_task( execution_context: Optional[Dict[str, Any]] = None, consent_granted: bool = False, owner_id: str = DEFAULT_OWNER_ID, - source: str = "api", ) -> str: """ Create a new scan task. @@ -259,7 +258,6 @@ async def create_task( access (issue #401). Defaults to the shared default owner for internal callers (workflows, scheduler, CLI) that are not tied to a request. - source: Origin of this task (api|workflow|scheduler) Returns: Task ID @@ -302,7 +300,7 @@ async def create_task( ) ) - # Log audit event with source tracking + # Log audit event await db.log_audit( "task_created", f"Task created for {plugin.name}", @@ -311,7 +309,6 @@ async def create_task( "plugin_id": plugin_id, "target": inputs.get("target"), "execution_context": normalize_execution_context(execution_context), - "source": source, }, task_id=task_id, plugin_id=plugin_id @@ -663,19 +660,11 @@ async def execute_task(self, task_id: str): await self._broadcast(task_id, "status", final_status) await self._invalidate_cached_views() - task_source_row = await db.fetchone("SELECT inputs_json FROM tasks WHERE id = ?", (task_id,)) - source = "api" - if task_source_row: - try: - tj = json.loads(task_source_row["inputs_json"]) - source = tj.get("_source", "api") - except (json.JSONDecodeError, TypeError): - pass - + # Log completion await db.log_audit( "task_completed", f"Task completed in {duration:.2f}s", - context={"task_id": task_id, "exit_code": locals().get('exit_code', 0), "source": source}, + context={"task_id": task_id, "exit_code": locals().get('exit_code', 0)}, task_id=task_id, plugin_id=plugin_id ) @@ -683,8 +672,11 @@ async def execute_task(self, task_id: str): logger.info(f"Task {task_id} completed in {duration:.2f}s") except asyncio.CancelledError: - self.running_tasks.pop(task_id, None) - self._process_pids.pop(task_id, None) + # CancelledError inherits from BaseException, not Exception — + # it bypasses the broad except below, so we handle it explicitly. + # Task.cancelled() returns False while the finally block is still + # executing, so this is the only reliable place to write the + # cancellation status to the DB. duration = (time.time() - start_time) if 'start_time' in locals() else 0 await db.execute( """ @@ -707,8 +699,6 @@ async def execute_task(self, task_id: str): raise # let asyncio complete the cancellation except CapabilityDeniedError as e: - self.running_tasks.pop(task_id, None) - self._process_pids.pop(task_id, None) logger.warning("Task %s blocked by capability policy: %s", task_id, e) duration = (time.time() - start_time) if "start_time" in locals() else 0 await db.execute( @@ -743,10 +733,9 @@ async def execute_task(self, task_id: str): ) except Exception as e: - self.running_tasks.pop(task_id, None) - self._process_pids.pop(task_id, None) logger.error(f"Task {task_id} failed: {e}", exc_info=True) + # Update task as failed duration = (time.time() - start_time) if 'start_time' in locals() else 0 await db.execute( """ diff --git a/backend/secuscan/main.py b/backend/secuscan/main.py index 99f393d38..8e06d6638 100644 --- a/backend/secuscan/main.py +++ b/backend/secuscan/main.py @@ -26,7 +26,7 @@ from .cache import init_cache, cache as global_cache from .database import init_db, db as global_db from .plugins import init_plugins -from .routes import router, cancel_pending_workflow_tasks +from .routes import router from .saved_views import saved_views_router from .workflows import scheduler @@ -122,14 +122,11 @@ async def lifespan(app: FastAPI): # Shutdown logger.info("🛑 Shutting down SecuScan backend...") - from . import database as database_module - from . import cache as cache_module - if database_module.db: - await database_module.db.disconnect() - if cache_module.cache: - await cache_module.cache.disconnect() + if global_db: + await global_db.disconnect() + if global_cache: + await global_cache.disconnect() await scheduler.stop() - await cancel_pending_workflow_tasks() logger.info("✓ Shutdown complete") # Create FastAPI application diff --git a/backend/secuscan/plugins.py b/backend/secuscan/plugins.py index 84e9ee693..436844cea 100644 --- a/backend/secuscan/plugins.py +++ b/backend/secuscan/plugins.py @@ -30,7 +30,6 @@ "consent_granted", "dry_run", "debug_mode", - "_source", }) logger = logging.getLogger(__name__) diff --git a/backend/secuscan/ratelimit.py b/backend/secuscan/ratelimit.py index e7f2aa213..8cf4c5e75 100644 --- a/backend/secuscan/ratelimit.py +++ b/backend/secuscan/ratelimit.py @@ -238,11 +238,10 @@ async def reset(self): class WorkflowRateLimiter: - """Rate limiter for workflow-triggered scans.""" + """Rate limiter for scheduler-triggered workflow scans.""" def __init__(self): self._last_run: Dict[str, datetime] = {} - self._user_workflow_count: Dict[str, int] = {} self.lock = asyncio.Lock() async def check_workflow_rate_limit(self, workflow_id: str, min_interval_seconds: int) -> Tuple[bool, str]: @@ -255,23 +254,6 @@ async def check_workflow_rate_limit(self, workflow_id: str, min_interval_seconds self._last_run[workflow_id] = now return True, "" - async def check_user_workflow_limit(self, user_id: str, max_workflows: int) -> Tuple[bool, str]: - async with self.lock: - count = self._user_workflow_count.get(user_id, 0) - if count >= max_workflows: - return False, f"User workflow limit reached ({max_workflows})" - return True, "" - - async def register_user_workflow(self, user_id: str): - async with self.lock: - self._user_workflow_count[user_id] = self._user_workflow_count.get(user_id, 0) + 1 - - async def unregister_user_workflow(self, user_id: str): - async with self.lock: - current = self._user_workflow_count.get(user_id, 0) - if current > 0: - self._user_workflow_count[user_id] = current - 1 - # Global instances rate_limiter = RateLimiter() diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index cead68e92..8a42ad433 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, BackgroundTasks, Response, Request, Depends, Body, Query from fastapi.responses import JSONResponse -from typing import Any, Optional, List, Dict, Callable, Set +from typing import Any, Optional, List, Dict, Callable import json import logging import re @@ -14,22 +14,6 @@ from pathlib import Path from urllib.parse import urlencode, urlparse -_pending_workflow_tasks: Set[asyncio.Task] = set() - - -def _track_task(task: asyncio.Task) -> None: - _pending_workflow_tasks.add(task) - task.add_done_callback(_pending_workflow_tasks.discard) - - -async def cancel_pending_workflow_tasks() -> None: - for task in list(_pending_workflow_tasks): - task.cancel() - if _pending_workflow_tasks: - await asyncio.gather(*_pending_workflow_tasks, return_exceptions=True) - _pending_workflow_tasks.clear() - - def parse_json_fields(rows: List[Dict], fields: List[str]) -> List[Dict]: """Helper to parse stringified JSON fields from SQLite.""" parsed = [] @@ -200,7 +184,7 @@ def build_report_filename(task: Dict[str, Any], extension: str) -> str: from .reporting import reporting from .vault import VaultCrypto from .workflows import scheduler -from .auth import require_api_key, get_current_owner, DEFAULT_OWNER_ID +from .auth import require_api_key, get_current_owner from .execution_context import is_offensive_validation, normalize_execution_context from .finding_intelligence import build_asset_summary, build_finding_groups from .knowledgebase import KnowledgeBase @@ -317,99 +301,6 @@ def _report_generation_error_response(task_id: str, report_format: str) -> JSONR ) -async def _execute_scan_safe( - plugin_id: str, - inputs: Dict[str, Any], - consent_granted: bool, - preset: Optional[str] = None, - owner: str = DEFAULT_OWNER_ID, - source: str = "api", - client_id: Optional[str] = None, - target_policy: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """Shared scan execution that applies consent, safe mode, target validation, - rate limiting, and concurrency limits. Used by both the API route and the - workflow runner.""" - if settings.require_consent and not consent_granted: - raise HTTPException( - status_code=400, - detail="Consent required. You must acknowledge the legal notice." - ) - - plugin_manager = await get_plugin_manager_for_request() - plugin = plugin_manager.get_plugin(plugin_id) - - if not plugin: - raise HTTPException(status_code=404, detail=f"Plugin not found: {plugin_id}") - - safe_mode = bool( - settings.safe_mode_default - and not (target_policy and target_policy.get("allow_public_targets")) - ) - effective_inputs = dict(inputs) - if "safe_mode" in effective_inputs: - effective_inputs.pop("safe_mode", None) - effective_inputs["safe_mode"] = safe_mode - effective_inputs["_source"] = source - - for tkey in ("timeout", "max_scan_time"): - declared = any(getattr(f, "id", None) == tkey for f in (plugin.fields or [])) - if not declared: - continue - if tkey in effective_inputs and effective_inputs[tkey] not in (None, ""): - try: - tval = int(effective_inputs[tkey]) - except (TypeError, ValueError): - raise HTTPException(status_code=400, detail=f"Invalid value for {tkey}: must be an integer") - if tval <= 0 or tval > settings.sandbox_timeout: - raise HTTPException(status_code=400, detail=f"{tkey} must be between 1 and {settings.sandbox_timeout} seconds") - - if target := effective_inputs.get("target"): - target_str = str(target) - should_validate = plugin.category != "code" and not is_filesystem_target(target_str) - if should_validate: - try: - is_valid, error_msg = await asyncio.wait_for( - asyncio.to_thread(validate_target, target_str, safe_mode), - timeout=float(settings.dns_resolution_timeout_seconds), - ) - except asyncio.TimeoutError: - raise HTTPException(status_code=400, detail="Target validation timed out in safe mode (SecuScan Guardrail)") - if not is_valid: - raise HTTPException(status_code=400, detail=error_msg) - - client = client_id or f"user:{owner}" - can_exec, err = await rate_limiter.can_execute( - plugin_id, - plugin.safety.get("rate_limit", {}).get("max_per_hour", settings.max_tasks_per_hour), - client_id=client, - ) - if not can_exec: - raise HTTPException(status_code=429, detail=err) - - task_id = await executor.create_task( - plugin_id, - effective_inputs, - safe_mode=safe_mode, - preset=preset, - consent_granted=consent_granted, - owner_id=owner, - source=source, - ) - - can_acquire, concurrency_err = await concurrent_limiter.acquire(task_id) - if not can_acquire: - await executor.mark_task_failed(task_id, reason="Concurrency limit reached") - raise HTTPException(status_code=503, detail=concurrency_err) - - return { - "task_id": task_id, - "status": "queued", - "created_at": "now", - "stream_url": f"/api/v1/task/{task_id}/stream" - } - - async def get_plugin_manager_for_request(): """ In debug mode, refresh plugin metadata from disk on demand so frontend catalog @@ -501,6 +392,22 @@ async def start_task( if not ok: raise HTTPException(status_code=status_code, detail=error_msg) + # Validate consent + if settings.require_consent and not request.consent_granted: + logger.warning(f"Task start failed: Consent not granted. Request: {request}") + raise HTTPException( + status_code=400, + detail="Consent required. You must acknowledge the legal notice." + ) + + # Get plugin + plugin_manager = await get_plugin_manager_for_request() + plugin = plugin_manager.get_plugin(request.plugin_id) + + if not plugin: + logger.warning(f"Task start failed: Plugin not found: {request.plugin_id}") + raise HTTPException(status_code=404, detail=f"Plugin not found: {request.plugin_id}") + db = await get_db() target_policy = await get_target_policy(db, owner, execution_context.get("target_policy_id")) credential_profile = await get_credential_profile(db, owner, execution_context.get("credential_profile_id")) @@ -519,11 +426,6 @@ async def start_task( detail="Authenticated scans require a target policy with authenticated scanning enabled.", ) - plugin_manager = await get_plugin_manager_for_request() - plugin = plugin_manager.get_plugin(request.plugin_id) - if not plugin: - raise HTTPException(status_code=404, detail=f"Plugin not found: {request.plugin_id}") - requires_exploit_policy = ( plugin.safety.get("level") == "exploit" or execution_context.get("validation_mode") == ValidationMode.CONTROLLED_EXTRACT.value @@ -535,22 +437,103 @@ async def start_task( detail="Offensive validation requires a target policy that explicitly allows exploit validation.", ) + # Server-controlled safe mode: public-target scans are opt-in via target policy. + safe_mode = bool( + settings.safe_mode_default + and not (target_policy and target_policy.get("allow_public_targets")) + ) + + # Ensure downstream scanners/plugins see the effective safe-mode, but prevent client override. + effective_inputs = dict(request.inputs or {}) + if "safe_mode" in effective_inputs: + effective_inputs.pop("safe_mode", None) + effective_inputs["safe_mode"] = safe_mode + + # Validate numeric timeout inputs at request time to prevent unsafe values + for tkey in ("timeout", "max_scan_time"): + # Only enforce bounds if the plugin declares the field in its schema + declared = any(getattr(f, "id", None) == tkey for f in (plugin.fields or [])) + if not declared: + continue + if tkey in effective_inputs and effective_inputs[tkey] not in (None, ""): + try: + tval = int(effective_inputs[tkey]) + except (TypeError, ValueError): + raise HTTPException(status_code=400, detail=f"Invalid value for {tkey}: must be an integer") + if tval <= 0 or tval > settings.sandbox_timeout: + raise HTTPException(status_code=400, detail=f"{tkey} must be between 1 and {settings.sandbox_timeout} seconds") + + if target := effective_inputs.get("target"): + target_str = str(target) + should_validate_target = plugin.category != "code" and not is_filesystem_target(target_str) + + if should_validate_target: + try: + is_valid, error_msg = await asyncio.wait_for( + asyncio.to_thread(validate_target, target_str, safe_mode), + timeout=float(settings.dns_resolution_timeout_seconds), + ) + except asyncio.TimeoutError: + logger.warning("Task start failed: Target validation timed out for '%s'", target_str) + raise HTTPException( + status_code=400, + detail="Target validation timed out in safe mode (SecuScan Guardrail)", + ) + + if not is_valid: + logger.warning(f"Task start failed: Target validation failed for '{target}': {error_msg}") + raise HTTPException(status_code=400, detail=error_msg) + + # Check rate limits per (client, plugin) so one client cannot exhaust + # the quota for all other users of the same plugin. client_id = resolve_client_identity(raw_request) - result = await _execute_scan_safe( - plugin_id=request.plugin_id, - inputs=request.inputs or {}, - consent_granted=request.consent_granted, - preset=request.preset, - owner=owner, - source="api", + can_execute, error_msg = await rate_limiter.can_execute( + request.plugin_id, + plugin.safety.get("rate_limit", {}).get("max_per_hour", settings.max_tasks_per_hour), client_id=client_id, - target_policy=target_policy, ) - background_tasks.add_task(executor.execute_task, result["task_id"]) + if not can_execute: + raise HTTPException(status_code=429, detail=error_msg) + + # Create task record first so we have a real task_id for the limiter + try: + task_id = await executor.create_task( + request.plugin_id, + effective_inputs, + safe_mode=safe_mode, + preset=request.preset, + execution_context=execution_context, + consent_granted=request.consent_granted, + owner_id=owner, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + # Atomically acquire a concurrency slot using the real task_id. + # acquire() is lock-protected internally, so the check and register + # happen in a single operation — no TOCTOU window between requests. + can_acquire, error_msg = await concurrent_limiter.acquire(task_id) + if not can_acquire: + # Roll back: mark the DB row failed so it isn't left orphaned + await executor.mark_task_failed(task_id, reason="Concurrency limit reached; task was not started") + raise HTTPException(status_code=503, detail=error_msg) + + # Slot is held — schedule execution. + # execute_task releases the slot in its finally block on every exit path. + # + # Use BackgroundTasks so the response can be sent without waiting in real + # ASGI servers, while tests using TestClient still execute the task to keep + # contract tests deterministic. + background_tasks.add_task(executor.execute_task, task_id) await invalidate_view_cache() - return result + return { + "task_id": task_id, + "status": "queued", + "created_at": "now", + "stream_url": f"/api/v1/task/{task_id}/stream" + } @router.get("/task/{task_id}/status") async def get_task_status(task_id: str, owner: str = Depends(get_current_owner)): @@ -1743,10 +1726,7 @@ async def create_workflow(payload: Dict[str, Any]): @router.post("/workflows/{workflow_id}/run") -async def run_workflow_once( - workflow_id: str, - owner: str = Depends(get_current_owner), -): +async def run_workflow_once(workflow_id: str, owner: str = Depends(get_current_owner)): db = await get_db() row = await db.fetchone("SELECT * FROM workflows WHERE id = ?", (workflow_id,)) if not row: @@ -1756,7 +1736,6 @@ async def run_workflow_once( ) if not wf_rate_ok: raise HTTPException(status_code=429, detail=wf_rate_msg) - steps = _parse_workflow_steps(row["steps_json"] or "[]") active_version = await db.fetchone( "SELECT id, version_number FROM workflow_versions " @@ -1767,20 +1746,26 @@ async def run_workflow_once( version_number = active_version["version_number"] if active_version else None created_task_ids: List[str] = [] for step in steps: - plugin_id = step.get("plugin_id") - if not plugin_id: - continue - - result = await _execute_scan_safe( - plugin_id=plugin_id, - inputs=step.get("inputs", {}), - consent_granted=True, + execution_context = normalize_execution_context(step.get("execution_context") or {}) + target_policy = await get_target_policy(db, owner, execution_context.get("target_policy_id")) + safe_mode = bool( + settings.safe_mode_default + and not (target_policy and target_policy.get("allow_public_targets")) + ) + effective_inputs = dict(step.get("inputs", {}) or {}) + effective_inputs.pop("safe_mode", None) + effective_inputs["safe_mode"] = safe_mode + task_id = await executor.create_task( + step.get("plugin_id"), + effective_inputs, + safe_mode=safe_mode, preset=step.get("preset"), - owner=DEFAULT_OWNER_ID, - source="workflow", + execution_context=execution_context, + consent_granted=True, + owner_id=owner, ) - created_task_ids.append(result["task_id"]) - asyncio.create_task(executor.execute_task(result["task_id"])) + asyncio.create_task(executor.execute_task(task_id)) + created_task_ids.append(task_id) await db.execute("UPDATE workflows SET last_run_at = datetime('now') WHERE id = ?", (workflow_id,)) run_id = await db.record_workflow_run( workflow_id=workflow_id, @@ -1789,7 +1774,7 @@ async def run_workflow_once( task_ids=created_task_ids, triggered_by="manual", ) - _track_task(asyncio.create_task(_finalize_workflow_run(run_id))) + asyncio.create_task(_finalize_workflow_run(run_id)) return { "workflow_id": workflow_id, "run_id": run_id, diff --git a/backend/secuscan/workflows.py b/backend/secuscan/workflows.py index 003d95eb1..74eb7b0d2 100644 --- a/backend/secuscan/workflows.py +++ b/backend/secuscan/workflows.py @@ -2,11 +2,10 @@ from __future__ import annotations from .request_context import get_request_id, set_request_id import asyncio -from typing import Set import json import logging from datetime import datetime, timezone -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from .database import get_db from .config import settings from .ratelimit import workflow_rate_limiter, rate_limiter, concurrent_limiter @@ -19,11 +18,6 @@ class WorkflowScheduler: def __init__(self): self._task: asyncio.Task | None = None self._running = False - self._child_tasks: Set[asyncio.Task] = set() - - def _track_child(self, task: asyncio.Task) -> None: - self._child_tasks.add(task) - task.add_done_callback(self._child_tasks.discard) async def start(self): if self._task and not self._task.done(): @@ -40,11 +34,6 @@ async def stop(self): except asyncio.CancelledError: pass self._task = None - for task in list(self._child_tasks): - task.cancel() - if self._child_tasks: - await asyncio.gather(*self._child_tasks, return_exceptions=True) - self._child_tasks.clear() logger.info("Workflow scheduler stopped") async def _run_loop(self): while self._running: @@ -83,6 +72,10 @@ def _should_run(self, now: datetime, last_run_at: str | None, schedule_seconds: if not last_run_at: return True last = datetime.fromisoformat(last_run_at.replace("Z", "+00:00")) + # SQLite's datetime('now') produces "2026-05-25 08:02:28" — no Z and + # no +00:00 suffix — so fromisoformat() returns a naive datetime. + # Subtracting a naive datetime from an aware one raises TypeError. + # Treat any naive timestamp from the DB as UTC. if last.tzinfo is None: last = last.replace(tzinfo=timezone.utc) elapsed = (now - last).total_seconds() @@ -115,7 +108,6 @@ async def _run_workflow(self, workflow_id: str, steps: List[Dict[str, Any]]): effective_inputs = dict(inputs) effective_inputs.pop("safe_mode", None) effective_inputs["safe_mode"] = safe_mode - effective_inputs["_source"] = "scheduler" if target := effective_inputs.get("target"): target_str = str(target) @@ -157,7 +149,6 @@ async def _run_workflow(self, workflow_id: str, steps: List[Dict[str, Any]]): execution_context=execution_context, consent_granted=True, owner_id=DEFAULT_OWNER_ID, - source="scheduler", ) can_acquire, concurrency_err = await concurrent_limiter.acquire(task_id) @@ -166,7 +157,11 @@ async def _run_workflow(self, workflow_id: str, steps: List[Dict[str, Any]]): logger.warning("Workflow %s: concurrency limit reached for %s", workflow_id, plugin_id) continue - self._track_child(asyncio.create_task(executor.execute_task(task_id))) + async def run_task(task_id: str) -> None: + set_request_id(request_id) + await executor.execute_task(task_id) + + asyncio.create_task(run_task(task_id)) -scheduler = WorkflowScheduler() \ No newline at end of file +scheduler = WorkflowScheduler() diff --git a/testing/backend/unit/test_workflow_scheduler_security.py b/testing/backend/unit/test_workflow_scheduler_security.py index d25408c63..34ec0ba22 100644 --- a/testing/backend/unit/test_workflow_scheduler_security.py +++ b/testing/backend/unit/test_workflow_scheduler_security.py @@ -1,8 +1,8 @@ """ Tests for workflow scheduler route-level security controls. -Verifies that the scheduler path applies the same consent, target validation, -rate limiting, and concurrency controls as the API path. +Verifies that the scheduler path applies target validation, rate limiting, +and concurrency controls consistent with the API path. """ import pytest @@ -46,26 +46,6 @@ async def test_allows_different_workflows_independently(self, rate_limiter): ok, msg = await rate_limiter.check_workflow_rate_limit("wf-2", 60) assert ok is True - @pytest.mark.asyncio - async def test_user_workflow_limit(self, rate_limiter): - ok, msg = await rate_limiter.check_user_workflow_limit("user-1", 5) - assert ok is True - await rate_limiter.register_user_workflow("user-1") - await rate_limiter.register_user_workflow("user-1") - await rate_limiter.register_user_workflow("user-1") - await rate_limiter.register_user_workflow("user-1") - await rate_limiter.register_user_workflow("user-1") - ok, msg = await rate_limiter.check_user_workflow_limit("user-1", 5) - assert ok is False - assert "limit reached" in msg.lower() - - @pytest.mark.asyncio - async def test_unregister_user_workflow(self, rate_limiter): - await rate_limiter.register_user_workflow("user-1") - await rate_limiter.unregister_user_workflow("user-1") - ok, msg = await rate_limiter.check_user_workflow_limit("user-1", 1) - assert ok is True - # --------------------------------------------------------------------------- # WorkflowScheduler._run_workflow security control tests @@ -134,33 +114,6 @@ async def test_skips_step_when_rate_limit_exceeded(self, scheduler): await scheduler._run_workflow("wf-1", steps) mock_rate.assert_called_once() - @pytest.mark.asyncio - async def test_applies_source_tag_in_inputs(self, scheduler): - steps = [{ - "plugin_id": "nmap", - "inputs": {"target": "example.com"}, - }] - with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ - patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ - patch("backend.secuscan.validation.validate_target", return_value=(True, "")), \ - patch("backend.secuscan.ratelimit.rate_limiter.can_execute", return_value=(True, "")), \ - patch("backend.secuscan.ratelimit.concurrent_limiter.acquire", return_value=(True, "")), \ - patch("backend.secuscan.executor.executor.create_task", new_callable=AsyncMock, return_value="task-1") as mock_create: - - mock_pm = MagicMock() - plugin = MagicMock() - plugin.category = "network" - plugin.safety = {"rate_limit": {"max_per_hour": 50}} - plugin.fields = [] - mock_pm.get_plugin.return_value = plugin - mock_get_pm.return_value = mock_pm - - await scheduler._run_workflow("wf-1", steps) - args, kwargs = mock_create.call_args - assert kwargs.get("source") == "scheduler" - inputs = args[1] if len(args) > 1 else kwargs.get("inputs", {}) - assert inputs.get("_source") == "scheduler" - @pytest.mark.asyncio async def test_applies_safe_mode_consistently(self, scheduler): steps = [{