Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backend/secuscan/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ class Settings(BaseSettings):
max_tasks_per_hour: int = 50
max_requests_per_minute: int = 100

scan_rate_limit: int = int(os.environ.get("SCAN_RATE_LIMIT", "5"))
scan_rate_window: int = int(os.environ.get("SCAN_RATE_WINDOW_SECONDS", "60"))
scan_burst_limit: int = int(os.environ.get("SCAN_BURST_LIMIT", "10"))
scan_burst_window: int = int(os.environ.get("SCAN_BURST_WINDOW_SECONDS", "3600"))

# Endpoint rate limiting buckets
rate_limit_task_start_limit: int = 50
rate_limit_task_start_window: int = 3600
Expand Down
94 changes: 91 additions & 3 deletions backend/secuscan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from contextlib import asynccontextmanager
from .request_middleware import RequestIDMiddleware

from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi import FastAPI, Request, status
from fastapi.responses import HTMLResponse, PlainTextResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.exception_handlers import (
Expand All @@ -19,6 +19,7 @@
)
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.status import HTTP_429_TOO_MANY_REQUESTS
from .request_context import get_request_id

from .config import settings
Expand All @@ -30,6 +31,9 @@
from .workflows import scheduler
from .plugins import init_plugins, get_plugin_check_latency_ms

# Import rate limiter
from .rate_limiter import make_scan_rate_limiter, RateLimitExceeded

logging.basicConfig(
level=getattr(logging, settings.log_level),
handlers=[
Expand Down Expand Up @@ -69,6 +73,35 @@ async def lifespan(app: FastAPI):
await init_cache()
logger.info("✓ In-memory cache initialized")

# ─── RATE LIMITER SETUP ──────────────────────────────────────────────
# Initialize rate limiter with Redis client from cache
# The cache client is stored in global_cache (which is a Redis client)
logger.info("🔒 Initializing rate limiter...")

# Check if rate limiting is enabled
if getattr(settings, 'rate_limit_enabled', True):
try:
# Use the global_cache Redis client for rate limiting storage
app.state.scan_rate_limiter = make_scan_rate_limiter(
redis_client=global_cache._client if hasattr(global_cache, '_client') else global_cache,
rate_limit=getattr(settings, 'scan_rate_limit', '5/minute'),
rate_window=getattr(settings, 'scan_rate_window', 60), # 60 seconds
burst_limit=getattr(settings, 'scan_burst_limit', '10/hour'),
burst_window=getattr(settings, 'scan_burst_window', 3600), # 1 hour
)
logger.info("✓ Rate limiter initialized successfully")
logger.info(f" Rate limit: {getattr(settings, 'scan_rate_limit', '5/minute')}")
logger.info(f" Burst limit: {getattr(settings, 'scan_burst_limit', '10/hour')}")
except Exception as e:
logger.error(f"Failed to initialize rate limiter: {e}")
# Set a dummy limiter that doesn't actually limit
app.state.scan_rate_limiter = None
logger.warning("⚠️ Rate limiting disabled due to initialization error")
else:
logger.info("⚠️ Rate limiting disabled by configuration")
app.state.scan_rate_limiter = None
# ─── END RATE LIMITER SETUP ──────────────────────────────────────────

# Load plugins
await init_plugins(settings.plugins_dir)
logger.info("✓ Plugins loaded")
Expand Down Expand Up @@ -172,6 +205,53 @@ async def redirect_api_openapi():
)
app.add_middleware(RequestIDMiddleware)

# ─── CUSTOM 429 RATE LIMIT EXCEPTION HANDLER ──────────────────────────────
@app.exception_handler(RateLimitExceeded)
async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
"""
Custom handler for rate limit exceeded errors.
Returns a consistent JSON 429 response matching the API's error schema.
"""
logger.warning(
f"Rate limit exceeded for {request.client.host if request.client else 'unknown'} "
f"on {request.url.path} - {str(exc)}"
)

# Get retry-after from exception if available
retry_after = getattr(exc, 'retry_after', 60)

return JSONResponse(
status_code=HTTP_429_TOO_MANY_REQUESTS,
content={
"error": str(exc.detail) if hasattr(exc, 'detail') else "Too Many Requests",
"retry_after": retry_after,
"message": "Rate limit exceeded. Please wait before making more requests."
},
headers={
"Retry-After": str(retry_after),
"X-Request-ID": getattr(request.state, "request_id", get_request_id()),
},
)

# Also handle generic 429 exceptions (for compatibility)
@app.exception_handler(HTTP_429_TOO_MANY_REQUESTS)
async def generic_rate_limit_handler(request: Request, exc: Exception):
"""
Generic handler for 429 status code exceptions.
"""
return JSONResponse(
status_code=HTTP_429_TOO_MANY_REQUESTS,
content={
"error": "Too Many Requests",
"message": "Rate limit exceeded. Please try again later."
},
headers={
"Retry-After": "60",
"X-Request-ID": getattr(request.state, "request_id", get_request_id()),
},
)
# ─── END CUSTOM 429 HANDLER ──────────────────────────────────────────────────

@app.exception_handler(StarletteHTTPException)
async def custom_http_exception_handler(request: Request, exc: StarletteHTTPException):
response = await http_exception_handler(request, exc)
Expand Down Expand Up @@ -211,6 +291,9 @@ async def health_check():
import platform
import sys

# Check rate limiter status
rate_limiter_status = "enabled" if hasattr(app.state, 'scan_rate_limiter') and app.state.scan_rate_limiter else "disabled"

logger.info("Health check endpoint accessed")
return {
"status": "operational",
Expand All @@ -220,6 +303,11 @@ async def health_check():
"python_version": sys.version.split()[0],
"docker_available": shutil.which("docker") is not None,
},
"rate_limiting": {
"status": rate_limiter_status,
"rate_limit": getattr(settings, 'scan_rate_limit', '5/minute'),
"burst_limit": getattr(settings, 'scan_burst_limit', '10/hour'),
},
"plugin_check_latency_ms": get_plugin_check_latency_ms(),
}

Expand Down Expand Up @@ -259,4 +347,4 @@ def main():
)

if __name__ == "__main__":
main()
main()
185 changes: 185 additions & 0 deletions backend/secuscan/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""
backend/secuscan/rate_limiter.py

Redis-backed sliding window rate limiter for scan execution endpoints.

Algorithm: Sliding window counter using Redis INCR + EXPIRE.
- Per-IP counters stored as Redis keys with TTL.
- Two-tier limits: per-minute (burst protection) and per-hour (sustained limit).
- Returns HTTP 429 with Retry-After header when limits are exceeded.
- When Redis is unavailable, fails OPEN (allows request) and logs a warning,
so a Redis outage does not take down the scan service entirely.

Key schema:
rate_limit:scan:{ip}:minute:{window_start_minute} → request count
rate_limit:scan:{ip}:hour:{window_start_hour} → request count
"""

import logging
import time
from typing import Optional

import redis.asyncio as aioredis
from fastapi import HTTPException, Request, status

logger = logging.getLogger(__name__)


class ScanRateLimiter:
"""
Sliding window rate limiter for scan execution endpoints.

Usage:
limiter = ScanRateLimiter(redis_client, rate_limit=5, rate_window=60,
burst_limit=10, burst_window=3600)
await limiter.check(request) # raises HTTP 429 if limit exceeded
"""

def __init__(
self,
redis_client: Optional[aioredis.Redis],
rate_limit: int,
rate_window: int,
burst_limit: int,
burst_window: int,
) -> None:
self._redis = redis_client
self._rate_limit = rate_limit # e.g. 5 requests
self._rate_window = rate_window # e.g. per 60 seconds
self._burst_limit = burst_limit # e.g. 10 requests
self._burst_window = burst_window # e.g. per 3600 seconds

def _get_client_ip(self, request: Request) -> str:
"""
Extract the real client IP.
Checks X-Forwarded-For first (for reverse-proxy / Docker deployments),
falls back to direct connection address.
"""
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For can be a comma-separated list; take the first
return forwarded_for.split(",")[0].strip()
return request.client.host if request.client else "unknown"

def _make_key(self, ip: str, window_type: str, window_value: int) -> str:
"""Build a namespaced Redis key for this IP and time window."""
return f"rate_limit:scan:{ip}:{window_type}:{window_value}"

async def check(self, request: Request) -> None:
"""
Check rate limits for the incoming request.
Raises HTTP 429 if either the per-minute or per-hour limit is exceeded.
Does nothing (allows request) if Redis is unavailable.

Args:
request: The incoming FastAPI request object.

Raises:
HTTPException: 429 Too Many Requests with Retry-After header.
"""
# If rate limiting is disabled (limit set to 0), pass through immediately
if self._rate_limit == 0:
return

# If Redis is not configured, fail open with a warning
if self._redis is None:
logger.warning(
"ScanRateLimiter: Redis client is None — rate limiting is DISABLED. "
"Configure REDIS_URL to enable rate limiting."
)
return

ip = self._get_client_ip(request)
now = int(time.time())

try:
# ── Tier 1: Per-minute limit (burst protection) ──────────────────
minute_window = now // self._rate_window
minute_key = self._make_key(ip, "minute", minute_window)

pipe = self._redis.pipeline()
pipe.incr(minute_key)
pipe.expire(minute_key, self._rate_window * 2) # 2x TTL for safety
results = await pipe.execute()
minute_count = results[0]

if minute_count > self._rate_limit:
retry_after = self._rate_window - (now % self._rate_window)
logger.warning(
"Rate limit exceeded (per-minute): ip=%s count=%d limit=%d",
ip,
minute_count,
self._rate_limit,
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"error": "rate_limit_exceeded",
"message": (
f"Scan rate limit exceeded: maximum {self._rate_limit} "
f"requests per {self._rate_window} seconds."
),
"retry_after": retry_after,
},
headers={"Retry-After": str(retry_after)},
)

# ── Tier 2: Per-hour limit (sustained abuse protection) ──────────
hour_window = now // self._burst_window
hour_key = self._make_key(ip, "hour", hour_window)

pipe2 = self._redis.pipeline()
pipe2.incr(hour_key)
pipe2.expire(hour_key, self._burst_window * 2)
results2 = await pipe2.execute()
hour_count = results2[0]

if hour_count > self._burst_limit:
retry_after = self._burst_window - (now % self._burst_window)
logger.warning(
"Rate limit exceeded (per-hour): ip=%s count=%d limit=%d",
ip,
hour_count,
self._burst_limit,
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"error": "burst_limit_exceeded",
"message": (
f"Hourly scan limit exceeded: maximum {self._burst_limit} "
f"requests per hour."
),
"retry_after": retry_after,
},
headers={"Retry-After": str(retry_after)},
)

except HTTPException:
# Re-raise 429s — don't swallow them in the Redis error handler
raise
except Exception as exc:
# Redis connection error, timeout, etc. — fail open, log, continue
logger.error(
"ScanRateLimiter: Redis error, failing open: %s", exc, exc_info=True
)


def make_scan_rate_limiter(
redis_client: Optional[aioredis.Redis],
rate_limit: int,
rate_window: int,
burst_limit: int,
burst_window: int,
) -> ScanRateLimiter:
"""
Factory function for creating a ScanRateLimiter.
Intended to be called once at app startup and reused across requests.
"""
return ScanRateLimiter(
redis_client=redis_client,
rate_limit=rate_limit,
rate_window=rate_window,
burst_limit=burst_limit,
burst_window=burst_window,
)
Loading
Loading