1313import os
1414from typing import Any
1515
16- from fastapi import FastAPI , HTTPException
16+ from fastapi import FastAPI , HTTPException , Request
1717from fastapi .middleware .cors import CORSMiddleware
18- from fastapi .responses import FileResponse
18+ from fastapi .responses import FileResponse , JSONResponse
1919from fastapi .staticfiles import StaticFiles
2020from pydantic import BaseModel , Field
21+ from slowapi import Limiter
22+ from slowapi .errors import RateLimitExceeded
23+ from slowapi .util import get_remote_address
2124
22- from config import CHAT_UI , HOST , PORT
25+ from config import CHAT_UI , HOST , PORT , validate
2326from llm import chat_response , generate_action_plan
2427from vector_store import ingest_playbooks , list_playbooks , search_playbooks
2528
2932)
3033logger = logging .getLogger ("gone-phishing" )
3134
35+ limiter = Limiter (key_func = get_remote_address )
36+
3237app = FastAPI (
3338 title = "Gone-Phishing IRP Engine" ,
3439 version = "0.2.0" ,
3540 description = "AI-powered Incident Response Plan retrieval and action plan generation." ,
3641)
42+ app .state .limiter = limiter
43+
44+
45+ @app .exception_handler (RateLimitExceeded )
46+ async def _rate_limit_handler (request : Request , exc : RateLimitExceeded ):
47+ return JSONResponse (status_code = 429 , content = {"detail" : "Rate limit exceeded. Try again shortly." })
48+
49+ _cors_origins = os .getenv ("CORS_ORIGINS" , "" ).split ("," )
50+ _cors_origins = [o .strip () for o in _cors_origins if o .strip ()]
3751
3852app .add_middleware (
3953 CORSMiddleware ,
40- allow_origins = ["*" ],
41- allow_credentials = True ,
54+ allow_origins = _cors_origins or ["*" ],
55+ allow_credentials = bool ( _cors_origins ) ,
4256 allow_methods = ["*" ],
4357 allow_headers = ["*" ],
4458)
4559
4660
61+ @app .middleware ("http" )
62+ async def security_headers (request : Request , call_next ):
63+ response = await call_next (request )
64+ response .headers ["X-Content-Type-Options" ] = "nosniff"
65+ response .headers ["X-Frame-Options" ] = "DENY"
66+ response .headers ["Referrer-Policy" ] = "strict-origin-when-cross-origin"
67+ return response
68+
69+
4770# -- Request / response schemas ---------------------------------------------
4871
4972
@@ -66,7 +89,8 @@ class ChatInput(BaseModel):
6689
6790
6891@app .post ("/api/incident" )
69- async def handle_incident (body : IncidentInput ) -> dict [str , Any ]:
92+ @limiter .limit ("10/minute" )
93+ async def handle_incident (request : Request , body : IncidentInput ) -> dict [str , Any ]:
7094 """Submit an incident description → receive a role-assigned action plan."""
7195 try :
7296 matches = search_playbooks (body .description , n_results = 8 )
@@ -91,13 +115,14 @@ async def handle_incident(body: IncidentInput) -> dict[str, Any]:
91115 }
92116 except HTTPException :
93117 raise
94- except Exception as exc :
118+ except Exception :
95119 logger .exception ("Incident endpoint error" )
96- raise HTTPException (500 , str ( exc ) )
120+ raise HTTPException (500 , "Failed to generate action plan. Check server logs." )
97121
98122
99123@app .post ("/api/chat" )
100- async def handle_chat (body : ChatInput ) -> dict [str , str ]:
124+ @limiter .limit ("20/minute" )
125+ async def handle_chat (request : Request , body : ChatInput ) -> dict [str , str ]:
101126 """Follow-up questions in chat context."""
102127 try :
103128 latest_user_msg = next (
@@ -107,9 +132,9 @@ async def handle_chat(body: ChatInput) -> dict[str, str]:
107132 matches = search_playbooks (latest_user_msg , n_results = 5 ) if latest_user_msg else []
108133 response = chat_response (messages = body .messages , playbook_context = matches or None )
109134 return {"response" : response }
110- except Exception as exc :
135+ except Exception :
111136 logger .exception ("Chat endpoint error" )
112- raise HTTPException (500 , str ( exc ) )
137+ raise HTTPException (500 , "Chat request failed. Check server logs." )
113138
114139
115140@app .post ("/api/search" )
@@ -127,8 +152,9 @@ async def handle_search(body: SearchInput) -> dict[str, Any]:
127152 for m in matches
128153 ]
129154 }
130- except Exception as exc :
131- raise HTTPException (500 , str (exc ))
155+ except Exception :
156+ logger .exception ("Search endpoint error" )
157+ raise HTTPException (500 , "Search failed. Check server logs." )
132158
133159
134160@app .get ("/api/playbooks" )
@@ -144,8 +170,34 @@ async def run_ingest() -> dict[str, Any]:
144170
145171
146172@app .get ("/api/health" )
147- async def health () -> dict [str , str ]:
148- return {"status" : "ok" , "service" : "gone-phishing" }
173+ async def health () -> dict [str , Any ]:
174+ """Liveness check with dependency status."""
175+ checks : list [dict [str , Any ]] = []
176+
177+ # ChromaDB reachable + playbooks ingested
178+ try :
179+ playbooks = list_playbooks ()
180+ checks .append ({"name" : "chromadb" , "ok" : True , "playbooks" : len (playbooks )})
181+ except Exception as exc :
182+ checks .append ({"name" : "chromadb" , "ok" : False , "error" : str (exc )})
183+
184+ # LLM provider configured
185+ from config import LLM_MODEL , LLM_PROVIDER
186+ llm_configured = True
187+ try :
188+ from adapters import get_adapter
189+ adapter = get_adapter ()
190+ checks .append ({"name" : "llm" , "ok" : True , "provider" : LLM_PROVIDER , "model" : adapter .model_name })
191+ except Exception as exc :
192+ llm_configured = False
193+ checks .append ({"name" : "llm" , "ok" : False , "error" : str (exc )})
194+
195+ all_ok = all (c ["ok" ] for c in checks )
196+ return {
197+ "status" : "ok" if all_ok else "degraded" ,
198+ "service" : "gone-phishing" ,
199+ "checks" : checks ,
200+ }
149201
150202
151203# -- Chat UI mount -----------------------------------------------------------
@@ -191,6 +243,7 @@ def _mount_chainlit() -> None:
191243if __name__ == "__main__" :
192244 import uvicorn
193245
246+ validate ()
194247 logger .info ("Ingesting playbooks on startup..." )
195248 result = ingest_playbooks ()
196249 logger .info ("Ready: %d files, %d chunks" , result ["files_ingested" ], result ["total_chunks" ])
0 commit comments