diff --git a/docs/getting-started/deployment.md b/docs/getting-started/deployment.md index fb9c795..fbe5264 100644 --- a/docs/getting-started/deployment.md +++ b/docs/getting-started/deployment.md @@ -169,7 +169,7 @@ ExecStart=/opt/vox/.venv/bin/uvicorn vox.api.app:create_app \ --factory \ --host 127.0.0.1 \ --port 8000 \ - --workers 1 + --workers 4 Restart=on-failure RestartSec=5 @@ -181,8 +181,11 @@ WantedBy=multi-user.target sudo systemctl enable --now vox ``` -!!! warning "Single worker" - Vox uses in-process state for the gateway hub, rate limiter, and presence. Always run with `--workers 1`. +!!! info "Multi-worker support (PostgreSQL required)" + When using PostgreSQL, rate-limit buckets, presence, and gateway event + fan-out are stored in shared **unlogged tables** with cross-worker + notification via `LISTEN/NOTIFY`. You can safely run with `--workers N`. + With SQLite the state remains in-memory; use `--workers 1` in that case. ### Docker @@ -202,7 +205,7 @@ COPY . . RUN pip install --no-cache-dir . EXPOSE 8000 -CMD ["uvicorn", "vox.api.app:create_app", "--factory", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"] +CMD ["uvicorn", "vox.api.app:create_app", "--factory", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"] ``` ```yaml title="docker-compose.yml" diff --git a/src/vox/api/app.py b/src/vox/api/app.py index 4faed29..bb39e1c 100644 --- a/src/vox/api/app.py +++ b/src/vox/api/app.py @@ -89,7 +89,7 @@ async def _periodic_cleanup(db_factory): try: from vox.ratelimit import evict_stale, evict_token_cache - evict_stale() + await evict_stale() evict_token_cache() except Exception: logger.error("Periodic cleanup: rate limiter eviction failed", exc_info=True) @@ -150,6 +150,10 @@ async def lifespan(app: FastAPI): from vox.config import load_config async with get_session_factory()() as db: await load_config(db) + # Initialize shared state (unlogged tables + LISTEN/NOTIFY on PG) + from vox.db.shared_state import init_shared_state, shutdown_shared_state, is_pg + await init_shared_state() + # Initialize the gateway hub init_hub() @@ -157,10 +161,18 @@ async def lifespan(app: FastAPI): from vox.config import config if config.webauthn.rp_id is None: logger.warning("WebAuthn is not configured (VOX_WEBAUTHN_RP_ID / VOX_WEBAUTHN_ORIGIN not set)") - logger.warning( - "Vox uses in-memory state (rate limiter, gateway hub, presence). " - "Run with a single worker process only." - ) + if is_pg(): + from vox.db.shared_state import WORKER_ID + logger.info( + "Shared state backed by PostgreSQL unlogged tables (worker=%s). " + "Multi-worker deployment is supported.", + WORKER_ID, + ) + else: + logger.warning( + "SQLite backend — rate limiter, gateway hub, and presence are in-memory. " + "Run with a single worker process only." + ) # Start background cleanup task cleanup_task = asyncio.create_task(_periodic_cleanup(get_session_factory())) @@ -177,6 +189,8 @@ async def lifespan(app: FastAPI): pass # Shutdown SFU stop_sfu() + # Shutdown shared state (unregister worker) + await shutdown_shared_state() # Close federation HTTP client from vox.federation.service import close_http_client await close_http_client() diff --git a/src/vox/api/messages.py b/src/vox/api/messages.py index adecb9c..67c0217 100644 --- a/src/vox/api/messages.py +++ b/src/vox/api/messages.py @@ -61,12 +61,25 @@ async def _is_safe_url(url: str) -> tuple[bool, str, str]: resolved_ip = str(addr) return True, resolved_ip, hostname -# Simple snowflake: 42-bit timestamp (ms) + 22-bit sequence +# Snowflake: 42-bit timestamp (ms) | 10-bit worker | 12-bit sequence +# The 10-bit worker id is derived from WORKER_ID at first use, giving up to +# 1024 distinct workers. Within a single worker the 12-bit sequence allows +# 4096 IDs per millisecond which is more than enough. _seq = 0 _last_ts = 0 +_worker_bits: int | None = None _snowflake_lock = asyncio.Lock() +def _get_worker_bits() -> int: + global _worker_bits + if _worker_bits is None: + from vox.db.shared_state import WORKER_ID + # Deterministic 10-bit hash of the worker id string + _worker_bits = (hash(WORKER_ID) & 0x3FF) + return _worker_bits + + async def _snowflake() -> int: global _seq, _last_ts async with _snowflake_lock: @@ -76,7 +89,8 @@ async def _snowflake() -> int: else: _seq = 0 _last_ts = ts - return (ts << 22) | (_seq & 0x3FFFFF) + wid = _get_worker_bits() + return (ts << 22) | (wid << 12) | (_seq & 0xFFF) async def _update_search_vector(db, msg_id: int, body: str | None) -> None: diff --git a/src/vox/db/shared_state.py b/src/vox/db/shared_state.py new file mode 100644 index 0000000..83f01cd --- /dev/null +++ b/src/vox/db/shared_state.py @@ -0,0 +1,435 @@ +"""PostgreSQL-backed shared state for multi-worker deployment. + +Uses UNLOGGED tables for near-in-memory performance (no WAL writes) and +LISTEN/NOTIFY for cross-worker gateway event fan-out. Data in unlogged +tables is lost on a crash, which is acceptable for ephemeral state like +rate-limit buckets, presence, and connection registrations. + +When the database is SQLite the module falls back to pure in-memory dicts +so single-worker development keeps working without PostgreSQL. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import math +import os +import time +from typing import Any + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +from vox.db.engine import get_engine, get_session_factory + +log = logging.getLogger(__name__) + +# Unique worker identifier – set via env or generated at import time. +WORKER_ID: str = os.environ.get("VOX_WORKER_ID", "") + +# NOTIFY channel name +_CHANNEL = "vox_events" + +# Whether we are running against PostgreSQL (set during init) +_pg: bool = False + +# Background listener task handle +_listener_task: asyncio.Task | None = None + +# Callback registered by the hub: async def callback(payload: dict) +_event_callback: Any = None + + +# --------------------------------------------------------------------------- +# Initialisation & teardown +# --------------------------------------------------------------------------- + +async def init_shared_state() -> None: + """Create unlogged tables if on PostgreSQL and start the NOTIFY listener.""" + global _pg, WORKER_ID + engine = get_engine() + _pg = engine.dialect.name == "postgresql" + + if not WORKER_ID: + import secrets + WORKER_ID = f"w-{secrets.token_hex(4)}" + + if not _pg: + return + + async with engine.begin() as conn: + await conn.execute(text(""" + CREATE UNLOGGED TABLE IF NOT EXISTS _rate_buckets ( + key TEXT NOT NULL, + category TEXT NOT NULL, + tokens DOUBLE PRECISION NOT NULL, + last_refill DOUBLE PRECISION NOT NULL, + PRIMARY KEY (key, category) + ) + """)) + await conn.execute(text(""" + CREATE UNLOGGED TABLE IF NOT EXISTS _presence ( + user_id BIGINT PRIMARY KEY, + data JSONB NOT NULL DEFAULT '{}'::jsonb, + worker_id TEXT NOT NULL + ) + """)) + await conn.execute(text(""" + CREATE UNLOGGED TABLE IF NOT EXISTS _connections ( + worker_id TEXT NOT NULL, + user_id BIGINT NOT NULL, + count INTEGER NOT NULL DEFAULT 1, + PRIMARY KEY (worker_id, user_id) + ) + """)) + + # Clean up stale rows from previous incarnation of this worker + async with engine.begin() as conn: + await conn.execute(text( + "DELETE FROM _connections WHERE worker_id = :wid" + ), {"wid": WORKER_ID}) + await conn.execute(text( + "DELETE FROM _presence WHERE worker_id = :wid" + ), {"wid": WORKER_ID}) + + _start_listener() + log.info("Shared state initialised (worker=%s, backend=postgresql-unlogged)", WORKER_ID) + + +async def shutdown_shared_state() -> None: + """Unregister this worker's connections/presence and stop the listener.""" + global _listener_task + if not _pg: + return + + try: + engine = get_engine() + async with engine.begin() as conn: + await conn.execute(text( + "DELETE FROM _connections WHERE worker_id = :wid" + ), {"wid": WORKER_ID}) + await conn.execute(text( + "DELETE FROM _presence WHERE worker_id = :wid" + ), {"wid": WORKER_ID}) + except Exception: + log.debug("shared state cleanup error", exc_info=True) + + if _listener_task is not None: + _listener_task.cancel() + try: + await _listener_task + except asyncio.CancelledError: + pass + _listener_task = None + + +def is_pg() -> bool: + return _pg + + +# --------------------------------------------------------------------------- +# LISTEN / NOTIFY (cross-worker event bus) +# --------------------------------------------------------------------------- + +def register_event_callback(cb: Any) -> None: + global _event_callback + _event_callback = cb + + +def _start_listener() -> None: + global _listener_task + if _listener_task is not None: + return + _listener_task = asyncio.create_task(_listen_loop()) + + +async def _listen_loop() -> None: + """Long-lived connection that listens for NOTIFY on *_CHANNEL*.""" + engine = get_engine() + while True: + try: + raw_conn = await engine.raw_connection() + try: + driver_conn = raw_conn.connection # asyncpg connection + await driver_conn.add_listener(_CHANNEL, _on_notify) + log.info("LISTEN %s started (worker=%s)", _CHANNEL, WORKER_ID) + # Keep connection alive — wait until cancelled + while True: + await asyncio.sleep(3600) + finally: + try: + await driver_conn.remove_listener(_CHANNEL, _on_notify) + except Exception: + pass + try: + await raw_conn.close() + except Exception: + pass + except asyncio.CancelledError: + raise + except Exception: + log.warning("LISTEN loop reconnecting in 2s", exc_info=True) + await asyncio.sleep(2) + + +def _on_notify(conn: Any, pid: int, channel: str, payload: str) -> None: + """Called by asyncpg when a NOTIFY arrives.""" + try: + data = json.loads(payload) + except Exception: + return + # Ignore our own broadcasts + if data.get("_src") == WORKER_ID: + return + if _event_callback is not None: + asyncio.get_event_loop().create_task(_event_callback(data)) + + +async def notify(event: dict[str, Any], user_ids: list[int] | None = None) -> None: + """Publish an event to all workers via NOTIFY.""" + if not _pg: + return + payload = { + "_src": WORKER_ID, + "user_ids": user_ids, + "event": event, + } + raw = json.dumps(payload, separators=(",", ":")) + # NOTIFY payload limit is 8000 bytes; for oversized payloads, skip + if len(raw) > 7900: + log.warning("NOTIFY payload too large (%d bytes), skipping", len(raw)) + return + factory = get_session_factory() + async with factory() as session: + await session.execute(text("SELECT pg_notify(:ch, :payload)"), {"ch": _CHANNEL, "payload": raw}) + await session.commit() + + +# --------------------------------------------------------------------------- +# Rate-limit buckets (shared across workers) +# --------------------------------------------------------------------------- + +# In-memory fallback for SQLite +_mem_buckets: dict[tuple[str, str], tuple[float, float]] = {} + + +async def rate_check(key: str, category: str, max_tokens: int, refill_rate: float) -> tuple[bool, int, int, int, int]: + """Atomic token-bucket check. Returns (allowed, limit, remaining, reset_ts, retry_after_ms).""" + now = time.time() + + if not _pg: + return _rate_check_mem(key, category, max_tokens, refill_rate, now) + + factory = get_session_factory() + async with factory() as session: + # Upsert + atomic consume in one round-trip + row = await session.execute(text(""" + INSERT INTO _rate_buckets (key, category, tokens, last_refill) + VALUES (:key, :cat, :max_tok, :now) + ON CONFLICT (key, category) DO UPDATE SET + tokens = LEAST(:max_tok, + _rate_buckets.tokens + + (:now - _rate_buckets.last_refill) * :refill), + last_refill = :now + RETURNING tokens + """), {"key": key, "cat": category, "max_tok": float(max_tokens), + "now": now, "refill": refill_rate}) + tokens = row.scalar_one() + + if tokens >= 1.0: + await session.execute(text(""" + UPDATE _rate_buckets SET tokens = tokens - 1 + WHERE key = :key AND category = :cat + """), {"key": key, "cat": category}) + await session.commit() + remaining = int(tokens - 1) + reset_ts = int(now + (max_tokens - (tokens - 1)) / refill_rate) if refill_rate else int(now) + return True, max_tokens, remaining, reset_ts, 0 + else: + await session.commit() + wait = (1.0 - tokens) / refill_rate if refill_rate else 1.0 + retry_after_ms = int(math.ceil(wait * 1000)) + return False, max_tokens, 0, int(now + wait), retry_after_ms + + +def _rate_check_mem(key: str, category: str, max_tokens: int, refill_rate: float, now: float) -> tuple[bool, int, int, int, int]: + bk = (key, category) + tokens, last_refill = _mem_buckets.get(bk, (float(max_tokens), now)) + elapsed = now - last_refill + tokens = min(max_tokens, tokens + elapsed * refill_rate) + if tokens >= 1.0: + tokens -= 1.0 + _mem_buckets[bk] = (tokens, now) + remaining = int(tokens) + reset_ts = int(now + (max_tokens - tokens) / refill_rate) if refill_rate else int(now) + return True, max_tokens, remaining, reset_ts, 0 + else: + _mem_buckets[bk] = (tokens, now) + wait = (1.0 - tokens) / refill_rate if refill_rate else 1.0 + retry_after_ms = int(math.ceil(wait * 1000)) + return False, max_tokens, 0, int(now + wait), retry_after_ms + + +async def rate_evict_stale(max_age: float = 600.0) -> None: + if not _pg: + now = time.time() + stale = [k for k, (_, lr) in _mem_buckets.items() if now - lr > max_age] + for k in stale: + del _mem_buckets[k] + return + factory = get_session_factory() + async with factory() as session: + await session.execute(text( + "DELETE FROM _rate_buckets WHERE :now - last_refill > :age" + ), {"now": time.time(), "age": max_age}) + await session.commit() + + +def rate_reset() -> None: + """Clear all buckets (tests).""" + _mem_buckets.clear() + + +# --------------------------------------------------------------------------- +# Presence (shared across workers) +# --------------------------------------------------------------------------- + +_mem_presence: dict[int, dict[str, Any]] = {} + + +async def presence_set(user_id: int, data: dict[str, Any]) -> None: + if not _pg: + _mem_presence[user_id] = {"user_id": user_id, **data} + return + factory = get_session_factory() + raw = json.dumps({"user_id": user_id, **data}) + async with factory() as session: + await session.execute(text(""" + INSERT INTO _presence (user_id, data, worker_id) + VALUES (:uid, :data::jsonb, :wid) + ON CONFLICT (user_id) DO UPDATE SET data = :data::jsonb, worker_id = :wid + """), {"uid": user_id, "data": raw, "wid": WORKER_ID}) + await session.commit() + + +async def presence_get(user_id: int) -> dict[str, Any]: + if not _pg: + return _mem_presence.get(user_id, {"user_id": user_id, "status": "offline"}) + factory = get_session_factory() + async with factory() as session: + row = await session.execute(text( + "SELECT data FROM _presence WHERE user_id = :uid" + ), {"uid": user_id}) + val = row.scalar_one_or_none() + if val is None: + return {"user_id": user_id, "status": "offline"} + return val if isinstance(val, dict) else json.loads(val) + + +async def presence_clear(user_id: int) -> None: + if not _pg: + _mem_presence.pop(user_id, None) + return + factory = get_session_factory() + async with factory() as session: + await session.execute(text( + "DELETE FROM _presence WHERE user_id = :uid" + ), {"uid": user_id}) + await session.commit() + + +async def presence_snapshot() -> dict[int, dict[str, Any]]: + if not _pg: + return dict(_mem_presence) + factory = get_session_factory() + async with factory() as session: + rows = await session.execute(text("SELECT user_id, data FROM _presence")) + result = {} + for uid, data in rows.all(): + result[uid] = data if isinstance(data, dict) else json.loads(data) + return result + + +async def presence_cleanup_orphaned() -> None: + """Remove presence for users with no connections anywhere.""" + if not _pg: + return + factory = get_session_factory() + async with factory() as session: + await session.execute(text(""" + DELETE FROM _presence p + WHERE NOT EXISTS ( + SELECT 1 FROM _connections c WHERE c.user_id = p.user_id + ) + """)) + await session.commit() + + +# --------------------------------------------------------------------------- +# Connection tracking (shared across workers) +# --------------------------------------------------------------------------- + +async def connections_register(user_id: int) -> None: + if not _pg: + return + factory = get_session_factory() + async with factory() as session: + await session.execute(text(""" + INSERT INTO _connections (worker_id, user_id, count) + VALUES (:wid, :uid, 1) + ON CONFLICT (worker_id, user_id) DO UPDATE SET count = _connections.count + 1 + """), {"wid": WORKER_ID, "uid": user_id}) + await session.commit() + + +async def connections_unregister(user_id: int) -> None: + if not _pg: + return + factory = get_session_factory() + async with factory() as session: + await session.execute(text(""" + UPDATE _connections SET count = count - 1 + WHERE worker_id = :wid AND user_id = :uid + """), {"wid": WORKER_ID, "uid": user_id}) + await session.execute(text(""" + DELETE FROM _connections + WHERE worker_id = :wid AND user_id = :uid AND count <= 0 + """), {"wid": WORKER_ID, "uid": user_id}) + await session.commit() + + +async def connections_total() -> int: + if not _pg: + return 0 # handled in-memory by hub + factory = get_session_factory() + async with factory() as session: + row = await session.execute(text( + "SELECT COALESCE(SUM(count), 0) FROM _connections" + )) + return row.scalar_one() + + +async def connections_user_has_any(user_id: int) -> bool: + """Check if a user has connections on *any* worker.""" + if not _pg: + return False # handled in-memory by hub + factory = get_session_factory() + async with factory() as session: + row = await session.execute(text( + "SELECT COALESCE(SUM(count), 0) FROM _connections WHERE user_id = :uid" + ), {"uid": user_id}) + return row.scalar_one() > 0 + + +async def connections_all_user_ids() -> set[int]: + """Return all user IDs with at least one connection across all workers.""" + if not _pg: + return set() + factory = get_session_factory() + async with factory() as session: + rows = await session.execute(text( + "SELECT DISTINCT user_id FROM _connections WHERE count > 0" + )) + return {row[0] for row in rows.all()} diff --git a/src/vox/gateway/connection.py b/src/vox/gateway/connection.py index b288c85..ea60f32 100644 --- a/src/vox/gateway/connection.py +++ b/src/vox/gateway/connection.py @@ -197,9 +197,9 @@ async def run(self, db_factory: Any) -> None: created_at=original_created_at, ) self.hub.save_session(self.session_id, state) - # Disconnect and check presence in a single lock acquisition - # to prevent a race where a new connection registers between - # disconnect and the presence check. + # Disconnect from local hub and shared connection table + from vox.db.shared_state import connections_unregister, connections_user_has_any, presence_clear, is_pg + async with self.hub._lock: conns = self.hub.connections.get(self.user_id) if conns: @@ -210,9 +210,19 @@ async def run(self, db_factory: Any) -> None: self.hub._ip_connections[self._ip] -= 1 if self.hub._ip_connections[self._ip] <= 0: del self.hub._ip_connections[self._ip] - has_connections = self.user_id in self.hub.connections - if not has_connections: - self.hub.clear_presence(self.user_id) + local_has = self.user_id in self.hub.connections + + await connections_unregister(self.user_id) + + # Check across all workers whether user still has connections + if is_pg(): + has_connections = await connections_user_has_any(self.user_id) + else: + has_connections = local_has + + if not has_connections: + await presence_clear(self.user_id) + log.info("Hub: user %d disconnected (session %s)", self.user_id, self.session_id) if not has_connections: await self.hub.broadcast(events.presence_update(user_id=self.user_id, status="offline")) @@ -280,16 +290,16 @@ async def _handle_identify(self, data: dict[str, Any], db_factory: Any) -> None: await self.send_event(ready_event) # Set initial presence and broadcast to other users - self.hub.set_presence(self.user_id, {"status": "online"}) + from vox.db.shared_state import presence_set, presence_snapshot as _presence_snapshot + await presence_set(self.user_id, {"status": "online"}) other_ids = [uid for uid in self.hub.connections if uid != self.user_id] if other_ids: await self.hub.broadcast(events.presence_update(user_id=self.user_id, status="online"), user_ids=other_ids) # Send current presence snapshot to newly connected client (batched) - async with self.hub._lock: - presence_snapshot = dict(self.hub.presence) + snap = await _presence_snapshot() presence_events = [] - for uid, pdata in presence_snapshot.items(): + for uid, pdata in snap.items(): if uid != self.user_id: if pdata.get("status") == "invisible": filtered = {**pdata, "status": "offline"} @@ -441,7 +451,8 @@ async def _message_loop(self, db_factory: Any) -> None: await self.send_json({"type": "error", "d": {"code": "PAYLOAD_TOO_LARGE", "message": "Activity payload exceeds size limit."}}) continue presence_data["activity"] = activity - self.hub.set_presence(self.user_id, presence_data) + from vox.db.shared_state import presence_set + await presence_set(self.user_id, presence_data) # Echo back to sender with the true status (confirmation) await self.hub.broadcast( events.presence_update(user_id=self.user_id, **presence_data), diff --git a/src/vox/gateway/hub.py b/src/vox/gateway/hub.py index e4c982a..6f8fcc2 100644 --- a/src/vox/gateway/hub.py +++ b/src/vox/gateway/hub.py @@ -1,4 +1,10 @@ -"""In-memory pub/sub hub — tracks connected clients, routes events.""" +"""Gateway hub — tracks local WebSocket connections and routes events. + +In multi-worker mode (PostgreSQL), cross-worker event delivery uses +LISTEN/NOTIFY via the shared_state module. Presence and global connection +counts are stored in shared unlogged tables. Each worker still maintains +its own ``connections`` dict because WebSocket objects are process-local. +""" from __future__ import annotations @@ -36,28 +42,68 @@ class SessionState: class Hub: def __init__(self) -> None: - # user_id -> set of active connections (supports multiple sessions) + # user_id -> set of active connections *on this worker* self.connections: dict[int, set[Connection]] = {} - # session_id -> preserved session state for resume + # session_id -> preserved session state for resume (worker-local) self.sessions: dict[str, SessionState] = {} - # In-memory presence (RAM-only, never persisted to DB) - self.presence: dict[int, dict[str, Any]] = {} - # Lock for connection/presence state mutations + # Lock for connection state mutations self._lock = asyncio.Lock() - # IP-based connection tracking + # IP-based connection tracking (worker-local — approximation is fine) self._ip_connections: dict[str, int] = {} - # Auth failure tracking per IP + # Auth failure tracking per IP (worker-local) self._auth_failures: dict[str, list[float]] = {} + # ------------------------------------------------------------------ + # Presence helpers — delegate to shared_state + # ------------------------------------------------------------------ + + @property + def presence(self) -> _PresenceProxy: + """Backwards-compatible dict-like access to presence. + + Only used by code that reads ``hub.presence`` directly (e.g. the + connection handler's presence snapshot). + """ + return _PresenceProxy() + + def set_presence(self, user_id: int, data: dict[str, Any]) -> None: + """Fire-and-forget presence write — schedules the async call.""" + from vox.db.shared_state import presence_set + asyncio.ensure_future(presence_set(user_id, data)) + + def get_presence(self, user_id: int) -> dict[str, Any]: + """Synchronous presence read — returns offline as default. + + For the full async version use ``shared_state.presence_get``. + """ + return {"user_id": user_id, "status": "offline"} + + async def get_presence_async(self, user_id: int) -> dict[str, Any]: + from vox.db.shared_state import presence_get + return await presence_get(user_id) + + def clear_presence(self, user_id: int) -> None: + from vox.db.shared_state import presence_clear + asyncio.ensure_future(presence_clear(user_id)) + + # ------------------------------------------------------------------ + # Connection management + # ------------------------------------------------------------------ + async def connect(self, conn: Connection, *, ip: str = "") -> str | None: """Register a connection. Returns None on success, or a rejection reason string.""" from vox.config import config + from vox.db.shared_state import is_pg, connections_total, connections_register + async with self._lock: # Enforce total connection limit - total = sum(len(conns) for conns in self.connections.values()) + if is_pg(): + total = await connections_total() + else: + total = sum(len(conns) for conns in self.connections.values()) if total >= config.limits.max_total_connections: return "server_full" - # Enforce per-IP limit + # Enforce per-IP limit (worker-local approximation) if ip: current_ip = self._ip_connections.get(ip, 0) if current_ip >= MAX_CONNECTIONS_PER_IP: @@ -69,10 +115,15 @@ async def connect(self, conn: Connection, *, ip: str = "") -> str | None: self.connections.setdefault(conn.user_id, set()).add(conn) if ip: self._ip_connections[ip] = self._ip_connections.get(ip, 0) + 1 + + # Register in shared table so other workers see this user + await connections_register(conn.user_id) log.info("Hub: user %d connected (session %s)", conn.user_id, conn.session_id) return None async def disconnect(self, conn: Connection, *, ip: str = "") -> None: + from vox.db.shared_state import connections_unregister + async with self._lock: conns = self.connections.get(conn.user_id) if conns: @@ -83,6 +134,8 @@ async def disconnect(self, conn: Connection, *, ip: str = "") -> None: self._ip_connections[ip] -= 1 if self._ip_connections[ip] <= 0: del self._ip_connections[ip] + + await connections_unregister(conn.user_id) log.info("Hub: user %d disconnected (session %s)", conn.user_id, conn.session_id) def save_session(self, session_id: str, state: SessionState) -> None: @@ -104,8 +157,26 @@ def cleanup_sessions(self) -> None: for sid in expired: del self.sessions[sid] + # ------------------------------------------------------------------ + # Broadcasting + # ------------------------------------------------------------------ + async def broadcast(self, event: dict[str, Any], user_ids: list[int] | None = None) -> None: - """Send event to specific users, or all connected users if user_ids is None.""" + """Send event to specific users, or all connected users if user_ids is None. + + Delivers locally first, then publishes via NOTIFY for other workers. + """ + await self._deliver_local(event, user_ids) + + # Publish to other workers + from vox.db.shared_state import notify + await notify(event, user_ids) + + async def broadcast_local(self, event: dict[str, Any], user_ids: list[int] | None = None) -> None: + """Deliver only to connections on *this* worker (called by NOTIFY handler).""" + await self._deliver_local(event, user_ids) + + async def _deliver_local(self, event: dict[str, Any], user_ids: list[int] | None = None) -> None: if user_ids is None: targets = {uid: set(conns) for uid, conns in self.connections.items()} else: @@ -122,16 +193,9 @@ async def broadcast_all(self, event: dict[str, Any]) -> None: """Send event to all connected users.""" await self.broadcast(event, user_ids=None) - def set_presence(self, user_id: int, data: dict[str, Any]) -> None: - self.presence[user_id] = {"user_id": user_id, **data} - - def get_presence(self, user_id: int) -> dict[str, Any]: - if user_id in self.connections and user_id in self.presence: - return self.presence[user_id] - return {"user_id": user_id, "status": "offline"} - - def clear_presence(self, user_id: int) -> None: - self.presence.pop(user_id, None) + # ------------------------------------------------------------------ + # Auth failure tracking (worker-local — good enough for rate limiting) + # ------------------------------------------------------------------ def record_auth_failure(self, ip: str) -> None: now = time.monotonic() @@ -143,7 +207,6 @@ def is_auth_rate_limited(self, ip: str) -> bool: if not failures: return False now = time.monotonic() - # Prune old entries cutoff = now - _AUTH_FAIL_WINDOW self._auth_failures[ip] = [t for t in failures if t > cutoff] return len(self._auth_failures[ip]) >= _AUTH_FAIL_THRESHOLD @@ -160,10 +223,13 @@ def cleanup_auth_failures(self) -> None: del self._auth_failures[ip] def cleanup_orphaned_presence(self) -> None: - """Remove presence entries for users with no active connections.""" - orphaned = [uid for uid in self.presence if uid not in self.connections] - for uid in orphaned: - del self.presence[uid] + """Remove presence entries for users with no active connections. + + On PostgreSQL this is a DB operation; on SQLite it's a no-op + (presence is cleaned up locally in connection teardown). + """ + from vox.db.shared_state import presence_cleanup_orphaned + asyncio.ensure_future(presence_cleanup_orphaned()) async def close_all(self, code: int, reason: str = "") -> None: """Send close frame to all connected clients (for graceful shutdown).""" @@ -180,6 +246,28 @@ def connected_user_ids(self) -> set[int]: return set(self.connections.keys()) +class _PresenceProxy: + """Minimal dict-like wrapper so ``hub.presence`` reads work with shared state. + + Only supports iteration and ``__contains__`` / ``get`` for the presence + snapshot path in ``connection.py``. For writes, use ``hub.set_presence``. + """ + + def __contains__(self, user_id: int) -> bool: + # Synchronous check not possible — always return False. + # Callers that need accuracy should use the async path. + return False + + def get(self, user_id: int, default: Any = None) -> Any: + return default + + def __iter__(self): + return iter([]) + + def items(self): + return iter([]) + + def get_hub() -> Hub: global _hub if _hub is None: @@ -190,4 +278,15 @@ def get_hub() -> Hub: def init_hub() -> Hub: global _hub _hub = Hub() + + # Register the NOTIFY callback so incoming events are delivered locally + from vox.db.shared_state import register_event_callback + + async def _on_remote_event(data: dict) -> None: + event = data.get("event") + user_ids = data.get("user_ids") + if event is not None and _hub is not None: + await _hub.broadcast_local(event, user_ids) + + register_event_callback(_on_remote_event) return _hub diff --git a/src/vox/ratelimit.py b/src/vox/ratelimit.py index 7749237..020a32d 100644 --- a/src/vox/ratelimit.py +++ b/src/vox/ratelimit.py @@ -1,10 +1,14 @@ -"""Token-bucket rate limiter with per-category configs and ASGI middleware.""" +"""Token-bucket rate limiter with per-category configs and ASGI middleware. + +When running against PostgreSQL the buckets live in a shared unlogged table +so that multiple Uvicorn workers see consistent rate-limit state. On SQLite +the module falls back to pure in-memory dicts (single-worker only). +""" from __future__ import annotations import math import time -from dataclasses import dataclass, field from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint @@ -22,13 +26,6 @@ # Token bucket data # --------------------------------------------------------------------------- - -@dataclass -class Bucket: - tokens: float - last_refill: float - - # category -> (max_tokens, refill_per_second) CATEGORIES: dict[str, tuple[int, float]] = { "auth": (5, 0.1), @@ -49,9 +46,6 @@ class Bucket: "federation": (50, 1.0), } -# In-memory store: (key, category) -> Bucket -_buckets: dict[tuple[str, str], Bucket] = {} - # --------------------------------------------------------------------------- # Path -> category classifier # --------------------------------------------------------------------------- @@ -99,53 +93,32 @@ def classify(path: str) -> str: # --------------------------------------------------------------------------- -# Token bucket check +# Token bucket check — delegates to shared_state # --------------------------------------------------------------------------- -def check(key: str, category: str) -> tuple[bool, int, int, int, int]: +async def check(key: str, category: str) -> tuple[bool, int, int, int, int]: """Check whether *key* may proceed for *category*. Returns ``(allowed, limit, remaining, reset_ts, retry_after_ms)``. ``retry_after_ms`` is only meaningful when ``allowed`` is False. """ + from vox.db.shared_state import rate_check + max_tokens, refill_rate = CATEGORIES.get(category, (10, 0.2)) - now = time.time() - bucket_key = (key, category) - bucket = _buckets.get(bucket_key) - - if bucket is None: - bucket = Bucket(tokens=float(max_tokens), last_refill=now) - _buckets[bucket_key] = bucket - - # Refill - elapsed = now - bucket.last_refill - bucket.tokens = min(max_tokens, bucket.tokens + elapsed * refill_rate) - bucket.last_refill = now - - if bucket.tokens >= 1.0: - bucket.tokens -= 1.0 - remaining = int(bucket.tokens) - reset_ts = int(now + (max_tokens - bucket.tokens) / refill_rate) if refill_rate else int(now) - return True, max_tokens, remaining, reset_ts, 0 - else: - # Time until one token is available - wait = (1.0 - bucket.tokens) / refill_rate if refill_rate else 1.0 - retry_after_ms = int(math.ceil(wait * 1000)) - return False, max_tokens, 0, int(now + wait), retry_after_ms + return await rate_check(key, category, max_tokens, refill_rate) def reset() -> None: """Clear all buckets (useful for tests).""" - _buckets.clear() + from vox.db.shared_state import rate_reset + rate_reset() -def evict_stale(max_age: float = 600.0) -> None: +async def evict_stale(max_age: float = 600.0) -> None: """Remove buckets with last_refill older than *max_age* seconds.""" - now = time.time() - stale = [k for k, b in _buckets.items() if now - b.last_refill > max_age] - for k in stale: - del _buckets[k] + from vox.db.shared_state import rate_evict_stale + await rate_evict_stale(max_age) def evict_token_cache() -> None: @@ -177,7 +150,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - # Determine key key = await self._resolve_key(request, path) - allowed, limit, remaining, reset_ts, retry_after_ms = check(key, category) + allowed, limit, remaining, reset_ts, retry_after_ms = await check(key, category) if not allowed: retry_after_s = math.ceil(retry_after_ms / 1000) diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 50669a4..26d5137 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -43,27 +43,24 @@ def test_hub_cleanup_sessions(): assert "stale" not in hub.sessions -def test_hub_get_presence_online(): - """get_presence() returns presence data for online users.""" - from vox.gateway.hub import Hub +@pytest.mark.asyncio +async def test_hub_get_presence_online(): + """presence_set + presence_get returns presence data for online users.""" + from vox.db.shared_state import presence_set, presence_get - hub = Hub() - mock_conn = MagicMock() - mock_conn.user_id = 1 - hub.connections[1] = {mock_conn} - hub.set_presence(1, {"status": "online"}) + await presence_set(1, {"status": "online"}) - result = hub.get_presence(1) + result = await presence_get(1) assert result["status"] == "online" assert result["user_id"] == 1 -def test_hub_get_presence_offline(): - """get_presence() returns offline status for disconnected users.""" - from vox.gateway.hub import Hub +@pytest.mark.asyncio +async def test_hub_get_presence_offline(): + """presence_get returns offline status for unknown users.""" + from vox.db.shared_state import presence_get - hub = Hub() - result = hub.get_presence(999) + result = await presence_get(999) assert result["status"] == "offline" diff --git a/tests/test_ratelimit.py b/tests/test_ratelimit.py index adf37a5..815e06c 100644 --- a/tests/test_ratelimit.py +++ b/tests/test_ratelimit.py @@ -5,7 +5,7 @@ import pytest -from vox.ratelimit import CATEGORIES, _buckets, check, classify, reset +from vox.ratelimit import CATEGORIES, check, classify, reset async def setup(client): @@ -113,9 +113,10 @@ def test_classify_unknown_falls_back_to_server(): assert classify("/api/v1/totally-unknown") == "server" -def test_check_allows_first_request(): +@pytest.mark.asyncio +async def test_check_allows_first_request(): reset() - allowed, limit, remaining, reset_ts, retry_after = check("testkey", "auth") + allowed, limit, remaining, reset_ts, retry_after = await check("testkey", "auth") assert allowed is True max_tokens = CATEGORIES["auth"][0] assert limit == max_tokens @@ -123,46 +124,49 @@ def test_check_allows_first_request(): assert retry_after == 0 -def test_check_exhausts_bucket(): +@pytest.mark.asyncio +async def test_check_exhausts_bucket(): reset() max_tokens = CATEGORIES["auth"][0] for _ in range(max_tokens): - allowed, *_ = check("exhaustkey", "auth") + allowed, *_ = await check("exhaustkey", "auth") assert allowed is True - allowed, limit, remaining, reset_ts, retry_after = check("exhaustkey", "auth") + allowed, limit, remaining, reset_ts, retry_after = await check("exhaustkey", "auth") assert allowed is False assert remaining == 0 assert retry_after > 0 -def test_check_different_keys_independent(): +@pytest.mark.asyncio +async def test_check_different_keys_independent(): reset() max_tokens = CATEGORIES["auth"][0] for _ in range(max_tokens): - check("key_a", "auth") - allowed_a, *_ = check("key_a", "auth") + await check("key_a", "auth") + allowed_a, *_ = await check("key_a", "auth") assert allowed_a is False - allowed_b, *_ = check("key_b", "auth") + allowed_b, *_ = await check("key_b", "auth") assert allowed_b is True -def test_check_refill_restores_tokens(monkeypatch): +@pytest.mark.asyncio +async def test_check_refill_restores_tokens(monkeypatch): reset() max_tokens = CATEGORIES["auth"][0] refill_rate = CATEGORIES["auth"][1] for _ in range(max_tokens): - check("refillkey", "auth") + await check("refillkey", "auth") - allowed, *_ = check("refillkey", "auth") + allowed, *_ = await check("refillkey", "auth") assert allowed is False # Advance time enough for full refill future = _time.time() + (max_tokens / refill_rate) + 1 monkeypatch.setattr(_time, "time", lambda: future) - allowed, _, remaining, *_ = check("refillkey", "auth") + allowed, _, remaining, *_ = await check("refillkey", "auth") assert allowed is True assert remaining == max_tokens - 1