diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/hmmv/end_to_end.py b/examples/hmmv/end_to_end.py new file mode 100644 index 0000000..8ffc884 --- /dev/null +++ b/examples/hmmv/end_to_end.py @@ -0,0 +1,424 @@ +""" +examples/hmmv/end_to_end.py +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Full worked example: from setup to proposal approval. + +This file is self-contained and runnable: + python3 -m examples.hmmv.end_to_end + +It demonstrates: + 1. Harness setup with in-memory stores + 2. Early regime: 22 runs with converging models → baseline building + 3. Mature regime: normal convergent runs (quiet) + 4. Drift trigger: edge case input splits models + 5. CorrectionRunner: full pipeline with stub LLM + stub model runner + 6. Proposal review: inspect → approve → spec registry updated + 7. Verification: re-run drifting input → regime returns to convergent + +No real LLM or model calls are made. All external dependencies are stubs +that simulate realistic behaviour (converging scores for normal inputs, +divergent scores for the one edge-case input). +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import statistics +import uuid +from dataclasses import replace +from datetime import datetime, timezone + +logging.basicConfig( + level=logging.INFO, + format="%(levelname)-8s %(name)s: %(message)s", +) +log = logging.getLogger("example") + +from manifold.testing.convergence import ConvergenceConfig, ConvergenceMonitor +from manifold.testing.correction import CorrectionRunner +from manifold.testing.models import ( + DriftSignal, + DriftType, + ProposalStatus, + ReviewStatus, + SpecProposal, + _compute_mad, +) +from manifold.testing.stores import ( + InMemoryBaselineStore, + InMemoryProposalStore, + InMemorySnapshotStore, + InMemorySpecRegistry, +) + +# ============================================================================= +# STUB: Model stubs simulating 4 diverse LLM models +# ============================================================================= + +EDGE_CASE_INPUT = { + "name": "Caritas Berlin e.V.", + "description": "Catholic welfare organisation providing social services.", + "funding_tags": ["soziale_arbeit", "wohlfahrt"], +} + +# Normal inputs → all 4 models within ~0.05 of each other +NORMAL_SCORES = { + "gpt-4o": 0.82, + "gemini-flash": 0.80, + "llama-3.3": 0.81, + "mistral-small": 0.79, +} + +# Edge case → models split: 2 say strongly religious, 2 say secular welfare +EDGE_CASE_SCORES_BEFORE = { + "gpt-4o": 0.85, + "gemini-flash": 0.85, + "llama-3.3": -0.60, + "mistral-small": -0.65, +} + +# After proposed criteria applied → all models converge +EDGE_CASE_SCORES_AFTER = { + "gpt-4o": 0.80, + "gemini-flash": 0.79, + "llama-3.3": 0.78, + "mistral-small": 0.77, +} + + +async def stub_model_runner( + input_data: dict, + criteria_hint: str, + model_id: str, +) -> float: + """ + Stub model runner used by CorrectionRunner during validation. + + If criteria_hint is non-empty (correction workflow) AND the input + matches the edge case → return converged scores. + Otherwise → return the before-correction edge case scores. + """ + is_edge_case = "caritas" in input_data.get("name", "").lower() + has_criteria = bool(criteria_hint.strip()) + + if is_edge_case and has_criteria: + return EDGE_CASE_SCORES_AFTER[model_id] + elif is_edge_case: + return EDGE_CASE_SCORES_BEFORE[model_id] + else: + return NORMAL_SCORES.get(model_id, 0.80) + + +async def stub_llm_caller(prompt: str) -> str: + """Stub LLM that returns a canned correction proposal.""" + return json.dumps( + { + "proposed_change": ( + "Add explicit criterion for Catholic/Protestant welfare organisations: " + "organisations whose primary mission is social welfare but are operated " + "by a religious body should be classified as ngo_religious. " + "The religious character of the operating body takes precedence over " + "the service domain." + ), + "proposed_spec_code": ( + "# Criterion: welfare org operated by religious body → ngo_religious\n" + "RELIGIOUS_WELFARE_KEYWORDS = ['caritas', 'diakonie', 'malteser',\n" + " 'johanniter', 'rotes kreuz']\n" + "org_name_lower = candidate.get('name', '').lower()\n" + "if any(kw in org_name_lower for kw in RELIGIOUS_WELFARE_KEYWORDS):\n" + " return SpecResult.ok(rule_id=self.rule_id,\n" + " message='Religious welfare org: ngo_religious',\n" + " data={'religious_welfare': True})" + ), + "hypothesis": ( + "Catholic welfare orgs like Caritas are operated by the Catholic Church " + "and should be classified as ngo_religious regardless of their service domain. " + "Explicit keyword matching for well-known religious welfare bodies removes " + "the ambiguity that causes model divergence." + ), + "target_spec_id": "classify_spec_v1", + } + ) + + +# ============================================================================= +# Lightweight harness (direct wiring — no external manifold dependency) +# ============================================================================= + +MODEL_IDS = ["gpt-4o", "gemini-flash", "llama-3.3", "mistral-small"] + + +class DirectHarness: + """ + Minimal harness for the example. + Directly wires ConvergenceMonitor + CorrectionRunner + stores. + In production, use HMMVTestHarness which wraps all of this. + """ + + def __init__(self): + self.baseline = InMemoryBaselineStore() + self.snapshots = InMemorySnapshotStore() + self.proposals = InMemoryProposalStore() + self.spec_registry = InMemorySpecRegistry() + + self.config = ConvergenceConfig( + min_baseline_size=20, + drift_multiplier=2.5, + min_class_records=5, + ) + self.monitor = ConvergenceMonitor( + baseline_store=self.baseline, + config=self.config, + spec_versions={"classify_spec_v1": "1.0.0"}, + ) + self.correction_runner = CorrectionRunner( + llm_caller=stub_llm_caller, + model_runner=stub_model_runner, + model_ids=MODEL_IDS, + improvement_threshold=0.25, + ) + + self.total_runs: int = 0 + self.drift_signals: list[DriftSignal] = [] + self.proposals_generated: list[SpecProposal] = [] + + async def _refresh_cache(self): + total = await self.baseline.total_records() + snapshot = await self.snapshots.latest() + if snapshot and snapshot.total_records > 0: + mads = snapshot.mad_by_class + counts = snapshot.records_by_class + else: + mads, counts = {}, {} + for r in self.baseline._records: + counts[r.input_class] = counts.get(r.input_class, 0) + 1 + by_class: dict[str, list] = {} + for r in self.baseline._records: + by_class.setdefault(r.input_class, []).append(r.inter_model_mad) + mads = {c: statistics.mean(vs) for c, vs in by_class.items()} + self.monitor.update_baseline_cache(total, mads, counts) + + async def run( + self, + input_data: dict, + model_scores: dict[str, float], + input_class: str = "ngo_religious", + ) -> dict: + """Simulate a single workflow run and return convergence result.""" + self.total_runs += 1 + run_id = str(uuid.uuid4()) + await self._refresh_cache() + + result = self.monitor.evaluate_sync( + run_id=run_id, + input_data=input_data, + input_class=input_class, + cluster_version="v1", + model_scores=model_scores, + raw_outputs={}, + ) + + signals = self.monitor.drain_signals() + records = self.monitor.drain_records() + + for r in records: + await self.baseline.append(r) + + for sig in signals: + sig = replace(sig, triggering_input=input_data) + await self.baseline.append_signal(sig) + self.drift_signals.append(sig) + log.info( + "DRIFT DETECTED: %s on class=%s MAD %.3f (expected %.3f)", + sig.drift_type.value, + sig.input_class, + sig.observed_mad, + sig.expected_mad or 0, + ) + + log.info("Starting correction workflow for signal %s ...", sig.signal_id[:8]) + proposal = await self.correction_runner.run(sig) + if proposal: + await self.proposals.write(proposal) + self.proposals_generated.append(proposal) + log.info( + "Proposal ready: %s status=%s improvement=%.4f", + proposal.proposal_id[:8], + proposal.proposal_status.value, + proposal.mad_improvement or 0.0, + ) + + total = await self.baseline.total_records() + if total > 0 and total % 10 == 0: + snapshot = await self.baseline.take_snapshot(self.spec_registry) + await self.snapshots.write(snapshot) + + return { + "run_id": run_id, + "regime": result["regime"], + "mad": result["mad"], + "message": result["message"], + "had_drift": len(signals) > 0, + } + + +# ============================================================================= +# Main example flow +# ============================================================================= + + +async def main(): + harness = DirectHarness() + + print("\n" + "═" * 70) + print("MANIFOLD TESTING — End-to-End Example") + print("═" * 70) + + # ───────────────────────────────────────────────────────────────────────── + print("\n[1] EARLY REGIME — Building baseline (20 runs needed)") + print("─" * 50) + + for i in range(22): + org = {"name": f"Test NGO {i}", "funding": 10_000 + i * 500} + scores = {m: NORMAL_SCORES[m] + (i % 3 - 1) * 0.01 for m in MODEL_IDS} + r = await harness.run(org, scores) + if (i + 1) % 5 == 0: + print( + f" Run {i+1:2d}: regime={r['regime']:<12} MAD={r['mad']:.4f} {r['message'][:60]}" + ) + + baseline_size = await harness.baseline.total_records() + print(f"\n Baseline built: {baseline_size} records") + assert baseline_size >= 20, "Should have crossed min_baseline_size" + assert harness.drift_signals == [], "No drift expected during baseline building" + print(" No drift signals during baseline phase — as expected.") + + # ───────────────────────────────────────────────────────────────────────── + print("\n[2] MATURE REGIME — Normal convergent runs") + print("─" * 50) + + for i in range(5): + org = {"name": f"Normal NGO {i}", "funding": 50_000} + r = await harness.run(org, NORMAL_SCORES) + print(f" Run {i+1}: regime={r['regime']:<12} MAD={r['mad']:.4f}") + + print(" All runs convergent — drift detection active but quiet.") + + # ───────────────────────────────────────────────────────────────────────── + print("\n[3] DRIFT TRIGGER — Edge case input") + print("─" * 50) + print(f" Input: {EDGE_CASE_INPUT['name']}") + print( + f" Scores: gpt={EDGE_CASE_SCORES_BEFORE['gpt-4o']:.2f} " + f"gemini={EDGE_CASE_SCORES_BEFORE['gemini-flash']:.2f} " + f"llama={EDGE_CASE_SCORES_BEFORE['llama-3.3']:.2f} " + f"mistral={EDGE_CASE_SCORES_BEFORE['mistral-small']:.2f}" + ) + print(f" Expected MAD: ~{_compute_mad(list(EDGE_CASE_SCORES_BEFORE.values())):.3f}") + + r = await harness.run(EDGE_CASE_INPUT, EDGE_CASE_SCORES_BEFORE) + print(f"\n Result: regime={r['regime']} MAD={r['mad']:.4f}") + print(f" Message: {r['message']}") + + assert r["regime"] == "drift", f"Expected drift, got {r['regime']}" + assert len(harness.drift_signals) == 1 + + # ───────────────────────────────────────────────────────────────────────── + print("\n[4] CORRECTION WORKFLOW OUTPUT") + print("─" * 50) + + proposals = await harness.proposals.pending_proposals() + assert len(proposals) == 1, f"Expected 1 proposal, got {len(proposals)}" + proposal = proposals[0] + + print(f" Proposal ID: {proposal.proposal_id[:16]}...") + print(f" Status: {proposal.proposal_status.value}") + print(f" Target spec: {proposal.target_spec_id}") + print(f" MAD before: {proposal.validation_mad_before:.4f}") + print(f" MAD after: {proposal.validation_mad_after:.4f}") + print( + f" Improvement: {proposal.mad_improvement:.4f} " + f"({proposal.mad_improvement / proposal.validation_mad_before * 100:.1f}%)" + ) + print(f"\n Hypothesis:\n {proposal.hypothesis}") + + assert proposal.proposal_status == ProposalStatus.VALIDATED + assert proposal.mad_improvement > 0 + + # ───────────────────────────────────────────────────────────────────────── + print("\n[5] HUMAN REVIEW — Approve proposal") + print("─" * 50) + + approved = replace( + proposal, + review_status=ReviewStatus.APPROVED, + reviewer_notes="Confirmed: Caritas-type orgs should be ngo_religious. Approved.", + applied_at=datetime.now(timezone.utc), + ) + await harness.spec_registry.apply_proposal(approved) + + new_versions = await harness.spec_registry.current_versions() + print(" Approved. Spec registry updated:") + for spec_id, version in new_versions.items(): + print(f" {spec_id}: {version}") + + # ───────────────────────────────────────────────────────────────────────── + print("\n[6] VERIFICATION — Re-run edge case with updated model behaviour") + print("─" * 50) + print(" (Models now return converged scores for Caritas-type inputs)") + + initial_signal_count = len(harness.drift_signals) + r2 = await harness.run(EDGE_CASE_INPUT, EDGE_CASE_SCORES_AFTER) + print(f"\n Result: regime={r2['regime']} MAD={r2['mad']:.4f}") + print(f" Message: {r2['message']}") + + assert r2["regime"] in ( + "convergent", + "novel_class", + ), f"Expected convergent after fix, got {r2['regime']}" + assert len(harness.drift_signals) == initial_signal_count, "No new drift after fix" + print(" No new drift signal — edge case now classified consistently.") + + # ───────────────────────────────────────────────────────────────────────── + print("\n" + "═" * 70) + print("SUMMARY") + print("═" * 70) + print(f" Total runs executed: {harness.total_runs}") + print(f" Baseline records built: {await harness.baseline.total_records()}") + print(f" Drift signals detected: {len(harness.drift_signals)}") + print(f" Proposals generated: {len(harness.proposals_generated)}") + print( + f" Proposals validated: " + f"{sum(1 for p in harness.proposals_generated if p.proposal_status == ProposalStatus.VALIDATED)}" + ) + + signal = harness.drift_signals[0] + print(f"\n Drift signal:") + print(f" Type: {signal.drift_type.value}") + print(f" Input class: {signal.input_class}") + print(f" Observed MAD: {signal.observed_mad:.4f}") + print(f" Expected MAD: {signal.expected_mad:.4f}") + + p = harness.proposals_generated[0] + print(f"\n Correction proposal:") + print( + f" MAD reduction: {p.mad_improvement:.4f} " + f"({p.mad_improvement / p.validation_mad_before * 100:.1f}%)" + ) + print(f" Validation: {p.proposal_status.value}") + print(f" Models converged after: {p.models_converged_after}/{len(MODEL_IDS)}") + + print("\n System operated correctly across all phases:") + print(" early regime → baseline accumulated, no false positives") + print(" mature regime → normal runs passed silently") + print(" drift detected → signal emitted, correction triggered") + print(" correction run → proposal generated and validated") + print(" human reviewed → spec registry updated") + print(" re-verification → edge case now convergent") + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/simple_example/example.py b/examples/simple_example/example.py index 56f9d9b..880401c 100644 --- a/examples/simple_example/example.py +++ b/examples/simple_example/example.py @@ -12,63 +12,61 @@ import asyncio from dataclasses import dataclass from manifold import ( - Context, Spec, SpecResult, Agent, AgentOutput, - OrchestratorBuilder, create_context + Context, + Spec, + SpecResult, + Agent, + AgentOutput, + OrchestratorBuilder, + create_context, ) - # ─── SPECS ─────────────────────────────────────────────────────────────── class HasInputData(Spec): """Pre-condition: input_data must exist.""" - + @property def rule_id(self) -> str: return "has_input_data" - + @property def tags(self): return ("precondition", "data") - + def evaluate(self, context: Context, candidate=None) -> SpecResult: if context.has_data("input_data"): - return SpecResult.ok( - self.rule_id, - "Input data is present", - tags=self.tags - ) + return SpecResult.ok(self.rule_id, "Input data is present", tags=self.tags) return SpecResult.fail( self.rule_id, "Missing input_data", suggested_fix="Provide 'input_data' in initial context", - tags=self.tags + tags=self.tags, ) class OutputNotEmpty(Spec): """Post-condition: output must not be empty.""" - + @property def rule_id(self) -> str: return "output_not_empty" - + @property def tags(self): return ("postcondition", "output") - + def evaluate(self, context: Context, candidate=None) -> SpecResult: if candidate and len(str(candidate)) > 0: return SpecResult.ok( - self.rule_id, - f"Output produced: {len(str(candidate))} chars", - tags=self.tags + self.rule_id, f"Output produced: {len(str(candidate))} chars", tags=self.tags ) return SpecResult.fail( self.rule_id, "Output is empty or None", suggested_fix="Ensure agent produces non-empty output", - tags=self.tags + tags=self.tags, ) @@ -77,23 +75,23 @@ def evaluate(self, context: Context, candidate=None) -> SpecResult: class DataProcessorAgent(Agent): """Simple agent that processes input data.""" - + @property def agent_id(self) -> str: return "data_processor" - + async def execute(self, context: Context, input_data=None) -> AgentOutput: """Process the input data.""" # Get input from context raw_data = context.get_data("input_data", "") - + # Simple processing: uppercase and add suffix processed = f"{raw_data.upper()} - PROCESSED" - + return AgentOutput( output=processed, delta={"processed_data": processed}, - cost=0.001 # Minimal cost for demo + cost=0.001, # Minimal cost for demo ) @@ -102,12 +100,12 @@ async def execute(self, context: Context, input_data=None) -> AgentOutput: async def main(): """Run the example workflow.""" - + print("=" * 60) print("Manifold Simple Example - Data Processing Workflow") print("=" * 60) print() - + # Build orchestrator print("Building orchestrator...") orchestrator = ( @@ -120,13 +118,11 @@ async def main(): ) print("[OK] Orchestrator built") print() - + # Run workflow print("Running workflow...") - result = await orchestrator.run( - initial_data={"input_data": "hello world"} - ) - + result = await orchestrator.run(initial_data={"input_data": "hello world"}) + print() print("=" * 60) print("RESULTS") @@ -137,13 +133,13 @@ async def main(): print(f"Total Retries: {result.total_retries}") print(f"Duration: {result.duration_ms}ms") print() - + # Show final data print("Final Context Data:") for key, value in result.final_context.data.items(): print(f" {key}: {value}") print() - + # Show trace print("Execution Trace:") for i, entry in enumerate(result.final_context.trace, 1): @@ -155,7 +151,7 @@ async def main(): if entry.error: print(f" - Error: {entry.error}") print() - + print("=" * 60) print("Example completed successfully!") print("=" * 60) diff --git a/examples/sprite_generation/example.py b/examples/sprite_generation/example.py index 63f1855..512225f 100644 --- a/examples/sprite_generation/example.py +++ b/examples/sprite_generation/example.py @@ -43,6 +43,7 @@ async def main(): # This assumes sprite_pipeline is installed try: from sprite_pipeline.providers.fast_hook_provider import FastHookProvider + hook_provider = FastHookProvider() print("[OK] FastHookProvider initialized") except ImportError: diff --git a/examples/sprite_generation/harness/agent.py b/examples/sprite_generation/harness/agent.py index 4684bb9..2fdb367 100644 --- a/examples/sprite_generation/harness/agent.py +++ b/examples/sprite_generation/harness/agent.py @@ -36,7 +36,9 @@ def agent_id(self) -> str: def description(self) -> str: return "Generates sprite images using GPT image models" - async def execute(self, context: Context, input_data: dict[str, Any] | None = None) -> AgentOutput: + async def execute( + self, context: Context, input_data: dict[str, Any] | None = None + ) -> AgentOutput: """ Generate sprite image. @@ -54,17 +56,11 @@ async def execute(self, context: Context, input_data: dict[str, Any] | None = No gen_size = context.get_data("gen_size", "1024x1024") if not prompt: - return AgentOutput( - output=None, - delta={}, - cost=0.0 - ) + return AgentOutput(output=None, delta={}, cost=0.0) # Create image generation request request = HookRequest( - task_type=HookTaskType.GENERATE_IMAGE, - prompt_text=prompt, - gen_size=gen_size + task_type=HookTaskType.GENERATE_IMAGE, prompt_text=prompt, gen_size=gen_size ) # Call hook provider @@ -74,34 +70,30 @@ async def execute(self, context: Context, input_data: dict[str, Any] | None = No artifact = response.artifacts[0] return AgentOutput( - output={ - "width": artifact.width, - "height": artifact.height, - "status": "ok" - }, + output={"width": artifact.width, "height": artifact.height, "status": "ok"}, delta={ "generated_image": { "width": artifact.width, "height": artifact.height, - "size_bytes": len(artifact.png_bytes) + "size_bytes": len(artifact.png_bytes), }, - "image_bytes": artifact.png_bytes + "image_bytes": artifact.png_bytes, }, - cost=0.04 # Approximate GPT image generation cost + cost=0.04, # Approximate GPT image generation cost ) elif response.status == "content_policy": return AgentOutput( output={"status": "content_policy", "error": response.error_message}, delta={"error": response.error_message}, - cost=0.0 + cost=0.0, ) else: return AgentOutput( output={"status": "error", "error": response.error_message}, delta={"error": response.error_message}, - cost=0.0 + cost=0.0, ) @@ -123,7 +115,9 @@ def agent_id(self) -> str: def description(self) -> str: return "Builds optimized prompts for sprite generation" - async def execute(self, context: Context, input_data: dict[str, Any] | None = None) -> AgentOutput: + async def execute( + self, context: Context, input_data: dict[str, Any] | None = None + ) -> AgentOutput: """ Build sprite generation prompt. @@ -140,16 +134,10 @@ async def execute(self, context: Context, input_data: dict[str, Any] | None = No global_style = context.get_data("global_style", "Pixel Art") if not spec: - return AgentOutput( - output="", - delta={}, - cost=0.0 - ) + return AgentOutput(output="", delta={}, cost=0.0) request = HookRequest( - task_type=HookTaskType.BUILD_PROMPT, - spec=spec, - global_style=global_style + task_type=HookTaskType.BUILD_PROMPT, spec=spec, global_style=global_style ) response = await self._provider.run(request) @@ -158,13 +146,11 @@ async def execute(self, context: Context, input_data: dict[str, Any] | None = No return AgentOutput( output=response.text_output, delta={"prompt_text": response.text_output}, - cost=0.0 # Prompt building is lightweight + cost=0.0, # Prompt building is lightweight ) return AgentOutput( - output="", - delta={"error": response.error_message or "Prompt build failed"}, - cost=0.0 + output="", delta={"error": response.error_message or "Prompt build failed"}, cost=0.0 ) @@ -186,7 +172,9 @@ def agent_id(self) -> str: def description(self) -> str: return "Builds grid-specific generation briefs" - async def execute(self, context: Context, input_data: dict[str, Any] | None = None) -> AgentOutput: + async def execute( + self, context: Context, input_data: dict[str, Any] | None = None + ) -> AgentOutput: """ Build generation brief with grid constraints. @@ -208,22 +196,20 @@ async def execute(self, context: Context, input_data: dict[str, Any] | None = No task_type=HookTaskType.BUILD_BRIEF, spec=spec, global_style=global_style, - prompt_text=prompt_text + prompt_text=prompt_text, ) response = await self._provider.run(request) if response.status == "ok" and response.text_output: return AgentOutput( - output=response.text_output, - delta={"brief_text": response.text_output}, - cost=0.0 + output=response.text_output, delta={"brief_text": response.text_output}, cost=0.0 ) return AgentOutput( output=prompt_text, # Fallback to base prompt delta={"brief_text": prompt_text}, - cost=0.0 + cost=0.0, ) @@ -241,5 +227,5 @@ def create_sprite_agents(hook_provider: Any) -> dict[str, Agent]: return { "prompt_builder": PromptBuilderAgent(hook_provider), "brief_builder": BriefBuilderAgent(hook_provider), - "sprite_generator": SpriteGenerationAgent(hook_provider) + "sprite_generator": SpriteGenerationAgent(hook_provider), } diff --git a/examples/sprite_generation/harness/specs.py b/examples/sprite_generation/harness/specs.py index 92705c1..4602279 100644 --- a/examples/sprite_generation/harness/specs.py +++ b/examples/sprite_generation/harness/specs.py @@ -47,7 +47,7 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: self.rule_id, "No image metadata to validate", suggested_fix="Ensure agent returns dict with 'width' and 'height'", - tags=self.tags + tags=self.tags, ) width = candidate.get("width", 0) @@ -58,7 +58,7 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: self.rule_id, f"Image dimensions OK: {width}x{height}", tags=self.tags, - data={"width": width, "height": height} + data={"width": width, "height": height}, ) return SpecResult.fail( @@ -70,8 +70,8 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: "actual_width": width, "actual_height": height, "min_width": self._min_width, - "min_height": self._min_height - } + "min_height": self._min_height, + }, ) @@ -114,7 +114,7 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: self.rule_id, "No expected_frames defined in context", suggested_fix="Set context.data['expected_frames'] before extraction", - tags=self.tags + tags=self.tags, ) if extracted >= expected: @@ -122,7 +122,7 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: self.rule_id, f"Extraction succeeded: {extracted}/{expected} frames", tags=self.tags, - data={"expected": expected, "extracted": extracted} + data={"expected": expected, "extracted": extracted}, ) return SpecResult.fail( @@ -130,11 +130,7 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: f"Extraction incomplete: {extracted}/{expected} frames", suggested_fix="Regenerate with clearer grid constraints or adjust extraction logic", tags=self.tags, - data={ - "expected": expected, - "extracted": extracted, - "missing": expected - extracted - } + data={"expected": expected, "extracted": extracted, "missing": expected - extracted}, ) @@ -177,7 +173,7 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: self.rule_id, "No expected grid_size in context", suggested_fix="Set context.data['grid_size'] = N for NxN grid", - tags=self.tags + tags=self.tags, ) if detected_grid is None: @@ -185,7 +181,7 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: self.rule_id, "Grid detection not performed", suggested_fix="Run grid detection on generated image", - tags=self.tags + tags=self.tags, ) if detected_grid == expected_grid: @@ -193,7 +189,7 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: self.rule_id, f"Grid layout correct: {detected_grid}x{detected_grid}", tags=self.tags, - data={"grid_size": detected_grid} + data={"grid_size": detected_grid}, ) return SpecResult.fail( @@ -201,10 +197,7 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: f"Grid mismatch: expected {expected_grid}x{expected_grid}, got {detected_grid}x{detected_grid}", suggested_fix="Strengthen grid constraints in prompt or regenerate", tags=self.tags, - data={ - "expected": expected_grid, - "detected": detected_grid - } + data={"expected": expected_grid, "detected": detected_grid}, ) @@ -223,17 +216,13 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: style = context.get_data("global_style") if style and len(str(style).strip()) > 0: - return SpecResult.ok( - self.rule_id, - f"Style defined: {style}", - tags=self.tags - ) + return SpecResult.ok(self.rule_id, f"Style defined: {style}", tags=self.tags) return SpecResult.fail( self.rule_id, "Missing global_style", suggested_fix="Set context.data['global_style'] = 'Pixel Art' (or other style)", - tags=self.tags + tags=self.tags, ) @@ -255,12 +244,12 @@ def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: self.rule_id, f"Prompt generated: {length} chars", tags=self.tags, - data={"prompt_length": length} + data={"prompt_length": length}, ) return SpecResult.fail( self.rule_id, "Generated prompt is empty", suggested_fix="Check prompt generation logic", - tags=self.tags + tags=self.tags, ) diff --git a/experiments/lib/agents/ollama/__init__.py b/experiments/lib/agents/ollama/__init__.py new file mode 100644 index 0000000..90f1317 --- /dev/null +++ b/experiments/lib/agents/ollama/__init__.py @@ -0,0 +1,7 @@ +""" +Ollama model agent wrappers. +""" + +from .chat_agent import OllamaAgent + +__all__ = ["OllamaAgent"] diff --git a/experiments/lib/agents/ollama/chat_agent.py b/experiments/lib/agents/ollama/chat_agent.py new file mode 100644 index 0000000..bb73135 --- /dev/null +++ b/experiments/lib/agents/ollama/chat_agent.py @@ -0,0 +1,204 @@ +""" +Ollama Chat Agent wrapper. + +Supports any locally running Ollama model (qwen2.5:14b, mistral, etc.). +Drop-in replacement for OpenAIChatAgent — implements the same Agent ABC. + +API docs: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion +""" + +import json +import urllib.request +import urllib.error +from typing import Any + +from manifold import Agent, AgentOutput, Context, ToolCall + + +class OllamaAgent(Agent): + """ + Agent wrapper for locally running Ollama models. + + Structurally identical to OpenAIChatAgent — same message format, + same JSON extraction logic — so experiments can swap backends + without touching manifests or specs. + + Differences from OpenAI: + - Endpoint: /api/chat instead of /v1/chat/completions + - cost is always 0.0 + - JSON mode via "format": "json" instead of response_format + - No streaming (stream=False) + """ + + def __init__( + self, + agent_id: str, + model: str = "qwen2.5:14b", + temperature: float = 0.0, + max_tokens: int | None = None, + json_mode: bool = False, + system_prompt: str | None = None, + base_url: str = "http://localhost:11434", + timeout: int = 120, + ): + """ + Args: + agent_id: Unique identifier for this agent instance + model: Ollama model tag (e.g. "qwen2.5:14b", "mistral") + temperature: Sampling temperature (0.0 = deterministic) + max_tokens: Maximum tokens in response (None = model default) + json_mode: If True, forces JSON output via Ollama's format param + system_prompt: Optional system prompt prepended to messages + base_url: Ollama server URL (default: localhost) + timeout: HTTP request timeout in seconds + """ + self._agent_id = agent_id + self._model = model + self._temperature = temperature + self._max_tokens = max_tokens + self._json_mode = json_mode + self._system_prompt = system_prompt + self._base_url = base_url.rstrip("/") + self._timeout = timeout + + @property + def agent_id(self) -> str: + return self._agent_id + + @property + def description(self) -> str: + return f"Ollama agent using {self._model}" + + async def execute( + self, context: Context, input_data: dict[str, Any] | None = None + ) -> AgentOutput: + """ + Run a chat completion via Ollama. + + Reads from context: + - "messages": full message list (takes priority), OR + - "user_message": single user string + + Returns AgentOutput with: + - output: parsed dict if json_mode=True, else raw string + - tool_calls: single ToolCall recording the request/response stats + - cost: always 0.0 + """ + messages = self._build_messages(context) + + if not messages: + return AgentOutput(output=None, tool_calls=[], cost=0.0) + + payload = self._build_payload(messages) + + raw_text, eval_count, prompt_eval_count, error = self._call_api(payload) + + if error: + return AgentOutput( + output=None, + tool_calls=[self._make_tool_call(prompt_eval_count, eval_count, error=error)], + cost=0.0, + ) + + output = self._parse_output(raw_text) + + return AgentOutput( + output=output, + raw=raw_text, + tool_calls=[self._make_tool_call(prompt_eval_count, eval_count)], + cost=0.0, + ) + + # ─── Private helpers ──────────────────────────────────────────────────── + + def _build_messages(self, context: Context) -> list[dict]: + messages = [] + + if self._system_prompt: + messages.append({"role": "system", "content": self._system_prompt}) + + if context.has_data("messages"): + messages.extend(context.get_data("messages")) + elif context.has_data("user_message"): + messages.append({"role": "user", "content": context.get_data("user_message")}) + + return messages + + def _build_payload(self, messages: list[dict]) -> dict: + payload: dict[str, Any] = { + "model": self._model, + "messages": messages, + "stream": False, + "options": { + "temperature": self._temperature, + }, + } + + if self._max_tokens is not None: + payload["options"]["num_predict"] = self._max_tokens + + if self._json_mode: + payload["format"] = "json" + + return payload + + def _call_api( + self, payload: dict + ) -> tuple[str | None, int, int, str | None]: + """ + POST to Ollama /api/chat. + + Returns: (raw_text, eval_count, prompt_eval_count, error_message) + """ + url = f"{self._base_url}/api/chat" + data = json.dumps(payload).encode("utf-8") + headers = {"Content-Type": "application/json"} + + try: + req = urllib.request.Request(url, data=data, headers=headers, method="POST") + with urllib.request.urlopen(req, timeout=self._timeout) as resp: + result = json.loads(resp.read().decode("utf-8")) + + raw_text = result.get("message", {}).get("content", "") + eval_count = result.get("eval_count", 0) # completion tokens + prompt_eval_count = result.get("prompt_eval_count", 0) # prompt tokens + return raw_text, eval_count, prompt_eval_count, None + + except urllib.error.URLError as e: + return None, 0, 0, f"Connection error: {e.reason}" + except Exception as e: + return None, 0, 0, str(e) + + def _parse_output(self, raw_text: str | None) -> Any: + """Parse output — JSON dict if json_mode, else raw string.""" + if not raw_text: + return None + + if self._json_mode: + try: + return json.loads(raw_text) + except json.JSONDecodeError: + # Return string if model didn't comply (happens occasionally) + return raw_text + + return raw_text + + def _make_tool_call( + self, + prompt_tokens: int, + completion_tokens: int, + error: str | None = None, + ) -> ToolCall: + return ToolCall( + name="ollama_chat", + args={ + "model": self._model, + "base_url": self._base_url, + "prompt_tokens": prompt_tokens, + }, + result={ + "completion_tokens": completion_tokens, + "error": error, + }, + duration_ms=0, + ) diff --git a/experiments/lib/agents/ollama/smoke_test.py b/experiments/lib/agents/ollama/smoke_test.py new file mode 100644 index 0000000..97b4289 --- /dev/null +++ b/experiments/lib/agents/ollama/smoke_test.py @@ -0,0 +1,116 @@ +""" +Smoke test for OllamaAgent — verifies the agent works with Manifold +before running a full experiment. + +Usage: + cd C:\\Users\\fbrmp\\Projekte\\manifold + python experiments/lib/agents/ollama/smoke_test.py +""" + +import asyncio +import sys +from pathlib import Path + +# Add repo root to path (5 levels up from experiments/lib/agents/ollama/) +_repo_root = Path(__file__).parent.parent.parent.parent.parent +sys.path.insert(0, str(_repo_root)) +sys.path.insert(1, str(_repo_root / "experiments")) + +from manifold import create_context +from lib.agents.ollama import OllamaAgent + + +SYSTEM_PROMPT = """ +You are a classification assistant. Classify the given organization. +Return valid JSON with exactly these fields: +{ + "name": "organization name", + "economic_orientation": "market" | "state" | "mixed" | "unknown", + "cultural_orientation": "progressive" | "conservative" | "neutral" | "unknown", + "sector": "ngo" | "think_tank" | "media" | "political" | "academic" | "other" +} +""" + +TEST_CASES = [ + { + "name": "Greenpeace Germany", + "description": "International environmental NGO focused on climate action and anti-nuclear campaigns.", + "expect_sector": "ngo", + }, + { + "name": "Konrad-Adenauer-Stiftung", + "description": "Political foundation affiliated with the CDU, promotes Christian democratic values.", + "expect_sector": "think_tank", + }, + { + "name": "Bertelsmann Stiftung", + "description": "Operating foundation focused on education, healthcare, and social policy reform.", + "expect_sector": "think_tank", + }, +] + + +async def run_smoke_test(): + agent = OllamaAgent( + agent_id="ollama_classifier", + model="qwen2.5:14b", + temperature=0.0, + json_mode=True, + system_prompt=SYSTEM_PROMPT, + ) + + print(f"Smoke test: OllamaAgent with {agent._model}") + print("=" * 60) + + passed = 0 + for tc in TEST_CASES: + user_message = f"Organization: {tc['name']}\nDescription: {tc['description']}" + + context = create_context( + run_id="smoke_test", + initial_data={"user_message": user_message}, + ) + + result = await agent.execute(context) + + # Check basics + ok = ( + result.output is not None + and isinstance(result.output, dict) + and result.cost == 0.0 + and len(result.tool_calls) == 1 + ) + + sector_match = ( + isinstance(result.output, dict) + and result.output.get("sector") == tc["expect_sector"] + ) + + status = "PASS" if (ok and sector_match) else "FAIL" + if ok and sector_match: + passed += 1 + + print(f"\n[{status}] {tc['name']}") + if isinstance(result.output, dict): + import json + print(json.dumps(result.output, indent=2, ensure_ascii=False)) + else: + print(f" raw output: {result.output!r}") + + tc_entry = result.tool_calls[0] if result.tool_calls else None + if tc_entry: + print(f" tokens: prompt={tc_entry.args.get('prompt_tokens')} " + f"completion={tc_entry.result.get('completion_tokens')}") + if tc_entry.result.get("error"): + print(f" error: {tc_entry.result['error']}") + + print("\n" + "=" * 60) + print(f"Result: {passed}/{len(TEST_CASES)} passed") + + if passed < len(TEST_CASES): + print("\nNote: sector mismatches are model judgment calls, not bugs.") + print("Check for None output or connection errors — those are real failures.") + + +if __name__ == "__main__": + asyncio.run(run_smoke_test()) diff --git a/experiments/lib/agents/openai/__init__.py b/experiments/lib/agents/openai/__init__.py new file mode 100644 index 0000000..69db852 --- /dev/null +++ b/experiments/lib/agents/openai/__init__.py @@ -0,0 +1,11 @@ +""" +OpenAI model agent wrappers. +""" + +from .image_agent import OpenAIImageAgent +from .chat_agent import OpenAIChatAgent + +__all__ = [ + "OpenAIImageAgent", + "OpenAIChatAgent", +] diff --git a/experiments/lib/agents/openai/chat_agent.py b/experiments/lib/agents/openai/chat_agent.py new file mode 100644 index 0000000..db4a192 --- /dev/null +++ b/experiments/lib/agents/openai/chat_agent.py @@ -0,0 +1,206 @@ +""" +OpenAI Chat Completion Agent wrapper. + +Supports GPT-4, GPT-4o, GPT-3.5-turbo, etc. +""" + +from manifold import Agent, AgentOutput, Context, ToolCall +from typing import Any +import os +import urllib.request +import urllib.parse +import json + + +class OpenAIChatAgent(Agent): + """ + Agent wrapper for OpenAI chat completion models. + + Supports: + - gpt-4o + - gpt-4-turbo + - gpt-4 + - gpt-3.5-turbo + """ + + def __init__( + self, + agent_id: str, + model: str = "gpt-4o", + temperature: float = 0.3, + max_tokens: int | None = None, + response_format: dict | None = None, + system_prompt: str | None = None, + api_key: str | None = None + ): + """ + Args: + agent_id: Unique identifier for this agent + model: OpenAI chat model + temperature: Sampling temperature (0-2) + max_tokens: Maximum tokens in response + response_format: Optional response format (e.g., {"type": "json_object"}) + system_prompt: Optional system prompt + api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) + """ + self._agent_id = agent_id + self._model = model + self._temperature = temperature + self._max_tokens = max_tokens + self._response_format = response_format + self._system_prompt = system_prompt + self._api_key = api_key or os.getenv("OPENAI_API_KEY") + + if not self._api_key: + raise ValueError("OpenAI API key required (set OPENAI_API_KEY env var)") + + @property + def agent_id(self) -> str: + return self._agent_id + + async def execute(self, context: Context, input_data: dict[str, Any] | None = None) -> AgentOutput: + """ + Generate chat completion. + + Expects context.data to have: + - user_message: User message text OR + - messages: Full message history + + Returns: + AgentOutput with assistant's message (string or parsed JSON) + """ + # Build messages array + messages = [] + + # Add system prompt if configured + if self._system_prompt: + messages.append({ + "role": "system", + "content": self._system_prompt + }) + + # Get user input from context + if context.has_data("messages"): + # Use full message history + messages.extend(context.get_data("messages")) + elif context.has_data("user_message"): + # Simple user message + messages.append({ + "role": "user", + "content": context.get_data("user_message") + }) + else: + # Missing input - return None output + return AgentOutput( + output=None, + tool_calls=[], + cost=0.0 + ) + + # Build request payload + payload = { + "model": self._model, + "messages": messages, + "temperature": self._temperature, + } + + if self._max_tokens: + payload["max_tokens"] = self._max_tokens + + if self._response_format: + payload["response_format"] = self._response_format + + # Make API request + url = "https://api.openai.com/v1/chat/completions" + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json" + } + + try: + req = urllib.request.Request( + url, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + with urllib.request.urlopen(req, timeout=120) as response: + result = json.loads(response.read().decode('utf-8')) + + # Extract response + message = result["choices"][0]["message"]["content"] + usage = result.get("usage", {}) + + # Parse JSON response if requested + output = message + if self._response_format and self._response_format.get("type") == "json_object": + try: + output = json.loads(message) + except json.JSONDecodeError: + # Return string if JSON parsing fails + pass + + # Calculate cost + cost = self._calculate_cost( + usage.get("prompt_tokens", 0), + usage.get("completion_tokens", 0) + ) + + tool_call = ToolCall( + name="openai_chat_completion", + args={ + "model": self._model, + "messages": len(messages), + "prompt_tokens": usage.get("prompt_tokens", 0) + }, + result={ + "completion_tokens": usage.get("completion_tokens", 0), + "finish_reason": result["choices"][0].get("finish_reason"), + "cost": cost + }, + duration_ms=0 + ) + + return AgentOutput( + output=output, + tool_calls=[tool_call], + cost=cost + ) + + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + # API error - return None output + return AgentOutput( + output=None, + tool_calls=[], + cost=0.0 + ) + + except Exception as e: + # Execution error - return None output + return AgentOutput( + output=None, + tool_calls=[], + cost=0.0 + ) + + def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: + """Calculate API cost based on token usage.""" + # Pricing as of 2024 (per 1M tokens) + pricing = { + "gpt-4o": (2.50, 10.00), # (input, output) per 1M tokens + "gpt-4o-mini": (0.15, 0.60), + "gpt-4-turbo": (10.00, 30.00), + "gpt-4": (30.00, 60.00), + "gpt-3.5-turbo": (0.50, 1.50), + } + + # Get pricing for model (default to gpt-4o) + input_price, output_price = pricing.get(self._model, (2.50, 10.00)) + + # Calculate cost + input_cost = (prompt_tokens / 1_000_000) * input_price + output_cost = (completion_tokens / 1_000_000) * output_price + + return input_cost + output_cost diff --git a/experiments/lib/agents/openai/image_agent.py b/experiments/lib/agents/openai/image_agent.py new file mode 100644 index 0000000..e3cefe2 --- /dev/null +++ b/experiments/lib/agents/openai/image_agent.py @@ -0,0 +1,188 @@ +""" +OpenAI Image Generation Agent wrapper. + +Supports DALL-E 3 and gpt-image-1 models. +""" + +from manifold import Agent, AgentOutput, Context, ToolCall +from typing import Any +import os +import urllib.request +import urllib.parse +import json +import base64 +from io import BytesIO + + +class OpenAIImageAgent(Agent): + """ + Agent wrapper for OpenAI image generation models. + + Supports: + - dall-e-3 + - gpt-image-1 + + Returns image metadata (URL, dimensions, base64 data). + """ + + def __init__( + self, + agent_id: str, + model: str = "dall-e-3", + size: str = "1024x1024", + quality: str = "standard", + api_key: str | None = None + ): + """ + Args: + agent_id: Unique identifier for this agent + model: OpenAI image model ("dall-e-3" or "gpt-image-1") + size: Image size (default: "1024x1024") + quality: Image quality ("standard" or "hd") - dall-e-3 only + api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) + """ + self._agent_id = agent_id + self._model = model + self._size = size + self._quality = quality + self._api_key = api_key or os.getenv("OPENAI_API_KEY") + + if not self._api_key: + raise ValueError("OpenAI API key required (set OPENAI_API_KEY env var)") + + @property + def agent_id(self) -> str: + return self._agent_id + + async def execute(self, context: Context, input_data: dict[str, Any] | None = None) -> AgentOutput: + """ + Generate image using OpenAI API. + + Expects context.data to have: + - prompt: Image generation prompt + + Returns: + AgentOutput with image metadata dict: + { + "url": "https://...", + "width": 1024, + "height": 1024, + "b64_data": "base64...", # Optional + "model": "dall-e-3", + "cost": 0.04 + } + """ + prompt = context.get_data("prompt") + + if not prompt: + return AgentOutput( + output=None, + tool_calls=[], + cost=0.0 + ) + + # Build request payload + payload = { + "model": self._model, + "prompt": prompt, + "n": 1, + "size": self._size, + } + + # Model-specific params + if self._model == "dall-e-3": + payload["quality"] = self._quality + payload["response_format"] = "url" # or "b64_json" + elif self._model == "gpt-image-1": + # gpt-image-1 doesn't support quality/response_format + pass + + # Make API request + url = "https://api.openai.com/v1/images/generations" + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json" + } + + try: + req = urllib.request.Request( + url, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + with urllib.request.urlopen(req, timeout=120) as response: + result = json.loads(response.read().decode('utf-8')) + + # Extract image data + image_data = result["data"][0] + image_url = image_data.get("url") + + # Download image to get dimensions and b64 + image_bytes = None + if image_url: + img_req = urllib.request.Request(image_url) + with urllib.request.urlopen(img_req, timeout=30) as img_response: + image_bytes = img_response.read() + + # Parse dimensions from size parameter + width, height = map(int, self._size.split('x')) + + # Build output + output = { + "url": image_url, + "width": width, + "height": height, + "model": self._model, + "cost": self._estimate_cost(), + } + + if image_bytes: + output["b64_data"] = base64.b64encode(image_bytes).decode('utf-8') + output["size_bytes"] = len(image_bytes) + + tool_call = ToolCall( + name="openai_image_generation", + args={"model": self._model, "prompt": prompt[:100]}, + result={"url": image_url, "size": self._size}, + duration_ms=0 # Could track this if needed + ) + + return AgentOutput( + output=output, + tool_calls=[tool_call], + cost=self._estimate_cost() + ) + + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + # Return None to indicate failure (error can be logged in delta if needed) + return AgentOutput( + output=None, + tool_calls=[], + cost=0.0 + ) + + except Exception as e: + # Return None to indicate failure + return AgentOutput( + output=None, + tool_calls=[], + cost=0.0 + ) + + def _estimate_cost(self) -> float: + """Estimate API cost based on model and size.""" + # Pricing as of 2024 + if self._model == "dall-e-3": + if self._quality == "hd" and self._size == "1024x1024": + return 0.080 + elif self._quality == "hd": + return 0.120 # 1024x1792 or 1792x1024 + else: + return 0.040 # standard quality + elif self._model == "gpt-image-1": + return 0.040 # Same as dall-e-3 standard + else: + return 0.040 # Default estimate diff --git a/experiments/lib/specs/content/__init__.py b/experiments/lib/specs/content/__init__.py new file mode 100644 index 0000000..acf172e --- /dev/null +++ b/experiments/lib/specs/content/__init__.py @@ -0,0 +1,26 @@ +""" +Content generation specs for validating multi-step content pipelines. + +These specs validate: +- Outline structure +- Content length +- Outline compliance +- Grammar quality +- Research sources +""" + +from .specs import ( + HasMinItemsSpec, + OutlineValidationSpec, + OutlineComplianceSpec, + LengthRangeSpec, + GrammarCheckSpec, +) + +__all__ = [ + "HasMinItemsSpec", + "OutlineValidationSpec", + "OutlineComplianceSpec", + "LengthRangeSpec", + "GrammarCheckSpec", +] diff --git a/experiments/lib/specs/content/specs.py b/experiments/lib/specs/content/specs.py new file mode 100644 index 0000000..ae24321 --- /dev/null +++ b/experiments/lib/specs/content/specs.py @@ -0,0 +1,457 @@ +""" +Specs for multi-step content generation validation. +""" + +from manifold import Spec, SpecResult, Context +from typing import Any +import re + + +class HasMinItemsSpec(Spec): + """ + Validates that a list field has minimum number of items. + + Used for research sources, outline sections, etc. + """ + + def __init__(self, field: str, min_count: int): + """ + Args: + field: Name of list field in context or candidate + min_count: Minimum required items + """ + self._field = field + self._min_count = min_count + + @property + def rule_id(self) -> str: + return f"has_min_items:{self._field}" + + @property + def tags(self) -> tuple[str, ...]: + return ("postcondition", "count", "research") + + def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: + """Check that field has minimum items.""" + # Try candidate first, then context + items = None + + if isinstance(candidate, dict) and self._field in candidate: + items = candidate[self._field] + elif context.has_data(self._field): + items = context.get_data(self._field) + + if items is None: + return SpecResult.fail( + self.rule_id, + f"Field '{self._field}' not found", + suggested_fix=f"Ensure step produces '{self._field}'", + tags=self.tags + ) + + if not isinstance(items, (list, tuple)): + return SpecResult.fail( + self.rule_id, + f"Field '{self._field}' is not a list (got {type(items).__name__})", + suggested_fix=f"Ensure '{self._field}' is a list", + tags=self.tags + ) + + count = len(items) + + if count >= self._min_count: + return SpecResult.ok( + self.rule_id, + f"{self._field} has {count} items (need {self._min_count})", + tags=self.tags, + data={"count": count, "min_count": self._min_count} + ) + + return SpecResult.fail( + self.rule_id, + f"{self._field} has only {count} items (need {self._min_count})", + suggested_fix=f"Generate at least {self._min_count} {self._field}", + tags=self.tags, + data={"count": count, "min_count": self._min_count, "missing": self._min_count - count} + ) + + +class OutlineValidationSpec(Spec): + """ + Validates outline structure (intro, sections, conclusion). + """ + + def __init__( + self, + min_sections: int = 3, + has_intro: bool = True, + has_conclusion: bool = True + ): + """ + Args: + min_sections: Minimum number of main sections + has_intro: Whether intro is required + has_conclusion: Whether conclusion is required + """ + self._min_sections = min_sections + self._has_intro = has_intro + self._has_conclusion = has_conclusion + + @property + def rule_id(self) -> str: + return "outline_structure_valid" + + @property + def tags(self) -> tuple[str, ...]: + return ("postcondition", "structure", "outline") + + def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: + """ + Validate outline structure. + + Expected candidate format (as string): + # Introduction + ... + # Section 1 + ... + # Section 2 + ... + # Conclusion + ... + """ + outline = candidate if isinstance(candidate, str) else str(candidate) + + if not outline or len(outline.strip()) == 0: + return SpecResult.fail( + self.rule_id, + "Outline is empty", + suggested_fix="Generate outline with intro, sections, conclusion", + tags=self.tags + ) + + # Count markdown headings (# Header) + headings = re.findall(r'^#+\s+(.+)$', outline, re.MULTILINE) + + if not headings: + return SpecResult.fail( + self.rule_id, + "No headings found in outline", + suggested_fix="Use markdown headings (# Section Name)", + tags=self.tags + ) + + # Check for intro + has_intro_heading = any( + 'intro' in h.lower() or 'overview' in h.lower() + for h in headings + ) + + if self._has_intro and not has_intro_heading: + return SpecResult.fail( + self.rule_id, + "Missing introduction section", + suggested_fix="Add introduction section to outline", + tags=self.tags, + data={"headings": headings} + ) + + # Check for conclusion + has_conclusion_heading = any( + 'conclusion' in h.lower() or 'summary' in h.lower() + for h in headings + ) + + if self._has_conclusion and not has_conclusion_heading: + return SpecResult.fail( + self.rule_id, + "Missing conclusion section", + suggested_fix="Add conclusion section to outline", + tags=self.tags, + data={"headings": headings} + ) + + # Count main sections (exclude intro/conclusion) + main_sections = [ + h for h in headings + if 'intro' not in h.lower() + and 'conclusion' not in h.lower() + and 'overview' not in h.lower() + and 'summary' not in h.lower() + ] + + section_count = len(main_sections) + + if section_count >= self._min_sections: + return SpecResult.ok( + self.rule_id, + f"Valid outline: {section_count} sections, intro={has_intro_heading}, conclusion={has_conclusion_heading}", + tags=self.tags, + data={ + "section_count": section_count, + "has_intro": has_intro_heading, + "has_conclusion": has_conclusion_heading, + "headings": headings + } + ) + + return SpecResult.fail( + self.rule_id, + f"Only {section_count} sections (need {self._min_sections})", + suggested_fix=f"Add {self._min_sections - section_count} more sections", + tags=self.tags, + data={ + "section_count": section_count, + "min_sections": self._min_sections, + "headings": headings + } + ) + + +class OutlineComplianceSpec(Spec): + """ + Validates that draft content follows outline structure. + + Checks that draft has headings matching the outline. + """ + + def __init__(self, strict: bool = False): + """ + Args: + strict: If True, headings must match exactly. If False, allows variations. + """ + self._strict = strict + + @property + def rule_id(self) -> str: + return "follows_outline" + + @property + def tags(self) -> tuple[str, ...]: + return ("postcondition", "compliance", "structure") + + def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: + """ + Check that draft content follows outline. + + Looks for 'outline' in context.data, compares headings with draft. + """ + outline = context.get_data("outline") + + if not outline: + return SpecResult.fail( + self.rule_id, + "No outline in context to compare against", + suggested_fix="Ensure outline step runs before draft", + tags=self.tags + ) + + draft = candidate if isinstance(candidate, str) else str(candidate) + + if not draft or len(draft.strip()) == 0: + return SpecResult.fail( + self.rule_id, + "Draft is empty", + suggested_fix="Generate draft content", + tags=self.tags + ) + + # Extract headings from both + outline_headings = re.findall(r'^#+\s+(.+)$', str(outline), re.MULTILINE) + draft_headings = re.findall(r'^#+\s+(.+)$', draft, re.MULTILINE) + + if not draft_headings: + return SpecResult.fail( + self.rule_id, + "Draft has no headings", + suggested_fix="Include section headings from outline", + tags=self.tags + ) + + # Check compliance + if self._strict: + # Exact match required + if outline_headings == draft_headings: + return SpecResult.ok( + self.rule_id, + f"Draft follows outline exactly ({len(draft_headings)} sections)", + tags=self.tags, + data={"headings": draft_headings} + ) + else: + return SpecResult.fail( + self.rule_id, + "Draft headings don't match outline", + suggested_fix="Follow outline structure exactly", + tags=self.tags, + data={ + "outline_headings": outline_headings, + "draft_headings": draft_headings + } + ) + else: + # Fuzzy match - check if most outline headings are in draft + outline_lower = [h.lower().strip() for h in outline_headings] + draft_lower = [h.lower().strip() for h in draft_headings] + + matched = sum( + 1 for oh in outline_lower + if any(oh in dh or dh in oh for dh in draft_lower) + ) + + match_ratio = matched / len(outline_headings) if outline_headings else 0 + + if match_ratio >= 0.7: # 70% match is good enough + return SpecResult.ok( + self.rule_id, + f"Draft follows outline ({matched}/{len(outline_headings)} sections match)", + tags=self.tags, + data={ + "matched": matched, + "total": len(outline_headings), + "match_ratio": match_ratio + } + ) + + return SpecResult.fail( + self.rule_id, + f"Draft only matches {matched}/{len(outline_headings)} outline sections", + suggested_fix="Follow outline structure more closely", + tags=self.tags, + data={ + "matched": matched, + "total": len(outline_headings), + "match_ratio": match_ratio, + "outline_headings": outline_headings, + "draft_headings": draft_headings + } + ) + + +class LengthRangeSpec(Spec): + """ + Validates that content length (word count) is within range. + """ + + def __init__(self, min_words: int, max_words: int): + """ + Args: + min_words: Minimum word count + max_words: Maximum word count + """ + self._min_words = min_words + self._max_words = max_words + + @property + def rule_id(self) -> str: + return "length_in_range" + + @property + def tags(self) -> tuple[str, ...]: + return ("postcondition", "length", "quality") + + def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: + """Check word count of content.""" + content = candidate if isinstance(candidate, str) else str(candidate) + + # Simple word count (split on whitespace) + words = content.split() + word_count = len(words) + + if self._min_words <= word_count <= self._max_words: + return SpecResult.ok( + self.rule_id, + f"Length OK: {word_count} words (target: {self._min_words}-{self._max_words})", + tags=self.tags, + data={"word_count": word_count, "min": self._min_words, "max": self._max_words} + ) + + if word_count < self._min_words: + return SpecResult.fail( + self.rule_id, + f"Content too short: {word_count} words (need {self._min_words})", + suggested_fix=f"Expand content to at least {self._min_words} words", + tags=self.tags, + data={ + "word_count": word_count, + "min": self._min_words, + "missing": self._min_words - word_count + } + ) + + return SpecResult.fail( + self.rule_id, + f"Content too long: {word_count} words (max {self._max_words})", + suggested_fix=f"Reduce content to max {self._max_words} words", + tags=self.tags, + data={ + "word_count": word_count, + "max": self._max_words, + "excess": word_count - self._max_words + } + ) + + +class GrammarCheckSpec(Spec): + """ + Basic grammar validation (detects obvious errors). + + Note: This is a simplified check. For production, use a real + grammar checker like LanguageTool. + """ + + # Common grammar error patterns + PATTERNS = [ + (r'\bi\s+[a-z]', "Lowercase after 'I'"), # "i am" should be "I am" + (r'\.\s+[a-z]', "Lowercase after period"), + (r'\s{2,}', "Multiple spaces"), + (r'[.!?]{2,}', "Repeated punctuation"), + ] + + @property + def rule_id(self) -> str: + return "no_grammar_errors" + + @property + def tags(self) -> tuple[str, ...]: + return ("postcondition", "quality", "grammar") + + def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: + """Check for obvious grammar errors.""" + content = candidate if isinstance(candidate, str) else str(candidate) + + if not content or len(content.strip()) == 0: + return SpecResult.fail( + self.rule_id, + "Content is empty", + suggested_fix="Generate content", + tags=self.tags + ) + + errors = [] + + for pattern, description in self.PATTERNS: + matches = re.finditer(pattern, content) + for match in matches: + errors.append({ + "type": description, + "position": match.start(), + "text": match.group()[:20] # First 20 chars + }) + + # Limit to first 10 errors (don't overwhelm) + errors = errors[:10] + + if not errors: + return SpecResult.ok( + self.rule_id, + "No obvious grammar errors detected", + tags=self.tags + ) + + return SpecResult.fail( + self.rule_id, + f"Found {len(errors)} potential grammar errors", + suggested_fix="Review and fix grammar issues", + tags=self.tags, + data={"errors": errors, "error_count": len(errors)} + ) diff --git a/experiments/lib/specs/extraction/__init__.py b/experiments/lib/specs/extraction/__init__.py new file mode 100644 index 0000000..da8efe2 --- /dev/null +++ b/experiments/lib/specs/extraction/__init__.py @@ -0,0 +1,26 @@ +""" +Data extraction specs for validating structured data extraction tasks. + +These specs validate: +- Required fields presence +- Email format validation +- Numeric range validation +- Enum/choice validation +- Schema compliance +""" + +from .specs import ( + HasRequiredFieldsSpec, + EmailValidationSpec, + RangeValidationSpec, + EnumValidationSpec, + ProgressSpec, +) + +__all__ = [ + "HasRequiredFieldsSpec", + "EmailValidationSpec", + "RangeValidationSpec", + "EnumValidationSpec", + "ProgressSpec", +] diff --git a/experiments/lib/specs/extraction/specs.py b/experiments/lib/specs/extraction/specs.py new file mode 100644 index 0000000..cb3a1d4 --- /dev/null +++ b/experiments/lib/specs/extraction/specs.py @@ -0,0 +1,378 @@ +""" +Specs for structured data extraction validation. +""" + +from manifold import Spec, SpecResult, Context +from typing import Any +import re + + +class HasRequiredFieldsSpec(Spec): + """ + Validates that all required fields are present in extracted data. + + This is the most critical spec for data extraction - ensures + all mandatory fields were successfully extracted. + """ + + def __init__(self, fields: list[str]): + """ + Args: + fields: List of required field names + """ + self._fields = fields + + @property + def rule_id(self) -> str: + return "has_required_fields" + + @property + def tags(self) -> tuple[str, ...]: + return ("postcondition", "schema", "extraction") + + def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: + """ + Check that candidate dict has all required fields. + + Args: + candidate: Dict with extracted data + """ + if not isinstance(candidate, dict): + return SpecResult.fail( + self.rule_id, + f"Candidate is not a dict (got {type(candidate).__name__})", + suggested_fix="Ensure extraction returns a dictionary", + tags=self.tags + ) + + missing = [f for f in self._fields if f not in candidate or candidate[f] is None] + + if not missing: + return SpecResult.ok( + self.rule_id, + f"All {len(self._fields)} required fields present", + tags=self.tags, + data={"required_count": len(self._fields)} + ) + + return SpecResult.fail( + self.rule_id, + f"Missing required fields: {', '.join(missing)}", + suggested_fix=f"Ensure extraction includes: {', '.join(missing)}", + tags=self.tags, + data={ + "missing_fields": missing, + "required_fields": self._fields, + "present_fields": list(candidate.keys()) + } + ) + + +class EmailValidationSpec(Spec): + """ + Validates that a field contains a properly formatted email address. + """ + + # Simple email regex - not RFC 5322 compliant but good enough + EMAIL_PATTERN = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$') + + def __init__(self, field: str = "email"): + """ + Args: + field: Name of the field to validate + """ + self._field = field + + @property + def rule_id(self) -> str: + return f"email_valid:{self._field}" + + @property + def tags(self) -> tuple[str, ...]: + return ("postcondition", "format", "extraction") + + def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: + """Validate email format in candidate dict.""" + if not isinstance(candidate, dict): + return SpecResult.fail( + self.rule_id, + "Candidate is not a dict", + suggested_fix="Ensure extraction returns dict", + tags=self.tags + ) + + if self._field not in candidate: + return SpecResult.fail( + self.rule_id, + f"Field '{self._field}' not found", + suggested_fix=f"Ensure extraction includes '{self._field}'", + tags=self.tags + ) + + email = candidate[self._field] + + if not isinstance(email, str): + return SpecResult.fail( + self.rule_id, + f"Email field is not a string (got {type(email).__name__})", + suggested_fix=f"Ensure '{self._field}' is extracted as string", + tags=self.tags + ) + + if self.EMAIL_PATTERN.match(email): + return SpecResult.ok( + self.rule_id, + f"Valid email: {email}", + tags=self.tags, + data={"email": email} + ) + + return SpecResult.fail( + self.rule_id, + f"Invalid email format: {email}", + suggested_fix="Extract valid email address from input", + tags=self.tags, + data={"invalid_email": email} + ) + + +class RangeValidationSpec(Spec): + """ + Validates that a numeric field falls within an expected range. + """ + + def __init__(self, field: str, min_val: int | float, max_val: int | float): + """ + Args: + field: Name of field to validate + min_val: Minimum acceptable value (inclusive) + max_val: Maximum acceptable value (inclusive) + """ + self._field = field + self._min = min_val + self._max = max_val + + @property + def rule_id(self) -> str: + return f"range_valid:{self._field}" + + @property + def tags(self) -> tuple[str, ...]: + return ("postcondition", "range", "extraction") + + def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: + """Validate numeric range.""" + if not isinstance(candidate, dict): + return SpecResult.fail( + self.rule_id, + "Candidate is not a dict", + suggested_fix="Ensure extraction returns dict", + tags=self.tags + ) + + if self._field not in candidate: + return SpecResult.fail( + self.rule_id, + f"Field '{self._field}' not found", + suggested_fix=f"Ensure extraction includes '{self._field}'", + tags=self.tags + ) + + value = candidate[self._field] + + # Convert to number if string + try: + if isinstance(value, str): + value = float(value) if '.' in value else int(value) + elif not isinstance(value, (int, float)): + raise ValueError(f"Cannot convert {type(value).__name__} to number") + except (ValueError, TypeError) as e: + return SpecResult.fail( + self.rule_id, + f"Field '{self._field}' is not numeric: {value}", + suggested_fix=f"Extract '{self._field}' as number between {self._min} and {self._max}", + tags=self.tags, + data={"value": str(value), "error": str(e)} + ) + + if self._min <= value <= self._max: + return SpecResult.ok( + self.rule_id, + f"{self._field}={value} within range [{self._min}, {self._max}]", + tags=self.tags, + data={"value": value, "min": self._min, "max": self._max} + ) + + return SpecResult.fail( + self.rule_id, + f"{self._field}={value} outside range [{self._min}, {self._max}]", + suggested_fix=f"Extract valid {self._field} value between {self._min} and {self._max}", + tags=self.tags, + data={ + "value": value, + "min": self._min, + "max": self._max, + "too_low": value < self._min, + "too_high": value > self._max + } + ) + + +class EnumValidationSpec(Spec): + """ + Validates that a field contains one of the allowed values. + """ + + def __init__(self, field: str, allowed: list[str], case_sensitive: bool = True): + """ + Args: + field: Name of field to validate + allowed: List of allowed values + case_sensitive: Whether to do case-sensitive comparison + """ + self._field = field + self._allowed = allowed + self._case_sensitive = case_sensitive + + @property + def rule_id(self) -> str: + return f"enum_valid:{self._field}" + + @property + def tags(self) -> tuple[str, ...]: + return ("postcondition", "enum", "extraction") + + def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: + """Validate enum value.""" + if not isinstance(candidate, dict): + return SpecResult.fail( + self.rule_id, + "Candidate is not a dict", + suggested_fix="Ensure extraction returns dict", + tags=self.tags + ) + + if self._field not in candidate: + return SpecResult.fail( + self.rule_id, + f"Field '{self._field}' not found", + suggested_fix=f"Ensure extraction includes '{self._field}'", + tags=self.tags + ) + + value = candidate[self._field] + + if not isinstance(value, str): + value = str(value) + + # Check if value is in allowed list + if self._case_sensitive: + is_valid = value in self._allowed + else: + is_valid = value.lower() in [a.lower() for a in self._allowed] + + if is_valid: + return SpecResult.ok( + self.rule_id, + f"{self._field}='{value}' is valid", + tags=self.tags, + data={"value": value, "allowed": self._allowed} + ) + + return SpecResult.fail( + self.rule_id, + f"{self._field}='{value}' not in allowed values: {self._allowed}", + suggested_fix=f"Extract {self._field} as one of: {', '.join(self._allowed)}", + tags=self.tags, + data={ + "value": value, + "allowed": self._allowed, + "case_sensitive": self._case_sensitive + } + ) + + +class ProgressSpec(Spec): + """ + Anti-loop spec: Validates that the situation has changed between retries. + + This prevents blind retries by checking that at least ONE thing changed: + - New field appeared + - Existing field value changed + - Field count increased + """ + + @property + def rule_id(self) -> str: + return "extraction_progress" + + @property + def tags(self) -> tuple[str, ...]: + return ("progress", "anti-loop") + + def evaluate(self, context: Context, candidate: Any = None) -> SpecResult: + """ + Check if extraction made progress compared to previous attempt. + + Looks for 'last_extraction' in context.data - if present, + compares with current candidate to ensure progress. + """ + if not isinstance(candidate, dict): + # If candidate isn't a dict, can't check progress + # But we don't fail - other specs will catch this + return SpecResult.ok( + self.rule_id, + "No previous extraction to compare", + tags=self.tags + ) + + last = context.get_data("last_extraction") + + if last is None or not isinstance(last, dict): + # First attempt - no previous to compare + return SpecResult.ok( + self.rule_id, + "First extraction attempt", + tags=self.tags, + data={"first_attempt": True} + ) + + # Check for progress: + # 1. More fields extracted? + new_field_count = len(candidate) + old_field_count = len(last) + + if new_field_count > old_field_count: + return SpecResult.ok( + self.rule_id, + f"Progress: {new_field_count} fields (was {old_field_count})", + tags=self.tags, + data={"new_fields": new_field_count - old_field_count} + ) + + # 2. Any field value changed? + changed_fields = [] + for key in candidate: + if key not in last or candidate[key] != last[key]: + changed_fields.append(key) + + if changed_fields: + return SpecResult.ok( + self.rule_id, + f"Progress: {len(changed_fields)} fields changed", + tags=self.tags, + data={"changed_fields": changed_fields} + ) + + # No progress - identical extraction + return SpecResult.fail( + self.rule_id, + "No progress: extraction identical to previous attempt", + suggested_fix="Try different extraction strategy or enrichment", + tags=self.tags, + data={ + "candidate": candidate, + "last": last, + "identical": True + } + ) diff --git a/manifold/testing/__init__.py b/manifold/testing/__init__.py new file mode 100644 index 0000000..b108e4c --- /dev/null +++ b/manifold/testing/__init__.py @@ -0,0 +1,98 @@ +""" +manifold.testing +~~~~~~~~~~~~~~~~ +Adaptive Convergence Testing — Heterogeneous Multi-Model Validation preset. + +Quick start +----------- + from manifold.testing import HMMVTestHarness, CorrectionRunner + from manifold.testing.stores import SQLiteBaselineStore + + async def call_llm(prompt): ... # wire to Anthropic/OpenAI/etc. + async def run_model(inp, hint, mid): ... # wire to your model clients + + harness = HMMVTestHarness( + models=["gpt-4o", "gemini-flash", "llama-3.3", "mistral-small"], + workflow_manifest="classify.yaml", + baseline_store=SQLiteBaselineStore("baseline.db"), + correction_runner=CorrectionRunner(call_llm, run_model, + model_ids=[...]), + ) + + await harness.setup() + result = await harness.run({"name": "Caritas Berlin", "type": "welfare"}) + +Standalone example (no manifold dependency on Orchestrator): + python3 -m manifold.testing.example.end_to_end +""" + +from manifold.testing.models import ( + ConvergenceRecord, + BaselineSnapshot, + DriftSignal, + DriftType, + SpecProposal, + ProposalStatus, + ReviewStatus, +) +from manifold.testing.convergence import ( + ConvergenceConfig, + ConvergenceMonitor, + make_convergence_spec, +) +from manifold.testing.correction import ( + CorrectionRunner, + CorrectionAnalysis, + Hypothesis, + ValidationResult, + analyze, + generate_hypothesis, + validate, +) +from manifold.testing.events import Event, EventBus, EventConsumer, EventType +from manifold.testing.stores import ( + InMemoryBaselineStore, + InMemorySnapshotStore, + InMemoryProposalStore, + InMemorySpecRegistry, + SQLiteBaselineStore, +) +from manifold.testing.harness import HMMVTestHarness, HMMVResult, NoOpCorrectionRunner + +__all__ = [ + # models + "ConvergenceRecord", + "BaselineSnapshot", + "DriftSignal", + "DriftType", + "SpecProposal", + "ProposalStatus", + "ReviewStatus", + # convergence + "ConvergenceConfig", + "ConvergenceMonitor", + "make_convergence_spec", + # correction + "CorrectionRunner", + "CorrectionAnalysis", + "Hypothesis", + "ValidationResult", + "analyze", + "generate_hypothesis", + "validate", + # events + "Event", + "EventBus", + "EventConsumer", + "EventType", + # stores + "InMemoryBaselineStore", + "InMemorySnapshotStore", + "InMemoryProposalStore", + "InMemorySpecRegistry", + "SQLiteBaselineStore", + # harness + "HMMVTestHarness", + "HMMVResult", + "NoOpCorrectionRunner", +] diff --git a/manifold/testing/convergence.py b/manifold/testing/convergence.py new file mode 100644 index 0000000..50d181a --- /dev/null +++ b/manifold/testing/convergence.py @@ -0,0 +1,469 @@ +""" +manifold.testing.convergence +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +ConvergenceMonitor — an invariant Spec that tracks inter-model agreement. + +Design constraint +----------------- +Specs in Manifold are pure: they cannot mutate Context. They can only +return a SpecResult (with optional data payload). + +To bridge the gap between "spec detected drift" and "event needs to be +emitted", the ConvergenceMonitor maintains an internal signal queue. +After each orchestrator run the harness calls `drain_signals()` to +collect any pending DriftSignals and emit the appropriate events. + +This keeps the spec system clean while giving the harness full control +over what happens next. + +Invariant behaviour +------------------- +The monitor ALWAYS returns SpecResult.ok() — drift is not a workflow +failure. It is a side-channel signal. The primary workflow completes +and produces its consensus result regardless. + +Regimes (emergent, not configured) +----------------------------------- +EARLY total_records < min_baseline_size + → Detection inactive. Every convergent run silently adds to baseline. + +MIDDLE baseline active, but < 10 records for this input class + → Class is new. Pass silently, mark as novel, accumulate data. + +MATURE ≥ 10 records for input class + → Compare observed MAD against class baseline. + If observed > expected × drift_multiplier → drift signal. + Otherwise → append to baseline. +""" + +from __future__ import annotations + +import statistics +import uuid +from collections import deque +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +from manifold.testing.models import ( + ConvergenceRecord, + DriftSignal, + DriftType, + _compute_mad, + _fingerprint, +) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class ConvergenceConfig: + """ + Tuning parameters for the convergence monitor. + + Defaults are intentionally conservative — the system will spend more + time in the early regime, but once drift detection activates, it is + well-calibrated. + + Attributes + ---------- + min_baseline_size Total records needed before drift detection activates. + Set high at start. Once crossed, never needs adjustment. + drift_multiplier Observed MAD must exceed expected_MAD × this value to + trigger a drift signal. 2.5 = 150% above baseline. + outlier_threshold How many σ a single model must be from the others + to be classified as MODEL_OUTLIER vs CRITERIA_GAP. + min_class_records Records per input_class before per-class MAD is used. + Below this: class is treated as novel. + record_mad_field Context data key holding the per-model score dict. + record_class_field Context data key holding the input_class label. + record_input_field Context data key holding the raw input (for fingerprint). + record_output_field Context data key holding raw per-model outputs (optional). + """ + + min_baseline_size: int = 500 + drift_multiplier: float = 2.5 + outlier_threshold: float = 2.0 # σ + min_class_records: int = 10 + record_mad_field: str = "model_scores" + record_class_field: str = "input_class" + record_input_field: str = "input_data" + record_output_field: str = "model_outputs" + + +# --------------------------------------------------------------------------- +# ConvergenceMonitor +# --------------------------------------------------------------------------- + + +class ConvergenceMonitor: + """ + Invariant Spec that monitors inter-model convergence. + + This class is NOT a subclass of manifold.core.spec.Spec because it + cannot import from the manifold package at this layer (to keep the + testing module usable standalone). The HMMVTestHarness wraps it in + an adapter that satisfies the Spec protocol. + + To use it with a raw OrchestratorBuilder, use ConvergenceMonitorSpec + (the adapter defined below), which requires manifold to be installed. + + Internal state + -------------- + _baseline_store AsyncBaselineStore protocol (injected) + _config ConvergenceConfig + _pending_signals deque of DriftSignal, drained by harness after each run + _pending_records deque of ConvergenceRecord, flushed to store after each run + _snapshot_total how many records existed at last snapshot check + """ + + rule_id = "convergence_monitor" + + def __init__( + self, + baseline_store: Any, + config: ConvergenceConfig | None = None, + spec_versions: dict[str, str] | None = None, + ) -> None: + self._baseline = baseline_store + self._config = config or ConvergenceConfig() + self._spec_versions = spec_versions or {} + self._pending_signals: deque[DriftSignal] = deque() + self._pending_records: deque[ConvergenceRecord] = deque() + + # Synchronous snapshot of baseline state for evaluate() + # Updated by harness before each run via update_baseline_cache() + self._cached_total: int = 0 + self._cached_class_mads: dict[str, float] = {} + self._cached_class_count: dict[str, int] = {} + + # ------------------------------------------------------------------ + # Cache management (called by harness, sync) + # ------------------------------------------------------------------ + + def update_baseline_cache( + self, + total_records: int, + class_mads: dict[str, float], + class_counts: dict[str, int], + ) -> None: + """ + Refresh the synchronous cache used by evaluate(). + + Called by the harness before starting a run. This avoids making + evaluate() async (which the Spec protocol does not support). + """ + self._cached_total = total_records + self._cached_class_mads = class_mads + self._cached_class_count = class_counts + + # ------------------------------------------------------------------ + # Core evaluation (synchronous — called by Spec adapter) + # ------------------------------------------------------------------ + + def evaluate_sync( + self, + run_id: str, + input_data: dict[str, Any], + input_class: str, + cluster_version: str | None, + model_scores: dict[str, float], + raw_outputs: dict[str, Any], + ) -> dict: + """ + Core logic. Returns a result dict consumed by the Spec adapter. + + Always returns regime and mad. If drift is detected, appends a + DriftSignal to _pending_signals. If convergent, appends a + ConvergenceRecord to _pending_records. + + Returns + ------- + { + "regime": "early" | "novel_class" | "convergent" | "drift", + "mad": float, + "expected_mad": float | None, + "drift_type": str | None, + "signal_id": str | None, + "message": str, + } + """ + scores = list(model_scores.values()) + if not scores: + return { + "regime": "early", + "mad": 0.0, + "expected_mad": None, + "drift_type": None, + "signal_id": None, + "message": "No model scores available yet", + } + + observed_mad = _compute_mad(scores) + + # ── REGIME 1: Early ────────────────────────────────────────── + if self._cached_total < self._config.min_baseline_size: + record = self._make_record( + run_id, + input_data, + input_class, + cluster_version, + model_scores, + observed_mad, + raw_outputs, + ) + self._pending_records.append(record) + return { + "regime": "early", + "mad": observed_mad, + "expected_mad": None, + "drift_type": None, + "signal_id": None, + "message": ( + f"Baseline building: " + f"{self._cached_total}/{self._config.min_baseline_size} records" + ), + } + + expected_mad = self._cached_class_mads.get(input_class) + class_count = self._cached_class_count.get(input_class, 0) + + # ── REGIME 2: Novel input class ─────────────────────────────── + if expected_mad is None or class_count < self._config.min_class_records: + record = self._make_record( + run_id, + input_data, + input_class, + cluster_version, + model_scores, + observed_mad, + raw_outputs, + ) + self._pending_records.append(record) + return { + "regime": "novel_class", + "mad": observed_mad, + "expected_mad": None, + "drift_type": None, + "signal_id": None, + "message": ( + f"Novel class '{input_class}' " + f"({class_count}/{self._config.min_class_records} records). " + "Accumulating baseline data." + ), + } + + # ── REGIME 3: Mature — compare against baseline ──────────────── + threshold = expected_mad * self._config.drift_multiplier + + if observed_mad > threshold: + drift_type, outlier = self._classify_drift(model_scores, scores) + signal = DriftSignal( + signal_id=str(uuid.uuid4()), + run_id=run_id, + timestamp=datetime.now(timezone.utc), + drift_type=drift_type, + input_fingerprint=_fingerprint(input_data), + input_class=input_class, + model_scores=model_scores, + observed_mad=observed_mad, + expected_mad=expected_mad, + baseline_records=class_count, + outlier_model=outlier, + implicated_specs=list(self._spec_versions.keys()), + representative_fps=[], # filled by harness from baseline store + ) + self._pending_signals.append(signal) + return { + "regime": "drift", + "mad": observed_mad, + "expected_mad": expected_mad, + "drift_type": drift_type.value, + "signal_id": signal.signal_id, + "message": ( + f"Drift detected: MAD {observed_mad:.3f} " + f"> threshold {threshold:.3f} " + f"(expected {expected_mad:.3f} × {self._config.drift_multiplier}). " + f"Type: {drift_type.value}." + ), + } + + # Convergent — append to baseline + record = self._make_record( + run_id, + input_data, + input_class, + cluster_version, + model_scores, + observed_mad, + raw_outputs, + ) + self._pending_records.append(record) + return { + "regime": "convergent", + "mad": observed_mad, + "expected_mad": expected_mad, + "drift_type": None, + "signal_id": None, + "message": (f"Convergent: MAD {observed_mad:.3f} " f"≤ threshold {threshold:.3f}"), + } + + # ------------------------------------------------------------------ + # Drain queues (called by harness after orchestrator.run()) + # ------------------------------------------------------------------ + + def drain_signals(self) -> list[DriftSignal]: + """Return and clear all pending drift signals.""" + out = list(self._pending_signals) + self._pending_signals.clear() + return out + + def drain_records(self) -> list[ConvergenceRecord]: + """Return and clear all pending convergence records.""" + out = list(self._pending_records) + self._pending_records.clear() + return out + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _make_record( + self, + run_id: str, + input_data: dict[str, Any], + input_class: str, + cluster_version: str | None, + model_scores: dict[str, float], + mad: float, + raw_outputs: dict[str, Any], + ) -> ConvergenceRecord: + import statistics as _st + + scores = list(model_scores.values()) + consensus = _st.median(scores) + confidence = max(0.0, 1.0 - mad) + return ConvergenceRecord( + run_id=run_id, + timestamp=datetime.now(timezone.utc), + input_fingerprint=_fingerprint(input_data), + input_class=input_class, + cluster_version=cluster_version, + model_scores=model_scores, + consensus_score=consensus, + inter_model_mad=mad, + confidence=confidence, + spec_versions=dict(self._spec_versions), + raw_outputs=raw_outputs, + ) + + def _classify_drift( + self, + model_scores: dict[str, float], + scores: list[float], + ) -> tuple[DriftType, str | None]: + """ + Determine drift type from the score distribution. + + MODEL_OUTLIER: one model is > outlier_threshold σ from the others. + CRITERIA_GAP: all models disagree roughly equally. + """ + if len(scores) < 2: + return DriftType.UNKNOWN, None + + for model_id, score in model_scores.items(): + others = [s for m, s in model_scores.items() if m != model_id] + if not others: + continue + others_mean = statistics.mean(others) + others_std = statistics.stdev(others) if len(others) > 1 else 0.0 + if ( + others_std > 0 + and abs(score - others_mean) / others_std > self._config.outlier_threshold + ): + return DriftType.MODEL_OUTLIER, model_id + + return DriftType.CRITERIA_GAP, None + + +# --------------------------------------------------------------------------- +# Spec adapter (requires manifold to be installed) +# --------------------------------------------------------------------------- + + +def make_convergence_spec(monitor: ConvergenceMonitor) -> Any: + """ + Wrap a ConvergenceMonitor in a Manifold Spec. + + Lazily imports manifold.core.spec so that models.py / stores.py / + events.py remain usable without manifold installed (e.g. in tests). + + The Spec reads model_scores, input_class, and input_data from the + context data dict. These keys must be populated by the workflow's + consensus step before the invariant runs. + + Usage + ----- + monitor = ConvergenceMonitor(baseline_store, config) + spec = make_convergence_spec(monitor) + orchestrator = OrchestratorBuilder().with_spec(spec).build() + """ + try: + from manifold.core.spec import Spec, SpecResult + from manifold.core.context import Context + except ImportError as e: + raise ImportError( + "manifold must be installed to use make_convergence_spec(). " + "Install with: pip install manifold-ai" + ) from e + + class ConvergenceMonitorSpec(Spec): + """ + Invariant spec that wraps a ConvergenceMonitor. + + Always returns SpecResult.ok() — drift does not fail the workflow. + Drift information is in result.data and drained by the harness + via monitor.drain_signals() after each run. + """ + + rule_id = "convergence_monitor" + + @property + def tags(self) -> tuple[str, ...]: + return ("invariant", "convergence", "hmmv") + + def evaluate(self, context: Context, candidate=None) -> SpecResult: + cfg = monitor._config + scores = context.get_data(cfg.record_mad_field) + cls = context.get_data(cfg.record_class_field, "unknown") + raw_in = context.get_data(cfg.record_input_field, {}) + raw_out = context.get_data(cfg.record_output_field, {}) + c_ver = context.get_data("cluster_version") + + if not scores: + return SpecResult.ok( + rule_id=self.rule_id, + message="No model scores in context — skipping convergence check", + tags=self.tags, + data={"regime": "waiting"}, + ) + + result = monitor.evaluate_sync( + run_id=context.run_id, + input_data=raw_in if isinstance(raw_in, dict) else {"value": raw_in}, + input_class=cls, + cluster_version=c_ver, + model_scores=scores, + raw_outputs=raw_out if isinstance(raw_out, dict) else {}, + ) + + return SpecResult.ok( + rule_id=self.rule_id, + message=result["message"], + tags=self.tags, + data=result, + ) + + return ConvergenceMonitorSpec() diff --git a/manifold/testing/correction.py b/manifold/testing/correction.py new file mode 100644 index 0000000..1a642a8 --- /dev/null +++ b/manifold/testing/correction.py @@ -0,0 +1,664 @@ +""" +manifold.testing.correction +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +CorrectionRunner — the real implementation of the correction workflow. + +Replaces NoOpCorrectionRunner in harness.py once you wire up an LLM caller +and a ModelRunner (the two external dependencies). + +Pipeline +-------- +Given a DriftSignal, the CorrectionRunner executes four steps: + + 1. analyze(signal) → CorrectionAnalysis + Pure struct enrichment. No IO. Summarises what we know about the + drift, identifies the probable root cause, decides which spec to target. + + 2. generate_hypothesis(analysis) → Hypothesis + LLM call. Produces a proposed_change description and proposed_spec_code. + Behaviour differs by drift type: + CRITERIA_GAP → propose a spec addition/modification + MODEL_OUTLIER → recommend model audit (no spec change) + UNKNOWN → conservative: request investigation + + 3. validate(hypothesis, signal) → ValidationResult + Re-runs all models on the triggering input with the proposed criteria + as additional context, measures MAD improvement. + For MODEL_OUTLIER: validates by computing MAD excluding the outlier. + + 4. assemble → SpecProposal + Packages everything into an immutable SpecProposal for human review. + +Protocols +--------- +Two external dependencies are behind protocols so they can be stubbed in tests: + + LLMCaller — async callable: (prompt: str) → str + Wire to any LLM provider (Anthropic, OpenAI, etc.) + + ModelRunner — async callable: (input_data: dict, criteria_hint: str, + model_id: str) → float + Re-runs a single model on a single input. + Use the same model clients you use in your primary workflow. + +Both are injected into CorrectionRunner.__init__, so tests use stubs. + +Usage +----- + from manifold.testing.correction import CorrectionRunner + + async def call_llm(prompt: str) -> str: + response = await anthropic_client.messages.create(...) + return response.content[0].text + + async def run_model(input_data: dict, criteria_hint: str, model_id: str) -> float: + # Call your real model with the criteria hint as system context + ... + + runner = CorrectionRunner( + llm_caller=call_llm, + model_runner=run_model, + model_ids=["gpt-4o", "gemini-flash", "llama-3.3", "mistral-small"], + baseline_store=baseline, + ) + + # In harness: + harness = HMMVTestHarness(..., correction_runner=runner) +""" + +from __future__ import annotations + +import json +import logging +import statistics +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Awaitable, Callable + +from manifold.testing.models import ( + DriftSignal, + DriftType, + ProposalStatus, + SpecProposal, + _compute_mad, +) + +logger = logging.getLogger(__name__) + +# Protocols as type aliases (runtime duck-typing, no ABC overhead) +LLMCaller = Callable[[str], Awaitable[str]] +ModelRunner = Callable[[dict, str, str], Awaitable[float]] +# ModelRunner(input_data, criteria_hint, model_id) → score + + +# --------------------------------------------------------------------------- +# Internal data structures (pipeline-private) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class CorrectionAnalysis: + """ + Pure enrichment of a DriftSignal. No IO. + + Produced by step 1. Feeds step 2 (hypothesis generation). + """ + + signal_id: str + drift_type: DriftType + input_class: str + triggering_input: dict + + # What the models said + model_scores: dict[str, float] + observed_mad: float + expected_mad: float | None + outlier_model: str | None + + # Derived stats + agreeing_models: list[str] # models whose scores are close to consensus + disagreeing_models: list[str] # models far from consensus + + # Context + baseline_records: int + implicated_specs: list[str] # spec_ids likely responsible + + # Diagnosis + probable_cause: str # human-readable single sentence + target_spec_id: str # which spec to target + confidence_in_diagnosis: float # [0, 1] + + +@dataclass(frozen=True) +class Hypothesis: + """ + A proposed correction. Produced by step 2 (LLM call). + """ + + proposed_change: str # human-readable description + proposed_spec_code: str # the actual implementation (or recommendation) + hypothesis: str # why this change should restore convergence + target_spec_id: str + llm_raw_response: str # preserved for audit + + +@dataclass(frozen=True) +class ValidationResult: + """ + Outcome of re-running models with the proposed criteria. Produced by step 3. + """ + + validated: bool + mad_before: float + mad_after: float | None + model_scores_after: dict[str, float] + models_converged: int # models within 2× expected MAD + n_models_tested: int + validation_note: str # why validated/rejected + + +# --------------------------------------------------------------------------- +# Step 1 — Analyze +# --------------------------------------------------------------------------- + + +def analyze(signal: DriftSignal) -> CorrectionAnalysis: + """ + Pure function. Enriches a DriftSignal into a CorrectionAnalysis. + + Identifies agreeing vs disagreeing models, picks the most likely + target spec from implicated_specs, and writes a human-readable + probable_cause sentence. + """ + scores = signal.model_scores + values = list(scores.values()) + consensus = statistics.median(values) + spread = statistics.stdev(values) if len(values) > 1 else 0.0 + + agreeing = [m for m, s in scores.items() if abs(s - consensus) <= spread] + disagreeing = [m for m, s in scores.items() if abs(s - consensus) > spread] + + # When all models fall within 1 stdev (e.g. evenly split scores), + # fall back to splitting by sign relative to the consensus. + if not disagreeing and len(scores) > 1: + above = [m for m, s in scores.items() if s >= consensus] + below = [m for m, s in scores.items() if s < consensus] + if above and below: + agreeing = above + disagreeing = below + + # Target spec: prefer the first implicated spec, fallback to sentinel + target = signal.implicated_specs[0] if signal.implicated_specs else "unknown_spec" + + if signal.drift_type == DriftType.MODEL_OUTLIER: + cause = ( + f"Model '{signal.outlier_model}' diverges from the other " + f"{len(agreeing)} models on input class '{signal.input_class}'. " + "Likely cause: model update, weight drift, or architecture change. " + "Spec change is probably NOT needed — model audit is." + ) + confidence = 0.85 if signal.outlier_model else 0.6 + + elif signal.drift_type == DriftType.CRITERIA_GAP: + cause = ( + f"Models split into {len(agreeing)} agreeing and " + f"{len(disagreeing)} disagreeing on input class '{signal.input_class}'. " + f"Observed MAD {signal.observed_mad:.3f} is " + f"{signal.observed_mad / (signal.expected_mad or 1.0):.1f}× " + f"the expected {signal.expected_mad or 0.0:.3f}. " + "Likely cause: the current spec does not adequately define " + "classification criteria for this input type." + ) + confidence = 0.75 + + else: + cause = ( + f"Drift type UNKNOWN on input class '{signal.input_class}'. " + f"Insufficient baseline data ({signal.baseline_records} records) to diagnose. " + "Investigation required before proposing a change." + ) + confidence = 0.3 + + return CorrectionAnalysis( + signal_id=signal.signal_id, + drift_type=signal.drift_type, + input_class=signal.input_class, + triggering_input=signal.triggering_input, + model_scores=scores, + observed_mad=signal.observed_mad, + expected_mad=signal.expected_mad, + outlier_model=signal.outlier_model, + agreeing_models=agreeing, + disagreeing_models=disagreeing, + baseline_records=signal.baseline_records, + implicated_specs=signal.implicated_specs, + probable_cause=cause, + target_spec_id=target, + confidence_in_diagnosis=confidence, + ) + + +# --------------------------------------------------------------------------- +# Step 2 — Generate hypothesis (LLM) +# --------------------------------------------------------------------------- + + +def _build_prompt(analysis: CorrectionAnalysis, current_spec_code: str) -> str: + scores_fmt = "\n".join(f" {m}: {s:+.3f}" for m, s in sorted(analysis.model_scores.items())) + agreeing_fmt = ", ".join(analysis.agreeing_models) or "none" + disagreeing_fmt = ", ".join(analysis.disagreeing_models) or "none" + + if analysis.drift_type == DriftType.MODEL_OUTLIER: + task = f"""\ +One model ('{analysis.outlier_model}') diverges while others agree. +This is a MODEL HEALTH issue, not a criteria gap. + +Your task: +1. Confirm whether the evidence supports a model drift diagnosis. +2. Recommend what action to take (audit, replace, weight-reduce the outlier model). +3. State clearly whether a spec code change is needed (almost certainly: NO). + +proposed_spec_code should be a comment-only Python block explaining what to do: + # Model audit recommendation: ... + # No criteria change required. +""" + elif analysis.drift_type == DriftType.CRITERIA_GAP: + task = f"""\ +All models diverge from each other on input class '{analysis.input_class}'. +This is a CRITERIA GAP — the current spec does not adequately cover this input type. + +Your task: +1. Analyse what is ambiguous about the triggering input. +2. Propose a concrete, objective, verifiable criteria addition or modification. +3. Write proposed_spec_code as a Python class fragment (the new/modified evaluate() logic). + +The proposed criteria MUST be: +- Objective (checkable by diverse models independently) +- Specific (not vague like "consider context") +- Additive (extend the spec, do not rewrite it entirely) +""" + else: + task = """\ +Drift type is UNKNOWN. Insufficient baseline data to diagnose. +Propose a CONSERVATIVE investigation action — do NOT propose spec code changes. +proposed_spec_code should explain what data to collect to diagnose the issue. +""" + + return f"""\ +You are a classification spec engineer. A multi-model validation system has detected +that {len(analysis.model_scores)} architecturally diverse models diverged on a classification task. + +═══ DRIFT SUMMARY ═══════════════════════════════════════════════════ +Drift type : {analysis.drift_type.value} +Input class : {analysis.input_class} +Triggering input : {json.dumps(analysis.triggering_input, ensure_ascii=False)} +Observed MAD : {analysis.observed_mad:.4f} (expected: {analysis.expected_mad or "N/A"}) +Baseline records : {analysis.baseline_records} + +Model scores: +{scores_fmt} + +Agreeing models : {agreeing_fmt} +Disagreeing models : {disagreeing_fmt} + +Probable cause: {analysis.probable_cause} + +═══ CURRENT SPEC CODE ═══════════════════════════════════════════════ +{current_spec_code} + +═══ YOUR TASK ═══════════════════════════════════════════════════════ +{task} + +═══ RESPONSE FORMAT ═════════════════════════════════════════════════ +Respond ONLY with a valid JSON object. No markdown, no preamble. + +{{ + "proposed_change": "<2-3 sentence human-readable description of what changes and why>", + "proposed_spec_code": "", + "hypothesis": "<1-2 sentence explanation of why this change should restore convergence>", + "target_spec_id": "" +}} +""" + + +async def generate_hypothesis( + analysis: CorrectionAnalysis, + llm_caller: LLMCaller, + current_spec_code: str = "# Spec code not provided", + max_retries: int = 2, +) -> Hypothesis | None: + """ + Call the LLM to generate a correction hypothesis. + + Returns None if the LLM fails to produce parseable output after retries. + """ + prompt = _build_prompt(analysis, current_spec_code) + + for attempt in range(1, max_retries + 1): + try: + raw = await llm_caller(prompt) + parsed = _parse_llm_response(raw) + if parsed: + return Hypothesis( + proposed_change=parsed["proposed_change"], + proposed_spec_code=parsed["proposed_spec_code"], + hypothesis=parsed["hypothesis"], + target_spec_id=parsed.get("target_spec_id", analysis.target_spec_id), + llm_raw_response=raw, + ) + logger.warning("Attempt %d: LLM response not parseable, retrying", attempt) + except Exception as e: + logger.warning("Attempt %d: LLM caller raised: %s", attempt, e) + + logger.error("All %d attempts failed for signal %s", max_retries, analysis.signal_id) + return None + + +def _parse_llm_response(raw: str) -> dict | None: + """ + Extract JSON from LLM response. Tolerates markdown code fences. + Returns None if required fields are missing. + """ + text = raw.strip() + + # Strip markdown fences if present + if text.startswith("```"): + lines = text.split("\n") + # Remove first line (```json or ```) and last line (```) + inner = [line for line in lines[1:] if line.strip() != "```"] + text = "\n".join(inner).strip() + + try: + data = json.loads(text) + except json.JSONDecodeError: + # Try to find JSON object within larger text + start = text.find("{") + end = text.rfind("}") + 1 + if start == -1 or end == 0: + return None + try: + data = json.loads(text[start:end]) + except json.JSONDecodeError: + return None + + required = {"proposed_change", "proposed_spec_code", "hypothesis"} + if not required.issubset(data.keys()): + logger.warning("LLM response missing fields: %s", required - data.keys()) + return None + + # Validate non-empty + for key in required: + if not data[key] or not str(data[key]).strip(): + logger.warning("LLM response has empty field: %s", key) + return None + + return dict(data) + + +# --------------------------------------------------------------------------- +# Step 3 — Validate +# --------------------------------------------------------------------------- + + +async def validate( + hypothesis: Hypothesis, + signal: DriftSignal, + model_runner: ModelRunner, + model_ids: list[str], + expected_mad: float | None = None, + improvement_threshold: float = 0.3, +) -> ValidationResult: + """ + Re-run all models on the triggering input with the proposed criteria + as additional context. Measure whether MAD improves. + + For MODEL_OUTLIER: validates by computing MAD excluding the outlier. + For CRITERIA_GAP: re-runs all models with proposed_spec_code as criteria hint. + + Args + ---- + hypothesis Proposed change from step 2. + signal The original DriftSignal (has triggering_input and scores). + model_runner Async callable: (input_data, criteria_hint, model_id) → float. + model_ids All model IDs to test. + expected_mad Historical baseline MAD (from signal or snapshot). + improvement_threshold Fractional MAD reduction required to declare validated. + Default 0.3 = 30% reduction required. + """ + mad_before = signal.observed_mad + + # ── Special case: MODEL_OUTLIER ──────────────────────────────────────── + # Validation = does MAD drop to normal when we exclude the outlier? + # No model re-run needed — we already have the scores. + if signal.drift_type == DriftType.MODEL_OUTLIER and signal.outlier_model: + scores_without_outlier = { + m: s for m, s in signal.model_scores.items() if m != signal.outlier_model + } + mad_without = _compute_mad(list(scores_without_outlier.values())) + threshold = (expected_mad or mad_before) * 2.0 + + validated = mad_without <= threshold + note = ( + f"MAD without '{signal.outlier_model}': {mad_without:.4f} " + f"({'≤' if validated else '>'} threshold {threshold:.4f}). " + + ( + "Model outlier confirmed — audit model." + if validated + else "MAD still high without outlier — may be CRITERIA_GAP instead." + ) + ) + return ValidationResult( + validated=validated, + mad_before=mad_before, + mad_after=mad_without, + model_scores_after=scores_without_outlier, + models_converged=len(scores_without_outlier) if validated else 0, + n_models_tested=len(scores_without_outlier), + validation_note=note, + ) + + # ── CRITERIA_GAP and UNKNOWN: re-run models ──────────────────────────── + if not signal.triggering_input: + return ValidationResult( + validated=False, + mad_before=mad_before, + mad_after=None, + model_scores_after={}, + models_converged=0, + n_models_tested=0, + validation_note=( + "Cannot validate: triggering_input is empty. " + "Harness must populate DriftSignal.triggering_input." + ), + ) + + # Run all models with proposed criteria as additional context + criteria_hint = ( + f"Apply the following updated classification criteria:\n\n" + f"{hypothesis.proposed_spec_code}\n\n" + f"Rationale: {hypothesis.hypothesis}" + ) + + scores_after: dict[str, float] = {} + failed_models: list[str] = [] + + for model_id in model_ids: + try: + score = await model_runner(signal.triggering_input, criteria_hint, model_id) + scores_after[model_id] = score + except Exception as e: + logger.warning("Model '%s' failed during validation: %s", model_id, e) + failed_models.append(model_id) + + if not scores_after: + return ValidationResult( + validated=False, + mad_before=mad_before, + mad_after=None, + model_scores_after={}, + models_converged=0, + n_models_tested=0, + validation_note=f"All models failed during validation: {failed_models}", + ) + + mad_after = _compute_mad(list(scores_after.values())) + reduction = (mad_before - mad_after) / mad_before if mad_before > 0 else 0.0 + validated = reduction >= improvement_threshold + + # Count converged models: within 2× expected MAD of consensus + if expected_mad and expected_mad > 0: + consensus = statistics.median(list(scores_after.values())) + converged = sum(1 for s in scores_after.values() if abs(s - consensus) <= expected_mad * 2) + else: + converged = len(scores_after) + + note = ( + f"MAD {mad_before:.4f} → {mad_after:.4f} " + f"({reduction * 100:+.1f}% change, threshold {improvement_threshold * 100:.0f}%). " + ) + if failed_models: + note += f"Failed models (excluded): {failed_models}. " + note += f"{converged}/{len(scores_after)} models converged after proposed change. " + ( + "✓ Validated." if validated else "✗ Insufficient improvement — proposal rejected." + ) + + return ValidationResult( + validated=validated, + mad_before=mad_before, + mad_after=mad_after, + model_scores_after=scores_after, + models_converged=converged, + n_models_tested=len(scores_after), + validation_note=note, + ) + + +# --------------------------------------------------------------------------- +# CorrectionRunner — the public interface +# --------------------------------------------------------------------------- + + +class CorrectionRunner: + """ + Real correction runner. Replaces NoOpCorrectionRunner. + + Executes: analyze → generate_hypothesis → validate → assemble SpecProposal. + + All four steps are logged. If any non-final step fails, the runner + returns None (not an exception) — the EventConsumer handles the + CORRECTION_FAILED event. + + Args + ---- + llm_caller Async callable: (prompt: str) → str. + model_runner Async callable: (input_data, criteria_hint, model_id) → float. + model_ids All model IDs used in the primary workflow. + baseline_store For fetching current spec code (optional — uses placeholder if None). + improvement_threshold MAD reduction fraction required to validate. Default 0.3. + max_llm_retries How many times to retry a failed LLM call. Default 2. + """ + + def __init__( + self, + llm_caller: LLMCaller, + model_runner: ModelRunner, + model_ids: list[str], + baseline_store: Any = None, # BaselineStore (optional) + current_spec_codes: dict[str, str] | None = None, + improvement_threshold: float = 0.3, + max_llm_retries: int = 2, + ) -> None: + self._llm = llm_caller + self._model_runner = model_runner + self._model_ids = model_ids + self._baseline = baseline_store + self._spec_codes = current_spec_codes or {} + self._threshold = improvement_threshold + self._max_retries = max_llm_retries + + async def run(self, signal: DriftSignal) -> SpecProposal | None: + """ + Full correction pipeline. + + Returns a SpecProposal (proposal_status=VALIDATED or REJECTED) + or None if the pipeline itself failed (e.g. LLM unreachable). + """ + logger.info( + "CorrectionRunner: starting for signal_id=%s type=%s class=%s", + signal.signal_id, + signal.drift_type.value, + signal.input_class, + ) + + # ── Step 1: Analyze ───────────────────────────────────────────────── + analysis = analyze(signal) + logger.info( + "Analysis: cause=%s confidence=%.2f target_spec=%s", + analysis.drift_type.value, + analysis.confidence_in_diagnosis, + analysis.target_spec_id, + ) + + # ── Step 2: Generate hypothesis ───────────────────────────────────── + current_code = self._spec_codes.get( + analysis.target_spec_id, f"# Spec '{analysis.target_spec_id}' code not available" + ) + hypothesis = await generate_hypothesis(analysis, self._llm, current_code, self._max_retries) + if hypothesis is None: + logger.error("Hypothesis generation failed for signal_id=%s", signal.signal_id) + return None + logger.info( + "Hypothesis: target=%s change=%s", + hypothesis.target_spec_id, + ( + hypothesis.proposed_change[:80] + "..." + if len(hypothesis.proposed_change) > 80 + else hypothesis.proposed_change + ), + ) + + # ── Step 3: Validate ──────────────────────────────────────────────── + validation = await validate( + hypothesis=hypothesis, + signal=signal, + model_runner=self._model_runner, + model_ids=self._model_ids, + expected_mad=signal.expected_mad, + improvement_threshold=self._threshold, + ) + logger.info( + "Validation: validated=%s mad_before=%.4f mad_after=%s note=%s", + validation.validated, + validation.mad_before, + f"{validation.mad_after:.4f}" if validation.mad_after is not None else "N/A", + validation.validation_note, + ) + + # ── Step 4: Assemble SpecProposal ──────────────────────────────────── + status = ProposalStatus.VALIDATED if validation.validated else ProposalStatus.REJECTED + + proposal = SpecProposal( + proposal_id=str(uuid.uuid4()), + created_at=datetime.now(timezone.utc), + triggered_by_signal_id=signal.signal_id, + target_spec_id=hypothesis.target_spec_id, + current_spec_version="unknown", # caller can enrich from registry + proposed_change=hypothesis.proposed_change, + proposed_spec_code=hypothesis.proposed_spec_code, + hypothesis=hypothesis.hypothesis, + drift_examples=[signal.input_fingerprint], + convergence_examples=signal.representative_fps, + proposal_status=status, + validation_mad_before=validation.mad_before, + validation_mad_after=validation.mad_after, + models_converged_after=validation.models_converged, + ) + + logger.info( + "Proposal assembled: proposal_id=%s status=%s improvement=%.4f", + proposal.proposal_id, + status.value, + proposal.mad_improvement or 0.0, + ) + return proposal diff --git a/manifold/testing/events.py b/manifold/testing/events.py new file mode 100644 index 0000000..16a845b --- /dev/null +++ b/manifold/testing/events.py @@ -0,0 +1,672 @@ +""" +manifold.testing.events +~~~~~~~~~~~~~~~~~~~~~~~ +Event schema and central EventConsumer. + +Architecture +------------ +The system is fully event-driven. Every significant state transition +emits an Event. The EventConsumer is the single entry point that +receives all events and decides what to do next. + +This means: +- No direct coupling between components (baseline store, correction + workflow, spec registry don't call each other) +- Full audit trail of every decision +- Easy to replay, test, and extend +- Natural backpressure: the consumer processes one event at a time + per queue; concurrent runs emit events that queue cleanly + +Event flow +---------- + Primary workflow completes + → RUN_COMPLETED + + EventConsumer receives RUN_COMPLETED + → reads artifacts from context + → if DriftSignal found: emits DRIFT_DETECTED + → if converged: emits BASELINE_UPDATED + + EventConsumer receives DRIFT_DETECTED + → schedules correction workflow + → emits CORRECTION_STARTED + + Correction workflow completes + → CORRECTION_COMPLETED + + EventConsumer receives CORRECTION_COMPLETED + → reads SpecProposal from result + → writes to ProposalStore + → emits PROPOSAL_READY + + Human approves proposal + → PROPOSAL_APPROVED + + EventConsumer receives PROPOSAL_APPROVED + → applies to SpecRegistry + → emits SPEC_UPDATED + + EventConsumer receives SPEC_UPDATED + → marks affected baseline records as stale (not deleted) + → triggers snapshot if baseline is large enough + → emits BASELINE_SNAPSHOT_TAKEN (if triggered) +""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Awaitable, Callable + +from manifold.testing.models import ( + BaselineSnapshot, + ConvergenceRecord, + DriftSignal, + SpecProposal, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Event types +# --------------------------------------------------------------------------- + + +class EventType(Enum): + """ + All events in the system. Ordered by rough lifecycle position. + + PRIMARY WORKFLOW EVENTS + ----------------------- + RUN_COMPLETED — a primary workflow run finished (success or fail) + BASELINE_UPDATED — a convergent record was appended to the baseline + DRIFT_DETECTED — ConvergenceMonitor emitted a DriftSignal + + CORRECTION WORKFLOW EVENTS + -------------------------- + CORRECTION_STARTED — correction workflow was triggered for a signal + CORRECTION_COMPLETED — correction workflow finished + CORRECTION_FAILED — correction workflow could not produce a proposal + + PROPOSAL EVENTS + --------------- + PROPOSAL_READY — SpecProposal written, awaiting human review + PROPOSAL_APPROVED — human approved a proposal + PROPOSAL_REJECTED — human rejected a proposal + + SPEC REGISTRY EVENTS + -------------------- + SPEC_UPDATED — a spec was updated in the registry + BASELINE_STALE — some baseline records were marked stale after spec change + + SNAPSHOT EVENTS + --------------- + BASELINE_SNAPSHOT_TAKEN — a new snapshot was persisted + """ + + # Primary workflow + RUN_COMPLETED = "run_completed" + BASELINE_UPDATED = "baseline_updated" + DRIFT_DETECTED = "drift_detected" + + # Correction workflow + CORRECTION_STARTED = "correction_started" + CORRECTION_COMPLETED = "correction_completed" + CORRECTION_FAILED = "correction_failed" + + # Proposals + PROPOSAL_READY = "proposal_ready" + PROPOSAL_APPROVED = "proposal_approved" + PROPOSAL_REJECTED = "proposal_rejected" + + # Spec registry + SPEC_UPDATED = "spec_updated" + BASELINE_STALE = "baseline_stale" + + # Snapshots + BASELINE_SNAPSHOT_TAKEN = "baseline_snapshot_taken" + + +# --------------------------------------------------------------------------- +# Event +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class Event: + """ + A single event in the system. + + Every state transition emits an Event. Events are immutable + once created. The payload contains the relevant data for handlers. + + Fields + ------ + event_id : globally unique identifier + event_type : what happened + timestamp : UTC + source : component that emitted the event (for logging/debugging) + payload : type-specific data (see payload shapes below) + correlation_id : links events that belong to the same logical chain + (e.g. all events from one drift→correction→proposal cycle + share a correlation_id) + """ + + event_id: str + event_type: EventType + timestamp: datetime + source: str + payload: dict[str, Any] + correlation_id: str | None = None + + @classmethod + def create( + cls, + event_type: EventType, + source: str, + payload: dict[str, Any], + correlation_id: str | None = None, + ) -> "Event": + return cls( + event_id=str(uuid.uuid4()), + event_type=event_type, + timestamp=datetime.now(timezone.utc), + source=source, + payload=payload, + correlation_id=correlation_id or str(uuid.uuid4()), + ) + + def to_dict(self) -> dict: + return { + "event_id": self.event_id, + "event_type": self.event_type.value, + "timestamp": self.timestamp.isoformat(), + "source": self.source, + "payload": self.payload, + "correlation_id": self.correlation_id, + } + + +# --------------------------------------------------------------------------- +# Payload shapes (documentation + helpers) +# --------------------------------------------------------------------------- +# +# These are not enforced at runtime (dicts are flexible), but all handlers +# should follow these shapes. Each factory function is the canonical way +# to build a payload for a given event type. + + +def payload_run_completed( + run_id: str, + success: bool, + had_drift: bool, + drift_signal_id: str | None, + convergence_record: ConvergenceRecord | None, +) -> dict: + return { + "run_id": run_id, + "success": success, + "had_drift": had_drift, + "drift_signal_id": drift_signal_id, + "convergence_record": convergence_record.to_dict() if convergence_record else None, + } + + +def payload_drift_detected(signal: DriftSignal) -> dict: + return {"drift_signal": signal.to_dict()} + + +def payload_baseline_updated(record: ConvergenceRecord) -> dict: + return {"convergence_record": record.to_dict()} + + +def payload_correction_started( + signal_id: str, + workflow_run_id: str, +) -> dict: + return {"signal_id": signal_id, "workflow_run_id": workflow_run_id} + + +def payload_correction_completed( + signal_id: str, + proposal: SpecProposal, +) -> dict: + return {"signal_id": signal_id, "proposal": proposal.to_dict()} + + +def payload_correction_failed( + signal_id: str, + reason: str, +) -> dict: + return {"signal_id": signal_id, "reason": reason} + + +def payload_proposal_ready(proposal: SpecProposal) -> dict: + return {"proposal": proposal.to_dict()} + + +def payload_proposal_approved( + proposal_id: str, + reviewer_notes: str, +) -> dict: + return {"proposal_id": proposal_id, "reviewer_notes": reviewer_notes} + + +def payload_proposal_rejected( + proposal_id: str, + reviewer_notes: str, +) -> dict: + return {"proposal_id": proposal_id, "reviewer_notes": reviewer_notes} + + +def payload_spec_updated( + spec_id: str, + old_version: str, + new_version: str, + proposal_id: str, +) -> dict: + return { + "spec_id": spec_id, + "old_version": old_version, + "new_version": new_version, + "proposal_id": proposal_id, + } + + +def payload_baseline_stale( + spec_id: str, + stale_record_count: int, +) -> dict: + return {"spec_id": spec_id, "stale_record_count": stale_record_count} + + +def payload_snapshot_taken(snapshot: BaselineSnapshot) -> dict: + return {"snapshot": snapshot.to_dict()} + + +# --------------------------------------------------------------------------- +# EventBus — thin async pub/sub +# --------------------------------------------------------------------------- + +Handler = Callable[[Event], Awaitable[None]] + + +class EventBus: + """ + Minimal async pub/sub bus. + + Handlers are registered per EventType. When an event is emitted, + all registered handlers for that type are called concurrently. + + This is intentionally thin — no persistence, no retry logic at this + layer. The EventConsumer (below) is responsible for durability and + error handling. The bus is just the wiring. + + Usage + ----- + bus = EventBus() + bus.subscribe(EventType.DRIFT_DETECTED, my_handler) + await bus.emit(Event.create(EventType.DRIFT_DETECTED, ...)) + """ + + def __init__(self) -> None: + self._handlers: dict[EventType, list[Handler]] = {} + + def subscribe(self, event_type: EventType, handler: Handler) -> None: + """Register a handler for an event type.""" + self._handlers.setdefault(event_type, []).append(handler) + + def subscribe_many( + self, + handlers: dict[EventType, Handler | list[Handler]], + ) -> None: + """Register multiple handlers at once.""" + for event_type, handler_or_list in handlers.items(): + if isinstance(handler_or_list, list): + for h in handler_or_list: + self.subscribe(event_type, h) + else: + self.subscribe(event_type, handler_or_list) + + async def emit(self, event: Event) -> None: + """ + Emit an event and await all handlers. + + Handlers run concurrently. If any handler raises, the exception + is logged but does not prevent other handlers from running. + The bus never raises. + """ + handlers = self._handlers.get(event.event_type, []) + if not handlers: + logger.debug("No handlers for %s", event.event_type.value) + return + + results = await asyncio.gather( + *[h(event) for h in handlers], + return_exceptions=True, + ) + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error( + "Handler %d for %s raised: %s", + i, + event.event_type.value, + result, + exc_info=result, + ) + + +# --------------------------------------------------------------------------- +# EventConsumer — the nervous system +# --------------------------------------------------------------------------- + + +class EventConsumer: + """ + Central coordinator. All routing logic lives here. + + The EventConsumer subscribes to the EventBus and decides what to do + in response to each event. It does not contain business logic — + it delegates to stores, the correction workflow runner, and the spec + registry. Its job is coordination only. + + Dependencies (injected, all behind protocols defined in stores.py) + ---------- + baseline_store : reads/writes ConvergenceRecords + snapshot_store : reads/writes BaselineSnapshots + proposal_store : reads/writes SpecProposals + spec_registry : manages spec versions, applies proposals + correction_runner : runs the correction workflow + bus : the EventBus to emit downstream events + + Configuration + ------------- + snapshot_interval : take a snapshot every N new convergence records + """ + + def __init__( + self, + baseline_store: Any, # BaselineStore protocol (see stores.py) + snapshot_store: Any, # SnapshotStore protocol + proposal_store: Any, # ProposalStore protocol + spec_registry: Any, # SpecRegistry protocol + correction_runner: Any, # CorrectionRunner protocol + bus: EventBus, + snapshot_interval: int = 100, + ) -> None: + self._baseline_store = baseline_store + self._snapshot_store = snapshot_store + self._proposal_store = proposal_store + self._spec_registry = spec_registry + self._correction_runner = correction_runner + self._bus = bus + self._snapshot_interval = snapshot_interval + + # Wire up handlers + bus.subscribe_many( + { + EventType.RUN_COMPLETED: self._on_run_completed, + EventType.DRIFT_DETECTED: self._on_drift_detected, + EventType.CORRECTION_COMPLETED: self._on_correction_completed, + EventType.CORRECTION_FAILED: self._on_correction_failed, + EventType.PROPOSAL_APPROVED: self._on_proposal_approved, + EventType.PROPOSAL_REJECTED: self._on_proposal_rejected, + EventType.SPEC_UPDATED: self._on_spec_updated, + } + ) + + # ------------------------------------------------------------------ + # Handlers + # ------------------------------------------------------------------ + + async def _on_run_completed(self, event: Event) -> None: + """ + A primary workflow run finished. + + If it had drift: emit DRIFT_DETECTED. + If it converged: append record to baseline, maybe take snapshot. + """ + p = event.payload + logger.info( + "Run completed: run_id=%s success=%s drift=%s", + p["run_id"], + p["success"], + p["had_drift"], + ) + + if p["had_drift"] and p.get("drift_signal_id"): + # Retrieve the DriftSignal from the baseline store's signal log + signal = await self._baseline_store.get_signal(p["drift_signal_id"]) + if signal: + await self._bus.emit( + Event.create( + EventType.DRIFT_DETECTED, + source="event_consumer", + payload=payload_drift_detected(signal), + correlation_id=event.correlation_id, + ) + ) + return + + record_dict = p.get("convergence_record") + if record_dict: + record = ConvergenceRecord.from_dict(record_dict) + await self._baseline_store.append(record) + await self._bus.emit( + Event.create( + EventType.BASELINE_UPDATED, + source="event_consumer", + payload=payload_baseline_updated(record), + correlation_id=event.correlation_id, + ) + ) + + # Take snapshot if interval hit + count = await self._baseline_store.total_records() + if count % self._snapshot_interval == 0: + await self._maybe_take_snapshot(correlation_id=event.correlation_id) + + async def _on_drift_detected(self, event: Event) -> None: + """ + Drift was detected. Start the correction workflow. + """ + signal = DriftSignal.from_dict(event.payload["drift_signal"]) + logger.warning( + "Drift detected: signal_id=%s type=%s class=%s mad=%.3f (expected=%.3f)", + signal.signal_id, + signal.drift_type.value, + signal.input_class, + signal.observed_mad, + signal.expected_mad or 0.0, + ) + + workflow_run_id = str(uuid.uuid4()) + + await self._bus.emit( + Event.create( + EventType.CORRECTION_STARTED, + source="event_consumer", + payload=payload_correction_started(signal.signal_id, workflow_run_id), + correlation_id=event.correlation_id, + ) + ) + + # Run correction workflow (async, non-blocking for bus) + asyncio.create_task(self._run_correction(signal, workflow_run_id, event.correlation_id)) + + async def _run_correction( + self, + signal: DriftSignal, + workflow_run_id: str, + correlation_id: str | None, + ) -> None: + """Run the correction workflow and emit outcome event.""" + try: + proposal = await self._correction_runner.run(signal) + if proposal is not None: + await self._bus.emit( + Event.create( + EventType.CORRECTION_COMPLETED, + source="event_consumer.correction_runner", + payload=payload_correction_completed(signal.signal_id, proposal), + correlation_id=correlation_id, + ) + ) + else: + await self._bus.emit( + Event.create( + EventType.CORRECTION_FAILED, + source="event_consumer.correction_runner", + payload=payload_correction_failed( + signal.signal_id, "Correction workflow produced no proposal" + ), + correlation_id=correlation_id, + ) + ) + except Exception as e: + logger.error("Correction workflow raised: %s", e, exc_info=True) + await self._bus.emit( + Event.create( + EventType.CORRECTION_FAILED, + source="event_consumer.correction_runner", + payload=payload_correction_failed(signal.signal_id, str(e)), + correlation_id=correlation_id, + ) + ) + + async def _on_correction_completed(self, event: Event) -> None: + """Correction workflow produced a proposal. Write it and notify.""" + proposal = SpecProposal.from_dict(event.payload["proposal"]) + await self._proposal_store.write(proposal) + logger.info( + "Proposal ready: proposal_id=%s spec=%s mad_improvement=%.3f", + proposal.proposal_id, + proposal.target_spec_id, + proposal.mad_improvement or 0.0, + ) + await self._bus.emit( + Event.create( + EventType.PROPOSAL_READY, + source="event_consumer", + payload=payload_proposal_ready(proposal), + correlation_id=event.correlation_id, + ) + ) + + async def _on_correction_failed(self, event: Event) -> None: + """Correction workflow could not produce a proposal. Log and escalate.""" + logger.error( + "Correction failed: signal_id=%s reason=%s", + event.payload["signal_id"], + event.payload["reason"], + ) + # TODO: escalation hook (e.g. Slack notification, PagerDuty) + + async def _on_proposal_approved(self, event: Event) -> None: + """Human approved a proposal. Apply it to the spec registry.""" + proposal_id = event.payload["proposal_id"] + + proposal = await self._proposal_store.get(proposal_id) + if proposal is None: + logger.error("Approved unknown proposal_id=%s", proposal_id) + return + + old_version = proposal.current_spec_version + new_version = await self._spec_registry.apply_proposal(proposal) + + await self._bus.emit( + Event.create( + EventType.SPEC_UPDATED, + source="event_consumer", + payload=payload_spec_updated( + spec_id=proposal.target_spec_id, + old_version=old_version, + new_version=new_version, + proposal_id=proposal_id, + ), + correlation_id=event.correlation_id, + ) + ) + + async def _on_proposal_rejected(self, event: Event) -> None: + """Human rejected a proposal. Update status, no spec change.""" + proposal_id = event.payload["proposal_id"] + logger.info("Proposal rejected: proposal_id=%s", proposal_id) + await self._proposal_store.mark_rejected( + proposal_id, + event.payload.get("reviewer_notes", ""), + ) + + async def _on_spec_updated(self, event: Event) -> None: + """ + A spec changed. Mark affected baseline records as stale. + + Records collected under the old spec version may no longer be + comparable to new runs. We mark them stale (not delete — they + are still evidence of what the old spec produced). + """ + spec_id = event.payload["spec_id"] + old_version = event.payload["old_version"] + + stale_count = await self._baseline_store.mark_stale_for_spec_version(spec_id, old_version) + + logger.info( + "Marked %d baseline records stale after spec update: spec=%s v%s→v%s", + stale_count, + spec_id, + old_version, + event.payload["new_version"], + ) + + await self._bus.emit( + Event.create( + EventType.BASELINE_STALE, + source="event_consumer", + payload=payload_baseline_stale(spec_id, stale_count), + correlation_id=event.correlation_id, + ) + ) + + # Take a snapshot to checkpoint the pre-change baseline + await self._maybe_take_snapshot( + notes=f"Triggered by spec update: {spec_id}", + correlation_id=event.correlation_id, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + async def _maybe_take_snapshot( + self, + notes: str = "", + correlation_id: str | None = None, + ) -> None: + """Take a baseline snapshot and emit BASELINE_SNAPSHOT_TAKEN.""" + try: + snapshot = await self._baseline_store.take_snapshot( + spec_registry=self._spec_registry, + notes=notes, + ) + await self._snapshot_store.write(snapshot) + await self._bus.emit( + Event.create( + EventType.BASELINE_SNAPSHOT_TAKEN, + source="event_consumer", + payload=payload_snapshot_taken(snapshot), + correlation_id=correlation_id, + ) + ) + logger.info( + "Snapshot taken: snapshot_id=%s total_records=%d", + snapshot.snapshot_id, + snapshot.total_records, + ) + except Exception as e: + logger.error("Failed to take snapshot: %s", e, exc_info=True) diff --git a/manifold/testing/harness.py b/manifold/testing/harness.py new file mode 100644 index 0000000..81ac206 --- /dev/null +++ b/manifold/testing/harness.py @@ -0,0 +1,536 @@ +""" +manifold.testing.harness +~~~~~~~~~~~~~~~~~~~~~~~~ +HMMVTestHarness — the "fill in your models and go" preset. + +Usage +----- + from manifold.testing import HMMVTestHarness + from manifold.testing.stores import SQLiteBaselineStore + + harness = HMMVTestHarness( + models=["gpt-4o", "gemini-flash", "llama-3.3", "mistral-small"], + workflow_manifest="classify.yaml", + baseline_store=SQLiteBaselineStore("baseline.db"), + ) + + await harness.setup() + result = await harness.run({"name": "Caritas Berlin", "type": "welfare"}) + + print(result.regime) # "convergent" | "drift" | "early" | "novel_class" + print(result.consensus_score) # float + print(result.drift_signal) # DriftSignal | None + +Design +------ +The harness owns the wiring between all components: + + 1. Builds EventBus + EventConsumer + 2. Creates ConvergenceMonitor with the user's baseline store + 3. Wraps it in a Spec via make_convergence_spec() + 4. Builds the Orchestrator via OrchestratorBuilder + 5. After each run: + a. Drains pending records → writes to baseline store + b. Drains pending signals → stores signal → emits RUN_COMPLETED event + c. Refreshes baseline cache for next run + 6. Exposes hooks for human review (on_proposal_ready callback) + +The harness does NOT decide what happens to proposals — that is the +EventConsumer's job. The harness just emits events and wires the plumbing. + +Correction runner +----------------- +The harness accepts an optional `correction_runner` argument. If None, +a NoOpCorrectionRunner is used which logs signals but produces no proposals. +Replace with a real implementation when the correction workflow is built. +""" + +from __future__ import annotations + +import asyncio +import logging +import statistics +import uuid +from dataclasses import dataclass +from typing import Any, Awaitable, Callable + +from manifold.testing.convergence import ( + ConvergenceConfig, + ConvergenceMonitor, + make_convergence_spec, +) +from manifold.testing.events import ( + Event, + EventBus, + EventConsumer, + EventType, + payload_run_completed, +) +from manifold.testing.models import DriftSignal, SpecProposal +from manifold.testing.stores import ( + InMemoryProposalStore, + InMemorySnapshotStore, + InMemorySpecRegistry, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# HMMVResult — what harness.run() returns +# --------------------------------------------------------------------------- + + +@dataclass +class HMMVResult: + """ + Result of a single harness run. + + Fields + ------ + run_id : unique run identifier + success : whether the primary workflow completed + regime : "early" | "novel_class" | "convergent" | "drift" + consensus_score : median across model scores (None if workflow failed) + model_scores : {model_id: score} + inter_model_mad : mean absolute deviation + input_class : cluster label assigned to this input + drift_signal : populated if regime == "drift", else None + workflow_summary: summary string from the underlying WorkflowResult + error : populated if success == False + """ + + run_id: str + success: bool + regime: str + consensus_score: float | None + model_scores: dict[str, float] + inter_model_mad: float + input_class: str + drift_signal: DriftSignal | None = None + workflow_summary: str = "" + error: str | None = None + + +# --------------------------------------------------------------------------- +# No-op correction runner (placeholder) +# --------------------------------------------------------------------------- + + +class NoOpCorrectionRunner: + """ + Placeholder correction runner. + + Logs drift signals but produces no SpecProposal. Replace with a real + implementation backed by a correction workflow manifest. + """ + + async def run(self, signal: DriftSignal) -> SpecProposal | None: + logger.warning( + "NoOpCorrectionRunner received drift signal %s (type=%s, class=%s). " + "No proposal generated — wire up a real correction runner.", + signal.signal_id, + signal.drift_type.value, + signal.input_class, + ) + return None + + +# --------------------------------------------------------------------------- +# HMMVTestHarness +# --------------------------------------------------------------------------- + + +class HMMVTestHarness: + """ + Preset harness for Heterogeneous Multi-Model Validation workflows. + + Minimal required arguments: + models — list of model identifiers (used to name agents) + workflow_manifest — path to your Manifold workflow YAML + baseline_store — BaselineStore implementation + + Everything else has sensible defaults. + + Lifecycle + --------- + 1. Instantiate + 2. Call await harness.setup() ← initialises stores and caches + 3. Call await harness.run(input) ← as many times as needed + 4. Human reviews proposals via harness.pending_proposals() + 5. Call await harness.approve_proposal(id, notes) to apply a proposal + """ + + def __init__( + self, + models: list[str], + workflow_manifest: str, + baseline_store: Any, + *, + # Optional stores (in-memory defaults for development) + snapshot_store: Any | None = None, + proposal_store: Any | None = None, + spec_registry: Any | None = None, + # Correction runner (no-op by default) + correction_runner: Any | None = None, + # Convergence tuning + config: ConvergenceConfig | None = None, + # Snapshot frequency + snapshot_interval: int = 100, + # Hooks + on_proposal_ready: Callable[[SpecProposal], Awaitable[None]] | None = None, + on_drift_detected: Callable[[DriftSignal], Awaitable[None]] | None = None, + ) -> None: + self._models = models + self._manifest_path = workflow_manifest + self._baseline = baseline_store + self._snapshots = snapshot_store or InMemorySnapshotStore() + self._proposals = proposal_store or InMemoryProposalStore() + self._registry = spec_registry or InMemorySpecRegistry() + self._correction_runner = correction_runner or NoOpCorrectionRunner() + self._config = config or ConvergenceConfig() + self._snapshot_interval = snapshot_interval + self._on_proposal_ready = on_proposal_ready + self._on_drift_detected = on_drift_detected + + # Built during setup() + self._monitor: ConvergenceMonitor | None = None + self._orchestrator: Any | None = None + self._bus: EventBus | None = None + self._consumer: EventConsumer | None = None + self._ready = False + + # ------------------------------------------------------------------ + # Setup + # ------------------------------------------------------------------ + + async def setup(self) -> None: + """ + Initialise all components. + + Must be called before run(). Idempotent. + """ + if self._ready: + return + + # Initialise baseline store if it supports it + if hasattr(self._baseline, "initialise"): + await self._baseline.initialise() + + # Build event bus + self._bus = EventBus() + + # Build correction runner wrapper that emits correction events + self._consumer = EventConsumer( + baseline_store=self._baseline, + snapshot_store=self._snapshots, + proposal_store=self._proposals, + spec_registry=self._registry, + correction_runner=self._correction_runner, + bus=self._bus, + snapshot_interval=self._snapshot_interval, + ) + + # Wire user callbacks + if self._on_proposal_ready: + + async def _proposal_handler(event: Event) -> None: + proposal = SpecProposal.from_dict(event.payload["proposal"]) + await self._on_proposal_ready(proposal) # type: ignore[misc] + + self._bus.subscribe(EventType.PROPOSAL_READY, _proposal_handler) + + if self._on_drift_detected: + + async def _drift_handler(event: Event) -> None: + signal = DriftSignal.from_dict(event.payload["drift_signal"]) + await self._on_drift_detected(signal) # type: ignore[misc] + + self._bus.subscribe(EventType.DRIFT_DETECTED, _drift_handler) + + # Get current spec versions from registry for record annotation + spec_versions = await self._registry.current_versions() + + # Build ConvergenceMonitor + self._monitor = ConvergenceMonitor( + baseline_store=self._baseline, + config=self._config, + spec_versions=spec_versions, + ) + + # Refresh cache before building orchestrator + await self._refresh_cache() + + # Build orchestrator + self._orchestrator = self._build_orchestrator() + + self._ready = True + logger.info( + "HMMVTestHarness ready: %d models, manifest=%s", + len(self._models), + self._manifest_path, + ) + + # ------------------------------------------------------------------ + # Run + # ------------------------------------------------------------------ + + async def run( + self, + input_data: dict[str, Any], + input_class: str | None = None, + cluster_version: str | None = None, + ) -> HMMVResult: + """ + Run the primary workflow on a single input. + + Args + ---- + input_data Raw input dict (will be fingerprinted). + input_class Override cluster label. If None, must be set in + input_data or populated by a clustering step in + the workflow manifest. + cluster_version Version tag of the clustering model. + + Returns + ------- + HMMVResult with convergence info and optional DriftSignal. + """ + if not self._ready: + await self.setup() + + # Refresh baseline cache before each run + await self._refresh_cache() + + # Merge input_class into initial data if provided + initial: dict[str, Any] = { + self._config.record_input_field: input_data, + **input_data, + } + if input_class: + initial[self._config.record_class_field] = input_class + if cluster_version: + initial["cluster_version"] = cluster_version + + # Run the orchestrator + run_id = str(uuid.uuid4()) + try: + from manifold.core.context import create_context + + create_context(run_id=run_id, initial_data=initial) + assert self._orchestrator is not None, "Call setup() before run()" + workflow_result = await self._orchestrator.run(initial_data=initial) + except Exception as e: + logger.error("Orchestrator failed: %s", e, exc_info=True) + return HMMVResult( + run_id=run_id, + success=False, + regime="error", + consensus_score=None, + model_scores={}, + inter_model_mad=0.0, + input_class=input_class or "unknown", + error=str(e), + ) + + # Drain convergence monitor + assert self._monitor is not None, "Call setup() before run()" + signals = self._monitor.drain_signals() + records = self._monitor.drain_records() + + had_drift = len(signals) > 0 + drift_signal = signals[0] if signals else None + convergence_r = records[0] if records else None + + # Store signals and records + for sig in signals: + # Enrich: add representative fingerprints from baseline + fps = await self._baseline.sample_fingerprints_for_class(sig.input_class, n=5) + from dataclasses import replace + + sig = replace(sig, representative_fps=fps, triggering_input=input_data) + await self._baseline.append_signal(sig) + + # Emit RUN_COMPLETED event (triggers baseline update or correction workflow) + event = Event.create( + EventType.RUN_COMPLETED, + source="hmmv_harness", + payload=payload_run_completed( + run_id=run_id, + success=workflow_result.success, + had_drift=had_drift, + drift_signal_id=drift_signal.signal_id if drift_signal else None, + convergence_record=convergence_r, + ), + ) + await self._bus.emit(event) # type: ignore[union-attr] + + # Give event tasks a cycle to start + await asyncio.sleep(0) + + # Extract convergence result from monitor output + model_scores = workflow_result.final_context.get_data(self._config.record_mad_field, {}) + detected_class = workflow_result.final_context.get_data( + self._config.record_class_field, input_class or "unknown" + ) + + scores_list = list(model_scores.values()) + mad = _compute_mad_sync(scores_list) + consensus = statistics.median(scores_list) if scores_list else None + + # Determine regime from most recent spec result + regime = self._extract_regime(workflow_result) + + return HMMVResult( + run_id=run_id, + success=workflow_result.success, + regime=regime, + consensus_score=consensus, + model_scores=model_scores, + inter_model_mad=mad, + input_class=detected_class, + drift_signal=drift_signal, + workflow_summary=workflow_result.summary, + ) + + # ------------------------------------------------------------------ + # Proposal review API + # ------------------------------------------------------------------ + + async def pending_proposals(self) -> list[SpecProposal]: + """Return all SpecProposals awaiting human review.""" + return await self._proposals.pending_proposals() + + async def approve_proposal( + self, + proposal_id: str, + reviewer_notes: str = "", + ) -> None: + """ + Approve a proposal. This: + 1. Marks it approved in the proposal store + 2. Applies it to the spec registry + 3. Marks affected baseline records as stale + 4. Refreshes the monitor's spec_versions + """ + assert self._bus is not None, "Call setup() before approve_proposal()" + await self._bus.emit( + Event.create( + EventType.PROPOSAL_APPROVED, + source="harness.approve_proposal", + payload={"proposal_id": proposal_id, "reviewer_notes": reviewer_notes}, + ) + ) + await asyncio.sleep(0) # let EventConsumer process + # Refresh spec versions in monitor + assert self._monitor is not None + self._monitor._spec_versions = await self._registry.current_versions() + + async def reject_proposal( + self, + proposal_id: str, + reviewer_notes: str = "", + ) -> None: + """Reject a proposal without applying it.""" + assert self._bus is not None, "Call setup() before reject_proposal()" + await self._bus.emit( + Event.create( + EventType.PROPOSAL_REJECTED, + source="harness.reject_proposal", + payload={"proposal_id": proposal_id, "reviewer_notes": reviewer_notes}, + ) + ) + + # ------------------------------------------------------------------ + # Introspection + # ------------------------------------------------------------------ + + async def baseline_stats(self) -> dict: + """Return a summary of current baseline state.""" + total = await self._baseline.total_records() + snapshot = await self._snapshots.latest() + return { + "total_records": total, + "drift_detection_active": total >= self._config.min_baseline_size, + "latest_snapshot_id": snapshot.snapshot_id if snapshot else None, + "latest_snapshot_records": snapshot.total_records if snapshot else 0, + "pending_proposals": len(await self._proposals.pending_proposals()), + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + async def _refresh_cache(self) -> None: + """Update the convergence monitor's synchronous baseline cache.""" + if self._monitor is None: + return + + total = await self._baseline.total_records() + snapshot = await self._snapshots.latest() + + if snapshot and snapshot.total_records > 0: + class_mads = snapshot.mad_by_class + class_counts = snapshot.records_by_class + else: + # No snapshot yet — compute from raw records (slower, only in early regime) + class_mads, class_counts = await self._compute_class_stats() + + self._monitor.update_baseline_cache(total, class_mads, class_counts) + + async def _compute_class_stats(self) -> tuple[dict[str, float], dict[str, int]]: + """Compute per-class MAD stats directly from baseline records (no snapshot).""" + # In early regime this is called infrequently, so the cost is acceptable. + # Once a snapshot exists, _refresh_cache uses it instead. + # This is a best-effort fallback — only works for InMemoryBaselineStore. + if not hasattr(self._baseline, "_records"): + return {}, {} + records = self._baseline._records + from collections import defaultdict + + by_class: dict[str, list] = defaultdict(list) + for r in records: + by_class[r.input_class].append(r) + mads = {c: statistics.mean(r.inter_model_mad for r in rs) for c, rs in by_class.items()} + counts = {c: len(rs) for c, rs in by_class.items()} + return mads, counts + + def _build_orchestrator(self) -> Any: + """Build the Manifold orchestrator with the convergence spec injected.""" + try: + from manifold import OrchestratorBuilder + except ImportError as e: + raise ImportError( + "manifold must be installed to use HMMVTestHarness. " "pip install manifold-ai" + ) from e + + spec = make_convergence_spec(self._monitor) # type: ignore[arg-type] + + builder = OrchestratorBuilder().with_manifest_file(self._manifest_path).with_spec(spec) + return builder.build() + + def _extract_regime(self, workflow_result: Any) -> str: + """ + Pull the regime string from the convergence monitor's last spec result. + + Falls back to "convergent" if the spec result is not findable + (e.g. the workflow failed before the invariant ran). + """ + ctx = workflow_result.final_context + for entry in reversed(ctx.trace): + for spec_ref in entry.spec_results: + if spec_ref.rule_id == "convergence_monitor": + return str(spec_ref.data.get("regime", "convergent")) + return "convergent" + + +# --------------------------------------------------------------------------- +# Convenience (standalone, without manifold installed) +# --------------------------------------------------------------------------- + + +def _compute_mad_sync(values: list[float]) -> float: + if not values: + return 0.0 + mean = statistics.mean(values) + return statistics.mean(abs(v - mean) for v in values) diff --git a/manifold/testing/models.py b/manifold/testing/models.py new file mode 100644 index 0000000..4d75765 --- /dev/null +++ b/manifold/testing/models.py @@ -0,0 +1,493 @@ +""" +manifold.testing.models +~~~~~~~~~~~~~~~~~~~~~~~ +Core data structures for adaptive convergence testing. + +Design principles (consistent with manifold.core): +- All models are frozen dataclasses: immutable after creation +- All models have to_dict() for serialisation +- No business logic here — pure data, pure types +""" + +from __future__ import annotations + +import hashlib +import json +import statistics +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class DriftType(Enum): + """ + Classification of why convergence broke down. + + MODEL_OUTLIER — one model diverges; others still agree. + The criteria are fine; that model may have changed. + + CRITERIA_GAP — all models diverge from each other on a new input class. + The criteria don't cover this case yet. + + UNKNOWN — insufficient data to classify. Treated conservatively: + correction workflow runs but proposes investigation, + not a spec change. + + Note: SILENT_CONSENSUS (all agree but wrong) is explicitly out of scope. + It requires a human baseline layer and cannot be detected automatically. + """ + + MODEL_OUTLIER = "model_outlier" + CRITERIA_GAP = "criteria_gap" + UNKNOWN = "unknown" + + +class ProposalStatus(Enum): + PENDING = "pending" + VALIDATED = "validated" # technically validated via re-run + REJECTED = "rejected" # validation showed no improvement + + +class ReviewStatus(Enum): + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + MODIFIED = "modified" # human modified before approving + + +# --------------------------------------------------------------------------- +# ConvergenceRecord — the atom of the baseline +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class ConvergenceRecord: + """ + A single run that achieved acceptable convergence. + + These records accumulate in the BaselineStore and define + what "normal" looks like for each input class. Drift detection + compares live observations against this history. + + Fields + ------ + run_id : globally unique run identifier + timestamp : UTC time of the run + input_fingerprint : sha256 of canonical input (not raw input — privacy) + input_class : cluster label (assigned by clustering agent, may be + provisional until cluster model stabilises) + cluster_version : version tag of the clustering model that assigned the + label. Records with different cluster versions may not + be directly comparable — used to detect when re-clustering + is needed. + model_scores : {model_id: score} — numeric output per model + consensus_score : agreed-upon value (e.g. median across models) + inter_model_mad : mean absolute deviation across model_scores + confidence : derived confidence (1 - normalised MAD); range [0, 1] + spec_versions : {spec_id: version} at time of run — for changelog + raw_outputs : {model_id: raw output dict} — full detail for debugging + """ + + run_id: str + timestamp: datetime + input_fingerprint: str + input_class: str # cluster label + cluster_version: str | None # None = pre-stable cluster + model_scores: dict[str, float] # model_id → score + consensus_score: float + inter_model_mad: float + confidence: float # [0, 1] + spec_versions: dict[str, str] # spec_id → version + raw_outputs: dict[str, Any] = field(default_factory=dict) + + # ------------------------------------------------------------------ + # Constructors + # ------------------------------------------------------------------ + + @classmethod + def create( + cls, + run_id: str, + input_data: dict[str, Any], + input_class: str, + cluster_version: str | None, + model_scores: dict[str, float], + spec_versions: dict[str, str], + raw_outputs: dict[str, Any] | None = None, + ) -> "ConvergenceRecord": + """ + Factory that computes derived fields automatically. + + Args: + run_id: Unique run ID from orchestrator. + input_data: The raw input (used only for fingerprinting). + input_class: Cluster label assigned to this input. + cluster_version: Version of the clustering model. + model_scores: {model_id: numeric_score}. + spec_versions: {spec_id: version_string}. + raw_outputs: Optional full outputs for debugging. + """ + scores = list(model_scores.values()) + if not scores: + raise ValueError("model_scores must not be empty") + + mad = _compute_mad(scores) + consensus = statistics.median(scores) + + # Normalise MAD to [0, 1] using expected max range. + # Scores are expected to be on [-1, 1] → max MAD = 1.0 + confidence = max(0.0, 1.0 - mad) + + fingerprint = _fingerprint(input_data) + + return cls( + run_id=run_id, + timestamp=datetime.now(timezone.utc), + input_fingerprint=fingerprint, + input_class=input_class, + cluster_version=cluster_version, + model_scores=model_scores, + consensus_score=consensus, + inter_model_mad=mad, + confidence=confidence, + spec_versions=spec_versions, + raw_outputs=raw_outputs or {}, + ) + + # ------------------------------------------------------------------ + # Serialisation + # ------------------------------------------------------------------ + + def to_dict(self) -> dict: + return { + "run_id": self.run_id, + "timestamp": self.timestamp.isoformat(), + "input_fingerprint": self.input_fingerprint, + "input_class": self.input_class, + "cluster_version": self.cluster_version, + "model_scores": self.model_scores, + "consensus_score": self.consensus_score, + "inter_model_mad": self.inter_model_mad, + "confidence": self.confidence, + "spec_versions": self.spec_versions, + } + + @classmethod + def from_dict(cls, d: dict) -> "ConvergenceRecord": + return cls( + run_id=d["run_id"], + timestamp=datetime.fromisoformat(d["timestamp"]), + input_fingerprint=d["input_fingerprint"], + input_class=d["input_class"], + cluster_version=d.get("cluster_version"), + model_scores=d["model_scores"], + consensus_score=d["consensus_score"], + inter_model_mad=d["inter_model_mad"], + confidence=d["confidence"], + spec_versions=d["spec_versions"], + raw_outputs=d.get("raw_outputs", {}), + ) + + +# --------------------------------------------------------------------------- +# BaselineSnapshot — a point-in-time summary of the baseline +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class BaselineSnapshot: + """ + A versioned, point-in-time snapshot of the baseline statistics. + + The BaselineStore accumulates raw ConvergenceRecords (append-only, + never modified). Periodically, when the baseline crosses a confidence + threshold, a snapshot is taken and persisted separately. + + The system reads from the most recent valid snapshot during operation, + not from raw records directly. This means the hot path is fast (small + snapshot) while the full history is always available for analysis. + + Fields + ------ + snapshot_id : unique identifier for this snapshot + created_at : when it was taken + total_records : total convergence records at snapshot time + records_by_class : {input_class: count} + mad_by_class : {input_class: mean MAD across records} + mad_stddev_by_class : {input_class: std deviation of MAD} + confidence_by_class : {input_class: mean confidence} + spec_versions : spec versions active when snapshot was taken + proposals_since_last: list of SpecProposal IDs applied since previous snapshot + cluster_version : clustering model version used for records in snapshot + is_valid : False if a spec change invalidated some records + notes : free-text (e.g. why snapshot was triggered) + """ + + snapshot_id: str + created_at: datetime + total_records: int + records_by_class: dict[str, int] + mad_by_class: dict[str, float] + mad_stddev_by_class: dict[str, float] + confidence_by_class: dict[str, float] + spec_versions: dict[str, str] + proposals_since_last: list[str] # SpecProposal IDs + cluster_version: str | None + is_valid: bool = True + notes: str = "" + + def to_dict(self) -> dict: + return { + "snapshot_id": self.snapshot_id, + "created_at": self.created_at.isoformat(), + "total_records": self.total_records, + "records_by_class": self.records_by_class, + "mad_by_class": self.mad_by_class, + "mad_stddev_by_class": self.mad_stddev_by_class, + "confidence_by_class": self.confidence_by_class, + "spec_versions": self.spec_versions, + "proposals_since_last": self.proposals_since_last, + "cluster_version": self.cluster_version, + "is_valid": self.is_valid, + "notes": self.notes, + } + + @classmethod + def from_dict(cls, d: dict) -> "BaselineSnapshot": + return cls( + snapshot_id=d["snapshot_id"], + created_at=datetime.fromisoformat(d["created_at"]), + total_records=d["total_records"], + records_by_class=d["records_by_class"], + mad_by_class=d["mad_by_class"], + mad_stddev_by_class=d["mad_stddev_by_class"], + confidence_by_class=d["confidence_by_class"], + spec_versions=d["spec_versions"], + proposals_since_last=d["proposals_since_last"], + cluster_version=d.get("cluster_version"), + is_valid=d.get("is_valid", True), + notes=d.get("notes", ""), + ) + + +# --------------------------------------------------------------------------- +# DriftSignal — emitted when convergence breaks down +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class DriftSignal: + """ + Emitted by the ConvergenceMonitor spec when inter-model agreement + exceeds the drift threshold. + + Stored as an artifact on the context, then consumed by the + EventConsumer which triggers the correction workflow. + + Drift does NOT fail the primary workflow. + The primary output (consensus_score) is still valid and used. + Drift is a signal to investigate criteria, not a production failure. + + Fields + ------ + signal_id : unique identifier + run_id : the run that triggered this signal + timestamp : UTC + drift_type : MODEL_OUTLIER | CRITERIA_GAP | UNKNOWN + input_fingerprint : fingerprint of the triggering input + input_class : cluster label of the triggering input + model_scores : what each model produced + observed_mad : MAD observed in this run + expected_mad : historical baseline MAD for this class (None = new class) + baseline_records : how many baseline records exist for this class + outlier_model : model_id if drift_type == MODEL_OUTLIER, else None + implicated_specs : spec rule_ids that were evaluating when drift occurred + representative_fps : fingerprints of other inputs from same class that converged + (context for correction workflow) + """ + + signal_id: str + run_id: str + timestamp: datetime + drift_type: DriftType + input_fingerprint: str + input_class: str + model_scores: dict[str, float] + observed_mad: float + expected_mad: float | None + baseline_records: int + outlier_model: str | None + implicated_specs: list[str] + representative_fps: list[str] # fingerprints for correction context + triggering_input: dict = field(default_factory=dict) + # Raw input that triggered drift — populated by harness, used by CorrectionRunner + # for re-running models during hypothesis validation. + + def to_dict(self) -> dict: + return { + "signal_id": self.signal_id, + "run_id": self.run_id, + "timestamp": self.timestamp.isoformat(), + "drift_type": self.drift_type.value, + "input_fingerprint": self.input_fingerprint, + "input_class": self.input_class, + "model_scores": self.model_scores, + "observed_mad": self.observed_mad, + "expected_mad": self.expected_mad, + "baseline_records": self.baseline_records, + "outlier_model": self.outlier_model, + "implicated_specs": self.implicated_specs, + "representative_fps": self.representative_fps, + "triggering_input": self.triggering_input, + } + + @classmethod + def from_dict(cls, d: dict) -> "DriftSignal": + return cls( + signal_id=d["signal_id"], + run_id=d["run_id"], + timestamp=datetime.fromisoformat(d["timestamp"]), + drift_type=DriftType(d["drift_type"]), + input_fingerprint=d["input_fingerprint"], + input_class=d["input_class"], + model_scores=d["model_scores"], + observed_mad=d["observed_mad"], + expected_mad=d.get("expected_mad"), + baseline_records=d["baseline_records"], + outlier_model=d.get("outlier_model"), + implicated_specs=d.get("implicated_specs", []), + representative_fps=d.get("representative_fps", []), + triggering_input=d.get("triggering_input", {}), + ) + + +# --------------------------------------------------------------------------- +# SpecProposal — output of the correction workflow +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class SpecProposal: + """ + A proposed spec change, produced by the correction workflow. + + Never applied automatically. Requires: + 1. Technical validation: re-run drifting inputs with proposed spec, + confirm MAD improves (done by correction workflow itself). + 2. Human review: semantic correctness cannot be automated. + + The proposal is an immutable record of the entire decision chain: + what triggered it, what was proposed, and what the evidence was. + + Fields + ------ + proposal_id : unique identifier + created_at : UTC + triggered_by : DriftSignal that started the correction workflow + target_spec_id : which spec is being changed + current_spec_version : version before change + proposed_change : human-readable description of the change + proposed_spec_code : the actual Python implementation (as string) + hypothesis : why this change should restore convergence + drift_examples : input fingerprints that triggered drift + convergence_examples : input fingerprints that still converged (controls) + proposal_status : pending | validated | rejected + validation_mad_before : MAD on drift examples under current spec + validation_mad_after : MAD on drift examples under proposed spec + models_converged_after: how many models agreed under proposed spec + review_status : pending | approved | rejected | modified + reviewer_notes : free-text from human reviewer + applied_at : UTC when applied to registry (None until applied) + """ + + proposal_id: str + created_at: datetime + triggered_by_signal_id: str # FK to DriftSignal + target_spec_id: str + current_spec_version: str + proposed_change: str + proposed_spec_code: str + hypothesis: str + drift_examples: list[str] # input fingerprints + convergence_examples: list[str] # input fingerprints + proposal_status: ProposalStatus = ProposalStatus.PENDING + validation_mad_before: float | None = None + validation_mad_after: float | None = None + models_converged_after: int | None = None + review_status: ReviewStatus = ReviewStatus.PENDING + reviewer_notes: str | None = None + applied_at: datetime | None = None + + @property + def mad_improvement(self) -> float | None: + """Absolute MAD reduction. Positive = improvement.""" + if self.validation_mad_before is None or self.validation_mad_after is None: + return None + return self.validation_mad_before - self.validation_mad_after + + def to_dict(self) -> dict: + return { + "proposal_id": self.proposal_id, + "created_at": self.created_at.isoformat(), + "triggered_by_signal_id": self.triggered_by_signal_id, + "target_spec_id": self.target_spec_id, + "current_spec_version": self.current_spec_version, + "proposed_change": self.proposed_change, + "proposed_spec_code": self.proposed_spec_code, + "hypothesis": self.hypothesis, + "drift_examples": self.drift_examples, + "convergence_examples": self.convergence_examples, + "proposal_status": self.proposal_status.value, + "validation_mad_before": self.validation_mad_before, + "validation_mad_after": self.validation_mad_after, + "models_converged_after": self.models_converged_after, + "review_status": self.review_status.value, + "reviewer_notes": self.reviewer_notes, + "applied_at": self.applied_at.isoformat() if self.applied_at else None, + } + + @classmethod + def from_dict(cls, d: dict) -> "SpecProposal": + return cls( + proposal_id=d["proposal_id"], + created_at=datetime.fromisoformat(d["created_at"]), + triggered_by_signal_id=d["triggered_by_signal_id"], + target_spec_id=d["target_spec_id"], + current_spec_version=d["current_spec_version"], + proposed_change=d["proposed_change"], + proposed_spec_code=d["proposed_spec_code"], + hypothesis=d["hypothesis"], + drift_examples=d["drift_examples"], + convergence_examples=d["convergence_examples"], + proposal_status=ProposalStatus(d["proposal_status"]), + validation_mad_before=d.get("validation_mad_before"), + validation_mad_after=d.get("validation_mad_after"), + models_converged_after=d.get("models_converged_after"), + review_status=ReviewStatus(d.get("review_status", "pending")), + reviewer_notes=d.get("reviewer_notes"), + applied_at=datetime.fromisoformat(d["applied_at"]) if d.get("applied_at") else None, + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _compute_mad(values: list[float]) -> float: + """Mean absolute deviation from the mean.""" + if not values: + return 0.0 + mean = statistics.mean(values) + return statistics.mean(abs(v - mean) for v in values) + + +def _fingerprint(data: dict[str, Any]) -> str: + """Stable sha256 fingerprint of canonical input data.""" + serialised = json.dumps(data, sort_keys=True, default=str) + return hashlib.sha256(serialised.encode()).hexdigest()[:16] diff --git a/manifold/testing/stores.py b/manifold/testing/stores.py new file mode 100644 index 0000000..076b815 --- /dev/null +++ b/manifold/testing/stores.py @@ -0,0 +1,556 @@ +""" +manifold.testing.stores +~~~~~~~~~~~~~~~~~~~~~~~ +Storage protocols and a reference SQLite implementation. + +All stores are defined as Protocols first. This means: +- The EventConsumer depends only on the protocol, not the implementation +- Tests use in-memory implementations (also defined here) +- Production uses SQLite (defined here) +- Future implementations (Postgres, etc.) just implement the protocol + +Protocol summary +---------------- +BaselineStore — append-only store for ConvergenceRecords + also holds DriftSignals (they arrive before records + and the EventConsumer needs to retrieve them by ID) +SnapshotStore — versioned store for BaselineSnapshots +ProposalStore — store for SpecProposals with status transitions +SpecRegistry — manages spec versions and applies approved proposals +""" + +from __future__ import annotations + +import json +import sqlite3 +import statistics +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Protocol, runtime_checkable + +from manifold.testing.models import ( + BaselineSnapshot, + ConvergenceRecord, + DriftSignal, + ReviewStatus, + SpecProposal, +) + +# --------------------------------------------------------------------------- +# Protocols +# --------------------------------------------------------------------------- + + +@runtime_checkable +class BaselineStore(Protocol): + """ + Append-only store for ConvergenceRecords and DriftSignals. + + Records are never modified after writing. They can be marked stale + (a flag, not deletion) when the spec they were collected under changes. + """ + + async def append(self, record: ConvergenceRecord) -> None: + """Append a new convergence record.""" + ... + + async def append_signal(self, signal: DriftSignal) -> None: + """Store a drift signal (for later retrieval by event consumer).""" + ... + + async def get_signal(self, signal_id: str) -> DriftSignal | None: + """Retrieve a stored drift signal by ID.""" + ... + + async def total_records(self) -> int: + """Total number of records (including stale).""" + ... + + async def records_for_class( + self, + input_class: str, + exclude_stale: bool = True, + limit: int | None = None, + ) -> list[ConvergenceRecord]: + """Fetch records for a specific input class.""" + ... + + async def expected_mad_for_class(self, input_class: str) -> float | None: + """ + Historical mean MAD for this input class. + Returns None if fewer than 10 valid records exist. + """ + ... + + async def sample_fingerprints_for_class( + self, + input_class: str, + n: int, + ) -> list[str]: + """Sample n input fingerprints from convergent records for this class.""" + ... + + async def mark_stale_for_spec_version( + self, + spec_id: str, + spec_version: str, + ) -> int: + """ + Mark records collected under spec_id@spec_version as stale. + Returns count of records marked. + """ + ... + + async def take_snapshot( + self, + spec_registry: Any, + notes: str = "", + ) -> BaselineSnapshot: + """Compute and return a snapshot from current records.""" + ... + + +@runtime_checkable +class SnapshotStore(Protocol): + """Versioned store for BaselineSnapshots.""" + + async def write(self, snapshot: BaselineSnapshot) -> None: + """Persist a snapshot.""" + ... + + async def latest(self) -> BaselineSnapshot | None: + """Return the most recent valid snapshot.""" + ... + + async def all(self) -> list[BaselineSnapshot]: + """Return all snapshots, newest first.""" + ... + + +@runtime_checkable +class ProposalStore(Protocol): + """Store for SpecProposals with status transitions.""" + + async def write(self, proposal: SpecProposal) -> None: + """Write a new proposal.""" + ... + + async def get(self, proposal_id: str) -> SpecProposal | None: + """Retrieve by ID.""" + ... + + async def pending_proposals(self) -> list[SpecProposal]: + """All proposals awaiting human review.""" + ... + + async def mark_rejected( + self, + proposal_id: str, + reviewer_notes: str, + ) -> None: + """Mark a proposal as rejected by reviewer.""" + ... + + async def mark_approved( + self, + proposal_id: str, + reviewer_notes: str, + applied_at: datetime, + ) -> None: + """Mark a proposal as approved and applied.""" + ... + + +@runtime_checkable +class SpecRegistry(Protocol): + """Manages spec versions and applies approved proposals.""" + + async def current_versions(self) -> dict[str, str]: + """Return {spec_id: version} for all registered specs.""" + ... + + async def apply_proposal(self, proposal: SpecProposal) -> str: + """ + Apply an approved proposal to the registry. + Returns the new spec version string. + """ + ... + + +# --------------------------------------------------------------------------- +# In-memory implementations (for tests and development) +# --------------------------------------------------------------------------- + + +class InMemoryBaselineStore: + """ + In-memory BaselineStore for tests. + + Thread-safety: not guaranteed. Use for single-threaded tests only. + """ + + def __init__(self) -> None: + self._records: list[ConvergenceRecord] = [] + self._signals: dict[str, DriftSignal] = {} + self._stale: set[str] = set() # run_ids + + async def append(self, record: ConvergenceRecord) -> None: + self._records.append(record) + + async def append_signal(self, signal: DriftSignal) -> None: + self._signals[signal.signal_id] = signal + + async def get_signal(self, signal_id: str) -> DriftSignal | None: + return self._signals.get(signal_id) + + async def total_records(self) -> int: + return len(self._records) + + async def records_for_class( + self, + input_class: str, + exclude_stale: bool = True, + limit: int | None = None, + ) -> list[ConvergenceRecord]: + result = [ + r + for r in self._records + if r.input_class == input_class and (not exclude_stale or r.run_id not in self._stale) + ] + if limit: + result = result[-limit:] + return result + + async def expected_mad_for_class(self, input_class: str) -> float | None: + records = await self.records_for_class(input_class) + if len(records) < 10: + return None + return statistics.mean(r.inter_model_mad for r in records) + + async def sample_fingerprints_for_class( + self, + input_class: str, + n: int, + ) -> list[str]: + records = await self.records_for_class(input_class, limit=n * 2) + return [r.input_fingerprint for r in records[:n]] + + async def mark_stale_for_spec_version( + self, + spec_id: str, + spec_version: str, + ) -> int: + count = 0 + for r in self._records: + if r.spec_versions.get(spec_id) == spec_version: + self._stale.add(r.run_id) + count += 1 + return count + + async def take_snapshot( + self, + spec_registry: Any, + notes: str = "", + ) -> BaselineSnapshot: + valid_records = [r for r in self._records if r.run_id not in self._stale] + + # Compute per-class stats + classes: dict[str, list[ConvergenceRecord]] = {} + for r in valid_records: + classes.setdefault(r.input_class, []).append(r) + + records_by_class = {c: len(rs) for c, rs in classes.items()} + mad_by_class = { + c: statistics.mean(r.inter_model_mad for r in rs) for c, rs in classes.items() + } + mad_stddev_by_class = { + c: (statistics.stdev(r.inter_model_mad for r in rs) if len(rs) > 1 else 0.0) + for c, rs in classes.items() + } + confidence_by_class = { + c: statistics.mean(r.confidence for r in rs) for c, rs in classes.items() + } + + spec_versions = await spec_registry.current_versions() + + return BaselineSnapshot( + snapshot_id=str(uuid.uuid4()), + created_at=datetime.now(timezone.utc), + total_records=len(valid_records), + records_by_class=records_by_class, + mad_by_class=mad_by_class, + mad_stddev_by_class=mad_stddev_by_class, + confidence_by_class=confidence_by_class, + spec_versions=spec_versions, + proposals_since_last=[], # consumer fills this in + cluster_version=None, + notes=notes, + ) + + +class InMemorySnapshotStore: + def __init__(self) -> None: + self._snapshots: list[BaselineSnapshot] = [] + + async def write(self, snapshot: BaselineSnapshot) -> None: + self._snapshots.append(snapshot) + + async def latest(self) -> BaselineSnapshot | None: + valid = [s for s in self._snapshots if s.is_valid] + return valid[-1] if valid else None + + async def all(self) -> list[BaselineSnapshot]: + return list(reversed(self._snapshots)) + + +class InMemoryProposalStore: + def __init__(self) -> None: + self._proposals: dict[str, SpecProposal] = {} + + async def write(self, proposal: SpecProposal) -> None: + self._proposals[proposal.proposal_id] = proposal + + async def get(self, proposal_id: str) -> SpecProposal | None: + return self._proposals.get(proposal_id) + + async def pending_proposals(self) -> list[SpecProposal]: + return [p for p in self._proposals.values() if p.review_status == ReviewStatus.PENDING] + + async def mark_rejected(self, proposal_id: str, reviewer_notes: str) -> None: + p = self._proposals.get(proposal_id) + if p: + # frozen dataclass → replace + from dataclasses import replace + + self._proposals[proposal_id] = replace( + p, + review_status=ReviewStatus.REJECTED, + reviewer_notes=reviewer_notes, + ) + + async def mark_approved( + self, + proposal_id: str, + reviewer_notes: str, + applied_at: datetime, + ) -> None: + p = self._proposals.get(proposal_id) + if p: + from dataclasses import replace + + self._proposals[proposal_id] = replace( + p, + review_status=ReviewStatus.APPROVED, + reviewer_notes=reviewer_notes, + applied_at=applied_at, + ) + + +class InMemorySpecRegistry: + def __init__(self, initial_versions: dict[str, str] | None = None) -> None: + self._versions: dict[str, str] = initial_versions or {} + self._history: list[dict] = [] + + async def current_versions(self) -> dict[str, str]: + return dict(self._versions) + + async def apply_proposal(self, proposal: SpecProposal) -> str: + old = self._versions.get(proposal.target_spec_id, "0.0.0") + # Bump patch version + parts = old.split(".") + new = f"{parts[0]}.{parts[1]}.{int(parts[2]) + 1}" if len(parts) == 3 else "0.0.1" + self._versions[proposal.target_spec_id] = new + self._history.append( + { + "spec_id": proposal.target_spec_id, + "old_version": old, + "new_version": new, + "proposal_id": proposal.proposal_id, + "applied_at": datetime.now(timezone.utc).isoformat(), + } + ) + return new + + +# --------------------------------------------------------------------------- +# SQLite implementation (production) +# --------------------------------------------------------------------------- + + +class SQLiteBaselineStore: + """ + Production BaselineStore backed by SQLite. + + Single file, no external dependencies, suitable for development and + single-node production. Replace with Postgres-backed implementation + for multi-node setups. + + Usage + ----- + store = SQLiteBaselineStore("baseline.db") + await store.initialise() + """ + + SCHEMA = """ + CREATE TABLE IF NOT EXISTS convergence_records ( + run_id TEXT PRIMARY KEY, + timestamp TEXT NOT NULL, + input_fingerprint TEXT NOT NULL, + input_class TEXT NOT NULL, + cluster_version TEXT, + model_scores TEXT NOT NULL, -- JSON + consensus_score REAL NOT NULL, + inter_model_mad REAL NOT NULL, + confidence REAL NOT NULL, + spec_versions TEXT NOT NULL, -- JSON + raw_outputs TEXT NOT NULL, -- JSON + is_stale INTEGER NOT NULL DEFAULT 0 + ); + + CREATE INDEX IF NOT EXISTS idx_input_class + ON convergence_records(input_class, is_stale); + + CREATE TABLE IF NOT EXISTS drift_signals ( + signal_id TEXT PRIMARY KEY, + data TEXT NOT NULL -- full JSON + ); + """ + + def __init__(self, db_path: str | Path) -> None: + self._db_path = str(db_path) + self._conn: sqlite3.Connection | None = None + + async def initialise(self) -> None: + """Create tables if they don't exist.""" + self._conn = sqlite3.connect(self._db_path) + self._conn.executescript(self.SCHEMA) + self._conn.commit() + + def _cx(self) -> sqlite3.Connection: + if self._conn is None: + raise RuntimeError("Call initialise() before using the store") + return self._conn + + async def append(self, record: ConvergenceRecord) -> None: + cx = self._cx() + cx.execute( + """INSERT OR IGNORE INTO convergence_records + (run_id, timestamp, input_fingerprint, input_class, + cluster_version, model_scores, consensus_score, + inter_model_mad, confidence, spec_versions, raw_outputs) + VALUES (?,?,?,?,?,?,?,?,?,?,?)""", + ( + record.run_id, + record.timestamp.isoformat(), + record.input_fingerprint, + record.input_class, + record.cluster_version, + json.dumps(record.model_scores), + record.consensus_score, + record.inter_model_mad, + record.confidence, + json.dumps(record.spec_versions), + json.dumps(record.raw_outputs), + ), + ) + cx.commit() + + async def append_signal(self, signal: DriftSignal) -> None: + cx = self._cx() + cx.execute( + "INSERT OR IGNORE INTO drift_signals (signal_id, data) VALUES (?,?)", + (signal.signal_id, json.dumps(signal.to_dict())), + ) + cx.commit() + + async def get_signal(self, signal_id: str) -> DriftSignal | None: + row = ( + self._cx() + .execute("SELECT data FROM drift_signals WHERE signal_id=?", (signal_id,)) + .fetchone() + ) + return DriftSignal.from_dict(json.loads(row[0])) if row else None + + async def total_records(self) -> int: + row = self._cx().execute("SELECT COUNT(*) FROM convergence_records").fetchone() + return int(row[0]) + + async def records_for_class( + self, + input_class: str, + exclude_stale: bool = True, + limit: int | None = None, + ) -> list[ConvergenceRecord]: + q = "SELECT * FROM convergence_records WHERE input_class=?" + params: list[Any] = [input_class] + if exclude_stale: + q += " AND is_stale=0" + q += " ORDER BY timestamp DESC" + if limit: + q += f" LIMIT {limit}" + rows = self._cx().execute(q, params).fetchall() + return [self._row_to_record(r) for r in rows] + + async def expected_mad_for_class(self, input_class: str) -> float | None: + records = await self.records_for_class(input_class) + if len(records) < 10: + return None + return statistics.mean(r.inter_model_mad for r in records) + + async def sample_fingerprints_for_class( + self, + input_class: str, + n: int, + ) -> list[str]: + records = await self.records_for_class(input_class, limit=n * 2) + return [r.input_fingerprint for r in records[:n]] + + async def mark_stale_for_spec_version( + self, + spec_id: str, + spec_version: str, + ) -> int: + # Fetch all non-stale records and filter in Python for compatibility + rows = self._cx().execute("SELECT * FROM convergence_records WHERE is_stale=0").fetchall() + records = [self._row_to_record(r) for r in rows] + stale_ids = [r.run_id for r in records if r.spec_versions.get(spec_id) == spec_version] + if stale_ids: + placeholders = ",".join("?" * len(stale_ids)) + self._cx().execute( + f"UPDATE convergence_records SET is_stale=1 WHERE run_id IN ({placeholders})", + stale_ids, + ) + self._cx().commit() + return len(stale_ids) + + async def take_snapshot( + self, + spec_registry: Any, + notes: str = "", + ) -> BaselineSnapshot: + """Delegate to in-memory logic after loading valid records.""" + mem = InMemoryBaselineStore() + valid_records = ( + self._cx().execute("SELECT * FROM convergence_records WHERE is_stale=0").fetchall() + ) + for row in valid_records: + mem._records.append(self._row_to_record(row)) + return await mem.take_snapshot(spec_registry, notes) + + def _row_to_record(self, row: tuple) -> ConvergenceRecord: + # Column order matches INSERT + run_id, timestamp, fp, ic, cv, ms, cs, mad, conf, sv, ro, _stale = row + return ConvergenceRecord( + run_id=run_id, + timestamp=datetime.fromisoformat(timestamp), + input_fingerprint=fp, + input_class=ic, + cluster_version=cv, + model_scores=json.loads(ms), + consensus_score=cs, + inter_model_mad=mad, + confidence=conf, + spec_versions=json.loads(sv), + raw_outputs=json.loads(ro), + ) diff --git a/tests/testing/__init__.py b/tests/testing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/testing/test_convergence.py b/tests/testing/test_convergence.py new file mode 100644 index 0000000..1d6e9e1 --- /dev/null +++ b/tests/testing/test_convergence.py @@ -0,0 +1,316 @@ +""" +Tests for manifold.testing.convergence + +ConvergenceMonitor is tested standalone (no manifold dependency). +We verify: +- All three regimes behave correctly +- Drift classification (MODEL_OUTLIER vs CRITERIA_GAP) +- Signal/record queues are properly drained +- Cache update affects behaviour +- Idempotency: multiple drains are safe +""" + +from __future__ import annotations + +import uuid +from datetime import timezone + +import pytest + +from manifold.testing.convergence import ConvergenceConfig, ConvergenceMonitor +from manifold.testing.models import DriftType +from manifold.testing.stores import InMemoryBaselineStore + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def make_monitor( + min_baseline_size: int = 10, + drift_multiplier: float = 2.5, + outlier_threshold: float = 2.0, + min_class_records: int = 3, +) -> ConvergenceMonitor: + store = InMemoryBaselineStore() + config = ConvergenceConfig( + min_baseline_size=min_baseline_size, + drift_multiplier=drift_multiplier, + outlier_threshold=outlier_threshold, + min_class_records=min_class_records, + ) + return ConvergenceMonitor(store, config) + + +def run_eval( + monitor: ConvergenceMonitor, + model_scores: dict[str, float], + input_class: str = "ngo_religious", + run_id: str | None = None, +) -> dict: + return monitor.evaluate_sync( + run_id=run_id or str(uuid.uuid4()), + input_data={"name": "Test Org"}, + input_class=input_class, + cluster_version="v1", + model_scores=model_scores, + raw_outputs={}, + ) + + +CONVERGED_SCORES = {"a": 0.80, "b": 0.78, "c": 0.82, "d": 0.79} +# Two-camp split: a+b agree, c+d disagree → high MAD, no single outlier → CRITERIA_GAP +DIVERGED_SCORES = {"a": 0.80, "b": 0.80, "c": -0.80, "d": -0.80} +# One model far from the other three → MODEL_OUTLIER +OUTLIER_SCORES = {"a": 0.80, "b": 0.78, "c": 0.82, "d": -0.90} + + +# --------------------------------------------------------------------------- +# Regime 1: Early +# --------------------------------------------------------------------------- + + +class TestEarlyRegime: + def test_regime_is_early_below_baseline_threshold(self): + m = make_monitor(min_baseline_size=100) + m.update_baseline_cache(total_records=5, class_mads={}, class_counts={}) + result = run_eval(m, CONVERGED_SCORES) + assert result["regime"] == "early" + + def test_no_signal_emitted_in_early_regime(self): + m = make_monitor(min_baseline_size=100) + m.update_baseline_cache(5, {}, {}) + run_eval(m, DIVERGED_SCORES) # even diverged — no signal in early regime + assert m.drain_signals() == [] + + def test_record_always_added_in_early_regime(self): + m = make_monitor(min_baseline_size=100) + m.update_baseline_cache(5, {}, {}) + run_eval(m, CONVERGED_SCORES) + records = m.drain_records() + assert len(records) == 1 + assert records[0].input_class == "ngo_religious" + + def test_early_message_contains_progress(self): + m = make_monitor(min_baseline_size=50) + m.update_baseline_cache(12, {}, {}) + result = run_eval(m, CONVERGED_SCORES) + assert "12/50" in result["message"] + + +# --------------------------------------------------------------------------- +# Regime 2: Novel input class +# --------------------------------------------------------------------------- + + +class TestNovelClassRegime: + def _mature_cache(self, monitor: ConvergenceMonitor, class_mads=None, class_counts=None): + monitor.update_baseline_cache( + total_records=500, + class_mads=class_mads or {}, + class_counts=class_counts or {}, + ) + + def test_novel_class_no_signal(self): + m = make_monitor(min_baseline_size=10, min_class_records=5) + self._mature_cache(m) + run_eval(m, DIVERGED_SCORES, input_class="new_class") + assert m.drain_signals() == [] + + def test_novel_class_adds_record(self): + m = make_monitor(min_baseline_size=10, min_class_records=5) + self._mature_cache(m) + run_eval(m, CONVERGED_SCORES, input_class="new_class") + records = m.drain_records() + assert len(records) == 1 + + def test_known_class_below_min_records_treated_as_novel(self): + m = make_monitor(min_baseline_size=10, min_class_records=5) + self._mature_cache(m, class_mads={"cls": 0.04}, class_counts={"cls": 2}) + run_eval(m, DIVERGED_SCORES, input_class="cls") + assert m.drain_signals() == [] # only 2 records, below min_class_records=5 + + +# --------------------------------------------------------------------------- +# Regime 3: Mature / drift detection +# --------------------------------------------------------------------------- + + +class TestMatureRegime: + def _setup_mature(self, monitor: ConvergenceMonitor, expected_mad=0.04, class_count=20): + monitor.update_baseline_cache( + total_records=500, + class_mads={"ngo_religious": expected_mad}, + class_counts={"ngo_religious": class_count}, + ) + + def test_convergent_run_no_signal(self): + m = make_monitor(drift_multiplier=2.5) + self._setup_mature(m, expected_mad=0.04) + # MAD of CONVERGED_SCORES ≈ 0.014 — well below threshold 0.04 * 2.5 = 0.10 + result = run_eval(m, CONVERGED_SCORES) + assert result["regime"] == "convergent" + assert m.drain_signals() == [] + + def test_convergent_run_adds_record(self): + m = make_monitor(drift_multiplier=2.5) + self._setup_mature(m, expected_mad=0.04) + run_eval(m, CONVERGED_SCORES) + assert len(m.drain_records()) == 1 + + def test_diverged_run_emits_signal(self): + m = make_monitor(drift_multiplier=2.5) + self._setup_mature(m, expected_mad=0.04) + # MAD of DIVERGED_SCORES ≈ 0.39 — well above threshold 0.10 + result = run_eval(m, DIVERGED_SCORES) + assert result["regime"] == "drift" + signals = m.drain_signals() + assert len(signals) == 1 + + def test_drift_signal_has_correct_run_id(self): + m = make_monitor() + self._setup_mature(m) + run_id = str(uuid.uuid4()) + run_eval(m, DIVERGED_SCORES, run_id=run_id) + signal = m.drain_signals()[0] + assert signal.run_id == run_id + + def test_drift_signal_expected_mad_correct(self): + m = make_monitor() + self._setup_mature(m, expected_mad=0.04) + run_eval(m, DIVERGED_SCORES) + signal = m.drain_signals()[0] + assert signal.expected_mad == pytest.approx(0.04) + + def test_drift_signal_has_timestamp(self): + m = make_monitor() + self._setup_mature(m) + run_eval(m, DIVERGED_SCORES) + signal = m.drain_signals()[0] + assert signal.timestamp.tzinfo is not None # timezone-aware + + def test_diverged_run_no_record_added(self): + m = make_monitor() + self._setup_mature(m) + run_eval(m, DIVERGED_SCORES) + m.drain_signals() + assert m.drain_records() == [] # drifted run doesn't add to baseline + + +# --------------------------------------------------------------------------- +# Drift classification +# --------------------------------------------------------------------------- + + +class TestDriftClassification: + def _setup(self) -> ConvergenceMonitor: + m = make_monitor(drift_multiplier=2.0, outlier_threshold=2.0) + m.update_baseline_cache(500, {"cls": 0.03}, {"cls": 20}) + return m + + def test_model_outlier_detected(self): + m = self._setup() + # Model 'd' is -0.90, others are ~0.80 — clear outlier + run_eval(m, OUTLIER_SCORES, input_class="cls") + signals = m.drain_signals() + assert len(signals) == 1 + assert signals[0].drift_type == DriftType.MODEL_OUTLIER + assert signals[0].outlier_model == "d" + + def test_criteria_gap_when_all_diverge(self): + m = self._setup() + run_eval(m, DIVERGED_SCORES, input_class="cls") + signals = m.drain_signals() + assert len(signals) == 1 + assert signals[0].drift_type == DriftType.CRITERIA_GAP + assert signals[0].outlier_model is None + + def test_no_outlier_with_identical_scores(self): + m = self._setup() + # All the same — no drift expected + scores = {"a": 0.8, "b": 0.8, "c": 0.8, "d": 0.8} + run_eval(m, scores, input_class="cls") + assert m.drain_signals() == [] + + +# --------------------------------------------------------------------------- +# Queue drain behaviour +# --------------------------------------------------------------------------- + + +class TestDrainBehaviour: + def test_drain_signals_empties_queue(self): + m = make_monitor() + m.update_baseline_cache(500, {"ngo_religious": 0.04}, {"ngo_religious": 20}) + run_eval(m, DIVERGED_SCORES) + assert len(m.drain_signals()) == 1 + assert len(m.drain_signals()) == 0 # drained + + def test_drain_records_empties_queue(self): + m = make_monitor(min_baseline_size=1) + m.update_baseline_cache(500, {"ngo_religious": 0.04}, {"ngo_religious": 20}) + run_eval(m, CONVERGED_SCORES) + assert len(m.drain_records()) == 1 + assert len(m.drain_records()) == 0 + + def test_multiple_runs_accumulate(self): + m = make_monitor(min_baseline_size=1) + m.update_baseline_cache(500, {"ngo_religious": 0.04}, {"ngo_religious": 20}) + run_eval(m, CONVERGED_SCORES) + run_eval(m, CONVERGED_SCORES) + run_eval(m, CONVERGED_SCORES) + assert len(m.drain_records()) == 3 + + def test_no_scores_returns_early_gracefully(self): + m = make_monitor() + m.update_baseline_cache(500, {}, {}) + result = m.evaluate_sync( + run_id="r1", + input_data={}, + input_class="cls", + cluster_version=None, + model_scores={}, # empty + raw_outputs={}, + ) + assert result["regime"] == "early" + assert m.drain_signals() == [] + + +# --------------------------------------------------------------------------- +# Cache update behaviour +# --------------------------------------------------------------------------- + + +class TestCacheUpdate: + def test_updating_cache_changes_regime(self): + m = make_monitor(min_baseline_size=50, min_class_records=3) + + # First: early regime + m.update_baseline_cache(5, {}, {}) + r1 = run_eval(m, DIVERGED_SCORES) + assert r1["regime"] == "early" + m.drain_records() + + # Update cache to mature + m.update_baseline_cache(500, {"ngo_religious": 0.04}, {"ngo_religious": 20}) + r2 = run_eval(m, DIVERGED_SCORES) + assert r2["regime"] == "drift" + + def test_record_contains_correct_spec_versions(self): + m = ConvergenceMonitor( + InMemoryBaselineStore(), + ConvergenceConfig(min_baseline_size=1), + spec_versions={"classify_spec": "1.0.0", "threshold_spec": "2.1.0"}, + ) + m.update_baseline_cache(500, {"ngo_religious": 0.04}, {"ngo_religious": 20}) + run_eval(m, CONVERGED_SCORES) + record = m.drain_records()[0] + assert record.spec_versions["classify_spec"] == "1.0.0" + assert record.spec_versions["threshold_spec"] == "2.1.0" + + def test_signal_timestamp_is_timezone_aware(self): + m = make_monitor() + m.update_baseline_cache(500, {"ngo_religious": 0.04}, {"ngo_religious": 20}) + run_eval(m, DIVERGED_SCORES) + signal = m.drain_signals()[0] + assert signal.timestamp.tzinfo == timezone.utc diff --git a/tests/testing/test_correction.py b/tests/testing/test_correction.py new file mode 100644 index 0000000..6c15357 --- /dev/null +++ b/tests/testing/test_correction.py @@ -0,0 +1,504 @@ +""" +Tests for manifold.testing.correction + +All steps are tested independently using stubs — no real LLM or model +calls. The CorrectionRunner integration test exercises the full pipeline +end-to-end with deterministic fakes. + +Coverage: +- analyze(): pure function, all three drift types +- _parse_llm_response(): JSON parsing, fence stripping, missing fields +- generate_hypothesis(): happy path, LLM failure retry, all-fail returns None +- validate(): MODEL_OUTLIER (no re-run), CRITERIA_GAP improvement, degradation, + empty triggering_input guard +- CorrectionRunner.run(): CRITERIA_GAP validated, CRITERIA_GAP rejected, + MODEL_OUTLIER, LLM failure → None +""" + +from __future__ import annotations + +import json +import uuid +from datetime import datetime, timezone + +import pytest + +from manifold.testing.correction import ( + CorrectionRunner, + Hypothesis, + analyze, + generate_hypothesis, + validate, + _parse_llm_response, +) +from manifold.testing.models import ( + DriftSignal, + DriftType, + ProposalStatus, + SpecProposal, + _compute_mad, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +MODEL_IDS = ["gpt", "gemini", "llama", "mistral"] + + +def make_signal( + drift_type: DriftType = DriftType.CRITERIA_GAP, + model_scores: dict | None = None, + outlier_model: str | None = None, + triggering_input: dict | None = None, + expected_mad: float = 0.04, + baseline_records: int = 50, +) -> DriftSignal: + if model_scores is None: + if drift_type == DriftType.MODEL_OUTLIER: + model_scores = {"gpt": 0.80, "gemini": 0.78, "llama": 0.82, "mistral": -0.90} + outlier_model = outlier_model or "mistral" + else: + model_scores = {"gpt": 0.80, "gemini": 0.80, "llama": -0.80, "mistral": -0.80} + + return DriftSignal( + signal_id=str(uuid.uuid4()), + run_id=str(uuid.uuid4()), + timestamp=datetime.now(timezone.utc), + drift_type=drift_type, + input_fingerprint="abc123", + input_class="ngo_religious", + model_scores=model_scores, + observed_mad=_compute_mad(list(model_scores.values())), + expected_mad=expected_mad, + baseline_records=baseline_records, + outlier_model=outlier_model, + implicated_specs=["classify_spec", "threshold_spec"], + representative_fps=["fp1", "fp2", "fp3"], + triggering_input=triggering_input or {"name": "Caritas Berlin", "type": "welfare"}, + ) + + +def make_hypothesis(target: str = "classify_spec") -> Hypothesis: + return Hypothesis( + proposed_change="Add explicit criteria for welfare organisations in religious context.", + proposed_spec_code=( + "# Welfare orgs with religious affiliation → classify as ngo_religious\n" + "if 'welfare' in candidate.lower() and 'religious' in context.get_data('tags', []):\n" + " return SpecResult.ok(...)" + ), + hypothesis="Welfare orgs with partial religious affiliation cause model splits. " + "Explicit criteria should restore convergence.", + target_spec_id=target, + llm_raw_response='{"proposed_change":"...","proposed_spec_code":"...","hypothesis":"..."}', + ) + + +# --------------------------------------------------------------------------- +# Step 1 — analyze() +# --------------------------------------------------------------------------- + + +class TestAnalyze: + def test_criteria_gap_sets_cause(self): + sig = make_signal(DriftType.CRITERIA_GAP) + a = analyze(sig) + assert a.drift_type == DriftType.CRITERIA_GAP + assert "CRITERIA_GAP" in a.probable_cause or "split" in a.probable_cause.lower() + + def test_model_outlier_names_outlier(self): + sig = make_signal(DriftType.MODEL_OUTLIER) + a = analyze(sig) + assert "mistral" in a.probable_cause + + def test_unknown_has_low_confidence(self): + sig = make_signal(DriftType.UNKNOWN) + a = analyze(sig) + assert a.confidence_in_diagnosis < 0.5 + + def test_model_outlier_high_confidence(self): + sig = make_signal(DriftType.MODEL_OUTLIER) + a = analyze(sig) + assert a.confidence_in_diagnosis > 0.7 + + def test_target_spec_from_implicated(self): + sig = make_signal() + a = analyze(sig) + assert a.target_spec_id == "classify_spec" # first in list + + def test_target_spec_fallback_when_empty(self): + sig = make_signal() + from dataclasses import replace + + sig = replace(sig, implicated_specs=[]) + a = analyze(sig) + assert a.target_spec_id == "unknown_spec" + + def test_agreeing_disagreeing_populated(self): + sig = make_signal(DriftType.CRITERIA_GAP) + a = analyze(sig) + assert len(a.agreeing_models) > 0 + assert len(a.disagreeing_models) > 0 + # All model IDs appear exactly once across the two lists + all_ids = set(a.agreeing_models) | set(a.disagreeing_models) + assert all_ids == set(sig.model_scores.keys()) + + def test_pure_no_side_effects(self): + sig = make_signal() + a1 = analyze(sig) + a2 = analyze(sig) + assert a1.probable_cause == a2.probable_cause + + +# --------------------------------------------------------------------------- +# LLM response parsing +# --------------------------------------------------------------------------- + + +class TestParseLLMResponse: + def _valid_json(self, **overrides) -> str: + data = { + "proposed_change": "Add welfare criteria.", + "proposed_spec_code": "# code", + "hypothesis": "This will converge.", + "target_spec_id": "spec_a", + **overrides, + } + return json.dumps(data) + + def test_clean_json(self): + result = _parse_llm_response(self._valid_json()) + assert result["proposed_change"] == "Add welfare criteria." + + def test_strips_markdown_fences(self): + raw = "```json\n" + self._valid_json() + "\n```" + result = _parse_llm_response(raw) + assert result is not None + + def test_strips_plain_fences(self): + raw = "```\n" + self._valid_json() + "\n```" + result = _parse_llm_response(raw) + assert result is not None + + def test_json_embedded_in_text(self): + raw = "Here is the response: " + self._valid_json() + " Thank you." + result = _parse_llm_response(raw) + assert result is not None + + def test_missing_field_returns_none(self): + raw = json.dumps({"proposed_change": "x", "hypothesis": "y"}) + assert _parse_llm_response(raw) is None + + def test_empty_field_returns_none(self): + raw = self._valid_json(proposed_change="") + assert _parse_llm_response(raw) is None + + def test_invalid_json_returns_none(self): + assert _parse_llm_response("not json at all") is None + + def test_preserves_target_spec_id(self): + raw = self._valid_json(target_spec_id="my_spec") + result = _parse_llm_response(raw) + assert result["target_spec_id"] == "my_spec" + + +# --------------------------------------------------------------------------- +# Step 2 — generate_hypothesis() +# --------------------------------------------------------------------------- + + +class TestGenerateHypothesis: + def _mock_llm(self, response: str): + """Returns a stub LLM caller that always responds with `response`.""" + + async def caller(prompt: str) -> str: + return response + + return caller + + def _valid_llm_response(self, **overrides) -> str: + data = { + "proposed_change": "Add welfare/religious intersection criterion.", + "proposed_spec_code": "# if welfare and religious → classify ngo_religious", + "hypothesis": "Models diverge on this edge case. Explicit criterion fixes it.", + "target_spec_id": "classify_spec", + **overrides, + } + return json.dumps(data) + + @pytest.mark.asyncio + async def test_happy_path(self): + sig = make_signal() + analysis = analyze(sig) + h = await generate_hypothesis(analysis, self._mock_llm(self._valid_llm_response())) + assert h is not None + assert h.proposed_change == "Add welfare/religious intersection criterion." + + @pytest.mark.asyncio + async def test_preserves_raw_response(self): + sig = make_signal() + analysis = analyze(sig) + raw = self._valid_llm_response() + h = await generate_hypothesis(analysis, self._mock_llm(raw)) + assert h.llm_raw_response == raw + + @pytest.mark.asyncio + async def test_retries_on_bad_response_then_succeeds(self): + call_count = [0] + good = self._valid_llm_response() + + async def flaky_llm(prompt: str) -> str: + call_count[0] += 1 + if call_count[0] == 1: + return "not json" + return good + + analysis = analyze(make_signal()) + h = await generate_hypothesis(analysis, flaky_llm, max_retries=2) + assert h is not None + assert call_count[0] == 2 + + @pytest.mark.asyncio + async def test_all_retries_fail_returns_none(self): + async def bad_llm(prompt: str) -> str: + return "not json ever" + + analysis = analyze(make_signal()) + h = await generate_hypothesis(analysis, bad_llm, max_retries=2) + assert h is None + + @pytest.mark.asyncio + async def test_llm_exception_counts_as_failure(self): + async def exploding_llm(prompt: str) -> str: + raise RuntimeError("LLM unavailable") + + analysis = analyze(make_signal()) + h = await generate_hypothesis(analysis, exploding_llm, max_retries=2) + assert h is None + + @pytest.mark.asyncio + async def test_uses_fallback_target_if_llm_omits_it(self): + """If LLM doesn't return target_spec_id, falls back to analysis.target_spec_id.""" + data = { + "proposed_change": "x", + "proposed_spec_code": "# x", + "hypothesis": "y", + # target_spec_id deliberately omitted + } + analysis = analyze(make_signal()) + h = await generate_hypothesis(analysis, self._mock_llm(json.dumps(data))) + assert h is not None + assert h.target_spec_id == analysis.target_spec_id + + +# --------------------------------------------------------------------------- +# Step 3 — validate() +# --------------------------------------------------------------------------- + + +class TestValidate: + def _make_model_runner(self, scores_by_model: dict[str, float]): + """Stub model runner returning predefined scores.""" + + async def runner(input_data: dict, criteria_hint: str, model_id: str) -> float: + return scores_by_model[model_id] + + return runner + + @pytest.mark.asyncio + async def test_model_outlier_validated_without_rerun(self): + sig = make_signal(DriftType.MODEL_OUTLIER) + # mistral is the outlier; without it MAD should be very low + h = make_hypothesis() + result = await validate(h, sig, None, MODEL_IDS, expected_mad=0.04) + assert result.validated is True + assert "mistral" not in result.model_scores_after + + @pytest.mark.asyncio + async def test_model_outlier_not_validated_if_mad_still_high(self): + # All models diverge — excluding outlier doesn't help + sig = make_signal( + DriftType.MODEL_OUTLIER, + model_scores={"gpt": 0.80, "gemini": -0.80, "llama": 0.80, "mistral": -0.90}, + outlier_model="mistral", + ) + h = make_hypothesis() + result = await validate(h, sig, None, MODEL_IDS, expected_mad=0.04) + assert result.validated is False + + @pytest.mark.asyncio + async def test_criteria_gap_validated_on_improvement(self): + # After proposed criteria, models converge: MAD goes from ~0.8 to ~0.02 + sig = make_signal(DriftType.CRITERIA_GAP, expected_mad=0.04) + h = make_hypothesis() + # Tight scores after correction + runner = self._make_model_runner( + {"gpt": 0.79, "gemini": 0.80, "llama": 0.81, "mistral": 0.78} + ) + result = await validate(h, sig, runner, MODEL_IDS, expected_mad=0.04) + assert result.validated is True + assert result.mad_after < result.mad_before + + @pytest.mark.asyncio + async def test_criteria_gap_rejected_on_no_improvement(self): + sig = make_signal(DriftType.CRITERIA_GAP, expected_mad=0.04) + h = make_hypothesis() + # Scores still diverge after correction + runner = self._make_model_runner( + {"gpt": 0.80, "gemini": 0.80, "llama": -0.80, "mistral": -0.80} + ) + result = await validate(h, sig, runner, MODEL_IDS, expected_mad=0.04) + assert result.validated is False + + @pytest.mark.asyncio + async def test_empty_triggering_input_guard(self): + from dataclasses import replace + + sig = make_signal(DriftType.CRITERIA_GAP) + sig = replace(sig, triggering_input={}) + h = make_hypothesis() + result = await validate(h, sig, None, MODEL_IDS) + assert result.validated is False + assert "triggering_input is empty" in result.validation_note + + @pytest.mark.asyncio + async def test_model_runner_exception_excluded_gracefully(self): + sig = make_signal(DriftType.CRITERIA_GAP, expected_mad=0.04) + h = make_hypothesis() + + async def flaky_runner(input_data, criteria_hint, model_id): + if model_id == "llama": + raise RuntimeError("model timeout") + return {"gpt": 0.79, "gemini": 0.80, "mistral": 0.78}[model_id] + + result = await validate(h, sig, flaky_runner, MODEL_IDS, expected_mad=0.04) + assert "llama" not in result.model_scores_after + assert result.n_models_tested == 3 + + @pytest.mark.asyncio + async def test_all_models_fail_returns_not_validated(self): + sig = make_signal(DriftType.CRITERIA_GAP) + h = make_hypothesis() + + async def broken_runner(input_data, criteria_hint, model_id): + raise RuntimeError("all broken") + + result = await validate(h, sig, broken_runner, MODEL_IDS) + assert result.validated is False + assert result.mad_after is None + + +# --------------------------------------------------------------------------- +# CorrectionRunner — full pipeline +# --------------------------------------------------------------------------- + + +class TestCorrectionRunner: + def _make_runner( + self, + llm_response: str | None = None, + scores_after: dict[str, float] | None = None, + ) -> CorrectionRunner: + if llm_response is None: + llm_response = json.dumps( + { + "proposed_change": "Add explicit edge-case criterion.", + "proposed_spec_code": "# Explicit criterion: welfare + religious → ngo_religious", + "hypothesis": "Models converge once ambiguity is resolved.", + "target_spec_id": "classify_spec", + } + ) + + async def llm(prompt: str) -> str: + return llm_response + + if scores_after is None: + scores_after = {"gpt": 0.79, "gemini": 0.80, "llama": 0.81, "mistral": 0.78} + + async def model_runner(input_data, criteria_hint, model_id) -> float: + return scores_after[model_id] + + return CorrectionRunner( + llm_caller=llm, + model_runner=model_runner, + model_ids=MODEL_IDS, + improvement_threshold=0.2, + ) + + @pytest.mark.asyncio + async def test_criteria_gap_produces_validated_proposal(self): + runner = self._make_runner() + signal = make_signal(DriftType.CRITERIA_GAP, expected_mad=0.04) + proposal = await runner.run(signal) + + assert proposal is not None + assert proposal.proposal_status == ProposalStatus.VALIDATED + assert proposal.triggered_by_signal_id == signal.signal_id + assert proposal.target_spec_id == "classify_spec" + assert proposal.validation_mad_before > proposal.validation_mad_after + + @pytest.mark.asyncio + async def test_criteria_gap_still_diverging_produces_rejected_proposal(self): + still_diverged = {"gpt": 0.80, "gemini": 0.80, "llama": -0.80, "mistral": -0.80} + runner = self._make_runner(scores_after=still_diverged) + signal = make_signal(DriftType.CRITERIA_GAP, expected_mad=0.04) + proposal = await runner.run(signal) + + assert proposal is not None + assert proposal.proposal_status == ProposalStatus.REJECTED + + @pytest.mark.asyncio + async def test_model_outlier_produces_validated_proposal(self): + runner = self._make_runner() + signal = make_signal(DriftType.MODEL_OUTLIER) + proposal = await runner.run(signal) + + # MODEL_OUTLIER validation excludes the outlier — should converge + assert proposal is not None + assert proposal.proposal_status == ProposalStatus.VALIDATED + + @pytest.mark.asyncio + async def test_llm_failure_returns_none(self): + async def bad_llm(prompt: str) -> str: + raise RuntimeError("LLM down") + + async def model_runner(input_data, criteria_hint, model_id) -> float: + return 0.8 + + runner = CorrectionRunner(bad_llm, model_runner, MODEL_IDS, max_llm_retries=1) + proposal = await runner.run(make_signal()) + assert proposal is None + + @pytest.mark.asyncio + async def test_proposal_contains_full_audit_trail(self): + runner = self._make_runner() + signal = make_signal(DriftType.CRITERIA_GAP, expected_mad=0.04) + proposal = await runner.run(signal) + + assert proposal is not None + assert proposal.proposed_change + assert proposal.proposed_spec_code + assert proposal.hypothesis + assert proposal.drift_examples # at least the triggering fingerprint + assert proposal.validation_mad_before is not None + assert proposal.validation_mad_after is not None + assert proposal.created_at.tzinfo == timezone.utc + + @pytest.mark.asyncio + async def test_convergence_examples_from_signal(self): + runner = self._make_runner() + signal = make_signal() + proposal = await runner.run(signal) + assert proposal.convergence_examples == signal.representative_fps + + @pytest.mark.asyncio + async def test_round_trip_serialization(self): + runner = self._make_runner() + signal = make_signal(DriftType.CRITERIA_GAP, expected_mad=0.04) + proposal = await runner.run(signal) + assert proposal is not None + + roundtripped = SpecProposal.from_dict(proposal.to_dict()) + assert roundtripped.proposal_id == proposal.proposal_id + assert roundtripped.proposal_status == proposal.proposal_status + assert roundtripped.validation_mad_before == proposal.validation_mad_before diff --git a/tests/testing/test_data_layer.py b/tests/testing/test_data_layer.py new file mode 100644 index 0000000..d53908b --- /dev/null +++ b/tests/testing/test_data_layer.py @@ -0,0 +1,605 @@ +""" +Tests for manifold.testing data layer. + +Coverage: +- models: construction, derived fields, serialisation round-trip +- stores: in-memory happy path + edge cases +- events: EventBus dispatch, EventConsumer routing +""" + +from __future__ import annotations + +import asyncio +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock + +import pytest + +from manifold.testing.models import ( + ConvergenceRecord, + DriftSignal, + DriftType, + ProposalStatus, + ReviewStatus, + SpecProposal, + _compute_mad, + _fingerprint, +) +from manifold.testing.stores import ( + InMemoryBaselineStore, + InMemoryProposalStore, + InMemorySnapshotStore, + InMemorySpecRegistry, + SQLiteBaselineStore, +) +from manifold.testing.events import ( + Event, + EventBus, + EventConsumer, + EventType, + payload_run_completed, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def make_record( + input_class: str = "ngo_religious", + scores: dict | None = None, + run_id: str | None = None, + spec_versions: dict | None = None, +) -> ConvergenceRecord: + return ConvergenceRecord.create( + run_id=run_id or str(uuid.uuid4()), + input_data={"name": "Caritas Berlin", "type": "welfare"}, + input_class=input_class, + cluster_version="v1", + model_scores=scores or {"gpt4o": 0.8, "gemini": 0.75, "llama": 0.82, "mistral": 0.79}, + spec_versions=spec_versions or {"classify_spec": "1.0.0"}, + ) + + +def make_signal(drift_type: DriftType = DriftType.CRITERIA_GAP) -> DriftSignal: + return DriftSignal( + signal_id=str(uuid.uuid4()), + run_id=str(uuid.uuid4()), + timestamp=datetime.now(timezone.utc), + drift_type=drift_type, + input_fingerprint="abc123", + input_class="ngo_religious", + model_scores={"gpt4o": 0.8, "gemini": -0.3, "llama": 0.1, "mistral": 0.6}, + observed_mad=0.42, + expected_mad=0.04, + baseline_records=150, + outlier_model=None, + implicated_specs=["classify_spec"], + representative_fps=["fp1", "fp2"], + ) + + +def make_proposal(signal: DriftSignal | None = None) -> SpecProposal: + s = signal or make_signal() + return SpecProposal( + proposal_id=str(uuid.uuid4()), + created_at=datetime.now(timezone.utc), + triggered_by_signal_id=s.signal_id, + target_spec_id="classify_spec", + current_spec_version="1.0.0", + proposed_change="Add 'ambiguous_religious' sub-class to criteria", + proposed_spec_code="class ClassifySpec(Spec): ...", + hypothesis="Models diverge because criteria don't cover welfare orgs with religious affiliation", + drift_examples=["abc123"], + convergence_examples=["fp1", "fp2"], + proposal_status=ProposalStatus.VALIDATED, + validation_mad_before=0.42, + validation_mad_after=0.06, + models_converged_after=4, + ) + + +# --------------------------------------------------------------------------- +# models.py +# --------------------------------------------------------------------------- + + +class TestComputeMAD: + def test_identical_values_gives_zero(self): + assert _compute_mad([0.5, 0.5, 0.5]) == 0.0 + + def test_known_values(self): + # values [0, 1]: mean=0.5, deviations=[0.5, 0.5], MAD=0.5 + assert _compute_mad([0.0, 1.0]) == pytest.approx(0.5) + + def test_empty_gives_zero(self): + assert _compute_mad([]) == 0.0 + + def test_single_value_gives_zero(self): + assert _compute_mad([0.7]) == 0.0 + + +class TestFingerprint: + def test_deterministic(self): + data = {"name": "Caritas", "type": "welfare"} + assert _fingerprint(data) == _fingerprint(data) + + def test_different_data_different_fingerprint(self): + assert _fingerprint({"a": 1}) != _fingerprint({"a": 2}) + + def test_key_order_irrelevant(self): + assert _fingerprint({"a": 1, "b": 2}) == _fingerprint({"b": 2, "a": 1}) + + def test_returns_16_chars(self): + assert len(_fingerprint({"x": "y"})) == 16 + + +class TestConvergenceRecordCreate: + def test_derives_mad_correctly(self): + record = make_record(scores={"a": 0.0, "b": 1.0}) + assert record.inter_model_mad == pytest.approx(0.5) + + def test_confidence_is_one_minus_mad(self): + record = make_record(scores={"a": 0.8, "b": 0.8, "c": 0.8}) + assert record.confidence == pytest.approx(1.0) + + def test_confidence_clamps_at_zero(self): + # MAD > 1.0 should not give negative confidence + record = make_record(scores={"a": -1.0, "b": 1.0}) + assert record.confidence >= 0.0 + + def test_raises_on_empty_scores(self): + with pytest.raises(ValueError, match="model_scores must not be empty"): + ConvergenceRecord.create( + run_id="x", + input_data={}, + input_class="test", + cluster_version=None, + model_scores={}, + spec_versions={}, + ) + + def test_serialisation_round_trip(self): + r = make_record() + assert ConvergenceRecord.from_dict(r.to_dict()).run_id == r.run_id + assert ConvergenceRecord.from_dict(r.to_dict()).inter_model_mad == pytest.approx( + r.inter_model_mad + ) + + def test_timestamp_is_utc_datetime(self): + r = make_record() + assert isinstance(r.timestamp, datetime) + + +class TestDriftSignal: + def test_serialisation_round_trip(self): + s = make_signal() + s2 = DriftSignal.from_dict(s.to_dict()) + assert s2.signal_id == s.signal_id + assert s2.drift_type == DriftType.CRITERIA_GAP + assert s2.expected_mad == pytest.approx(0.04) + + def test_null_expected_mad_survives_round_trip(self): + s = make_signal() + d = s.to_dict() + d["expected_mad"] = None + s2 = DriftSignal.from_dict(d) + assert s2.expected_mad is None + + +class TestSpecProposal: + def test_mad_improvement_computed(self): + p = make_proposal() + assert p.mad_improvement == pytest.approx(0.42 - 0.06) + + def test_mad_improvement_none_when_not_validated(self): + p = SpecProposal( + proposal_id="x", + created_at=datetime.now(timezone.utc), + triggered_by_signal_id="s", + target_spec_id="classify_spec", + current_spec_version="1.0.0", + proposed_change="test", + proposed_spec_code="...", + hypothesis="test", + drift_examples=[], + convergence_examples=[], + ) + assert p.mad_improvement is None + + def test_serialisation_round_trip(self): + p = make_proposal() + p2 = SpecProposal.from_dict(p.to_dict()) + assert p2.proposal_id == p.proposal_id + assert p2.proposal_status == ProposalStatus.VALIDATED + assert p2.review_status == ReviewStatus.PENDING + + +# --------------------------------------------------------------------------- +# stores.py — InMemory +# --------------------------------------------------------------------------- + + +class TestInMemoryBaselineStore: + @pytest.mark.asyncio + async def test_append_and_count(self): + store = InMemoryBaselineStore() + await store.append(make_record()) + await store.append(make_record()) + assert await store.total_records() == 2 + + @pytest.mark.asyncio + async def test_records_for_class_filters(self): + store = InMemoryBaselineStore() + await store.append(make_record(input_class="class_a")) + await store.append(make_record(input_class="class_b")) + results = await store.records_for_class("class_a") + assert len(results) == 1 + assert results[0].input_class == "class_a" + + @pytest.mark.asyncio + async def test_expected_mad_none_below_threshold(self): + store = InMemoryBaselineStore() + for _ in range(9): + await store.append(make_record()) + assert await store.expected_mad_for_class("ngo_religious") is None + + @pytest.mark.asyncio + async def test_expected_mad_computed_above_threshold(self): + store = InMemoryBaselineStore() + for _ in range(10): + await store.append(make_record()) + mad = await store.expected_mad_for_class("ngo_religious") + assert mad is not None + assert 0.0 <= mad <= 1.0 + + @pytest.mark.asyncio + async def test_mark_stale_excludes_from_query(self): + store = InMemoryBaselineStore() + r = make_record(spec_versions={"classify_spec": "1.0.0"}) + await store.append(r) + count = await store.mark_stale_for_spec_version("classify_spec", "1.0.0") + assert count == 1 + results = await store.records_for_class("ngo_religious", exclude_stale=True) + assert len(results) == 0 + + @pytest.mark.asyncio + async def test_stale_records_visible_when_not_excluded(self): + store = InMemoryBaselineStore() + r = make_record(spec_versions={"classify_spec": "1.0.0"}) + await store.append(r) + await store.mark_stale_for_spec_version("classify_spec", "1.0.0") + results = await store.records_for_class("ngo_religious", exclude_stale=False) + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_signal_store_and_retrieve(self): + store = InMemoryBaselineStore() + signal = make_signal() + await store.append_signal(signal) + retrieved = await store.get_signal(signal.signal_id) + assert retrieved is not None + assert retrieved.signal_id == signal.signal_id + + @pytest.mark.asyncio + async def test_get_signal_returns_none_for_unknown(self): + store = InMemoryBaselineStore() + assert await store.get_signal("does-not-exist") is None + + @pytest.mark.asyncio + async def test_snapshot_stats_correct(self): + store = InMemoryBaselineStore() + registry = InMemorySpecRegistry({"classify_spec": "1.0.0"}) + for _ in range(5): + await store.append(make_record(input_class="class_a")) + for _ in range(3): + await store.append(make_record(input_class="class_b")) + + snap = await store.take_snapshot(registry) + assert snap.total_records == 8 + assert snap.records_by_class["class_a"] == 5 + assert snap.records_by_class["class_b"] == 3 + assert "class_a" in snap.mad_by_class + + +class TestInMemoryProposalStore: + @pytest.mark.asyncio + async def test_write_and_retrieve(self): + store = InMemoryProposalStore() + p = make_proposal() + await store.write(p) + retrieved = await store.get(p.proposal_id) + assert retrieved is not None + assert retrieved.proposal_id == p.proposal_id + + @pytest.mark.asyncio + async def test_pending_proposals(self): + store = InMemoryProposalStore() + await store.write(make_proposal()) + await store.write(make_proposal()) + pending = await store.pending_proposals() + assert len(pending) == 2 + + @pytest.mark.asyncio + async def test_mark_rejected(self): + store = InMemoryProposalStore() + p = make_proposal() + await store.write(p) + await store.mark_rejected(p.proposal_id, "criteria are fine") + updated = await store.get(p.proposal_id) + assert updated.review_status == ReviewStatus.REJECTED + assert updated.reviewer_notes == "criteria are fine" + + @pytest.mark.asyncio + async def test_approved_leaves_pending_list(self): + store = InMemoryProposalStore() + p = make_proposal() + await store.write(p) + await store.mark_approved(p.proposal_id, "looks good", datetime.now(timezone.utc)) + assert len(await store.pending_proposals()) == 0 + + +class TestInMemorySpecRegistry: + @pytest.mark.asyncio + async def test_apply_proposal_bumps_patch_version(self): + registry = InMemorySpecRegistry({"classify_spec": "1.0.0"}) + p = make_proposal() + new_version = await registry.apply_proposal(p) + assert new_version == "1.0.1" + + @pytest.mark.asyncio + async def test_current_versions(self): + registry = InMemorySpecRegistry({"a": "1.0.0", "b": "2.1.3"}) + versions = await registry.current_versions() + assert versions["a"] == "1.0.0" + assert versions["b"] == "2.1.3" + + +# --------------------------------------------------------------------------- +# stores.py — SQLite +# --------------------------------------------------------------------------- + + +class TestSQLiteBaselineStore: + @pytest.mark.asyncio + async def test_append_and_retrieve(self, tmp_path): + store = SQLiteBaselineStore(tmp_path / "test.db") + await store.initialise() + r = make_record() + await store.append(r) + results = await store.records_for_class("ngo_religious") + assert len(results) == 1 + assert results[0].run_id == r.run_id + + @pytest.mark.asyncio + async def test_total_records(self, tmp_path): + store = SQLiteBaselineStore(tmp_path / "test.db") + await store.initialise() + await store.append(make_record()) + await store.append(make_record()) + assert await store.total_records() == 2 + + @pytest.mark.asyncio + async def test_idempotent_append(self, tmp_path): + """INSERT OR IGNORE — same run_id twice does not duplicate.""" + store = SQLiteBaselineStore(tmp_path / "test.db") + await store.initialise() + r = make_record() + await store.append(r) + await store.append(r) + assert await store.total_records() == 1 + + @pytest.mark.asyncio + async def test_signal_round_trip(self, tmp_path): + store = SQLiteBaselineStore(tmp_path / "test.db") + await store.initialise() + s = make_signal() + await store.append_signal(s) + s2 = await store.get_signal(s.signal_id) + assert s2 is not None + assert s2.drift_type == DriftType.CRITERIA_GAP + + @pytest.mark.asyncio + async def test_mark_stale(self, tmp_path): + store = SQLiteBaselineStore(tmp_path / "test.db") + await store.initialise() + await store.append(make_record(spec_versions={"classify_spec": "1.0.0"})) + count = await store.mark_stale_for_spec_version("classify_spec", "1.0.0") + assert count == 1 + assert len(await store.records_for_class("ngo_religious")) == 0 + + +# --------------------------------------------------------------------------- +# events.py +# --------------------------------------------------------------------------- + + +class TestEventBus: + @pytest.mark.asyncio + async def test_handler_called_on_emit(self): + bus = EventBus() + received = [] + + async def handler(event: Event): + received.append(event) + + bus.subscribe(EventType.DRIFT_DETECTED, handler) + event = Event.create(EventType.DRIFT_DETECTED, "test", {}) + await bus.emit(event) + + assert len(received) == 1 + assert received[0].event_id == event.event_id + + @pytest.mark.asyncio + async def test_multiple_handlers_all_called(self): + bus = EventBus() + calls = [] + bus.subscribe(EventType.DRIFT_DETECTED, AsyncMock(side_effect=lambda e: calls.append("h1"))) + bus.subscribe(EventType.DRIFT_DETECTED, AsyncMock(side_effect=lambda e: calls.append("h2"))) + + await bus.emit(Event.create(EventType.DRIFT_DETECTED, "test", {})) + assert set(calls) == {"h1", "h2"} + + @pytest.mark.asyncio + async def test_failing_handler_does_not_stop_others(self): + bus = EventBus() + calls = [] + + async def bad_handler(event): + raise RuntimeError("oops") + + async def good_handler(event): + calls.append("ok") + + bus.subscribe(EventType.DRIFT_DETECTED, bad_handler) + bus.subscribe(EventType.DRIFT_DETECTED, good_handler) + + await bus.emit(Event.create(EventType.DRIFT_DETECTED, "test", {})) + assert calls == ["ok"] + + @pytest.mark.asyncio + async def test_no_handlers_does_not_raise(self): + bus = EventBus() + await bus.emit(Event.create(EventType.SPEC_UPDATED, "test", {})) + + @pytest.mark.asyncio + async def test_subscribe_many(self): + bus = EventBus() + calls = [] + bus.subscribe_many( + { + EventType.DRIFT_DETECTED: AsyncMock(side_effect=lambda e: calls.append("drift")), + EventType.SPEC_UPDATED: AsyncMock(side_effect=lambda e: calls.append("spec")), + } + ) + await bus.emit(Event.create(EventType.DRIFT_DETECTED, "x", {})) + await bus.emit(Event.create(EventType.SPEC_UPDATED, "x", {})) + assert "drift" in calls + assert "spec" in calls + + +class TestEventConsumer: + def _make_consumer(self): + baseline = InMemoryBaselineStore() + snapshots = InMemorySnapshotStore() + proposals = InMemoryProposalStore() + registry = InMemorySpecRegistry({"classify_spec": "1.0.0"}) + correction_runner = AsyncMock() + bus = EventBus() + consumer = EventConsumer( + baseline_store=baseline, + snapshot_store=snapshots, + proposal_store=proposals, + spec_registry=registry, + correction_runner=correction_runner, + bus=bus, + snapshot_interval=5, + ) + return consumer, baseline, snapshots, proposals, registry, correction_runner, bus + + @pytest.mark.asyncio + async def test_converged_run_updates_baseline(self): + consumer, baseline, *_ = self._make_consumer() + record = make_record() + + event = Event.create( + EventType.RUN_COMPLETED, + source="test", + payload=payload_run_completed( + run_id=record.run_id, + success=True, + had_drift=False, + drift_signal_id=None, + convergence_record=record, + ), + ) + await consumer._on_run_completed(event) + assert await baseline.total_records() == 1 + + @pytest.mark.asyncio + async def test_drifted_run_emits_drift_detected(self): + consumer, baseline, _, _, _, _, bus = self._make_consumer() + signal = make_signal() + await baseline.append_signal(signal) + + received_events = [] + + async def _capture(e: Event) -> None: + received_events.append(e) + + bus.subscribe(EventType.DRIFT_DETECTED, _capture) + + event = Event.create( + EventType.RUN_COMPLETED, + source="test", + payload=payload_run_completed( + run_id="run-1", + success=True, + had_drift=True, + drift_signal_id=signal.signal_id, + convergence_record=None, + ), + ) + await consumer._on_run_completed(event) + await asyncio.sleep(0) + assert len(received_events) == 1 + + @pytest.mark.asyncio + async def test_snapshot_taken_at_interval(self): + consumer, baseline, snapshots, _, registry, _, _ = self._make_consumer() + + for _ in range(4): + await baseline.append(make_record()) + + event = Event.create( + EventType.RUN_COMPLETED, + source="test", + payload=payload_run_completed( + run_id=str(uuid.uuid4()), + success=True, + had_drift=False, + drift_signal_id=None, + convergence_record=make_record(), + ), + ) + await consumer._on_run_completed(event) + + snap = await snapshots.latest() + assert snap is not None + + @pytest.mark.asyncio + async def test_proposal_approved_updates_spec_registry(self): + consumer, _, _, proposals, registry, _, bus = self._make_consumer() + p = make_proposal() + await proposals.write(p) + + event = Event.create( + EventType.PROPOSAL_APPROVED, + source="test", + payload={"proposal_id": p.proposal_id, "reviewer_notes": "looks good"}, + ) + await consumer._on_proposal_approved(event) + + versions = await registry.current_versions() + assert versions["classify_spec"] == "1.0.1" + + @pytest.mark.asyncio + async def test_spec_updated_marks_baseline_stale(self): + consumer, baseline, _, _, _, _, _ = self._make_consumer() + await baseline.append(make_record(spec_versions={"classify_spec": "1.0.0"})) + + event = Event.create( + EventType.SPEC_UPDATED, + source="test", + payload={ + "spec_id": "classify_spec", + "old_version": "1.0.0", + "new_version": "1.0.1", + "proposal_id": "p1", + }, + ) + await consumer._on_spec_updated(event) + + records = await baseline.records_for_class("ngo_religious", exclude_stale=True) + assert len(records) == 0