diff --git a/.env.example b/.env.example index 8a2b667..50148ce 100644 --- a/.env.example +++ b/.env.example @@ -12,3 +12,12 @@ LOG_RESPONSE_BODY=false # Providers are managed via the Admin API: # POST /admin/providers — register a provider (openai, openrouter, dashscope) # POST /admin/models — map a model name to a provider + +# ─── Single-instance mode (SQLite + in-memory cache) ─── +# Uncomment the line below and remove the DATABASE_URL / REDIS_URL +# above to run without PostgreSQL or Redis. +# +# DATABASE_URL=sqlite:llm_gateway.db?mode=rwc +# +# When DATABASE_URL starts with "sqlite:", Redis is not required. +# An in-memory key/route cache is used instead. diff --git a/Cargo.lock b/Cargo.lock index 16d1150..67cd0b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -990,6 +990,7 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" dependencies = [ + "cc", "pkg-config", "vcpkg", ] diff --git a/Cargo.toml b/Cargo.toml index 07f3b9b..ef6bae7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ tower-http = { version = "0.6", features = ["cors", "trace"] } tower = { version = "0.5" } # Database -sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "postgres", "migrate", "uuid", "chrono"] } +sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "postgres", "sqlite", "migrate", "uuid", "chrono"] } # Redis redis = { version = "0.27", features = ["tokio-comp", "aio", "connection-manager"] } diff --git a/Dockerfile b/Dockerfile index ff4b569..b1a37e9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,6 +13,7 @@ RUN cargo build --release && rm -rf src # Build real binary COPY src/ src/ COPY migrations/ migrations/ +COPY migrations_sqlite/ migrations_sqlite/ RUN touch src/main.rs && cargo build --release # ---- Runtime stage ---- diff --git a/README.md b/README.md index 0098165..63f3137 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Routes OpenAI-compatible `/v1/chat/completions` requests to multiple upstream pr - **User Key management** — Generate `sk-{uuid}` keys, rotate (old key instantly invalidated), soft-delete - **Streaming** — Full SSE streaming passthrough for `stream: true` requests - **Two-tier caching** — Redis (hot) for O(1) key validation & model routing, PostgreSQL (cold) for persistence +- **Single-instance mode** — Run with SQLite + in-memory cache when Redis and PostgreSQL are unavailable - **Admin API** — Protected by a static admin key; manage providers, models, and user keys ## Architecture @@ -18,16 +19,21 @@ Routes OpenAI-compatible `/v1/chat/completions` requests to multiple upstream pr ```text Client ──► Gateway (/v1/chat/completions) ──► Provider (OpenAI / OpenRouter / DashScope / Ark) │ - ├─ User Key auth (Redis SET → PG fallback) - ├─ Model resolution (Redis HASH → PG fallback) + ├─ User Key auth (Cache → DB fallback) + ├─ Model resolution (Cache → DB fallback) └─ Request rewrite (model name) + proxy + +Full mode: Cache = Redis, DB = PostgreSQL +Single-instance: Cache = In-memory, DB = SQLite ``` ```text src/ ├── main.rs # Entrypoint: init, migrations, server ├── config.rs # Env-based configuration -├── state.rs # Shared AppState (PgPool, Redis, HttpClient) +├── state.rs # Shared AppState (DbPool, Cache, HttpClient) +├── db.rs # DbPool enum (PgPool | SqlitePool) + query macros +├── cache.rs # Cache enum (Redis | InMemory) ├── error.rs # Unified error type → HTTP responses ├── middleware/ │ └── auth.rs # Admin key + User key auth middleware @@ -41,7 +47,7 @@ src/ └── services/ ├── key_service.rs # Key generation, hashing, validation, rotation ├── provider_service.rs # Provider CRUD - └── model_service.rs # Model CRUD, route resolution, Redis cache + └── model_service.rs # Model CRUD, route resolution, cache ``` ## Quick Start @@ -49,9 +55,11 @@ src/ ### Prerequisites - Rust 1.75+ -- Docker & Docker Compose (for PostgreSQL and Redis) +- Docker & Docker Compose (for PostgreSQL and Redis — **not needed in single-instance mode**) + +### Option A: Full deployment (PostgreSQL + Redis) -### 1. Clone and configure +#### 1. Clone and configure ```bash git clone && cd llm-gateway-rs @@ -67,13 +75,13 @@ ADMIN_KEY=my-secret-admin-key LISTEN_ADDR=0.0.0.0:8080 ``` -### 2. Start dependencies +#### 2. Start dependencies ```bash docker compose up -d ``` -### 3. Run the gateway +#### 3. Run the gateway ```bash cargo run @@ -81,6 +89,26 @@ cargo run The server starts on `http://localhost:8080`. Database migrations run automatically on startup. +### Option B: Single-instance mode (SQLite, no external dependencies) + +For lightweight or local deployments, the gateway can run with SQLite and an in-memory cache — no PostgreSQL or Redis required. + +```bash +git clone && cd llm-gateway-rs + +# Configure for SQLite +export DATABASE_URL="sqlite:llm_gateway.db?mode=rwc" +export ADMIN_KEY="my-secret-admin-key" +# REDIS_URL is not needed — in-memory cache is used automatically + +cargo run +``` + +A `llm_gateway.db` file will be created in the working directory with all tables. + +> **Note:** Single-instance mode stores the key/model route cache in process memory. +> It is designed for single-process deployments. For multi-instance or HA setups, use PostgreSQL + Redis. + ## Admin API All admin endpoints require `Authorization: Bearer `. @@ -262,8 +290,8 @@ The gateway will: | Variable | Required | Default | Description | | -------- | -------- | ------- | ----------- | -| `DATABASE_URL` | Yes | — | PostgreSQL connection string | -| `REDIS_URL` | No | `redis://127.0.0.1:6379` | Redis connection string | +| `DATABASE_URL` | Yes | — | PostgreSQL connection string, or `sqlite:?mode=rwc` for SQLite | +| `REDIS_URL` | No | `redis://127.0.0.1:6379` (PG mode) / *none* (SQLite mode) | Redis connection string. Omit for in-memory caching with SQLite | | `ADMIN_KEY` | Yes | — | Secret key for admin API access | | `LISTEN_ADDR` | No | `0.0.0.0:8080` | Server listen address | @@ -272,7 +300,8 @@ The gateway will: - **Key format**: `sk-{uuid v4}` — 39 characters, recognizable prefix - **Key storage**: Only SHA-256 hashes stored; plaintext returned once on create/rotate (like GitHub PATs) - **Redis strategy**: `SET` for key hashes (`SISMEMBER` O(1)), `HASH` for model routes (`HGET` O(1)) -- **Cache warm-up**: On startup, all active keys and model routes are loaded from PG into Redis +- **Single-instance mode**: When `DATABASE_URL` starts with `sqlite:`, an in-memory cache replaces Redis, and SQLite replaces PostgreSQL — zero external dependencies +- **Cache warm-up**: On startup, all active keys and model routes are loaded from DB into cache (Redis or in-memory) - **Streaming**: Raw byte-stream passthrough — no SSE parsing, minimal latency - **Provider API keys**: Stored in PG, listed with masked preview (`sk-x...xxxx`), never cached in plaintext outside the routing lookup diff --git a/migrations_sqlite/001_init.sql b/migrations_sqlite/001_init.sql new file mode 100644 index 0000000..4b0e19c --- /dev/null +++ b/migrations_sqlite/001_init.sql @@ -0,0 +1,39 @@ +-- User keys: gateway-issued API keys for end users +CREATE TABLE IF NOT EXISTS user_keys ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + key_hash TEXT NOT NULL, + key_prefix TEXT NOT NULL, + is_active INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_user_keys_key_hash ON user_keys (key_hash); +CREATE INDEX IF NOT EXISTS idx_user_keys_is_active ON user_keys (is_active); + +-- Providers: each represents an LLM API backend (OpenAI, OpenRouter, DashScope, etc.) +CREATE TABLE IF NOT EXISTS providers ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + kind TEXT NOT NULL DEFAULT 'openai', + base_url TEXT NOT NULL, + api_key TEXT NOT NULL, + is_active INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- Models: maps user-facing model names to a provider +CREATE TABLE IF NOT EXISTS models ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + provider_id TEXT NOT NULL REFERENCES providers(id), + provider_model_name TEXT, + is_active INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE INDEX IF NOT EXISTS idx_models_name ON models (name) WHERE is_active = 1; +CREATE INDEX IF NOT EXISTS idx_models_provider_id ON models (provider_id); diff --git a/migrations_sqlite/002_request_logs.sql b/migrations_sqlite/002_request_logs.sql new file mode 100644 index 0000000..5d478b2 --- /dev/null +++ b/migrations_sqlite/002_request_logs.sql @@ -0,0 +1,26 @@ +-- Request logs for tracking all proxy calls +CREATE TABLE IF NOT EXISTS request_logs ( + id TEXT PRIMARY KEY, + request_id TEXT, + user_key_id TEXT, + user_key_hash TEXT NOT NULL, + model_requested TEXT NOT NULL, + model_sent TEXT NOT NULL, + provider_id TEXT, + provider_kind TEXT, + status_code INTEGER NOT NULL, + is_error INTEGER NOT NULL DEFAULT 0, + prompt_tokens INTEGER, + completion_tokens INTEGER, + total_tokens INTEGER, + latency_ms INTEGER NOT NULL, + is_stream INTEGER NOT NULL DEFAULT 0, + request_body TEXT, + response_body TEXT, + error_message TEXT, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE INDEX IF NOT EXISTS idx_request_logs_created_at ON request_logs (created_at DESC); +CREATE INDEX IF NOT EXISTS idx_request_logs_user_key ON request_logs (user_key_hash); +CREATE INDEX IF NOT EXISTS idx_request_logs_model ON request_logs (model_requested); diff --git a/migrations_sqlite/003_token_budget.sql b/migrations_sqlite/003_token_budget.sql new file mode 100644 index 0000000..a4912d6 --- /dev/null +++ b/migrations_sqlite/003_token_budget.sql @@ -0,0 +1,3 @@ +-- Add token budget columns to user_keys +ALTER TABLE user_keys ADD COLUMN token_budget INTEGER NULL; +ALTER TABLE user_keys ADD COLUMN tokens_used INTEGER NOT NULL DEFAULT 0; diff --git a/migrations_sqlite/004_token_coefficients.sql b/migrations_sqlite/004_token_coefficients.sql new file mode 100644 index 0000000..47466bc --- /dev/null +++ b/migrations_sqlite/004_token_coefficients.sql @@ -0,0 +1,4 @@ +-- Add input/output token cost coefficients to models +-- Default 1.0 means 1 raw token = 1 budget token +ALTER TABLE models ADD COLUMN input_token_coefficient REAL NOT NULL DEFAULT 1.0; +ALTER TABLE models ADD COLUMN output_token_coefficient REAL NOT NULL DEFAULT 1.0; diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..2910afd --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,164 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Cache abstraction: either Redis (for multi-instance deployments) or +/// an in-memory store (for single-instance / SQLite mode). +#[derive(Clone)] +pub enum Cache { + Redis(Box), + InMemory(Arc), +} + +/// Simple in-memory cache that mirrors the Redis SET / HASH operations +/// used by key_service and model_service. +pub struct InMemoryCache { + sets: RwLock>>, + hashes: RwLock>>, +} + +impl InMemoryCache { + pub fn new() -> Self { + Self { + sets: RwLock::new(HashMap::new()), + hashes: RwLock::new(HashMap::new()), + } + } +} + +impl Cache { + /// Create a new in-memory cache. + pub fn in_memory() -> Self { + Cache::InMemory(Arc::new(InMemoryCache::new())) + } + + // ── SET operations ──────────────────────────────────────────────── + + /// Add a member to a set. + pub async fn sadd(&mut self, key: &str, member: &str) -> Result<(), crate::error::AppError> { + match self { + Cache::Redis(cm) => { + let _: () = redis::AsyncCommands::sadd(cm.as_mut(), key, member).await?; + Ok(()) + } + Cache::InMemory(store) => { + let mut sets = store.sets.write().await; + sets.entry(key.to_string()) + .or_default() + .insert(member.to_string()); + Ok(()) + } + } + } + + /// Check if a member exists in a set. + pub async fn sismember(&mut self, key: &str, member: &str) -> Result { + match self { + Cache::Redis(cm) => { + let exists: bool = redis::AsyncCommands::sismember(cm.as_mut(), key, member).await?; + Ok(exists) + } + Cache::InMemory(store) => { + let sets = store.sets.read().await; + Ok(sets.get(key).is_some_and(|s| s.contains(member))) + } + } + } + + /// Remove a member from a set. + pub async fn srem(&mut self, key: &str, member: &str) -> Result<(), crate::error::AppError> { + match self { + Cache::Redis(cm) => { + let _: () = redis::AsyncCommands::srem(cm.as_mut(), key, member).await?; + Ok(()) + } + Cache::InMemory(store) => { + let mut sets = store.sets.write().await; + if let Some(set) = sets.get_mut(key) { + set.remove(member); + } + Ok(()) + } + } + } + + // ── HASH operations ─────────────────────────────────────────────── + + /// Get a field from a hash. + pub async fn hget(&mut self, key: &str, field: &str) -> Result, crate::error::AppError> { + match self { + Cache::Redis(cm) => { + let val: Option = redis::AsyncCommands::hget(cm.as_mut(), key, field).await?; + Ok(val) + } + Cache::InMemory(store) => { + let hashes = store.hashes.read().await; + Ok(hashes + .get(key) + .and_then(|h| h.get(field)) + .cloned()) + } + } + } + + /// Set a field in a hash. + pub async fn hset(&mut self, key: &str, field: &str, value: &str) -> Result<(), crate::error::AppError> { + match self { + Cache::Redis(cm) => { + let _: () = redis::AsyncCommands::hset(cm.as_mut(), key, field, value).await?; + Ok(()) + } + Cache::InMemory(store) => { + let mut hashes = store.hashes.write().await; + hashes + .entry(key.to_string()) + .or_default() + .insert(field.to_string(), value.to_string()); + Ok(()) + } + } + } + + /// Remove a field from a hash. + pub async fn hdel(&mut self, key: &str, field: &str) -> Result<(), crate::error::AppError> { + match self { + Cache::Redis(cm) => { + let _: () = redis::AsyncCommands::hdel(cm.as_mut(), key, field).await?; + Ok(()) + } + Cache::InMemory(store) => { + let mut hashes = store.hashes.write().await; + if let Some(hash) = hashes.get_mut(key) { + hash.remove(field); + } + Ok(()) + } + } + } + + // ── Key-level operations ────────────────────────────────────────── + + /// Delete an entire key (set or hash). + pub async fn del(&mut self, key: &str) -> Result<(), crate::error::AppError> { + match self { + Cache::Redis(cm) => { + let _: () = redis::cmd("DEL") + .arg(key) + .query_async(cm.as_mut()) + .await?; + Ok(()) + } + Cache::InMemory(store) => { + { + let mut sets = store.sets.write().await; + sets.remove(key); + } + { + let mut hashes = store.hashes.write().await; + hashes.remove(key); + } + Ok(()) + } + } + } +} diff --git a/src/config.rs b/src/config.rs index dabfd45..8e1a25a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,8 @@ use std::env; #[derive(Debug, Clone)] pub struct Config { pub database_url: String, - pub redis_url: String, + /// Redis URL. `None` when running in single-instance mode (SQLite + in-memory cache). + pub redis_url: Option, pub admin_key: String, pub listen_addr: String, /// Comma-separated list of allowed CORS origins, or "*" for any. @@ -26,10 +27,25 @@ fn parse_bool_env(key: &str, default: bool) -> bool { impl Config { pub fn from_env() -> anyhow::Result { + let database_url = env::var("DATABASE_URL") + .map_err(|_| anyhow::anyhow!("DATABASE_URL is required"))?; + + // Redis is optional: when DATABASE_URL points at SQLite and REDIS_URL + // is not explicitly set, fall back to in-memory caching. + let redis_url = match env::var("REDIS_URL") { + Ok(url) => Some(url), + Err(_) => { + if database_url.starts_with("sqlite:") { + None // single-instance mode + } else { + Some("redis://127.0.0.1:6379".into()) + } + } + }; + Ok(Self { - database_url: env::var("DATABASE_URL") - .map_err(|_| anyhow::anyhow!("DATABASE_URL is required"))?, - redis_url: env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".into()), + database_url, + redis_url, admin_key: env::var("ADMIN_KEY") .map_err(|_| anyhow::anyhow!("ADMIN_KEY is required"))?, listen_addr: env::var("LISTEN_ADDR") @@ -44,4 +60,9 @@ impl Config { log_response_body: parse_bool_env("LOG_RESPONSE_BODY", false), }) } + + /// Whether the database is SQLite. + pub fn is_sqlite(&self) -> bool { + self.database_url.starts_with("sqlite:") + } } diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..0c2e893 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,154 @@ +/// Database pool abstraction supporting both PostgreSQL and SQLite. +/// +/// When `DATABASE_URL` starts with `sqlite:`, a SQLite pool is created; +/// otherwise a PostgreSQL pool is used. +#[derive(Clone)] +pub enum DbPool { + Pg(sqlx::PgPool), + Sqlite(sqlx::SqlitePool), +} + +impl DbPool { + pub fn is_sqlite(&self) -> bool { + matches!(self, DbPool::Sqlite(_)) + } +} + +// ── Query dispatch macros ───────────────────────────────────────────── +// +// These macros duplicate the query call across both pool variants so that +// callers can write database-agnostic service code. The Rust compiler +// resolves both arms at compile time, and only the active variant runs. + +/// `db_query_as!(mode, pool, sql [, bind]*)` — run a `query_as` against DbPool. +/// +/// Modes: +/// - `optional` — `fetch_optional`, returns `Result, sqlx::Error>` +/// - `all` — `fetch_all`, returns `Result, sqlx::Error>` +/// - `one` — `fetch_one`, returns `Result` (errors if no row) +macro_rules! db_query_as { + (optional, $pool:expr, $sql:expr $(, $bind:expr)*) => {{ + match $pool { + $crate::db::DbPool::Pg(p) => { + sqlx::query_as($sql) + $(.bind($bind))* + .fetch_optional(p) + .await + } + $crate::db::DbPool::Sqlite(p) => { + sqlx::query_as($sql) + $(.bind($bind))* + .fetch_optional(p) + .await + } + } + }}; + (all, $pool:expr, $sql:expr $(, $bind:expr)*) => {{ + match $pool { + $crate::db::DbPool::Pg(p) => { + sqlx::query_as($sql) + $(.bind($bind))* + .fetch_all(p) + .await + } + $crate::db::DbPool::Sqlite(p) => { + sqlx::query_as($sql) + $(.bind($bind))* + .fetch_all(p) + .await + } + } + }}; + (one, $pool:expr, $sql:expr $(, $bind:expr)*) => {{ + match $pool { + $crate::db::DbPool::Pg(p) => { + sqlx::query_as($sql) + $(.bind($bind))* + .fetch_one(p) + .await + } + $crate::db::DbPool::Sqlite(p) => { + sqlx::query_as($sql) + $(.bind($bind))* + .fetch_one(p) + .await + } + } + }}; +} + +/// `db_execute!(pool, sql [, bind]*)` — execute a statement, returns `Result`. +macro_rules! db_execute { + ($pool:expr, $sql:expr $(, $bind:expr)*) => {{ + match $pool { + $crate::db::DbPool::Pg(p) => { + sqlx::query($sql) + $(.bind($bind))* + .execute(p) + .await + .map(|r| r.rows_affected()) + } + $crate::db::DbPool::Sqlite(p) => { + sqlx::query($sql) + $(.bind($bind))* + .execute(p) + .await + .map(|r| r.rows_affected()) + } + } + }}; +} + +/// `db_query_scalar!(mode, pool, sql [, bind]*)` — run a `query_scalar` against DbPool. +/// +/// `mode` is one of `optional`, `all`, `one`. +macro_rules! db_query_scalar { + (optional, $pool:expr, $sql:expr $(, $bind:expr)*) => {{ + match $pool { + $crate::db::DbPool::Pg(p) => { + sqlx::query_scalar($sql) + $(.bind($bind))* + .fetch_optional(p) + .await + } + $crate::db::DbPool::Sqlite(p) => { + sqlx::query_scalar($sql) + $(.bind($bind))* + .fetch_optional(p) + .await + } + } + }}; + (all, $pool:expr, $sql:expr $(, $bind:expr)*) => {{ + match $pool { + $crate::db::DbPool::Pg(p) => { + sqlx::query_scalar($sql) + $(.bind($bind))* + .fetch_all(p) + .await + } + $crate::db::DbPool::Sqlite(p) => { + sqlx::query_scalar($sql) + $(.bind($bind))* + .fetch_all(p) + .await + } + } + }}; + (one, $pool:expr, $sql:expr $(, $bind:expr)*) => {{ + match $pool { + $crate::db::DbPool::Pg(p) => { + sqlx::query_scalar($sql) + $(.bind($bind))* + .fetch_one(p) + .await + } + $crate::db::DbPool::Sqlite(p) => { + sqlx::query_scalar($sql) + $(.bind($bind))* + .fetch_one(p) + .await + } + } + }}; +} diff --git a/src/main.rs b/src/main.rs index 96fb289..ca02cc9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,6 @@ +#[macro_use] +mod db; +mod cache; mod config; mod error; mod middleware; @@ -9,13 +12,14 @@ mod state; use std::sync::Arc; use axum::{http::HeaderValue, middleware as axum_mw, Router}; -use sqlx::postgres::PgPoolOptions; use tokio::net::TcpListener; use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::trace::TraceLayer; use tracing_subscriber::EnvFilter; +use cache::Cache; use config::Config; +use db::DbPool; use state::AppState; #[tokio::main] @@ -34,29 +38,62 @@ async fn main() -> anyhow::Result<()> { let config = Config::from_env()?; tracing::info!("Starting LLM Gateway on {}", config.listen_addr); - // Create Postgres connection pool - let db = PgPoolOptions::new() - .max_connections(10) - .connect(&config.database_url) - .await?; + // Create database pool (Postgres or SQLite) + let db = if config.is_sqlite() { + tracing::info!("Using SQLite database"); + let pool = sqlx::sqlite::SqlitePoolOptions::new() + .max_connections(5) + .connect(&config.database_url) + .await?; + + // Enable WAL mode and foreign keys for better performance and correctness + sqlx::query("PRAGMA journal_mode = WAL") + .execute(&pool) + .await?; + sqlx::query("PRAGMA foreign_keys = ON") + .execute(&pool) + .await?; + + DbPool::Sqlite(pool) + } else { + tracing::info!("Using PostgreSQL database"); + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(10) + .connect(&config.database_url) + .await?; + DbPool::Pg(pool) + }; - // Run migrations - sqlx::migrate!("./migrations").run(&db).await?; + // Run migrations (select the correct set based on database kind) + match &db { + DbPool::Pg(pool) => { + sqlx::migrate!("./migrations").run(pool).await?; + } + DbPool::Sqlite(pool) => { + sqlx::migrate!("./migrations_sqlite").run(pool).await?; + } + } tracing::info!("Database migrations applied"); - // Create Redis connection manager - let redis_client = redis::Client::open(config.redis_url.as_str())?; - let mut redis = redis_client.get_connection_manager().await?; - tracing::info!("Connected to Redis"); + // Create cache (Redis or in-memory) + let mut cache = if let Some(ref redis_url) = config.redis_url { + let redis_client = redis::Client::open(redis_url.as_str())?; + let cm = redis_client.get_connection_manager().await?; + tracing::info!("Connected to Redis"); + Cache::Redis(Box::new(cm)) + } else { + tracing::info!("Using in-memory cache (single-instance mode)"); + Cache::in_memory() + }; - // Warm up Redis caches - services::key_service::warm_up_redis(&db, &mut redis).await?; - services::model_service::warm_up_model_routes(&db, &mut redis).await?; + // Warm up caches + services::key_service::warm_up_cache(&db, &mut cache).await?; + services::model_service::warm_up_model_routes(&db, &mut cache).await?; // Build shared state let state = Arc::new(AppState { db, - redis, + cache, config: config.clone(), http_client: reqwest::Client::new(), }); diff --git a/src/middleware/auth.rs b/src/middleware/auth.rs index 784f33a..dcb3156 100644 --- a/src/middleware/auth.rs +++ b/src/middleware/auth.rs @@ -58,7 +58,7 @@ pub async fn admin_auth( next.run(req).await } -/// Middleware that validates a User Key against Redis / PG. +/// Middleware that validates a User Key against cache / DB. pub async fn user_key_auth( State(state): State>, req: Request, @@ -75,8 +75,8 @@ pub async fn user_key_auth( } }; - let mut redis = state.redis.clone(); - match key_service::validate_key(&token, &mut redis, &state.db).await { + let mut cache = state.cache.clone(); + match key_service::validate_key(&token, &mut cache, &state.db).await { Ok(Some(v)) => { let mut req = req; req.extensions_mut().insert(KeyIdentity { diff --git a/src/routes/admin.rs b/src/routes/admin.rs index 8692cdc..2c054b8 100644 --- a/src/routes/admin.rs +++ b/src/routes/admin.rs @@ -39,8 +39,8 @@ async fn create_key( return Err(AppError::BadRequest("name is required".into())); } - let mut redis = state.redis.clone(); - let result = key_service::create_key(&body.name, body.token_budget, &state.db, &mut redis).await?; + let mut cache = state.cache.clone(); + let result = key_service::create_key(&body.name, body.token_budget, &state.db, &mut cache).await?; Ok((StatusCode::CREATED, Json(result))) } @@ -58,8 +58,8 @@ async fn rotate_key( State(state): State>, Path(id): Path, ) -> Result, AppError> { - let mut redis = state.redis.clone(); - let result = key_service::rotate_key(id, &state.db, &mut redis).await?; + let mut cache = state.cache.clone(); + let result = key_service::rotate_key(id, &state.db, &mut cache).await?; Ok(Json(result)) } @@ -68,8 +68,8 @@ async fn delete_key_handler( State(state): State>, Path(id): Path, ) -> Result { - let mut redis = state.redis.clone(); - key_service::delete_key(id, &state.db, &mut redis).await?; + let mut cache = state.cache.clone(); + key_service::delete_key(id, &state.db, &mut cache).await?; Ok(StatusCode::NO_CONTENT) } @@ -157,8 +157,8 @@ async fn update_provider( .await?; // Rebuild model route cache since provider details may have changed - let mut redis = state.redis.clone(); - model_service::warm_up_model_routes(&state.db, &mut redis).await?; + let mut cache = state.cache.clone(); + model_service::warm_up_model_routes(&state.db, &mut cache).await?; Ok(Json(result)) } @@ -171,8 +171,8 @@ async fn delete_provider_handler( provider_service::delete_provider(id, &state.db).await?; // Rebuild model route cache - let mut redis = state.redis.clone(); - model_service::warm_up_model_routes(&state.db, &mut redis).await?; + let mut cache = state.cache.clone(); + model_service::warm_up_model_routes(&state.db, &mut cache).await?; Ok(StatusCode::NO_CONTENT) } @@ -202,7 +202,7 @@ async fn create_model( return Err(AppError::BadRequest("name is required".into())); } - let mut redis = state.redis.clone(); + let mut cache = state.cache.clone(); let result = model_service::create_model( &body.name, body.provider_id, @@ -210,7 +210,7 @@ async fn create_model( body.input_token_coefficient.unwrap_or(1.0), body.output_token_coefficient.unwrap_or(1.0), &state.db, - &mut redis, + &mut cache, ) .await?; @@ -230,8 +230,8 @@ async fn delete_model_handler( State(state): State>, Path(id): Path, ) -> Result { - let mut redis = state.redis.clone(); - model_service::delete_model(id, &state.db, &mut redis).await?; + let mut cache = state.cache.clone(); + model_service::delete_model(id, &state.db, &mut cache).await?; Ok(StatusCode::NO_CONTENT) } @@ -252,7 +252,7 @@ async fn update_model_handler( Path(id): Path, Json(body): Json, ) -> Result, AppError> { - let mut redis = state.redis.clone(); + let mut cache = state.cache.clone(); let result = model_service::update_model( id, body.name.as_deref(), @@ -262,7 +262,7 @@ async fn update_model_handler( body.input_token_coefficient, body.output_token_coefficient, &state.db, - &mut redis, + &mut cache, ) .await?; diff --git a/src/routes/proxy.rs b/src/routes/proxy.rs index 32eaf8a..0b2f3ea 100644 --- a/src/routes/proxy.rs +++ b/src/routes/proxy.rs @@ -71,8 +71,8 @@ async fn chat_completions( } // Resolve model → provider routing - let mut redis = state.redis.clone(); - let route = model_service::resolve_model_route(&model_name, &mut redis, &state.db) + let mut cache = state.cache.clone(); + let route = model_service::resolve_model_route(&model_name, &mut cache, &state.db) .await .map_err(|e| { tracing::error!("Model route resolution error: {}", e); diff --git a/src/services/key_service.rs b/src/services/key_service.rs index fdb7987..c05de89 100644 --- a/src/services/key_service.rs +++ b/src/services/key_service.rs @@ -1,14 +1,13 @@ use chrono::Utc; -use redis::aio::ConnectionManager; -use redis::AsyncCommands; use sha2::{Digest, Sha256}; -use sqlx::PgPool; use uuid::Uuid; +use crate::cache::Cache; +use crate::db::DbPool; use crate::error::AppError; use crate::models::user_key::{UserKey, UserKeyCreated, UserKeyInfo}; -const REDIS_ACTIVE_KEYS_SET: &str = "gateway:active_key_hashes"; +const CACHE_ACTIVE_KEYS_SET: &str = "gateway:active_key_hashes"; /// Generate a new key in the format `sk-{uuid v4}` pub fn generate_key() -> String { @@ -31,13 +30,13 @@ fn key_prefix(plain: &str) -> String { } } -/// Create a new user key, persist to PG + cache in Redis. +/// Create a new user key, persist to DB + cache. /// Returns the full key info plus the plaintext key (shown only once). pub async fn create_key( name: &str, token_budget: Option, - db: &PgPool, - redis: &mut ConnectionManager, + db: &DbPool, + cache: &mut Cache, ) -> Result { let id = Uuid::new_v4(); let plain = generate_key(); @@ -45,23 +44,17 @@ pub async fn create_key( let prefix = key_prefix(&plain); let now = Utc::now(); - sqlx::query( + db_execute!( + db, r#" INSERT INTO user_keys (id, name, key_hash, key_prefix, is_active, token_budget, tokens_used, created_at, updated_at) VALUES ($1, $2, $3, $4, TRUE, $5, 0, $6, $6) "#, - ) - .bind(id) - .bind(name) - .bind(&hash) - .bind(&prefix) - .bind(token_budget) - .bind(now) - .execute(db) - .await?; - - // Add hash to Redis active set - let _: () = redis.sadd(REDIS_ACTIVE_KEYS_SET, &hash).await?; + id, name, &hash, &prefix, token_budget, now + )?; + + // Add hash to active set + cache.sadd(CACHE_ACTIVE_KEYS_SET, &hash).await?; Ok(UserKeyCreated { id, @@ -80,25 +73,24 @@ pub struct KeyValidation { pub tokens_used: i64, } -/// Validate a plaintext key against Redis (fast path) or PG (slow path + backfill). +/// Validate a plaintext key against cache (fast path) or DB (slow path + backfill). /// Returns `Some(KeyValidation)` on success, `None` on invalid key. pub async fn validate_key( plain: &str, - redis: &mut ConnectionManager, - db: &PgPool, + cache: &mut Cache, + db: &DbPool, ) -> Result, AppError> { let hash = hash_key(plain); - // Fast path: check Redis SET - let exists: bool = redis.sismember(REDIS_ACTIVE_KEYS_SET, &hash).await?; + // Fast path: check cache + let exists = cache.sismember(CACHE_ACTIVE_KEYS_SET, &hash).await?; if exists { - // Look up key details from PG - let row = sqlx::query_as::<_, (Uuid, Option, i64)>( + // Look up key details from DB + let row: Option<(Uuid, Option, i64)> = db_query_as!( + optional, db, "SELECT id, token_budget, tokens_used FROM user_keys WHERE key_hash = $1 AND is_active = TRUE", - ) - .bind(&hash) - .fetch_optional(db) - .await?; + &hash + )?; return Ok(row.map(|(id, budget, used)| KeyValidation { key_id: id, @@ -108,17 +100,16 @@ pub async fn validate_key( })); } - // Slow path: check PG - let row = sqlx::query_as::<_, (Uuid, Option, i64)>( + // Slow path: check DB + let row: Option<(Uuid, Option, i64)> = db_query_as!( + optional, db, "SELECT id, token_budget, tokens_used FROM user_keys WHERE key_hash = $1 AND is_active = TRUE", - ) - .bind(&hash) - .fetch_optional(db) - .await?; + &hash + )?; if let Some((id, budget, used)) = row { - // Backfill Redis - let _: () = redis.sadd(REDIS_ACTIVE_KEYS_SET, &hash).await?; + // Backfill cache + cache.sadd(CACHE_ACTIVE_KEYS_SET, &hash).await?; return Ok(Some(KeyValidation { key_id: id, key_hash: hash, @@ -131,10 +122,11 @@ pub async fn validate_key( } /// List all keys (without exposing hashes or plaintext). -pub async fn list_keys(db: &PgPool) -> Result, AppError> { - let keys = sqlx::query_as::<_, UserKey>("SELECT * FROM user_keys ORDER BY created_at DESC") - .fetch_all(db) - .await?; +pub async fn list_keys(db: &DbPool) -> Result, AppError> { + let keys: Vec = db_query_as!( + all, db, + "SELECT * FROM user_keys ORDER BY created_at DESC" + )?; Ok(keys.into_iter().map(UserKeyInfo::from).collect()) } @@ -143,20 +135,19 @@ pub async fn list_keys(db: &PgPool) -> Result, AppError> { /// Returns the new plaintext key (shown only once). pub async fn rotate_key( id: Uuid, - db: &PgPool, - redis: &mut ConnectionManager, + db: &DbPool, + cache: &mut Cache, ) -> Result { // Fetch the existing key to get its old hash - let existing = sqlx::query_as::<_, UserKey>( + let existing: UserKey = db_query_as!( + optional, db, "SELECT * FROM user_keys WHERE id = $1 AND is_active = TRUE", - ) - .bind(id) - .fetch_optional(db) - .await? + id + )? .ok_or(AppError::NotFound)?; - // Remove old hash from Redis - let _: () = redis.srem(REDIS_ACTIVE_KEYS_SET, &existing.key_hash).await?; + // Remove old hash from cache + cache.srem(CACHE_ACTIVE_KEYS_SET, &existing.key_hash).await?; // Generate new key let new_plain = generate_key(); @@ -164,18 +155,14 @@ pub async fn rotate_key( let new_prefix = key_prefix(&new_plain); let now = Utc::now(); - sqlx::query( + db_execute!( + db, "UPDATE user_keys SET key_hash = $1, key_prefix = $2, updated_at = $3 WHERE id = $4", - ) - .bind(&new_hash) - .bind(&new_prefix) - .bind(now) - .bind(id) - .execute(db) - .await?; + &new_hash, &new_prefix, now, id + )?; - // Add new hash to Redis - let _: () = redis.sadd(REDIS_ACTIVE_KEYS_SET, &new_hash).await?; + // Add new hash to cache + cache.sadd(CACHE_ACTIVE_KEYS_SET, &new_hash).await?; Ok(UserKeyCreated { id, @@ -186,55 +173,52 @@ pub async fn rotate_key( }) } -/// Soft-delete a key: mark inactive + remove from Redis. +/// Soft-delete a key: mark inactive + remove from cache. pub async fn delete_key( id: Uuid, - db: &PgPool, - redis: &mut ConnectionManager, + db: &DbPool, + cache: &mut Cache, ) -> Result<(), AppError> { - let existing = sqlx::query_as::<_, UserKey>( + let existing: UserKey = db_query_as!( + optional, db, "SELECT * FROM user_keys WHERE id = $1 AND is_active = TRUE", - ) - .bind(id) - .fetch_optional(db) - .await? + id + )? .ok_or(AppError::NotFound)?; - sqlx::query("UPDATE user_keys SET is_active = FALSE, updated_at = NOW() WHERE id = $1") - .bind(id) - .execute(db) - .await?; + let now = Utc::now(); + db_execute!( + db, + "UPDATE user_keys SET is_active = FALSE, updated_at = $1 WHERE id = $2", + now, id + )?; - let _: () = redis.srem(REDIS_ACTIVE_KEYS_SET, &existing.key_hash).await?; + cache.srem(CACHE_ACTIVE_KEYS_SET, &existing.key_hash).await?; Ok(()) } -/// Warm up Redis with all active key hashes from PG (call on startup). -pub async fn warm_up_redis( - db: &PgPool, - redis: &mut ConnectionManager, +/// Warm up cache with all active key hashes from DB (call on startup). +pub async fn warm_up_cache( + db: &DbPool, + cache: &mut Cache, ) -> Result<(), AppError> { - let hashes = sqlx::query_scalar::<_, String>( - "SELECT key_hash FROM user_keys WHERE is_active = TRUE", - ) - .fetch_all(db) - .await?; + let hashes: Vec = db_query_scalar!( + all, db, + "SELECT key_hash FROM user_keys WHERE is_active = TRUE" + )?; if !hashes.is_empty() { // Clear stale data and re-populate - let _: () = redis::cmd("DEL") - .arg(REDIS_ACTIVE_KEYS_SET) - .query_async(redis) - .await?; + cache.del(CACHE_ACTIVE_KEYS_SET).await?; for hash in &hashes { - let _: () = redis.sadd(REDIS_ACTIVE_KEYS_SET, hash).await?; + cache.sadd(CACHE_ACTIVE_KEYS_SET, hash).await?; } - tracing::info!("Warmed up Redis with {} active key hashes", hashes.len()); + tracing::info!("Warmed up cache with {} active key hashes", hashes.len()); } else { - tracing::info!("No active keys to warm up in Redis"); + tracing::info!("No active keys to warm up in cache"); } Ok(()) @@ -245,24 +229,21 @@ pub async fn update_key_budget( id: Uuid, token_budget: Option, reset_usage: bool, - db: &PgPool, + db: &DbPool, ) -> Result { - let key = if reset_usage { - sqlx::query_as::<_, UserKey>( - "UPDATE user_keys SET token_budget = $1, tokens_used = 0, updated_at = NOW() WHERE id = $2 RETURNING *", - ) - .bind(token_budget) - .bind(id) - .fetch_optional(db) - .await? + let now = Utc::now(); + let key: Option = if reset_usage { + db_query_as!( + optional, db, + "UPDATE user_keys SET token_budget = $1, tokens_used = 0, updated_at = $2 WHERE id = $3 RETURNING *", + token_budget, now, id + )? } else { - sqlx::query_as::<_, UserKey>( - "UPDATE user_keys SET token_budget = $1, updated_at = NOW() WHERE id = $2 RETURNING *", - ) - .bind(token_budget) - .bind(id) - .fetch_optional(db) - .await? + db_query_as!( + optional, db, + "UPDATE user_keys SET token_budget = $1, updated_at = $2 WHERE id = $3 RETURNING *", + token_budget, now, id + )? }; key.map(UserKeyInfo::from).ok_or(AppError::NotFound) @@ -272,14 +253,13 @@ pub async fn update_key_budget( pub async fn increment_tokens_used( id: Uuid, tokens: i64, - db: &PgPool, + db: &DbPool, ) -> Result<(), AppError> { - sqlx::query( - "UPDATE user_keys SET tokens_used = tokens_used + $1, updated_at = NOW() WHERE id = $2", - ) - .bind(tokens) - .bind(id) - .execute(db) - .await?; + let now = Utc::now(); + db_execute!( + db, + "UPDATE user_keys SET tokens_used = tokens_used + $1, updated_at = $2 WHERE id = $3", + tokens, now, id + )?; Ok(()) } diff --git a/src/services/log_service.rs b/src/services/log_service.rs index 1f1e6ee..e6af046 100644 --- a/src/services/log_service.rs +++ b/src/services/log_service.rs @@ -1,7 +1,7 @@ use chrono::Utc; -use sqlx::PgPool; use uuid::Uuid; +use crate::db::DbPool; use crate::error::AppError; use crate::models::request_log::{LogListResponse, RequestLogInfo}; @@ -27,11 +27,12 @@ pub struct NewRequestLog { } /// Insert a request log entry into the database. -pub async fn insert_log(db: &PgPool, log: NewRequestLog) -> Result<(), AppError> { +pub async fn insert_log(db: &DbPool, log: NewRequestLog) -> Result<(), AppError> { let id = Uuid::new_v4(); let now = Utc::now(); - sqlx::query( + db_execute!( + db, r#" INSERT INTO request_logs ( id, request_id, user_key_id, user_key_hash, @@ -43,28 +44,26 @@ pub async fn insert_log(db: &PgPool, log: NewRequestLog) -> Result<(), AppError> $14, $15, $16, $17, $18, $19 ) "#, - ) - .bind(id) - .bind(&log.request_id) - .bind(log.user_key_id) - .bind(&log.user_key_hash) - .bind(&log.model_requested) - .bind(&log.model_sent) - .bind(log.provider_id) - .bind(&log.provider_kind) - .bind(log.status_code) - .bind(log.is_error) - .bind(log.prompt_tokens) - .bind(log.completion_tokens) - .bind(log.total_tokens) - .bind(log.latency_ms) - .bind(log.is_stream) - .bind(&log.request_body) - .bind(&log.response_body) - .bind(&log.error_message) - .bind(now) - .execute(db) - .await?; + id, + &log.request_id, + log.user_key_id, + &log.user_key_hash, + &log.model_requested, + &log.model_sent, + log.provider_id, + &log.provider_kind, + log.status_code, + log.is_error, + log.prompt_tokens, + log.completion_tokens, + log.total_tokens, + log.latency_ms, + log.is_stream, + &log.request_body, + &log.response_body, + &log.error_message, + now + )?; Ok(()) } @@ -132,7 +131,7 @@ impl From for RequestLogInfo { } /// List logs with offset-based pagination and optional filters. -pub async fn list_logs(db: &PgPool, params: ListLogsParams) -> Result { +pub async fn list_logs(db: &DbPool, params: ListLogsParams) -> Result { let offset = (params.page - 1).max(0) * params.per_page; // Build dynamic WHERE clauses @@ -152,6 +151,9 @@ pub async fn list_logs(db: &PgPool, params: ListLogsParams) -> Result Result Result(&count_query); - if let Some(ref kid) = params.key_id { - q = q.bind(kid); - } - if let Some(ref m) = params.model { - q = q.bind(m); - } - q.fetch_one(db).await? - }; + // Execute count and data queries — must build separate query objects per DB variant + // because sqlx locks the database type at bind time. + macro_rules! run_list_queries { + ($pool:expr) => {{ + let total: i64 = { + let mut q = sqlx::query_scalar::<_, i64>(&count_query); + if let Some(ref kid) = params.key_id { + q = q.bind(kid); + } + if let Some(ref m) = params.model { + q = q.bind(m); + } + q.fetch_one($pool).await? + }; + let rows: Vec = { + let mut q = sqlx::query_as::<_, RequestLogRow>(&data_query) + .bind(params.per_page) + .bind(offset); + if let Some(ref kid) = params.key_id { + q = q.bind(kid); + } + if let Some(ref m) = params.model { + q = q.bind(m); + } + q.fetch_all($pool).await? + }; + (total, rows) + }}; + } - // Execute data query - let rows: Vec = { - let mut q = sqlx::query_as::<_, RequestLogRow>(&data_query) - .bind(params.per_page) - .bind(offset); - if let Some(ref kid) = params.key_id { - q = q.bind(kid); - } - if let Some(ref m) = params.model { - q = q.bind(m); - } - q.fetch_all(db).await? + let (total, rows) = match db { + DbPool::Pg(p) => run_list_queries!(p), + DbPool::Sqlite(p) => run_list_queries!(p), }; Ok(LogListResponse { @@ -208,15 +219,14 @@ pub async fn list_logs(db: &PgPool, params: ListLogsParams) -> Result Result { - let result = sqlx::query( - "DELETE FROM request_logs WHERE created_at < NOW() - make_interval(days => $1)", - ) - .bind(retention_days as i32) - .execute(db) - .await?; - - Ok(result.rows_affected()) +pub async fn cleanup_old_logs(db: &DbPool, retention_days: u32) -> Result { + let cutoff = Utc::now() - chrono::Duration::days(retention_days as i64); + let rows = db_execute!( + db, + "DELETE FROM request_logs WHERE created_at < $1", + cutoff + )?; + Ok(rows) } // ── Dashboard Stats ─────────────────────────────────────────────────── @@ -273,7 +283,7 @@ struct SummaryRow { #[derive(Debug, sqlx::FromRow)] struct HourlyRow { - hour: chrono::DateTime, + hour: String, requests: i64, errors: i64, tokens: i64, @@ -294,67 +304,90 @@ struct ProviderRow { errors: i64, } -pub async fn get_dashboard_stats(db: &PgPool) -> Result { - // 1) Summary - let summary = sqlx::query_as::<_, SummaryRow>( +pub async fn get_dashboard_stats(db: &DbPool) -> Result { + let cutoff_24h = Utc::now() - chrono::Duration::hours(24); + let cutoff_7d = Utc::now() - chrono::Duration::days(7); + + // 1) Summary — portable SQL using CASE WHEN instead of FILTER + let summary: SummaryRow = db_query_as!( + one, db, r#" SELECT - COUNT(*)::BIGINT AS total_requests, - COUNT(*) FILTER (WHERE created_at >= NOW() - INTERVAL '24 hours')::BIGINT AS total_requests_24h, - COUNT(*) FILTER (WHERE created_at >= NOW() - INTERVAL '24 hours' AND is_error)::BIGINT AS total_errors_24h, - COALESCE(SUM(total_tokens) FILTER (WHERE created_at >= NOW() - INTERVAL '24 hours'), 0)::BIGINT AS total_tokens_24h, - COALESCE(AVG(latency_ms) FILTER (WHERE created_at >= NOW() - INTERVAL '24 hours'), 0)::FLOAT8 AS avg_latency_24h + COUNT(*) AS total_requests, + SUM(CASE WHEN created_at >= $1 THEN 1 ELSE 0 END) AS total_requests_24h, + SUM(CASE WHEN created_at >= $1 AND is_error THEN 1 ELSE 0 END) AS total_errors_24h, + COALESCE(SUM(CASE WHEN created_at >= $1 THEN total_tokens ELSE 0 END), 0) AS total_tokens_24h, + COALESCE(AVG(CASE WHEN created_at >= $1 THEN latency_ms END), 0.0) AS avg_latency_24h FROM request_logs "#, - ) - .fetch_one(db) - .await?; + cutoff_24h + )?; - // 2) Hourly buckets (last 24h) - let hourly_rows = sqlx::query_as::<_, HourlyRow>( + // 2) Hourly buckets (last 24h) — need db-specific date truncation + let hourly_sql = if db.is_sqlite() { + r#" + SELECT + strftime('%Y-%m-%d %H:00:00', created_at) AS hour, + COUNT(*) AS requests, + SUM(CASE WHEN is_error THEN 1 ELSE 0 END) AS errors, + COALESCE(SUM(total_tokens), 0) AS tokens, + COALESCE(AVG(latency_ms), 0.0) AS avg_latency + FROM request_logs + WHERE created_at >= $1 + GROUP BY strftime('%Y-%m-%d %H:00:00', created_at) + ORDER BY hour + "# + } else { r#" SELECT - date_trunc('hour', created_at) AS hour, + to_char(date_trunc('hour', created_at), 'YYYY-MM-DD HH24:00:00') AS hour, COUNT(*) AS requests, - COUNT(*) FILTER (WHERE is_error) AS errors, - COALESCE(SUM(total_tokens), 0)::BIGINT AS tokens, - COALESCE(AVG(latency_ms), 0)::FLOAT8 AS avg_latency + SUM(CASE WHEN is_error THEN 1 ELSE 0 END) AS errors, + COALESCE(SUM(total_tokens), 0) AS tokens, + COALESCE(AVG(latency_ms), 0.0) AS avg_latency FROM request_logs - WHERE created_at >= NOW() - INTERVAL '24 hours' + WHERE created_at >= $1 GROUP BY date_trunc('hour', created_at) ORDER BY hour - "#, - ) - .fetch_all(db) - .await?; + "# + }; + + let hourly_rows: Vec = db_query_as!(all, db, hourly_sql, cutoff_24h)?; let requests_per_hour: Vec = hourly_rows .into_iter() - .map(|r| HourlyBucket { - hour: r.hour.format("%H:%M").to_string(), - requests: r.requests, - errors: r.errors, - tokens: r.tokens, - avg_latency: (r.avg_latency * 10.0).round() / 10.0, + .map(|r| { + // Extract "HH:MM" from the hour string (expected format: "YYYY-MM-DD HH:00:00") + let display_hour = r.hour.find(' ') + .and_then(|pos| r.hour.get(pos + 1..pos + 6)) + .unwrap_or(&r.hour) + .to_string(); + HourlyBucket { + hour: display_hour, + requests: r.requests, + errors: r.errors, + tokens: r.tokens, + avg_latency: (r.avg_latency * 10.0).round() / 10.0, + } }) .collect(); - // 3) Per-model usage (last 7 days) - let model_rows = sqlx::query_as::<_, ModelRow>( + // 3) Per-model usage (last 7 days) — portable SQL + let model_rows: Vec = db_query_as!( + all, db, r#" SELECT model_requested AS model, COUNT(*) AS requests, - COALESCE(SUM(total_tokens), 0)::BIGINT AS tokens + COALESCE(SUM(total_tokens), 0) AS tokens FROM request_logs - WHERE created_at >= NOW() - INTERVAL '7 days' + WHERE created_at >= $1 GROUP BY model_requested ORDER BY requests DESC LIMIT 20 "#, - ) - .fetch_all(db) - .await?; + cutoff_7d + )?; let model_usage: Vec = model_rows .into_iter() @@ -365,21 +398,21 @@ pub async fn get_dashboard_stats(db: &PgPool) -> Result( + // 4) Per-provider usage (last 7 days) — portable SQL + let provider_rows: Vec = db_query_as!( + all, db, r#" SELECT COALESCE(provider_kind, 'unknown') AS provider, COUNT(*) AS requests, - COUNT(*) FILTER (WHERE is_error) AS errors + SUM(CASE WHEN is_error THEN 1 ELSE 0 END) AS errors FROM request_logs - WHERE created_at >= NOW() - INTERVAL '7 days' + WHERE created_at >= $1 GROUP BY provider_kind ORDER BY requests DESC "#, - ) - .fetch_all(db) - .await?; + cutoff_7d + )?; let provider_usage: Vec = provider_rows .into_iter() diff --git a/src/services/model_service.rs b/src/services/model_service.rs index 8b9a469..7869684 100644 --- a/src/services/model_service.rs +++ b/src/services/model_service.rs @@ -1,14 +1,13 @@ use chrono::Utc; -use redis::aio::ConnectionManager; -use redis::AsyncCommands; -use sqlx::PgPool; use uuid::Uuid; +use crate::cache::Cache; +use crate::db::DbPool; use crate::error::AppError; use crate::models::model::{Model, ModelInfo, ModelRoute}; use crate::models::provider::Provider; -const REDIS_MODEL_ROUTES_HASH: &str = "gateway:model_routes"; +const CACHE_MODEL_ROUTES_HASH: &str = "gateway:model_routes"; /// Create a new model mapping. pub async fn create_model( @@ -17,38 +16,33 @@ pub async fn create_model( provider_model_name: Option<&str>, input_token_coefficient: f64, output_token_coefficient: f64, - db: &PgPool, - redis: &mut ConnectionManager, + db: &DbPool, + cache: &mut Cache, ) -> Result { // Verify provider exists - let provider = sqlx::query_as::<_, Provider>("SELECT * FROM providers WHERE id = $1") - .bind(provider_id) - .fetch_optional(db) - .await? - .ok_or_else(|| AppError::BadRequest(format!("Provider {provider_id} not found")))?; + let provider: Provider = db_query_as!( + optional, db, + "SELECT * FROM providers WHERE id = $1", + provider_id + )? + .ok_or_else(|| AppError::BadRequest(format!("Provider {provider_id} not found")))?; let id = Uuid::new_v4(); let now = Utc::now(); - sqlx::query( + db_execute!( + db, r#" INSERT INTO models (id, name, provider_id, provider_model_name, is_active, input_token_coefficient, output_token_coefficient, created_at, updated_at) VALUES ($1, $2, $3, $4, TRUE, $5, $6, $7, $7) "#, - ) - .bind(id) - .bind(name) - .bind(provider_id) - .bind(provider_model_name) - .bind(input_token_coefficient) - .bind(output_token_coefficient) - .bind(now) - .execute(db) - .await?; - - // Update Redis cache - cache_model_route(name, provider_model_name, input_token_coefficient, output_token_coefficient, &provider, redis).await?; + id, name, provider_id, provider_model_name, + input_token_coefficient, output_token_coefficient, now + )?; + + // Update cache + cache_model_route(name, provider_model_name, input_token_coefficient, output_token_coefficient, &provider, cache).await?; Ok(ModelInfo { id, @@ -65,8 +59,9 @@ pub async fn create_model( } /// List all models with their provider names. -pub async fn list_models(db: &PgPool) -> Result, AppError> { - let rows = sqlx::query_as::<_, ModelWithProvider>( +pub async fn list_models(db: &DbPool) -> Result, AppError> { + let rows: Vec = db_query_as!( + all, db, r#" SELECT m.id, m.name, m.provider_id, m.provider_model_name, m.is_active, m.input_token_coefficient, m.output_token_coefficient, @@ -74,10 +69,8 @@ pub async fn list_models(db: &PgPool) -> Result, AppError> { FROM models m JOIN providers p ON m.provider_id = p.id ORDER BY m.created_at DESC - "#, - ) - .fetch_all(db) - .await?; + "# + )?; Ok(rows .into_iter() @@ -96,30 +89,28 @@ pub async fn list_models(db: &PgPool) -> Result, AppError> { .collect()) } -/// Delete a model and remove from Redis cache. +/// Delete a model and remove from cache. pub async fn delete_model( id: Uuid, - db: &PgPool, - redis: &mut ConnectionManager, + db: &DbPool, + cache: &mut Cache, ) -> Result<(), AppError> { - let model = sqlx::query_as::<_, Model>("SELECT * FROM models WHERE id = $1") - .bind(id) - .fetch_optional(db) - .await? - .ok_or(AppError::NotFound)?; + let model: Model = db_query_as!( + optional, db, + "SELECT * FROM models WHERE id = $1", + id + )? + .ok_or(AppError::NotFound)?; - sqlx::query("DELETE FROM models WHERE id = $1") - .bind(id) - .execute(db) - .await?; + db_execute!(db, "DELETE FROM models WHERE id = $1", id)?; - // Remove from Redis - let _: () = redis.hdel(REDIS_MODEL_ROUTES_HASH, &model.name).await?; + // Remove from cache + cache.hdel(CACHE_MODEL_ROUTES_HASH, &model.name).await?; Ok(()) } -/// Update an existing model and rebuild Redis cache. +/// Update an existing model and rebuild cache. pub async fn update_model( id: Uuid, name: Option<&str>, @@ -128,14 +119,15 @@ pub async fn update_model( is_active: Option, input_token_coefficient: Option, output_token_coefficient: Option, - db: &PgPool, - redis: &mut ConnectionManager, + db: &DbPool, + cache: &mut Cache, ) -> Result { - let existing = sqlx::query_as::<_, Model>("SELECT * FROM models WHERE id = $1") - .bind(id) - .fetch_optional(db) - .await? - .ok_or(AppError::NotFound)?; + let existing: Model = db_query_as!( + optional, db, + "SELECT * FROM models WHERE id = $1", + id + )? + .ok_or(AppError::NotFound)?; let new_name = name.map(|s| s.to_string()).unwrap_or(existing.name.clone()); let new_provider_id = provider_id.unwrap_or(existing.provider_id); @@ -149,41 +141,38 @@ pub async fn update_model( // If provider changed, verify it exists if new_provider_id != existing.provider_id { - sqlx::query_as::<_, Provider>("SELECT * FROM providers WHERE id = $1") - .bind(new_provider_id) - .fetch_optional(db) - .await? - .ok_or_else(|| AppError::BadRequest(format!("Provider {new_provider_id} not found")))?; + let _: Provider = db_query_as!( + optional, db, + "SELECT * FROM providers WHERE id = $1", + new_provider_id + )? + .ok_or_else(|| AppError::BadRequest(format!("Provider {new_provider_id} not found")))?; } - sqlx::query( + let now = Utc::now(); + db_execute!( + db, r#" UPDATE models SET name = $1, provider_id = $2, provider_model_name = $3, is_active = $4, - input_token_coefficient = $5, output_token_coefficient = $6, updated_at = NOW() - WHERE id = $7 + input_token_coefficient = $5, output_token_coefficient = $6, updated_at = $7 + WHERE id = $8 "#, - ) - .bind(&new_name) - .bind(new_provider_id) - .bind(&new_provider_model_name) - .bind(new_is_active) - .bind(new_input_coeff) - .bind(new_output_coeff) - .bind(id) - .execute(db) - .await?; - - // Remove old name from Redis if name changed + &new_name, new_provider_id, &new_provider_model_name, new_is_active, + new_input_coeff, new_output_coeff, now, id + )?; + + // Remove old name from cache if name changed if new_name != existing.name { - let _: () = redis.hdel(REDIS_MODEL_ROUTES_HASH, &existing.name).await?; + cache.hdel(CACHE_MODEL_ROUTES_HASH, &existing.name).await?; } // Rebuild the full cache to keep everything consistent - warm_up_model_routes(db, redis).await?; + warm_up_model_routes(db, cache).await?; // Fetch updated row with provider name - let row = sqlx::query_as::<_, ModelWithProvider>( + let row: ModelWithProvider = db_query_as!( + one, db, r#" SELECT m.id, m.name, m.provider_id, m.provider_model_name, m.is_active, m.input_token_coefficient, m.output_token_coefficient, @@ -192,10 +181,8 @@ pub async fn update_model( JOIN providers p ON m.provider_id = p.id WHERE m.id = $1 "#, - ) - .bind(id) - .fetch_one(db) - .await?; + id + )?; Ok(ModelInfo { id: row.id, @@ -212,22 +199,23 @@ pub async fn update_model( } /// Resolve a user-facing model name to its routing information. -/// Fast path: Redis hash lookup. Slow path: PG query + backfill Redis. +/// Fast path: cache lookup. Slow path: DB query + backfill cache. pub async fn resolve_model_route( model_name: &str, - redis: &mut ConnectionManager, - db: &PgPool, + cache: &mut Cache, + db: &DbPool, ) -> Result, AppError> { - // Fast path: check Redis - let cached: Option = redis.hget(REDIS_MODEL_ROUTES_HASH, model_name).await?; + // Fast path: check cache + let cached: Option = cache.hget(CACHE_MODEL_ROUTES_HASH, model_name).await?; if let Some(json_str) = cached { if let Ok(route) = serde_json::from_str::(&json_str) { return Ok(Some(route)); } } - // Slow path: query PG - let row = sqlx::query_as::<_, ModelWithProviderFull>( + // Slow path: query DB + let row: Option = db_query_as!( + optional, db, r#" SELECT m.name AS model_name, m.provider_model_name, m.provider_id, m.input_token_coefficient, m.output_token_coefficient, @@ -236,10 +224,8 @@ pub async fn resolve_model_route( JOIN providers p ON m.provider_id = p.id WHERE m.name = $1 AND m.is_active = TRUE AND p.is_active = TRUE "#, - ) - .bind(model_name) - .fetch_optional(db) - .await?; + model_name + )?; match row { Some(r) => { @@ -255,11 +241,11 @@ pub async fn resolve_model_route( output_token_coefficient: r.output_token_coefficient, }; - // Backfill Redis + // Backfill cache if let Ok(json_str) = serde_json::to_string(&route) { - let _: Result<(), _> = redis - .hset(REDIS_MODEL_ROUTES_HASH, model_name, &json_str) - .await; + if let Err(e) = cache.hset(CACHE_MODEL_ROUTES_HASH, model_name, &json_str).await { + tracing::warn!("Failed to backfill model route cache: {e}"); + } } Ok(Some(route)) @@ -268,12 +254,13 @@ pub async fn resolve_model_route( } } -/// Warm up Redis with all active model routes (call on startup). +/// Warm up cache with all active model routes (call on startup). pub async fn warm_up_model_routes( - db: &PgPool, - redis: &mut ConnectionManager, + db: &DbPool, + cache: &mut Cache, ) -> Result<(), AppError> { - let rows = sqlx::query_as::<_, ModelWithProviderFull>( + let rows: Vec = db_query_as!( + all, db, r#" SELECT m.name AS model_name, m.provider_model_name, m.provider_id, m.input_token_coefficient, m.output_token_coefficient, @@ -281,16 +268,11 @@ pub async fn warm_up_model_routes( FROM models m JOIN providers p ON m.provider_id = p.id WHERE m.is_active = TRUE AND p.is_active = TRUE - "#, - ) - .fetch_all(db) - .await?; + "# + )?; // Clear stale cache - let _: () = redis::cmd("DEL") - .arg(REDIS_MODEL_ROUTES_HASH) - .query_async(redis) - .await?; + cache.del(CACHE_MODEL_ROUTES_HASH).await?; for r in &rows { let route = ModelRoute { @@ -307,13 +289,13 @@ pub async fn warm_up_model_routes( }; if let Ok(json_str) = serde_json::to_string(&route) { - let _: Result<(), _> = redis - .hset(REDIS_MODEL_ROUTES_HASH, &r.model_name, &json_str) - .await; + if let Err(e) = cache.hset(CACHE_MODEL_ROUTES_HASH, &r.model_name, &json_str).await { + tracing::warn!("Failed to cache model route '{}': {e}", r.model_name); + } } } - tracing::info!("Warmed up Redis with {} model routes", rows.len()); + tracing::info!("Warmed up cache with {} model routes", rows.len()); Ok(()) } @@ -345,14 +327,14 @@ struct ModelWithProviderFull { provider_kind: String, } -/// Cache a single model route into Redis. +/// Cache a single model route. async fn cache_model_route( model_name: &str, provider_model_name: Option<&str>, input_token_coefficient: f64, output_token_coefficient: f64, provider: &Provider, - redis: &mut ConnectionManager, + cache: &mut Cache, ) -> Result<(), AppError> { let route = ModelRoute { provider_id: provider.id, @@ -369,6 +351,6 @@ async fn cache_model_route( let json_str = serde_json::to_string(&route) .map_err(|e| AppError::Internal(format!("JSON serialization error: {e}")))?; - let _: () = redis.hset(REDIS_MODEL_ROUTES_HASH, model_name, &json_str).await?; + cache.hset(CACHE_MODEL_ROUTES_HASH, model_name, &json_str).await?; Ok(()) } diff --git a/src/services/provider_service.rs b/src/services/provider_service.rs index 8e4c2e2..204a01c 100644 --- a/src/services/provider_service.rs +++ b/src/services/provider_service.rs @@ -1,7 +1,7 @@ use chrono::Utc; -use sqlx::PgPool; use uuid::Uuid; +use crate::db::DbPool; use crate::error::AppError; use crate::models::provider::{Provider, ProviderInfo, ProviderKind}; @@ -11,7 +11,7 @@ pub async fn create_provider( kind: &str, base_url: Option<&str>, api_key: &str, - db: &PgPool, + db: &DbPool, ) -> Result { let pk = ProviderKind::from_str(kind) .ok_or_else(|| AppError::BadRequest(format!("Unknown provider kind: {kind}. Supported: openai, openrouter, dashscope, ark")))?; @@ -20,34 +20,30 @@ pub async fn create_provider( let id = Uuid::new_v4(); let now = Utc::now(); - sqlx::query( + db_execute!( + db, r#" INSERT INTO providers (id, name, kind, base_url, api_key, is_active, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, TRUE, $6, $6) "#, - ) - .bind(id) - .bind(name) - .bind(pk.as_str()) - .bind(resolved_base_url) - .bind(api_key) - .bind(now) - .execute(db) - .await?; - - let provider = sqlx::query_as::<_, Provider>("SELECT * FROM providers WHERE id = $1") - .bind(id) - .fetch_one(db) - .await?; + id, name, pk.as_str(), resolved_base_url, api_key, now + )?; + + let provider: Provider = db_query_as!( + one, db, + "SELECT * FROM providers WHERE id = $1", + id + )?; Ok(ProviderInfo::from(provider)) } /// List all providers. -pub async fn list_providers(db: &PgPool) -> Result, AppError> { - let providers = sqlx::query_as::<_, Provider>("SELECT * FROM providers ORDER BY created_at DESC") - .fetch_all(db) - .await?; +pub async fn list_providers(db: &DbPool) -> Result, AppError> { + let providers: Vec = db_query_as!( + all, db, + "SELECT * FROM providers ORDER BY created_at DESC" + )?; Ok(providers.into_iter().map(ProviderInfo::from).collect()) } @@ -60,13 +56,14 @@ pub async fn update_provider( base_url: Option<&str>, api_key: Option<&str>, is_active: Option, - db: &PgPool, + db: &DbPool, ) -> Result { - let existing = sqlx::query_as::<_, Provider>("SELECT * FROM providers WHERE id = $1") - .bind(id) - .fetch_optional(db) - .await? - .ok_or(AppError::NotFound)?; + let existing: Provider = db_query_as!( + optional, db, + "SELECT * FROM providers WHERE id = $1", + id + )? + .ok_or(AppError::NotFound)?; let new_kind = match kind { Some(k) => { @@ -81,39 +78,32 @@ pub async fn update_provider( let new_base_url = base_url.map(|s| s.to_string()).unwrap_or(existing.base_url); let new_api_key = api_key.map(|s| s.to_string()).unwrap_or(existing.api_key); let new_is_active = is_active.unwrap_or(existing.is_active); + let now = Utc::now(); - sqlx::query( + db_execute!( + db, r#" UPDATE providers - SET name = $1, kind = $2, base_url = $3, api_key = $4, is_active = $5, updated_at = NOW() - WHERE id = $6 + SET name = $1, kind = $2, base_url = $3, api_key = $4, is_active = $5, updated_at = $6 + WHERE id = $7 "#, - ) - .bind(&new_name) - .bind(&new_kind) - .bind(&new_base_url) - .bind(&new_api_key) - .bind(new_is_active) - .bind(id) - .execute(db) - .await?; - - let updated = sqlx::query_as::<_, Provider>("SELECT * FROM providers WHERE id = $1") - .bind(id) - .fetch_one(db) - .await?; + &new_name, &new_kind, &new_base_url, &new_api_key, new_is_active, now, id + )?; + + let updated: Provider = db_query_as!( + one, db, + "SELECT * FROM providers WHERE id = $1", + id + )?; Ok(ProviderInfo::from(updated)) } /// Delete a provider (hard delete — will fail if models reference it). -pub async fn delete_provider(id: Uuid, db: &PgPool) -> Result<(), AppError> { - let result = sqlx::query("DELETE FROM providers WHERE id = $1") - .bind(id) - .execute(db) - .await?; +pub async fn delete_provider(id: Uuid, db: &DbPool) -> Result<(), AppError> { + let rows = db_execute!(db, "DELETE FROM providers WHERE id = $1", id)?; - if result.rows_affected() == 0 { + if rows == 0 { return Err(AppError::NotFound); } diff --git a/src/state.rs b/src/state.rs index af62322..663c4a5 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,12 +1,11 @@ -use redis::aio::ConnectionManager; -use sqlx::PgPool; - +use crate::cache::Cache; use crate::config::Config; +use crate::db::DbPool; #[derive(Clone)] pub struct AppState { - pub db: PgPool, - pub redis: ConnectionManager, + pub db: DbPool, + pub cache: Cache, pub config: Config, pub http_client: reqwest::Client, }