diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c1a1221..cd9c209f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -250,7 +250,7 @@ jobs: if: steps.run_benchmarks.outcome == 'failure' run: | echo "::warning::Performance benchmark thresholds exceeded or benchmarks failed to run. Check the job logs for details." - frontend-checks: + frontend-run-checks: needs: [detect-changes, formatting-hygiene] if: | always() && @@ -304,3 +304,18 @@ jobs: run: npm run test - name: Build frontend run: npm run build + + frontend-checks: + needs: [frontend-run-checks] + if: always() + runs-on: ubuntu-latest + steps: + - name: Check matrix results + run: | + if [ "${{ needs.frontend-run-checks.result }}" = "success" ] || [ "${{ needs.frontend-run-checks.result }}" = "skipped" ]; then + echo "Frontend checks completed successfully or skipped" + exit 0 + else + echo "Frontend checks failed" + exit 1 + fi diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index db078a8e..90c43a98 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -226,7 +226,7 @@ def _enqueue_listener_event(self, task_id: str, q: asyncio.Queue, event: Dict[st q.put_nowait(event) except asyncio.QueueFull: logger.warning("Dropping stream event for slow listener on task %s", task_id) - + async def _broadcast_phase(self, task_id: str, phase: str): """Broadcast a scan phase transition and persist it to the database.""" await self._broadcast(task_id, "phase", phase) @@ -265,16 +265,16 @@ async def create_task( task_id = str(uuid.uuid4()) plugin_manager = get_plugin_manager() plugin = plugin_manager.get_plugin(plugin_id) - + if not plugin: raise ValueError(f"Plugin not found: {plugin_id}") - + # Apply preset if provided if preset and preset in plugin.presets: preset_values = plugin.presets[preset] # Merge preset with user inputs (user inputs take precedence) inputs = {**preset_values, **inputs} - + # Store task in database db = await get_db() await db.execute( @@ -299,7 +299,7 @@ async def create_task( bool(safe_mode) ) ) - + # Log audit event await db.log_audit( "task_created", @@ -313,9 +313,9 @@ async def create_task( task_id=task_id, plugin_id=plugin_id ) - + return task_id - + async def mark_task_failed(self, task_id: str, reason: str) -> None: """ Mark a task as failed without running it. @@ -351,15 +351,299 @@ async def mark_task_failed(self, task_id: str, reason: str) -> None: task_id=task_id, ) - async def execute_task(self, task_id: str): + async def _enforce_guardrails( + self, + target: str, + plugin_id: str, + safe_mode: bool, + task_id: str, + ) -> bool: + """Enforce Safe Mode target validation and Network Policy access checks. + + Returns: + True if all checks pass or are bypassed. False if any check blocks execution. + """ + if not target: + return True + + plugin_manager = get_plugin_manager() + plugin = plugin_manager.get_plugin(plugin_id) + should_validate = True + if plugin and plugin.category == "code": + should_validate = False + + # Use shared is_filesystem_target from validation to ensure + # consistent filesystem detection across route and executor layers. + from .validation import is_filesystem_target + is_fs = is_filesystem_target(target) + + if should_validate and not is_fs: + from .validation import validate_target + try: + # Enforce safe mode validation of target address in a thread pool + is_valid, error_msg = await asyncio.wait_for( + asyncio.to_thread(validate_target, target, safe_mode), + timeout=float(settings.dns_resolution_timeout_seconds), + ) + if not is_valid: + await self.mark_task_failed( + task_id, + f"Safe mode target validation failed: {error_msg}", + ) + await self._broadcast(task_id, "status", TaskStatus.FAILED.value) + return False + except asyncio.TimeoutError: + await self.mark_task_failed( + task_id, + "Target validation timed out (SecuScan Guardrail)", + ) + await self._broadcast(task_id, "status", TaskStatus.FAILED.value) + return False + + # Check before launching any scanner or subprocess. check_access() + # writes an audit entry on every path, so no extra logging needed. + if settings.enforce_network_policy: + engine = get_policy_engine() + try: + allowed, reason, _ = await asyncio.wait_for( + asyncio.to_thread( + engine.check_access, + dest_ip=target, + plugin_id=plugin_id, + task_id=task_id, + ), + timeout=float(settings.dns_resolution_timeout_seconds), + ) + except asyncio.TimeoutError: + allowed, reason = False, "Network policy check timed out (DNS resolution timeout)" + + if not allowed: + if settings.network_policy_failure_mode == "log_only": + logger.warning( + f"[Log Only] Network policy violation allowed for {target}: {reason}" + ) + else: + await self.mark_task_failed( + task_id, + f"Network policy denied access to {target}: {reason}", + ) + await self._broadcast(task_id, "status", TaskStatus.FAILED.value) + return False + + return True + + async def _ensure_docker_network(self) -> None: + """Validate and automatically create the configured Docker network if missing.""" + _net_check = await asyncio.create_subprocess_exec( + "docker", "network", "inspect", settings.docker_network, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await _net_check.wait() + if _net_check.returncode == 0: + return + + logger.info(f"Docker network '{settings.docker_network}' not found. Creating isolated bridge network (ICC disabled)...") + _net_create = await asyncio.create_subprocess_exec( + "docker", "network", "create", + "--driver", "bridge", + "--opt", "com.docker.network.bridge.enable_icc=false", + settings.docker_network, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await _net_create.wait() + if _net_create.returncode == 0: + logger.info(f"Successfully created Docker network '{settings.docker_network}' with ICC disabled") + return + + logger.warning("Failed to create isolated bridge network with ICC disabled. Falling back to standard bridge...") + _net_create_fallback = await asyncio.create_subprocess_exec( + "docker", "network", "create", "--driver", "bridge", settings.docker_network, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await _net_create_fallback.wait() + if _net_create_fallback.returncode != 0: + raise RuntimeError( + f"Docker network '{settings.docker_network}' does not exist and could not be created automatically." + ) + logger.info(f"Successfully created Docker network '{settings.docker_network}' (fallback)") + + async def _execute_modular_scanner( + self, + db, + task_id: str, + owner_id: str, + plugin_id: str, + target: str, + inputs: Dict[str, Any], + safe_mode: bool, + ) -> tuple[str, float]: + """Execute a modular scanner and persist findings/report.""" + scanner_class = MODULAR_SCANNERS[plugin_id] + scanner = scanner_class(task_id, db, safe_mode=safe_mode) + + logger.info(f"Executing modular scanner {plugin_id} for task {task_id}") + await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) + await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) + + start_time = time.time() + result = await scanner.run(target, inputs) + duration = time.time() - start_time + + final_status = ( + TaskStatus.COMPLETED.value + if result.get("status") != "failed" + else TaskStatus.FAILED.value + ) + + await db.execute( + """ + UPDATE tasks SET + status = ?, + completed_at = ?, + duration_seconds = ?, + structured_json = ?, + error_message = ? + WHERE id = ? + """, + ( + final_status, + datetime.now().isoformat(), + duration, + json.dumps(result), + result.get("error_message"), + task_id, + ), + ) + + await self._broadcast_phase(task_id, ScanPhase.PARSING.value) + await self._upsert_findings_and_report_from_scanner( + db=db, + task_id=task_id, + owner_id=owner_id, + scanner=scanner, + plugin_id=plugin_id, + target=target, + status=final_status, + result=result, + ) + await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) + return final_status, duration + + async def _execute_standard_scanner( + self, + db, + task_id: str, + owner_id: str, + plugin: Any, + plugin_id: str, + target: str, + inputs: Dict[str, Any], + ) -> tuple[str, float, int]: + """Execute a standard CLI/Docker plugin and persist findings/report.""" + plugin_manager = get_plugin_manager() + command = plugin_manager.build_command(plugin_id, inputs) + + if not command: + raise ValueError("Failed to build command") + + # Apply Docker Sandboxing if enabled + if settings.docker_enabled: + await self._ensure_docker_network() + docker_image = plugin.docker_image or "alpine:latest" + docker_cmd = [ + "docker", + "run", + "--rm", + "--name", + f"secuscan_task_{task_id}", + "--memory", + f"{settings.sandbox_memory_mb}m", + "--cpus", + str(settings.sandbox_cpu_quota), + "--cap-drop", "NET_RAW", + "--network", settings.docker_network, + docker_image, + ] + command = docker_cmd + command + + logger.info(f"Executing task {task_id}: {' '.join(command)}") + await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) + await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) + + # Execute command + start_time = time.time() + output, exit_code = await self._execute_command( + command, + task_id, + timeout=self._resolve_execution_timeout(inputs), + ) + duration = time.time() - start_time + + # Save raw output + raw_path = Path(settings.raw_output_dir) / f"{task_id}.txt" + output = redact(output) + with open(raw_path, 'w') as f: + f.write(output) + + # Classify result + final_status, error_message = self._classify_command_result( + plugin=plugin, + output=output, + exit_code=exit_code, + ) + + await db.execute( + """ + UPDATE tasks SET + status = ?, + completed_at = ?, + duration_seconds = ?, + exit_code = ?, + raw_output_path = ?, + command_used = ?, + error_message = ? + WHERE id = ? + """, + ( + final_status, + datetime.now().isoformat(), + duration, + exit_code, + str(raw_path), + " ".join(command), + error_message, + task_id, + ), + ) + + # Upsert findings and report + await self._broadcast_phase(task_id, ScanPhase.PARSING.value) + await self._upsert_findings_and_report( + db=db, + task_id=task_id, + owner_id=owner_id, + plugin=plugin, + plugin_id=plugin_id, + target=target, + status=final_status, + output=output, + ) + await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) + return final_status, duration, exit_code + + async def execute_task(self, task_id: str) -> None: """ Execute a task asynchronously. - + Args: task_id: Task identifier """ db = await get_db() self.running_tasks[task_id] = asyncio.current_task() + start_time = time.time() try: # Update status to running — use optimistic lock to detect @@ -399,73 +683,8 @@ async def execute_task(self, task_id: str): ) # ── Safe Mode & Network policy enforcement ─────────────────────── - # Enforce Safe Mode target validation inside TaskExecutor to guarantee - # that all execution paths (manual API, workflows, scheduled tasks) are protected. - if target: - plugin_manager = get_plugin_manager() - plugin = plugin_manager.get_plugin(plugin_id) - should_validate = True - if plugin and plugin.category == "code": - should_validate = False - - - # Use shared is_filesystem_target from validation to ensure - # consistent filesystem detection across route and executor layers. - from .validation import is_filesystem_target - is_fs = is_filesystem_target(target) - - if should_validate and not is_fs: - from .validation import validate_target - try: - # Enforce safe mode validation of target address in a thread pool - is_valid, error_msg = await asyncio.wait_for( - asyncio.to_thread(validate_target, target, safe_mode), - timeout=float(settings.dns_resolution_timeout_seconds), - ) - if not is_valid: - await self.mark_task_failed( - task_id, - f"Safe mode target validation failed: {error_msg}", - ) - await self._broadcast(task_id, "status", TaskStatus.FAILED.value) - return - except asyncio.TimeoutError: - await self.mark_task_failed( - task_id, - "Target validation timed out (SecuScan Guardrail)", - ) - await self._broadcast(task_id, "status", TaskStatus.FAILED.value) - return - - # Check before launching any scanner or subprocess. check_access() - # writes an audit entry on every path, so no extra logging needed. - if target and settings.enforce_network_policy: - engine = get_policy_engine() - try: - allowed, reason, _ = await asyncio.wait_for( - asyncio.to_thread( - engine.check_access, - dest_ip=target, - plugin_id=plugin_id, - task_id=task_id, - ), - timeout=float(settings.dns_resolution_timeout_seconds), - ) - except asyncio.TimeoutError: - allowed, reason = False, "Network policy check timed out (DNS resolution timeout)" - - if not allowed: - if settings.network_policy_failure_mode == "log_only": - logger.warning( - f"[Log Only] Network policy violation allowed for {target}: {reason}" - ) - else: - await self.mark_task_failed( - task_id, - f"Network policy denied access to {target}: {reason}", - ) - await self._broadcast(task_id, "status", TaskStatus.FAILED.value) - return # finally block handles running_tasks cleanup + limiter release + if not await self._enforce_guardrails(target, plugin_id, safe_mode, task_id): + return # Check if this is a modular scanner or a standard plugin plugin_manager = get_plugin_manager() @@ -485,180 +704,26 @@ async def execute_task(self, task_id: str): ) if plugin_id in MODULAR_SCANNERS: - scanner_class = MODULAR_SCANNERS[plugin_id] - scanner = scanner_class(task_id, db, safe_mode=safe_mode) - - logger.info(f"Executing modular scanner {plugin_id} for task {task_id}") - await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) - await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) - - start_time = time.time() - # Run the scanner - result = await scanner.run(target, inputs) - duration = time.time() - start_time - - # Update task with results - final_status = TaskStatus.COMPLETED.value if result.get("status") != "failed" else TaskStatus.FAILED.value - - await db.execute( - """ - UPDATE tasks SET - status = ?, - completed_at = ?, - duration_seconds = ?, - structured_json = ?, - error_message = ? - WHERE id = ? - """, - ( - final_status, - datetime.now().isoformat(), - duration, - json.dumps(result), - result.get("error_message"), - task_id - ) - ) - - # Upsert findings and report using the scanner's result - await self._broadcast_phase(task_id, ScanPhase.PARSING.value) - await self._upsert_findings_and_report_from_scanner( + final_status, duration = await self._execute_modular_scanner( db=db, task_id=task_id, owner_id=owner_id, - scanner=scanner, plugin_id=plugin_id, target=target, - status=final_status, - result=result + inputs=inputs, + safe_mode=safe_mode, ) - await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) - + exit_code = 0 else: - # Standard Plugin Execution - command = plugin_manager.build_command(plugin_id, inputs) - - if not command: - raise ValueError("Failed to build command") - - # Apply Docker Sandboxing if enabled - if settings.docker_enabled: - # Validate the named Docker network exists before using it. - # If missing, attempt to create it automatically rather than failing. - _net_check = await asyncio.create_subprocess_exec( - "docker", "network", "inspect", settings.docker_network, - stdout=asyncio.subprocess.DEVNULL, - stderr=asyncio.subprocess.DEVNULL, - ) - await _net_check.wait() - if _net_check.returncode != 0: - logger.info(f"Docker network '{settings.docker_network}' not found. Creating isolated bridge network (ICC disabled)...") - _net_create = await asyncio.create_subprocess_exec( - "docker", "network", "create", - "--driver", "bridge", - "--opt", "com.docker.network.bridge.enable_icc=false", - settings.docker_network, - stdout=asyncio.subprocess.DEVNULL, - stderr=asyncio.subprocess.DEVNULL, - ) - await _net_create.wait() - if _net_create.returncode != 0: - logger.warning("Failed to create isolated bridge network with ICC disabled. Falling back to standard bridge...") - _net_create_fallback = await asyncio.create_subprocess_exec( - "docker", "network", "create", "--driver", "bridge", settings.docker_network, - stdout=asyncio.subprocess.DEVNULL, - stderr=asyncio.subprocess.DEVNULL, - ) - await _net_create_fallback.wait() - if _net_create_fallback.returncode != 0: - raise RuntimeError( - f"Docker network '{settings.docker_network}' does not exist and could not be created automatically." - ) - logger.info(f"Successfully created Docker network '{settings.docker_network}' (fallback)") - else: - logger.info(f"Successfully created Docker network '{settings.docker_network}' with ICC disabled") - - docker_image = plugin.docker_image or "alpine:latest" - docker_cmd = [ - "docker", - "run", - "--rm", - "--name", - f"secuscan_task_{task_id}", - "--memory", - f"{settings.sandbox_memory_mb}m", - "--cpus", - str(settings.sandbox_cpu_quota), - "--cap-drop", "NET_RAW", - "--network", settings.docker_network, - docker_image, - ] - command = docker_cmd + command - - logger.info(f"Executing task {task_id}: {' '.join(command)}") - await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) - await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) - - # Execute command - start_time = time.time() - output, exit_code = await self._execute_command( - command, - task_id, - timeout=self._resolve_execution_timeout(inputs), - ) - duration = time.time() - start_time - - # Save raw output - raw_path = Path(settings.raw_output_dir) / f"{task_id}.txt" - output = redact(output) - with open(raw_path, 'w') as f: - f.write(output) - - # Some CLI tools use non-zero exit codes for "no result" states while still - # producing a complete, parseable report. Let plugin metadata opt into that. - final_status, error_message = self._classify_command_result( - plugin=plugin, - output=output, - exit_code=exit_code, - ) - - await db.execute( - """ - UPDATE tasks SET - status = ?, - completed_at = ?, - duration_seconds = ?, - exit_code = ?, - raw_output_path = ?, - command_used = ?, - error_message = ? - WHERE id = ? - """, - ( - final_status, - datetime.now().isoformat(), - duration, - exit_code, - str(raw_path), - " ".join(command), - error_message, - task_id - ) - ) - - # Upsert findings and report - await self._broadcast_phase(task_id, ScanPhase.PARSING.value) - await self._upsert_findings_and_report( + final_status, duration, exit_code = await self._execute_standard_scanner( db=db, task_id=task_id, owner_id=owner_id, plugin=plugin, plugin_id=plugin_id, target=target, - status=final_status, - output=output + inputs=inputs, ) - await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) await self._dispatch_task_notifications(db, task_id) @@ -670,7 +735,7 @@ async def execute_task(self, task_id: str): await db.log_audit( "task_completed", f"Task completed in {duration:.2f}s", - context={"task_id": task_id, "exit_code": locals().get('exit_code', 0)}, + context={"task_id": task_id, "exit_code": exit_code}, task_id=task_id, plugin_id=plugin_id ) @@ -678,11 +743,6 @@ async def execute_task(self, task_id: str): logger.info(f"Task {task_id} completed in {duration:.2f}s") except asyncio.CancelledError: - # CancelledError inherits from BaseException, not Exception — - # it bypasses the broad except below, so we handle it explicitly. - # Task.cancelled() returns False while the finally block is still - # executing, so this is the only reliable place to write the - # cancellation status to the DB. duration = (time.time() - start_time) if 'start_time' in locals() else 0 await db.execute( """ @@ -740,8 +800,6 @@ async def execute_task(self, task_id: str): except Exception as e: logger.error(f"Task {task_id} failed: {e}", exc_info=True) - - # Update task as failed duration = (time.time() - start_time) if 'start_time' in locals() else 0 await db.execute( """ @@ -772,12 +830,10 @@ async def execute_task(self, task_id: str): task_id=task_id ) finally: - # Always clean up: remove from the in-memory registry and - # release the concurrency slot regardless of how the task ended. self.running_tasks.pop(task_id, None) self._process_pids.pop(task_id, None) await concurrent_limiter.release(task_id) - + async def _execute_command( self, command: list, @@ -962,7 +1018,7 @@ async def cancel_task(self, task_id: str) -> bool: ) return True - + async def get_task_status(self, task_id: str) -> Optional[Dict]: """Get task status and progress""" db = await get_db() @@ -1506,12 +1562,12 @@ def _parse_results(self, plugin, output: str) -> Dict[str, Any]: """Route to appropriate parser based on plugin metadata.""" parser_type = plugin.output.get("parser") parser_input = self._resolve_parser_input(plugin, output) - + # 1. Check for custom parser.py in plugin directory (Recommended) plugin_manager = get_plugin_manager() plugin_dir = plugin_manager.plugins_dir / plugin.id parser_path = plugin_dir / "parser.py" - + if parser_path.exists(): if not plugin_manager.verify_parser_at_exec_time(plugin, plugin_dir): raise ValueError( @@ -1547,7 +1603,7 @@ def _parse_results(self, plugin, output: str) -> Dict[str, Any]: return self._normalize_parsed_result(plugin, parser_input, self._parse_nmap_output(parser_input)) elif parser_type == "builtin_http": return self._normalize_parsed_result(plugin, parser_input, self._parse_http_output(parser_input)) - + return self._normalize_parsed_result(plugin, parser_input, {"findings": [], "raw": parser_input}) def _resolve_parser_input(self, plugin, output: str) -> str: @@ -1720,7 +1776,7 @@ def _parse_nmap_output(self, output: str) -> Dict[str, Any]: findings = [] ports = [] services = [] - + # Regex for open ports: 80/tcp open http port_pattern = re.compile(r"(\d+)/(tcp|udp)\s+open\s+([\w-]+)") for match in port_pattern.finditer(output): @@ -1736,7 +1792,7 @@ def _parse_nmap_output(self, output: str) -> Dict[str, Any]: "remediation": "Close unnecessary ports and use a firewall to restrict access.", "metadata": {"port": port_str, "protocol": proto, "service": service} }) - + return { "open_ports": sorted(list(set(ports))), "services": sorted(list(set(services))), diff --git a/frontend/src/pages/Reports.tsx b/frontend/src/pages/Reports.tsx index b690b787..8a073f89 100644 --- a/frontend/src/pages/Reports.tsx +++ b/frontend/src/pages/Reports.tsx @@ -445,7 +445,7 @@ export default function Reports() { className="bg-rag-green border-4 border-black px-3 py-2 text-[9px] font-black uppercase tracking-widest text-black shadow-[4px_4px_0px_0px_rgba(0,0,0,1)] hover:shadow-none hover:translate-x-1 hover:translate-y-1 transition-all" title="Download PDF Report" > - PDF + Quick PDF {(() => { const ordered = preferredFormat diff --git a/testing/backend/unit/test_database_workflow_versions.py b/testing/backend/unit/test_database_workflow_versions.py index 15a5a038..db9d9d77 100644 --- a/testing/backend/unit/test_database_workflow_versions.py +++ b/testing/backend/unit/test_database_workflow_versions.py @@ -1,270 +1,209 @@ """ Unit tests for database workflow version methods. """ -import asyncio import uuid +import json +import pytest +import pytest_asyncio from backend.secuscan.database import Database -def run(coro): - return asyncio.run(coro) - - -def make_db(): - return Database(":memory:") +@pytest_asyncio.fixture +async def db(): + database = Database(":memory:") + await database.connect() + # Insert dummy workflows since workflow_versions and workflow_runs have a foreign key to workflows + await database.execute( + "INSERT INTO workflows (id, name, steps_json) VALUES (?, ?, ?)", + ("wf-1", "Default Workflow", "[]") + ) + await database.execute( + "INSERT INTO workflows (id, name, steps_json) VALUES (?, ?, ?)", + ("wf-test-1", "Test Workflow 1", "[]") + ) + await database.execute( + "INSERT INTO workflows (id, name, steps_json) VALUES (?, ?, ?)", + ("wf-A", "Workflow A", "[]") + ) + await database.execute( + "INSERT INTO workflows (id, name, steps_json) VALUES (?, ?, ?)", + ("wf-B", "Workflow B", "[]") + ) + yield database + await database.disconnect() class TestSnapshotWorkflowVersion: - def test_first_snapshot_has_version_1(self): - db = make_db() - run(db.connect()) - try: - v = run(db.snapshot_workflow_version( - "wf-test-1", "Test WF", 60, True, [{"plugin_id": "nmap"}] - )) - assert v["version_number"] == 1 - assert v["workflow_id"] == "wf-test-1" - assert v["created_by"] == "system" - finally: - run(db.disconnect()) - - def test_subsequent_snapshots_increment_version(self): - db = make_db() - run(db.connect()) - try: - v1 = run(db.snapshot_workflow_version("wf-1", "WF", 60, True, [])) - v2 = run(db.snapshot_workflow_version("wf-1", "WF", 60, True, [])) - assert v2["version_number"] == v1["version_number"] + 1 - finally: - run(db.disconnect()) - - def test_snapshot_stores_definition(self): - db = make_db() - run(db.connect()) - try: - steps = [{"plugin_id": "nmap", "inputs": {"target": "127.0.0.1"}}] - v = run(db.snapshot_workflow_version("wf-1", "My WF", 120, False, steps)) - assert v["definition"]["name"] == "My WF" - assert v["definition"]["schedule_seconds"] == 120 - assert v["definition"]["enabled"] is False - assert v["definition"]["steps"] == steps - finally: - run(db.disconnect()) - - def test_snapshots_across_workflows_independent(self): - db = make_db() - run(db.connect()) - try: - v_a1 = run(db.snapshot_workflow_version("wf-A", "A", 60, True, [])) - v_b1 = run(db.snapshot_workflow_version("wf-B", "B", 60, True, [])) - v_a2 = run(db.snapshot_workflow_version("wf-A", "A", 60, True, [])) - assert v_a1["version_number"] == 1 - assert v_b1["version_number"] == 1 - assert v_a2["version_number"] == 2 - finally: - run(db.disconnect()) + @pytest.mark.asyncio + async def test_first_snapshot_has_version_1(self, db): + v = await db.snapshot_workflow_version( + "wf-test-1", "Test WF", 60, True, [{"plugin_id": "nmap"}] + ) + assert v["version_number"] == 1 + assert v["workflow_id"] == "wf-test-1" + assert v["created_by"] == "system" + + @pytest.mark.asyncio + async def test_subsequent_snapshots_increment_version(self, db): + v1 = await db.snapshot_workflow_version("wf-1", "WF", 60, True, []) + v2 = await db.snapshot_workflow_version("wf-1", "WF", 60, True, []) + assert v2["version_number"] == v1["version_number"] + 1 + + @pytest.mark.asyncio + async def test_snapshot_stores_definition(self, db): + steps = [{"plugin_id": "nmap", "inputs": {"target": "127.0.0.1"}}] + v = await db.snapshot_workflow_version("wf-1", "My WF", 120, False, steps) + assert v["definition"]["name"] == "My WF" + assert v["definition"]["schedule_seconds"] == 120 + assert v["definition"]["enabled"] is False + assert v["definition"]["steps"] == steps + + @pytest.mark.asyncio + async def test_snapshots_across_workflows_independent(self, db): + v_a1 = await db.snapshot_workflow_version("wf-A", "A", 60, True, []) + v_b1 = await db.snapshot_workflow_version("wf-B", "B", 60, True, []) + v_a2 = await db.snapshot_workflow_version("wf-A", "A", 60, True, []) + assert v_a1["version_number"] == 1 + assert v_b1["version_number"] == 1 + assert v_a2["version_number"] == 2 class TestGetWorkflowVersions: - def test_returns_all_versions_newest_first(self): - db = make_db() - run(db.connect()) - try: - run(db.snapshot_workflow_version("wf-1", "WF", 60, True, [])) - run(db.snapshot_workflow_version("wf-1", "WF", 60, True, [])) - run(db.snapshot_workflow_version("wf-1", "WF", 60, True, [])) - versions = run(db.get_workflow_versions("wf-1")) - assert len(versions) == 3 - assert versions[0]["version_number"] == 3 - assert versions[1]["version_number"] == 2 - assert versions[2]["version_number"] == 1 - finally: - run(db.disconnect()) - - def test_returns_empty_for_unknown_workflow(self): - db = make_db() - run(db.connect()) - try: - versions = run(db.get_workflow_versions("does-not-exist")) - assert versions == [] - finally: - run(db.disconnect()) + @pytest.mark.asyncio + async def test_returns_all_versions_newest_first(self, db): + await db.snapshot_workflow_version("wf-1", "WF", 60, True, []) + await db.snapshot_workflow_version("wf-1", "WF", 60, True, []) + await db.snapshot_workflow_version("wf-1", "WF", 60, True, []) + versions = await db.get_workflow_versions("wf-1") + assert len(versions) == 3 + assert versions[0]["version_number"] == 3 + assert versions[1]["version_number"] == 2 + assert versions[2]["version_number"] == 1 + + @pytest.mark.asyncio + async def test_returns_empty_for_unknown_workflow(self, db): + versions = await db.get_workflow_versions("does-not-exist") + assert versions == [] class TestGetWorkflowVersion: - def test_returns_specific_version(self): - db = make_db() - run(db.connect()) - try: - created = run(db.snapshot_workflow_version("wf-1", "WF", 60, True, [])) - found = run(db.get_workflow_version("wf-1", created["version_number"])) - assert found is not None - assert found["id"] == created["id"] - finally: - run(db.disconnect()) + @pytest.mark.asyncio + async def test_returns_specific_version(self, db): + created = await db.snapshot_workflow_version("wf-1", "WF", 60, True, []) + found = await db.get_workflow_version("wf-1", created["version_number"]) + assert found is not None + assert found["id"] == created["id"] - def test_returns_none_for_missing_workflow(self): - db = make_db() - run(db.connect()) - try: - result = run(db.get_workflow_version("wf-does-not-exist", 99)) - assert result is None - finally: - run(db.disconnect()) + @pytest.mark.asyncio + async def test_returns_none_for_missing_workflow(self, db): + result = await db.get_workflow_version("wf-does-not-exist", 99) + assert result is None - def test_returns_none_for_missing_version_number(self): - db = make_db() - run(db.connect()) - try: - run(db.snapshot_workflow_version("wf-1", "WF", 60, True, [])) - result = run(db.get_workflow_version("wf-1", 99)) - assert result is None - finally: - run(db.disconnect()) + @pytest.mark.asyncio + async def test_returns_none_for_missing_version_number(self, db): + await db.snapshot_workflow_version("wf-1", "WF", 60, True, []) + result = await db.get_workflow_version("wf-1", 99) + assert result is None class TestRecordWorkflowRun: - def test_inserts_queued_run(self): - db = make_db() - run(db.connect()) - try: - run_id = run(db.record_workflow_run("wf-1", None, 1, ["t1", "t2"], "manual")) - assert run_id is not None - run_row = run(db.fetchone("SELECT status, triggered_by FROM workflow_runs WHERE id = ?", (run_id,))) - assert run_row["status"] == "queued" - assert run_row["triggered_by"] == "manual" - finally: - run(db.disconnect()) + @pytest.mark.asyncio + async def test_inserts_queued_run(self, db): + run_id = await db.record_workflow_run("wf-1", None, 1, ["t1", "t2"], "manual") + assert run_id is not None + run_row = await db.fetchone("SELECT status, triggered_by FROM workflow_runs WHERE id = ?", (run_id,)) + assert run_row["status"] == "queued" + assert run_row["triggered_by"] == "manual" - def test_inserts_empty_task_list(self): - db = make_db() - run(db.connect()) - try: - run_id = run(db.record_workflow_run("wf-1", None, 1, [], "scheduler")) - raw = run(db.fetchone("SELECT task_ids_json FROM workflow_runs WHERE id = ?", (run_id,))) - assert raw["task_ids_json"] == "[]" - finally: - run(db.disconnect()) + @pytest.mark.asyncio + async def test_inserts_empty_task_list(self, db): + run_id = await db.record_workflow_run("wf-1", None, 1, [], "scheduler") + raw = await db.fetchone("SELECT task_ids_json FROM workflow_runs WHERE id = ?", (run_id,)) + assert raw["task_ids_json"] == "[]" class TestFinalizeWorkflowRun: - def test_sets_status_and_timestamp(self): - db = make_db() - run(db.connect()) - try: - run_id = run(db.record_workflow_run("wf-1", None, 1, [], "manual")) - run(db.finalize_workflow_run(run_id, "completed")) - run_row = run(db.fetchone("SELECT status, completed_at FROM workflow_runs WHERE id = ?", (run_id,))) - assert run_row["status"] == "completed" - assert run_row["completed_at"] is not None - finally: - run(db.disconnect()) - - def test_finalize_with_error_message(self): - db = make_db() - run(db.connect()) - try: - run_id = run(db.record_workflow_run("wf-1", None, 1, [], "manual")) - run(db.finalize_workflow_run(run_id, "failed", error_message="Plugin not found")) - run_row = run(db.fetchone("SELECT status, error_message FROM workflow_runs WHERE id = ?", (run_id,))) - assert run_row["status"] == "failed" - assert run_row["error_message"] == "Plugin not found" - finally: - run(db.disconnect()) + @pytest.mark.asyncio + async def test_sets_status_and_timestamp(self, db): + run_id = await db.record_workflow_run("wf-1", None, 1, [], "manual") + await db.finalize_workflow_run(run_id, "completed") + run_row = await db.fetchone("SELECT status, completed_at FROM workflow_runs WHERE id = ?", (run_id,)) + assert run_row["status"] == "completed" + assert run_row["completed_at"] is not None + + @pytest.mark.asyncio + async def test_finalize_with_error_message(self, db): + run_id = await db.record_workflow_run("wf-1", None, 1, [], "manual") + await db.finalize_workflow_run(run_id, "failed", error_message="Plugin not found") + run_row = await db.fetchone("SELECT status, error_message FROM workflow_runs WHERE id = ?", (run_id,)) + assert run_row["status"] == "failed" + assert run_row["error_message"] == "Plugin not found" class TestCheckWorkflowRunTasks: - def test_empty_run_returns_completed(self): - db = make_db() - run(db.connect()) - try: - run_id = run(db.record_workflow_run("wf-1", None, 1, [], "manual")) - result = run(db.check_workflow_run_tasks(run_id)) - assert result == "completed" - finally: - run(db.disconnect()) - - def test_all_tasks_completed_returns_completed(self): - db = make_db() - run(db.connect()) - try: - task_ids = [] - for _ in range(3): - tid = uuid.uuid4().hex - run(db.execute( - "INSERT INTO tasks (id, plugin_id, tool_name, target, inputs_json, execution_context_json, status) VALUES (?, ?, ?, ?, ?, ?, ?)", - (tid, "nmap", "nmap", "127.0.0.1", "{}", "{}", "completed"), - )) - task_ids.append(tid) - run_id = run(db.record_workflow_run("wf-1", None, 1, task_ids, "manual")) - result = run(db.check_workflow_run_tasks(run_id)) - assert result == "completed" - finally: - run(db.disconnect()) - - def test_still_running_returns_none(self): - db = make_db() - run(db.connect()) - try: + @pytest.mark.asyncio + async def test_empty_run_returns_completed(self, db): + run_id = await db.record_workflow_run("wf-1", None, 1, [], "manual") + result = await db.check_workflow_run_tasks(run_id) + assert result == "completed" + + @pytest.mark.asyncio + async def test_all_tasks_completed_returns_completed(self, db): + task_ids = [] + for _ in range(3): tid = uuid.uuid4().hex - run(db.execute( + await db.execute( "INSERT INTO tasks (id, plugin_id, tool_name, target, inputs_json, execution_context_json, status) VALUES (?, ?, ?, ?, ?, ?, ?)", - (tid, "nmap", "nmap", "127.0.0.1", "{}", "{}", "running"), - )) - run_id = run(db.record_workflow_run("wf-1", None, 1, [tid], "manual")) - result = run(db.check_workflow_run_tasks(run_id)) - assert result is None - finally: - run(db.disconnect()) - - def test_any_task_failed_returns_failed(self): - db = make_db() - run(db.connect()) - try: - tid = uuid.uuid4().hex - run(db.execute( - "INSERT INTO tasks (id, plugin_id, tool_name, target, inputs_json, execution_context_json, status) VALUES (?, ?, ?, ?, ?, ?, ?)", - (tid, "nmap", "nmap", "127.0.0.1", "{}", "{}", "failed"), - )) - run_id = run(db.record_workflow_run("wf-1", None, 1, [tid], "manual")) - result = run(db.check_workflow_run_tasks(run_id)) - assert result == "failed" - finally: - run(db.disconnect()) - - def test_missing_run_id_returns_none(self): - db = make_db() - run(db.connect()) - try: - result = run(db.check_workflow_run_tasks("no-such-run")) - assert result is None - finally: - run(db.disconnect()) + (tid, "nmap", "nmap", "127.0.0.1", "{}", "{}", "completed"), + ) + task_ids.append(tid) + run_id = await db.record_workflow_run("wf-1", None, 1, task_ids, "manual") + result = await db.check_workflow_run_tasks(run_id) + assert result == "completed" + + @pytest.mark.asyncio + async def test_still_running_returns_none(self, db): + tid = uuid.uuid4().hex + await db.execute( + "INSERT INTO tasks (id, plugin_id, tool_name, target, inputs_json, execution_context_json, status) VALUES (?, ?, ?, ?, ?, ?, ?)", + (tid, "nmap", "nmap", "127.0.0.1", "{}", "{}", "running"), + ) + run_id = await db.record_workflow_run("wf-1", None, 1, [tid], "manual") + result = await db.check_workflow_run_tasks(run_id) + assert result is None + + @pytest.mark.asyncio + async def test_any_task_failed_returns_failed(self, db): + tid = uuid.uuid4().hex + await db.execute( + "INSERT INTO tasks (id, plugin_id, tool_name, target, inputs_json, execution_context_json, status) VALUES (?, ?, ?, ?, ?, ?, ?)", + (tid, "nmap", "nmap", "127.0.0.1", "{}", "{}", "failed"), + ) + run_id = await db.record_workflow_run("wf-1", None, 1, [tid], "manual") + result = await db.check_workflow_run_tasks(run_id) + assert result == "failed" + + @pytest.mark.asyncio + async def test_missing_run_id_returns_none(self, db): + result = await db.check_workflow_run_tasks("no-such-run") + assert result is None class TestGetWorkflowRuns: - def test_returns_paginated_run_history(self): - db = make_db() - run(db.connect()) - try: - for _ in range(3): - run_id = run(db.record_workflow_run("wf-1", None, 1, [], "manual")) - run(db.finalize_workflow_run(run_id, "completed")) - result = run(db.get_workflow_runs("wf-1", limit=10)) - assert result["total"] == 3 - assert len(result["runs"]) == 3 - finally: - run(db.disconnect()) - - def test_respects_limit_and_offset(self): - db = make_db() - run(db.connect()) - try: - for _ in range(3): - run_id = run(db.record_workflow_run("wf-1", None, 1, [], "manual")) - run(db.finalize_workflow_run(run_id, "completed")) - result = run(db.get_workflow_runs("wf-1", limit=1, offset=1)) - assert result["total"] == 3 - assert len(result["runs"]) == 1 - finally: - run(db.disconnect()) + @pytest.mark.asyncio + async def test_returns_paginated_run_history(self, db): + for _ in range(3): + run_id = await db.record_workflow_run("wf-1", None, 1, [], "manual") + await db.finalize_workflow_run(run_id, "completed") + result = await db.get_workflow_runs("wf-1", limit=10) + assert result["total"] == 3 + assert len(result["runs"]) == 3 + + @pytest.mark.asyncio + async def test_respects_limit_and_offset(self, db): + for _ in range(3): + run_id = await db.record_workflow_run("wf-1", None, 1, [], "manual") + await db.finalize_workflow_run(run_id, "completed") + result = await db.get_workflow_runs("wf-1", limit=1, offset=1) + assert result["total"] == 3 + assert len(result["runs"]) == 1 diff --git a/testing/backend/unit/test_executor.py b/testing/backend/unit/test_executor.py index dc611ed3..55c513aa 100644 --- a/testing/backend/unit/test_executor.py +++ b/testing/backend/unit/test_executor.py @@ -710,6 +710,172 @@ async def _wait(): await db.disconnect() +# --------------------------------------------------------------------------- +# Direct tests for extracted helper methods +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_enforce_guardrails_empty_target(): + executor = TaskExecutor() + # If target is empty, enforce_guardrails should immediately return True + res = await executor._enforce_guardrails("", "nmap", False, "task-1") + assert res is True + + +@pytest.mark.asyncio +async def test_enforce_guardrails_validation_failure(setup_test_environment): + await init_db(settings.database_path) + db = await get_db() + + executor = TaskExecutor() + task_id = str(uuid.uuid4()) + + # Pre-populate task in DB so mark_task_failed works + await db.execute( + "INSERT INTO tasks (id, plugin_id, tool_name, target, inputs_json, status, consent_granted, safe_mode) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (task_id, "nmap", "nmap", "127.0.0.1", "{}", TaskStatus.QUEUED.value, 1, 1) + ) + + with patch("backend.secuscan.executor.get_plugin_manager") as mock_pm, \ + patch("backend.secuscan.executor.asyncio.to_thread") as mock_to_thread: + + mock_plugin = MagicMock() + mock_plugin.category = "Network" + mock_pm.return_value.get_plugin.return_value = mock_plugin + + # validate_target returns (False, "invalid target") + mock_to_thread.return_value = (False, "invalid target") + + res = await executor._enforce_guardrails("127.0.0.1", "nmap", True, task_id) + assert res is False + + row = await db.fetchone("SELECT status, error_message FROM tasks WHERE id = ?", (task_id,)) + assert row["status"] == TaskStatus.FAILED.value + assert "Safe mode target validation failed" in row["error_message"] + await db.disconnect() + + +@pytest.mark.asyncio +async def test_enforce_guardrails_network_policy_failure(setup_test_environment): + await init_db(settings.database_path) + db = await get_db() + + executor = TaskExecutor() + task_id = str(uuid.uuid4()) + + await db.execute( + "INSERT INTO tasks (id, plugin_id, tool_name, target, inputs_json, status, consent_granted, safe_mode) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (task_id, "nmap", "nmap", "10.0.0.1", "{}", TaskStatus.QUEUED.value, 1, 0) + ) + + mock_engine = MagicMock() + mock_engine.check_access.return_value = (False, "Blocked by policy", None) + + with patch("backend.secuscan.executor.settings") as mock_settings, \ + patch("backend.secuscan.executor.get_policy_engine", return_value=mock_engine), \ + patch("backend.secuscan.executor.get_plugin_manager") as mock_pm: + + mock_settings.enforce_network_policy = True + mock_settings.network_policy_failure_mode = "block" + mock_settings.dns_resolution_timeout_seconds = 5 + + mock_plugin = MagicMock() + mock_plugin.category = "Network" + mock_pm.return_value.get_plugin.return_value = mock_plugin + + res = await executor._enforce_guardrails("10.0.0.1", "nmap", False, task_id) + assert res is False + + row = await db.fetchone("SELECT status, error_message FROM tasks WHERE id = ?", (task_id,)) + assert row["status"] == TaskStatus.FAILED.value + assert "Network policy denied access" in row["error_message"] + await db.disconnect() + + +@pytest.mark.asyncio +async def test_ensure_docker_network_exists(): + executor = TaskExecutor() + + proc = MagicMock() + proc.returncode = 0 + proc.wait = AsyncMock(return_value=0) + + with patch("backend.secuscan.executor.asyncio.create_subprocess_exec", return_value=proc) as mock_create: + await executor._ensure_docker_network() + + # Should only call inspect network once + mock_create.assert_called_once_with( + "docker", "network", "inspect", settings.docker_network, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + + +@pytest.mark.asyncio +async def test_execute_modular_scanner(setup_test_environment): + await init_db(settings.database_path) + db = await get_db() + + task_id = str(uuid.uuid4()) + owner_id = str(uuid.uuid4()) + + # Insert task in DB + await db.execute( + """ + INSERT INTO tasks (id, owner_id, plugin_id, tool_name, target, inputs_json, status, consent_granted, safe_mode) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (task_id, owner_id, "mock_scanner", "mock_scanner", "127.0.0.1", "{}", TaskStatus.QUEUED.value, 1, 0) + ) + + class MockScanner: + name = "MockScanner" + def __init__(self, task_id, db, safe_mode=False): + self.task_id = task_id + self.db = db + self.safe_mode = safe_mode + + async def run(self, target, inputs): + return { + "status": "completed", + "findings": [ + { + "title": "Mock Finding", + "category": "Mock Category", + "severity": "low", + "description": "Mock description", + } + ], + "asset_services": [] + } + + executor = TaskExecutor() + + # We patch the MODULAR_SCANNERS dictionary in backend.secuscan.executor + with patch.dict("backend.secuscan.executor.MODULAR_SCANNERS", {"mock_scanner": MockScanner}): + status, duration = await executor._execute_modular_scanner( + db=db, + task_id=task_id, + owner_id=owner_id, + plugin_id="mock_scanner", + target="127.0.0.1", + inputs={}, + safe_mode=False + ) + + assert status == TaskStatus.COMPLETED.value + assert duration >= 0 + + # Verify task updated in DB + row = await db.fetchone("SELECT status, structured_json FROM tasks WHERE id = ?", (task_id,)) + assert row["status"] == TaskStatus.COMPLETED.value + structured = json.loads(row["structured_json"]) + assert len(structured["findings"]) == 1 + assert structured["findings"][0]["title"] == "Mock Finding" + + await db.disconnect() + + @pytest.mark.asyncio async def test_execute_task_aborts_when_task_no_longer_queued(setup_test_environment): """ @@ -747,6 +913,59 @@ async def test_execute_task_aborts_when_task_no_longer_queued(setup_test_environ await db.disconnect() +@pytest.mark.asyncio +async def test_execute_standard_scanner(setup_test_environment): + await init_db(settings.database_path) + db = await get_db() + + task_id = str(uuid.uuid4()) + owner_id = str(uuid.uuid4()) + + # Insert task in DB + await db.execute( + """ + INSERT INTO tasks (id, owner_id, plugin_id, tool_name, target, inputs_json, status, consent_granted, safe_mode) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (task_id, owner_id, "mock_cli_plugin", "mock_cli_plugin", "127.0.0.1", "{}", TaskStatus.QUEUED.value, 1, 0) + ) + + executor = TaskExecutor() + + mock_plugin = MagicMock() + mock_plugin.id = "mock_cli_plugin" + mock_plugin.name = "mock_cli_plugin" + mock_plugin.docker_image = None + + with patch("backend.secuscan.executor.get_plugin_manager") as mock_pm, \ + patch.object(executor, "_execute_command", return_value=("Mock output\n", 0)) as mock_exec, \ + patch.object(executor, "_classify_command_result", return_value=(TaskStatus.COMPLETED.value, None)) as mock_classify, \ + patch.object(executor, "_upsert_findings_and_report") as mock_upsert: + + mock_pm.return_value.build_command.return_value = ["ping", "127.0.0.1"] + + status, duration, exit_code = await executor._execute_standard_scanner( + db=db, + task_id=task_id, + owner_id=owner_id, + plugin=mock_plugin, + plugin_id="mock_cli_plugin", + target="127.0.0.1", + inputs={} + ) + + assert status == TaskStatus.COMPLETED.value + assert exit_code == 0 + assert duration >= 0 + + # Verify task updated in DB + row = await db.fetchone("SELECT status, exit_code FROM tasks WHERE id = ?", (task_id,)) + assert row["status"] == TaskStatus.COMPLETED.value + assert row["exit_code"] == 0 + + await db.disconnect() + + @pytest.mark.asyncio async def test_execute_task_aborts_when_task_deleted_before_running(setup_test_environment): """ diff --git a/testing/backend/unit/test_risk_scoring.py b/testing/backend/unit/test_risk_scoring.py index 820d7978..ccf6ba19 100644 --- a/testing/backend/unit/test_risk_scoring.py +++ b/testing/backend/unit/test_risk_scoring.py @@ -296,6 +296,12 @@ async def test_backfill_sets_risk_score_on_null_findings(self, setup_test_enviro await init_db(settings.database_path) db = await get_db() + # Insert referenced task first to satisfy foreign key constraint + await db.execute( + "INSERT INTO tasks (id, plugin_id, tool_name, target) VALUES (?, ?, ?, ?)", + ("task-1", "test", "test", "example.com") + ) + finding_id = "test-finding-001" await db.execute( """ @@ -335,6 +341,12 @@ async def test_backfill_idempotent(self, setup_test_environment): await init_db(settings.database_path) db = await get_db() + # Insert referenced task first to satisfy foreign key constraint + await db.execute( + "INSERT INTO tasks (id, plugin_id, tool_name, target) VALUES (?, ?, ?, ?)", + ("task-2", "test", "test", "example.com") + ) + finding_id = "test-finding-002" await db.execute( """