diff --git a/README.md b/README.md index 478051d..9116263 100644 --- a/README.md +++ b/README.md @@ -221,6 +221,20 @@ with budget("$5/hr + 100 calls/hr", name="api-tier", backend=backend) as b: Works with `AsyncRedisBackend` for async workflows. Circuit breaker built in — configurable threshold + cooldown. Fail-open or fail-closed. +### Per-user / per-tenant enforcement + +```python +from shekel.backends.redis import RedisBackend + +backend = RedisBackend() # reads REDIS_URL from env + +with budget(max_usd=0.10, tenant_id=user.id, name="api", backend=backend) as b: + run_agent() +# Each user gets their own isolated $0.10 cap — same Redis, zero per-tenant config +``` + +Quota management: `backend.set_tenant_limit(...)`, `backend.get_tenant_spend(...)`, `backend.reset_tenant(...)`, `backend.list_tenants(...)`. Inspect from the command line with `shekel tenants list --name api`. + ### Rolling-window rate limits ```python diff --git a/docs/api-reference.md b/docs/api-reference.md index caaeecb..eef3764 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -48,6 +48,9 @@ def budget( | `loop_guard_window_seconds` | `float` | `60.0` | Rolling window duration in seconds. `0` = all-time cap (no rolling window). Only applies when `loop_guard=True`. | | `max_velocity` | `str \| None` | `None` | Spend velocity cap. Format: `"$/"` (e.g. `"$0.50/min"`, `"$5/hr"`). Raises `SpendVelocityExceededError` when the burn rate exceeds this threshold. | | `warn_velocity` | `str \| None` | `None` | Soft velocity warning threshold. Same format as `max_velocity`. Must be less than `max_velocity`. Fires `on_warn` callback when crossed; does not raise. | +| `tenant_id` | `str \| None` | `None` | Tenant or user identifier for per-tenant spend isolation. When set, Redis state is namespaced under `shekel:tb:{name}:{tenant_id}`. Requires `name` and `backend`. Empty string raises `ValueError`. | +| `backend` | `RedisBackend \| AsyncRedisBackend \| None` | `None` | Redis backend for distributed or per-tenant enforcement. Required when `tenant_id` is set. | +| `window_seconds` | `float \| None` | `None` | Rolling-window duration in seconds. Required (or inferred from a spec string) for temporal budgets. Default when `tenant_id` is set: `86400 * 30` (30 days). | ### Returns @@ -172,6 +175,7 @@ The budget context manager object. | `switched_at_usd` | `float \| None` | USD spent when fallback occurred, or `None`. | | `fallback_spent` | `float` | USD spent on the fallback model. | | `loop_guard_counts` | `dict[str, int]` | Per-tool call counts recorded by the loop guard. Empty dict when `loop_guard=False`. Keys are tool names; values are total calls recorded within the current window. | +| `tenant_id` | `str \| None` | Tenant identifier passed to `budget()`, or `None` if not set. | ### Nested Budget Properties {#nested-budget-properties} @@ -444,6 +448,38 @@ with budget("$5/hr + 100 calls/hr", name="api-tier", backend=backend) as b: **Raises:** `BudgetConfigMismatchError` if `budget_name` is already registered with different limits or windows. +### Per-Tenant Methods + +| Method | Returns | Description | +|---|---|---| +| `get_tenant_spend(name, tenant_id)` | `float` | Current window spend for the tenant. Returns `0.0` if unknown. | +| `get_tenant_limit(name, tenant_id)` | `float \| None` | Active spend limit for the tenant. Returns `None` if no limit recorded. | +| `set_tenant_limit(name, tenant_id, max_usd)` | `None` | Override the tenant's spend limit without resetting accumulated spend. | +| `reset_tenant(name, tenant_id)` | `None` | Zero out accumulated spend while preserving the limit. | +| `list_tenants(name)` | `list[str]` | All tenant IDs that have recorded spend for the budget name. | + +```python +from shekel.backends.redis import RedisBackend + +backend = RedisBackend() + +# Inspect a tenant +spent = backend.get_tenant_spend(name="api", tenant_id="user-42") +limit = backend.get_tenant_limit(name="api", tenant_id="user-42") + +# Adjust quota +backend.set_tenant_limit(name="api", tenant_id="user-42", max_usd=0.50) + +# Reset at billing period rollover +backend.reset_tenant(name="api", tenant_id="user-42") + +# Enumerate all tenants +for tid in backend.list_tenants(name="api"): + print(tid, backend.get_tenant_spend(name="api", tenant_id=tid)) +``` + +See [Per-Tenant Budgets](usage/per-tenant-budgets.md) for the full guide. + --- ## `AsyncRedisBackend` @@ -461,6 +497,16 @@ async with budget("$5/hr", name="api", backend=backend) as b: Constructor and parameters are identical to `RedisBackend`. +All five per-tenant methods are available as coroutines: + +```python +spent = await backend.get_tenant_spend(name="api", tenant_id="user-42") +limit = await backend.get_tenant_limit(name="api", tenant_id="user-42") +await backend.set_tenant_limit(name="api", tenant_id="user-42", max_usd=0.50) +await backend.reset_tenant(name="api", tenant_id="user-42") +tenants = await backend.list_tenants(name="api") +``` + --- ## `@with_budget` diff --git a/docs/usage/distributed-budgets.md b/docs/usage/distributed-budgets.md index 4305b21..08de175 100644 --- a/docs/usage/distributed-budgets.md +++ b/docs/usage/distributed-budgets.md @@ -291,8 +291,29 @@ See [Docker & Containers](../docker.md) for a full production setup. --- +## Per-Tenant Namespacing + +When `tenant_id` is passed to `budget()`, shekel appends it to the Redis key so each tenant gets fully isolated state: + +``` +shekel:tb:{name}:{tenant_id} # with tenant_id +shekel:tb:{name} # without tenant_id (shared) +``` + +This means you can enforce per-user spend caps in a SaaS app with a single shared `RedisBackend`: + +```python +with budget(max_usd=0.10, tenant_id=user.id, name="api", backend=backend) as b: + run_agent() +``` + +See **[Per-Tenant Budgets](per-tenant-budgets.md)** for the full guide, including quota management methods and the `shekel tenants` CLI. + +--- + ## Next Steps +- **[Per-Tenant Budgets](per-tenant-budgets.md)** - Per-user spend isolation for SaaS apps - **[Temporal Budgets](temporal-budgets.md)** - Rolling-window budget fundamentals - **[Docker & Containers](../docker.md)** - Full containerized production setup - **[API Reference](../api-reference.md)** - Complete parameter reference diff --git a/docs/usage/per-tenant-budgets.md b/docs/usage/per-tenant-budgets.md new file mode 100644 index 0000000..1d79cdb --- /dev/null +++ b/docs/usage/per-tenant-budgets.md @@ -0,0 +1,340 @@ +--- +title: Per-Tenant Budgets – Enforce Per-User LLM Spend Limits +description: "Enforce isolated per-user or per-tenant LLM spend limits in SaaS apps using shekel and Redis. Each tenant gets an independent cap — same backend, zero per-tenant config." +tags: + - per-tenant + - multi-tenant + - saas + - redis + - budget-enforcement +--- + +# Per-Tenant Budgets + +**Give every user their own isolated LLM spend cap — same Redis backend, zero per-tenant infrastructure.** + +```python +with budget(max_usd=0.10, tenant_id=user.id, name="api", backend=RedisBackend()) as b: + run_agent() +# Each user gets an independent $0.10 cap. No shared state. No cross-contamination. +``` + +When `tenant_id` is set, shekel namespaces all Redis state under `shekel:tb:{name}:{tenant_id}`. Two tenants with the same `name` never share counters. + +--- + +## Installation + +```bash +pip install shekel[redis] +``` + +--- + +## Quick Start + +```python +from shekel import budget +from shekel.backends.redis import RedisBackend + +backend = RedisBackend() # reads REDIS_URL from env + +# Enforce a $0.10 monthly cap for user "user-42" +with budget(max_usd=0.10, tenant_id="user-42", name="api", backend=backend) as b: + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + ) +``` + +!!! note "Required parameters" + `tenant_id` requires both `name` and `backend`. Omitting either raises `ValueError`. + +--- + +## FastAPI SaaS Example + +A production-ready endpoint that enforces per-user spend, returns HTTP 429 on exhaustion, and lets admins inspect quotas: + +```python +from fastapi import FastAPI, Depends, HTTPException, Request +from shekel import budget +from shekel.backends.redis import AsyncRedisBackend +from shekel.exceptions import BudgetExceededError + +app = FastAPI() +backend = AsyncRedisBackend(url="redis://redis:6379/0") + +MONTHLY_CAP_USD = 0.10 # $0.10 per user per 30 days + +async def get_current_user(request: Request) -> str: + return request.headers["X-User-ID"] # your auth here + +@app.post("/chat") +async def chat(prompt: str, user_id: str = Depends(get_current_user)): + try: + async with budget( + max_usd=MONTHLY_CAP_USD, + tenant_id=user_id, + name="api", + backend=backend, + window_seconds=86400 * 30, # 30-day rolling window (default) + ) as b: + response = await client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + ) + return { + "reply": response.choices[0].message.content, + "spent": b.spent, + } + except BudgetExceededError as e: + raise HTTPException( + status_code=429, + detail="Monthly spend limit reached.", + headers={"Retry-After": str(int(e.retry_after or 0))}, + ) +``` + +--- + +## Async Usage + +`budget()` with `tenant_id` works identically in async contexts — just use `async with`: + +```python +from shekel import budget +from shekel.backends.redis import AsyncRedisBackend + +backend = AsyncRedisBackend() + +async with budget( + max_usd=0.10, + tenant_id=user_id, + name="api", + backend=backend, +) as b: + await call_llm() + +print(f"Tenant {user_id} spent: ${b.spent:.4f}") +``` + +--- + +## Redis Key Scheme + +Each tenant's state lives in its own Redis hash, completely isolated from other tenants: + +| Key pattern | Example | Contains | +|---|---|---| +| `shekel:tb:{name}:{tenant_id}` | `shekel:tb:api:user-42` | `usd:spent`, `usd:max`, `usd:window_s`, `usd:start`, `spec_hash` | + +Without `tenant_id`, the key is `shekel:tb:{name}` (shared across all callers). With `tenant_id`, an extra segment is appended so no two tenants ever touch the same hash. + +--- + +## Quota Management + +`RedisBackend` (and `AsyncRedisBackend`) expose five admin methods for managing tenant quotas programmatically. + +### `get_tenant_spend(name, tenant_id) → float` + +Return the current window spend for a tenant. Returns `0.0` if the tenant has never been seen. + +```python +spent = backend.get_tenant_spend(name="api", tenant_id="user-42") +print(f"User 42 has spent ${spent:.4f} this window") +``` + +### `get_tenant_limit(name, tenant_id) → float | None` + +Return the active spend limit for a tenant. Returns `None` if the tenant has no recorded limit. + +```python +limit = backend.get_tenant_limit(name="api", tenant_id="user-42") +if limit is not None: + print(f"User 42 limit: ${limit:.2f}") +``` + +### `set_tenant_limit(name, tenant_id, max_usd)` + +Override the spend limit for a tenant without resetting their accumulated spend. Useful for upgrades (free → pro) or admin adjustments. + +```python +# Upgrade user to a $1.00 monthly cap +backend.set_tenant_limit(name="api", tenant_id="user-42", max_usd=1.00) +``` + +After calling `set_tenant_limit`, subsequent `budget(max_usd=1.00, tenant_id="user-42", ...)` calls succeed. Passing the old limit raises `BudgetConfigMismatchError` — see [Limit-change flow](#limit-change-flow). + +### `reset_tenant(name, tenant_id)` + +Zero out a tenant's accumulated spend while preserving their limit. Use this at the start of a new billing period. + +```python +backend.reset_tenant(name="api", tenant_id="user-42") +# spend → 0.0, limit unchanged +``` + +### `list_tenants(name) → list[str]` + +Return all tenant IDs that have ever recorded spend for the given budget name. + +```python +tenants = backend.list_tenants(name="api") +for tid in tenants: + spent = backend.get_tenant_spend(name="api", tenant_id=tid) + limit = backend.get_tenant_limit(name="api", tenant_id=tid) + print(f"{tid}: ${spent:.4f} / ${limit:.2f}") +``` + +### Async equivalents + +All five methods are available as coroutines on `AsyncRedisBackend`: + +```python +spent = await backend.get_tenant_spend(name="api", tenant_id="user-42") +limit = await backend.get_tenant_limit(name="api", tenant_id="user-42") +await backend.set_tenant_limit(name="api", tenant_id="user-42", max_usd=1.00) +await backend.reset_tenant(name="api", tenant_id="user-42") +tenants = await backend.list_tenants(name="api") +``` + +--- + +## `shekel tenants` CLI + +The `shekel tenants` command inspects and manages tenant quotas from the command line — no code changes needed. + +### List tenants + +```bash +shekel tenants list --name api +``` + +``` +Tenant Spent Limit % Used +user-1 $0.0821 $0.1000 82.1% +user-2 $0.0034 $0.1000 3.4% +org:user-3 $0.0990 $0.1000 99.0% +``` + +JSON output: + +```bash +shekel tenants list --name api --json +``` + +```json +[ + {"tenant_id": "user-1", "spent": 0.0821, "limit": 0.1}, + {"tenant_id": "user-2", "spent": 0.0034, "limit": 0.1}, + {"tenant_id": "org:user-3", "spent": 0.0990, "limit": 0.1} +] +``` + +### Set a limit + +```bash +shekel tenants set-limit --name api --tenant user-1 --max-usd 0.50 +``` + +### Reset spend + +```bash +shekel tenants reset --name api --tenant user-1 +``` + +### Flag reference + +| Flag | Description | +|---|---| +| `--name` | Budget name (required for all subcommands) | +| `--tenant` | Tenant ID (required for `set-limit` and `reset`) | +| `--max-usd` | New spend limit in USD (required for `set-limit`) | +| `--redis-url` | Redis URL (default: `$REDIS_URL`) | +| `--json` | Output as JSON instead of a table | + +--- + +## Limit-Change Flow + +When the tenant limit changes (e.g. a user upgrades), shekel detects the mismatch via a stored `spec_hash` and raises `BudgetConfigMismatchError` if you call `budget()` with the old limit. + +**Correct flow:** + +```python +# 1. Admin raises the limit in Redis +backend.set_tenant_limit(name="api", tenant_id="user-1", max_usd=0.50) + +# 2. Next request uses the new limit — no mismatch +with budget(max_usd=0.50, tenant_id="user-1", name="api", backend=backend): + call_llm() +``` + +**Incorrect — still passing old limit:** + +```python +backend.set_tenant_limit(name="api", tenant_id="user-1", max_usd=0.50) + +# Passing old limit 0.10 → BudgetConfigMismatchError +with budget(max_usd=0.10, tenant_id="user-1", name="api", backend=backend): + call_llm() +``` + +The mismatch check is **per-tenant** — changing user-1's limit has no effect on user-2. + +--- + +## Error Reference + +| Exception | When raised | +|---|---| +| `ValueError` | `tenant_id=""` (empty string), or `tenant_id` set without `backend`, or `tenant_id` set without `name` | +| `BudgetExceededError` | Tenant's spend cap is reached during a call | +| `BudgetConfigMismatchError` | Same `(name, tenant_id)` called with a different `max_usd` than what's stored in Redis | + +```python +from shekel.exceptions import BudgetExceededError, BudgetConfigMismatchError + +try: + with budget(max_usd=0.10, tenant_id=user_id, name="api", backend=backend): + call_llm() +except BudgetExceededError as e: + # Tenant is over their cap — retry_after tells them when the window resets + print(f"Limit reached. Retry in {e.retry_after:.0f}s") +except BudgetConfigMismatchError: + # Limit was changed in Redis but code still uses the old value + print("Budget config mismatch — check set_tenant_limit()") +``` + +--- + +## `tenant_id` on the Budget Object + +The `tenant_id` is accessible on the budget instance after the context exits: + +```python +with budget(max_usd=0.10, tenant_id="user-42", name="api", backend=backend) as b: + call_llm() + +print(b.tenant_id) # "user-42" +print(b.spent) # e.g. 0.0023 +``` + +`b.summary()` also surfaces the tenant: + +``` +Budget: api +Tenant: user-42 +Spent: $0.0023 / $0.1000 (2.3%) +Calls: 1 +``` + +--- + +## Next Steps + +- **[Distributed Budgets](distributed-budgets.md)** — shared caps across multiple workers using Redis +- **[Temporal Budgets](temporal-budgets.md)** — rolling-window rate limits +- **[API Reference](../api-reference.md)** — complete parameter and method reference diff --git a/mkdocs.yml b/mkdocs.yml index c309307..f40ba97 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -190,6 +190,7 @@ nav: - Tool Budgets: usage/tool-budgets.md - Temporal Budgets: usage/temporal-budgets.md - Distributed Budgets: usage/distributed-budgets.md + - Per-Tenant Budgets: usage/per-tenant-budgets.md - Accumulating Budgets: usage/accumulating-budgets.md - Streaming: usage/streaming.md - Decorators: usage/decorators.md diff --git a/pyproject.toml b/pyproject.toml index 08ff83c..8433387 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ dev = [ "opentelemetry-api>=1.0.0", "opentelemetry-sdk>=1.0.0", "redis>=4.0.0", - "fakeredis>=2.0.0", + "fakeredis[lua]>=2.0.0", "testcontainers[redis]>=4.0.0", ] docs = [ diff --git a/shekel/__init__.py b/shekel/__init__.py index 320ac38..e3d446c 100644 --- a/shekel/__init__.py +++ b/shekel/__init__.py @@ -85,6 +85,15 @@ def budget( raise ValueError("TemporalBudget requires name=") return TemporalBudget(window_seconds=window_seconds, name=name, **kwargs) + # Route to TemporalBudget when backend or tenant_id is set (Redis-backed enforcement). + # Default window: 30 days (one billing period) when not explicitly provided. + tenant_id = kwargs.get("tenant_id") + backend = kwargs.get("backend") + if tenant_id is not None or backend is not None: + if not name: + raise ValueError("budget() with tenant_id or backend requires name=") + return TemporalBudget(window_seconds=86400 * 30, name=name, **kwargs) + return Budget(name=name, **kwargs) diff --git a/shekel/_budget.py b/shekel/_budget.py index 55ec1f2..9b181fc 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -1285,6 +1285,8 @@ def summary_data(self) -> dict[str, object]: "tool_calls_limit": self.max_tool_calls, "tool_spent": self._tool_spent, "by_tool": by_tool, + # Tenant tracking (v1.2.0) + "tenant_id": None, } def summary(self) -> str: diff --git a/shekel/_cli.py b/shekel/_cli.py index 4144a81..b41f133 100644 --- a/shekel/_cli.py +++ b/shekel/_cli.py @@ -286,6 +286,7 @@ def run( "calls": data["calls_used"], "tool_calls": data["tool_calls_used"], "status": status, + "tenant_id": data.get("tenant_id"), } if by_model: top_model = max( @@ -303,3 +304,63 @@ def run( ) sys.exit(script_exit_code) + + +@cli.command() +@click.option("--name", required=True, help="Budget name to inspect") +@click.option( + "--redis-url", + default=None, + envvar="REDIS_URL", + help="Redis URL (falls back to REDIS_URL env var)", +) +@click.option( + "--output", + type=click.Choice(["table", "json"]), + default="table", + help="Output format", +) +def tenants(name: str, redis_url: str | None, output: str) -> None: + """List all tenants for a named budget with their spend and limit.""" + import json as _json + + from shekel.backends.redis import RedisBackend + + if not redis_url: + click.echo("Error: Redis URL required. Set --redis-url or REDIS_URL.", err=False) + sys.exit(1) + + backend = RedisBackend(url=redis_url) + try: + tenant_ids = backend.list_tenants(name=name) + except Exception as exc: # noqa: BLE001 + click.echo(f"Error: Redis unreachable — {exc}", err=True) + sys.exit(1) + + TenantRow = dict[str, object] + rows: list[TenantRow] = [] + for tid in tenant_ids: + spent: float = backend.get_tenant_spend(name=name, tenant_id=tid) + limit: float | None = backend.get_tenant_limit(name=name, tenant_id=tid) + utilization: float | None = (spent / limit) if limit else None + rows.append({"tenant_id": tid, "spent": spent, "limit": limit, "utilization": utilization}) + + if output == "json": + click.echo(_json.dumps(rows)) + return + + # Table output + w0 = max(len("TENANT"), max((len(str(r["tenant_id"])) for r in rows), default=0)) + w1, w2, w3 = 9, 9, 7 + header = f"{'TENANT':<{w0}} {'SPENT':>{w1}} {'LIMIT':>{w2}} {'USED%':>{w3}}" + click.echo(header) + click.echo("-" * (w0 + w1 + w2 + w3 + 6)) + for r in rows: + tid_s = str(r["tenant_id"]) + spent_f = float(r["spent"]) # type: ignore[arg-type] + limit_f: float | None = float(r["limit"]) if r["limit"] is not None else None # type: ignore[arg-type] + util_f: float | None = float(r["utilization"]) if r["utilization"] is not None else None # type: ignore[arg-type] + spent_s = f"${spent_f:.4f}" + limit_s = f"${limit_f:.4f}" if limit_f is not None else "—" + pct_s = f"{util_f * 100:.1f}%" if util_f is not None else "—" + click.echo(f"{tid_s:<{w0}} {spent_s:>{w1}} {limit_s:>{w2}} {pct_s:>{w3}}") diff --git a/shekel/_temporal.py b/shekel/_temporal.py index e518d38..7285964 100644 --- a/shekel/_temporal.py +++ b/shekel/_temporal.py @@ -263,11 +263,18 @@ def __init__( name: str, backend: TemporalBudgetBackend | None = None, caps: list[tuple[str, float | None, float]] | None = None, + tenant_id: str | None = None, **kwargs: Any, ) -> None: if not name: raise ValueError("TemporalBudget requires a non-empty name=") + if tenant_id is not None: + if not isinstance(tenant_id, str) or tenant_id == "": + raise ValueError("tenant_id must be a non-empty string") + if backend is None: + raise ValueError("tenant_id requires a Redis backend") + # Extract multi-cap kwargs that should NOT be passed to Budget # (Budget enforces them cumulatively; TemporalBudget enforces via backend). max_llm_calls: int | None = kwargs.pop("max_llm_calls", None) @@ -285,6 +292,7 @@ def __init__( # Pass max_usd to parent for .spent / .remaining / .limit property tracking. super().__init__(max_usd=effective_max_usd, name=name, **kwargs) self._backend: TemporalBudgetBackend = backend or InMemoryBackend() + self._tenant_id: str | None = tenant_id if caps is not None: # Structured caps from factory (spec-string form). @@ -326,9 +334,16 @@ def _check_temporal_ancestor(self) -> None: ) current = current.parent + @property + def tenant_id(self) -> str | None: + """The tenant identifier scoping this budget's Redis key, or None.""" + return self._tenant_id + def _lazy_window_reset(self) -> None: """If the primary window has expired since last entry, emit on_window_reset.""" - budget_name = self.name or "unnamed" + budget_name = ( + f"{self.name}:{self._tenant_id}" if self._tenant_id else (self.name or "unnamed") + ) # Use get_window_info if available (InMemoryBackend exposes it). if not hasattr(self._backend, "get_window_info"): @@ -365,7 +380,9 @@ def _lazy_window_reset(self) -> None: def _record_spend(self, cost: float, model: str, tokens: dict[str, int]) -> None: """Override to enforce rolling-window spend via backend before calling parent.""" - budget_name = self.name or "unnamed" + budget_name = ( + f"{self.name}:{self._tenant_id}" if self._tenant_id else (self.name or "unnamed") + ) # Build amounts/limits/windows for backend call. # Only include LLM-relevant counters (usd + llm_calls). @@ -423,6 +440,20 @@ def _record_spend(self, cost: float, model: str, tokens: dict[str, int]) -> None super()._record_spend(cost, model, tokens) + def summary_data(self) -> dict[str, object]: + data = super().summary_data() + data["tenant_id"] = self._tenant_id + return data + + def summary(self) -> str: + text = super().summary() + if not self._tenant_id: + return text + lines = text.split("\n") + # Insert "Tenant:" after the first line (the top border) + lines.insert(1, f"│ Tenant: {self._tenant_id}") + return "\n".join(lines) + def __enter__(self) -> TemporalBudget: self._check_temporal_ancestor() self._lazy_window_reset() diff --git a/shekel/backends/redis.py b/shekel/backends/redis.py index 7d29cc5..2a50468 100644 --- a/shekel/backends/redis.py +++ b/shekel/backends/redis.py @@ -324,6 +324,70 @@ def reset(self, budget_name: str) -> None: key = f"shekel:tb:{budget_name}" self._ensure_client().delete(key) + # ------------------------------------------------------------------ + # Per-tenant quota management API (SHEK-4) + # ------------------------------------------------------------------ + + def _tenant_key(self, name: str, tenant_id: str) -> str: + return f"shekel:tb:{name}:{tenant_id}" + + def get_tenant_spend(self, name: str, tenant_id: str) -> float: + """Return accumulated spend for a tenant; 0.0 if the tenant is unknown.""" + key = self._tenant_key(name, tenant_id) + try: + raw = self._ensure_client().hget(key, "usd:spent") + except Exception: # noqa: BLE001 + return 0.0 + return float(raw) if raw else 0.0 + + def get_tenant_limit(self, name: str, tenant_id: str) -> float | None: + """Return the stored limit for a tenant, or None if not set.""" + key = self._tenant_key(name, tenant_id) + try: + raw = self._ensure_client().hget(key, "usd:max") + except Exception: # noqa: BLE001 + return None + if not raw or raw in (b"", ""): + return None + return float(raw) + + def set_tenant_limit(self, name: str, tenant_id: str, max_usd: float) -> None: + """Update a tenant's limit and recompute spec_hash so budget() calls with + the new limit succeed without BudgetConfigMismatchError.""" + key = self._tenant_key(name, tenant_id) + client = self._ensure_client() + raw_window = client.hget(key, "usd:window_s") + window_s: float | int = float(86400 * 30) + if raw_window: + w = float(raw_window) + window_s = int(w) if w.is_integer() else w + new_hash = _build_spec_hash({"usd": max_usd}, {"usd": window_s}) + client.hset(key, mapping={"usd:max": str(max_usd), "spec_hash": new_hash}) + + def reset_tenant(self, name: str, tenant_id: str) -> None: + """Zero the spend counter for a tenant; preserves limit and spec_hash.""" + key = self._tenant_key(name, tenant_id) + self._ensure_client().hset(key, "usd:spent", "0") + + def list_tenants(self, name: str) -> list[str]: + """Return all tenant IDs that have a key under this budget name.""" + prefix = f"shekel:tb:{name}:" + pattern = f"{prefix}*" + client = self._ensure_client() + tenant_ids: list[str] = [] + try: + cursor = 0 + while True: + cursor, keys = client.scan(cursor, match=pattern, count=100) + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else key + tenant_ids.append(key_str[len(prefix) :]) + if cursor == 0: + break + except Exception: # noqa: BLE001 + return [] + return tenant_ids + def close(self) -> None: if self._client is not None: self._client.close() @@ -474,6 +538,72 @@ async def reset(self, budget_name: str) -> None: client = await self._ensure_client() await client.delete(key) + # ------------------------------------------------------------------ + # Per-tenant quota management API (SHEK-4) + # ------------------------------------------------------------------ + + def _tenant_key(self, name: str, tenant_id: str) -> str: + return f"shekel:tb:{name}:{tenant_id}" + + async def get_tenant_spend(self, name: str, tenant_id: str) -> float: + """Return accumulated spend for a tenant; 0.0 if the tenant is unknown.""" + key = self._tenant_key(name, tenant_id) + try: + client = await self._ensure_client() + raw = await client.hget(key, "usd:spent") + except Exception: # noqa: BLE001 + return 0.0 + return float(raw) if raw else 0.0 + + async def get_tenant_limit(self, name: str, tenant_id: str) -> float | None: + """Return the stored limit for a tenant, or None if not set.""" + key = self._tenant_key(name, tenant_id) + try: + client = await self._ensure_client() + raw = await client.hget(key, "usd:max") + except Exception: # noqa: BLE001 + return None + if not raw or raw in (b"", ""): + return None + return float(raw) + + async def set_tenant_limit(self, name: str, tenant_id: str, max_usd: float) -> None: + """Update a tenant's limit and recompute spec_hash.""" + key = self._tenant_key(name, tenant_id) + client = await self._ensure_client() + raw_window = await client.hget(key, "usd:window_s") + window_s: float | int = float(86400 * 30) + if raw_window: + w = float(raw_window) + window_s = int(w) if w.is_integer() else w + new_hash = _build_spec_hash({"usd": max_usd}, {"usd": window_s}) + await client.hset(key, mapping={"usd:max": str(max_usd), "spec_hash": new_hash}) + + async def reset_tenant(self, name: str, tenant_id: str) -> None: + """Zero the spend counter for a tenant; preserves limit and spec_hash.""" + key = self._tenant_key(name, tenant_id) + client = await self._ensure_client() + await client.hset(key, "usd:spent", "0") + + async def list_tenants(self, name: str) -> list[str]: + """Return all tenant IDs that have a key under this budget name.""" + prefix = f"shekel:tb:{name}:" + pattern = f"{prefix}*" + client = await self._ensure_client() + tenant_ids: list[str] = [] + try: + cursor = 0 + while True: + cursor, keys = await client.scan(cursor, match=pattern, count=100) + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else key + tenant_ids.append(key_str[len(prefix) :]) + if cursor == 0: + break + except Exception: # noqa: BLE001 + return [] + return tenant_ids + async def close(self) -> None: if self._client is not None: await self._client.aclose() diff --git a/tests/test_cli.py b/tests/test_cli.py index 4e752d2..19959fa 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -114,3 +114,125 @@ def test_models_no_results(runner: CliRunner) -> None: result = runner.invoke(cli, ["models"]) assert result.exit_code == 0 assert "No models found" in result.output + + +# --------------------------------------------------------------------------- +# SHEK-5: shekel tenants command +# --------------------------------------------------------------------------- + + +def test_tenants_no_redis_url_exits_nonzero(runner: CliRunner) -> None: + """Missing REDIS_URL and no --redis-url exits 1 with error message.""" + import os + from unittest.mock import patch + + env = {k: v for k, v in os.environ.items() if k != "REDIS_URL"} + with patch.dict(os.environ, env, clear=True): + result = runner.invoke(cli, ["tenants", "--name", "api"]) + assert result.exit_code != 0 + assert "Redis URL required" in result.output + + +def test_tenants_table_output(runner: CliRunner) -> None: + """shekel tenants prints a table with TENANT, SPENT, LIMIT, USED% columns.""" + from unittest.mock import MagicMock, patch + + mock_backend = MagicMock() + mock_backend.list_tenants.return_value = ["user-1", "user-2"] + mock_backend.get_tenant_spend.side_effect = lambda name, tenant_id: ( + 0.08 if tenant_id == "user-1" else 0.03 + ) + mock_backend.get_tenant_limit.side_effect = lambda name, tenant_id: 0.10 + + with patch("shekel.backends.redis.RedisBackend", return_value=mock_backend): + result = runner.invoke( + cli, ["tenants", "--name", "api", "--redis-url", "redis://localhost"] + ) + + assert result.exit_code == 0 + assert "TENANT" in result.output + assert "SPENT" in result.output + assert "LIMIT" in result.output + assert "user-1" in result.output + assert "user-2" in result.output + + +def test_tenants_json_output(runner: CliRunner) -> None: + """shekel tenants --output json prints valid JSON array.""" + import json + from unittest.mock import MagicMock, patch + + mock_backend = MagicMock() + mock_backend.list_tenants.return_value = ["user-1"] + mock_backend.get_tenant_spend.return_value = 0.05 + mock_backend.get_tenant_limit.return_value = 0.10 + + with patch("shekel.backends.redis.RedisBackend", return_value=mock_backend): + result = runner.invoke( + cli, + ["tenants", "--name", "api", "--redis-url", "redis://localhost", "--output", "json"], + ) + + assert result.exit_code == 0 + rows = json.loads(result.output) + assert isinstance(rows, list) + assert rows[0]["tenant_id"] == "user-1" + assert rows[0]["spent"] == pytest.approx(0.05) + assert rows[0]["limit"] == pytest.approx(0.10) + assert rows[0]["utilization"] == pytest.approx(0.50) + + +def test_tenants_no_limit_shows_dash(runner: CliRunner) -> None: + """Tenant with no stored limit shows — in LIMIT and USED% columns.""" + from unittest.mock import MagicMock, patch + + mock_backend = MagicMock() + mock_backend.list_tenants.return_value = ["user-x"] + mock_backend.get_tenant_spend.return_value = 0.12 + mock_backend.get_tenant_limit.return_value = None + + with patch("shekel.backends.redis.RedisBackend", return_value=mock_backend): + result = runner.invoke( + cli, ["tenants", "--name", "api", "--redis-url", "redis://localhost"] + ) + + assert result.exit_code == 0 + assert "—" in result.output + + +def test_tenants_json_null_limit(runner: CliRunner) -> None: + """JSON output has null for limit/utilization when no limit stored.""" + import json + from unittest.mock import MagicMock, patch + + mock_backend = MagicMock() + mock_backend.list_tenants.return_value = ["user-x"] + mock_backend.get_tenant_spend.return_value = 0.12 + mock_backend.get_tenant_limit.return_value = None + + with patch("shekel.backends.redis.RedisBackend", return_value=mock_backend): + result = runner.invoke( + cli, + ["tenants", "--name", "api", "--redis-url", "redis://localhost", "--output", "json"], + ) + + assert result.exit_code == 0 + rows = json.loads(result.output) + assert rows[0]["limit"] is None + assert rows[0]["utilization"] is None + + +def test_tenants_redis_unreachable_exits_nonzero(runner: CliRunner) -> None: + """list_tenants raising an exception prints an error and exits non-zero.""" + from unittest.mock import MagicMock, patch + + mock_backend = MagicMock() + mock_backend.list_tenants.side_effect = RuntimeError("connection refused") + + with patch("shekel.backends.redis.RedisBackend", return_value=mock_backend): + result = runner.invoke( + cli, ["tenants", "--name", "api", "--redis-url", "redis://localhost"] + ) + + assert result.exit_code != 0 + assert "Redis unreachable" in result.output diff --git a/tests/test_tenant_budgets.py b/tests/test_tenant_budgets.py new file mode 100644 index 0000000..4d85e46 --- /dev/null +++ b/tests/test_tenant_budgets.py @@ -0,0 +1,1061 @@ +"""Tests for per-tenant budget enforcement (SHEK-1 / SHEK-3). + +Uses fakeredis — no Docker or real Redis required. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +try: + import fakeredis + + FAKEREDIS_AVAILABLE = True +except ImportError: + FAKEREDIS_AVAILABLE = False + +pytestmark = pytest.mark.skipif(not FAKEREDIS_AVAILABLE, reason="fakeredis not installed") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_backend(): + """Return a RedisBackend wired to an in-process fakeredis server.""" + from shekel.backends.redis import RedisBackend + + server = fakeredis.FakeServer() + backend = RedisBackend.__new__(RedisBackend) + backend._url = "redis://localhost" + backend._tls = False + backend._on_unavailable = "closed" + backend._cb_threshold = 3 + backend._cb_cooldown = 10.0 + backend._client = fakeredis.FakeRedis(server=server, decode_responses=False) + backend._script_sha = None + backend._consecutive_errors = 0 + backend._circuit_open_at = None + return backend + + +# --------------------------------------------------------------------------- +# SHEK-3: budget() factory routing +# --------------------------------------------------------------------------- + + +class TestBudgetFactoryRouting: + """budget() routes to TemporalBudget when tenant_id or backend is set.""" + + def test_tenant_id_routes_to_temporal_budget(self) -> None: + from shekel import budget + from shekel._temporal import TemporalBudget + + b = budget(max_usd=0.10, tenant_id="user-1", name="api", backend=_make_backend()) + assert isinstance(b, TemporalBudget) + + def test_default_window_is_30_days(self) -> None: + from shekel import budget + + b = budget(max_usd=0.10, tenant_id="user-1", name="api", backend=_make_backend()) + assert b._caps["usd"][1] == 86400 * 30 + + def test_explicit_window_seconds_overrides_default(self) -> None: + from shekel import budget + + b = budget( + max_usd=0.10, + tenant_id="user-1", + name="api", + backend=_make_backend(), + window_seconds=3600, + ) + assert b._caps["usd"][1] == 3600 + + def test_backend_alone_routes_to_temporal_budget(self) -> None: + from shekel import budget + from shekel._temporal import TemporalBudget + + b = budget(max_usd=1.00, name="api", backend=_make_backend(), window_seconds=3600) + assert isinstance(b, TemporalBudget) + + def test_no_tenant_no_backend_returns_plain_budget(self) -> None: + from shekel import budget + from shekel._budget import Budget + from shekel._temporal import TemporalBudget + + b = budget(max_usd=1.00) + assert isinstance(b, Budget) + assert not isinstance(b, TemporalBudget) + + def test_tenant_id_without_name_raises(self) -> None: + from shekel import budget + + with pytest.raises(ValueError, match="name"): + budget(max_usd=0.10, tenant_id="user-1", backend=_make_backend()) + + def test_tenant_id_accessible_on_instance(self) -> None: + from shekel import budget + + b = budget(max_usd=0.10, tenant_id="user-99", name="api", backend=_make_backend()) + assert b.tenant_id == "user-99" + + def test_tenant_id_none_when_not_set(self) -> None: + from shekel import budget + + b = budget(max_usd=1.00, name="api", backend=_make_backend(), window_seconds=3600) + assert b.tenant_id is None + + +# --------------------------------------------------------------------------- +# SHEK-3: TemporalBudget validation +# --------------------------------------------------------------------------- + + +class TestTemporalBudgetTenantValidation: + """tenant_id validation in TemporalBudget.__init__.""" + + def test_empty_tenant_id_raises(self) -> None: + from shekel import budget + + with pytest.raises(ValueError, match="non-empty"): + budget(max_usd=0.10, tenant_id="", name="api", backend=_make_backend()) + + def test_tenant_id_without_backend_raises(self) -> None: + from shekel._temporal import TemporalBudget + + with pytest.raises(ValueError, match="Redis backend"): + TemporalBudget( + max_usd=0.10, + tenant_id="user-1", + name="api", + window_seconds=3600, + backend=None, + ) + + def test_valid_tenant_id_does_not_raise(self) -> None: + from shekel import budget + + b = budget(max_usd=0.10, tenant_id="user-1", name="api", backend=_make_backend()) + assert b.tenant_id == "user-1" + + def test_tenant_id_with_colon_is_valid(self) -> None: + from shekel import budget + + b = budget(max_usd=0.10, tenant_id="org:user-1", name="api", backend=_make_backend()) + assert b.tenant_id == "org:user-1" + + +# --------------------------------------------------------------------------- +# SHEK-3: Redis key isolation +# --------------------------------------------------------------------------- + + +class TestTenantKeyIsolation: + """Different tenant_ids use distinct Redis keys and never share state.""" + + def _make_shared_backend(self): + """Return two backends sharing the same fakeredis server.""" + server = fakeredis.FakeServer() + from shekel.backends.redis import RedisBackend + + def _backend(): + b = RedisBackend.__new__(RedisBackend) + b._url = "redis://localhost" + b._tls = False + b._on_unavailable = "closed" + b._cb_threshold = 3 + b._cb_cooldown = 10.0 + b._client = fakeredis.FakeRedis(server=server, decode_responses=False) + b._script_sha = None + b._consecutive_errors = 0 + b._circuit_open_at = None + return b + + return _backend + + def test_two_tenants_do_not_share_state(self) -> None: + from shekel import budget + + make = self._make_shared_backend() + + # User A: tiny budget + with budget(max_usd=0.001, tenant_id="user-a", name="api", backend=make()) as b_a: + pass + + # User B: generous budget — should not be affected by user-a's spend + with budget(max_usd=10.00, tenant_id="user-b", name="api", backend=make()) as b_b: + pass + + assert b_a.spent == pytest.approx(0.0) + assert b_b.spent == pytest.approx(0.0) + + def test_tenant_a_exceeded_does_not_block_tenant_b(self) -> None: + from unittest.mock import patch + + import openai + + from shekel import budget + from shekel.exceptions import BudgetExceededError + + make = self._make_shared_backend() + mock_resp = _fake_openai_response(100, 50) + + # Apply mock OUTSIDE budget context so shekel wraps it, not replaces it + with patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_resp, + ): + client = openai.OpenAI(api_key="test") + + # Exhaust tenant-a + with pytest.raises(BudgetExceededError): + with budget( + max_usd=0.001, + tenant_id="user-a", + name="api", + backend=make(), + price_per_1k_tokens={"input": 100.0, "output": 100.0}, + ): + client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi"}], + ) + + # Tenant-b must still work + with budget(max_usd=5.00, tenant_id="user-b", name="api", backend=make()) as b_b: + pass + assert b_b is not None + + def test_redis_key_includes_tenant_id(self) -> None: + import openai + + from shekel import budget + + server = fakeredis.FakeServer() + from shekel.backends.redis import RedisBackend + + backend = RedisBackend.__new__(RedisBackend) + backend._url = "redis://localhost" + backend._tls = False + backend._on_unavailable = "closed" + backend._cb_threshold = 3 + backend._cb_cooldown = 10.0 + redis_client = fakeredis.FakeRedis(server=server, decode_responses=False) + backend._client = redis_client + backend._script_sha = None + backend._consecutive_errors = 0 + backend._circuit_open_at = None + + mock_resp = _fake_openai_response(100, 50) + # Mock OUTSIDE budget so shekel's patch wraps the mock + with patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_resp, + ): + client = openai.OpenAI(api_key="test") + with budget( + max_usd=1.00, + tenant_id="user-123", + name="api", + backend=backend, + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ): + client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi"}], + ) + + keys = [k.decode() for k in redis_client.keys("shekel:tb:*")] + assert any("api:user-123" in k for k in keys) + assert not any(k == "shekel:tb:api" for k in keys) + + +# --------------------------------------------------------------------------- +# SHEK-3: BudgetConfigMismatchError +# --------------------------------------------------------------------------- + + +class TestTenantConfigMismatch: + """Same (name, tenant_id) with different max_usd raises BudgetConfigMismatchError.""" + + def _seed_tenant(self, make, tenant_id: str, max_usd: float) -> None: + """Open a budget and make a tiny LLM call to seed the Redis spec_hash.""" + from unittest.mock import patch as upatch + + import openai + + from shekel import budget + + mock_resp = _fake_openai_response(1, 1) + with upatch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_resp, + ): + client = openai.OpenAI(api_key="test") + with budget( + max_usd=max_usd, + tenant_id=tenant_id, + name="api", + backend=make(), + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ): + client.chat.completions.create( + model="gpt-4o-mini", messages=[{"role": "user", "content": "hi"}] + ) + + def test_mismatch_on_different_limit(self) -> None: + import openai + + from shekel import budget + from shekel.exceptions import BudgetConfigMismatchError + + make = self._make_shared_backend() + # Seed spec_hash for user-1 with limit 0.10 + self._seed_tenant(make, "user-1", 0.10) + + mock_resp = _fake_openai_response(1, 1) + with patch("openai.resources.chat.completions.Completions.create", return_value=mock_resp): + client = openai.OpenAI(api_key="test") + with pytest.raises(BudgetConfigMismatchError): + with budget( + max_usd=0.20, + tenant_id="user-1", + name="api", + backend=make(), + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ): + client.chat.completions.create( + model="gpt-4o-mini", messages=[{"role": "user", "content": "hi"}] + ) + + def test_same_limit_does_not_raise(self) -> None: + import openai + + from shekel import budget + + make = self._make_shared_backend() + self._seed_tenant(make, "user-1", 0.10) + + mock_resp = _fake_openai_response(1, 1) + with patch("openai.resources.chat.completions.Completions.create", return_value=mock_resp): + client = openai.OpenAI(api_key="test") + # Same limit — must not raise + with budget( + max_usd=0.10, + tenant_id="user-1", + name="api", + backend=make(), + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ): + client.chat.completions.create( + model="gpt-4o-mini", messages=[{"role": "user", "content": "hi"}] + ) + + def test_mismatch_is_per_tenant_not_global(self) -> None: + import openai + + from shekel import budget + + make = self._make_shared_backend() + self._seed_tenant(make, "user-1", 0.10) + + mock_resp = _fake_openai_response(1, 1) + with patch("openai.resources.chat.completions.Completions.create", return_value=mock_resp): + client = openai.OpenAI(api_key="test") + # user-2 can use a different limit — no mismatch (different key) + with budget( + max_usd=0.50, + tenant_id="user-2", + name="api", + backend=make(), + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ): + client.chat.completions.create( + model="gpt-4o-mini", messages=[{"role": "user", "content": "hi"}] + ) + + def _make_shared_backend(self): + server = fakeredis.FakeServer() + from shekel.backends.redis import RedisBackend + + def _backend(): + b = RedisBackend.__new__(RedisBackend) + b._url = "redis://localhost" + b._tls = False + b._on_unavailable = "closed" + b._cb_threshold = 3 + b._cb_cooldown = 10.0 + b._client = fakeredis.FakeRedis(server=server, decode_responses=False) + b._script_sha = None + b._consecutive_errors = 0 + b._circuit_open_at = None + return b + + return _backend + + +# --------------------------------------------------------------------------- +# SHEK-3: async parity +# --------------------------------------------------------------------------- + + +class TestTenantAsync: + """async with budget(..., tenant_id=...) enforces identically to sync.""" + + @pytest.mark.asyncio + async def test_async_budget_with_tenant_id(self) -> None: + from shekel import budget + from shekel._temporal import TemporalBudget + + b = budget(max_usd=1.00, tenant_id="user-1", name="api", backend=_make_backend()) + assert isinstance(b, TemporalBudget) + async with b: + pass + + @pytest.mark.asyncio + async def test_async_tenant_id_accessible(self) -> None: + from shekel import budget + + async with budget( + max_usd=1.00, tenant_id="async-user", name="api", backend=_make_backend() + ) as b: + assert b.tenant_id == "async-user" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _fake_openai_response(input_tokens: int, output_tokens: int): + from unittest.mock import MagicMock + + m = MagicMock() + m.choices[0].message.content = "hi" + m.usage.prompt_tokens = input_tokens + m.usage.completion_tokens = output_tokens + m.model = "gpt-4o-mini" + return m + + +def _shared_fakeredis_backend(): + """Return a factory that produces backends sharing one FakeServer.""" + server = fakeredis.FakeServer() + from shekel.backends.redis import RedisBackend + + def _make(): + b = RedisBackend.__new__(RedisBackend) + b._url = "redis://localhost" + b._tls = False + b._on_unavailable = "closed" + b._cb_threshold = 3 + b._cb_cooldown = 10.0 + b._client = fakeredis.FakeRedis(server=server, decode_responses=False) + b._script_sha = None + b._consecutive_errors = 0 + b._circuit_open_at = None + return b + + return _make + + +def _shared_async_fakeredis_backend(): + """Return a factory that produces async backends sharing one FakeServer.""" + server = fakeredis.FakeServer() + from shekel.backends.redis import AsyncRedisBackend + + def _make(): + b = AsyncRedisBackend.__new__(AsyncRedisBackend) + b._url = "redis://localhost" + b._tls = False + b._on_unavailable = "closed" + b._cb_threshold = 3 + b._cb_cooldown = 10.0 + b._client = fakeredis.aioredis.FakeRedis(server=server, decode_responses=False) + b._script_sha = None + b._consecutive_errors = 0 + b._circuit_open_at = None + return b + + return _make + + +def _seed_spend(make, name: str, tenant_id: str, max_usd: float) -> None: + """Open a budget and make a tiny LLM call to seed Redis spend + spec_hash.""" + import openai + + from shekel import budget + + mock_resp = _fake_openai_response(1, 1) + with patch("openai.resources.chat.completions.Completions.create", return_value=mock_resp): + client = openai.OpenAI(api_key="test") + with budget( + max_usd=max_usd, + tenant_id=tenant_id, + name=name, + backend=make(), + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ): + client.chat.completions.create( + model="gpt-4o-mini", messages=[{"role": "user", "content": "hi"}] + ) + + +# --------------------------------------------------------------------------- +# SHEK-4: sync quota management API +# --------------------------------------------------------------------------- + + +class TestTenantQuotaManagement: + """RedisBackend quota management methods: get_tenant_spend, get_tenant_limit, + set_tenant_limit, reset_tenant, list_tenants.""" + + def test_get_tenant_spend_unknown_tenant_returns_zero(self) -> None: + + backend = _make_backend() + assert backend.get_tenant_spend(name="api", tenant_id="nobody") == 0.0 + + def test_get_tenant_spend_returns_accumulated_spend(self) -> None: + make = _shared_fakeredis_backend() + _seed_spend(make, "api", "user-1", 1.00) + + backend = make() + spent = backend.get_tenant_spend(name="api", tenant_id="user-1") + assert spent > 0.0 + + def test_get_tenant_limit_unknown_tenant_returns_none(self) -> None: + backend = _make_backend() + assert backend.get_tenant_limit(name="api", tenant_id="nobody") is None + + def test_get_tenant_limit_returns_limit_after_set(self) -> None: + backend = _make_backend() + backend.set_tenant_limit(name="api", tenant_id="user-1", max_usd=0.50) + assert backend.get_tenant_limit(name="api", tenant_id="user-1") == pytest.approx(0.50) + + def test_set_tenant_limit_allows_budget_with_new_limit(self) -> None: + make = _shared_fakeredis_backend() + # Seed with 0.10 + _seed_spend(make, "api", "user-1", 0.10) + + # Admin raises limit to 0.20 + make().set_tenant_limit(name="api", tenant_id="user-1", max_usd=0.20) + + # budget(max_usd=0.20) must now succeed without BudgetConfigMismatchError + import openai + + from shekel import budget + + mock_resp = _fake_openai_response(1, 1) + with patch("openai.resources.chat.completions.Completions.create", return_value=mock_resp): + client = openai.OpenAI(api_key="test") + with budget( + max_usd=0.20, + tenant_id="user-1", + name="api", + backend=make(), + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ): + client.chat.completions.create( + model="gpt-4o-mini", messages=[{"role": "user", "content": "hi"}] + ) + + def test_set_tenant_limit_old_limit_raises_mismatch(self) -> None: + from shekel import budget + from shekel.exceptions import BudgetConfigMismatchError + + make = _shared_fakeredis_backend() + _seed_spend(make, "api", "user-1", 0.10) + + # Admin raises limit to 0.20 + make().set_tenant_limit(name="api", tenant_id="user-1", max_usd=0.20) + + # budget with OLD limit must raise BudgetConfigMismatchError + import openai + + mock_resp = _fake_openai_response(1, 1) + with patch("openai.resources.chat.completions.Completions.create", return_value=mock_resp): + client = openai.OpenAI(api_key="test") + with pytest.raises(BudgetConfigMismatchError): + with budget( + max_usd=0.10, + tenant_id="user-1", + name="api", + backend=make(), + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ): + client.chat.completions.create( + model="gpt-4o-mini", messages=[{"role": "user", "content": "hi"}] + ) + + def test_reset_tenant_zeroes_spend(self) -> None: + make = _shared_fakeredis_backend() + _seed_spend(make, "api", "user-1", 1.00) + + backend = make() + assert backend.get_tenant_spend(name="api", tenant_id="user-1") > 0.0 + + backend.reset_tenant(name="api", tenant_id="user-1") + assert backend.get_tenant_spend(name="api", tenant_id="user-1") == pytest.approx(0.0) + + def test_reset_tenant_preserves_limit(self) -> None: + make = _shared_fakeredis_backend() + _seed_spend(make, "api", "user-1", 0.50) + + backend = make() + backend.reset_tenant(name="api", tenant_id="user-1") + + # Limit should still be readable after reset + limit = backend.get_tenant_limit(name="api", tenant_id="user-1") + assert limit == pytest.approx(0.50) + + def test_reset_tenant_allows_re_accumulation(self) -> None: + make = _shared_fakeredis_backend() + _seed_spend(make, "api", "user-1", 1.00) + + make().reset_tenant(name="api", tenant_id="user-1") + + # Can accumulate spend again from zero without BudgetConfigMismatchError + _seed_spend(make, "api", "user-1", 1.00) + spent = make().get_tenant_spend(name="api", tenant_id="user-1") + assert spent > 0.0 + + def test_list_tenants_empty_when_no_tenants(self) -> None: + backend = _make_backend() + assert backend.list_tenants(name="api") == [] + + def test_list_tenants_returns_all_tenant_ids(self) -> None: + make = _shared_fakeredis_backend() + _seed_spend(make, "api", "user-1", 1.00) + _seed_spend(make, "api", "user-2", 1.00) + _seed_spend(make, "api", "user-3", 1.00) + + tenants = make().list_tenants(name="api") + assert sorted(tenants) == ["user-1", "user-2", "user-3"] + + def test_list_tenants_excludes_other_budget_names(self) -> None: + make = _shared_fakeredis_backend() + _seed_spend(make, "api", "user-1", 1.00) + _seed_spend(make, "other", "user-9", 1.00) + + tenants = make().list_tenants(name="api") + assert tenants == ["user-1"] + + def test_list_tenants_handles_colon_in_tenant_id(self) -> None: + make = _shared_fakeredis_backend() + _seed_spend(make, "api", "org:user-1", 1.00) + + tenants = make().list_tenants(name="api") + assert tenants == ["org:user-1"] + + +# --------------------------------------------------------------------------- +# SHEK-4: async quota management API +# --------------------------------------------------------------------------- + + +class TestAsyncTenantQuotaManagement: + """AsyncRedisBackend mirrors all five sync methods.""" + + @pytest.mark.asyncio + async def test_async_get_tenant_spend_unknown_returns_zero(self) -> None: + make = _shared_async_fakeredis_backend() + backend = make() + assert await backend.get_tenant_spend(name="api", tenant_id="nobody") == 0.0 + + @pytest.mark.asyncio + async def test_async_get_tenant_limit_unknown_returns_none(self) -> None: + make = _shared_async_fakeredis_backend() + backend = make() + assert await backend.get_tenant_limit(name="api", tenant_id="nobody") is None + + @pytest.mark.asyncio + async def test_async_set_and_get_tenant_limit(self) -> None: + make = _shared_async_fakeredis_backend() + backend = make() + await backend.set_tenant_limit(name="api", tenant_id="user-1", max_usd=0.75) + limit = await backend.get_tenant_limit(name="api", tenant_id="user-1") + assert limit == pytest.approx(0.75) + + @pytest.mark.asyncio + async def test_async_reset_tenant_zeroes_spend(self) -> None: + make = _shared_async_fakeredis_backend() + backend = make() + # Seed spend via the async set + key = "shekel:tb:api:user-1" + await backend._client.hset(key, "usd:spent", "0.42") + + await backend.reset_tenant(name="api", tenant_id="user-1") + assert await backend.get_tenant_spend(name="api", tenant_id="user-1") == pytest.approx(0.0) + + @pytest.mark.asyncio + async def test_async_list_tenants(self) -> None: + make = _shared_async_fakeredis_backend() + backend = make() + # Seed two tenant keys via async client + await backend._client.hset("shekel:tb:api:user-a", "usd:spent", "0.10") + await backend._client.hset("shekel:tb:api:user-b", "usd:spent", "0.20") + + tenants = await backend.list_tenants(name="api") + assert sorted(tenants) == ["user-a", "user-b"] + + +# --------------------------------------------------------------------------- +# SHEK-5: summary() and summary_data() tenant display +# --------------------------------------------------------------------------- + + +class TestTenantSummary: + """b.summary() and b.summary_data() surface tenant_id when set.""" + + def test_summary_data_includes_tenant_id_when_set(self) -> None: + from shekel import budget + + b = budget(max_usd=1.00, tenant_id="user-42", name="api", backend=_make_backend()) + with b: + pass + data = b.summary_data() + assert data["tenant_id"] == "user-42" + + def test_summary_data_tenant_id_none_on_plain_budget(self) -> None: + from shekel import budget + from shekel._budget import Budget + + b = budget(max_usd=1.00) + assert isinstance(b, Budget) + with b: + pass + data = b.summary_data() + assert data.get("tenant_id") is None + + def test_summary_contains_tenant_line_when_set(self) -> None: + from shekel import budget + + b = budget(max_usd=1.00, tenant_id="user-42", name="api", backend=_make_backend()) + with b: + pass + text = b.summary() + assert "Tenant: user-42" in text + + def test_summary_has_no_tenant_line_when_not_set(self) -> None: + from shekel import budget + from shekel._temporal import TemporalBudget + + b = budget(max_usd=1.00, name="api", backend=_make_backend(), window_seconds=3600) + assert isinstance(b, TemporalBudget) + with b: + pass + text = b.summary() + assert "Tenant:" not in text + + +# --------------------------------------------------------------------------- +# SHEK-6: concurrency — 10 async tenants via asyncio.gather +# --------------------------------------------------------------------------- + + +class TestTenantConcurrency: + """10 async budget() contexts running concurrently enforce independently.""" + + @pytest.mark.asyncio + async def test_ten_concurrent_tenants_enforce_independently(self) -> None: + import asyncio + + import openai + + from shekel import budget + + server = fakeredis.FakeServer() + + def _make(): + from shekel.backends.redis import RedisBackend + + b = RedisBackend.__new__(RedisBackend) + b._url = "redis://localhost" + b._tls = False + b._on_unavailable = "closed" + b._cb_threshold = 3 + b._cb_cooldown = 10.0 + b._client = fakeredis.FakeRedis(server=server, decode_responses=False) + b._script_sha = None + b._consecutive_errors = 0 + b._circuit_open_at = None + return b + + mock_resp = _fake_openai_response(1, 1) + + async def run_tenant(tid: str) -> float: + with patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_resp, + ): + client = openai.OpenAI(api_key="test") + async with budget( + max_usd=1.00, + tenant_id=tid, + name="api", + backend=_make(), + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ) as b: + client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi"}], + ) + return b.spent + + results = await asyncio.gather(*[run_tenant(f"user-{i}") for i in range(10)]) + + # Every tenant must have recorded independent spend + assert len(results) == 10 + assert all(s > 0 for s in results) + + @pytest.mark.asyncio + async def test_concurrent_tenants_do_not_cross_contaminate(self) -> None: + """A tiny budget for one tenant doesn't affect others running simultaneously.""" + import asyncio + + import openai + + from shekel import budget + from shekel.exceptions import BudgetExceededError + + server = fakeredis.FakeServer() + + def _make(): + from shekel.backends.redis import RedisBackend + + b = RedisBackend.__new__(RedisBackend) + b._url = "redis://localhost" + b._tls = False + b._on_unavailable = "closed" + b._cb_threshold = 3 + b._cb_cooldown = 10.0 + b._client = fakeredis.FakeRedis(server=server, decode_responses=False) + b._script_sha = None + b._consecutive_errors = 0 + b._circuit_open_at = None + return b + + exceeded_tenants: list[str] = [] + succeeded_tenants: list[str] = [] + + async def run_tenant(tid: str, max_usd: float) -> None: + # price: 1 token * 0.001/1k = 0.000001 — tiny cost per call + mock_resp = _fake_openai_response(1, 1) + with patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_resp, + ): + client = openai.OpenAI(api_key="test") + try: + async with budget( + max_usd=max_usd, + tenant_id=tid, + name="api", + backend=_make(), + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ): + client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi"}], + ) + succeeded_tenants.append(tid) + except BudgetExceededError: + exceeded_tenants.append(tid) + + # user-tiny: budget 1e-6 < cost 2e-6 → exceeded; others: 5.00 >> 2e-6 → ok + await asyncio.gather( + run_tenant("user-tiny", 0.000001), + *[run_tenant(f"user-ok-{i}", 5.00) for i in range(5)], + ) + + assert "user-tiny" in exceeded_tenants + assert len(succeeded_tenants) == 5 + assert all(t.startswith("user-ok-") for t in succeeded_tenants) + + +# --------------------------------------------------------------------------- +# SHEK-6: circuit breaker — global, not per-tenant +# --------------------------------------------------------------------------- + + +class TestTenantCircuitBreaker: + """Circuit breaker is global — Redis failure affects all tenants together.""" + + def test_circuit_breaker_fires_for_all_tenants_on_redis_failure(self) -> None: + from unittest.mock import MagicMock + from unittest.mock import patch as upatch + + from shekel.backends.redis import RedisBackend + from shekel.exceptions import BudgetExceededError + + # Build a backend whose evalsha always raises (simulates Redis down) + backend = RedisBackend.__new__(RedisBackend) + backend._url = "redis://localhost" + backend._tls = False + backend._on_unavailable = "closed" + backend._cb_threshold = 1 + backend._cb_cooldown = 60.0 + backend._consecutive_errors = 0 + backend._circuit_open_at = None + backend._script_sha = "fakecha" + + mock_client = MagicMock() + mock_client.evalsha.side_effect = RuntimeError("connection refused") + backend._client = mock_client + + import openai + + from shekel import budget + + mock_resp = _fake_openai_response(1, 1) + + with upatch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_resp, + ): + client = openai.OpenAI(api_key="test") + + # First call opens the circuit breaker + with pytest.raises(BudgetExceededError): + with budget( + max_usd=1.00, + tenant_id="user-a", + name="api", + backend=backend, + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ): + client.chat.completions.create( + model="gpt-4o-mini", messages=[{"role": "user", "content": "hi"}] + ) + + # Circuit is now open — affects user-b too (same backend, global breaker) + assert backend._circuit_open_at is not None, "Circuit breaker should be open" + assert backend._is_circuit_open() + + +# --------------------------------------------------------------------------- +# SHEK-6: exception path coverage — new method error returns +# --------------------------------------------------------------------------- + + +class TestTenantErrorPaths: + """Defensive error paths in new RedisBackend tenant methods return safe defaults.""" + + def test_get_tenant_spend_returns_zero_on_redis_error(self) -> None: + from unittest.mock import MagicMock + + from shekel.backends.redis import RedisBackend + + backend = RedisBackend.__new__(RedisBackend) + backend._url = "redis://localhost" + backend._tls = False + mock_client = MagicMock() + mock_client.hget.side_effect = RuntimeError("connection refused") + backend._client = mock_client + + assert backend.get_tenant_spend(name="api", tenant_id="user-1") == 0.0 + + def test_get_tenant_limit_returns_none_on_redis_error(self) -> None: + from unittest.mock import MagicMock + + from shekel.backends.redis import RedisBackend + + backend = RedisBackend.__new__(RedisBackend) + backend._url = "redis://localhost" + backend._tls = False + mock_client = MagicMock() + mock_client.hget.side_effect = RuntimeError("connection refused") + backend._client = mock_client + + assert backend.get_tenant_limit(name="api", tenant_id="user-1") is None + + def test_list_tenants_returns_empty_on_redis_error(self) -> None: + from unittest.mock import MagicMock + + from shekel.backends.redis import RedisBackend + + backend = RedisBackend.__new__(RedisBackend) + backend._url = "redis://localhost" + backend._tls = False + mock_client = MagicMock() + mock_client.scan.side_effect = RuntimeError("connection refused") + backend._client = mock_client + + assert backend.list_tenants(name="api") == [] + + @pytest.mark.asyncio + async def test_async_get_tenant_spend_returns_zero_on_error(self) -> None: + from unittest.mock import AsyncMock, MagicMock + + from shekel.backends.redis import AsyncRedisBackend + + backend = AsyncRedisBackend.__new__(AsyncRedisBackend) + backend._url = "redis://localhost" + backend._tls = False + mock_client = MagicMock() + mock_client.hget = AsyncMock(side_effect=RuntimeError("connection refused")) + backend._client = mock_client + + assert await backend.get_tenant_spend(name="api", tenant_id="user-1") == 0.0 + + @pytest.mark.asyncio + async def test_async_get_tenant_limit_returns_none_on_error(self) -> None: + from unittest.mock import AsyncMock, MagicMock + + from shekel.backends.redis import AsyncRedisBackend + + backend = AsyncRedisBackend.__new__(AsyncRedisBackend) + backend._url = "redis://localhost" + backend._tls = False + mock_client = MagicMock() + mock_client.hget = AsyncMock(side_effect=RuntimeError("connection refused")) + backend._client = mock_client + + assert await backend.get_tenant_limit(name="api", tenant_id="user-1") is None + + @pytest.mark.asyncio + async def test_async_list_tenants_returns_empty_on_error(self) -> None: + from unittest.mock import AsyncMock, MagicMock + + from shekel.backends.redis import AsyncRedisBackend + + backend = AsyncRedisBackend.__new__(AsyncRedisBackend) + backend._url = "redis://localhost" + backend._tls = False + mock_client = MagicMock() + mock_client.scan = AsyncMock(side_effect=RuntimeError("connection refused")) + backend._client = mock_client + + assert await backend.list_tenants(name="api") == [] + + +class TestAsyncTenantEdgeCases: + """Cover remaining async method branches.""" + + @pytest.mark.asyncio + async def test_async_set_tenant_limit_with_existing_window(self) -> None: + """set_tenant_limit reads existing usd:window_s when key already populated.""" + make = _shared_async_fakeredis_backend() + backend = make() + # Seed key with usd:window_s already present (simulates a budget having run) + key = "shekel:tb:api:user-1" + await backend._client.hset(key, mapping={"usd:spent": "0.01", "usd:window_s": "2592000"}) + + # Now set_tenant_limit must read the existing window_s (lines 577-578) + await backend.set_tenant_limit(name="api", tenant_id="user-1", max_usd=0.50) + limit = await backend.get_tenant_limit(name="api", tenant_id="user-1") + assert limit == pytest.approx(0.50) + + @pytest.mark.asyncio + async def test_async_close(self) -> None: + """AsyncRedisBackend.close() releases the client without raising.""" + make = _shared_async_fakeredis_backend() + backend = make() + # Ensure _client is set + await backend._client.ping() + # close() should not raise + await backend.close()