From 441f6ff0ce80f6f923d231099e23dfd61291fbf4 Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Thu, 11 Jun 2026 18:18:41 +0530 Subject: [PATCH 01/13] feat: implement Redis scan result caching layer with bypass param and automatic hashing invalidation --- backend/secuscan/cache.py | 59 ++++++- backend/secuscan/executor.py | 225 +++++++++++++++++++++++- backend/secuscan/main.py | 4 +- backend/secuscan/routes.py | 3 +- testing/backend/unit/test_scan_cache.py | 125 +++++++++++++ 5 files changed, 405 insertions(+), 11 deletions(-) create mode 100644 testing/backend/unit/test_scan_cache.py diff --git a/backend/secuscan/cache.py b/backend/secuscan/cache.py index a42c3edbf..894ec0518 100644 --- a/backend/secuscan/cache.py +++ b/backend/secuscan/cache.py @@ -1,5 +1,5 @@ """ -In-memory cache helpers for API responses. +In-memory and Redis-based cache helpers for API responses. """ import json @@ -15,9 +15,15 @@ SWEEP_EVICT_FRACTION = 0.25 OPPORTUNISTIC_SWEEP_INTERVAL = 50 +try: + from redis.asyncio import Redis, ConnectionError as RedisConnectionError +except ImportError: + Redis = None + RedisConnectionError = Exception + class CacheClient: - """In-memory dictionary based cache client with TTL, size limit, and LRU eviction.""" + """Cache client supporting Redis with an in-memory dictionary fallback.""" def __init__(self, url: Optional[str] = None, max_entries: int = DEFAULT_MAX_ENTRIES): self.url = url @@ -28,11 +34,27 @@ def __init__(self, url: Optional[str] = None, max_entries: int = DEFAULT_MAX_ENT self._eviction_count = 0 self._sweep_count = 0 self._write_count = 0 + self.client: Optional[Redis] = None async def connect(self): - pass + if self.url and Redis is not None: + try: + self.client = Redis.from_url(self.url, decode_responses=True) + await self.client.ping() + logger.info("✓ Connected to Redis cache at %s", self.url) + except RedisConnectionError as e: + logger.warning("Failed to connect to Redis, falling back to in-memory: %s", e) + self.client = None + else: + self.client = None async def disconnect(self): + if self.client: + try: + await self.client.aclose() + except Exception: + pass + self.client = None self._data.clear() self._expires.clear() self._access_order.clear() @@ -60,7 +82,14 @@ def _evict_lru(self): self._eviction_count += evict_count async def get_json(self, key: str) -> Optional[Any]: - """Retrieve and parse JSON from memory, respecting TTL.""" + """Retrieve and parse JSON from cache, respecting TTL.""" + if self.client: + try: + val = await self.client.get(key) + return json.loads(val) if val is not None else None + except Exception as e: + logger.warning("Redis get_json error (falling back to in-memory): %s", e) + now = time.time() expiry = self._expires.get(key) @@ -76,12 +105,20 @@ async def get_json(self, key: str) -> Optional[Any]: return self._data.get(key) async def set_json(self, key: str, value: Any, ttl: Optional[int] = None): - """Store value in memory with optional TTL.""" + """Store value in cache with optional TTL.""" + actual_ttl = ttl or settings.cache_ttl_seconds + + if self.client: + try: + await self.client.set(key, json.dumps(value), ex=actual_ttl) + return + except Exception as e: + logger.warning("Redis set_json error (falling back to in-memory): %s", e) + if len(self._data) >= self.max_entries and key not in self._data: self._evict_lru() self._data[key] = value - actual_ttl = ttl or settings.cache_ttl_seconds self._expires[key] = time.time() + actual_ttl self._access_order[key] = time.time() self._write_count += 1 @@ -91,6 +128,15 @@ async def set_json(self, key: str, value: Any, ttl: Optional[int] = None): async def delete_prefix(self, prefix: str): """Delete all keys starting with prefix.""" + if self.client: + try: + keys = await self.client.keys(f"{prefix}*") + if keys: + await self.client.delete(*keys) + return + except Exception as e: + logger.warning("Redis delete_prefix error (falling back to in-memory): %s", e) + to_delete = [k for k in self._data.keys() if k.startswith(prefix)] for k in to_delete: self._data.pop(k, None) @@ -128,3 +174,4 @@ async def get_cache() -> CacheClient: if cache is None: raise RuntimeError("Cache not initialized") return cache + diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index 4317b476f..61a50b0bf 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -155,6 +155,71 @@ def extract_target(inputs: Dict[str, Any]) -> str: ) +def generate_scan_cache_key(plugin_id: str, target: str) -> tuple[str, str, str]: + """ + Generate target hash, dependency hash, and a cache key. + + Returns: + tuple: (target_hash, dependency_hash, cache_key) + """ + import hashlib + import subprocess + from pathlib import Path + + target_hash = None + if target and os.path.isdir(target): + try: + res = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=target, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=5 + ) + if res.returncode == 0: + target_hash = res.stdout.strip() + except Exception: + pass + + if not target_hash: + target_hash = hashlib.sha256(str(target or "").encode("utf-8")).hexdigest() + + dependency_files = [ + "package-lock.json", + "poetry.lock", + "Cargo.lock", + "go.sum", + "requirements.txt", + "Pipfile.lock", + "pnpm-lock.yaml", + "yarn.lock", + "gemfile.lock" + ] + hasher = hashlib.sha256() + found_any = False + + if target and os.path.isdir(target): + p = Path(target) + for dep_file in sorted(dependency_files): + file_path = p / dep_file + if file_path.exists() and file_path.is_file(): + try: + hasher.update(dep_file.encode("utf-8")) + hasher.update(file_path.read_bytes()) + found_any = True + except Exception: + pass + + if not found_any: + dependency_hash = "no_deps" + else: + dependency_hash = hasher.hexdigest() + + cache_key = f"scan_cache:{plugin_id}:{target_hash}:{dependency_hash}" + return target_hash, dependency_hash, cache_key + + def _stable_asset_id(target: str, host: Any, port: Any, protocol: Any) -> str: material = "||".join( [ @@ -351,12 +416,13 @@ async def mark_task_failed(self, task_id: str, reason: str) -> None: task_id=task_id, ) - async def execute_task(self, task_id: str): + async def execute_task(self, task_id: str, bypass_cache: bool = False): """ Execute a task asynchronously. Args: task_id: Task identifier + bypass_cache: Whether to bypass Redis scan result cache """ db = await get_db() self.running_tasks[task_id] = asyncio.current_task() @@ -393,6 +459,134 @@ async def execute_task(self, task_id: str): execution_context=execution_context, ) + # Check cache if not bypassed + cached_result = None + cache_key = None + if target and not bypass_cache: + try: + target_hash, dependency_hash, cache_key = generate_scan_cache_key(plugin_id, target) + cache_client = await get_cache() + cached_result = await cache_client.get_json(cache_key) + except Exception as cache_exc: + logger.warning("Failed to query scan cache: %s", cache_exc) + + if cached_result is not None: + logger.info("Cache hit for scan task %s (key: %s)", task_id, cache_key) + await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) + await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) + + raw_path = Path(settings.raw_output_dir) / f"{task_id}.txt" + try: + with open(raw_path, 'w', encoding='utf-8') as f: + f.write(cached_result.get("raw_output", "")) + except Exception as f_exc: + logger.warning("Failed to write raw output for cached task: %s", f_exc) + + status = cached_result.get("status", TaskStatus.COMPLETED.value) + duration = cached_result.get("duration_seconds", 0.0) + exit_code = cached_result.get("exit_code", 0) + error_message = cached_result.get("error_message") + structured_data = cached_result.get("structured", {}) + + # Update task in SQLite with the cached results + await db.execute( + """ + UPDATE tasks SET + status = ?, + completed_at = ?, + duration_seconds = ?, + exit_code = ?, + raw_output_path = ?, + structured_json = ?, + error_message = ? + WHERE id = ? + """, + ( + status, + datetime.now().isoformat(), + duration, + exit_code, + str(raw_path), + json.dumps(structured_data), + error_message, + task_id + ) + ) + + # Persist findings and reports to database for the new task + await self._broadcast_phase(task_id, ScanPhase.PARSING.value) + + findings_data: List[Dict[str, Any]] = [] + for finding in structured_data.get("findings", []): + findings_data.append( + await self._persist_finding( + db, + owner_id=owner_id, + task_id=task_id, + plugin_id=plugin_id, + target=target, + finding=finding, + ) + ) + + # Create/Update report + plugin_manager = get_plugin_manager() + plugin = plugin_manager.get_plugin(plugin_id) + report_name = f"{plugin.name} Report" if plugin else f"{plugin_id} Report" + report_type = "technical" + pages = 1 + if plugin_id in MODULAR_SCANNERS: + scanner_class = MODULAR_SCANNERS[plugin_id] + report_name = f"{scanner_class.__name__} Report" + report_type = "professional" + pages = 2 + + await db.execute( + """ + INSERT INTO reports ( + id, owner_id, task_id, name, type, generated_at, status, findings, pages + ) VALUES (?, ?, ?, ?, ?, (datetime('now')), ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + status = EXCLUDED.status, + findings = EXCLUDED.findings, + pages = EXCLUDED.pages + """, + ( + f"report:{task_id}", + owner_id, + task_id, + report_name, + report_type, + "ready" if status == TaskStatus.COMPLETED.value else "failed", + len(findings_data), + pages, + ), + ) + + await self._persist_result_resources( + db, + owner_id=owner_id, + task_id=task_id, + plugin_id=plugin_id, + target=target, + result=structured_data, + ) + + await self._dispatch_task_notifications(db, task_id) + await self._broadcast_phase(task_id, ScanPhase.FINISHED.value) + await self._broadcast(task_id, "status", status) + await self._invalidate_cached_views() + + await db.log_audit( + "task_completed", + f"Task completed from cache (duration: {duration:.2f}s)", + context={"task_id": task_id, "exit_code": exit_code, "cached": True}, + task_id=task_id, + plugin_id=plugin_id + ) + logger.info(f"Task {task_id} completed from cache") + return + # ── Safe Mode & Network policy enforcement ─────────────────────── # Enforce Safe Mode target validation inside TaskExecutor to guarantee # that all execution paths (manual API, workflows, scheduled tasks) are protected. @@ -652,7 +846,34 @@ async def execute_task(self, task_id: str): status=final_status, output=output ) - await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) + if target and not bypass_cache: + try: + target_hash, dependency_hash, cache_key = generate_scan_cache_key(plugin_id, target) + task_data = await db.fetchone( + "SELECT status, duration_seconds, exit_code, error_message, structured_json, raw_output_path FROM tasks WHERE id = ?", + (task_id,) + ) + if task_data and task_data["status"] in (TaskStatus.COMPLETED.value, TaskStatus.FAILED.value): + raw_output = "" + if task_data["raw_output_path"]: + try: + with open(task_data["raw_output_path"], "r", encoding="utf-8") as f: + raw_output = f.read() + except Exception: + pass + cache_data = { + "status": task_data["status"], + "duration_seconds": task_data["duration_seconds"], + "exit_code": task_data["exit_code"], + "error_message": task_data["error_message"], + "raw_output": raw_output, + "structured": json.loads(task_data["structured_json"]) if task_data["structured_json"] else {} + } + cache_client = await get_cache() + await cache_client.set_json(cache_key, cache_data, ttl=86400) + logger.info("Saved scan results to cache for task %s (key: %s)", task_id, cache_key) + except Exception as cache_exc: + logger.warning("Failed to save scan results to cache: %s", cache_exc) await self._dispatch_task_notifications(db, task_id) diff --git a/backend/secuscan/main.py b/backend/secuscan/main.py index 8e06d6638..9b43d8c98 100644 --- a/backend/secuscan/main.py +++ b/backend/secuscan/main.py @@ -66,8 +66,8 @@ async def lifespan(app: FastAPI): await init_db(settings.database_path) logger.info("✓ SQLite connected") - await init_cache() - logger.info("✓ In-memory cache initialized") + await init_cache(settings.redis_url) + logger.info("✓ Cache initialized") # Load plugins await init_plugins(settings.plugins_dir) diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index 53bac0b84..e064d84e5 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -389,6 +389,7 @@ async def start_task( request: TaskCreateRequest, background_tasks: BackgroundTasks, raw_request: Request, + bypass_cache: bool = Query(False), owner: str = Depends(get_current_owner), ): """ @@ -534,7 +535,7 @@ async def start_task( # Use BackgroundTasks so the response can be sent without waiting in real # ASGI servers, while tests using TestClient still execute the task to keep # contract tests deterministic. - background_tasks.add_task(executor.execute_task, task_id) + background_tasks.add_task(executor.execute_task, task_id, bypass_cache=bypass_cache) await invalidate_view_cache() return { diff --git a/testing/backend/unit/test_scan_cache.py b/testing/backend/unit/test_scan_cache.py new file mode 100644 index 000000000..5100c3422 --- /dev/null +++ b/testing/backend/unit/test_scan_cache.py @@ -0,0 +1,125 @@ +import os +import json +import shutil +import tempfile +import pytest +import asyncio +from unittest.mock import AsyncMock, patch, MagicMock, ANY + +from backend.secuscan.executor import generate_scan_cache_key, TaskExecutor +from backend.secuscan.cache import init_cache, get_cache +from backend.secuscan.models import TaskStatus + +@pytest.fixture +def temp_repo(): + # Create a temporary directory structure representing a project + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) + +def test_generate_scan_cache_key_no_repo(temp_repo): + # If no git or dependency files exist, it hashes target string + target_hash, dep_hash, key = generate_scan_cache_key("test_plugin", temp_repo) + assert len(target_hash) == 64 + assert dep_hash == "no_deps" + assert key.startswith("scan_cache:test_plugin:") + +def test_generate_scan_cache_key_with_deps(temp_repo): + # Create package-lock.json + dep_file = os.path.join(temp_repo, "package-lock.json") + with open(dep_file, "w") as f: + f.write("npm-deps-v1") + + target_hash, dep_hash, key = generate_scan_cache_key("test_plugin", temp_repo) + assert len(target_hash) == 64 + assert dep_hash != "no_deps" + + # Modify package-lock.json -> dependency hash changes! + with open(dep_file, "w") as f: + f.write("npm-deps-v2") + + target_hash_2, dep_hash_2, key_2 = generate_scan_cache_key("test_plugin", temp_repo) + assert dep_hash != dep_hash_2 + assert key != key_2 + +@pytest.mark.asyncio +async def test_execute_task_cache_hit(temp_repo): + # Initialize in-memory cache + await init_cache() + + # We will mock the database and task run details + mock_db = AsyncMock() + mock_db.fetchone = AsyncMock(return_value={ + "owner_id": "owner_1", + "plugin_id": "test_plugin", + "inputs_json": json.dumps({"target": temp_repo}), + "execution_context_json": "{}", + "safe_mode": False + }) + + executor = TaskExecutor() + + # Pre-populate cache for this target + target_hash, dep_hash, cache_key = generate_scan_cache_key("test_plugin", temp_repo) + cache_client = await get_cache() + + cached_data = { + "status": TaskStatus.COMPLETED.value, + "duration_seconds": 1.5, + "exit_code": 0, + "error_message": None, + "raw_output": "cached output text", + "structured": { + "findings": [ + { + "title": "Cached Finding", + "category": "Code Security", + "severity": "high", + "description": "Cached desc" + } + ] + } + } + await cache_client.set_json(cache_key, cached_data) + + # We mock internal helper methods + executor._persist_finding = AsyncMock(return_value={"id": "finding_1"}) + executor._persist_result_resources = AsyncMock() + executor._dispatch_task_notifications = AsyncMock() + executor._invalidate_cached_views = AsyncMock() + + with patch("backend.secuscan.executor.get_db", return_value=mock_db), \ + patch("backend.secuscan.executor.get_plugin_manager") as mock_pm: + + mock_plugin = MagicMock() + mock_plugin.name = "Test Plugin" + mock_pm.return_value.get_plugin.return_value = mock_plugin + + await executor.execute_task("task_id_123", bypass_cache=False) + + # Verify db was updated with cached data + mock_db.execute.assert_any_call( + """ + UPDATE tasks SET + status = ?, + completed_at = ?, + duration_seconds = ?, + exit_code = ?, + raw_output_path = ?, + structured_json = ?, + error_message = ? + WHERE id = ? + """, + ( + TaskStatus.COMPLETED.value, + ANY, + 1.5, + 0, + ANY, + '{"findings": [{"title": "Cached Finding", "category": "Code Security", "severity": "high", "description": "Cached desc"}]}', + None, + "task_id_123" + ) + ) + # Verify it persisted the cached findings + executor._persist_finding.assert_called_once() From 93c92ffbd8b3cea53b4cad84ff21af81d135486f Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Thu, 11 Jun 2026 18:37:19 +0530 Subject: [PATCH 02/13] style: remove trailing whitespace in test_scan_cache.py --- testing/backend/unit/test_scan_cache.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/testing/backend/unit/test_scan_cache.py b/testing/backend/unit/test_scan_cache.py index 5100c3422..4da1f72ea 100644 --- a/testing/backend/unit/test_scan_cache.py +++ b/testing/backend/unit/test_scan_cache.py @@ -29,15 +29,15 @@ def test_generate_scan_cache_key_with_deps(temp_repo): dep_file = os.path.join(temp_repo, "package-lock.json") with open(dep_file, "w") as f: f.write("npm-deps-v1") - + target_hash, dep_hash, key = generate_scan_cache_key("test_plugin", temp_repo) assert len(target_hash) == 64 assert dep_hash != "no_deps" - + # Modify package-lock.json -> dependency hash changes! with open(dep_file, "w") as f: f.write("npm-deps-v2") - + target_hash_2, dep_hash_2, key_2 = generate_scan_cache_key("test_plugin", temp_repo) assert dep_hash != dep_hash_2 assert key != key_2 @@ -46,7 +46,7 @@ def test_generate_scan_cache_key_with_deps(temp_repo): async def test_execute_task_cache_hit(temp_repo): # Initialize in-memory cache await init_cache() - + # We will mock the database and task run details mock_db = AsyncMock() mock_db.fetchone = AsyncMock(return_value={ @@ -56,13 +56,13 @@ async def test_execute_task_cache_hit(temp_repo): "execution_context_json": "{}", "safe_mode": False }) - + executor = TaskExecutor() - + # Pre-populate cache for this target target_hash, dep_hash, cache_key = generate_scan_cache_key("test_plugin", temp_repo) cache_client = await get_cache() - + cached_data = { "status": TaskStatus.COMPLETED.value, "duration_seconds": 1.5, @@ -81,22 +81,22 @@ async def test_execute_task_cache_hit(temp_repo): } } await cache_client.set_json(cache_key, cached_data) - + # We mock internal helper methods executor._persist_finding = AsyncMock(return_value={"id": "finding_1"}) executor._persist_result_resources = AsyncMock() executor._dispatch_task_notifications = AsyncMock() executor._invalidate_cached_views = AsyncMock() - + with patch("backend.secuscan.executor.get_db", return_value=mock_db), \ patch("backend.secuscan.executor.get_plugin_manager") as mock_pm: - + mock_plugin = MagicMock() mock_plugin.name = "Test Plugin" mock_pm.return_value.get_plugin.return_value = mock_plugin - + await executor.execute_task("task_id_123", bypass_cache=False) - + # Verify db was updated with cached data mock_db.execute.assert_any_call( """ From 11bf6b42588c2c408af24baff549c042b8ae4b4a Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Thu, 11 Jun 2026 18:42:47 +0530 Subject: [PATCH 03/13] style: remove trailing whitespace in executor.py --- backend/secuscan/executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index 61a50b0bf..2722dff23 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -515,7 +515,7 @@ async def execute_task(self, task_id: str, bypass_cache: bool = False): # Persist findings and reports to database for the new task await self._broadcast_phase(task_id, ScanPhase.PARSING.value) - + findings_data: List[Dict[str, Any]] = [] for finding in structured_data.get("findings", []): findings_data.append( @@ -576,7 +576,7 @@ async def execute_task(self, task_id: str, bypass_cache: bool = False): await self._broadcast_phase(task_id, ScanPhase.FINISHED.value) await self._broadcast(task_id, "status", status) await self._invalidate_cached_views() - + await db.log_audit( "task_completed", f"Task completed from cache (duration: {duration:.2f}s)", From e2c7199149711454411c48fe5d2a29662ccc8eb7 Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Thu, 11 Jun 2026 18:49:59 +0530 Subject: [PATCH 04/13] style: remove trailing blank line at EOF of cache.py --- backend/secuscan/cache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/secuscan/cache.py b/backend/secuscan/cache.py index 894ec0518..68a3231e9 100644 --- a/backend/secuscan/cache.py +++ b/backend/secuscan/cache.py @@ -174,4 +174,3 @@ async def get_cache() -> CacheClient: if cache is None: raise RuntimeError("Cache not initialized") return cache - From 55e399f49490d7e54218cc9067a89c0b799fc61c Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Thu, 11 Jun 2026 19:25:41 +0530 Subject: [PATCH 05/13] test: reset global cache singleton in setup_test_environment to prevent test pollution --- testing/backend/conftest.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/testing/backend/conftest.py b/testing/backend/conftest.py index fc34fdb28..73ae1fd1f 100644 --- a/testing/backend/conftest.py +++ b/testing/backend/conftest.py @@ -26,6 +26,12 @@ def anyio_backend(): @pytest.fixture(autouse=True) def setup_test_environment(monkeypatch): """Override settings for tests to ensure isolated execution.""" + try: + from backend.secuscan import cache as cache_module + cache_module.cache = None + except ImportError: + pass + temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) temp_path = temp_dir.name From 62505e0af1fbd551956b63308b4996eb92a7dc16 Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Fri, 12 Jun 2026 12:09:49 +0530 Subject: [PATCH 06/13] feat: enhance scan caching safety, tenant isolation, and unify result persistence --- backend/secuscan/executor.py | 199 ++++++++++------------ testing/backend/unit/test_scan_cache.py | 212 ++++++++++++++++++++++-- 2 files changed, 280 insertions(+), 131 deletions(-) diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index 2722dff23..1afd39c91 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -155,9 +155,16 @@ def extract_target(inputs: Dict[str, Any]) -> str: ) -def generate_scan_cache_key(plugin_id: str, target: str) -> tuple[str, str, str]: +def generate_scan_cache_key( + owner_id: str, + plugin_id: str, + target: str, + inputs: Dict[str, Any], + execution_context: Dict[str, Any], + safe_mode: bool +) -> tuple[str, str, str]: """ - Generate target hash, dependency hash, and a cache key. + Generate target hash, dependency hash, and an owner-scoped cache key. Returns: tuple: (target_hash, dependency_hash, cache_key) @@ -216,7 +223,13 @@ def generate_scan_cache_key(plugin_id: str, target: str) -> tuple[str, str, str] else: dependency_hash = hasher.hexdigest() - cache_key = f"scan_cache:{plugin_id}:{target_hash}:{dependency_hash}" + inputs_str = json.dumps(inputs, sort_keys=True) + inputs_hash = hashlib.sha256(inputs_str.encode("utf-8")).hexdigest() + + context_str = json.dumps(execution_context, sort_keys=True) + context_hash = hashlib.sha256(context_str.encode("utf-8")).hexdigest() + + cache_key = f"scan_cache:{owner_id}:{plugin_id}:{int(safe_mode)}:{target_hash}:{dependency_hash}:{inputs_hash}:{context_hash}" return target_hash, dependency_hash, cache_key @@ -464,7 +477,14 @@ async def execute_task(self, task_id: str, bypass_cache: bool = False): cache_key = None if target and not bypass_cache: try: - target_hash, dependency_hash, cache_key = generate_scan_cache_key(plugin_id, target) + target_hash, dependency_hash, cache_key = generate_scan_cache_key( + owner_id=owner_id, + plugin_id=plugin_id, + target=target, + inputs=inputs, + execution_context=execution_context, + safe_mode=safe_mode, + ) cache_client = await get_cache() cached_result = await cache_client.get_json(cache_key) except Exception as cache_exc: @@ -497,7 +517,6 @@ async def execute_task(self, task_id: str, bypass_cache: bool = False): duration_seconds = ?, exit_code = ?, raw_output_path = ?, - structured_json = ?, error_message = ? WHERE id = ? """, @@ -507,7 +526,6 @@ async def execute_task(self, task_id: str, bypass_cache: bool = False): duration, exit_code, str(raw_path), - json.dumps(structured_data), error_message, task_id ) @@ -516,60 +534,25 @@ async def execute_task(self, task_id: str, bypass_cache: bool = False): # Persist findings and reports to database for the new task await self._broadcast_phase(task_id, ScanPhase.PARSING.value) - findings_data: List[Dict[str, Any]] = [] - for finding in structured_data.get("findings", []): - findings_data.append( - await self._persist_finding( - db, - owner_id=owner_id, - task_id=task_id, - plugin_id=plugin_id, - target=target, - finding=finding, - ) - ) - - # Create/Update report plugin_manager = get_plugin_manager() plugin = plugin_manager.get_plugin(plugin_id) - report_name = f"{plugin.name} Report" if plugin else f"{plugin_id} Report" - report_type = "technical" - pages = 1 - if plugin_id in MODULAR_SCANNERS: + is_modular = plugin_id in MODULAR_SCANNERS + if is_modular: scanner_class = MODULAR_SCANNERS[plugin_id] report_name = f"{scanner_class.__name__} Report" - report_type = "professional" - pages = 2 - - await db.execute( - """ - INSERT INTO reports ( - id, owner_id, task_id, name, type, generated_at, status, findings, pages - ) VALUES (?, ?, ?, ?, ?, (datetime('now')), ?, ?, ?) - ON CONFLICT (id) DO UPDATE SET - status = EXCLUDED.status, - findings = EXCLUDED.findings, - pages = EXCLUDED.pages - """, - ( - f"report:{task_id}", - owner_id, - task_id, - report_name, - report_type, - "ready" if status == TaskStatus.COMPLETED.value else "failed", - len(findings_data), - pages, - ), - ) + else: + report_name = f"{plugin.name} Report" if plugin else f"{plugin_id} Report" - await self._persist_result_resources( + await self._persist_findings_and_report_common( db, - owner_id=owner_id, task_id=task_id, + owner_id=owner_id, plugin_id=plugin_id, target=target, - result=structured_data, + status=status, + result_dict=structured_data, + is_modular=is_modular, + report_name=report_name, ) await self._dispatch_task_notifications(db, task_id) @@ -848,12 +831,19 @@ async def execute_task(self, task_id: str, bypass_cache: bool = False): ) if target and not bypass_cache: try: - target_hash, dependency_hash, cache_key = generate_scan_cache_key(plugin_id, target) + target_hash, dependency_hash, cache_key = generate_scan_cache_key( + owner_id=owner_id, + plugin_id=plugin_id, + target=target, + inputs=inputs, + execution_context=execution_context, + safe_mode=safe_mode, + ) task_data = await db.fetchone( "SELECT status, duration_seconds, exit_code, error_message, structured_json, raw_output_path FROM tasks WHERE id = ?", (task_id,) ) - if task_data and task_data["status"] in (TaskStatus.COMPLETED.value, TaskStatus.FAILED.value): + if task_data and task_data["status"] == TaskStatus.COMPLETED.value: raw_output = "" if task_data["raw_output_path"]: try: @@ -1553,16 +1543,27 @@ async def _persist_finding( "risk_factors": risk_factors, } - async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plugin, plugin_id: str, target: str, status: str, output: str = ""): - """Persist derived findings and report records into SQLite.""" - parsed = self._parse_results(plugin, output) + async def _persist_findings_and_report_common( + self, + db, + *, + task_id: str, + owner_id: str, + plugin_id: str, + target: str, + status: str, + result_dict: Dict[str, Any], + is_modular: bool, + report_name: str, + ) -> Dict[str, Any]: + """Common logic to persist findings, report, and result resources for a scan.""" structured_result, previous_findings, asset_services = await self._build_result_contract( db, task_id=task_id, owner_id=owner_id, plugin_id=plugin_id, target=target, - result=parsed, + result=result_dict, ) findings_data: List[Dict[str, Any]] = [] for finding in structured_result.get("findings", []): @@ -1588,6 +1589,13 @@ async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plu (json.dumps(structured_result), task_id) ) + if is_modular: + report_type = "professional" if status == TaskStatus.COMPLETED.value else "failed" + pages = 2 + else: + report_type = "technical" + pages = 1 + await db.execute( """ INSERT INTO reports ( @@ -1602,11 +1610,11 @@ async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plu f"report:{task_id}", owner_id, task_id, - f"{plugin.name} Report", - "technical", + report_name, + report_type, "ready" if status == TaskStatus.COMPLETED.value else "failed", len(findings_data), - 1, + pages, ), ) @@ -1618,73 +1626,38 @@ async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plu target=target, result=structured_result, ) + return structured_result - async def _upsert_findings_and_report_from_scanner(self, db, task_id: str, owner_id: str, scanner: Any, plugin_id: str, target: str, status: str, result: Dict[str, Any]): - """Persist modular scanner results into findings, and reports.""" - structured_result, previous_findings, asset_services = await self._build_result_contract( + async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plugin, plugin_id: str, target: str, status: str, output: str = ""): + """Persist derived findings and report records into SQLite.""" + parsed = self._parse_results(plugin, output) + await self._persist_findings_and_report_common( db, task_id=task_id, owner_id=owner_id, plugin_id=plugin_id, target=target, - result=result, - ) - findings_data: List[Dict[str, Any]] = [] - for finding in structured_result.get("findings", []): - findings_data.append( - await self._persist_finding( - db, - owner_id=owner_id, - task_id=task_id, - plugin_id=plugin_id, - target=target, - finding=finding, - ) - ) - - structured_result["findings"] = findings_data - structured_result["severity_counts"] = self._build_severity_counts(findings_data) - structured_result["finding_groups"] = build_finding_groups(findings_data) - structured_result["asset_summary"] = build_asset_summary(findings_data, asset_services) - structured_result["scan_diff"] = build_scan_diff(findings_data, previous_findings) - - await db.execute( - "UPDATE tasks SET structured_json = ? WHERE id = ?", - (json.dumps(structured_result), task_id) + status=status, + result_dict=parsed, + is_modular=False, + report_name=f"{plugin.name} Report", ) - # Create/Update report - await db.execute( - """ - INSERT INTO reports ( - id, owner_id, task_id, name, type, generated_at, status, findings, pages - ) VALUES (?, ?, ?, ?, ?, (datetime('now')), ?, ?, ?) - ON CONFLICT (id) DO UPDATE SET - status = EXCLUDED.status, - findings = EXCLUDED.findings, - pages = EXCLUDED.pages - """, - ( - f"report:{task_id}", - owner_id, - task_id, - f"{scanner.name} Report", - "professional" if status == TaskStatus.COMPLETED.value else "failed", - "ready" if status == TaskStatus.COMPLETED.value else "failed", - len(findings_data), - 2, # Professional reports are typically multi-page - ), - ) - - await self._persist_result_resources( + async def _upsert_findings_and_report_from_scanner(self, db, task_id: str, owner_id: str, scanner: Any, plugin_id: str, target: str, status: str, result: Dict[str, Any]): + """Persist modular scanner results into findings, and reports.""" + await self._persist_findings_and_report_common( db, - owner_id=owner_id, task_id=task_id, + owner_id=owner_id, plugin_id=plugin_id, target=target, - result=structured_result, + status=status, + result_dict=result, + is_modular=True, + report_name=f"{scanner.name} Report", ) + async def _persist_result_resources( self, db, diff --git a/testing/backend/unit/test_scan_cache.py b/testing/backend/unit/test_scan_cache.py index 4da1f72ea..ed759d97e 100644 --- a/testing/backend/unit/test_scan_cache.py +++ b/testing/backend/unit/test_scan_cache.py @@ -9,6 +9,7 @@ from backend.secuscan.executor import generate_scan_cache_key, TaskExecutor from backend.secuscan.cache import init_cache, get_cache from backend.secuscan.models import TaskStatus +from backend.secuscan.execution_context import normalize_execution_context @pytest.fixture def temp_repo(): @@ -19,10 +20,17 @@ def temp_repo(): def test_generate_scan_cache_key_no_repo(temp_repo): # If no git or dependency files exist, it hashes target string - target_hash, dep_hash, key = generate_scan_cache_key("test_plugin", temp_repo) + target_hash, dep_hash, key = generate_scan_cache_key( + owner_id="owner_1", + plugin_id="test_plugin", + target=temp_repo, + inputs={"target": temp_repo}, + execution_context={}, + safe_mode=False + ) assert len(target_hash) == 64 assert dep_hash == "no_deps" - assert key.startswith("scan_cache:test_plugin:") + assert key.startswith("scan_cache:owner_1:test_plugin:0:") def test_generate_scan_cache_key_with_deps(temp_repo): # Create package-lock.json @@ -30,7 +38,14 @@ def test_generate_scan_cache_key_with_deps(temp_repo): with open(dep_file, "w") as f: f.write("npm-deps-v1") - target_hash, dep_hash, key = generate_scan_cache_key("test_plugin", temp_repo) + target_hash, dep_hash, key = generate_scan_cache_key( + owner_id="owner_1", + plugin_id="test_plugin", + target=temp_repo, + inputs={"target": temp_repo}, + execution_context={}, + safe_mode=False + ) assert len(target_hash) == 64 assert dep_hash != "no_deps" @@ -38,29 +53,111 @@ def test_generate_scan_cache_key_with_deps(temp_repo): with open(dep_file, "w") as f: f.write("npm-deps-v2") - target_hash_2, dep_hash_2, key_2 = generate_scan_cache_key("test_plugin", temp_repo) + target_hash_2, dep_hash_2, key_2 = generate_scan_cache_key( + owner_id="owner_1", + plugin_id="test_plugin", + target=temp_repo, + inputs={"target": temp_repo}, + execution_context={}, + safe_mode=False + ) assert dep_hash != dep_hash_2 assert key != key_2 +def test_cache_key_tenant_isolation(temp_repo): + # Same inputs/target, different owners -> different cache keys! + _, _, key_owner1 = generate_scan_cache_key( + owner_id="owner_1", + plugin_id="test_plugin", + target=temp_repo, + inputs={"target": temp_repo, "flag": "x"}, + execution_context={"profile": "admin"}, + safe_mode=False + ) + _, _, key_owner2 = generate_scan_cache_key( + owner_id="owner_2", + plugin_id="test_plugin", + target=temp_repo, + inputs={"target": temp_repo, "flag": "x"}, + execution_context={"profile": "admin"}, + safe_mode=False + ) + assert key_owner1 != key_owner2 + +def test_cache_key_inputs_isolation(temp_repo): + # Same target/owner, different inputs/flags -> different cache keys! + _, _, key_inputs1 = generate_scan_cache_key( + owner_id="owner_1", + plugin_id="test_plugin", + target=temp_repo, + inputs={"target": temp_repo, "wordlist": "common.txt"}, + execution_context={}, + safe_mode=False + ) + _, _, key_inputs2 = generate_scan_cache_key( + owner_id="owner_1", + plugin_id="test_plugin", + target=temp_repo, + inputs={"target": temp_repo, "wordlist": "deep.txt"}, + execution_context={}, + safe_mode=False + ) + assert key_inputs1 != key_inputs2 + +def test_cache_key_safe_mode_isolation(temp_repo): + # Same inputs/owner, safe_mode toggled -> different cache keys! + _, _, key_safe = generate_scan_cache_key( + owner_id="owner_1", + plugin_id="test_plugin", + target=temp_repo, + inputs={"target": temp_repo}, + execution_context={}, + safe_mode=True + ) + _, _, key_unsafe = generate_scan_cache_key( + owner_id="owner_1", + plugin_id="test_plugin", + target=temp_repo, + inputs={"target": temp_repo}, + execution_context={}, + safe_mode=False + ) + assert key_unsafe != key_safe + @pytest.mark.asyncio async def test_execute_task_cache_hit(temp_repo): # Initialize in-memory cache await init_cache() - # We will mock the database and task run details + # We will mock the database and task run details using a SQL-inspecting side effect mock_db = AsyncMock() - mock_db.fetchone = AsyncMock(return_value={ - "owner_id": "owner_1", - "plugin_id": "test_plugin", - "inputs_json": json.dumps({"target": temp_repo}), - "execution_context_json": "{}", - "safe_mode": False - }) + async def db_fetchone_mock(query, params=()): + query_lower = query.lower() + if "select owner_id, plugin_id" in query_lower: + return { + "owner_id": "owner_1", + "plugin_id": "test_plugin", + "inputs_json": json.dumps({"target": temp_repo}), + "execution_context_json": "{}", + "safe_mode": False + } + return None + mock_db.fetchone = AsyncMock(side_effect=db_fetchone_mock) executor = TaskExecutor() - # Pre-populate cache for this target - target_hash, dep_hash, cache_key = generate_scan_cache_key("test_plugin", temp_repo) + # Pre-populate cache for this target/owner/inputs/context/safe_mode + # Note: inputs is hydrated inside execute_task to contain normalized execution_context + execution_context = normalize_execution_context({}) + inputs = {"target": temp_repo, "__execution_context": execution_context} + target_hash, dep_hash, cache_key = generate_scan_cache_key( + owner_id="owner_1", + plugin_id="test_plugin", + target=temp_repo, + inputs=inputs, + execution_context=execution_context, + safe_mode=False + ) cache_client = await get_cache() cached_data = { @@ -97,7 +194,7 @@ async def test_execute_task_cache_hit(temp_repo): await executor.execute_task("task_id_123", bypass_cache=False) - # Verify db was updated with cached data + # Verify db was updated with status, duration, etc. mock_db.execute.assert_any_call( """ UPDATE tasks SET @@ -106,7 +203,6 @@ async def test_execute_task_cache_hit(temp_repo): duration_seconds = ?, exit_code = ?, raw_output_path = ?, - structured_json = ?, error_message = ? WHERE id = ? """, @@ -116,10 +212,90 @@ async def test_execute_task_cache_hit(temp_repo): 1.5, 0, ANY, - '{"findings": [{"title": "Cached Finding", "category": "Code Security", "severity": "high", "description": "Cached desc"}]}', None, "task_id_123" ) ) - # Verify it persisted the cached findings + + # Verify it persisted the cached findings and updated structured_json executor._persist_finding.assert_called_once() + mock_db.execute.assert_any_call( + "UPDATE tasks SET structured_json = ? WHERE id = ?", + (ANY, "task_id_123") + ) + +@pytest.mark.asyncio +async def test_execute_task_transient_failure_not_cached(temp_repo): + # Initialize in-memory cache + await init_cache() + + mock_db = AsyncMock() + async def db_fetchone_mock(query, params=()): + query_lower = query.lower() + if "select owner_id, plugin_id" in query_lower: + return { + "owner_id": "owner_1", + "plugin_id": "test_plugin", + "inputs_json": json.dumps({"target": temp_repo}), + "execution_context_json": "{}", + "safe_mode": False + } + if "select status, duration_seconds" in query_lower: + return { + "status": TaskStatus.FAILED.value, + "duration_seconds": 2.0, + "exit_code": 1, + "error_message": "Transient network timeout", + "structured_json": None, + "raw_output_path": None + } + return None + mock_db.fetchone = AsyncMock(side_effect=db_fetchone_mock) + + executor = TaskExecutor() + executor._persist_findings_and_report_common = AsyncMock() + executor._dispatch_task_notifications = AsyncMock() + executor._invalidate_cached_views = AsyncMock() + + # Stub the actually executed command to fail + async def fake_command(*args, **kwargs): + return "Network timeout", 1 + + execution_context = normalize_execution_context({}) + inputs = {"target": temp_repo, "__execution_context": execution_context} + _, _, cache_key = generate_scan_cache_key( + owner_id="owner_1", + plugin_id="test_plugin", + target=temp_repo, + inputs=inputs, + execution_context=execution_context, + safe_mode=False + ) + + with patch("backend.secuscan.executor.get_db", return_value=mock_db), \ + patch.object(executor, "_execute_command", side_effect=fake_command), \ + patch("backend.secuscan.executor.get_plugin_manager") as mock_pm: + + mock_plugin = MagicMock() + mock_plugin.name = "Test Plugin" + mock_plugin.presets = {} + mock_plugin.docker_image = None + mock_plugin.output = {"parser": "builtin_nmap", "format": "text"} + mock_plugin.category = "Network" + mock_plugin.id = "test_plugin" + + mock_pm.return_value.get_plugin.return_value = mock_plugin + mock_pm.return_value.build_command.return_value = ["ping", temp_repo] + mock_pm.return_value.plugins_dir = MagicMock() + mock_pm.return_value.plugins_dir.__truediv__ = MagicMock( + return_value=MagicMock( + __truediv__=MagicMock(return_value=MagicMock(exists=lambda: False)) + ) + ) + + await executor.execute_task("task_id_456", bypass_cache=False) + + # The cache should be empty for this key because the task status is FAILED + cache_client = await get_cache() + cached_val = await cache_client.get_json(cache_key) + assert cached_val is None From 5e30d40443f1f616c2d3fd48800f7858e3a56497 Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Fri, 12 Jun 2026 13:24:00 +0530 Subject: [PATCH 07/13] feat(remediation): validate upgrade suggestions against transitive dependency graph --- backend/secuscan/executor.py | 15 + backend/secuscan/models.py | 3 + backend/secuscan/remediation.py | 407 ++++++++++++++++++ backend/secuscan/routes.py | 9 + .../backend/unit/test_remediation_safety.py | 170 ++++++++ 5 files changed, 604 insertions(+) create mode 100644 backend/secuscan/remediation.py create mode 100644 testing/backend/unit/test_remediation_safety.py diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index 4317b476f..d9ae1c211 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -1190,6 +1190,21 @@ async def _build_result_contract( target=target, findings=[item for item in result.get("findings", []) if isinstance(item, dict)], ) + + try: + from .remediation import build_dependency_graph, validate_remediation + graph = build_dependency_graph(target) + for f in normalized_findings: + remediation_str = f.get("remediation", "") + if remediation_str: + val_res = validate_remediation(remediation_str, graph) + f_metadata = f.setdefault("metadata", {}) + f_metadata["safe_to_apply"] = val_res["safe_to_apply"] + f_metadata["compatible_range"] = val_res["compatible_range"] + f_metadata["alternatives"] = val_res["alternatives"] + except Exception: + pass + previous_findings = await self._load_previous_task_findings( db, owner_id=owner_id, diff --git a/backend/secuscan/models.py b/backend/secuscan/models.py index b6cb61c03..b7501ff47 100644 --- a/backend/secuscan/models.py +++ b/backend/secuscan/models.py @@ -235,6 +235,9 @@ class Finding(BaseModel): evidence_count: int = 0 analyst_status: AnalystStatus = AnalystStatus.NEW retest_status: RetestStatus = RetestStatus.NOT_REQUESTED + safe_to_apply: Optional[bool] = None + compatible_range: Optional[str] = None + alternatives: Optional[List[str]] = None class TaskResult(BaseModel): diff --git a/backend/secuscan/remediation.py b/backend/secuscan/remediation.py new file mode 100644 index 000000000..7004b7430 --- /dev/null +++ b/backend/secuscan/remediation.py @@ -0,0 +1,407 @@ +""" +Dependency graph resolution and remediation conflict validation. +""" + +import json +import re +import importlib.metadata +from pathlib import Path +from typing import Dict, List, Any, Tuple +from packaging.version import Version +from packaging.specifiers import SpecifierSet + + +def normalize_package_name(name: str) -> str: + """Normalize a package name to lowercase with PEP 503 compatibility.""" + return re.sub(r"[-_.]+", "-", name).strip().lower() + + +def clean_version_string(ver_str: str) -> str: + """Extract numeric prefix from version strings for comparison.""" + ver_str = ver_str.strip().lower() + if ver_str.startswith("v"): + ver_str = ver_str[1:] + # Match the first sequence of digits and dots (e.g., "1.2.3" in "1.2.3-ubuntu") + match = re.match(r"^([0-9]+(?:\.[0-9]+)*)", ver_str) + if match: + return match.group(1) + return ver_str + + +def parse_remediation_suggestion(remediation_str: str) -> Tuple[str, str] | None: + """Parse recommendation string to extract package name and target upgrade version. + + Example: "Update framer-motion to version 11.0.0" -> ("framer-motion", "11.0.0") + """ + pattern = r"(?:update|upgrade)\s+([a-zA-Z0-9_\-\.]+)\s+(?:to\s+)?(?:version\s+)?([a-zA-Z0-9_\-\.\+\~]+)" + match = re.search(pattern, remediation_str, re.IGNORECASE) + if match: + pkg_name = normalize_package_name(match.group(1)) + version = match.group(2) + return pkg_name, version + return None + + +def handle_caret(ver_str: str) -> List[str]: + """Convert NPM caret specification to PEP 440 constraints. + + ^1.2.3 -> >=1.2.3, <2.0.0 + ^0.2.3 -> >=0.2.3, <0.3.0 + ^0.0.3 -> >=0.0.3, <0.0.4 + """ + parts = ver_str.split(".") + while len(parts) < 3: + parts.append("0") + + major = "".join(filter(str.isdigit, parts[0])) or "0" + minor = "".join(filter(str.isdigit, parts[1])) or "0" + patch = "".join(filter(str.isdigit, parts[2])) or "0" + + if major != "0": + next_major = int(major) + 1 + return [f">={ver_str}", f"<{next_major}.0.0"] + elif minor != "0": + next_minor = int(minor) + 1 + return [f">={ver_str}", f"<0.{next_minor}.0"] + else: + next_patch = int(patch) + 1 + return [f">={ver_str}", f"<0.0.{next_patch}"] + + +def handle_tilde(ver_str: str) -> List[str]: + """Convert NPM tilde specification to PEP 440 constraints. + + ~1.2.3 -> >=1.2.3, <1.3.0 + ~1.2 -> >=1.2.0, <1.3.0 + """ + parts = ver_str.split(".") + while len(parts) < 2: + parts.append("0") + major = "".join(filter(str.isdigit, parts[0])) or "0" + minor = "".join(filter(str.isdigit, parts[1])) or "0" + next_minor = int(minor) + 1 + return [f">={ver_str}", f"<{major}.{next_minor}.0"] + + +def handle_wildcard(part: str) -> List[str]: + """Convert wildcard version strings (e.g. 1.x or 1.*) to PEP 440 constraints.""" + part = part.replace("*", "x") + parts = part.split(".") + if len(parts) == 1 or parts[0] == "x": + return [] + if len(parts) == 2 or parts[1] == "x": + major = "".join(filter(str.isdigit, parts[0])) or "0" + next_major = int(major) + 1 + return [f">={major}.0.0", f"<{next_major}.0.0"] + if parts[2] == "x": + major = "".join(filter(str.isdigit, parts[0])) or "0" + minor = "".join(filter(str.isdigit, parts[1])) or "0" + next_minor = int(minor) + 1 + return [f">={major}.{minor}.0", f"<{major}.{next_minor}.0"] + return [] + + +def semver_to_pep440(semver_str: str) -> SpecifierSet: + """Convert NPM/semver package version specifier into PEP 440 SpecifierSet.""" + semver_str = semver_str.strip() + if not semver_str or semver_str in ("*", "x", "any"): + return SpecifierSet() + + parts = semver_str.split() + pep440_parts = [] + + for part in parts: + part = part.strip() + if not part: + continue + + if part.startswith("^"): + pep440_parts.extend(handle_caret(part[1:])) + elif part.startswith("~"): + pep440_parts.extend(handle_tilde(part[1:])) + elif "x" in part or "*" in part: + pep440_parts.extend(handle_wildcard(part)) + elif part.startswith((">=", "<=", ">", "<", "==")): + match = re.match(r"^(>=|<=|>|<|==)\s*([0-9a-zA-Z\.\-\+]+)$", part) + if match: + op, ver = match.groups() + pep440_parts.append(f"{op}{ver}") + else: + if re.match(r"^[0-9]+(?:\.[0-9]+)?$", part): + pep440_parts.extend(handle_wildcard(part + ".x")) + elif re.match(r"^[0-9a-zA-Z\.\-\+]+$", part): + pep440_parts.append(f"=={part}") + + try: + return SpecifierSet(",".join(pep440_parts)) + except Exception: + return SpecifierSet() + + +def parse_package_lock(filepath: str) -> Dict[str, List[Tuple[str, str]]]: + """Parse a package-lock.json and extract direct and transitive package dependency requirements.""" + try: + with open(filepath, "r", encoding="utf-8") as f: + data = json.load(f) + except Exception: + return {} + + relations = {} + + # Check for packages key (modern NPM lockfile v2/v3) + packages = data.get("packages", {}) + for path, info in packages.items(): + if not path: + parent = "root" + else: + parent = path.replace("node_modules/", "") + + deps = info.get("dependencies", {}) + peer_deps = info.get("peerDependencies", {}) + all_deps = {**deps, **peer_deps} + + if all_deps: + relations[parent] = [(normalize_package_name(k), v) for k, v in all_deps.items()] + + # Fallback to dependencies key (NPM lockfile v1) + dependencies = data.get("dependencies", {}) + def parse_v1_deps(deps_dict): + for name, info in deps_dict.items(): + requires = info.get("requires", {}) + if requires: + relations[name] = [(normalize_package_name(k), v) for k, v in requires.items()] + child_deps = info.get("dependencies", {}) + if child_deps: + parse_v1_deps(child_deps) + + if not packages and dependencies: + parse_v1_deps(dependencies) + + return relations + + +def parse_package_json(filepath: str) -> Dict[str, List[Tuple[str, str]]]: + """Parse a package.json for direct project dependencies.""" + try: + with open(filepath, "r", encoding="utf-8") as f: + data = json.load(f) + deps = data.get("dependencies", {}) + dev_deps = data.get("devDependencies", {}) + peer_deps = data.get("peerDependencies", {}) + all_deps = {**deps, **dev_deps, **peer_deps} + return { + "root": [(normalize_package_name(k), v) for k, v in all_deps.items()] + } + except Exception: + return {} + + +def parse_requirement_line(line: str) -> Tuple[str, SpecifierSet] | None: + """Parse a single requirements.txt line into a normalized package name and SpecifierSet.""" + line = line.strip() + if not line or line.startswith(('#', '-')): + return None + # Strip environment markers (e.g. "pydantic; python_version >= '3.8'") + line = line.split(";")[0].strip() + match = re.match(r"^([a-zA-Z0-9_\-\.]+)\s*(.*)$", line) + if not match: + return None + name, spec_str = match.groups() + # Normalize comparison operators if present + spec_str = spec_str.strip() + name = normalize_package_name(name) + try: + spec = SpecifierSet(spec_str) + except Exception: + spec = SpecifierSet() + return name, spec + + +def get_mock_dependencies(pkg_name: str) -> List[Tuple[str, str]]: + """Return mock transitive dependencies for deterministic offline unit testing.""" + registry = { + "library-y": [("library-x", "<2.0")], + "parent-package": [("child-package", "<=1.5.0")], + } + return registry.get(pkg_name, []) + + +def get_python_transitive_dependencies(package_name: str) -> List[Tuple[str, SpecifierSet]]: + """Retrieve python transitive dependencies from installed metadata.""" + try: + reqs = importlib.metadata.requires(package_name) + if not reqs: + return [] + dependencies = [] + for req in reqs: + req_clean = req.split(";")[0].strip() + match = re.match(r"^([a-zA-Z0-9_\-\.]+)\s*\((.*)\)$", req_clean) + if match: + dep_name, dep_spec = match.groups() + else: + match2 = re.match(r"^([a-zA-Z0-9_\-\.]+)\s*(.*)$", req_clean) + if match2: + dep_name, dep_spec = match2.groups() + else: + continue + dep_name = normalize_package_name(dep_name) + try: + spec = SpecifierSet(dep_spec) + except Exception: + spec = SpecifierSet() + dependencies.append((dep_name, spec)) + return dependencies + except importlib.metadata.PackageNotFoundError: + return [] + + +def build_dependency_graph(target_dir: str) -> Dict[str, List[Dict[str, Any]]]: + """Scan the target directory for Python/Node manifests and construct a transitive dependency constraint graph.""" + graph: Dict[str, List[Dict[str, Any]]] = {} + + target_path = Path(target_dir) if target_dir else Path(".") + + # 1. Search for python requirements + req_files = ["requirements.txt", "requirements-dev.txt"] + for req_name in req_files: + p = target_path / req_name + if not p.exists(): + # Fallback to local project root + p = Path("backend") / req_name + + if p.exists(): + try: + with open(p, "r", encoding="utf-8") as f: + for line in f: + parsed = parse_requirement_line(line) + if parsed: + name, spec = parsed + graph.setdefault(name, []).append({ + "parent": "root", + "specifier": spec + }) + + # Transitive resolution + mock_deps = get_mock_dependencies(name) + if mock_deps: + for dep_name, dep_spec_str in mock_deps: + try: + dep_spec = SpecifierSet(dep_spec_str) + except Exception: + dep_spec = SpecifierSet() + graph.setdefault(dep_name, []).append({ + "parent": name, + "specifier": dep_spec + }) + else: + for dep_name, dep_spec in get_python_transitive_dependencies(name): + graph.setdefault(dep_name, []).append({ + "parent": name, + "specifier": dep_spec + }) + except Exception: + pass + + # 2. Search for Node.js package-lock.json / package.json + lock_path = target_path / "package-lock.json" + if not lock_path.exists(): + lock_path = Path("frontend/package-lock.json") + + pkg_path = target_path / "package.json" + if not pkg_path.exists(): + pkg_path = Path("frontend/package.json") + + if lock_path.exists(): + try: + relations = parse_package_lock(str(lock_path)) + for parent, children in relations.items(): + for child_name, semver_str in children: + spec = semver_to_pep440(semver_str) + graph.setdefault(child_name, []).append({ + "parent": parent, + "specifier": spec + }) + except Exception: + pass + elif pkg_path.exists(): + try: + relations = parse_package_json(str(pkg_path)) + for parent, children in relations.items(): + for child_name, semver_str in children: + spec = semver_to_pep440(semver_str) + graph.setdefault(child_name, []).append({ + "parent": parent, + "specifier": spec + }) + except Exception: + pass + + return graph + + +def validate_remediation(remediation_str: str, graph: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]: + """Validate a remediation string against a dependency graph, yielding safety status and alternative actions.""" + res = { + "safe_to_apply": True, + "compatible_range": None, + "alternatives": [] + } + + parsed = parse_remediation_suggestion(remediation_str) + if not parsed: + return res + + pkg_name, target_version = parsed + if pkg_name not in graph: + return res + + constraints = graph[pkg_name] + specifiers = [c["specifier"] for c in constraints] + + clean_target = clean_version_string(target_version) + + is_safe = True + try: + ver = Version(clean_target) + for c in constraints: + spec = c["specifier"] + if ver not in spec: + is_safe = False + break + except Exception: + # Fall back to safe if parsing error happens to prevent blocking valid tools + pass + + if not is_safe: + res["safe_to_apply"] = False + + # Combine all constraints to show the allowed range + combined_parts = [] + for c in constraints: + for spec in c["specifier"]: + combined_parts.append(str(spec)) + res["compatible_range"] = ", ".join(combined_parts) if combined_parts else "N/A" + + # Determine which packages impose conflicting requirements + try: + ver = Version(clean_target) + conflicting_parents = sorted(list({ + c["parent"] for c in constraints if ver not in c["specifier"] + })) + except Exception: + conflicting_parents = sorted(list({c["parent"] for c in constraints})) + + for parent in conflicting_parents: + if parent == "root": + res["alternatives"].append( + f"Update root project constraints for '{pkg_name}' to allow version {target_version}." + ) + else: + res["alternatives"].append( + f"Upgrade parent package '{parent}' to a version that supports '{pkg_name}' version {target_version}." + ) + res["alternatives"].append( + f"Downgrade or keep '{pkg_name}' within compatible range: {res['compatible_range']}." + ) + + return res diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index 53bac0b84..76cc79800 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -54,6 +54,12 @@ def deserialize_finding_rows(rows: List[Dict]) -> List[Dict[str, Any]]: finding["references"] = finding.pop("references_json") if "corroborating_sources_json" in finding: finding["corroborating_sources"] = finding.pop("corroborating_sources_json") + + # Expose remediation safety fields at the top level + metadata = finding.get("metadata", {}) or {} + finding["safe_to_apply"] = metadata.get("safe_to_apply") + finding["compatible_range"] = metadata.get("compatible_range") + finding["alternatives"] = metadata.get("alternatives") return findings @@ -2174,6 +2180,9 @@ async def get_finding_details(finding_id: str, owner: str = Depends(get_current_ "asset_exposure": finding_row.get("asset_exposure"), "risk_score": finding_row.get("risk_score"), "risk_factors": risk_factors, + "safe_to_apply": metadata.get("safe_to_apply"), + "compatible_range": metadata.get("compatible_range"), + "alternatives": metadata.get("alternatives"), } diff --git a/testing/backend/unit/test_remediation_safety.py b/testing/backend/unit/test_remediation_safety.py new file mode 100644 index 000000000..0861efc49 --- /dev/null +++ b/testing/backend/unit/test_remediation_safety.py @@ -0,0 +1,170 @@ +import json +import tempfile +from pathlib import Path +import pytest +from packaging.specifiers import SpecifierSet +from backend.secuscan.remediation import ( + normalize_package_name, + clean_version_string, + parse_remediation_suggestion, + semver_to_pep440, + parse_package_lock, + parse_package_json, + parse_requirement_line, + build_dependency_graph, + validate_remediation +) +from backend.secuscan.models import Finding + + +def test_normalize_package_name(): + assert normalize_package_name("pydantic_settings") == "pydantic-settings" + assert normalize_package_name("Flask-RESTful") == "flask-restful" + assert normalize_package_name(" PyJWT ") == "pyjwt" + assert normalize_package_name("libssl.1.1") == "libssl-1-1" + + +def test_clean_version_string(): + assert clean_version_string("v1.2.3") == "1.2.3" + assert clean_version_string("1.1.1f-1ubuntu2.23") == "1.1.1" + assert clean_version_string("3.0.0-rc1") == "3.0.0" + assert clean_version_string("invalid") == "invalid" + + +def test_parse_remediation_suggestion(): + res1 = parse_remediation_suggestion("Update framer-motion to version 11.0.0") + assert res1 == ("framer-motion", "11.0.0") + + res2 = parse_remediation_suggestion("upgrade library-x to 2.4.1") + assert res2 == ("library-x", "2.4.1") + + res3 = parse_remediation_suggestion("Apply secure controls") + assert res3 is None + + +def test_semver_to_pep440(): + # Carets + assert semver_to_pep440("^1.2.3") == SpecifierSet(">=1.2.3,<2.0.0") + assert semver_to_pep440("^0.2.3") == SpecifierSet(">=0.2.3,<0.3.0") + assert semver_to_pep440("^0.0.3") == SpecifierSet(">=0.0.3,<0.0.4") + + # Tildes + assert semver_to_pep440("~1.2.3") == SpecifierSet(">=1.2.3,<1.3.0") + assert semver_to_pep440("~1.2") == SpecifierSet(">=1.2.0,<1.3.0") + + # Wildcards + assert semver_to_pep440("1.x") == SpecifierSet(">=1.0.0,<2.0.0") + assert semver_to_pep440("1.*") == SpecifierSet(">=1.0.0,<2.0.0") + assert semver_to_pep440("1.2.x") == SpecifierSet(">=1.2.0,<1.3.0") + + # Partial without wildcards + assert semver_to_pep440("1.2") == SpecifierSet(">=1.2.0,<1.3.0") + assert semver_to_pep440("1") == SpecifierSet(">=1.0.0,<2.0.0") + + # Operators & ranges + assert semver_to_pep440(">=1.0.0 <2.0.0") == SpecifierSet(">=1.0.0,<2.0.0") + assert semver_to_pep440("<=2.0.0") == SpecifierSet("<=2.0.0") + + # Exact and fallbacks + assert semver_to_pep440("1.2.3") == SpecifierSet("==1.2.3") + assert semver_to_pep440("*") == SpecifierSet("") + + +def test_parse_requirement_line(): + assert parse_requirement_line("fastapi>=0.115.0") == ("fastapi", SpecifierSet(">=0.115.0")) + assert parse_requirement_line("cryptography>=42.0.0 ; extra == 'ssl'") == ("cryptography", SpecifierSet(">=42.0.0")) + assert parse_requirement_line(" # commented line") is None + assert parse_requirement_line("") is None + + +def test_parse_package_lock(): + lock_data = { + "packages": { + "": { + "dependencies": { + "framer-motion": "^10.0.0" + } + }, + "node_modules/framer-motion": { + "version": "10.16.4", + "dependencies": { + "react": "^18.0.0" + } + } + } + } + with tempfile.TemporaryDirectory() as tmpdir: + lock_file = Path(tmpdir) / "package-lock.json" + with open(lock_file, "w") as f: + json.dump(lock_data, f) + + relations = parse_package_lock(str(lock_file)) + assert "root" in relations + assert relations["root"] == [("framer-motion", "^10.0.0")] + assert "framer-motion" in relations + assert relations["framer-motion"] == [("react", "^18.0.0")] + + +def test_parse_package_json(): + pkg_data = { + "dependencies": { + "express": "^4.17.1" + }, + "devDependencies": { + "jest": "^26.0.0" + } + } + with tempfile.TemporaryDirectory() as tmpdir: + pkg_file = Path(tmpdir) / "package.json" + with open(pkg_file, "w") as f: + json.dump(pkg_data, f) + + relations = parse_package_json(str(pkg_file)) + assert "root" in relations + assert ("express", "^4.17.1") in relations["root"] + assert ("jest", "^26.0.0") in relations["root"] + + +def test_validate_remediation_no_conflict(): + # If package not in graph, defaults to safe + graph = {} + res = validate_remediation("Update framer-motion to version 11.0.0", graph) + assert res["safe_to_apply"] is True + assert res["compatible_range"] is None + assert len(res["alternatives"]) == 0 + + +def test_validate_remediation_with_conflict(): + # Setup graph where root requires library-y, which transitively requires library-x <2.0 + graph = { + "library-x": [ + {"parent": "library-y", "specifier": SpecifierSet("<2.0")} + ] + } + + # Suggest upgrade of library-x to 1.5.0 (compatible with <2.0) + res_safe = validate_remediation("Update library-x to version 1.5.0", graph) + assert res_safe["safe_to_apply"] is True + + # Suggest upgrade of library-x to 2.1.0 (conflicts with <2.0) + res_unsafe = validate_remediation("Update library-x to version 2.1.0", graph) + assert res_unsafe["safe_to_apply"] is False + assert res_unsafe["compatible_range"] == "<2.0" + assert len(res_unsafe["alternatives"]) > 0 + assert any("library-y" in alt for alt in res_unsafe["alternatives"]) + + +def test_finding_model_safety_fields(): + finding = Finding( + title="Outdated dependency", + category="Dependency Vulnerability", + severity="high", + target="package.json", + description="Vulnerability in library-x", + safe_to_apply=False, + compatible_range="<2.0", + alternatives=["Upgrade library-y"] + ) + assert finding.safe_to_apply is False + assert finding.compatible_range == "<2.0" + assert finding.alternatives == ["Upgrade library-y"] From 7f464c218fe76b8c532e5de270a1122dd5d47fb0 Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Fri, 12 Jun 2026 15:49:41 +0530 Subject: [PATCH 08/13] style: remove trailing whitespaces to pass formatting-hygiene check --- backend/secuscan/remediation.py | 68 +++++++++---------- backend/secuscan/routes.py | 2 +- .../backend/unit/test_remediation_safety.py | 8 +-- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/backend/secuscan/remediation.py b/backend/secuscan/remediation.py index 7004b7430..b4b9a3ff9 100644 --- a/backend/secuscan/remediation.py +++ b/backend/secuscan/remediation.py @@ -30,7 +30,7 @@ def clean_version_string(ver_str: str) -> str: def parse_remediation_suggestion(remediation_str: str) -> Tuple[str, str] | None: """Parse recommendation string to extract package name and target upgrade version. - + Example: "Update framer-motion to version 11.0.0" -> ("framer-motion", "11.0.0") """ pattern = r"(?:update|upgrade)\s+([a-zA-Z0-9_\-\.]+)\s+(?:to\s+)?(?:version\s+)?([a-zA-Z0-9_\-\.\+\~]+)" @@ -44,7 +44,7 @@ def parse_remediation_suggestion(remediation_str: str) -> Tuple[str, str] | None def handle_caret(ver_str: str) -> List[str]: """Convert NPM caret specification to PEP 440 constraints. - + ^1.2.3 -> >=1.2.3, <2.0.0 ^0.2.3 -> >=0.2.3, <0.3.0 ^0.0.3 -> >=0.0.3, <0.0.4 @@ -52,11 +52,11 @@ def handle_caret(ver_str: str) -> List[str]: parts = ver_str.split(".") while len(parts) < 3: parts.append("0") - + major = "".join(filter(str.isdigit, parts[0])) or "0" minor = "".join(filter(str.isdigit, parts[1])) or "0" patch = "".join(filter(str.isdigit, parts[2])) or "0" - + if major != "0": next_major = int(major) + 1 return [f">={ver_str}", f"<{next_major}.0.0"] @@ -70,7 +70,7 @@ def handle_caret(ver_str: str) -> List[str]: def handle_tilde(ver_str: str) -> List[str]: """Convert NPM tilde specification to PEP 440 constraints. - + ~1.2.3 -> >=1.2.3, <1.3.0 ~1.2 -> >=1.2.0, <1.3.0 """ @@ -106,15 +106,15 @@ def semver_to_pep440(semver_str: str) -> SpecifierSet: semver_str = semver_str.strip() if not semver_str or semver_str in ("*", "x", "any"): return SpecifierSet() - + parts = semver_str.split() pep440_parts = [] - + for part in parts: part = part.strip() if not part: continue - + if part.startswith("^"): pep440_parts.extend(handle_caret(part[1:])) elif part.startswith("~"): @@ -131,7 +131,7 @@ def semver_to_pep440(semver_str: str) -> SpecifierSet: pep440_parts.extend(handle_wildcard(part + ".x")) elif re.match(r"^[0-9a-zA-Z\.\-\+]+$", part): pep440_parts.append(f"=={part}") - + try: return SpecifierSet(",".join(pep440_parts)) except Exception: @@ -145,9 +145,9 @@ def parse_package_lock(filepath: str) -> Dict[str, List[Tuple[str, str]]]: data = json.load(f) except Exception: return {} - + relations = {} - + # Check for packages key (modern NPM lockfile v2/v3) packages = data.get("packages", {}) for path, info in packages.items(): @@ -155,14 +155,14 @@ def parse_package_lock(filepath: str) -> Dict[str, List[Tuple[str, str]]]: parent = "root" else: parent = path.replace("node_modules/", "") - + deps = info.get("dependencies", {}) peer_deps = info.get("peerDependencies", {}) all_deps = {**deps, **peer_deps} - + if all_deps: relations[parent] = [(normalize_package_name(k), v) for k, v in all_deps.items()] - + # Fallback to dependencies key (NPM lockfile v1) dependencies = data.get("dependencies", {}) def parse_v1_deps(deps_dict): @@ -173,10 +173,10 @@ def parse_v1_deps(deps_dict): child_deps = info.get("dependencies", {}) if child_deps: parse_v1_deps(child_deps) - + if not packages and dependencies: parse_v1_deps(dependencies) - + return relations @@ -258,9 +258,9 @@ def get_python_transitive_dependencies(package_name: str) -> List[Tuple[str, Spe def build_dependency_graph(target_dir: str) -> Dict[str, List[Dict[str, Any]]]: """Scan the target directory for Python/Node manifests and construct a transitive dependency constraint graph.""" graph: Dict[str, List[Dict[str, Any]]] = {} - + target_path = Path(target_dir) if target_dir else Path(".") - + # 1. Search for python requirements req_files = ["requirements.txt", "requirements-dev.txt"] for req_name in req_files: @@ -268,7 +268,7 @@ def build_dependency_graph(target_dir: str) -> Dict[str, List[Dict[str, Any]]]: if not p.exists(): # Fallback to local project root p = Path("backend") / req_name - + if p.exists(): try: with open(p, "r", encoding="utf-8") as f: @@ -280,7 +280,7 @@ def build_dependency_graph(target_dir: str) -> Dict[str, List[Dict[str, Any]]]: "parent": "root", "specifier": spec }) - + # Transitive resolution mock_deps = get_mock_dependencies(name) if mock_deps: @@ -301,16 +301,16 @@ def build_dependency_graph(target_dir: str) -> Dict[str, List[Dict[str, Any]]]: }) except Exception: pass - + # 2. Search for Node.js package-lock.json / package.json lock_path = target_path / "package-lock.json" if not lock_path.exists(): lock_path = Path("frontend/package-lock.json") - + pkg_path = target_path / "package.json" if not pkg_path.exists(): pkg_path = Path("frontend/package.json") - + if lock_path.exists(): try: relations = parse_package_lock(str(lock_path)) @@ -335,7 +335,7 @@ def build_dependency_graph(target_dir: str) -> Dict[str, List[Dict[str, Any]]]: }) except Exception: pass - + return graph @@ -346,20 +346,20 @@ def validate_remediation(remediation_str: str, graph: Dict[str, List[Dict[str, A "compatible_range": None, "alternatives": [] } - + parsed = parse_remediation_suggestion(remediation_str) if not parsed: return res - + pkg_name, target_version = parsed if pkg_name not in graph: return res - + constraints = graph[pkg_name] specifiers = [c["specifier"] for c in constraints] - + clean_target = clean_version_string(target_version) - + is_safe = True try: ver = Version(clean_target) @@ -371,17 +371,17 @@ def validate_remediation(remediation_str: str, graph: Dict[str, List[Dict[str, A except Exception: # Fall back to safe if parsing error happens to prevent blocking valid tools pass - + if not is_safe: res["safe_to_apply"] = False - + # Combine all constraints to show the allowed range combined_parts = [] for c in constraints: for spec in c["specifier"]: combined_parts.append(str(spec)) res["compatible_range"] = ", ".join(combined_parts) if combined_parts else "N/A" - + # Determine which packages impose conflicting requirements try: ver = Version(clean_target) @@ -390,7 +390,7 @@ def validate_remediation(remediation_str: str, graph: Dict[str, List[Dict[str, A })) except Exception: conflicting_parents = sorted(list({c["parent"] for c in constraints})) - + for parent in conflicting_parents: if parent == "root": res["alternatives"].append( @@ -403,5 +403,5 @@ def validate_remediation(remediation_str: str, graph: Dict[str, List[Dict[str, A res["alternatives"].append( f"Downgrade or keep '{pkg_name}' within compatible range: {res['compatible_range']}." ) - + return res diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index 76cc79800..8732dda2b 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -54,7 +54,7 @@ def deserialize_finding_rows(rows: List[Dict]) -> List[Dict[str, Any]]: finding["references"] = finding.pop("references_json") if "corroborating_sources_json" in finding: finding["corroborating_sources"] = finding.pop("corroborating_sources_json") - + # Expose remediation safety fields at the top level metadata = finding.get("metadata", {}) or {} finding["safe_to_apply"] = metadata.get("safe_to_apply") diff --git a/testing/backend/unit/test_remediation_safety.py b/testing/backend/unit/test_remediation_safety.py index 0861efc49..78fe81a2b 100644 --- a/testing/backend/unit/test_remediation_safety.py +++ b/testing/backend/unit/test_remediation_safety.py @@ -97,7 +97,7 @@ def test_parse_package_lock(): lock_file = Path(tmpdir) / "package-lock.json" with open(lock_file, "w") as f: json.dump(lock_data, f) - + relations = parse_package_lock(str(lock_file)) assert "root" in relations assert relations["root"] == [("framer-motion", "^10.0.0")] @@ -118,7 +118,7 @@ def test_parse_package_json(): pkg_file = Path(tmpdir) / "package.json" with open(pkg_file, "w") as f: json.dump(pkg_data, f) - + relations = parse_package_json(str(pkg_file)) assert "root" in relations assert ("express", "^4.17.1") in relations["root"] @@ -141,11 +141,11 @@ def test_validate_remediation_with_conflict(): {"parent": "library-y", "specifier": SpecifierSet("<2.0")} ] } - + # Suggest upgrade of library-x to 1.5.0 (compatible with <2.0) res_safe = validate_remediation("Update library-x to version 1.5.0", graph) assert res_safe["safe_to_apply"] is True - + # Suggest upgrade of library-x to 2.1.0 (conflicts with <2.0) res_unsafe = validate_remediation("Update library-x to version 2.1.0", graph) assert res_unsafe["safe_to_apply"] is False From 22468c1e200597c6d726ee945ca220259af0f2b5 Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Sat, 13 Jun 2026 12:30:33 +0530 Subject: [PATCH 09/13] test(cache): add regression test asserting _execute_command is not called on cache hit --- testing/backend/unit/test_scan_cache.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/testing/backend/unit/test_scan_cache.py b/testing/backend/unit/test_scan_cache.py index ed759d97e..42a322690 100644 --- a/testing/backend/unit/test_scan_cache.py +++ b/testing/backend/unit/test_scan_cache.py @@ -184,6 +184,7 @@ async def db_fetchone_mock(query, params=()): executor._persist_result_resources = AsyncMock() executor._dispatch_task_notifications = AsyncMock() executor._invalidate_cached_views = AsyncMock() + executor._execute_command = AsyncMock() with patch("backend.secuscan.executor.get_db", return_value=mock_db), \ patch("backend.secuscan.executor.get_plugin_manager") as mock_pm: @@ -194,6 +195,9 @@ async def db_fetchone_mock(query, params=()): await executor.execute_task("task_id_123", bypass_cache=False) + # Verify _execute_command was never called (regression test for cache bypass) + executor._execute_command.assert_not_called() + # Verify db was updated with status, duration, etc. mock_db.execute.assert_any_call( """ From b96a41773ab4faaa1af9a7d5018d545a0635706d Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Sat, 13 Jun 2026 12:48:02 +0530 Subject: [PATCH 10/13] chore: add npm audit exception for esbuild GHSA-gv7w-rqvm-qjhr --- .audit-config.yaml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.audit-config.yaml b/.audit-config.yaml index 71aedcb5f..90700a80b 100644 --- a/.audit-config.yaml +++ b/.audit-config.yaml @@ -17,7 +17,12 @@ policy: # Documented exceptions with business justification # Format: CVE-XXXX-XXXXX or GHSA-xxxx-xxxx-xxxx -exceptions: {} +exceptions: + GHSA-gv7w-rqvm-qjhr: + package: esbuild + severity: high + reason: "Development bundler dependency; not exposed in production" + expires_at: "2026-12-31" # Packages to exclude from audits (use sparingly!) excluded_packages: [] From e9e05e09166492141e9dc75d351dfe668b2e6687 Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Sat, 13 Jun 2026 17:46:58 +0530 Subject: [PATCH 11/13] feat(remediation): improve safety validation logic, remove fallbacks, mock test dependencies, and add serialization tests --- backend/secuscan/executor.py | 16 +++- backend/secuscan/remediation.py | 51 +++-------- .../test_routes_remediation_safety.py | 89 +++++++++++++++++++ .../backend/unit/test_remediation_safety.py | 34 +++++++ 4 files changed, 151 insertions(+), 39 deletions(-) create mode 100644 testing/backend/integration/test_routes_remediation_safety.py diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index 3a09f802c..75451c481 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -1406,16 +1406,28 @@ async def _build_result_contract( try: from .remediation import build_dependency_graph, validate_remediation graph = build_dependency_graph(target) + validations = {} for f in normalized_findings: remediation_str = f.get("remediation", "") if remediation_str: val_res = validate_remediation(remediation_str, graph) + validations[id(f)] = val_res + + for f in normalized_findings: + if id(f) in validations: + val_res = validations[id(f)] f_metadata = f.setdefault("metadata", {}) f_metadata["safe_to_apply"] = val_res["safe_to_apply"] f_metadata["compatible_range"] = val_res["compatible_range"] f_metadata["alternatives"] = val_res["alternatives"] - except Exception: - pass + except Exception as e: + logger.warning( + "Remediation safety validation failed for task %s (plugin %s): %s. Skipping safety metadata enrichment.", + task_id, + plugin_id, + str(e), + exc_info=True, + ) previous_findings = await self._load_previous_task_findings( db, diff --git a/backend/secuscan/remediation.py b/backend/secuscan/remediation.py index b4b9a3ff9..0cc76e1c8 100644 --- a/backend/secuscan/remediation.py +++ b/backend/secuscan/remediation.py @@ -216,16 +216,6 @@ def parse_requirement_line(line: str) -> Tuple[str, SpecifierSet] | None: spec = SpecifierSet() return name, spec - -def get_mock_dependencies(pkg_name: str) -> List[Tuple[str, str]]: - """Return mock transitive dependencies for deterministic offline unit testing.""" - registry = { - "library-y": [("library-x", "<2.0")], - "parent-package": [("child-package", "<=1.5.0")], - } - return registry.get(pkg_name, []) - - def get_python_transitive_dependencies(package_name: str) -> List[Tuple[str, SpecifierSet]]: """Retrieve python transitive dependencies from installed metadata.""" try: @@ -259,16 +249,20 @@ def build_dependency_graph(target_dir: str) -> Dict[str, List[Dict[str, Any]]]: """Scan the target directory for Python/Node manifests and construct a transitive dependency constraint graph.""" graph: Dict[str, List[Dict[str, Any]]] = {} - target_path = Path(target_dir) if target_dir else Path(".") + if not target_dir: + return graph + + target_path = Path(target_dir) + if not target_path.exists(): + return graph + + if target_path.is_file(): + target_path = target_path.parent # 1. Search for python requirements req_files = ["requirements.txt", "requirements-dev.txt"] for req_name in req_files: p = target_path / req_name - if not p.exists(): - # Fallback to local project root - p = Path("backend") / req_name - if p.exists(): try: with open(p, "r", encoding="utf-8") as f: @@ -282,34 +276,17 @@ def build_dependency_graph(target_dir: str) -> Dict[str, List[Dict[str, Any]]]: }) # Transitive resolution - mock_deps = get_mock_dependencies(name) - if mock_deps: - for dep_name, dep_spec_str in mock_deps: - try: - dep_spec = SpecifierSet(dep_spec_str) - except Exception: - dep_spec = SpecifierSet() - graph.setdefault(dep_name, []).append({ - "parent": name, - "specifier": dep_spec - }) - else: - for dep_name, dep_spec in get_python_transitive_dependencies(name): - graph.setdefault(dep_name, []).append({ - "parent": name, - "specifier": dep_spec - }) + for dep_name, dep_spec in get_python_transitive_dependencies(name): + graph.setdefault(dep_name, []).append({ + "parent": name, + "specifier": dep_spec + }) except Exception: pass # 2. Search for Node.js package-lock.json / package.json lock_path = target_path / "package-lock.json" - if not lock_path.exists(): - lock_path = Path("frontend/package-lock.json") - pkg_path = target_path / "package.json" - if not pkg_path.exists(): - pkg_path = Path("frontend/package.json") if lock_path.exists(): try: diff --git a/testing/backend/integration/test_routes_remediation_safety.py b/testing/backend/integration/test_routes_remediation_safety.py new file mode 100644 index 000000000..84a2fc438 --- /dev/null +++ b/testing/backend/integration/test_routes_remediation_safety.py @@ -0,0 +1,89 @@ +import sqlite3 +import json +import pytest +from backend.secuscan.config import settings + +ALICE = {"X-User-Id": "alice"} +ALICE_OWNER = "user:alice" + +def _seed_task(owner_id: str, task_id: str) -> None: + """Insert a task row directly with an explicit owner_id.""" + conn = sqlite3.connect(settings.database_path) + try: + conn.execute( + """ + INSERT INTO tasks (id, owner_id, plugin_id, tool_name, target, + status, inputs_json, structured_json, consent_granted) + VALUES (?, ?, 'nmap', 'nmap', '127.0.0.1', 'completed', '{}', '{"findings": []}', 1) + """, + (task_id, owner_id), + ) + conn.commit() + finally: + conn.close() + +def _seed_finding(owner_id: str, finding_id: str, task_id: str, metadata: dict | None = None) -> None: + conn = sqlite3.connect(settings.database_path) + metadata_json = json.dumps(metadata) if metadata is not None else None + try: + conn.execute( + """ + INSERT INTO findings (id, owner_id, task_id, plugin_id, title, category, + severity, target, description, remediation, metadata_json) + VALUES (?, ?, ?, 'nmap', 'Open port', 'network', 'low', '127.0.0.1', 'desc', 'fix', ?) + """, + (finding_id, owner_id, task_id, metadata_json), + ) + conn.commit() + finally: + conn.close() + +def test_routes_expose_remediation_safety_fields(test_client): + """Test that safe_to_apply, compatible_range, and alternatives fields are exposed in API responses when present in metadata, and default to None otherwise.""" + _seed_task(ALICE_OWNER, "task-1") + + # 1. Seed finding with validated remediation metadata + metadata_validated = { + "safe_to_apply": False, + "compatible_range": "<2.0", + "alternatives": ["Upgrade package-y"], + "other_key": "some_value" + } + _seed_finding(ALICE_OWNER, "finding-validated", "task-1", metadata=metadata_validated) + + # 2. Seed finding without validated remediation metadata + metadata_unvalidated = { + "other_key": "some_value" + } + _seed_finding(ALICE_OWNER, "finding-unvalidated", "task-1", metadata=metadata_unvalidated) + + # 3. Test `/findings` list endpoint + response_list = test_client.get("/api/v1/findings", headers=ALICE) + assert response_list.status_code == 200 + findings_list = response_list.json()["findings"] + + finding_val = next(f for f in findings_list if f["id"] == "finding-validated") + assert finding_val["safe_to_apply"] is False + assert finding_val["compatible_range"] == "<2.0" + assert finding_val["alternatives"] == ["Upgrade package-y"] + + finding_unval = next(f for f in findings_list if f["id"] == "finding-unvalidated") + assert finding_unval["safe_to_apply"] is None + assert finding_unval["compatible_range"] is None + assert finding_unval["alternatives"] is None + + # 4. Test `/finding/{finding_id}` detail endpoint - Validated Case + response_detail_val = test_client.get("/api/v1/finding/finding-validated", headers=ALICE) + assert response_detail_val.status_code == 200 + detail_val = response_detail_val.json() + assert detail_val["safe_to_apply"] is False + assert detail_val["compatible_range"] == "<2.0" + assert detail_val["alternatives"] == ["Upgrade package-y"] + + # 5. Test `/finding/{finding_id}` detail endpoint - Unvalidated Case + response_detail_unval = test_client.get("/api/v1/finding/finding-unvalidated", headers=ALICE) + assert response_detail_unval.status_code == 200 + detail_unval = response_detail_unval.json() + assert detail_unval["safe_to_apply"] is None + assert detail_unval["compatible_range"] is None + assert detail_unval["alternatives"] is None diff --git a/testing/backend/unit/test_remediation_safety.py b/testing/backend/unit/test_remediation_safety.py index 78fe81a2b..7fed0533d 100644 --- a/testing/backend/unit/test_remediation_safety.py +++ b/testing/backend/unit/test_remediation_safety.py @@ -168,3 +168,37 @@ def test_finding_model_safety_fields(): assert finding.safe_to_apply is False assert finding.compatible_range == "<2.0" assert finding.alternatives == ["Upgrade library-y"] + + +def test_build_dependency_graph_fallback_disabled(): + """Verify that build_dependency_graph does not fall back to local manifests when target is invalid/nonexistent.""" + # 1. Non-existent directory + graph_nonexistent = build_dependency_graph("nonexistent_directory_123") + assert graph_nonexistent == {} + + # 2. URL/IP target + graph_url = build_dependency_graph("http://example.com/api") + assert graph_url == {} + + +def test_build_dependency_graph_python_transitive_mocked(): + """Test building dependency graph for Python requirements with mocked transitive dependencies.""" + from unittest.mock import patch + + req_content = "library-y>=1.0.0\n" + + with tempfile.TemporaryDirectory() as tmpdir: + req_file = Path(tmpdir) / "requirements.txt" + with open(req_file, "w", encoding="utf-8") as f: + f.write(req_content) + + # Mock get_python_transitive_dependencies to return a transitive dependency + mock_transitive = [("library-x", SpecifierSet("<2.0"))] + with patch("backend.secuscan.remediation.get_python_transitive_dependencies", return_value=mock_transitive): + graph = build_dependency_graph(tmpdir) + + assert "library-y" in graph + assert graph["library-y"] == [{"parent": "root", "specifier": SpecifierSet(">=1.0.0")}] + + assert "library-x" in graph + assert graph["library-x"] == [{"parent": "library-y", "specifier": SpecifierSet("<2.0")}] From 78e63469c0f845295051549d9c9934a8d35b008b Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Sat, 13 Jun 2026 18:10:20 +0530 Subject: [PATCH 12/13] style: remove trailing whitespaces in test files to pass formatting hygiene check --- .../integration/test_routes_remediation_safety.py | 8 ++++---- testing/backend/unit/test_remediation_safety.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/testing/backend/integration/test_routes_remediation_safety.py b/testing/backend/integration/test_routes_remediation_safety.py index 84a2fc438..2954c09ea 100644 --- a/testing/backend/integration/test_routes_remediation_safety.py +++ b/testing/backend/integration/test_routes_remediation_safety.py @@ -41,7 +41,7 @@ def _seed_finding(owner_id: str, finding_id: str, task_id: str, metadata: dict | def test_routes_expose_remediation_safety_fields(test_client): """Test that safe_to_apply, compatible_range, and alternatives fields are exposed in API responses when present in metadata, and default to None otherwise.""" _seed_task(ALICE_OWNER, "task-1") - + # 1. Seed finding with validated remediation metadata metadata_validated = { "safe_to_apply": False, @@ -50,7 +50,7 @@ def test_routes_expose_remediation_safety_fields(test_client): "other_key": "some_value" } _seed_finding(ALICE_OWNER, "finding-validated", "task-1", metadata=metadata_validated) - + # 2. Seed finding without validated remediation metadata metadata_unvalidated = { "other_key": "some_value" @@ -61,12 +61,12 @@ def test_routes_expose_remediation_safety_fields(test_client): response_list = test_client.get("/api/v1/findings", headers=ALICE) assert response_list.status_code == 200 findings_list = response_list.json()["findings"] - + finding_val = next(f for f in findings_list if f["id"] == "finding-validated") assert finding_val["safe_to_apply"] is False assert finding_val["compatible_range"] == "<2.0" assert finding_val["alternatives"] == ["Upgrade package-y"] - + finding_unval = next(f for f in findings_list if f["id"] == "finding-unvalidated") assert finding_unval["safe_to_apply"] is None assert finding_unval["compatible_range"] is None diff --git a/testing/backend/unit/test_remediation_safety.py b/testing/backend/unit/test_remediation_safety.py index 7fed0533d..710e0800b 100644 --- a/testing/backend/unit/test_remediation_safety.py +++ b/testing/backend/unit/test_remediation_safety.py @@ -184,21 +184,21 @@ def test_build_dependency_graph_fallback_disabled(): def test_build_dependency_graph_python_transitive_mocked(): """Test building dependency graph for Python requirements with mocked transitive dependencies.""" from unittest.mock import patch - + req_content = "library-y>=1.0.0\n" - + with tempfile.TemporaryDirectory() as tmpdir: req_file = Path(tmpdir) / "requirements.txt" with open(req_file, "w", encoding="utf-8") as f: f.write(req_content) - + # Mock get_python_transitive_dependencies to return a transitive dependency mock_transitive = [("library-x", SpecifierSet("<2.0"))] with patch("backend.secuscan.remediation.get_python_transitive_dependencies", return_value=mock_transitive): graph = build_dependency_graph(tmpdir) - + assert "library-y" in graph assert graph["library-y"] == [{"parent": "root", "specifier": SpecifierSet(">=1.0.0")}] - + assert "library-x" in graph assert graph["library-x"] == [{"parent": "library-y", "specifier": SpecifierSet("<2.0")}] From 8aa1e8300b9a3f1a5ec5f23523228869fd47d3de Mon Sep 17 00:00:00 2001 From: dinesh9997 Date: Sun, 14 Jun 2026 17:57:52 +0530 Subject: [PATCH 13/13] chore: remove unrelated cache and audit config changes from remediation safety PR --- .audit-config.yaml | 7 +- PLUGINS.md | 4 +- backend/secuscan/cache.py | 58 +---- backend/secuscan/executor.py | 318 +++++------------------- backend/secuscan/main.py | 4 +- backend/secuscan/routes.py | 3 +- plugins/httpx/metadata.json | 8 +- plugins/website-recon-2/metadata.json | 8 +- scripts/refresh_plugin_checksum.py | 10 +- testing/backend/conftest.py | 6 - testing/backend/unit/test_scan_cache.py | 305 ----------------------- 11 files changed, 87 insertions(+), 644 deletions(-) delete mode 100644 testing/backend/unit/test_scan_cache.py diff --git a/.audit-config.yaml b/.audit-config.yaml index 90700a80b..71aedcb5f 100644 --- a/.audit-config.yaml +++ b/.audit-config.yaml @@ -17,12 +17,7 @@ policy: # Documented exceptions with business justification # Format: CVE-XXXX-XXXXX or GHSA-xxxx-xxxx-xxxx -exceptions: - GHSA-gv7w-rqvm-qjhr: - package: esbuild - severity: high - reason: "Development bundler dependency; not exposed in production" - expires_at: "2026-12-31" +exceptions: {} # Packages to exclude from audits (use sparingly!) excluded_packages: [] diff --git a/PLUGINS.md b/PLUGINS.md index 4bd256f0d..67a5b7685 100644 --- a/PLUGINS.md +++ b/PLUGINS.md @@ -60,7 +60,7 @@ Only run scans against systems you own or are explicitly authorized to assess. | Password Recovery Audit | `hashcat` | `expert` | `exploit` | `hashcat` | Password recovery and hash audit workflow. | | HTTP Inspector | `http_inspector` | `web` | `safe` | `curl` | Inspect HTTP/HTTPS endpoints for headers, cookies, and TLS configuration. | | HTTP Request Logger | `http_request_logger` | `exploit` | `intrusive` | `httpx` | Handle incoming HTTP requests and record data. | -| httpx | `httpx` | `recon` | `safe` | `httpx` | Live host probing with status, title, and technology fingerprinting. | +| httpx | `httpx` | `recon` | `safe` | `httpx` | Probe live hosts and collect reachability information, status codes, page titles, and basic technology indicators. | | IaC Scanner (Checkov) | `iac_scanner` | `vulnerability` | `safe` | `python3` | Analyze Terraform and CloudFormation code for flaws. | | ICMP Ping | `icmp_ping` | `utils` | `safe` | `ping` | Check if a server is live and responds to ICMP Echo requests. | | Joomla Security Scan | `joomscan` | `vulnerability` | `intrusive` | `joomscan` | Joomla security scanner for version and common weakness discovery. | @@ -95,7 +95,7 @@ Only run scans against systems you own or are explicitly authorized to assess. | Virtual Hosts Finder | `virtual-host-finder` | `recon` | `intrusive` | `ffuf` | Find multiple websites hosted on the same server. | | Volatility | `volatility` | `forensics` | `intrusive` | `volatility3` | Memory forensics workflow using Volatility 3 plugins. | | WAF Detector | `waf_detector` | `robots` | `safe` | `wafw00f` | Automatically identify Web Application Firewalls protecting targets. | -| Website Recon | `website-recon-2` | `recon` | `safe` | `httpx` | Fingerprint web technologies of target website. | +| Website Recon | `website-recon-2` | `recon` | `safe` | `httpx` | Perform website reconnaissance focused on identifying web technologies, frameworks, and application stack details. | | Domain Registration Lookup | `whois_lookup` | `utils` | `safe` | `python3` | Domain registration information lookup. | | WordPress Security Scan | `wpscan` | `vulnerability` | `intrusive` | `wpscan` | WordPress security scanner for plugin, theme, and core risk visibility. | | XSS Exploiter | `xss_exploiter` | `exploit` | `exploit` | `python3` | Exploit XSS in real-life attacks to extract cookies and data. | diff --git a/backend/secuscan/cache.py b/backend/secuscan/cache.py index 68a3231e9..a42c3edbf 100644 --- a/backend/secuscan/cache.py +++ b/backend/secuscan/cache.py @@ -1,5 +1,5 @@ """ -In-memory and Redis-based cache helpers for API responses. +In-memory cache helpers for API responses. """ import json @@ -15,15 +15,9 @@ SWEEP_EVICT_FRACTION = 0.25 OPPORTUNISTIC_SWEEP_INTERVAL = 50 -try: - from redis.asyncio import Redis, ConnectionError as RedisConnectionError -except ImportError: - Redis = None - RedisConnectionError = Exception - class CacheClient: - """Cache client supporting Redis with an in-memory dictionary fallback.""" + """In-memory dictionary based cache client with TTL, size limit, and LRU eviction.""" def __init__(self, url: Optional[str] = None, max_entries: int = DEFAULT_MAX_ENTRIES): self.url = url @@ -34,27 +28,11 @@ def __init__(self, url: Optional[str] = None, max_entries: int = DEFAULT_MAX_ENT self._eviction_count = 0 self._sweep_count = 0 self._write_count = 0 - self.client: Optional[Redis] = None async def connect(self): - if self.url and Redis is not None: - try: - self.client = Redis.from_url(self.url, decode_responses=True) - await self.client.ping() - logger.info("✓ Connected to Redis cache at %s", self.url) - except RedisConnectionError as e: - logger.warning("Failed to connect to Redis, falling back to in-memory: %s", e) - self.client = None - else: - self.client = None + pass async def disconnect(self): - if self.client: - try: - await self.client.aclose() - except Exception: - pass - self.client = None self._data.clear() self._expires.clear() self._access_order.clear() @@ -82,14 +60,7 @@ def _evict_lru(self): self._eviction_count += evict_count async def get_json(self, key: str) -> Optional[Any]: - """Retrieve and parse JSON from cache, respecting TTL.""" - if self.client: - try: - val = await self.client.get(key) - return json.loads(val) if val is not None else None - except Exception as e: - logger.warning("Redis get_json error (falling back to in-memory): %s", e) - + """Retrieve and parse JSON from memory, respecting TTL.""" now = time.time() expiry = self._expires.get(key) @@ -105,20 +76,12 @@ async def get_json(self, key: str) -> Optional[Any]: return self._data.get(key) async def set_json(self, key: str, value: Any, ttl: Optional[int] = None): - """Store value in cache with optional TTL.""" - actual_ttl = ttl or settings.cache_ttl_seconds - - if self.client: - try: - await self.client.set(key, json.dumps(value), ex=actual_ttl) - return - except Exception as e: - logger.warning("Redis set_json error (falling back to in-memory): %s", e) - + """Store value in memory with optional TTL.""" if len(self._data) >= self.max_entries and key not in self._data: self._evict_lru() self._data[key] = value + actual_ttl = ttl or settings.cache_ttl_seconds self._expires[key] = time.time() + actual_ttl self._access_order[key] = time.time() self._write_count += 1 @@ -128,15 +91,6 @@ async def set_json(self, key: str, value: Any, ttl: Optional[int] = None): async def delete_prefix(self, prefix: str): """Delete all keys starting with prefix.""" - if self.client: - try: - keys = await self.client.keys(f"{prefix}*") - if keys: - await self.client.delete(*keys) - return - except Exception as e: - logger.warning("Redis delete_prefix error (falling back to in-memory): %s", e) - to_delete = [k for k in self._data.keys() if k.startswith(prefix)] for k in to_delete: self._data.pop(k, None) diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index 75451c481..76b1e90e6 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -155,84 +155,6 @@ def extract_target(inputs: Dict[str, Any]) -> str: ) -def generate_scan_cache_key( - owner_id: str, - plugin_id: str, - target: str, - inputs: Dict[str, Any], - execution_context: Dict[str, Any], - safe_mode: bool -) -> tuple[str, str, str]: - """ - Generate target hash, dependency hash, and an owner-scoped cache key. - - Returns: - tuple: (target_hash, dependency_hash, cache_key) - """ - import hashlib - import subprocess - from pathlib import Path - - target_hash = None - if target and os.path.isdir(target): - try: - res = subprocess.run( - ["git", "rev-parse", "HEAD"], - cwd=target, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - timeout=5 - ) - if res.returncode == 0: - target_hash = res.stdout.strip() - except Exception: - pass - - if not target_hash: - target_hash = hashlib.sha256(str(target or "").encode("utf-8")).hexdigest() - - dependency_files = [ - "package-lock.json", - "poetry.lock", - "Cargo.lock", - "go.sum", - "requirements.txt", - "Pipfile.lock", - "pnpm-lock.yaml", - "yarn.lock", - "gemfile.lock" - ] - hasher = hashlib.sha256() - found_any = False - - if target and os.path.isdir(target): - p = Path(target) - for dep_file in sorted(dependency_files): - file_path = p / dep_file - if file_path.exists() and file_path.is_file(): - try: - hasher.update(dep_file.encode("utf-8")) - hasher.update(file_path.read_bytes()) - found_any = True - except Exception: - pass - - if not found_any: - dependency_hash = "no_deps" - else: - dependency_hash = hasher.hexdigest() - - inputs_str = json.dumps(inputs, sort_keys=True) - inputs_hash = hashlib.sha256(inputs_str.encode("utf-8")).hexdigest() - - context_str = json.dumps(execution_context, sort_keys=True) - context_hash = hashlib.sha256(context_str.encode("utf-8")).hexdigest() - - cache_key = f"scan_cache:{owner_id}:{plugin_id}:{int(safe_mode)}:{target_hash}:{dependency_hash}:{inputs_hash}:{context_hash}" - return target_hash, dependency_hash, cache_key - - def _stable_asset_id(target: str, host: Any, port: Any, protocol: Any) -> str: material = "||".join( [ @@ -429,13 +351,12 @@ async def mark_task_failed(self, task_id: str, reason: str) -> None: task_id=task_id, ) - async def execute_task(self, task_id: str, bypass_cache: bool = False): + async def execute_task(self, task_id: str): """ Execute a task asynchronously. Args: task_id: Task identifier - bypass_cache: Whether to bypass Redis scan result cache """ db = await get_db() self.running_tasks[task_id] = asyncio.current_task() @@ -472,104 +393,6 @@ async def execute_task(self, task_id: str, bypass_cache: bool = False): execution_context=execution_context, ) - # Check cache if not bypassed - cached_result = None - cache_key = None - if target and not bypass_cache: - try: - target_hash, dependency_hash, cache_key = generate_scan_cache_key( - owner_id=owner_id, - plugin_id=plugin_id, - target=target, - inputs=inputs, - execution_context=execution_context, - safe_mode=safe_mode, - ) - cache_client = await get_cache() - cached_result = await cache_client.get_json(cache_key) - except Exception as cache_exc: - logger.warning("Failed to query scan cache: %s", cache_exc) - - if cached_result is not None: - logger.info("Cache hit for scan task %s (key: %s)", task_id, cache_key) - await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) - await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) - - raw_path = Path(settings.raw_output_dir) / f"{task_id}.txt" - try: - with open(raw_path, 'w', encoding='utf-8') as f: - f.write(cached_result.get("raw_output", "")) - except Exception as f_exc: - logger.warning("Failed to write raw output for cached task: %s", f_exc) - - status = cached_result.get("status", TaskStatus.COMPLETED.value) - duration = cached_result.get("duration_seconds", 0.0) - exit_code = cached_result.get("exit_code", 0) - error_message = cached_result.get("error_message") - structured_data = cached_result.get("structured", {}) - - # Update task in SQLite with the cached results - await db.execute( - """ - UPDATE tasks SET - status = ?, - completed_at = ?, - duration_seconds = ?, - exit_code = ?, - raw_output_path = ?, - error_message = ? - WHERE id = ? - """, - ( - status, - datetime.now().isoformat(), - duration, - exit_code, - str(raw_path), - error_message, - task_id - ) - ) - - # Persist findings and reports to database for the new task - await self._broadcast_phase(task_id, ScanPhase.PARSING.value) - - plugin_manager = get_plugin_manager() - plugin = plugin_manager.get_plugin(plugin_id) - is_modular = plugin_id in MODULAR_SCANNERS - if is_modular: - scanner_class = MODULAR_SCANNERS[plugin_id] - report_name = f"{scanner_class.__name__} Report" - else: - report_name = f"{plugin.name} Report" if plugin else f"{plugin_id} Report" - - await self._persist_findings_and_report_common( - db, - task_id=task_id, - owner_id=owner_id, - plugin_id=plugin_id, - target=target, - status=status, - result_dict=structured_data, - is_modular=is_modular, - report_name=report_name, - ) - - await self._dispatch_task_notifications(db, task_id) - await self._broadcast_phase(task_id, ScanPhase.FINISHED.value) - await self._broadcast(task_id, "status", status) - await self._invalidate_cached_views() - - await db.log_audit( - "task_completed", - f"Task completed from cache (duration: {duration:.2f}s)", - context={"task_id": task_id, "exit_code": exit_code, "cached": True}, - task_id=task_id, - plugin_id=plugin_id - ) - logger.info(f"Task {task_id} completed from cache") - return - # ── Safe Mode & Network policy enforcement ─────────────────────── # Enforce Safe Mode target validation inside TaskExecutor to guarantee # that all execution paths (manual API, workflows, scheduled tasks) are protected. @@ -830,41 +653,7 @@ async def execute_task(self, task_id: str, bypass_cache: bool = False): status=final_status, output=output ) - if target and not bypass_cache: - try: - target_hash, dependency_hash, cache_key = generate_scan_cache_key( - owner_id=owner_id, - plugin_id=plugin_id, - target=target, - inputs=inputs, - execution_context=execution_context, - safe_mode=safe_mode, - ) - task_data = await db.fetchone( - "SELECT status, duration_seconds, exit_code, error_message, structured_json, raw_output_path FROM tasks WHERE id = ?", - (task_id,) - ) - if task_data and task_data["status"] == TaskStatus.COMPLETED.value: - raw_output = "" - if task_data["raw_output_path"]: - try: - with open(task_data["raw_output_path"], "r", encoding="utf-8") as f: - raw_output = f.read() - except Exception: - pass - cache_data = { - "status": task_data["status"], - "duration_seconds": task_data["duration_seconds"], - "exit_code": task_data["exit_code"], - "error_message": task_data["error_message"], - "raw_output": raw_output, - "structured": json.loads(task_data["structured_json"]) if task_data["structured_json"] else {} - } - cache_client = await get_cache() - await cache_client.set_json(cache_key, cache_data, ttl=86400) - logger.info("Saved scan results to cache for task %s (key: %s)", task_id, cache_key) - except Exception as cache_exc: - logger.warning("Failed to save scan results to cache: %s", cache_exc) + await self._broadcast_phase(task_id, ScanPhase.REPORTING.value) await self._dispatch_task_notifications(db, task_id) @@ -1571,27 +1360,16 @@ async def _persist_finding( "risk_factors": risk_factors, } - async def _persist_findings_and_report_common( - self, - db, - *, - task_id: str, - owner_id: str, - plugin_id: str, - target: str, - status: str, - result_dict: Dict[str, Any], - is_modular: bool, - report_name: str, - ) -> Dict[str, Any]: - """Common logic to persist findings, report, and result resources for a scan.""" + async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plugin, plugin_id: str, target: str, status: str, output: str = ""): + """Persist derived findings and report records into SQLite.""" + parsed = self._parse_results(plugin, output) structured_result, previous_findings, asset_services = await self._build_result_contract( db, task_id=task_id, owner_id=owner_id, plugin_id=plugin_id, target=target, - result=result_dict, + result=parsed, ) findings_data: List[Dict[str, Any]] = [] for finding in structured_result.get("findings", []): @@ -1617,13 +1395,6 @@ async def _persist_findings_and_report_common( (json.dumps(structured_result), task_id) ) - if is_modular: - report_type = "professional" if status == TaskStatus.COMPLETED.value else "failed" - pages = 2 - else: - report_type = "technical" - pages = 1 - await db.execute( """ INSERT INTO reports ( @@ -1638,11 +1409,11 @@ async def _persist_findings_and_report_common( f"report:{task_id}", owner_id, task_id, - report_name, - report_type, + f"{plugin.name} Report", + "technical", "ready" if status == TaskStatus.COMPLETED.value else "failed", len(findings_data), - pages, + 1, ), ) @@ -1654,38 +1425,73 @@ async def _persist_findings_and_report_common( target=target, result=structured_result, ) - return structured_result - async def _upsert_findings_and_report(self, db, task_id: str, owner_id: str, plugin, plugin_id: str, target: str, status: str, output: str = ""): - """Persist derived findings and report records into SQLite.""" - parsed = self._parse_results(plugin, output) - await self._persist_findings_and_report_common( + async def _upsert_findings_and_report_from_scanner(self, db, task_id: str, owner_id: str, scanner: Any, plugin_id: str, target: str, status: str, result: Dict[str, Any]): + """Persist modular scanner results into findings, and reports.""" + structured_result, previous_findings, asset_services = await self._build_result_contract( db, task_id=task_id, owner_id=owner_id, plugin_id=plugin_id, target=target, - status=status, - result_dict=parsed, - is_modular=False, - report_name=f"{plugin.name} Report", + result=result, ) + findings_data: List[Dict[str, Any]] = [] + for finding in structured_result.get("findings", []): + findings_data.append( + await self._persist_finding( + db, + owner_id=owner_id, + task_id=task_id, + plugin_id=plugin_id, + target=target, + finding=finding, + ) + ) - async def _upsert_findings_and_report_from_scanner(self, db, task_id: str, owner_id: str, scanner: Any, plugin_id: str, target: str, status: str, result: Dict[str, Any]): - """Persist modular scanner results into findings, and reports.""" - await self._persist_findings_and_report_common( + structured_result["findings"] = findings_data + structured_result["severity_counts"] = self._build_severity_counts(findings_data) + structured_result["finding_groups"] = build_finding_groups(findings_data) + structured_result["asset_summary"] = build_asset_summary(findings_data, asset_services) + structured_result["scan_diff"] = build_scan_diff(findings_data, previous_findings) + + await db.execute( + "UPDATE tasks SET structured_json = ? WHERE id = ?", + (json.dumps(structured_result), task_id) + ) + + # Create/Update report + await db.execute( + """ + INSERT INTO reports ( + id, owner_id, task_id, name, type, generated_at, status, findings, pages + ) VALUES (?, ?, ?, ?, ?, (datetime('now')), ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + status = EXCLUDED.status, + findings = EXCLUDED.findings, + pages = EXCLUDED.pages + """, + ( + f"report:{task_id}", + owner_id, + task_id, + f"{scanner.name} Report", + "professional" if status == TaskStatus.COMPLETED.value else "failed", + "ready" if status == TaskStatus.COMPLETED.value else "failed", + len(findings_data), + 2, # Professional reports are typically multi-page + ), + ) + + await self._persist_result_resources( db, - task_id=task_id, owner_id=owner_id, + task_id=task_id, plugin_id=plugin_id, target=target, - status=status, - result_dict=result, - is_modular=True, - report_name=f"{scanner.name} Report", + result=structured_result, ) - async def _persist_result_resources( self, db, diff --git a/backend/secuscan/main.py b/backend/secuscan/main.py index 9b43d8c98..8e06d6638 100644 --- a/backend/secuscan/main.py +++ b/backend/secuscan/main.py @@ -66,8 +66,8 @@ async def lifespan(app: FastAPI): await init_db(settings.database_path) logger.info("✓ SQLite connected") - await init_cache(settings.redis_url) - logger.info("✓ Cache initialized") + await init_cache() + logger.info("✓ In-memory cache initialized") # Load plugins await init_plugins(settings.plugins_dir) diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index 5f389fe6e..34cf97993 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -367,7 +367,6 @@ async def start_task( request: TaskCreateRequest, background_tasks: BackgroundTasks, raw_request: Request, - bypass_cache: bool = Query(False), owner: str = Depends(get_current_owner), ): """ @@ -513,7 +512,7 @@ async def start_task( # Use BackgroundTasks so the response can be sent without waiting in real # ASGI servers, while tests using TestClient still execute the task to keep # contract tests deterministic. - background_tasks.add_task(executor.execute_task, task_id, bypass_cache=bypass_cache) + background_tasks.add_task(executor.execute_task, task_id) await invalidate_view_cache() return { diff --git a/plugins/httpx/metadata.json b/plugins/httpx/metadata.json index 8a2a7bee0..bbde28b40 100644 --- a/plugins/httpx/metadata.json +++ b/plugins/httpx/metadata.json @@ -2,15 +2,15 @@ "id": "httpx", "name": "httpx", "version": "1.0.0", - "description": "Live host probing with status, title, and technology fingerprinting.", - "long_description": "Live host probing with status, title, and technology fingerprinting.", + "description": "Probe live hosts and collect reachability information, status codes, page titles, and basic technology indicators.", + "long_description": "Probe live hosts and collect reachability information, status codes, page titles, and basic technology indicators.", "category": "recon", "author": { "name": "SecuScan Contributors", "email": "dev@secuscan.local" }, "license": "MIT", - "icon": "\ud83d\udd0e", + "icon": "🔎", "engine": { "type": "cli", "binary": "httpx" @@ -55,5 +55,5 @@ "python_packages": [], "system_packages": [] }, - "checksum": "b74defa2b8d5595ae6a8fbd8020c35ce05a214beb65d11f31847ae28d6517e2f" + "checksum": "6570954f2b2cae9ce3a2281445f5c4c46533ada5cd6b5c35859d592523d517e9" } diff --git a/plugins/website-recon-2/metadata.json b/plugins/website-recon-2/metadata.json index f5865d5a3..b1ad20a7f 100644 --- a/plugins/website-recon-2/metadata.json +++ b/plugins/website-recon-2/metadata.json @@ -2,15 +2,15 @@ "id": "website-recon-2", "name": "Website Recon", "version": "1.0.0", - "description": "Fingerprint web technologies of target website.", - "long_description": "Fingerprint web technologies of target website.", + "description": "Perform website reconnaissance focused on identifying web technologies, frameworks, and application stack details.", + "long_description": "Perform website reconnaissance focused on identifying web technologies, frameworks, and application stack details.", "category": "recon", "author": { "name": "SecuScan Contributors", "email": "dev@secuscan.local" }, "license": "MIT", - "icon": "\ud83d\udd0e", + "icon": "🔎", "engine": { "type": "cli", "binary": "httpx" @@ -60,5 +60,5 @@ "python_packages": [], "system_packages": [] }, - "checksum": "53ac15d9af192a5ac70225f2faaf1f3c086868ea67438d9b588de6645555ef01" + "checksum": "4f8cf37fc9c3de4cdab5d4cb140c619f11bbf76ca6926aa6c41288168e8637f4" } diff --git a/scripts/refresh_plugin_checksum.py b/scripts/refresh_plugin_checksum.py index 7738589e8..2398a5d00 100644 --- a/scripts/refresh_plugin_checksum.py +++ b/scripts/refresh_plugin_checksum.py @@ -49,11 +49,11 @@ def compute_plugin_digest(metadata_file: Path, parser_file: Path) -> str: metadata_digest = hashlib.sha256(metadata_canonical.encode("utf-8")).hexdigest() # Hash parser.py if it exists, otherwise use empty string - parser_digest = ( - hashlib.sha256(parser_file.read_bytes()).hexdigest() - if parser_file.exists() - else "" - ) + parser_digest = "" + if parser_file.exists(): + parser_bytes = parser_file.read_bytes() + parser_bytes_normalized = parser_bytes.replace(b"\r\n", b"\n") + parser_digest = hashlib.sha256(parser_bytes_normalized).hexdigest() # Final digest combines both return hashlib.sha256( diff --git a/testing/backend/conftest.py b/testing/backend/conftest.py index 73ae1fd1f..fc34fdb28 100644 --- a/testing/backend/conftest.py +++ b/testing/backend/conftest.py @@ -26,12 +26,6 @@ def anyio_backend(): @pytest.fixture(autouse=True) def setup_test_environment(monkeypatch): """Override settings for tests to ensure isolated execution.""" - try: - from backend.secuscan import cache as cache_module - cache_module.cache = None - except ImportError: - pass - temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) temp_path = temp_dir.name diff --git a/testing/backend/unit/test_scan_cache.py b/testing/backend/unit/test_scan_cache.py deleted file mode 100644 index 42a322690..000000000 --- a/testing/backend/unit/test_scan_cache.py +++ /dev/null @@ -1,305 +0,0 @@ -import os -import json -import shutil -import tempfile -import pytest -import asyncio -from unittest.mock import AsyncMock, patch, MagicMock, ANY - -from backend.secuscan.executor import generate_scan_cache_key, TaskExecutor -from backend.secuscan.cache import init_cache, get_cache -from backend.secuscan.models import TaskStatus -from backend.secuscan.execution_context import normalize_execution_context - -@pytest.fixture -def temp_repo(): - # Create a temporary directory structure representing a project - temp_dir = tempfile.mkdtemp() - yield temp_dir - shutil.rmtree(temp_dir) - -def test_generate_scan_cache_key_no_repo(temp_repo): - # If no git or dependency files exist, it hashes target string - target_hash, dep_hash, key = generate_scan_cache_key( - owner_id="owner_1", - plugin_id="test_plugin", - target=temp_repo, - inputs={"target": temp_repo}, - execution_context={}, - safe_mode=False - ) - assert len(target_hash) == 64 - assert dep_hash == "no_deps" - assert key.startswith("scan_cache:owner_1:test_plugin:0:") - -def test_generate_scan_cache_key_with_deps(temp_repo): - # Create package-lock.json - dep_file = os.path.join(temp_repo, "package-lock.json") - with open(dep_file, "w") as f: - f.write("npm-deps-v1") - - target_hash, dep_hash, key = generate_scan_cache_key( - owner_id="owner_1", - plugin_id="test_plugin", - target=temp_repo, - inputs={"target": temp_repo}, - execution_context={}, - safe_mode=False - ) - assert len(target_hash) == 64 - assert dep_hash != "no_deps" - - # Modify package-lock.json -> dependency hash changes! - with open(dep_file, "w") as f: - f.write("npm-deps-v2") - - target_hash_2, dep_hash_2, key_2 = generate_scan_cache_key( - owner_id="owner_1", - plugin_id="test_plugin", - target=temp_repo, - inputs={"target": temp_repo}, - execution_context={}, - safe_mode=False - ) - assert dep_hash != dep_hash_2 - assert key != key_2 - -def test_cache_key_tenant_isolation(temp_repo): - # Same inputs/target, different owners -> different cache keys! - _, _, key_owner1 = generate_scan_cache_key( - owner_id="owner_1", - plugin_id="test_plugin", - target=temp_repo, - inputs={"target": temp_repo, "flag": "x"}, - execution_context={"profile": "admin"}, - safe_mode=False - ) - _, _, key_owner2 = generate_scan_cache_key( - owner_id="owner_2", - plugin_id="test_plugin", - target=temp_repo, - inputs={"target": temp_repo, "flag": "x"}, - execution_context={"profile": "admin"}, - safe_mode=False - ) - assert key_owner1 != key_owner2 - -def test_cache_key_inputs_isolation(temp_repo): - # Same target/owner, different inputs/flags -> different cache keys! - _, _, key_inputs1 = generate_scan_cache_key( - owner_id="owner_1", - plugin_id="test_plugin", - target=temp_repo, - inputs={"target": temp_repo, "wordlist": "common.txt"}, - execution_context={}, - safe_mode=False - ) - _, _, key_inputs2 = generate_scan_cache_key( - owner_id="owner_1", - plugin_id="test_plugin", - target=temp_repo, - inputs={"target": temp_repo, "wordlist": "deep.txt"}, - execution_context={}, - safe_mode=False - ) - assert key_inputs1 != key_inputs2 - -def test_cache_key_safe_mode_isolation(temp_repo): - # Same inputs/owner, safe_mode toggled -> different cache keys! - _, _, key_safe = generate_scan_cache_key( - owner_id="owner_1", - plugin_id="test_plugin", - target=temp_repo, - inputs={"target": temp_repo}, - execution_context={}, - safe_mode=True - ) - _, _, key_unsafe = generate_scan_cache_key( - owner_id="owner_1", - plugin_id="test_plugin", - target=temp_repo, - inputs={"target": temp_repo}, - execution_context={}, - safe_mode=False - ) - assert key_unsafe != key_safe - -@pytest.mark.asyncio -async def test_execute_task_cache_hit(temp_repo): - # Initialize in-memory cache - await init_cache() - - # We will mock the database and task run details using a SQL-inspecting side effect - mock_db = AsyncMock() - async def db_fetchone_mock(query, params=()): - query_lower = query.lower() - if "select owner_id, plugin_id" in query_lower: - return { - "owner_id": "owner_1", - "plugin_id": "test_plugin", - "inputs_json": json.dumps({"target": temp_repo}), - "execution_context_json": "{}", - "safe_mode": False - } - return None - mock_db.fetchone = AsyncMock(side_effect=db_fetchone_mock) - - executor = TaskExecutor() - - # Pre-populate cache for this target/owner/inputs/context/safe_mode - # Note: inputs is hydrated inside execute_task to contain normalized execution_context - execution_context = normalize_execution_context({}) - inputs = {"target": temp_repo, "__execution_context": execution_context} - target_hash, dep_hash, cache_key = generate_scan_cache_key( - owner_id="owner_1", - plugin_id="test_plugin", - target=temp_repo, - inputs=inputs, - execution_context=execution_context, - safe_mode=False - ) - cache_client = await get_cache() - - cached_data = { - "status": TaskStatus.COMPLETED.value, - "duration_seconds": 1.5, - "exit_code": 0, - "error_message": None, - "raw_output": "cached output text", - "structured": { - "findings": [ - { - "title": "Cached Finding", - "category": "Code Security", - "severity": "high", - "description": "Cached desc" - } - ] - } - } - await cache_client.set_json(cache_key, cached_data) - - # We mock internal helper methods - executor._persist_finding = AsyncMock(return_value={"id": "finding_1"}) - executor._persist_result_resources = AsyncMock() - executor._dispatch_task_notifications = AsyncMock() - executor._invalidate_cached_views = AsyncMock() - executor._execute_command = AsyncMock() - - with patch("backend.secuscan.executor.get_db", return_value=mock_db), \ - patch("backend.secuscan.executor.get_plugin_manager") as mock_pm: - - mock_plugin = MagicMock() - mock_plugin.name = "Test Plugin" - mock_pm.return_value.get_plugin.return_value = mock_plugin - - await executor.execute_task("task_id_123", bypass_cache=False) - - # Verify _execute_command was never called (regression test for cache bypass) - executor._execute_command.assert_not_called() - - # Verify db was updated with status, duration, etc. - mock_db.execute.assert_any_call( - """ - UPDATE tasks SET - status = ?, - completed_at = ?, - duration_seconds = ?, - exit_code = ?, - raw_output_path = ?, - error_message = ? - WHERE id = ? - """, - ( - TaskStatus.COMPLETED.value, - ANY, - 1.5, - 0, - ANY, - None, - "task_id_123" - ) - ) - - # Verify it persisted the cached findings and updated structured_json - executor._persist_finding.assert_called_once() - mock_db.execute.assert_any_call( - "UPDATE tasks SET structured_json = ? WHERE id = ?", - (ANY, "task_id_123") - ) - -@pytest.mark.asyncio -async def test_execute_task_transient_failure_not_cached(temp_repo): - # Initialize in-memory cache - await init_cache() - - mock_db = AsyncMock() - async def db_fetchone_mock(query, params=()): - query_lower = query.lower() - if "select owner_id, plugin_id" in query_lower: - return { - "owner_id": "owner_1", - "plugin_id": "test_plugin", - "inputs_json": json.dumps({"target": temp_repo}), - "execution_context_json": "{}", - "safe_mode": False - } - if "select status, duration_seconds" in query_lower: - return { - "status": TaskStatus.FAILED.value, - "duration_seconds": 2.0, - "exit_code": 1, - "error_message": "Transient network timeout", - "structured_json": None, - "raw_output_path": None - } - return None - mock_db.fetchone = AsyncMock(side_effect=db_fetchone_mock) - - executor = TaskExecutor() - executor._persist_findings_and_report_common = AsyncMock() - executor._dispatch_task_notifications = AsyncMock() - executor._invalidate_cached_views = AsyncMock() - - # Stub the actually executed command to fail - async def fake_command(*args, **kwargs): - return "Network timeout", 1 - - execution_context = normalize_execution_context({}) - inputs = {"target": temp_repo, "__execution_context": execution_context} - _, _, cache_key = generate_scan_cache_key( - owner_id="owner_1", - plugin_id="test_plugin", - target=temp_repo, - inputs=inputs, - execution_context=execution_context, - safe_mode=False - ) - - with patch("backend.secuscan.executor.get_db", return_value=mock_db), \ - patch.object(executor, "_execute_command", side_effect=fake_command), \ - patch("backend.secuscan.executor.get_plugin_manager") as mock_pm: - - mock_plugin = MagicMock() - mock_plugin.name = "Test Plugin" - mock_plugin.presets = {} - mock_plugin.docker_image = None - mock_plugin.output = {"parser": "builtin_nmap", "format": "text"} - mock_plugin.category = "Network" - mock_plugin.id = "test_plugin" - - mock_pm.return_value.get_plugin.return_value = mock_plugin - mock_pm.return_value.build_command.return_value = ["ping", temp_repo] - mock_pm.return_value.plugins_dir = MagicMock() - mock_pm.return_value.plugins_dir.__truediv__ = MagicMock( - return_value=MagicMock( - __truediv__=MagicMock(return_value=MagicMock(exists=lambda: False)) - ) - ) - - await executor.execute_task("task_id_456", bypass_cache=False) - - # The cache should be empty for this key because the task status is FAILED - cache_client = await get_cache() - cached_val = await cache_client.get_json(cache_key) - assert cached_val is None