diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6d93ba7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +# Never publish private skill variants +skills/*-private/ +skills/**/private-*/ +skills/**/internal-*/ + +# Python +__pycache__/ +*.pyc +*.pyo +*.egg-info/ +dist/ +build/ +.pytest_cache/ + +# SQLite +*.db +*.db-wal +*.db-shm diff --git a/PUBLISH-CYBERWOODS-PUBLIC.md b/PUBLISH-CYBERWOODS-PUBLIC.md new file mode 100644 index 0000000..e54e509 --- /dev/null +++ b/PUBLISH-CYBERWOODS-PUBLIC.md @@ -0,0 +1,24 @@ +# Publish Cyberwoods Public Only + +## Stage only public skill files + +```bash +git add .gitignore +git add skills/cyberwoods-public/SKILL.md +git add skills/cyberwoods-public/agents/openai.yaml +git add skills/cyberwoods-public/references/threat-model.md +git add skills/cyberwoods-public/references/adoption-checklist.md +``` + +## Commit message template + +```txt +feat(skill): add cyberwoods-public sanitized security review workflow +``` + +## Optional verify command + +```bash +git diff --cached --name-only +``` + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2b17b85 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "ai-hot-sauce" +version = "0.1.0" +description = "Scoring-based multi-model router with circuit breakers, quality gates, and session persistence" +requires-python = ">=3.11" +license = "MIT" +dependencies = [] + +[project.optional-dependencies] +openai = ["openai>=1.0"] +anthropic = ["anthropic>=0.40"] +google = ["google-genai>=1.0"] +all = ["openai>=1.0", "anthropic>=0.40", "google-genai>=1.0"] +dev = ["pytest", "pyyaml"] + +[project.scripts] +sauce = "sauce:main" + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/sauce.py b/sauce.py new file mode 100644 index 0000000..5d512d6 --- /dev/null +++ b/sauce.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +"""sauce.py — CLI entry point for the Hot Sauce engine. + +Usage: + python sauce.py "explain this error" # auto-route to best model + python sauce.py "@gemini describe this image" # explicit model override + python sauce.py --stats # show model health + breaker states + python sauce.py --rank "fix this bug" # show how models would be ranked + python sauce.py --session abc123 "continue..." # resume a session + python sauce.py --reset-breaker gemini-2.5-flash # manually reset a tripped breaker +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +from pathlib import Path + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent)) + +from src.engine import HotSauceEngine +from src.routing.scorer import classify_task + + +def _build_engine(db_path: str | None = None) -> HotSauceEngine: + """Build engine with available providers based on env vars.""" + engine = HotSauceEngine(db_path=db_path) + + # Only add providers we have keys for + if os.environ.get("OPENAI_API_KEY"): + from src.providers.openai_provider import OpenAIProvider + engine.add_provider(OpenAIProvider()) + + if os.environ.get("ANTHROPIC_API_KEY"): + from src.providers.anthropic_provider import AnthropicProvider + engine.add_provider(AnthropicProvider()) + + if os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY"): + from src.providers.google_provider import GoogleProvider + key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") + engine.add_provider(GoogleProvider(api_key=key)) + + # Always try Ollama (local, no key needed) + try: + from src.providers.ollama_provider import OllamaProvider + p = OllamaProvider() + if p._is_available(): + engine.add_provider(p) + except Exception: + pass + + return engine + + +def cmd_chat(args): + engine = _build_engine(args.db) + sid = engine.session(args.session) + result = engine.chat(args.message, session_id=sid, system=args.system) + print(result.content) + print(f"\n--- [{result.model}] {result.tokens_in}→{result.tokens_out} tokens, " + f"{result.latency_ms:.0f}ms, ${result.cost_usd:.6f} ---", file=sys.stderr) + + +def cmd_stats(args): + engine = _build_engine(args.db) + stats = engine.stats() + print(json.dumps(stats, indent=2, default=str)) + + +def cmd_rank(args): + engine = _build_engine(args.db) + task = classify_task(args.message) + ranked = engine.router.rank(task) + print(f"Task: {task.task_type} | vision={task.needs_vision} | tools={task.needs_tools}") + print(f"{'Model':<35} {'Score':>8}") + print("-" * 45) + for name, score in ranked: + breaker_state = engine.breaker.state(name) + flag = f" [{breaker_state.upper()}]" if breaker_state != "closed" else "" + print(f"{name:<35} {score:>8.4f}{flag}") + + +def cmd_reset_breaker(args): + engine = _build_engine(args.db) + engine.breaker.reset(args.model) + print(f"Breaker reset for {args.model}") + + +def main(): + parser = argparse.ArgumentParser(description="Hot Sauce — AI model router + engine") + parser.add_argument("--db", help="SQLite database path", default=None) + sub = parser.add_subparsers(dest="command") + + # Default: chat + chat_p = sub.add_parser("chat", help="Send a message (default)") + chat_p.add_argument("message", help="Your message") + chat_p.add_argument("--session", "-s", help="Session ID to resume") + chat_p.add_argument("--system", help="System prompt") + chat_p.set_defaults(func=cmd_chat) + + # Stats + stats_p = sub.add_parser("stats", help="Show model health and breaker states") + stats_p.set_defaults(func=cmd_stats) + + # Rank + rank_p = sub.add_parser("rank", help="Show model ranking for a message") + rank_p.add_argument("message", help="Message to classify and rank") + rank_p.set_defaults(func=cmd_rank) + + # Reset breaker + reset_p = sub.add_parser("reset-breaker", help="Reset a circuit breaker") + reset_p.add_argument("model", help="Model name to reset") + reset_p.set_defaults(func=cmd_reset_breaker) + + args = parser.parse_args() + + # Default to chat if bare message given + if args.command is None: + if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): + args.message = " ".join(sys.argv[1:]) + args.session = None + args.system = None + cmd_chat(args) + else: + parser.print_help() + else: + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..50375a8 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,2 @@ +# ai-hot-sauce engine +__version__ = "0.1.0" diff --git a/src/engine.py b/src/engine.py new file mode 100644 index 0000000..c350418 --- /dev/null +++ b/src/engine.py @@ -0,0 +1,162 @@ +"""Hot Sauce Engine — the main orchestrator. + +Wires together: providers → router → breaker → quality gate → persistence. +This is what replaces "just a README". +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from .providers.base import CompletionResult, Provider +from .quality.gate import QualityGate, QualityVerdict +from .routing.breaker import CircuitBreaker +from .routing.scorer import ScoringRouter, TaskProfile, classify_task +from .store.db import HotSauceDB + +log = logging.getLogger("hotsauce") + +MAX_RETRIES = 2 # retry once with same model, then fallback + + +class HotSauceEngine: + """Main entry point. Create one, call .chat().""" + + def __init__( + self, + db_path: Path | str | None = None, + providers: list[Provider] | None = None, + ): + self.db = HotSauceDB(db_path) + self.gate = QualityGate() + + # Default to empty — user adds providers they have keys for + self._providers: list[Provider] = providers or [] + self.breaker = CircuitBreaker(self.db) + self.router = ScoringRouter(self.db, self._providers, self.breaker) + self._session_id: str | None = None + + def add_provider(self, provider: Provider): + """Add a provider at runtime.""" + self._providers.append(provider) + self.router = ScoringRouter(self.db, self._providers, self.breaker) + + def session(self, session_id: str | None = None) -> str: + """Start or resume a session.""" + if session_id: + existing = self.db.get_session(session_id) + if existing: + self._session_id = session_id + return session_id + self._session_id = self.db.create_session() + return self._session_id + + def chat( + self, + message: str, + session_id: str | None = None, + system: str | None = None, + expect_json: bool = False, + **kwargs, + ) -> CompletionResult: + """Send a message, get a quality-checked response. + + Handles: routing → call → quality gate → retry/fallback → persistence. + """ + sid = session_id or self._session_id or self.session() + + # Log user turn + self.db.log_turn(sid, "user", message) + + # Build messages from session history + turns = self.db.get_turns(sid, limit=50) + messages = [] + if system: + messages.append({"role": "system", "content": system}) + for t in turns: + messages.append({"role": t["role"], "content": t["content"]}) + + # Route + model_name, provider_name, task = self.router.select(message) + ranked = self.router.rank(task) + tried_models: set[str] = set() + + for attempt in range(MAX_RETRIES + len(ranked)): + if model_name in tried_models: + # Move to next model in ranking + for name, score in ranked: + if name not in tried_models: + model_name = name + _, provider_name = self.router._model_registry[name] + break + else: + break # exhausted all models + + tried_models.add(model_name) + provider = self.router.providers.get(provider_name) + if not provider: + log.warning(f"Provider {provider_name} not registered, skipping {model_name}") + continue + + try: + result = provider.timed_complete(model_name, messages, **kwargs) + + # Record health + self.db.log_health(model_name, provider_name, True, result.latency_ms) + self.breaker.record_success(model_name) + + # Quality gate + gate_result = self.gate.check( + result.content, + task_type=task.task_type, + expect_json=expect_json, + ) + + if gate_result.passed: + # Persist and return + self.db.log_turn( + sid, "assistant", result.content, + model=model_name, provider=provider_name, + latency_ms=result.latency_ms, + tokens_in=result.tokens_in, tokens_out=result.tokens_out, + cost_usd=result.cost_usd, quality_status="pass", + ) + return result + + # Quality failed — log and try next model + log.warning(f"Quality gate failed for {model_name}: {gate_result.detail}") + self.db.log_turn( + sid, "assistant", result.content, + model=model_name, provider=provider_name, + latency_ms=result.latency_ms, + tokens_in=result.tokens_in, tokens_out=result.tokens_out, + cost_usd=result.cost_usd, + quality_status=gate_result.verdict.value, + ) + continue + + except Exception as e: + log.error(f"Model {model_name} failed: {e}") + self.db.log_health( + model_name, provider_name, False, + error_type=type(e).__name__, error_message=str(e)[:500], + ) + self.breaker.record_failure(model_name) + continue + + raise RuntimeError( + f"All models exhausted after {len(tried_models)} attempts. " + f"Tried: {tried_models}" + ) + + def stats(self) -> dict[str, Any]: + """Get current engine stats — model health, breaker states, session count.""" + result = {"models": {}, "breakers": {}} + for name in self.router._model_registry: + result["models"][name] = self.db.get_model_stats(name) + breaker = self.db.get_breaker(name) + if breaker: + result["breakers"][name] = dict(breaker) + return result diff --git a/src/providers/__init__.py b/src/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/providers/anthropic_provider.py b/src/providers/anthropic_provider.py new file mode 100644 index 0000000..72ca0bc --- /dev/null +++ b/src/providers/anthropic_provider.py @@ -0,0 +1,83 @@ +"""Anthropic provider.""" + +from __future__ import annotations + +import os +import time +from typing import Any + +from .base import CompletionResult, ModelInfo, Provider + +_anthropic = None + + +def _get_anthropic(): + global _anthropic + if _anthropic is None: + import anthropic + _anthropic = anthropic + return _anthropic + + +ANTHROPIC_MODELS = [ + ModelInfo("claude-opus-4-6", "anthropic", 200_000, 0.015, 0.075, supports_tools=True, supports_vision=True, tags=["reasoning", "code"]), + ModelInfo("claude-sonnet-4-6", "anthropic", 200_000, 0.003, 0.015, supports_tools=True, supports_vision=True, tags=["code", "fast"]), + ModelInfo("claude-haiku-4-5-20251001", "anthropic", 200_000, 0.0008, 0.004, supports_tools=True, tags=["fast", "cheap"]), +] + + +class AnthropicProvider(Provider): + name = "anthropic" + + def __init__(self, api_key: str | None = None): + anthropic = _get_anthropic() + self.client = anthropic.Anthropic( + api_key=api_key or os.environ.get("ANTHROPIC_API_KEY"), + ) + + def models(self) -> list[ModelInfo]: + return ANTHROPIC_MODELS + + def complete(self, model: str, messages: list[dict[str, str]], + temperature: float = 0.7, max_tokens: int = 4096, **kwargs) -> CompletionResult: + # Anthropic separates system from messages + system_msg = "" + chat_msgs = [] + for m in messages: + if m["role"] == "system": + system_msg += m["content"] + "\n" + else: + chat_msgs.append(m) + + start = time.perf_counter() + resp = self.client.messages.create( + model=model, + max_tokens=max_tokens, + temperature=temperature, + system=system_msg.strip() if system_msg else None, + messages=chat_msgs, + **kwargs, + ) + elapsed = (time.perf_counter() - start) * 1000 + + content = "" + for block in resp.content: + if hasattr(block, "text"): + content += block.text + + info = next((m for m in ANTHROPIC_MODELS if m.name == model), None) + cost = 0.0 + if info and resp.usage: + cost = (resp.usage.input_tokens / 1000 * info.cost_per_1k_in + + resp.usage.output_tokens / 1000 * info.cost_per_1k_out) + + return CompletionResult( + content=content, + model=model, + provider=self.name, + tokens_in=resp.usage.input_tokens if resp.usage else 0, + tokens_out=resp.usage.output_tokens if resp.usage else 0, + latency_ms=elapsed, + cost_usd=cost, + raw=resp, + ) diff --git a/src/providers/base.py b/src/providers/base.py new file mode 100644 index 0000000..fdf42e9 --- /dev/null +++ b/src/providers/base.py @@ -0,0 +1,66 @@ +"""Base provider interface and model registry.""" + +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class ModelInfo: + """Static metadata about a model.""" + name: str + provider: str + context_window: int + cost_per_1k_in: float # USD per 1k input tokens + cost_per_1k_out: float # USD per 1k output tokens + supports_tools: bool = False + supports_vision: bool = False + supports_streaming: bool = True + tags: list[str] = field(default_factory=list) # e.g. ["code", "reasoning", "fast", "local"] + + +@dataclass +class CompletionResult: + """Standardised response from any provider.""" + content: str + model: str + provider: str + tokens_in: int + tokens_out: int + latency_ms: float + cost_usd: float + raw: Any = None # original API response for debugging + + +class Provider(ABC): + """Abstract provider — one per API backend.""" + + name: str + + @abstractmethod + def models(self) -> list[ModelInfo]: + """Return all models this provider offers.""" + ... + + @abstractmethod + def complete( + self, + model: str, + messages: list[dict[str, str]], + temperature: float = 0.7, + max_tokens: int = 4096, + **kwargs, + ) -> CompletionResult: + """Send a chat completion request and return a standardised result.""" + ... + + def timed_complete(self, model: str, messages: list[dict[str, str]], **kwargs) -> CompletionResult: + """Wrapper that measures latency.""" + start = time.perf_counter() + result = self.complete(model, messages, **kwargs) + elapsed = (time.perf_counter() - start) * 1000 + result.latency_ms = elapsed + return result diff --git a/src/providers/google_provider.py b/src/providers/google_provider.py new file mode 100644 index 0000000..ff195a1 --- /dev/null +++ b/src/providers/google_provider.py @@ -0,0 +1,87 @@ +"""Google Gemini provider.""" + +from __future__ import annotations + +import os +import time + +from .base import CompletionResult, ModelInfo, Provider + +_google = None + + +def _get_google(): + global _google + if _google is None: + from google import genai + _google = genai + return _google + + +GOOGLE_MODELS = [ + ModelInfo("gemini-2.5-flash", "google", 1_000_000, 0.0, 0.0, supports_tools=True, supports_vision=True, tags=["fast", "free", "code"]), + ModelInfo("gemini-2.5-pro", "google", 1_000_000, 0.00125, 0.01, supports_tools=True, supports_vision=True, tags=["reasoning", "code"]), +] + + +class GoogleProvider(Provider): + name = "google" + + def __init__(self, api_key: str | None = None): + genai = _get_google() + self.client = genai.Client(api_key=api_key or os.environ.get("GOOGLE_API_KEY")) + + def models(self) -> list[ModelInfo]: + return GOOGLE_MODELS + + def complete(self, model: str, messages: list[dict[str, str]], + temperature: float = 0.7, max_tokens: int = 4096, **kwargs) -> CompletionResult: + genai = _get_google() + + # Convert chat format to Gemini contents + contents = [] + system_instruction = None + for m in messages: + if m["role"] == "system": + system_instruction = m["content"] + else: + role = "user" if m["role"] == "user" else "model" + contents.append(genai.types.Content( + role=role, + parts=[genai.types.Part(text=m["content"])], + )) + + config = genai.types.GenerateContentConfig( + temperature=temperature, + max_output_tokens=max_tokens, + system_instruction=system_instruction, + ) + + start = time.perf_counter() + resp = self.client.models.generate_content( + model=model, + contents=contents, + config=config, + ) + elapsed = (time.perf_counter() - start) * 1000 + + content = resp.text or "" + tokens_in = getattr(resp.usage_metadata, "prompt_token_count", 0) or 0 + tokens_out = getattr(resp.usage_metadata, "candidates_token_count", 0) or 0 + + info = next((m for m in GOOGLE_MODELS if m.name == model), None) + cost = 0.0 + if info: + cost = (tokens_in / 1000 * info.cost_per_1k_in + + tokens_out / 1000 * info.cost_per_1k_out) + + return CompletionResult( + content=content, + model=model, + provider=self.name, + tokens_in=tokens_in, + tokens_out=tokens_out, + latency_ms=elapsed, + cost_usd=cost, + raw=resp, + ) diff --git a/src/providers/ollama_provider.py b/src/providers/ollama_provider.py new file mode 100644 index 0000000..1c10c36 --- /dev/null +++ b/src/providers/ollama_provider.py @@ -0,0 +1,74 @@ +"""Ollama local model provider.""" + +from __future__ import annotations + +import json +import time +import urllib.request + +from .base import CompletionResult, ModelInfo, Provider + +# Common local models — costs are always 0 +OLLAMA_MODELS = [ + ModelInfo("llama3.2", "ollama", 128_000, 0.0, 0.0, tags=["local", "free"]), + ModelInfo("phi4-mini", "ollama", 128_000, 0.0, 0.0, tags=["local", "free", "fast"]), + ModelInfo("codellama", "ollama", 16_000, 0.0, 0.0, tags=["local", "free", "code"]), + ModelInfo("mistral", "ollama", 32_000, 0.0, 0.0, tags=["local", "free"]), +] + + +class OllamaProvider(Provider): + name = "ollama" + + def __init__(self, base_url: str = "http://localhost:11434"): + self.base_url = base_url.rstrip("/") + + def models(self) -> list[ModelInfo]: + return OLLAMA_MODELS + + def _is_available(self) -> bool: + try: + req = urllib.request.Request(f"{self.base_url}/api/tags", method="GET") + urllib.request.urlopen(req, timeout=2) + return True + except Exception: + return False + + def complete(self, model: str, messages: list[dict[str, str]], + temperature: float = 0.7, max_tokens: int = 4096, **kwargs) -> CompletionResult: + payload = json.dumps({ + "model": model, + "messages": messages, + "stream": False, + "options": { + "temperature": temperature, + "num_predict": max_tokens, + }, + }).encode() + + req = urllib.request.Request( + f"{self.base_url}/api/chat", + data=payload, + headers={"Content-Type": "application/json"}, + method="POST", + ) + + start = time.perf_counter() + with urllib.request.urlopen(req, timeout=120) as resp: + data = json.loads(resp.read()) + elapsed = (time.perf_counter() - start) * 1000 + + content = data.get("message", {}).get("content", "") + tokens_in = data.get("prompt_eval_count", 0) + tokens_out = data.get("eval_count", 0) + + return CompletionResult( + content=content, + model=model, + provider=self.name, + tokens_in=tokens_in, + tokens_out=tokens_out, + latency_ms=elapsed, + cost_usd=0.0, + raw=data, + ) diff --git a/src/providers/openai_provider.py b/src/providers/openai_provider.py new file mode 100644 index 0000000..7ec7b99 --- /dev/null +++ b/src/providers/openai_provider.py @@ -0,0 +1,75 @@ +"""OpenAI / OpenAI-compatible provider (also covers OpenRouter, local vLLM, etc).""" + +from __future__ import annotations + +import os +import time +from typing import Any + +from .base import CompletionResult, ModelInfo, Provider + +# Lazy import — don't crash if openai isn't installed +_openai = None + + +def _get_openai(): + global _openai + if _openai is None: + import openai + _openai = openai + return _openai + + +# Default model catalogue — costs in USD per 1k tokens (approximate) +OPENAI_MODELS = [ + ModelInfo("gpt-4.1", "openai", 1_000_000, 0.002, 0.008, supports_tools=True, supports_vision=True, tags=["code", "reasoning"]), + ModelInfo("gpt-4.1-mini", "openai", 1_000_000, 0.0004, 0.0016, supports_tools=True, supports_vision=True, tags=["fast", "code"]), + ModelInfo("gpt-4.1-nano", "openai", 1_000_000, 0.0001, 0.0004, supports_tools=True, tags=["fast", "cheap"]), + ModelInfo("o3", "openai", 200_000, 0.01, 0.04, supports_tools=True, tags=["reasoning"]), + ModelInfo("o4-mini", "openai", 200_000, 0.0011, 0.0044, supports_tools=True, tags=["reasoning", "fast"]), +] + + +class OpenAIProvider(Provider): + name = "openai" + + def __init__(self, api_key: str | None = None, base_url: str | None = None): + openai = _get_openai() + self.client = openai.OpenAI( + api_key=api_key or os.environ.get("OPENAI_API_KEY"), + base_url=base_url, + ) + + def models(self) -> list[ModelInfo]: + return OPENAI_MODELS + + def complete(self, model: str, messages: list[dict[str, str]], + temperature: float = 0.7, max_tokens: int = 4096, **kwargs) -> CompletionResult: + start = time.perf_counter() + resp = self.client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + **kwargs, + ) + elapsed = (time.perf_counter() - start) * 1000 + choice = resp.choices[0] + usage = resp.usage + + info = next((m for m in OPENAI_MODELS if m.name == model), None) + cost = 0.0 + if info and usage: + cost = (usage.prompt_tokens / 1000 * info.cost_per_1k_in + + usage.completion_tokens / 1000 * info.cost_per_1k_out) + + return CompletionResult( + content=choice.message.content or "", + model=model, + provider=self.name, + tokens_in=usage.prompt_tokens if usage else 0, + tokens_out=usage.completion_tokens if usage else 0, + latency_ms=elapsed, + cost_usd=cost, + raw=resp, + ) diff --git a/src/quality/__init__.py b/src/quality/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/quality/gate.py b/src/quality/gate.py new file mode 100644 index 0000000..9c7ebb2 --- /dev/null +++ b/src/quality/gate.py @@ -0,0 +1,112 @@ +"""Quality gate — catch garbage responses before they reach the user. + +Runs fast heuristic checks on model output: + 1. Too short (< min_length chars for non-trivial tasks) + 2. Empty or whitespace-only + 3. Greeting/filler patterns ("Hello! How can I help you today?") + 4. Refusal when none was expected + 5. Repeated text (model looping) + 6. Malformed JSON when JSON was requested + +On failure: one retry with stricter prompt, then fallback to next model. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from enum import Enum + + +class QualityVerdict(Enum): + PASS = "pass" + FAIL_EMPTY = "fail_empty" + FAIL_TOO_SHORT = "fail_too_short" + FAIL_GREETING = "fail_greeting" + FAIL_REFUSAL = "fail_refusal" + FAIL_REPETITION = "fail_repetition" + FAIL_JSON = "fail_json" + + +@dataclass +class GateResult: + verdict: QualityVerdict + detail: str = "" + + @property + def passed(self) -> bool: + return self.verdict == QualityVerdict.PASS + + +# Patterns that indicate a useless "polite but empty" response +GREETING_PATTERNS = [ + r"^(hi|hello|hey)[!.,]?\s*(how can i|what can i|i'?m here to)\s*(help|assist)", + r"^(sure|of course|absolutely)[!.,]?\s*(i'?d be happy to|let me)\s*(help|assist)", + r"^i'?m\s+(an?\s+)?(ai|language model|assistant)", + r"^as an ai", +] + +REFUSAL_PATTERNS = [ + r"i (can'?t|cannot|am unable to|don'?t|do not) (help|assist|provide|generate|create) (with )?(that|this)", + r"i'?m (not able|unable) to", + r"(against|violates?) my (guidelines|policy|programming)", + r"i (must|have to) (decline|refuse)", +] + +REPETITION_THRESHOLD = 3 # same phrase repeated N+ times = looping + + +class QualityGate: + def __init__(self, min_length: int = 20, max_repetition: int = REPETITION_THRESHOLD): + self.min_length = min_length + self.max_repetition = max_repetition + + def check(self, response: str, task_type: str = "quick", + expect_json: bool = False) -> GateResult: + """Run all quality checks. Returns first failure or PASS.""" + + # 1. Empty + stripped = response.strip() + if not stripped: + return GateResult(QualityVerdict.FAIL_EMPTY, "Response is empty") + + # 2. Too short (skip for "quick" tasks where short answers are valid) + if task_type != "quick" and len(stripped) < self.min_length: + return GateResult(QualityVerdict.FAIL_TOO_SHORT, + f"Response too short ({len(stripped)} chars, min {self.min_length})") + + # 3. Greeting/filler + lower = stripped.lower() + for pattern in GREETING_PATTERNS: + if re.search(pattern, lower): + return GateResult(QualityVerdict.FAIL_GREETING, + f"Response is filler/greeting: {stripped[:80]}") + + # 4. Unexpected refusal + for pattern in REFUSAL_PATTERNS: + if re.search(pattern, lower): + return GateResult(QualityVerdict.FAIL_REFUSAL, + f"Model refused: {stripped[:80]}") + + # 5. Repetition detection + # Split into sentences, check for loops + sentences = re.split(r'[.!?\n]+', stripped) + sentences = [s.strip().lower() for s in sentences if len(s.strip()) > 10] + if sentences: + from collections import Counter + counts = Counter(sentences) + most_common_count = counts.most_common(1)[0][1] if counts else 0 + if most_common_count >= self.max_repetition: + return GateResult(QualityVerdict.FAIL_REPETITION, + f"Repeated phrase {most_common_count}x") + + # 6. JSON validation + if expect_json: + import json + try: + json.loads(stripped) + except json.JSONDecodeError as e: + return GateResult(QualityVerdict.FAIL_JSON, + f"Invalid JSON: {e}") + + return GateResult(QualityVerdict.PASS) diff --git a/src/routing/__init__.py b/src/routing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/routing/breaker.py b/src/routing/breaker.py new file mode 100644 index 0000000..6010648 --- /dev/null +++ b/src/routing/breaker.py @@ -0,0 +1,106 @@ +"""Circuit breaker — per-model failure tracking. + +States: + CLOSED → normal operation, failures counted + OPEN → model is down, all requests rejected immediately + HALF_OPEN → cooldown expired, probe with 1 request + +Transitions: + CLOSED → OPEN when fail_count >= threshold in rolling window + OPEN → HALF_OPEN when now >= next_probe_at + HALF_OPEN → CLOSED on success (success_streak >= recovery_threshold) + HALF_OPEN → OPEN on failure (reset cooldown) +""" + +from __future__ import annotations + +import time + +from ..store.db import HotSauceDB + +FAIL_THRESHOLD = 3 +COOLDOWN_SECONDS = 60.0 +RECOVERY_SUCCESSES = 2 + + +class CircuitBreaker: + def __init__(self, db: HotSauceDB, fail_threshold: int = FAIL_THRESHOLD, + cooldown: float = COOLDOWN_SECONDS, recovery: int = RECOVERY_SUCCESSES): + self.db = db + self.fail_threshold = fail_threshold + self.cooldown = cooldown + self.recovery = recovery + + def state(self, model: str) -> str: + """Get current breaker state, auto-transitioning OPEN → HALF_OPEN if cooldown elapsed.""" + rec = self.db.get_breaker(model) + if rec is None: + return "closed" + + if rec["state"] == "open" and rec["next_probe_at"] and time.time() >= rec["next_probe_at"]: + self.db.upsert_breaker(model, "half_open", fail_count=rec["fail_count"], + success_streak=0, opened_at=rec["opened_at"]) + return "half_open" + + return rec["state"] + + def is_available(self, model: str) -> bool: + """Can we send a request to this model right now?""" + s = self.state(model) + return s in ("closed", "half_open") + + def record_success(self, model: str): + """Record a successful call.""" + rec = self.db.get_breaker(model) + if rec is None: + self.db.upsert_breaker(model, "closed", fail_count=0, success_streak=1) + return + + new_streak = rec["success_streak"] + 1 + + if rec["state"] == "half_open" and new_streak >= self.recovery: + # Recovered — close the breaker + self.db.upsert_breaker(model, "closed", fail_count=0, success_streak=0) + else: + self.db.upsert_breaker( + model, rec["state"], fail_count=max(0, rec["fail_count"] - 1), + success_streak=new_streak, + opened_at=rec["opened_at"], + next_probe_at=rec["next_probe_at"], + ) + + def record_failure(self, model: str): + """Record a failed call, potentially tripping the breaker.""" + rec = self.db.get_breaker(model) + now = time.time() + + if rec is None: + if 1 >= self.fail_threshold: + self._trip(model, 1, now) + else: + self.db.upsert_breaker(model, "closed", fail_count=1, success_streak=0) + return + + new_fails = rec["fail_count"] + 1 + + if rec["state"] == "half_open": + # Probe failed — back to open + self._trip(model, new_fails, now) + elif rec["state"] == "closed" and new_fails >= self.fail_threshold: + self._trip(model, new_fails, now) + else: + self.db.upsert_breaker( + model, rec["state"], fail_count=new_fails, success_streak=0, + opened_at=rec.get("opened_at"), next_probe_at=rec.get("next_probe_at"), + ) + + def _trip(self, model: str, fail_count: int, now: float): + """Trip breaker to OPEN.""" + self.db.upsert_breaker( + model, "open", fail_count=fail_count, success_streak=0, + opened_at=now, next_probe_at=now + self.cooldown, + ) + + def reset(self, model: str): + """Manually reset a breaker.""" + self.db.upsert_breaker(model, "closed", fail_count=0, success_streak=0) diff --git a/src/routing/scorer.py b/src/routing/scorer.py new file mode 100644 index 0000000..00377f0 --- /dev/null +++ b/src/routing/scorer.py @@ -0,0 +1,198 @@ +"""Scoring-based model router. + +Replaces the static markdown routing table with weighted scoring. +Each model gets a composite score based on: + - capability_fit: does this model match the task's needs? + - latency: historical p95 latency (lower is better) + - cost: cost per 1k tokens (lower is better) + - reliability: recent success rate from health telemetry + - breaker_penalty: heavy penalty if breaker is open + +The highest-scoring available model wins. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ..providers.base import ModelInfo +from ..store.db import HotSauceDB +from .breaker import CircuitBreaker + +if TYPE_CHECKING: + from ..providers.base import Provider + + +@dataclass +class TaskProfile: + """Classified task for routing decisions.""" + task_type: str # "code", "reasoning", "quick", "vision", "creative" + estimated_tokens: int = 1000 + needs_tools: bool = False + needs_vision: bool = False + explicit_model: str | None = None # user override e.g. @gemini + + +# Tag-to-task affinity scores (0.0 to 1.0) +TAG_AFFINITY = { + "code": {"code": 1.0, "reasoning": 0.6, "quick": 0.3, "vision": 0.2, "creative": 0.3}, + "reasoning": {"code": 0.5, "reasoning": 1.0, "quick": 0.2, "vision": 0.3, "creative": 0.6}, + "fast": {"code": 0.4, "reasoning": 0.2, "quick": 1.0, "vision": 0.4, "creative": 0.5}, + "cheap": {"code": 0.3, "reasoning": 0.1, "quick": 0.8, "vision": 0.3, "creative": 0.4}, + "local": {"code": 0.3, "reasoning": 0.2, "quick": 0.5, "vision": 0.0, "creative": 0.3}, + "free": {"code": 0.4, "reasoning": 0.3, "quick": 0.9, "vision": 0.4, "creative": 0.5}, +} + +# Scoring weights +W_CAPABILITY = 0.35 +W_LATENCY = 0.15 +W_COST = 0.25 +W_RELIABILITY = 0.20 +W_BREAKER = 0.50 # additive penalty + + +def classify_task(user_message: str) -> TaskProfile: + """Simple heuristic intent classifier. Replace with LLM classifier for production.""" + msg = user_message.lower() + + # Check for explicit model override + explicit = None + override_match = re.search(r"@(\w+)", user_message) + if override_match: + explicit = override_match.group(1) + + needs_vision = any(kw in msg for kw in ["image", "screenshot", "picture", "photo", "diagram"]) + needs_tools = any(kw in msg for kw in ["search the web", "search for", "browse", "fetch url", "run command", "execute command"]) + + # Task classification + code_signals = ["code", "function", "bug", "error", "debug", "refactor", "implement", + "class ", "def ", "```", "syntax", "compile", "test"] + reasoning_signals = ["explain", "why", "analyse", "compare", "trade-off", "design", + "architect", "plan", "strategy", "evaluate"] + quick_signals = ["what is", "define", "tldr", "summarise", "summary", "list", "name"] + + code_score = sum(1 for s in code_signals if s in msg) + reasoning_score = sum(1 for s in reasoning_signals if s in msg) + quick_score = sum(1 for s in quick_signals if s in msg) + + if needs_vision: + task_type = "vision" + elif code_score > reasoning_score and code_score > quick_score: + task_type = "code" + elif reasoning_score > quick_score: + task_type = "reasoning" + elif quick_score > 0: + task_type = "quick" + else: + task_type = "quick" # default to cheapest + + estimated = 500 if task_type == "quick" else 2000 if task_type == "code" else 1500 + + return TaskProfile( + task_type=task_type, + estimated_tokens=estimated, + needs_tools=needs_tools, + needs_vision=needs_vision, + explicit_model=explicit, + ) + + +def _capability_score(model: ModelInfo, task: TaskProfile) -> float: + """How well does this model fit the task?""" + score = 0.0 + for tag in model.tags: + affinities = TAG_AFFINITY.get(tag, {}) + score = max(score, affinities.get(task.task_type, 0.0)) + return score + + +def _cost_score(model: ModelInfo) -> float: + """Normalised cost score — lower cost = higher score. Free models get 1.0.""" + avg_cost = (model.cost_per_1k_in + model.cost_per_1k_out) / 2 + if avg_cost == 0: + return 1.0 + return min(1.0, 0.01 / avg_cost) # 0.01/1k as reference point + + +def _latency_score(stats: dict) -> float: + """Normalised latency score from health telemetry.""" + avg = stats.get("avg_latency_ms") + if avg is None or avg == 0: + return 0.5 # unknown — neutral + return min(1.0, 2000 / avg) # 2s as reference point + + +class ScoringRouter: + def __init__(self, db: HotSauceDB, providers: list[Provider], breaker: CircuitBreaker): + self.db = db + self.providers = {p.name: p for p in providers} + self.breaker = breaker + self._model_registry: dict[str, tuple[ModelInfo, str]] = {} + self._rebuild_registry() + + def _rebuild_registry(self): + """Build flat lookup of model_name → (ModelInfo, provider_name).""" + self._model_registry.clear() + for provider in self.providers.values(): + for model in provider.models(): + self._model_registry[model.name] = (model, provider.name) + + def score_model(self, model: ModelInfo, task: TaskProfile) -> float: + """Compute composite score for a model given a task.""" + # Hard filters + if task.needs_vision and not model.supports_vision: + return -1.0 + if task.needs_tools and not model.supports_tools: + return -1.0 + if task.estimated_tokens > model.context_window: + return -1.0 + + cap = _capability_score(model, task) + cost = _cost_score(model) + stats = self.db.get_model_stats(model.name) + lat = _latency_score(stats) + rel = stats.get("success_rate", 1.0) + + breaker_penalty = 0.0 + if not self.breaker.is_available(model.name): + breaker_penalty = 1.0 + + score = (W_CAPABILITY * cap + + W_LATENCY * lat + + W_COST * cost + + W_RELIABILITY * rel - + W_BREAKER * breaker_penalty) + + return round(score, 4) + + def rank(self, task: TaskProfile) -> list[tuple[str, float]]: + """Return all models ranked by score, descending.""" + scores = [] + for name, (model, provider) in self._model_registry.items(): + s = self.score_model(model, task) + if s >= 0: + scores.append((name, s)) + scores.sort(key=lambda x: x[1], reverse=True) + return scores + + def select(self, user_message: str) -> tuple[str, str, TaskProfile]: + """Classify task and select best model. Returns (model_name, provider_name, task_profile).""" + task = classify_task(user_message) + + # Explicit override + if task.explicit_model: + for name, (model, provider) in self._model_registry.items(): + if task.explicit_model.lower() in name.lower() or task.explicit_model.lower() == provider.lower(): + if self.breaker.is_available(name): + return name, provider, task + # Fall through to scoring if override not found / breaker tripped + + ranked = self.rank(task) + if not ranked: + raise RuntimeError("No available models for this task") + + best_name = ranked[0][0] + _, provider = self._model_registry[best_name] + return best_name, provider, task diff --git a/src/store/__init__.py b/src/store/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/store/db.py b/src/store/db.py new file mode 100644 index 0000000..23dd5a5 --- /dev/null +++ b/src/store/db.py @@ -0,0 +1,248 @@ +"""SQLite persistence layer for the Hot Sauce engine. + +Stores sessions, turns, model health telemetry, circuit breaker state, +and eval results. Local-first — no external dependencies. +""" + +import json +import sqlite3 +import time +import uuid +from contextlib import contextmanager +from pathlib import Path +from typing import Any + +DEFAULT_DB_PATH = Path.home() / ".agent" / "hotsauce.db" + +SCHEMA = """ +CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + created_at REAL NOT NULL, + updated_at REAL NOT NULL, + metadata_json TEXT DEFAULT '{}' +); + +CREATE TABLE IF NOT EXISTS turns ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + role TEXT NOT NULL CHECK(role IN ('user', 'assistant', 'system')), + content TEXT NOT NULL, + model TEXT, + provider TEXT, + latency_ms REAL, + tokens_in INTEGER, + tokens_out INTEGER, + cost_usd REAL, + quality_status TEXT CHECK(quality_status IN ('pass', 'fail', 'retry', 'fallback', NULL)), + created_at REAL NOT NULL, + FOREIGN KEY (session_id) REFERENCES sessions(id) +); + +CREATE TABLE IF NOT EXISTS model_health ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + model TEXT NOT NULL, + provider TEXT NOT NULL, + ts REAL NOT NULL, + success INTEGER NOT NULL CHECK(success IN (0, 1)), + latency_ms REAL, + error_type TEXT, + error_message TEXT +); + +CREATE TABLE IF NOT EXISTS breaker_state ( + model TEXT PRIMARY KEY, + state TEXT NOT NULL DEFAULT 'closed' CHECK(state IN ('closed', 'open', 'half_open')), + fail_count INTEGER NOT NULL DEFAULT 0, + success_streak INTEGER NOT NULL DEFAULT 0, + opened_at REAL, + next_probe_at REAL, + updated_at REAL NOT NULL +); + +CREATE TABLE IF NOT EXISTS eval_cases ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + input_json TEXT NOT NULL, + assertions_json TEXT NOT NULL, + created_at REAL NOT NULL +); + +CREATE TABLE IF NOT EXISTS eval_runs ( + id TEXT PRIMARY KEY, + case_id TEXT NOT NULL, + model TEXT NOT NULL, + passed INTEGER NOT NULL CHECK(passed IN (0, 1)), + metrics_json TEXT DEFAULT '{}', + created_at REAL NOT NULL, + FOREIGN KEY (case_id) REFERENCES eval_cases(id) +); + +CREATE INDEX IF NOT EXISTS idx_turns_session ON turns(session_id); +CREATE INDEX IF NOT EXISTS idx_turns_model ON turns(model); +CREATE INDEX IF NOT EXISTS idx_health_model_ts ON model_health(model, ts); +CREATE INDEX IF NOT EXISTS idx_eval_runs_case ON eval_runs(case_id); +""" + + +class HotSauceDB: + def __init__(self, db_path: Path | str | None = None): + self._in_memory = db_path == ":memory:" + if not self._in_memory: + self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH + self.db_path.parent.mkdir(parents=True, exist_ok=True) + else: + self.db_path = None + self._conn: sqlite3.Connection | None = None + self._init_schema() + + def _connect(self) -> sqlite3.Connection: + if self._conn is None: + target = ":memory:" if self._in_memory else str(self.db_path) + self._conn = sqlite3.connect(target) + self._conn.row_factory = sqlite3.Row + if not self._in_memory: + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA foreign_keys=ON") + return self._conn + + @contextmanager + def _tx(self): + conn = self._connect() + try: + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + + def _init_schema(self): + conn = self._connect() + conn.executescript(SCHEMA) + + def close(self): + """Close the database connection. Call before deleting the DB file on Windows.""" + if self._conn is not None: + self._conn.close() + self._conn = None + + # -- Sessions -- + + def create_session(self, metadata: dict | None = None) -> str: + sid = str(uuid.uuid4()) + now = time.time() + with self._tx() as conn: + conn.execute( + "INSERT INTO sessions (id, created_at, updated_at, metadata_json) VALUES (?, ?, ?, ?)", + (sid, now, now, json.dumps(metadata or {})), + ) + return sid + + def get_session(self, session_id: str) -> dict | None: + conn = self._connect() + row = conn.execute("SELECT * FROM sessions WHERE id = ?", (session_id,)).fetchone() + return dict(row) if row else None + + # -- Turns -- + + def log_turn( + self, + session_id: str, + role: str, + content: str, + model: str | None = None, + provider: str | None = None, + latency_ms: float | None = None, + tokens_in: int | None = None, + tokens_out: int | None = None, + cost_usd: float | None = None, + quality_status: str | None = None, + ) -> str: + tid = str(uuid.uuid4()) + now = time.time() + with self._tx() as conn: + conn.execute( + """INSERT INTO turns + (id, session_id, role, content, model, provider, + latency_ms, tokens_in, tokens_out, cost_usd, quality_status, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + (tid, session_id, role, content, model, provider, + latency_ms, tokens_in, tokens_out, cost_usd, quality_status, now), + ) + conn.execute( + "UPDATE sessions SET updated_at = ? WHERE id = ?", + (now, session_id), + ) + return tid + + def get_turns(self, session_id: str, limit: int = 50) -> list[dict]: + conn = self._connect() + rows = conn.execute( + "SELECT * FROM turns WHERE session_id = ? ORDER BY created_at ASC LIMIT ?", + (session_id, limit), + ).fetchall() + return [dict(r) for r in rows] + + # -- Model Health -- + + def log_health( + self, + model: str, + provider: str, + success: bool, + latency_ms: float | None = None, + error_type: str | None = None, + error_message: str | None = None, + ): + with self._tx() as conn: + conn.execute( + """INSERT INTO model_health (model, provider, ts, success, latency_ms, error_type, error_message) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + (model, provider, time.time(), int(success), latency_ms, error_type, error_message), + ) + + def get_model_stats(self, model: str, window_seconds: float = 300) -> dict[str, Any]: + """Get success rate and latency stats for a model over a time window.""" + cutoff = time.time() - window_seconds + conn = self._connect() + row = conn.execute( + """SELECT + COUNT(*) as total, + SUM(success) as successes, + AVG(CASE WHEN success = 1 THEN latency_ms END) as avg_latency, + MAX(CASE WHEN success = 1 THEN latency_ms END) as p_max_latency + FROM model_health + WHERE model = ? AND ts > ?""", + (model, cutoff), + ).fetchone() + total = row["total"] or 0 + successes = row["successes"] or 0 + return { + "model": model, + "total": total, + "successes": successes, + "success_rate": successes / total if total > 0 else 1.0, + "avg_latency_ms": row["avg_latency"], + "max_latency_ms": row["p_max_latency"], + } + + # -- Breaker State -- + + def get_breaker(self, model: str) -> dict | None: + conn = self._connect() + row = conn.execute("SELECT * FROM breaker_state WHERE model = ?", (model,)).fetchone() + return dict(row) if row else None + + def upsert_breaker(self, model: str, state: str, fail_count: int = 0, + success_streak: int = 0, opened_at: float | None = None, + next_probe_at: float | None = None): + now = time.time() + with self._tx() as conn: + conn.execute( + """INSERT INTO breaker_state (model, state, fail_count, success_streak, opened_at, next_probe_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(model) DO UPDATE SET + state=excluded.state, fail_count=excluded.fail_count, + success_streak=excluded.success_streak, opened_at=excluded.opened_at, + next_probe_at=excluded.next_probe_at, updated_at=excluded.updated_at""", + (model, state, fail_count, success_streak, opened_at, next_probe_at, now), + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/eval_cases/routing.yaml b/tests/eval_cases/routing.yaml new file mode 100644 index 0000000..8bd85e5 --- /dev/null +++ b/tests/eval_cases/routing.yaml @@ -0,0 +1,62 @@ +# Self-eval test vectors for routing classification +# Run with: python -m pytest tests/test_eval.py -v + +- name: simple_question + input: "what is a closure" + expect_type: quick + expect_vision: false + expect_tools: false + +- name: code_debug + input: "fix this bug in the function that parses JSON" + expect_type: code + expect_vision: false + expect_tools: false + +- name: architecture_reasoning + input: "explain why microservices are better than monoliths for this use case" + expect_type: reasoning + expect_vision: false + expect_tools: false + +- name: screenshot_vision + input: "describe what's in this screenshot" + expect_type: vision + expect_vision: true + expect_tools: false + +- name: web_search + input: "search for recent papers on transformer attention" + expect_type: quick + expect_vision: false + expect_tools: true + +- name: code_with_override + input: "@gemini refactor this class" + expect_type: code + expect_vision: false + expect_tools: false + +- name: definition + input: "define polymorphism" + expect_type: quick + expect_vision: false + expect_tools: false + +- name: compare_tradeoffs + input: "compare the trade-offs between REST and GraphQL" + expect_type: reasoning + expect_vision: false + expect_tools: false + +- name: implement_feature + input: "implement a binary search tree with insert and delete" + expect_type: code + expect_vision: false + expect_tools: false + +- name: image_analysis + input: "what's wrong with this diagram of the database schema" + expect_type: vision + expect_vision: true + expect_tools: false diff --git a/tests/test_breaker.py b/tests/test_breaker.py new file mode 100644 index 0000000..f8ffbbc --- /dev/null +++ b/tests/test_breaker.py @@ -0,0 +1,69 @@ +"""Tests for the circuit breaker.""" + +import time +import unittest + +from src.routing.breaker import CircuitBreaker +from src.store.db import HotSauceDB + + +class TestCircuitBreaker(unittest.TestCase): + def setUp(self): + self.db = HotSauceDB(":memory:") + self.breaker = CircuitBreaker(self.db, fail_threshold=3, cooldown=0.5, recovery=2) + + def tearDown(self): + self.db.close() + + def test_starts_closed(self): + self.assertEqual(self.breaker.state("test-model"), "closed") + self.assertTrue(self.breaker.is_available("test-model")) + + def test_trips_after_threshold(self): + for _ in range(3): + self.breaker.record_failure("test-model") + self.assertEqual(self.breaker.state("test-model"), "open") + self.assertFalse(self.breaker.is_available("test-model")) + + def test_does_not_trip_below_threshold(self): + self.breaker.record_failure("test-model") + self.breaker.record_failure("test-model") + self.assertEqual(self.breaker.state("test-model"), "closed") + + def test_transitions_to_half_open(self): + for _ in range(3): + self.breaker.record_failure("test-model") + self.assertEqual(self.breaker.state("test-model"), "open") + # Wait for cooldown + time.sleep(0.6) + self.assertEqual(self.breaker.state("test-model"), "half_open") + self.assertTrue(self.breaker.is_available("test-model")) + + def test_recovers_from_half_open(self): + for _ in range(3): + self.breaker.record_failure("test-model") + time.sleep(0.6) + self.assertEqual(self.breaker.state("test-model"), "half_open") + # Successful probes + self.breaker.record_success("test-model") + self.breaker.record_success("test-model") + self.assertEqual(self.breaker.state("test-model"), "closed") + + def test_half_open_failure_reopens(self): + for _ in range(3): + self.breaker.record_failure("test-model") + time.sleep(0.6) + self.assertEqual(self.breaker.state("test-model"), "half_open") + self.breaker.record_failure("test-model") + self.assertEqual(self.breaker.state("test-model"), "open") + + def test_manual_reset(self): + for _ in range(3): + self.breaker.record_failure("test-model") + self.assertEqual(self.breaker.state("test-model"), "open") + self.breaker.reset("test-model") + self.assertEqual(self.breaker.state("test-model"), "closed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..d3e9b3e --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,59 @@ +"""Tests for the SQLite persistence layer.""" + +import unittest + +from src.store.db import HotSauceDB + + +class TestHotSauceDB(unittest.TestCase): + def setUp(self): + self.db = HotSauceDB(":memory:") + + def tearDown(self): + self.db.close() + + def test_create_session(self): + sid = self.db.create_session({"project": "test"}) + self.assertIsNotNone(sid) + session = self.db.get_session(sid) + self.assertIsNotNone(session) + + def test_log_and_get_turns(self): + sid = self.db.create_session() + self.db.log_turn(sid, "user", "hello") + self.db.log_turn(sid, "assistant", "hi there", model="test-model") + turns = self.db.get_turns(sid) + self.assertEqual(len(turns), 2) + self.assertEqual(turns[0]["role"], "user") + self.assertEqual(turns[1]["model"], "test-model") + + def test_health_logging(self): + self.db.log_health("test-model", "test", True, 150.0) + self.db.log_health("test-model", "test", True, 200.0) + self.db.log_health("test-model", "test", False, error_type="Timeout") + stats = self.db.get_model_stats("test-model") + self.assertEqual(stats["total"], 3) + self.assertEqual(stats["successes"], 2) + self.assertAlmostEqual(stats["success_rate"], 2/3, places=2) + + def test_breaker_upsert(self): + self.db.upsert_breaker("test-model", "closed", fail_count=0) + rec = self.db.get_breaker("test-model") + self.assertEqual(rec["state"], "closed") + # Update + self.db.upsert_breaker("test-model", "open", fail_count=3) + rec = self.db.get_breaker("test-model") + self.assertEqual(rec["state"], "open") + self.assertEqual(rec["fail_count"], 3) + + def test_nonexistent_session(self): + self.assertIsNone(self.db.get_session("nonexistent")) + + def test_empty_stats(self): + stats = self.db.get_model_stats("never-seen") + self.assertEqual(stats["total"], 0) + self.assertEqual(stats["success_rate"], 1.0) # optimistic default + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_eval.py b/tests/test_eval.py new file mode 100644 index 0000000..1df723a --- /dev/null +++ b/tests/test_eval.py @@ -0,0 +1,52 @@ +"""Self-eval tests — verify routing classification against test vectors.""" + +import unittest +from pathlib import Path + +import yaml + +from src.routing.scorer import classify_task + +EVAL_DIR = Path(__file__).parent / "eval_cases" + + +class TestRoutingEval(unittest.TestCase): + """Parameterised tests from YAML eval cases.""" + + @classmethod + def setUpClass(cls): + cls.cases = [] + for yaml_file in EVAL_DIR.glob("*.yaml"): + with open(yaml_file) as f: + cls.cases.extend(yaml.safe_load(f)) + + def test_eval_cases_loaded(self): + self.assertGreater(len(self.cases), 0, "No eval cases found") + + def test_routing_classification(self): + failures = [] + for case in self.cases: + task = classify_task(case["input"]) + + if task.task_type != case["expect_type"]: + failures.append( + f" {case['name']}: expected type={case['expect_type']}, got={task.task_type}" + ) + if task.needs_vision != case["expect_vision"]: + failures.append( + f" {case['name']}: expected vision={case['expect_vision']}, got={task.needs_vision}" + ) + if task.needs_tools != case["expect_tools"]: + failures.append( + f" {case['name']}: expected tools={case['expect_tools']}, got={task.needs_tools}" + ) + + if failures: + pass_count = len(self.cases) * 3 - len(failures) + total = len(self.cases) * 3 + msg = f"\n{len(failures)}/{total} assertions failed:\n" + "\n".join(failures) + self.fail(msg) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_gate.py b/tests/test_gate.py new file mode 100644 index 0000000..bb42a38 --- /dev/null +++ b/tests/test_gate.py @@ -0,0 +1,59 @@ +"""Tests for the quality gate.""" + +import unittest + +from src.quality.gate import QualityGate, QualityVerdict + + +class TestQualityGate(unittest.TestCase): + def setUp(self): + self.gate = QualityGate(min_length=20) + + def test_pass_normal_response(self): + r = self.gate.check("Here is a detailed explanation of how the function works.", "code") + self.assertTrue(r.passed) + + def test_fail_empty(self): + r = self.gate.check("", "code") + self.assertEqual(r.verdict, QualityVerdict.FAIL_EMPTY) + + def test_fail_whitespace(self): + r = self.gate.check(" \n\t ", "code") + self.assertEqual(r.verdict, QualityVerdict.FAIL_EMPTY) + + def test_fail_too_short_for_code(self): + r = self.gate.check("Yes.", "code") + self.assertEqual(r.verdict, QualityVerdict.FAIL_TOO_SHORT) + + def test_pass_short_for_quick(self): + r = self.gate.check("Yes.", "quick") + self.assertTrue(r.passed) + + def test_fail_greeting(self): + r = self.gate.check("Hello! How can I help you today?", "quick") + self.assertEqual(r.verdict, QualityVerdict.FAIL_GREETING) + + def test_fail_greeting_variant(self): + r = self.gate.check("Sure, I'd be happy to help you with that!", "quick") + self.assertEqual(r.verdict, QualityVerdict.FAIL_GREETING) + + def test_fail_refusal(self): + r = self.gate.check("I cannot help with that request.", "code") + self.assertEqual(r.verdict, QualityVerdict.FAIL_REFUSAL) + + def test_fail_repetition(self): + repeated = "The answer is 42. " * 5 + r = self.gate.check(repeated, "code") + self.assertEqual(r.verdict, QualityVerdict.FAIL_REPETITION) + + def test_fail_json(self): + r = self.gate.check("{ invalid json here }", "code", expect_json=True) + self.assertEqual(r.verdict, QualityVerdict.FAIL_JSON) + + def test_pass_valid_json(self): + r = self.gate.check('{"key": "value", "num": 42}', "code", expect_json=True) + self.assertTrue(r.passed) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scorer.py b/tests/test_scorer.py new file mode 100644 index 0000000..d110d92 --- /dev/null +++ b/tests/test_scorer.py @@ -0,0 +1,102 @@ +"""Tests for the scoring router.""" + +import unittest + +from src.providers.base import ModelInfo, Provider, CompletionResult +from src.routing.breaker import CircuitBreaker +from src.routing.scorer import ScoringRouter, TaskProfile, classify_task +from src.store.db import HotSauceDB + + +class FakeProvider(Provider): + name = "fake" + + def __init__(self, model_list: list[ModelInfo]): + self._models = model_list + + def models(self) -> list[ModelInfo]: + return self._models + + def complete(self, model, messages, **kwargs) -> CompletionResult: + return CompletionResult("fake", model, "fake", 10, 10, 100.0, 0.0) + + +class TestClassifyTask(unittest.TestCase): + def test_code_task(self): + t = classify_task("fix this bug in the function") + self.assertEqual(t.task_type, "code") + + def test_reasoning_task(self): + t = classify_task("explain why this architecture is better") + self.assertEqual(t.task_type, "reasoning") + + def test_quick_task(self): + t = classify_task("what is a monad") + self.assertEqual(t.task_type, "quick") + + def test_vision_task(self): + t = classify_task("describe this screenshot") + self.assertEqual(t.task_type, "vision") + + def test_explicit_override(self): + t = classify_task("@gemini explain this error") + self.assertEqual(t.explicit_model, "gemini") + + def test_tools_needed(self): + t = classify_task("search for recent papers on transformers") + self.assertTrue(t.needs_tools) + + +class TestScoringRouter(unittest.TestCase): + def setUp(self): + self.db = HotSauceDB(":memory:") + self.breaker = CircuitBreaker(self.db) + + self.cheap = ModelInfo("cheap-fast", "fake", 128_000, 0.0, 0.0, tags=["fast", "cheap", "free"]) + self.coder = ModelInfo("code-heavy", "fake", 200_000, 0.01, 0.04, supports_tools=True, tags=["code", "reasoning"]) + self.vision = ModelInfo("vision-model", "fake", 128_000, 0.005, 0.02, supports_vision=True, tags=["fast"]) + + self.provider = FakeProvider([self.cheap, self.coder, self.vision]) + self.router = ScoringRouter(self.db, [self.provider], self.breaker) + + def tearDown(self): + self.db.close() + + def test_code_task_prefers_coder(self): + task = TaskProfile("code", estimated_tokens=1000) + ranked = self.router.rank(task) + names = [n for n, _ in ranked] + self.assertEqual(names[0], "code-heavy") + + def test_quick_task_prefers_cheap(self): + task = TaskProfile("quick", estimated_tokens=100) + ranked = self.router.rank(task) + names = [n for n, _ in ranked] + self.assertEqual(names[0], "cheap-fast") + + def test_vision_filters_non_vision(self): + task = TaskProfile("vision", needs_vision=True) + ranked = self.router.rank(task) + names = [n for n, _ in ranked] + self.assertIn("vision-model", names) + self.assertNotIn("cheap-fast", names) + self.assertNotIn("code-heavy", names) + + def test_breaker_penalises(self): + task = TaskProfile("quick", estimated_tokens=100) + # Trip the cheap model's breaker + for _ in range(3): + self.breaker.record_failure("cheap-fast") + ranked = self.router.rank(task) + names = [n for n, _ in ranked] + # cheap-fast should no longer be first + self.assertNotEqual(names[0], "cheap-fast") + + def test_select_returns_valid(self): + model, provider, task = self.router.select("what is python") + self.assertIn(model, ["cheap-fast", "code-heavy", "vision-model"]) + self.assertEqual(provider, "fake") + + +if __name__ == "__main__": + unittest.main()