diff --git a/backend/secuscan/config.py b/backend/secuscan/config.py index 2d8fc503c..6cc4c314c 100644 --- a/backend/secuscan/config.py +++ b/backend/secuscan/config.py @@ -123,6 +123,9 @@ class Settings(BaseSettings): parser_sandbox_timeout_seconds: int = 30 parser_sandbox_max_output_bytes: int = 8 * 1024 * 1024 # 8 MB + # Workflow Configuration + workflow_min_interval_seconds: int = 60 + # Logging log_level: str = "INFO" log_file: str = str(PROJECT_ROOT / "logs" / "secuscan.log") diff --git a/backend/secuscan/ratelimit.py b/backend/secuscan/ratelimit.py index e1f4064aa..8cf4c5e75 100644 --- a/backend/secuscan/ratelimit.py +++ b/backend/secuscan/ratelimit.py @@ -237,9 +237,28 @@ async def reset(self): self.last_cleanup = None +class WorkflowRateLimiter: + """Rate limiter for scheduler-triggered workflow scans.""" + + def __init__(self): + self._last_run: Dict[str, datetime] = {} + self.lock = asyncio.Lock() + + async def check_workflow_rate_limit(self, workflow_id: str, min_interval_seconds: int) -> Tuple[bool, str]: + async with self.lock: + now = datetime.now() + last = self._last_run.get(workflow_id) + if last and (now - last).total_seconds() < min_interval_seconds: + remaining = min_interval_seconds - (now - last).total_seconds() + return False, f"Workflow rate limited: wait {remaining:.0f}s between runs" + self._last_run[workflow_id] = now + return True, "" + + # Global instances rate_limiter = RateLimiter() concurrent_limiter = ConcurrentTaskLimiter() +workflow_rate_limiter = WorkflowRateLimiter() # Route-specific limiters task_start_limiter = EndpointRateLimiter( diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index d3e2f595f..8a42ad433 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -174,7 +174,7 @@ def build_report_filename(task: Dict[str, Any], extension: str) -> str: from .executor import executor from .redaction import redact_inputs from .ratelimit import ( - rate_limiter, concurrent_limiter, + rate_limiter, concurrent_limiter, workflow_rate_limiter, task_start_limiter, vault_limiter, report_download_limiter, read_heavy_limiter, resolve_client_identity, admin_limiter, @@ -1731,6 +1731,11 @@ async def run_workflow_once(workflow_id: str, owner: str = Depends(get_current_o row = await db.fetchone("SELECT * FROM workflows WHERE id = ?", (workflow_id,)) if not row: raise HTTPException(status_code=404, detail="Workflow not found") + wf_rate_ok, wf_rate_msg = await workflow_rate_limiter.check_workflow_rate_limit( + workflow_id, settings.workflow_min_interval_seconds + ) + if not wf_rate_ok: + raise HTTPException(status_code=429, detail=wf_rate_msg) steps = _parse_workflow_steps(row["steps_json"] or "[]") active_version = await db.fetchone( "SELECT id, version_number FROM workflow_versions " diff --git a/backend/secuscan/workflows.py b/backend/secuscan/workflows.py index c7ba88dc7..74eb7b0d2 100644 --- a/backend/secuscan/workflows.py +++ b/backend/secuscan/workflows.py @@ -8,7 +8,9 @@ from typing import Any, Dict, List from .database import get_db from .config import settings +from .ratelimit import workflow_rate_limiter, rate_limiter, concurrent_limiter from .executor import executor +from .auth import DEFAULT_OWNER_ID from .execution_context import normalize_execution_context from .platform_resources import get_target_policy logger = logging.getLogger(__name__) @@ -53,6 +55,14 @@ async def tick(self): for row in rows: if not self._should_run(now, row.get("last_run_at"), int(row["schedule_seconds"])): continue + + wf_rate_ok, wf_rate_msg = await workflow_rate_limiter.check_workflow_rate_limit( + row["id"], settings.workflow_min_interval_seconds + ) + if not wf_rate_ok: + logger.warning("Workflow %s skipped by rate limiter: %s", row["id"], wf_rate_msg) + continue + await self._run_workflow(row["id"], json.loads(row.get("steps_json") or "[]")) await db.execute( "UPDATE workflows SET last_run_at = datetime('now') WHERE id = ?", @@ -85,10 +95,52 @@ async def _run_workflow(self, workflow_id: str, steps: List[Dict[str, Any]]): settings.safe_mode_default and not (target_policy and target_policy.get("allow_public_targets")) ) + + from .plugins import get_plugin_manager + from .validation import validate_target + from .network_policy import get_policy_engine + + plugin_manager = get_plugin_manager() + plugin = plugin_manager.get_plugin(plugin_id) + if not plugin: + logger.warning("Workflow %s: plugin %s not found, skipping step", workflow_id, plugin_id) + continue effective_inputs = dict(inputs) effective_inputs.pop("safe_mode", None) effective_inputs["safe_mode"] = safe_mode + if target := effective_inputs.get("target"): + target_str = str(target) + if plugin.category != "code": + try: + is_valid, error_msg = await asyncio.wait_for( + asyncio.to_thread(validate_target, target_str, safe_mode), + timeout=float(settings.dns_resolution_timeout_seconds), + ) + if not is_valid: + logger.warning("Workflow %s: target validation failed for step %s: %s", workflow_id, plugin_id, error_msg) + continue + except asyncio.TimeoutError: + logger.warning("Workflow %s: target validation timed out for step %s", workflow_id, plugin_id) + continue + + if settings.enforce_network_policy and target_str: + engine = get_policy_engine() + allowed, reason, _ = await asyncio.wait_for( + asyncio.to_thread(engine.check_access, dest_ip=target_str, plugin_id=plugin_id, task_id=""), + timeout=float(settings.dns_resolution_timeout_seconds), + ) + if not allowed: + logger.warning("Workflow %s: network policy denied %s: %s", workflow_id, target_str, reason) + continue + + client = f"user:{DEFAULT_OWNER_ID}" + max_per_hour = plugin.safety.get("rate_limit", {}).get("max_per_hour", settings.max_tasks_per_hour) if plugin else settings.max_tasks_per_hour + can_exec, rate_err = await rate_limiter.can_execute(plugin_id, max_per_hour, client_id=client) + if not can_exec: + logger.warning("Workflow %s: rate limit exceeded for %s: %s", workflow_id, plugin_id, rate_err) + continue + task_id = await executor.create_task( plugin_id, effective_inputs, @@ -96,8 +148,15 @@ async def _run_workflow(self, workflow_id: str, steps: List[Dict[str, Any]]): preset=step.get("preset"), execution_context=execution_context, consent_granted=True, + owner_id=DEFAULT_OWNER_ID, ) + can_acquire, concurrency_err = await concurrent_limiter.acquire(task_id) + if not can_acquire: + await executor.mark_task_failed(task_id, reason="Concurrency limit reached") + logger.warning("Workflow %s: concurrency limit reached for %s", workflow_id, plugin_id) + continue + async def run_task(task_id: str) -> None: set_request_id(request_id) await executor.execute_task(task_id) diff --git a/testing/backend/unit/test_workflow_scheduler_security.py b/testing/backend/unit/test_workflow_scheduler_security.py new file mode 100644 index 000000000..34ec0ba22 --- /dev/null +++ b/testing/backend/unit/test_workflow_scheduler_security.py @@ -0,0 +1,236 @@ +""" +Tests for workflow scheduler route-level security controls. + +Verifies that the scheduler path applies target validation, rate limiting, +and concurrency controls consistent with the API path. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from backend.secuscan.workflows import WorkflowScheduler +from backend.secuscan.ratelimit import WorkflowRateLimiter + + +@pytest.fixture +def scheduler(): + return WorkflowScheduler() + + +@pytest.fixture +def rate_limiter(): + return WorkflowRateLimiter() + + +# --------------------------------------------------------------------------- +# WorkflowRateLimiter unit tests +# --------------------------------------------------------------------------- + +class TestWorkflowRateLimiter: + @pytest.mark.asyncio + async def test_allows_first_run(self, rate_limiter): + ok, msg = await rate_limiter.check_workflow_rate_limit("wf-1", 60) + assert ok is True + assert msg == "" + + @pytest.mark.asyncio + async def test_blocks_second_run_within_interval(self, rate_limiter): + await rate_limiter.check_workflow_rate_limit("wf-1", 60) + ok, msg = await rate_limiter.check_workflow_rate_limit("wf-1", 60) + assert ok is False + assert "rate limited" in msg.lower() + + @pytest.mark.asyncio + async def test_allows_different_workflows_independently(self, rate_limiter): + await rate_limiter.check_workflow_rate_limit("wf-1", 60) + ok, msg = await rate_limiter.check_workflow_rate_limit("wf-2", 60) + assert ok is True + + +# --------------------------------------------------------------------------- +# WorkflowScheduler._run_workflow security control tests +# --------------------------------------------------------------------------- +# Note: _run_workflow() uses local imports inside the function body +# (e.g., "from .plugins import get_plugin_manager"), so we patch the +# original module paths rather than the local names. + +class TestSchedulerSecurityControls: + @pytest.mark.asyncio + async def test_skips_step_when_plugin_not_found(self, scheduler): + steps = [{"plugin_id": "nonexistent-plugin", "inputs": {}}] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock) as mock_get_db, \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm: + + mock_db = AsyncMock() + mock_get_db.return_value = mock_db + mock_pm = MagicMock() + mock_pm.get_plugin.return_value = None + mock_get_pm.return_value = mock_pm + + await scheduler._run_workflow("wf-1", steps) + mock_db.record_workflow_run.assert_not_called() + + @pytest.mark.asyncio + async def test_skips_step_when_target_validation_fails(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "invalid-target"}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(False, "Target not allowed")) as mock_val: + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + + await scheduler._run_workflow("wf-1", steps) + mock_val.assert_called_once() + + @pytest.mark.asyncio + async def test_skips_step_when_rate_limit_exceeded(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "example.com"}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.rate_limiter.can_execute", new_callable=AsyncMock) as mock_rate: + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + mock_rate.return_value = (False, "Rate limit exceeded") + + await scheduler._run_workflow("wf-1", steps) + mock_rate.assert_called_once() + + @pytest.mark.asyncio + async def test_applies_safe_mode_consistently(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "example.com", "safe_mode": False}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.rate_limiter.can_execute", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.concurrent_limiter.acquire", return_value=(True, "")), \ + patch("backend.secuscan.executor.executor.create_task", new_callable=AsyncMock, return_value="task-1") as mock_create: + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + + await scheduler._run_workflow("wf-1", steps) + args, kwargs = mock_create.call_args + inputs = args[1] if len(args) > 1 else kwargs.get("inputs", {}) + assert "safe_mode" in inputs + assert inputs["safe_mode"] is True + + @pytest.mark.asyncio + async def test_acquires_concurrency_slot(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "example.com"}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.rate_limiter.can_execute", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.concurrent_limiter.acquire", new_callable=AsyncMock) as mock_acquire, \ + patch("backend.secuscan.executor.executor.create_task", new_callable=AsyncMock, return_value="task-1"): + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + mock_acquire.return_value = (True, "") + + await scheduler._run_workflow("wf-1", steps) + mock_acquire.assert_called_once_with("task-1") + + @pytest.mark.asyncio + async def test_skips_step_when_concurrency_limit_reached(self, scheduler): + steps = [{ + "plugin_id": "nmap", + "inputs": {"target": "example.com"}, + }] + with patch("backend.secuscan.workflows.get_db", new_callable=AsyncMock), \ + patch("backend.secuscan.plugins.get_plugin_manager") as mock_get_pm, \ + patch("backend.secuscan.validation.validate_target", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.rate_limiter.can_execute", return_value=(True, "")), \ + patch("backend.secuscan.ratelimit.concurrent_limiter.acquire", return_value=(False, "Concurrency limit reached")), \ + patch("backend.secuscan.executor.executor.create_task", new_callable=AsyncMock, return_value="task-1"), \ + patch("backend.secuscan.executor.executor.mark_task_failed", new_callable=AsyncMock) as mock_fail: + + mock_pm = MagicMock() + plugin = MagicMock() + plugin.category = "network" + plugin.safety = {"rate_limit": {"max_per_hour": 50}} + plugin.fields = [] + mock_pm.get_plugin.return_value = plugin + mock_get_pm.return_value = mock_pm + + await scheduler._run_workflow("wf-1", steps) + mock_fail.assert_called_once() + + +# --------------------------------------------------------------------------- +# WorkflowScheduler.tick rate limit integration +# --------------------------------------------------------------------------- + +class TestTickRateLimiting: + @pytest.mark.asyncio + async def test_tick_applies_workflow_rate_limiter(self, scheduler): + db_mock = AsyncMock() + db_mock.fetchall.return_value = [{ + "id": "wf-1", + "name": "test", + "schedule_seconds": 60, + "last_run_at": None, + "steps_json": "[]", + }] + with patch("backend.secuscan.workflows.get_db", return_value=db_mock), \ + patch.object(scheduler, "_run_workflow", new_callable=AsyncMock) as mock_run, \ + patch("backend.secuscan.workflows.workflow_rate_limiter.check_workflow_rate_limit", new_callable=AsyncMock) as mock_rate: + + mock_rate.return_value = (True, "") + await scheduler.tick() + mock_rate.assert_called_once_with("wf-1", 60) + mock_run.assert_called_once() + + @pytest.mark.asyncio + async def test_tick_skips_rate_limited_workflow(self, scheduler): + db_mock = AsyncMock() + db_mock.fetchall.return_value = [{ + "id": "wf-1", + "name": "test", + "schedule_seconds": 60, + "last_run_at": None, + "steps_json": "[]", + }] + with patch("backend.secuscan.workflows.get_db", return_value=db_mock), \ + patch.object(scheduler, "_run_workflow", new_callable=AsyncMock) as mock_run, \ + patch("backend.secuscan.workflows.workflow_rate_limiter.check_workflow_rate_limit", new_callable=AsyncMock) as mock_rate: + + mock_rate.return_value = (False, "Workflow rate limited: wait 30s between runs") + await scheduler.tick() + mock_rate.assert_called_once() + mock_run.assert_not_called()