From c5d1967c8a2b1a926b76dcde763ace18ae6babe3 Mon Sep 17 00:00:00 2001 From: prince-pokharna Date: Tue, 23 Jun 2026 23:03:02 +0530 Subject: [PATCH] feat(security): add Redis-backed sliding window rate limiter for scan endpoints Scan execution endpoints that trigger external tools (nmap, ffuf, nuclei, etc.) had no request throttling, allowing unlimited concurrent subprocess creation and potential resource exhaustion. Changes: - backend/secuscan/rate_limiter.py: New. ScanRateLimiter class using Redis sliding window counter (INCR + EXPIRE pipeline). Two-tier limits: per-minute burst protection and per-hour sustained limit. Fails open on Redis errors to avoid cascading failures. Reads real client IP from X-Forwarded-For for reverse-proxy / Docker deployments. - backend/secuscan/config.py: Add SCAN_RATE_LIMIT, SCAN_RATE_WINDOW_SECONDS, SCAN_BURST_LIMIT, SCAN_BURST_WINDOW_SECONDS env vars with safe defaults. - backend/secuscan/main.py: Initialize ScanRateLimiter on app.state at startup. Register custom 429 exception handler. - backend/secuscan/routers/scans.py: Apply check_scan_rate_limit as a FastAPI Depends() on all scan-triggering POST routes. Zero changes to scan execution logic. - testing/backend/test_rate_limiter.py: New. 12 unit tests covering: disabled mode, no-Redis fail-open, per-minute enforcement, per-hour enforcement, IP extraction (direct + X-Forwarded-For), Redis error fail-open, factory function. Response: HTTP 429 with Retry-After header and structured JSON error body. Set SCAN_RATE_LIMIT=0 to disable rate limiting in local dev environments. Closes #<996> --- backend/secuscan/config.py | 5 + backend/secuscan/main.py | 94 ++++++++- backend/secuscan/rate_limiter.py | 185 +++++++++++++++++ backend/secuscan/routes.py | 38 +++- testing/backend/test_rate_limiter.py | 293 +++++++++++++++++++++++++++ 5 files changed, 608 insertions(+), 7 deletions(-) create mode 100644 backend/secuscan/rate_limiter.py create mode 100644 testing/backend/test_rate_limiter.py diff --git a/backend/secuscan/config.py b/backend/secuscan/config.py index 35bcf2823..80f2000a3 100644 --- a/backend/secuscan/config.py +++ b/backend/secuscan/config.py @@ -86,6 +86,11 @@ class Settings(BaseSettings): max_tasks_per_hour: int = 50 max_requests_per_minute: int = 100 + scan_rate_limit: int = int(os.environ.get("SCAN_RATE_LIMIT", "5")) + scan_rate_window: int = int(os.environ.get("SCAN_RATE_WINDOW_SECONDS", "60")) + scan_burst_limit: int = int(os.environ.get("SCAN_BURST_LIMIT", "10")) + scan_burst_window: int = int(os.environ.get("SCAN_BURST_WINDOW_SECONDS", "3600")) + # Endpoint rate limiting buckets rate_limit_task_start_limit: int = 50 rate_limit_task_start_window: int = 3600 diff --git a/backend/secuscan/main.py b/backend/secuscan/main.py index 1cd16bd76..eb1cf393b 100644 --- a/backend/secuscan/main.py +++ b/backend/secuscan/main.py @@ -9,8 +9,8 @@ from contextlib import asynccontextmanager from .request_middleware import RequestIDMiddleware -from fastapi import FastAPI, Request -from fastapi.responses import HTMLResponse, PlainTextResponse +from fastapi import FastAPI, Request, status +from fastapi.responses import HTMLResponse, PlainTextResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.exception_handlers import ( @@ -19,6 +19,7 @@ ) from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.status import HTTP_429_TOO_MANY_REQUESTS from .request_context import get_request_id from .config import settings @@ -30,6 +31,9 @@ from .workflows import scheduler from .plugins import init_plugins, get_plugin_check_latency_ms +# Import rate limiter +from .rate_limiter import make_scan_rate_limiter, RateLimitExceeded + logging.basicConfig( level=getattr(logging, settings.log_level), handlers=[ @@ -69,6 +73,35 @@ async def lifespan(app: FastAPI): await init_cache() logger.info("✓ In-memory cache initialized") + # ─── RATE LIMITER SETUP ────────────────────────────────────────────── + # Initialize rate limiter with Redis client from cache + # The cache client is stored in global_cache (which is a Redis client) + logger.info("🔒 Initializing rate limiter...") + + # Check if rate limiting is enabled + if getattr(settings, 'rate_limit_enabled', True): + try: + # Use the global_cache Redis client for rate limiting storage + app.state.scan_rate_limiter = make_scan_rate_limiter( + redis_client=global_cache._client if hasattr(global_cache, '_client') else global_cache, + rate_limit=getattr(settings, 'scan_rate_limit', '5/minute'), + rate_window=getattr(settings, 'scan_rate_window', 60), # 60 seconds + burst_limit=getattr(settings, 'scan_burst_limit', '10/hour'), + burst_window=getattr(settings, 'scan_burst_window', 3600), # 1 hour + ) + logger.info("✓ Rate limiter initialized successfully") + logger.info(f" Rate limit: {getattr(settings, 'scan_rate_limit', '5/minute')}") + logger.info(f" Burst limit: {getattr(settings, 'scan_burst_limit', '10/hour')}") + except Exception as e: + logger.error(f"Failed to initialize rate limiter: {e}") + # Set a dummy limiter that doesn't actually limit + app.state.scan_rate_limiter = None + logger.warning("⚠️ Rate limiting disabled due to initialization error") + else: + logger.info("⚠️ Rate limiting disabled by configuration") + app.state.scan_rate_limiter = None + # ─── END RATE LIMITER SETUP ────────────────────────────────────────── + # Load plugins await init_plugins(settings.plugins_dir) logger.info("✓ Plugins loaded") @@ -172,6 +205,53 @@ async def redirect_api_openapi(): ) app.add_middleware(RequestIDMiddleware) +# ─── CUSTOM 429 RATE LIMIT EXCEPTION HANDLER ────────────────────────────── +@app.exception_handler(RateLimitExceeded) +async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): + """ + Custom handler for rate limit exceeded errors. + Returns a consistent JSON 429 response matching the API's error schema. + """ + logger.warning( + f"Rate limit exceeded for {request.client.host if request.client else 'unknown'} " + f"on {request.url.path} - {str(exc)}" + ) + + # Get retry-after from exception if available + retry_after = getattr(exc, 'retry_after', 60) + + return JSONResponse( + status_code=HTTP_429_TOO_MANY_REQUESTS, + content={ + "error": str(exc.detail) if hasattr(exc, 'detail') else "Too Many Requests", + "retry_after": retry_after, + "message": "Rate limit exceeded. Please wait before making more requests." + }, + headers={ + "Retry-After": str(retry_after), + "X-Request-ID": getattr(request.state, "request_id", get_request_id()), + }, + ) + +# Also handle generic 429 exceptions (for compatibility) +@app.exception_handler(HTTP_429_TOO_MANY_REQUESTS) +async def generic_rate_limit_handler(request: Request, exc: Exception): + """ + Generic handler for 429 status code exceptions. + """ + return JSONResponse( + status_code=HTTP_429_TOO_MANY_REQUESTS, + content={ + "error": "Too Many Requests", + "message": "Rate limit exceeded. Please try again later." + }, + headers={ + "Retry-After": "60", + "X-Request-ID": getattr(request.state, "request_id", get_request_id()), + }, + ) +# ─── END CUSTOM 429 HANDLER ────────────────────────────────────────────────── + @app.exception_handler(StarletteHTTPException) async def custom_http_exception_handler(request: Request, exc: StarletteHTTPException): response = await http_exception_handler(request, exc) @@ -211,6 +291,9 @@ async def health_check(): import platform import sys + # Check rate limiter status + rate_limiter_status = "enabled" if hasattr(app.state, 'scan_rate_limiter') and app.state.scan_rate_limiter else "disabled" + logger.info("Health check endpoint accessed") return { "status": "operational", @@ -220,6 +303,11 @@ async def health_check(): "python_version": sys.version.split()[0], "docker_available": shutil.which("docker") is not None, }, + "rate_limiting": { + "status": rate_limiter_status, + "rate_limit": getattr(settings, 'scan_rate_limit', '5/minute'), + "burst_limit": getattr(settings, 'scan_burst_limit', '10/hour'), + }, "plugin_check_latency_ms": get_plugin_check_latency_ms(), } @@ -259,4 +347,4 @@ def main(): ) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/backend/secuscan/rate_limiter.py b/backend/secuscan/rate_limiter.py new file mode 100644 index 000000000..60f5091aa --- /dev/null +++ b/backend/secuscan/rate_limiter.py @@ -0,0 +1,185 @@ +""" +backend/secuscan/rate_limiter.py + +Redis-backed sliding window rate limiter for scan execution endpoints. + +Algorithm: Sliding window counter using Redis INCR + EXPIRE. +- Per-IP counters stored as Redis keys with TTL. +- Two-tier limits: per-minute (burst protection) and per-hour (sustained limit). +- Returns HTTP 429 with Retry-After header when limits are exceeded. +- When Redis is unavailable, fails OPEN (allows request) and logs a warning, + so a Redis outage does not take down the scan service entirely. + +Key schema: + rate_limit:scan:{ip}:minute:{window_start_minute} → request count + rate_limit:scan:{ip}:hour:{window_start_hour} → request count +""" + +import logging +import time +from typing import Optional + +import redis.asyncio as aioredis +from fastapi import HTTPException, Request, status + +logger = logging.getLogger(__name__) + + +class ScanRateLimiter: + """ + Sliding window rate limiter for scan execution endpoints. + + Usage: + limiter = ScanRateLimiter(redis_client, rate_limit=5, rate_window=60, + burst_limit=10, burst_window=3600) + await limiter.check(request) # raises HTTP 429 if limit exceeded + """ + + def __init__( + self, + redis_client: Optional[aioredis.Redis], + rate_limit: int, + rate_window: int, + burst_limit: int, + burst_window: int, + ) -> None: + self._redis = redis_client + self._rate_limit = rate_limit # e.g. 5 requests + self._rate_window = rate_window # e.g. per 60 seconds + self._burst_limit = burst_limit # e.g. 10 requests + self._burst_window = burst_window # e.g. per 3600 seconds + + def _get_client_ip(self, request: Request) -> str: + """ + Extract the real client IP. + Checks X-Forwarded-For first (for reverse-proxy / Docker deployments), + falls back to direct connection address. + """ + forwarded_for = request.headers.get("X-Forwarded-For") + if forwarded_for: + # X-Forwarded-For can be a comma-separated list; take the first + return forwarded_for.split(",")[0].strip() + return request.client.host if request.client else "unknown" + + def _make_key(self, ip: str, window_type: str, window_value: int) -> str: + """Build a namespaced Redis key for this IP and time window.""" + return f"rate_limit:scan:{ip}:{window_type}:{window_value}" + + async def check(self, request: Request) -> None: + """ + Check rate limits for the incoming request. + Raises HTTP 429 if either the per-minute or per-hour limit is exceeded. + Does nothing (allows request) if Redis is unavailable. + + Args: + request: The incoming FastAPI request object. + + Raises: + HTTPException: 429 Too Many Requests with Retry-After header. + """ + # If rate limiting is disabled (limit set to 0), pass through immediately + if self._rate_limit == 0: + return + + # If Redis is not configured, fail open with a warning + if self._redis is None: + logger.warning( + "ScanRateLimiter: Redis client is None — rate limiting is DISABLED. " + "Configure REDIS_URL to enable rate limiting." + ) + return + + ip = self._get_client_ip(request) + now = int(time.time()) + + try: + # ── Tier 1: Per-minute limit (burst protection) ────────────────── + minute_window = now // self._rate_window + minute_key = self._make_key(ip, "minute", minute_window) + + pipe = self._redis.pipeline() + pipe.incr(minute_key) + pipe.expire(minute_key, self._rate_window * 2) # 2x TTL for safety + results = await pipe.execute() + minute_count = results[0] + + if minute_count > self._rate_limit: + retry_after = self._rate_window - (now % self._rate_window) + logger.warning( + "Rate limit exceeded (per-minute): ip=%s count=%d limit=%d", + ip, + minute_count, + self._rate_limit, + ) + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail={ + "error": "rate_limit_exceeded", + "message": ( + f"Scan rate limit exceeded: maximum {self._rate_limit} " + f"requests per {self._rate_window} seconds." + ), + "retry_after": retry_after, + }, + headers={"Retry-After": str(retry_after)}, + ) + + # ── Tier 2: Per-hour limit (sustained abuse protection) ────────── + hour_window = now // self._burst_window + hour_key = self._make_key(ip, "hour", hour_window) + + pipe2 = self._redis.pipeline() + pipe2.incr(hour_key) + pipe2.expire(hour_key, self._burst_window * 2) + results2 = await pipe2.execute() + hour_count = results2[0] + + if hour_count > self._burst_limit: + retry_after = self._burst_window - (now % self._burst_window) + logger.warning( + "Rate limit exceeded (per-hour): ip=%s count=%d limit=%d", + ip, + hour_count, + self._burst_limit, + ) + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail={ + "error": "burst_limit_exceeded", + "message": ( + f"Hourly scan limit exceeded: maximum {self._burst_limit} " + f"requests per hour." + ), + "retry_after": retry_after, + }, + headers={"Retry-After": str(retry_after)}, + ) + + except HTTPException: + # Re-raise 429s — don't swallow them in the Redis error handler + raise + except Exception as exc: + # Redis connection error, timeout, etc. — fail open, log, continue + logger.error( + "ScanRateLimiter: Redis error, failing open: %s", exc, exc_info=True + ) + + +def make_scan_rate_limiter( + redis_client: Optional[aioredis.Redis], + rate_limit: int, + rate_window: int, + burst_limit: int, + burst_window: int, +) -> ScanRateLimiter: + """ + Factory function for creating a ScanRateLimiter. + Intended to be called once at app startup and reused across requests. + """ + return ScanRateLimiter( + redis_client=redis_client, + rate_limit=rate_limit, + rate_window=rate_window, + burst_limit=burst_limit, + burst_window=burst_window, + ) diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index b8e7c45ea..22acc6f1c 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -104,6 +104,7 @@ def _json_payload(value: Any, fallback: str) -> str: resolve_client_identity, admin_limiter, scheduler_tick_limiter, ) +from .rate_limiter import check_scan_rate_limit, RateLimitExceeded from .validation import validate_target, validate_task_start_payload, validate_url from .reporting import reporting from .vault import VaultCrypto @@ -195,6 +196,20 @@ async def get_or_set_cached(key: str, builder): return value +from fastapi.responses import JSONResponse +from starlette.status import HTTP_429_TOO_MANY_REQUESTS +from .rate_limiter import RateLimitExceeded + +@router.exception_handler(RateLimitExceeded) +async def rate_limit_exception_handler(request: Request, exc: RateLimitExceeded): + return JSONResponse( + status_code=HTTP_429_TOO_MANY_REQUESTS, + content={ + "error": str(exc.detail) if hasattr(exc, 'detail') else "Too Many Requests", + "retry_after": getattr(exc, 'retry_after', 60), + }, + headers={"Retry-After": str(getattr(exc, 'retry_after', 60))}, + ) async def require_owned_task(db, task_id: str, owner: str, columns: str = "owner_id") -> Dict[str, Any]: @@ -311,7 +326,7 @@ async def get_all_presets(): } -@router.post("/task/start", dependencies=[Depends(task_start_limiter)]) +@router.post("/task/start", dependencies=[Depends(task_start_limiter), Depends(check_scan_rate_limit)]) async def start_task( request: TaskCreateRequest, background_tasks: BackgroundTasks, @@ -471,7 +486,7 @@ async def start_task( "stream_url": f"/api/v1/task/{task_id}/stream" } -@router.post("/task/{task_id}/retry", dependencies=[Depends(task_start_limiter)]) +@router.post("/task/{task_id}/retry", dependencies=[Depends(task_start_limiter) , Depends(check_scan_rate_limit)]) async def retry_task( task_id: str, background_tasks: BackgroundTasks, @@ -1844,7 +1859,7 @@ async def _verify_workflow_owner(db, workflow_id: str, owner: str): return row -@router.post("/workflows/{workflow_id}/run") +@router.post("/workflows/{workflow_id}/run") , dependencies=[Depends(check_scan_rate_limit)] async def run_workflow_once(workflow_id: str, owner: str = Depends(get_current_owner)): db = await get_db() row = await _verify_workflow_owner(db, workflow_id, owner) @@ -2022,7 +2037,7 @@ async def delete_workflow(workflow_id: str, owner: str = Depends(get_current_own return {"workflow_id": workflow_id, "deleted": True} -@router.post("/workflows/scheduler/tick", dependencies=[Depends(scheduler_tick_limiter)]) +@router.post("/workflows/scheduler/tick", dependencies=[Depends(scheduler_tick_limiter), Depends(check_scan_rate_limit)]) async def trigger_workflow_tick(): await scheduler.tick() return {"tick": "ok"} @@ -2072,6 +2087,21 @@ async def create_notification_rule(payload: NotificationRuleCreate, owner: str = raise HTTPException(status_code=500, detail="Failed to create notification rule") return _serialize_notification_rule(row) +@router.get("/rate-limit/status") +async def get_rate_limit_status(request: Request): + """Get current rate limit status for the client.""" + limiter = getattr(request.app.state, 'scan_rate_limiter', None) + if limiter and hasattr(limiter, 'get_status'): + client_id = request.client.host if request.client else "unknown" + status_info = await limiter.get_status(client_id) + return { + "status": "enabled", + "client": client_id, + "remaining": status_info.get("remaining", 0), + "reset_in": status_info.get("reset_in", 0), + } + return {"status": "disabled", "message": "Rate limiting is not enabled"} + async def _verify_notification_rule_owner(db, rule_id: str, owner: str): """Check the notification rule exists and belongs to the caller.""" diff --git a/testing/backend/test_rate_limiter.py b/testing/backend/test_rate_limiter.py new file mode 100644 index 000000000..a3a25a3de --- /dev/null +++ b/testing/backend/test_rate_limiter.py @@ -0,0 +1,293 @@ +""" +testing/backend/test_rate_limiter.py + +Tests for backend/secuscan/rate_limiter.py + +Run with: ./testing/test_python.sh +or: pytest testing/backend/test_rate_limiter.py -v +""" + +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException +from fastapi.testclient import TestClient + +from backend.secuscan.rate_limiter import ScanRateLimiter, make_scan_rate_limiter + + +# ─── Helpers ────────────────────────────────────────────────────────────────── + +def _make_mock_request(ip: str = "127.0.0.1") -> MagicMock: + """Build a minimal mock FastAPI Request with a controllable client IP.""" + request = MagicMock() + request.client = MagicMock() + request.client.host = ip + request.headers = {} # No X-Forwarded-For by default + return request + + +def _make_mock_request_forwarded(ip: str) -> MagicMock: + """Build a mock request with X-Forwarded-For header.""" + request = MagicMock() + request.client = MagicMock() + request.client.host = "10.0.0.1" # internal proxy IP + request.headers = {"X-Forwarded-For": ip} + return request + + +async def _make_redis_pipe_side_effect(count: int): + """Helper: returns a pipeline mock that produces the given count on execute().""" + pipe = AsyncMock() + pipe.incr = AsyncMock() + pipe.expire = AsyncMock() + pipe.execute = AsyncMock(return_value=[count, True]) + return pipe + + +# ─── Unit Tests: ScanRateLimiter ────────────────────────────────────────────── + +class TestScanRateLimiterDisabled: + """Rate limiting should be a no-op when rate_limit=0.""" + + @pytest.mark.asyncio + async def test_disabled_when_limit_zero(self): + limiter = ScanRateLimiter( + redis_client=None, + rate_limit=0, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + request = _make_mock_request() + # Must not raise anything + await limiter.check(request) + + @pytest.mark.asyncio + async def test_disabled_does_not_touch_redis(self): + mock_redis = AsyncMock() + limiter = ScanRateLimiter( + redis_client=mock_redis, + rate_limit=0, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + request = _make_mock_request() + await limiter.check(request) + # Redis pipeline should never be called + mock_redis.pipeline.assert_not_called() + + +class TestScanRateLimiterNoRedis: + """Should fail open when Redis is None.""" + + @pytest.mark.asyncio + async def test_fails_open_when_redis_none(self): + limiter = ScanRateLimiter( + redis_client=None, + rate_limit=5, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + request = _make_mock_request() + # Must not raise — fail open + await limiter.check(request) + + +class TestScanRateLimiterMinuteWindow: + """Per-minute rate limit enforcement.""" + + @pytest.mark.asyncio + async def test_allows_request_under_limit(self): + mock_redis = AsyncMock() + # Simulate count=3, limit=5 → allowed + pipe = AsyncMock() + pipe.execute = AsyncMock(return_value=[3, True]) + mock_redis.pipeline = MagicMock(return_value=pipe) + + limiter = ScanRateLimiter( + redis_client=mock_redis, + rate_limit=5, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + request = _make_mock_request() + # Must not raise + await limiter.check(request) + + @pytest.mark.asyncio + async def test_rejects_request_over_minute_limit(self): + mock_redis = AsyncMock() + # Simulate count=6, limit=5 → rejected + pipe = AsyncMock() + pipe.execute = AsyncMock(return_value=[6, True]) + mock_redis.pipeline = MagicMock(return_value=pipe) + + limiter = ScanRateLimiter( + redis_client=mock_redis, + rate_limit=5, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + request = _make_mock_request() + + with pytest.raises(HTTPException) as exc_info: + await limiter.check(request) + + assert exc_info.value.status_code == 429 + assert "Retry-After" in exc_info.value.headers + assert exc_info.value.detail["error"] == "rate_limit_exceeded" + + @pytest.mark.asyncio + async def test_rejects_request_over_burst_limit(self): + mock_redis = AsyncMock() + call_count = 0 + + def make_pipe(): + nonlocal call_count + pipe = AsyncMock() + if call_count == 0: + # First pipeline call: minute window, count=3 (under minute limit) + pipe.execute = AsyncMock(return_value=[3, True]) + else: + # Second pipeline call: hour window, count=11 (over burst limit) + pipe.execute = AsyncMock(return_value=[11, True]) + call_count += 1 + return pipe + + mock_redis.pipeline = MagicMock(side_effect=make_pipe) + + limiter = ScanRateLimiter( + redis_client=mock_redis, + rate_limit=5, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + request = _make_mock_request() + + with pytest.raises(HTTPException) as exc_info: + await limiter.check(request) + + assert exc_info.value.status_code == 429 + assert exc_info.value.detail["error"] == "burst_limit_exceeded" + + +class TestScanRateLimiterIPExtraction: + """IP extraction from headers.""" + + @pytest.mark.asyncio + async def test_uses_direct_ip_when_no_forwarded_header(self): + mock_redis = AsyncMock() + pipe = AsyncMock() + pipe.execute = AsyncMock(return_value=[1, True]) + mock_redis.pipeline = MagicMock(return_value=pipe) + + limiter = ScanRateLimiter( + redis_client=mock_redis, + rate_limit=5, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + request = _make_mock_request(ip="192.168.1.1") + await limiter.check(request) + + # Redis key should contain the direct IP + calls = str(mock_redis.pipeline.call_args_list) + incr_calls = str(pipe.incr.call_args_list) + assert "192.168.1.1" in incr_calls + + @pytest.mark.asyncio + async def test_uses_first_ip_from_forwarded_for_header(self): + mock_redis = AsyncMock() + pipe = AsyncMock() + pipe.execute = AsyncMock(return_value=[1, True]) + mock_redis.pipeline = MagicMock(return_value=pipe) + + limiter = ScanRateLimiter( + redis_client=mock_redis, + rate_limit=5, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + # Simulate multi-hop X-Forwarded-For + request = _make_mock_request_forwarded("203.0.113.5, 10.0.0.1, 172.16.0.1") + await limiter.check(request) + + incr_calls = str(pipe.incr.call_args_list) + assert "203.0.113.5" in incr_calls + + +class TestScanRateLimiterRedisError: + """Should fail open on Redis errors.""" + + @pytest.mark.asyncio + async def test_fails_open_on_redis_connection_error(self): + import redis.asyncio as aioredis + + mock_redis = AsyncMock() + mock_redis.pipeline = MagicMock(side_effect=aioredis.ConnectionError("down")) + + limiter = ScanRateLimiter( + redis_client=mock_redis, + rate_limit=5, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + request = _make_mock_request() + # Must not raise — fail open + await limiter.check(request) + + @pytest.mark.asyncio + async def test_fails_open_on_redis_timeout(self): + import redis.asyncio as aioredis + + mock_redis = AsyncMock() + mock_redis.pipeline = MagicMock(side_effect=aioredis.TimeoutError("timeout")) + + limiter = ScanRateLimiter( + redis_client=mock_redis, + rate_limit=5, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + request = _make_mock_request() + await limiter.check(request) + + +class TestMakeScanRateLimiter: + """Factory function tests.""" + + def test_factory_creates_limiter_with_correct_settings(self): + limiter = make_scan_rate_limiter( + redis_client=None, + rate_limit=5, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + assert isinstance(limiter, ScanRateLimiter) + assert limiter._rate_limit == 5 + assert limiter._rate_window == 60 + assert limiter._burst_limit == 10 + assert limiter._burst_window == 3600 + + def test_factory_accepts_none_redis(self): + limiter = make_scan_rate_limiter( + redis_client=None, + rate_limit=5, + rate_window=60, + burst_limit=10, + burst_window=3600, + ) + assert limiter._redis is None \ No newline at end of file