diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index db078a8e..90c43a98 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -226,7 +226,7 @@ def _enqueue_listener_event(self, task_id: str, q: asyncio.Queue, event: Dict[st q.put_nowait(event) except asyncio.QueueFull: logger.warning("Dropping stream event for slow listener on task %s", task_id) - + async def _broadcast_phase(self, task_id: str, phase: str): """Broadcast a scan phase transition and persist it to the database.""" await self._broadcast(task_id, "phase", phase) @@ -265,16 +265,16 @@ async def create_task( task_id = str(uuid.uuid4()) plugin_manager = get_plugin_manager() plugin = plugin_manager.get_plugin(plugin_id) - + if not plugin: raise ValueError(f"Plugin not found: {plugin_id}") - + # Apply preset if provided if preset and preset in plugin.presets: preset_values = plugin.presets[preset] # Merge preset with user inputs (user inputs take precedence) inputs = {**preset_values, **inputs} - + # Store task in database db = await get_db() await db.execute( @@ -299,7 +299,7 @@ async def create_task( bool(safe_mode) ) ) - + # Log audit event await db.log_audit( "task_created", @@ -313,9 +313,9 @@ async def create_task( task_id=task_id, plugin_id=plugin_id ) - + return task_id - + async def mark_task_failed(self, task_id: str, reason: str) -> None: """ Mark a task as failed without running it. @@ -351,15 +351,299 @@ async def mark_task_failed(self, task_id: str, reason: str) -> None: task_id=task_id, ) - async def execute_task(self, task_id: str): + async def _enforce_guardrails( + self, + target: str, + plugin_id: str, + safe_mode: bool, + task_id: str, + ) -> bool: + """Enforce Safe Mode target validation and Network Policy access checks. + + Returns: + True if all checks pass or are bypassed. False if any check blocks execution. + """ + if not target: + return True + + plugin_manager = get_plugin_manager() + plugin = plugin_manager.get_plugin(plugin_id) + should_validate = True + if plugin and plugin.category == "code": + should_validate = False + + # Use shared is_filesystem_target from validation to ensure + # consistent filesystem detection across route and executor layers. + from .validation import is_filesystem_target + is_fs = is_filesystem_target(target) + + if should_validate and not is_fs: + from .validation import validate_target + try: + # Enforce safe mode validation of target address in a thread pool + is_valid, error_msg = await asyncio.wait_for( + asyncio.to_thread(validate_target, target, safe_mode), + timeout=float(settings.dns_resolution_timeout_seconds), + ) + if not is_valid: + await self.mark_task_failed( + task_id, + f"Safe mode target validation failed: {error_msg}", + ) + await self._broadcast(task_id, "status", TaskStatus.FAILED.value) + return False + except asyncio.TimeoutError: + await self.mark_task_failed( + task_id, + "Target validation timed out (SecuScan Guardrail)", + ) + await self._broadcast(task_id, "status", TaskStatus.FAILED.value) + return False + + # Check before launching any scanner or subprocess. check_access() + # writes an audit entry on every path, so no extra logging needed. + if settings.enforce_network_policy: + engine = get_policy_engine() + try: + allowed, reason, _ = await asyncio.wait_for( + asyncio.to_thread( + engine.check_access, + dest_ip=target, + plugin_id=plugin_id, + task_id=task_id, + ), + timeout=float(settings.dns_resolution_timeout_seconds), + ) + except asyncio.TimeoutError: + allowed, reason = False, "Network policy check timed out (DNS resolution timeout)" + + if not allowed: + if settings.network_policy_failure_mode == "log_only": + logger.warning( + f"[Log Only] Network policy violation allowed for {target}: {reason}" + ) + else: + await self.mark_task_failed( + task_id, + f"Network policy denied access to {target}: {reason}", + ) + await self._broadcast(task_id, "status", TaskStatus.FAILED.value) + return False + + return True + + async def _ensure_docker_network(self) -> None: + """Validate and automatically create the configured Docker network if missing.""" + _net_check = await asyncio.create_subprocess_exec( + "docker", "network", "inspect", settings.docker_network, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await _net_check.wait() + if _net_check.returncode == 0: + return + + logger.info(f"Docker network '{settings.docker_network}' not found. Creating isolated bridge network (ICC disabled)...") + _net_create = await asyncio.create_subprocess_exec( + "docker", "network", "create", + "--driver", "bridge", + "--opt", "com.docker.network.bridge.enable_icc=false", + settings.docker_network, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await _net_create.wait() + if _net_create.returncode == 0: + logger.info(f"Successfully created Docker network '{settings.docker_network}' with ICC disabled") + return + + logger.warning("Failed to create isolated bridge network with ICC disabled. Falling back to standard bridge...") + _net_create_fallback = await asyncio.create_subprocess_exec( + "docker", "network", "create", "--driver", "bridge", settings.docker_network, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await _net_create_fallback.wait() + if _net_create_fallback.returncode != 0: + raise RuntimeError( + f"Docker network '{settings.docker_network}' does not exist and could not be created automatically." + ) + logger.info(f"Successfully created Docker network '{settings.docker_network}' (fallback)") + + async def _execute_modular_scanner( + self, + db, + task_id: str, + owner_id: str, + plugin_id: str, + target: str, + inputs: Dict[str, Any], + safe_mode: bool, + ) -> tuple[str, float]: + """Execute a modular scanner and persist findings/report.""" + scanner_class = MODULAR_SCANNERS[plugin_id] + scanner = scanner_class(task_id, db, safe_mode=safe_mode) + + logger.info(f"Executing modular scanner {plugin_id} for task {task_id}") + await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) + await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) + + start_time = time.time() + result = await scanner.run(target, inputs) + duration = time.time() - start_time + + final_status = ( + TaskStatus.COMPLETED.value + if result.get("status") != "failed" + else TaskStatus.FAILED.value + ) + + await db.execute( + """ + UPDATE tasks SET + status = ?, + completed_at = ?, + duration_seconds = ?, + structured_json = ?, + error_message = ? + WHERE id = ? + """, + ( + final_status, + datetime.now().isoformat(), + duration, + json.dumps(result), + result.get("error_message"), + task_id, + ), + ) + + await self._broadcast_phase(task_id, ScanPhase.PARSING.value) + await self._upsert_findings_and_report_from_scanner( + db=db, + task_id=task_id, + owner_id=owner_id, + scanner=scanner, + plugin_id=plugin_id, + target=target, + status=final_status, + result=result, + ) + await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) + return final_status, duration + + async def _execute_standard_scanner( + self, + db, + task_id: str, + owner_id: str, + plugin: Any, + plugin_id: str, + target: str, + inputs: Dict[str, Any], + ) -> tuple[str, float, int]: + """Execute a standard CLI/Docker plugin and persist findings/report.""" + plugin_manager = get_plugin_manager() + command = plugin_manager.build_command(plugin_id, inputs) + + if not command: + raise ValueError("Failed to build command") + + # Apply Docker Sandboxing if enabled + if settings.docker_enabled: + await self._ensure_docker_network() + docker_image = plugin.docker_image or "alpine:latest" + docker_cmd = [ + "docker", + "run", + "--rm", + "--name", + f"secuscan_task_{task_id}", + "--memory", + f"{settings.sandbox_memory_mb}m", + "--cpus", + str(settings.sandbox_cpu_quota), + "--cap-drop", "NET_RAW", + "--network", settings.docker_network, + docker_image, + ] + command = docker_cmd + command + + logger.info(f"Executing task {task_id}: {' '.join(command)}") + await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) + await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) + + # Execute command + start_time = time.time() + output, exit_code = await self._execute_command( + command, + task_id, + timeout=self._resolve_execution_timeout(inputs), + ) + duration = time.time() - start_time + + # Save raw output + raw_path = Path(settings.raw_output_dir) / f"{task_id}.txt" + output = redact(output) + with open(raw_path, 'w') as f: + f.write(output) + + # Classify result + final_status, error_message = self._classify_command_result( + plugin=plugin, + output=output, + exit_code=exit_code, + ) + + await db.execute( + """ + UPDATE tasks SET + status = ?, + completed_at = ?, + duration_seconds = ?, + exit_code = ?, + raw_output_path = ?, + command_used = ?, + error_message = ? + WHERE id = ? + """, + ( + final_status, + datetime.now().isoformat(), + duration, + exit_code, + str(raw_path), + " ".join(command), + error_message, + task_id, + ), + ) + + # Upsert findings and report + await self._broadcast_phase(task_id, ScanPhase.PARSING.value) + await self._upsert_findings_and_report( + db=db, + task_id=task_id, + owner_id=owner_id, + plugin=plugin, + plugin_id=plugin_id, + target=target, + status=final_status, + output=output, + ) + await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) + return final_status, duration, exit_code + + async def execute_task(self, task_id: str) -> None: """ Execute a task asynchronously. - + Args: task_id: Task identifier """ db = await get_db() self.running_tasks[task_id] = asyncio.current_task() + start_time = time.time() try: # Update status to running — use optimistic lock to detect @@ -399,73 +683,8 @@ async def execute_task(self, task_id: str): ) # ── Safe Mode & Network policy enforcement ─────────────────────── - # Enforce Safe Mode target validation inside TaskExecutor to guarantee - # that all execution paths (manual API, workflows, scheduled tasks) are protected. - if target: - plugin_manager = get_plugin_manager() - plugin = plugin_manager.get_plugin(plugin_id) - should_validate = True - if plugin and plugin.category == "code": - should_validate = False - - - # Use shared is_filesystem_target from validation to ensure - # consistent filesystem detection across route and executor layers. - from .validation import is_filesystem_target - is_fs = is_filesystem_target(target) - - if should_validate and not is_fs: - from .validation import validate_target - try: - # Enforce safe mode validation of target address in a thread pool - is_valid, error_msg = await asyncio.wait_for( - asyncio.to_thread(validate_target, target, safe_mode), - timeout=float(settings.dns_resolution_timeout_seconds), - ) - if not is_valid: - await self.mark_task_failed( - task_id, - f"Safe mode target validation failed: {error_msg}", - ) - await self._broadcast(task_id, "status", TaskStatus.FAILED.value) - return - except asyncio.TimeoutError: - await self.mark_task_failed( - task_id, - "Target validation timed out (SecuScan Guardrail)", - ) - await self._broadcast(task_id, "status", TaskStatus.FAILED.value) - return - - # Check before launching any scanner or subprocess. check_access() - # writes an audit entry on every path, so no extra logging needed. - if target and settings.enforce_network_policy: - engine = get_policy_engine() - try: - allowed, reason, _ = await asyncio.wait_for( - asyncio.to_thread( - engine.check_access, - dest_ip=target, - plugin_id=plugin_id, - task_id=task_id, - ), - timeout=float(settings.dns_resolution_timeout_seconds), - ) - except asyncio.TimeoutError: - allowed, reason = False, "Network policy check timed out (DNS resolution timeout)" - - if not allowed: - if settings.network_policy_failure_mode == "log_only": - logger.warning( - f"[Log Only] Network policy violation allowed for {target}: {reason}" - ) - else: - await self.mark_task_failed( - task_id, - f"Network policy denied access to {target}: {reason}", - ) - await self._broadcast(task_id, "status", TaskStatus.FAILED.value) - return # finally block handles running_tasks cleanup + limiter release + if not await self._enforce_guardrails(target, plugin_id, safe_mode, task_id): + return # Check if this is a modular scanner or a standard plugin plugin_manager = get_plugin_manager() @@ -485,180 +704,26 @@ async def execute_task(self, task_id: str): ) if plugin_id in MODULAR_SCANNERS: - scanner_class = MODULAR_SCANNERS[plugin_id] - scanner = scanner_class(task_id, db, safe_mode=safe_mode) - - logger.info(f"Executing modular scanner {plugin_id} for task {task_id}") - await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) - await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) - - start_time = time.time() - # Run the scanner - result = await scanner.run(target, inputs) - duration = time.time() - start_time - - # Update task with results - final_status = TaskStatus.COMPLETED.value if result.get("status") != "failed" else TaskStatus.FAILED.value - - await db.execute( - """ - UPDATE tasks SET - status = ?, - completed_at = ?, - duration_seconds = ?, - structured_json = ?, - error_message = ? - WHERE id = ? - """, - ( - final_status, - datetime.now().isoformat(), - duration, - json.dumps(result), - result.get("error_message"), - task_id - ) - ) - - # Upsert findings and report using the scanner's result - await self._broadcast_phase(task_id, ScanPhase.PARSING.value) - await self._upsert_findings_and_report_from_scanner( + final_status, duration = await self._execute_modular_scanner( db=db, task_id=task_id, owner_id=owner_id, - scanner=scanner, plugin_id=plugin_id, target=target, - status=final_status, - result=result + inputs=inputs, + safe_mode=safe_mode, ) - await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) - + exit_code = 0 else: - # Standard Plugin Execution - command = plugin_manager.build_command(plugin_id, inputs) - - if not command: - raise ValueError("Failed to build command") - - # Apply Docker Sandboxing if enabled - if settings.docker_enabled: - # Validate the named Docker network exists before using it. - # If missing, attempt to create it automatically rather than failing. - _net_check = await asyncio.create_subprocess_exec( - "docker", "network", "inspect", settings.docker_network, - stdout=asyncio.subprocess.DEVNULL, - stderr=asyncio.subprocess.DEVNULL, - ) - await _net_check.wait() - if _net_check.returncode != 0: - logger.info(f"Docker network '{settings.docker_network}' not found. Creating isolated bridge network (ICC disabled)...") - _net_create = await asyncio.create_subprocess_exec( - "docker", "network", "create", - "--driver", "bridge", - "--opt", "com.docker.network.bridge.enable_icc=false", - settings.docker_network, - stdout=asyncio.subprocess.DEVNULL, - stderr=asyncio.subprocess.DEVNULL, - ) - await _net_create.wait() - if _net_create.returncode != 0: - logger.warning("Failed to create isolated bridge network with ICC disabled. Falling back to standard bridge...") - _net_create_fallback = await asyncio.create_subprocess_exec( - "docker", "network", "create", "--driver", "bridge", settings.docker_network, - stdout=asyncio.subprocess.DEVNULL, - stderr=asyncio.subprocess.DEVNULL, - ) - await _net_create_fallback.wait() - if _net_create_fallback.returncode != 0: - raise RuntimeError( - f"Docker network '{settings.docker_network}' does not exist and could not be created automatically." - ) - logger.info(f"Successfully created Docker network '{settings.docker_network}' (fallback)") - else: - logger.info(f"Successfully created Docker network '{settings.docker_network}' with ICC disabled") - - docker_image = plugin.docker_image or "alpine:latest" - docker_cmd = [ - "docker", - "run", - "--rm", - "--name", - f"secuscan_task_{task_id}", - "--memory", - f"{settings.sandbox_memory_mb}m", - "--cpus", - str(settings.sandbox_cpu_quota), - "--cap-drop", "NET_RAW", - "--network", settings.docker_network, - docker_image, - ] - command = docker_cmd + command - - logger.info(f"Executing task {task_id}: {' '.join(command)}") - await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) - await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) - - # Execute command - start_time = time.time() - output, exit_code = await self._execute_command( - command, - task_id, - timeout=self._resolve_execution_timeout(inputs), - ) - duration = time.time() - start_time - - # Save raw output - raw_path = Path(settings.raw_output_dir) / f"{task_id}.txt" - output = redact(output) - with open(raw_path, 'w') as f: - f.write(output) - - # Some CLI tools use non-zero exit codes for "no result" states while still - # producing a complete, parseable report. Let plugin metadata opt into that. - final_status, error_message = self._classify_command_result( - plugin=plugin, - output=output, - exit_code=exit_code, - ) - - await db.execute( - """ - UPDATE tasks SET - status = ?, - completed_at = ?, - duration_seconds = ?, - exit_code = ?, - raw_output_path = ?, - command_used = ?, - error_message = ? - WHERE id = ? - """, - ( - final_status, - datetime.now().isoformat(), - duration, - exit_code, - str(raw_path), - " ".join(command), - error_message, - task_id - ) - ) - - # Upsert findings and report - await self._broadcast_phase(task_id, ScanPhase.PARSING.value) - await self._upsert_findings_and_report( + final_status, duration, exit_code = await self._execute_standard_scanner( db=db, task_id=task_id, owner_id=owner_id, plugin=plugin, plugin_id=plugin_id, target=target, - status=final_status, - output=output + inputs=inputs, ) - await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) await self._dispatch_task_notifications(db, task_id) @@ -670,7 +735,7 @@ async def execute_task(self, task_id: str): await db.log_audit( "task_completed", f"Task completed in {duration:.2f}s", - context={"task_id": task_id, "exit_code": locals().get('exit_code', 0)}, + context={"task_id": task_id, "exit_code": exit_code}, task_id=task_id, plugin_id=plugin_id ) @@ -678,11 +743,6 @@ async def execute_task(self, task_id: str): logger.info(f"Task {task_id} completed in {duration:.2f}s") except asyncio.CancelledError: - # CancelledError inherits from BaseException, not Exception — - # it bypasses the broad except below, so we handle it explicitly. - # Task.cancelled() returns False while the finally block is still - # executing, so this is the only reliable place to write the - # cancellation status to the DB. duration = (time.time() - start_time) if 'start_time' in locals() else 0 await db.execute( """ @@ -740,8 +800,6 @@ async def execute_task(self, task_id: str): except Exception as e: logger.error(f"Task {task_id} failed: {e}", exc_info=True) - - # Update task as failed duration = (time.time() - start_time) if 'start_time' in locals() else 0 await db.execute( """ @@ -772,12 +830,10 @@ async def execute_task(self, task_id: str): task_id=task_id ) finally: - # Always clean up: remove from the in-memory registry and - # release the concurrency slot regardless of how the task ended. self.running_tasks.pop(task_id, None) self._process_pids.pop(task_id, None) await concurrent_limiter.release(task_id) - + async def _execute_command( self, command: list, @@ -962,7 +1018,7 @@ async def cancel_task(self, task_id: str) -> bool: ) return True - + async def get_task_status(self, task_id: str) -> Optional[Dict]: """Get task status and progress""" db = await get_db() @@ -1506,12 +1562,12 @@ def _parse_results(self, plugin, output: str) -> Dict[str, Any]: """Route to appropriate parser based on plugin metadata.""" parser_type = plugin.output.get("parser") parser_input = self._resolve_parser_input(plugin, output) - + # 1. Check for custom parser.py in plugin directory (Recommended) plugin_manager = get_plugin_manager() plugin_dir = plugin_manager.plugins_dir / plugin.id parser_path = plugin_dir / "parser.py" - + if parser_path.exists(): if not plugin_manager.verify_parser_at_exec_time(plugin, plugin_dir): raise ValueError( @@ -1547,7 +1603,7 @@ def _parse_results(self, plugin, output: str) -> Dict[str, Any]: return self._normalize_parsed_result(plugin, parser_input, self._parse_nmap_output(parser_input)) elif parser_type == "builtin_http": return self._normalize_parsed_result(plugin, parser_input, self._parse_http_output(parser_input)) - + return self._normalize_parsed_result(plugin, parser_input, {"findings": [], "raw": parser_input}) def _resolve_parser_input(self, plugin, output: str) -> str: @@ -1720,7 +1776,7 @@ def _parse_nmap_output(self, output: str) -> Dict[str, Any]: findings = [] ports = [] services = [] - + # Regex for open ports: 80/tcp open http port_pattern = re.compile(r"(\d+)/(tcp|udp)\s+open\s+([\w-]+)") for match in port_pattern.finditer(output): @@ -1736,7 +1792,7 @@ def _parse_nmap_output(self, output: str) -> Dict[str, Any]: "remediation": "Close unnecessary ports and use a firewall to restrict access.", "metadata": {"port": port_str, "protocol": proto, "service": service} }) - + return { "open_ports": sorted(list(set(ports))), "services": sorted(list(set(services))), diff --git a/testing/backend/unit/test_executor.py b/testing/backend/unit/test_executor.py index dc611ed3..55c513aa 100644 --- a/testing/backend/unit/test_executor.py +++ b/testing/backend/unit/test_executor.py @@ -710,6 +710,172 @@ async def _wait(): await db.disconnect() +# --------------------------------------------------------------------------- +# Direct tests for extracted helper methods +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_enforce_guardrails_empty_target(): + executor = TaskExecutor() + # If target is empty, enforce_guardrails should immediately return True + res = await executor._enforce_guardrails("", "nmap", False, "task-1") + assert res is True + + +@pytest.mark.asyncio +async def test_enforce_guardrails_validation_failure(setup_test_environment): + await init_db(settings.database_path) + db = await get_db() + + executor = TaskExecutor() + task_id = str(uuid.uuid4()) + + # Pre-populate task in DB so mark_task_failed works + await db.execute( + "INSERT INTO tasks (id, plugin_id, tool_name, target, inputs_json, status, consent_granted, safe_mode) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (task_id, "nmap", "nmap", "127.0.0.1", "{}", TaskStatus.QUEUED.value, 1, 1) + ) + + with patch("backend.secuscan.executor.get_plugin_manager") as mock_pm, \ + patch("backend.secuscan.executor.asyncio.to_thread") as mock_to_thread: + + mock_plugin = MagicMock() + mock_plugin.category = "Network" + mock_pm.return_value.get_plugin.return_value = mock_plugin + + # validate_target returns (False, "invalid target") + mock_to_thread.return_value = (False, "invalid target") + + res = await executor._enforce_guardrails("127.0.0.1", "nmap", True, task_id) + assert res is False + + row = await db.fetchone("SELECT status, error_message FROM tasks WHERE id = ?", (task_id,)) + assert row["status"] == TaskStatus.FAILED.value + assert "Safe mode target validation failed" in row["error_message"] + await db.disconnect() + + +@pytest.mark.asyncio +async def test_enforce_guardrails_network_policy_failure(setup_test_environment): + await init_db(settings.database_path) + db = await get_db() + + executor = TaskExecutor() + task_id = str(uuid.uuid4()) + + await db.execute( + "INSERT INTO tasks (id, plugin_id, tool_name, target, inputs_json, status, consent_granted, safe_mode) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (task_id, "nmap", "nmap", "10.0.0.1", "{}", TaskStatus.QUEUED.value, 1, 0) + ) + + mock_engine = MagicMock() + mock_engine.check_access.return_value = (False, "Blocked by policy", None) + + with patch("backend.secuscan.executor.settings") as mock_settings, \ + patch("backend.secuscan.executor.get_policy_engine", return_value=mock_engine), \ + patch("backend.secuscan.executor.get_plugin_manager") as mock_pm: + + mock_settings.enforce_network_policy = True + mock_settings.network_policy_failure_mode = "block" + mock_settings.dns_resolution_timeout_seconds = 5 + + mock_plugin = MagicMock() + mock_plugin.category = "Network" + mock_pm.return_value.get_plugin.return_value = mock_plugin + + res = await executor._enforce_guardrails("10.0.0.1", "nmap", False, task_id) + assert res is False + + row = await db.fetchone("SELECT status, error_message FROM tasks WHERE id = ?", (task_id,)) + assert row["status"] == TaskStatus.FAILED.value + assert "Network policy denied access" in row["error_message"] + await db.disconnect() + + +@pytest.mark.asyncio +async def test_ensure_docker_network_exists(): + executor = TaskExecutor() + + proc = MagicMock() + proc.returncode = 0 + proc.wait = AsyncMock(return_value=0) + + with patch("backend.secuscan.executor.asyncio.create_subprocess_exec", return_value=proc) as mock_create: + await executor._ensure_docker_network() + + # Should only call inspect network once + mock_create.assert_called_once_with( + "docker", "network", "inspect", settings.docker_network, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + + +@pytest.mark.asyncio +async def test_execute_modular_scanner(setup_test_environment): + await init_db(settings.database_path) + db = await get_db() + + task_id = str(uuid.uuid4()) + owner_id = str(uuid.uuid4()) + + # Insert task in DB + await db.execute( + """ + INSERT INTO tasks (id, owner_id, plugin_id, tool_name, target, inputs_json, status, consent_granted, safe_mode) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (task_id, owner_id, "mock_scanner", "mock_scanner", "127.0.0.1", "{}", TaskStatus.QUEUED.value, 1, 0) + ) + + class MockScanner: + name = "MockScanner" + def __init__(self, task_id, db, safe_mode=False): + self.task_id = task_id + self.db = db + self.safe_mode = safe_mode + + async def run(self, target, inputs): + return { + "status": "completed", + "findings": [ + { + "title": "Mock Finding", + "category": "Mock Category", + "severity": "low", + "description": "Mock description", + } + ], + "asset_services": [] + } + + executor = TaskExecutor() + + # We patch the MODULAR_SCANNERS dictionary in backend.secuscan.executor + with patch.dict("backend.secuscan.executor.MODULAR_SCANNERS", {"mock_scanner": MockScanner}): + status, duration = await executor._execute_modular_scanner( + db=db, + task_id=task_id, + owner_id=owner_id, + plugin_id="mock_scanner", + target="127.0.0.1", + inputs={}, + safe_mode=False + ) + + assert status == TaskStatus.COMPLETED.value + assert duration >= 0 + + # Verify task updated in DB + row = await db.fetchone("SELECT status, structured_json FROM tasks WHERE id = ?", (task_id,)) + assert row["status"] == TaskStatus.COMPLETED.value + structured = json.loads(row["structured_json"]) + assert len(structured["findings"]) == 1 + assert structured["findings"][0]["title"] == "Mock Finding" + + await db.disconnect() + + @pytest.mark.asyncio async def test_execute_task_aborts_when_task_no_longer_queued(setup_test_environment): """ @@ -747,6 +913,59 @@ async def test_execute_task_aborts_when_task_no_longer_queued(setup_test_environ await db.disconnect() +@pytest.mark.asyncio +async def test_execute_standard_scanner(setup_test_environment): + await init_db(settings.database_path) + db = await get_db() + + task_id = str(uuid.uuid4()) + owner_id = str(uuid.uuid4()) + + # Insert task in DB + await db.execute( + """ + INSERT INTO tasks (id, owner_id, plugin_id, tool_name, target, inputs_json, status, consent_granted, safe_mode) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (task_id, owner_id, "mock_cli_plugin", "mock_cli_plugin", "127.0.0.1", "{}", TaskStatus.QUEUED.value, 1, 0) + ) + + executor = TaskExecutor() + + mock_plugin = MagicMock() + mock_plugin.id = "mock_cli_plugin" + mock_plugin.name = "mock_cli_plugin" + mock_plugin.docker_image = None + + with patch("backend.secuscan.executor.get_plugin_manager") as mock_pm, \ + patch.object(executor, "_execute_command", return_value=("Mock output\n", 0)) as mock_exec, \ + patch.object(executor, "_classify_command_result", return_value=(TaskStatus.COMPLETED.value, None)) as mock_classify, \ + patch.object(executor, "_upsert_findings_and_report") as mock_upsert: + + mock_pm.return_value.build_command.return_value = ["ping", "127.0.0.1"] + + status, duration, exit_code = await executor._execute_standard_scanner( + db=db, + task_id=task_id, + owner_id=owner_id, + plugin=mock_plugin, + plugin_id="mock_cli_plugin", + target="127.0.0.1", + inputs={} + ) + + assert status == TaskStatus.COMPLETED.value + assert exit_code == 0 + assert duration >= 0 + + # Verify task updated in DB + row = await db.fetchone("SELECT status, exit_code FROM tasks WHERE id = ?", (task_id,)) + assert row["status"] == TaskStatus.COMPLETED.value + assert row["exit_code"] == 0 + + await db.disconnect() + + @pytest.mark.asyncio async def test_execute_task_aborts_when_task_deleted_before_running(setup_test_environment): """