diff --git a/src/clio_agent/arc/cache.py b/src/clio_agent/arc/cache.py index 24cb9e1..1c74e4e 100644 --- a/src/clio_agent/arc/cache.py +++ b/src/clio_agent/arc/cache.py @@ -6,6 +6,7 @@ try: from lru import LRU # lru-dict package + HAS_LRU_DICT = True except ImportError: HAS_LRU_DICT = False diff --git a/src/clio_agent/arc/index.py b/src/clio_agent/arc/index.py index 0985714..a4ae43f 100644 --- a/src/clio_agent/arc/index.py +++ b/src/clio_agent/arc/index.py @@ -71,10 +71,7 @@ def search(self, key: Tuple[str, float]) -> Optional[Any]: return self._index.get(key) def range_query( - self, - start_key: Tuple[str, float], - end_key: Tuple[str, float], - inclusive: bool = True + self, start_key: Tuple[str, float], end_key: Tuple[str, float], inclusive: bool = True ) -> List[Any]: """ Retrieve all values within key range [start_key, end_key]. @@ -103,13 +100,13 @@ def range_query( ... ) """ # irange returns keys, so we need to extract values - return [self._index[k] for k in self._index.irange(start_key, end_key, inclusive=(True, inclusive))] + return [ + self._index[k] + for k in self._index.irange(start_key, end_key, inclusive=(True, inclusive)) + ] def range_query_keys( - self, - start_key: Tuple[str, float], - end_key: Tuple[str, float], - inclusive: bool = True + self, start_key: Tuple[str, float], end_key: Tuple[str, float], inclusive: bool = True ) -> List[Tuple[str, float]]: """ Retrieve all keys within range [start_key, end_key]. @@ -135,10 +132,7 @@ def range_query_keys( return list(self._index.irange(start_key, end_key, inclusive=(True, inclusive))) def range_query_items( - self, - start_key: Tuple[str, float], - end_key: Tuple[str, float], - inclusive: bool = True + self, start_key: Tuple[str, float], end_key: Tuple[str, float], inclusive: bool = True ) -> List[Tuple[Tuple[str, float], Any]]: """ Retrieve all (key, value) pairs within range. @@ -162,7 +156,10 @@ def range_query_items( ... print(f"{key}: {value}") """ # Return (key, value) tuples for all keys in range - return [(k, self._index[k]) for k in self._index.irange(start_key, end_key, inclusive=(True, inclusive))] + return [ + (k, self._index[k]) + for k in self._index.irange(start_key, end_key, inclusive=(True, inclusive)) + ] def delete(self, key: Tuple[str, float]) -> bool: """ @@ -249,10 +246,7 @@ def get_session_range(self, session_id: str) -> List[Any]: Examples: >>> conversations = index.get_session_range("session_1") """ - return self.range_query( - (session_id, 0.0), - (session_id, float('inf')) - ) + return self.range_query((session_id, 0.0), (session_id, float("inf"))) def get_latest_in_session(self, session_id: str, n: int = 1) -> List[Any]: """ @@ -279,11 +273,9 @@ def get_latest_in_session(self, session_id: str, n: int = 1) -> List[Any]: >>> recent = index.get_latest_in_session("session_1", n=5) """ # Get keys in reverse order (most recent first) - all_keys = list(self._index.irange( - (session_id, 0.0), - (session_id, float('inf')), - reverse=True - )) + all_keys = list( + self._index.irange((session_id, 0.0), (session_id, float("inf")), reverse=True) + ) # Take first n keys (the n most recent) latest_keys = all_keys[:n] # Return values in chronological order (oldest to newest) diff --git a/src/clio_agent/arc/lsm.py b/src/clio_agent/arc/lsm.py index 59eaed8..4d3a4f5 100644 --- a/src/clio_agent/arc/lsm.py +++ b/src/clio_agent/arc/lsm.py @@ -230,9 +230,7 @@ def range_scan(self, start_ts: float, end_ts: float) -> List[Dict[str, Any]]: if end_ts < sstable.min_key or start_ts > sstable.max_key: continue - sstable_results = self._range_scan_sstable( - sstable, start_ts, end_ts - ) + sstable_results = self._range_scan_sstable(sstable, start_ts, end_ts) for ts, metric in sstable_results.items(): # MemTable has priority (newer data) if ts not in results: diff --git a/src/clio_agent/arc/storage.py b/src/clio_agent/arc/storage.py index 099a3cd..43092ab 100644 --- a/src/clio_agent/arc/storage.py +++ b/src/clio_agent/arc/storage.py @@ -66,8 +66,8 @@ def __init__( # Tier migration policy (days) self.tier_policy = tier_policy or { - "hot_to_warm": 1, # 1 day in hot tier before eviction - "warm_to_cold": 7, # 1 week in warm tier + "hot_to_warm": 1, # 1 day in hot tier before eviction + "warm_to_cold": 7, # 1 week in warm tier "cold_to_archive": 30, # 1 month in cold tier } @@ -476,9 +476,7 @@ def _maybe_migrate_tiers(self) -> None: # Parse timestamp try: - last_accessed = datetime.fromisoformat( - last_accessed_str.replace("Z", "+00:00") - ) + last_accessed = datetime.fromisoformat(last_accessed_str.replace("Z", "+00:00")) except Exception: continue diff --git a/src/clio_agent/conversation_manager.py b/src/clio_agent/conversation_manager.py index 21e5742..a823dd4 100644 --- a/src/clio_agent/conversation_manager.py +++ b/src/clio_agent/conversation_manager.py @@ -39,6 +39,7 @@ class ConversationManagerSignature(dspy.Signature): - key_topics: Main topics discussed - context_for_response: Relevant context """ + history: dspy.History = dspy.InputField(desc="Conversation history") current_question: str = dspy.InputField(desc="Current question") summary: str = dspy.OutputField(desc="Summary of context") @@ -63,7 +64,7 @@ def add_message(self, role: str, content: str): self.history_buffer.append({"role": role, "content": content}) if len(self.history_buffer) > self.max_history_length: # Keep only recent messages - self.history_buffer = self.history_buffer[-self.max_history_length:] + self.history_buffer = self.history_buffer[-self.max_history_length :] def get_history(self) -> dspy.History: """Get current history as dspy.History.""" diff --git a/src/clio_agent/experts/__init__.py b/src/clio_agent/experts/__init__.py index 56b9d8f..cfc6d68 100644 --- a/src/clio_agent/experts/__init__.py +++ b/src/clio_agent/experts/__init__.py @@ -47,6 +47,7 @@ # EXPERT REGISTRY # ============================================================================ + def get_all_experts() -> Dict[str, dspy.Module]: """Get all available expert instances. diff --git a/src/clio_agent/experts/ndp_expert.py b/src/clio_agent/experts/ndp_expert.py index 457daf7..090ef2e 100644 --- a/src/clio_agent/experts/ndp_expert.py +++ b/src/clio_agent/experts/ndp_expert.py @@ -62,9 +62,7 @@ def __init__(self, tool_executor: ToolExecutor | None = None) -> None: self._owns_executor = tool_executor is None self._tool_executor = tool_executor or create_sync_tool_executor(gateway) self._tools = [ - tool - for tool in self._tool_executor.to_dspy_tools() - if tool.name.startswith("ndp_") + tool for tool in self._tool_executor.to_dspy_tools() if tool.name.startswith("ndp_") ] def forward(self, question: str, file_context: str = "") -> dspy.Prediction: diff --git a/src/clio_agent/experts/sac_format_expert.py b/src/clio_agent/experts/sac_format_expert.py index 068cb38..b5eeaab 100644 --- a/src/clio_agent/experts/sac_format_expert.py +++ b/src/clio_agent/experts/sac_format_expert.py @@ -58,9 +58,7 @@ def __init__(self, tool_executor: ToolExecutor | None = None) -> None: self._owns_executor = tool_executor is None self._tool_executor = tool_executor or create_sync_tool_executor(gateway) self._tools = [ - tool - for tool in self._tool_executor.to_dspy_tools() - if tool.name.startswith("sac_") + tool for tool in self._tool_executor.to_dspy_tools() if tool.name.startswith("sac_") ] def forward(self, question: str, file_context: str = "") -> dspy.Prediction: @@ -229,8 +227,7 @@ def plot_traces(self, filepath: str) -> ExpertResult: if isinstance(result, dict) and result.get("error"): return ExpertResult( analysis=( - "Could not create SAC waveform plot: " - f"{format_tool_error(result['error'])}" + f"Could not create SAC waveform plot: {format_tool_error(result['error'])}" ), recommendations="Verify the staged file and SAC plotting tool contract.", source="deterministic", diff --git a/src/clio_agent/gact/events.py b/src/clio_agent/gact/events.py index f46080a..30fc295 100644 --- a/src/clio_agent/gact/events.py +++ b/src/clio_agent/gact/events.py @@ -91,9 +91,7 @@ class EventBus: instances which serialize their own operations. """ - def __init__( - self, *, queue_capacity: int = 256, history_per_session: int = 256 - ) -> None: + def __init__(self, *, queue_capacity: int = 256, history_per_session: int = 256) -> None: self._capacity = queue_capacity self._history_cap = history_per_session # session_id -> list of subscriber queues @@ -126,9 +124,7 @@ def publish(self, event: Event) -> None: except asyncio.QueueFull: pass - async def subscribe( - self, session_id: str, *, last_event_id: int = 0 - ) -> AsyncIterator[Event]: + async def subscribe(self, session_id: str, *, last_event_id: int = 0) -> AsyncIterator[Event]: """Yield events for ``session_id`` until the consumer drops. ``last_event_id`` is the highest event id the client already diff --git a/src/clio_agent/gact/scheduler.py b/src/clio_agent/gact/scheduler.py index 4216753..24816de 100644 --- a/src/clio_agent/gact/scheduler.py +++ b/src/clio_agent/gact/scheduler.py @@ -120,9 +120,7 @@ def _flush(self) -> None: ) tmp.replace(self._path) - def add( - self, *, session_id: str, cron: str, question: str - ) -> Schedule: + def add(self, *, session_id: str, cron: str, question: str) -> Schedule: sid = "sched_" + uuid.uuid4().hex[:12] sch = Schedule( id=sid, @@ -140,9 +138,7 @@ def get(self, sid: str) -> Optional[Schedule]: with self._lock: return self._schedules.get(sid) - def list( - self, *, session_id: Optional[str] = None - ) -> list[Schedule]: + def list(self, *, session_id: Optional[str] = None) -> list[Schedule]: with self._lock: rows = list(self._schedules.values()) if session_id is not None: @@ -174,9 +170,7 @@ def due_now(self, when: datetime) -> Iterable[Schedule]: for sch in rows: if not sch.enabled: continue - if sch.last_fired_at and sch.last_fired_at.startswith( - when_minute - ): + if sch.last_fired_at and sch.last_fired_at.startswith(when_minute): continue if cron_matches(sch.cron, when): yield sch diff --git a/src/clio_agent/gact/sessions.py b/src/clio_agent/gact/sessions.py index 80207f6..dcda289 100644 --- a/src/clio_agent/gact/sessions.py +++ b/src/clio_agent/gact/sessions.py @@ -324,7 +324,10 @@ def update( if edit_mode is not None and edit_mode in {"diff", "whole", "patch"}: sess.edit_mode = edit_mode if routing_mode is not None and routing_mode in { - "auto", "chat", "experts", "reasoning_only", + "auto", + "chat", + "experts", + "reasoning_only", }: sess.routing_mode = routing_mode if model is not None: diff --git a/src/clio_agent/gact/user_agents.py b/src/clio_agent/gact/user_agents.py index 346df04..1add0a4 100644 --- a/src/clio_agent/gact/user_agents.py +++ b/src/clio_agent/gact/user_agents.py @@ -67,20 +67,29 @@ def _load(self) -> None: return for row in data.get("agents", []): try: - self._agents[row["id"]] = UserAgent(**{ - k: row[k] - for k in ( - "id", "title", "description", "source", - "system_prompt", "default_provider", "default_model", - "tier", "specialization", - ) - if k in row - } | { - "parameters": dict(row.get("parameters", {})), - "keywords": list(row.get("keywords", [])), - "tools": list(row.get("tools", [])), - "metadata": dict(row.get("metadata", {})), - }) + self._agents[row["id"]] = UserAgent( + **{ + k: row[k] + for k in ( + "id", + "title", + "description", + "source", + "system_prompt", + "default_provider", + "default_model", + "tier", + "specialization", + ) + if k in row + } + | { + "parameters": dict(row.get("parameters", {})), + "keywords": list(row.get("keywords", [])), + "tools": list(row.get("tools", [])), + "metadata": dict(row.get("metadata", {})), + } + ) except Exception: continue diff --git a/src/clio_agent/optimizer/instrumentation.py b/src/clio_agent/optimizer/instrumentation.py index 0b2a7de..6175a8e 100644 --- a/src/clio_agent/optimizer/instrumentation.py +++ b/src/clio_agent/optimizer/instrumentation.py @@ -35,6 +35,7 @@ def instrumented_forward(arc_memory: Any, agent_id: str) -> Callable: ... def forward(self, question, file_context=""): ... return dspy.Prediction(analysis="...", recommendations="...") """ + def decorator(forward_fn: Callable) -> Callable: @functools.wraps(forward_fn) def wrapper(*args, **kwargs): @@ -86,6 +87,7 @@ def wrapper(*args, **kwargs): arc_memory.store_invocation(invocation) return wrapper + return decorator @@ -147,7 +149,13 @@ def _extract_output(result: Any) -> Dict[str, Any]: pass else: # Fallback: try common expert output fields - for field in ("analysis", "recommendations", "visualization_description", "file_path", "answer"): + for field in ( + "analysis", + "recommendations", + "visualization_description", + "file_path", + "answer", + ): val = getattr(result, field, None) if val is not None: output_data[field] = _to_safe_text(val)[:500] diff --git a/src/clio_agent/optimizer/runner.py b/src/clio_agent/optimizer/runner.py index 8597eb0..ed675d9 100644 --- a/src/clio_agent/optimizer/runner.py +++ b/src/clio_agent/optimizer/runner.py @@ -96,8 +96,7 @@ def run( """ if len(trainset) < 5: raise ValueError( - f"Need at least 5 training examples for 20/80 split. " - f"Got {len(trainset)}." + f"Need at least 5 training examples for 20/80 split. Got {len(trainset)}." ) if metric_fn is None: diff --git a/src/clio_agent/optimizer/trainer.py b/src/clio_agent/optimizer/trainer.py index 59830c4..eadd59e 100644 --- a/src/clio_agent/optimizer/trainer.py +++ b/src/clio_agent/optimizer/trainer.py @@ -16,10 +16,19 @@ from clio_agent.arc.schema import Invocation, decode_invocation # Error keywords that indicate problematic output -_ERROR_KEYWORDS = frozenset([ - "error:", "error,", "traceback", "exception:", "failed to", - "could not", "unable to", "runtime error", "type error", -]) +_ERROR_KEYWORDS = frozenset( + [ + "error:", + "error,", + "traceback", + "exception:", + "failed to", + "could not", + "unable to", + "runtime error", + "type error", + ] +) class TrainingSetGenerator: @@ -46,9 +55,7 @@ def __init__(self, arc_memory: Any) -> None: """ self._arc = arc_memory - def generate( - self, agent_id: str, min_examples: int = 30 - ) -> list[dspy.Example]: + def generate(self, agent_id: str, min_examples: int = 30) -> list[dspy.Example]: """Generate training set from ARC invocations for a specific expert. Calls arc_memory.get_invocations_by_agent with status="success", @@ -73,9 +80,7 @@ def generate( >>> assert len(examples) >= 30 >>> assert "question" in examples[0].inputs() """ - invocations = self._arc.get_invocations_by_agent( - agent_id, status="success" - ) + invocations = self._arc.get_invocations_by_agent(agent_id, status="success") if len(invocations) < min_examples: raise ValueError( @@ -127,9 +132,7 @@ def get_available_counts(self) -> dict[str, int]: return counts @staticmethod - def _invocation_to_example( - inv: Invocation, agent_id: str - ) -> dspy.Example | None: + def _invocation_to_example(inv: Invocation, agent_id: str) -> dspy.Example | None: """Convert a single Invocation to a dspy.Example. Maps invocation input/output fields to expert signature fields. @@ -182,9 +185,7 @@ def _invocation_to_example( return None -def clio_expert_metric( - example: dspy.Example, pred: Any, trace: Any = None -) -> float | bool: +def clio_expert_metric(example: dspy.Example, pred: Any, trace: Any = None) -> float | bool: """Multi-signal metric for CLIO expert optimization. Scores expert outputs on three weighted signals: diff --git a/src/clio_agent/optimizer/variants.py b/src/clio_agent/optimizer/variants.py index d151b82..c5342e9 100644 --- a/src/clio_agent/optimizer/variants.py +++ b/src/clio_agent/optimizer/variants.py @@ -137,9 +137,7 @@ def load_variant( """ variant_path = self._variants_dir / f"{variant_id}.json" if not variant_path.exists(): - raise FileNotFoundError( - f"Variant file not found: {variant_path}" - ) + raise FileNotFoundError(f"Variant file not found: {variant_path}") existing_module.load(path=str(variant_path)) return existing_module @@ -169,9 +167,7 @@ def deploy(self, variant_id: str, agent_id: str) -> None: found = True if not found: - raise ValueError( - f"Variant '{variant_id}' not found for agent '{agent_id}'" - ) + raise ValueError(f"Variant '{variant_id}' not found for agent '{agent_id}'") # Deactivate all, then activate target for record in records: diff --git a/src/clio_agent/registry/capability_matcher.py b/src/clio_agent/registry/capability_matcher.py index e822488..bd919f3 100644 --- a/src/clio_agent/registry/capability_matcher.py +++ b/src/clio_agent/registry/capability_matcher.py @@ -37,13 +37,64 @@ def __init__(self): """Initialize matcher with default stopwords.""" # Common English stopwords to filter from queries self._stopwords = { - 'the', 'a', 'an', 'in', 'on', 'at', 'to', 'for', 'of', 'with', - 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', - 'had', 'do', 'does', 'did', 'will', 'would', 'should', 'could', - 'can', 'may', 'might', 'must', 'shall', 'i', 'you', 'he', 'she', - 'it', 'we', 'they', 'them', 'this', 'that', 'these', 'those', - 'my', 'your', 'his', 'her', 'its', 'our', 'their', 'what', 'which', - 'who', 'when', 'where', 'why', 'how' + "the", + "a", + "an", + "in", + "on", + "at", + "to", + "for", + "of", + "with", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "will", + "would", + "should", + "could", + "can", + "may", + "might", + "must", + "shall", + "i", + "you", + "he", + "she", + "it", + "we", + "they", + "them", + "this", + "that", + "these", + "those", + "my", + "your", + "his", + "her", + "its", + "our", + "their", + "what", + "which", + "who", + "when", + "where", + "why", + "how", } def extract_keywords(self, query: str) -> List[str]: @@ -70,18 +121,14 @@ def extract_keywords(self, query: str) -> List[str]: ['best', 'compression', 'adios'] """ # Lowercase and split on non-alphanumeric (keep numbers for hdf5, etc.) - words = re.findall(r'\w+', query.lower()) + words = re.findall(r"\w+", query.lower()) # Filter stopwords and empty strings keywords = [w for w in words if w and w not in self._stopwords] return keywords - def match_query( - self, - query: str, - capabilities: Dict[str, Any] - ) -> List[Tuple[str, float]]: + def match_query(self, query: str, capabilities: Dict[str, Any]) -> List[Tuple[str, float]]: """Match query to agent capabilities and return ranked list. Matching algorithm: @@ -125,7 +172,7 @@ def match_query( "No agent matching performed. " "Consider rewording the query with more specific terms.", UserWarning, - stacklevel=2 + stacklevel=2, ) return [] @@ -133,7 +180,7 @@ def match_query( scores = [] for agent_id, capability in capabilities.items(): # Get agent's keywords (handle missing keywords gracefully) - agent_keywords = capability.get('keywords', []) + agent_keywords = capability.get("keywords", []) if not agent_keywords: continue @@ -149,11 +196,7 @@ def match_query( return scores - def _calculate_score( - self, - query_keywords: List[str], - agent_keywords: List[str] - ) -> float: + def _calculate_score(self, query_keywords: List[str], agent_keywords: List[str]) -> float: """Calculate match score between query and agent keywords. Score formula: @@ -183,7 +226,7 @@ def _calculate_score( # Example: "parallel io" becomes ["parallel", "io"] expanded_keywords = [] for kw in agent_keywords: - expanded_keywords.extend(re.findall(r'\w+', kw.lower())) + expanded_keywords.extend(re.findall(r"\w+", kw.lower())) # Convert to set for O(1) lookup agent_kw_set = set(expanded_keywords) diff --git a/src/clio_agent/runtime/hooks.py b/src/clio_agent/runtime/hooks.py index e18e46e..835e2ff 100644 --- a/src/clio_agent/runtime/hooks.py +++ b/src/clio_agent/runtime/hooks.py @@ -62,9 +62,7 @@ class HookRegistry: def __init__(self, *, hooks_dir: Optional[Path] = None) -> None: self._dir = hooks_dir if hooks_dir is not None else _default_hooks_dir() - self._hooks: dict[str, list[Callable[..., Any]]] = { - event: [] for event in _KNOWN_EVENTS - } + self._hooks: dict[str, list[Callable[..., Any]]] = {event: [] for event in _KNOWN_EVENTS} self._lock = threading.Lock() self._load() @@ -75,9 +73,7 @@ def _load(self) -> None: try: module = self._import_path(path) except Exception as exc: - logger.warning( - "[clio-hooks] failed to load %s: %r", path, exc - ) + logger.warning("[clio-hooks] failed to load %s: %r", path, exc) continue for event in _KNOWN_EVENTS: fn = getattr(module, event, None) @@ -98,9 +94,7 @@ def _import_path(path: Path) -> Any: spec.loader.exec_module(module) return module - def fire( - self, event: str, /, *args: Any, **kwargs: Any - ) -> list[Any]: + def fire(self, event: str, /, *args: Any, **kwargs: Any) -> list[Any]: """Invoke every registered hook for ``event`` in load order. Returns the list of return values for hooks that didn't @@ -119,9 +113,7 @@ def fire( results.append(fn(*args, **kwargs)) except PermissionError: if event.startswith("post_") or event == "on_error": - logger.warning( - "[clio-hooks] post-event hook raised; swallowing" - ) + logger.warning("[clio-hooks] post-event hook raised; swallowing") continue raise except Exception as exc: # noqa: BLE001 diff --git a/src/clio_agent/tools/servers/adios_server.py b/src/clio_agent/tools/servers/adios_server.py index b537939..9c0846b 100644 --- a/src/clio_agent/tools/servers/adios_server.py +++ b/src/clio_agent/tools/servers/adios_server.py @@ -130,7 +130,9 @@ def _adios2_unavailable() -> dict[str, Any]: } -def _inspect_variables_with_adios2(filepath: Path, variable_name: str | None = None) -> dict[str, Any]: +def _inspect_variables_with_adios2( + filepath: Path, variable_name: str | None = None +) -> dict[str, Any]: """Inspect BP variables through ADIOS2 when the optional dependency exists.""" try: from adios2 import FileReader # type: ignore[import-not-found] diff --git a/src/clio_agent/tools/servers/fs_server.py b/src/clio_agent/tools/servers/fs_server.py index 311884e..2735d7a 100644 --- a/src/clio_agent/tools/servers/fs_server.py +++ b/src/clio_agent/tools/servers/fs_server.py @@ -92,19 +92,17 @@ def propose_edit(filepath: str, new_content: str) -> dict[str, Any]: p = Path(safe_read) old = p.read_text(encoding="utf-8", errors="replace") new = new_content if isinstance(new_content, str) else str(new_content) - diff_lines = list(difflib.unified_diff( - old.splitlines(keepends=True), - new.splitlines(keepends=True), - fromfile=f"a/{p.name}", - tofile=f"b/{p.name}", - lineterm="", - )) - added = sum( - 1 for ln in diff_lines if ln.startswith("+") and not ln.startswith("+++") - ) - removed = sum( - 1 for ln in diff_lines if ln.startswith("-") and not ln.startswith("---") + diff_lines = list( + difflib.unified_diff( + old.splitlines(keepends=True), + new.splitlines(keepends=True), + fromfile=f"a/{p.name}", + tofile=f"b/{p.name}", + lineterm="", + ) ) + added = sum(1 for ln in diff_lines if ln.startswith("+") and not ln.startswith("+++")) + removed = sum(1 for ln in diff_lines if ln.startswith("-") and not ln.startswith("---")) return { "path": str(p), "unified_diff": "\n".join(diff_lines), diff --git a/tests/test_arc/test_lsm.py b/tests/test_arc/test_lsm.py index 91638f2..94c8b08 100644 --- a/tests/test_arc/test_lsm.py +++ b/tests/test_arc/test_lsm.py @@ -130,9 +130,7 @@ def test_range_scan_across_flush(self, lsm): def test_compaction(self, temp_dir): """Test SSTable compaction.""" # Create LSM with low compaction threshold - lsm = LSMTree( - data_dir=temp_dir, memtable_size=5, compaction_threshold=3 - ) + lsm = LSMTree(data_dir=temp_dir, memtable_size=5, compaction_threshold=3) try: # Write enough data to create multiple SSTables diff --git a/tests/test_arc/test_memory_coverage.py b/tests/test_arc/test_memory_coverage.py index 9d14f7d..9fb36ab 100644 --- a/tests/test_arc/test_memory_coverage.py +++ b/tests/test_arc/test_memory_coverage.py @@ -357,12 +357,17 @@ def test_clear_all(self, arc): """clear_all should remove everything.""" now = time.time() arc.store_invocation(_inv("t1")) - arc.store_conversation(Conversation( - session_id="s1", user_id="u1", - created_at=now, updated_at=now, - last_accessed=now, status="active", - messages=[Message(role="user", content="hi", timestamp=now)], - )) + arc.store_conversation( + Conversation( + session_id="s1", + user_id="u1", + created_at=now, + updated_at=now, + last_accessed=now, + status="active", + messages=[Message(role="user", content="hi", timestamp=now)], + ) + ) arc.store_context(Context(domain="d1", created_at=now, updated_at=now)) arc.clear_all() diff --git a/tests/test_arc/test_retrieval.py b/tests/test_arc/test_retrieval.py index bddc985..9560a8a 100644 --- a/tests/test_arc/test_retrieval.py +++ b/tests/test_arc/test_retrieval.py @@ -44,10 +44,7 @@ def _make_conversation(session_id: str, messages: list[tuple[str, str]]) -> Conv updated_at=now, last_accessed=now, status="active", - messages=[ - Message(role=role, content=content, timestamp=now) - for role, content in messages - ], + messages=[Message(role=role, content=content, timestamp=now) for role, content in messages], ) @@ -94,25 +91,19 @@ class TestCalculateRelevanceScore: def test_identical_content(self, retriever): """Identical query and conversation should score high.""" conv = _make_conversation("s1", [("user", "optimize HDF5 compression")]) - score = retriever._calculate_relevance_score( - "optimize HDF5 compression", conv - ) + score = retriever._calculate_relevance_score("optimize HDF5 compression", conv) assert score > 0.5 def test_no_overlap(self, retriever): """No keyword overlap should score 0.""" conv = _make_conversation("s1", [("user", "weather forecast tomorrow")]) - score = retriever._calculate_relevance_score( - "optimize HDF5 compression", conv - ) + score = retriever._calculate_relevance_score("optimize HDF5 compression", conv) assert score == 0.0 def test_partial_overlap(self, retriever): """Partial overlap should score between 0 and 1.""" conv = _make_conversation("s1", [("user", "HDF5 file structure analysis")]) - score = retriever._calculate_relevance_score( - "optimize HDF5 compression", conv - ) + score = retriever._calculate_relevance_score("optimize HDF5 compression", conv) assert 0.0 < score < 1.0 def test_empty_query(self, retriever): @@ -152,10 +143,13 @@ class TestExtractKeyTopics: def test_extracts_frequent_words(self, retriever): """Should extract most frequent meaningful words.""" - conv = _make_conversation("s1", [ - ("user", "HDF5 compression HDF5 gzip HDF5"), - ("assistant", "HDF5 uses gzip compression by default"), - ]) + conv = _make_conversation( + "s1", + [ + ("user", "HDF5 compression HDF5 gzip HDF5"), + ("assistant", "HDF5 uses gzip compression by default"), + ], + ) topics = retriever.extract_key_topics([conv]) assert "hdf5" in topics assert "compression" in topics @@ -170,31 +164,28 @@ class TestRetrieveContextForQuery: def test_returns_context_object(self, retriever, arc): """Should return a Context object.""" - context = retriever.retrieve_context_for_query( - "optimize HDF5", "session-1" - ) + context = retriever.retrieve_context_for_query("optimize HDF5", "session-1") assert context is not None assert context.domain.startswith("query_context_") def test_with_conversation_history(self, retriever, arc): """Should include learned patterns from conversation.""" - conv = _make_conversation("session-1", [ - ("user", "How to optimize HDF5 compression?"), - ("assistant", "Use gzip compression with chunking."), - ]) + conv = _make_conversation( + "session-1", + [ + ("user", "How to optimize HDF5 compression?"), + ("assistant", "Use gzip compression with chunking."), + ], + ) arc.store_conversation(conv) - context = retriever.retrieve_context_for_query( - "HDF5 compression", "session-1" - ) + context = retriever.retrieve_context_for_query("HDF5 compression", "session-1") # Should have learned patterns assert len(context.learned_patterns) >= 0 def test_without_conversation(self, retriever, arc): """Should handle missing conversation gracefully.""" - context = retriever.retrieve_context_for_query( - "test query", "nonexistent-session" - ) + context = retriever.retrieve_context_for_query("test query", "nonexistent-session") assert context is not None @@ -203,9 +194,7 @@ class TestCompileExpertContext: def test_returns_string(self, retriever, arc): """Should return a compiled context string.""" - result = retriever.compile_expert_context( - "analyze HDF5", "session-1", tier=2 - ) + result = retriever.compile_expert_context("analyze HDF5", "session-1", tier=2) assert isinstance(result, str) def test_lazy_init_compiler(self, retriever): @@ -220,9 +209,7 @@ class TestGetRelevantToolResults: def test_no_context_returns_empty(self, retriever): """Missing context should return empty list.""" - results = retriever.get_relevant_tool_results( - "analyze HDF5", "hdf5_domain" - ) + results = retriever.get_relevant_tool_results("analyze HDF5", "hdf5_domain") assert results == [] def test_with_cached_results(self, retriever, arc): @@ -251,9 +238,7 @@ def test_with_cached_results(self, retriever, arc): ) arc.store_context(ctx) - results = retriever.get_relevant_tool_results( - "HDF5 compression analysis", "hdf5_domain" - ) + results = retriever.get_relevant_tool_results("HDF5 compression analysis", "hdf5_domain") # Should return results with hdf5-related content ranked higher assert len(results) >= 1 assert "tool" in results[0] diff --git a/tests/test_arc/test_shared_context.py b/tests/test_arc/test_shared_context.py index a5913b7..4de08af 100644 --- a/tests/test_arc/test_shared_context.py +++ b/tests/test_arc/test_shared_context.py @@ -90,9 +90,7 @@ def test_get_session_profiles_multiple(self, arc): """Store 3 profiles in same session, retrieve all.""" arc.store_dataset_profile(_make_profile(filepath="/data/a.parquet")) arc.store_dataset_profile(_make_profile(filepath="/data/b.parquet")) - arc.store_dataset_profile( - _make_profile(filepath="/data/c.hdf5", file_format="hdf5") - ) + arc.store_dataset_profile(_make_profile(filepath="/data/c.hdf5", file_format="hdf5")) profiles = arc.get_session_profiles("session-1") assert len(profiles) == 3 @@ -120,12 +118,8 @@ def test_dataset_profile_cross_expert(self, arc): def test_dataset_profile_different_sessions(self, arc): """Session A profiles are not visible to session B.""" - arc.store_dataset_profile( - _make_profile(session_id="session-A", filepath="/data/a.parquet") - ) - arc.store_dataset_profile( - _make_profile(session_id="session-B", filepath="/data/b.parquet") - ) + arc.store_dataset_profile(_make_profile(session_id="session-A", filepath="/data/a.parquet")) + arc.store_dataset_profile(_make_profile(session_id="session-B", filepath="/data/b.parquet")) profiles_a = arc.get_session_profiles("session-A") profiles_b = arc.get_session_profiles("session-B") @@ -171,9 +165,7 @@ def test_store_and_get_procedural_memory(self, arc): def test_procedural_memory_filter_by_expert(self, arc): """Filter procedural memories by expert_id.""" - arc.store_procedural_memory( - _make_procedural(expert_id="data", description="data pattern") - ) + arc.store_procedural_memory(_make_procedural(expert_id="data", description="data pattern")) arc.store_procedural_memory( _make_procedural(expert_id="analysis", description="analysis pattern") ) @@ -192,15 +184,9 @@ def test_procedural_memory_ordering(self, arc): t2 = time.time() - 50 t3 = time.time() - arc.store_procedural_memory( - _make_procedural(description="oldest", learned_at=t1) - ) - arc.store_procedural_memory( - _make_procedural(description="middle", learned_at=t2) - ) - arc.store_procedural_memory( - _make_procedural(description="newest", learned_at=t3) - ) + arc.store_procedural_memory(_make_procedural(description="oldest", learned_at=t1)) + arc.store_procedural_memory(_make_procedural(description="middle", learned_at=t2)) + arc.store_procedural_memory(_make_procedural(description="newest", learned_at=t3)) memories = arc.get_procedural_memories("session-1") assert len(memories) == 3 diff --git a/tests/test_arc/test_storage.py b/tests/test_arc/test_storage.py index f1a16cc..555e06b 100644 --- a/tests/test_arc/test_storage.py +++ b/tests/test_arc/test_storage.py @@ -47,9 +47,7 @@ def test_default_tier_policy(self, backend): def test_custom_tier_policy(self, tmp_path): """Custom tier policy should override defaults.""" policy = {"hot_to_warm": 2, "warm_to_cold": 14, "cold_to_archive": 60} - backend = IOWarpCTEBackend( - base_dir=str(tmp_path / "s"), tier_policy=policy - ) + backend = IOWarpCTEBackend(base_dir=str(tmp_path / "s"), tier_policy=policy) assert backend.tier_policy["warm_to_cold"] == 14 def test_performance_counters_start_at_zero(self, backend): diff --git a/tests/test_core/test_codex_provider.py b/tests/test_core/test_codex_provider.py index 455d112..a1d1573 100644 --- a/tests/test_core/test_codex_provider.py +++ b/tests/test_core/test_codex_provider.py @@ -55,9 +55,7 @@ def _reset_registration(): class TestMessagesToCodexPrompt: def test_single_user_message(self): - prompt = _messages_to_codex_prompt( - [{"role": "user", "content": "hello"}] - ) + prompt = _messages_to_codex_prompt([{"role": "user", "content": "hello"}]) assert "JSON Lines" in prompt assert json.loads(prompt.splitlines()[-1]) == { "role": "user", @@ -112,9 +110,7 @@ def test_role_like_content_cannot_create_prompt_boundary(self): } def test_unknown_role_is_downgraded_to_user(self): - prompt = _messages_to_codex_prompt( - [{"role": "root", "content": "hello"}] - ) + prompt = _messages_to_codex_prompt([{"role": "root", "content": "hello"}]) row = json.loads(prompt.splitlines()[-1]) assert row == {"role": "user", "content": "hello"} @@ -218,6 +214,7 @@ def test_unavailable_raises(self): from clio_agent.providers.codex_litellm import ( _resolve_codex_binary, ) + _resolve_codex_binary() @@ -417,8 +414,6 @@ def test_registered_entry_points_at_codex(self): import litellm ensure_registered() - codex_entries = [ - e for e in litellm.custom_provider_map if e.get("provider") == "codex" - ] + codex_entries = [e for e in litellm.custom_provider_map if e.get("provider") == "codex"] assert len(codex_entries) == 1 assert isinstance(codex_entries[0]["custom_handler"], CodexLLM) diff --git a/tests/test_core/test_instrumentation.py b/tests/test_core/test_instrumentation.py index 698bddd..5a0e352 100644 --- a/tests/test_core/test_instrumentation.py +++ b/tests/test_core/test_instrumentation.py @@ -232,12 +232,14 @@ def test_get_invocations_by_agent_filtered(): arc = _make_arc(tmp) # Store invocations for two agents - for i, (agent, status) in enumerate([ - ("data", "success"), - ("data", "failure"), - ("analysis", "success"), - ("data", "success"), - ]): + for i, (agent, status) in enumerate( + [ + ("data", "success"), + ("data", "failure"), + ("analysis", "success"), + ("data", "success"), + ] + ): inv = Invocation( trace_id=f"trace-{i}", session_id="session-1", diff --git a/tests/test_core/test_runner.py b/tests/test_core/test_runner.py index 7351734..cdc2632 100644 --- a/tests/test_core/test_runner.py +++ b/tests/test_core/test_runner.py @@ -129,8 +129,7 @@ class TestRun: def test_run_rejects_small_trainset(self, runner): """run() rejects trainset with fewer than 5 examples.""" small_trainset = [ - dspy.Example(question="q", analysis="a").with_inputs("question") - for _ in range(4) + dspy.Example(question="q", analysis="a").with_inputs("question") for _ in range(4) ] with pytest.raises(ValueError, match="at least 5"): @@ -165,8 +164,15 @@ def test_run_full_pipeline(self, mock_evaluate_cls, mock_simba_cls, runner, tmp_ # Verify result dict has all expected keys expected_keys = { - "optimized", "before_score", "after_score", "improvement_delta", - "p_value", "is_significant", "variant_record", "train_size", "val_size", + "optimized", + "before_score", + "after_score", + "improvement_delta", + "p_value", + "is_significant", + "variant_record", + "train_size", + "val_size", } assert set(result.keys()) == expected_keys @@ -186,7 +192,9 @@ def test_run_full_pipeline(self, mock_evaluate_cls, mock_simba_cls, runner, tmp_ @patch("clio_agent.optimizer.runner.dspy.SIMBA") @patch("clio_agent.optimizer.runner.dspy.evaluate.Evaluate") - def test_run_saves_variant_with_correct_args(self, mock_evaluate_cls, mock_simba_cls, runner, tmp_path): + def test_run_saves_variant_with_correct_args( + self, mock_evaluate_cls, mock_simba_cls, runner, tmp_path + ): """run() calls variant_manager.save_variant with correct arguments.""" mock_evaluator = MagicMock() mock_evaluator.side_effect = [50.0, 70.0] @@ -214,4 +222,7 @@ def test_run_saves_variant_with_correct_args(self, mock_evaluate_cls, mock_simba # Evaluate was constructed with the custom metric_fn mock_evaluate_cls.assert_called() call_kwargs = mock_evaluate_cls.call_args - assert call_kwargs.kwargs.get("metric") == metric_fn or call_kwargs[1].get("metric") == metric_fn + assert ( + call_kwargs.kwargs.get("metric") == metric_fn + or call_kwargs[1].get("metric") == metric_fn + ) diff --git a/tests/test_core/test_variants.py b/tests/test_core/test_variants.py index 0f0e2be..7b83ecb 100644 --- a/tests/test_core/test_variants.py +++ b/tests/test_core/test_variants.py @@ -65,9 +65,7 @@ class TestSaveVariant: def test_save_creates_file_on_disk(self, vm, mock_module, tmp_path): """save_variant creates a JSON file on disk.""" - vm.save_variant( - mock_module, "data", 0.6, 0.85, 50, 0.003, True - ) + vm.save_variant(mock_module, "data", 0.6, 0.85, 50, 0.003, True) variant_path = tmp_path / "variants" / "data_v1.json" assert variant_path.exists() @@ -75,9 +73,7 @@ def test_save_creates_file_on_disk(self, vm, mock_module, tmp_path): def test_save_stores_variant_record_in_arc(self, vm, arc, mock_module): """save_variant stores VariantRecord in ARC memory.""" - vm.save_variant( - mock_module, "data", 0.6, 0.85, 50, 0.003, True - ) + vm.save_variant(mock_module, "data", 0.6, 0.85, 50, 0.003, True) stored = arc.get_variant_records("data") assert len(stored) == 1 @@ -92,9 +88,7 @@ def test_save_stores_variant_record_in_arc(self, vm, arc, mock_module): def test_save_returns_variant_record(self, vm, mock_module): """save_variant returns a VariantRecord with correct metadata.""" - record = vm.save_variant( - mock_module, "data", 0.6, 0.85, 50, 0.003, True - ) + record = vm.save_variant(mock_module, "data", 0.6, 0.85, 50, 0.003, True) assert isinstance(record, VariantRecord) assert record.variant_id == "data_v1" diff --git a/tests/test_gact/conftest.py b/tests/test_gact/conftest.py index 722ac36..7d71744 100644 --- a/tests/test_gact/conftest.py +++ b/tests/test_gact/conftest.py @@ -59,6 +59,4 @@ def complete_turn( break time.sleep(poll_interval) - raise TimeoutError( - f"turn for user message {user_id!r} did not settle within {timeout:g}s" - ) + raise TimeoutError(f"turn for user message {user_id!r} did not settle within {timeout:g}s") diff --git a/tests/test_gact/test_cluster5.py b/tests/test_gact/test_cluster5.py index 23a461e..42d3d73 100644 --- a/tests/test_gact/test_cluster5.py +++ b/tests/test_gact/test_cluster5.py @@ -88,11 +88,10 @@ def test_share_expiry(client: TestClient) -> None: """ttl_s=1 returns a token that 410s after we trick the clock.""" sid = client.post("/v1/sessions", json={"title": "t"}).json()["id"] - share = client.post( - f"/v1/sessions/{sid}/share", json={"ttl_s": 1} - ).json() + share = client.post(f"/v1/sessions/{sid}/share", json={"ttl_s": 1}).json() # Force expiry by overwriting the row's expires_at. import time + state = client.app.state state.shared_tokens[share["token"]]["expires_at"] = time.time() - 10 resp = client.get(f"/v1/shared/{share['token']}") @@ -118,19 +117,27 @@ def test_extract_agent_from_sessions(client: TestClient) -> None: (sid1, ["hdf5_list_datasets", "hdf5_analyze_file"]), (sid2, ["hdf5_list_datasets", "parquet_analyze_schema"]), ]: - state.messages.setdefault(sid, []).append(Message( - id="msg_user_x", session_id=sid, role="user", - created_at="2026-04-25T00:00:00Z", - updated_at="2026-04-25T00:00:00Z", - parts=[Part(id="part_x", type="text", text=f"analyze {sid}")], - )) - state.messages[sid].append(Message( - id=f"msg_asst_{sid}", session_id=sid, role="assistant", - created_at="2026-04-25T00:00:00Z", - updated_at="2026-04-25T00:00:00Z", - parts=[Part(id="part_y", type="text", text="done")], - metadata={"tools_called": [{"name": t} for t in tools]}, - )) + state.messages.setdefault(sid, []).append( + Message( + id="msg_user_x", + session_id=sid, + role="user", + created_at="2026-04-25T00:00:00Z", + updated_at="2026-04-25T00:00:00Z", + parts=[Part(id="part_x", type="text", text=f"analyze {sid}")], + ) + ) + state.messages[sid].append( + Message( + id=f"msg_asst_{sid}", + session_id=sid, + role="assistant", + created_at="2026-04-25T00:00:00Z", + updated_at="2026-04-25T00:00:00Z", + parts=[Part(id="part_y", type="text", text="done")], + metadata={"tools_called": [{"name": t} for t in tools]}, + ) + ) resp = client.post( "/v1/agents/extract", diff --git a/tests/test_gact/test_commands.py b/tests/test_gact/test_commands.py index d9a77a9..6145d94 100644 --- a/tests/test_gact/test_commands.py +++ b/tests/test_gact/test_commands.py @@ -30,9 +30,7 @@ def get_cache_stats(self) -> dict[str, object]: @pytest.fixture() def client(tmp_path: Path) -> TestClient: - return TestClient( - build_app(sessions_path=tmp_path / "s.json", agent=_Agent()) - ) + return TestClient(build_app(sessions_path=tmp_path / "s.json", agent=_Agent())) def test_commands_listed(client: TestClient) -> None: @@ -55,9 +53,7 @@ def test_dispatch_clear_drops_messages(client: TestClient) -> None: sid = client.post("/v1/sessions", json={"title": "t"}).json()["id"] complete_turn(client, sid, "first") - assert ( - len(client.get(f"/v1/sessions/{sid}/messages").json()["messages"]) == 2 - ) + assert len(client.get(f"/v1/sessions/{sid}/messages").json()["messages"]) == 2 resp = client.post(f"/v1/sessions/{sid}/commands/clear").json() assert resp["command"] == "/clear" diff --git a/tests/test_gact/test_context_injection.py b/tests/test_gact/test_context_injection.py index 9c11f06..c21ae84 100644 --- a/tests/test_gact/test_context_injection.py +++ b/tests/test_gact/test_context_injection.py @@ -23,10 +23,15 @@ def __init__(self) -> None: def forward(self, question: str, session_id: str = "default"): self.calls.append((question, session_id)) - return type("Pred", (), { - "answer": "ok", "selected_expert": "", - "routing_rationale": "", - })() + return type( + "Pred", + (), + { + "answer": "ok", + "selected_expert": "", + "routing_rationale": "", + }, + )() @pytest.fixture() @@ -37,9 +42,7 @@ def setup(tmp_path: Path): app = build_app(sessions_path=tmp_path / "s.json", agent=agent) # Update ws_default's root_path so its files pass the # workspace check inside _enrich_with_context_files. - app.state.workspaces.update( - "ws_default", root_path=str(tmp_path) - ) + app.state.workspaces.update("ws_default", root_path=str(tmp_path)) return app, TestClient(app), agent, tmp_path @@ -134,10 +137,7 @@ def test_read_file_deleted_after_attach_surfaces_error(setup) -> None: "retry", "exit", ] - completed = [ - ev for ev in app.state.bus._history.get(sid, []) - if ev.type == "message.completed" - ] + completed = [ev for ev in app.state.bus._history.get(sid, []) if ev.type == "message.completed"] assert completed, "turn did not publish message.completed" payload = completed[-1].payload assert payload["message_id"] == assistant["id"] @@ -145,9 +145,7 @@ def test_read_file_deleted_after_attach_surfaces_error(setup) -> None: assert payload["error_info"]["error"] == "context_file_error" -def test_path_outside_workspace_root_is_inlined_for_reads( - setup, tmp_path: Path -) -> None: +def test_path_outside_workspace_root_is_inlined_for_reads(setup, tmp_path: Path) -> None: """iowarp/clio-agent#5 (revised): the workspace-root check used to silently skip context_files outside the root, but the user explicitly attaches via the API — they know what they're doing. diff --git a/tests/test_gact/test_cost_tracking.py b/tests/test_gact/test_cost_tracking.py index bb3b552..7c84dbe 100644 --- a/tests/test_gact/test_cost_tracking.py +++ b/tests/test_gact/test_cost_tracking.py @@ -39,13 +39,12 @@ def forward(self, question: str, session_id: str): def _client(tmp_path: Path, pred) -> TestClient: - return TestClient( - build_app(sessions_path=tmp_path / "s.json", agent=_Agent(pred)) - ) + return TestClient(build_app(sessions_path=tmp_path / "s.json", agent=_Agent(pred))) def _turn(client: TestClient, sid: str) -> dict: from .conftest import complete_turn + return complete_turn(client, sid, "hello") diff --git a/tests/test_gact/test_diffs.py b/tests/test_gact/test_diffs.py index c53b663..213053b 100644 --- a/tests/test_gact/test_diffs.py +++ b/tests/test_gact/test_diffs.py @@ -27,13 +27,12 @@ def forward(self, question: str, session_id: str): def _client(tmp_path: Path, diffs) -> TestClient: - return TestClient( - build_app(sessions_path=tmp_path / "s.json", agent=_Agent(diffs)) - ) + return TestClient(build_app(sessions_path=tmp_path / "s.json", agent=_Agent(diffs))) def _turn(client: TestClient, sid: str) -> dict: from .conftest import complete_turn + return complete_turn(client, sid, "propose an edit") @@ -47,9 +46,12 @@ def _turn(client: TestClient, sid: str) -> dict: def test_assistant_emits_file_diff_part(tmp_path: Path) -> None: - client = _client(tmp_path, diffs=[ - {"path": "main.go", "unified_diff": SAMPLE_DIFF}, - ]) + client = _client( + tmp_path, + diffs=[ + {"path": "main.go", "unified_diff": SAMPLE_DIFF}, + ], + ) sid = client.post("/v1/sessions", json={"title": "t"}).json()["id"] a = _turn(client, sid) parts = a["parts"] @@ -61,16 +63,17 @@ def test_assistant_emits_file_diff_part(tmp_path: Path) -> None: def test_apply_flips_status_and_returns_paths(tmp_path: Path) -> None: - client = _client(tmp_path, diffs=[ - {"path": "a.py", "unified_diff": SAMPLE_DIFF}, - {"path": "b.py", "unified_diff": SAMPLE_DIFF}, - ]) + client = _client( + tmp_path, + diffs=[ + {"path": "a.py", "unified_diff": SAMPLE_DIFF}, + {"path": "b.py", "unified_diff": SAMPLE_DIFF}, + ], + ) sid = client.post("/v1/sessions", json={"title": "t"}).json()["id"] _turn(client, sid) - resp = client.post( - f"/v1/sessions/{sid}/diffs/apply", json={"paths": ["a.py"]} - ).json() + resp = client.post(f"/v1/sessions/{sid}/diffs/apply", json={"paths": ["a.py"]}).json() assert resp["applied"] == ["a.py"] # b.py still pending — apply-all picks it up. @@ -83,9 +86,12 @@ def test_apply_flips_status_and_returns_paths(tmp_path: Path) -> None: def test_reject_flips_status(tmp_path: Path) -> None: - client = _client(tmp_path, diffs=[ - {"path": "a.py", "unified_diff": SAMPLE_DIFF}, - ]) + client = _client( + tmp_path, + diffs=[ + {"path": "a.py", "unified_diff": SAMPLE_DIFF}, + ], + ) sid = client.post("/v1/sessions", json={"title": "t"}).json()["id"] _turn(client, sid) resp = client.post(f"/v1/sessions/{sid}/diffs/reject", json={}).json() diff --git a/tests/test_gact/test_doctor_integrations.py b/tests/test_gact/test_doctor_integrations.py index 275f8ad..c4e7dd6 100644 --- a/tests/test_gact/test_doctor_integrations.py +++ b/tests/test_gact/test_doctor_integrations.py @@ -65,11 +65,13 @@ def test_no_agent_no_arc_overall_is_unavailable(tmp_path: Path) -> None: def test_fake_agent_flagged_degraded(tmp_path: Path) -> None: - resp = _health_response(build_app( - sessions_path=tmp_path / "s.json", - agent=_RealishAgent(), - arc=_FakeARC(), - )) + resp = _health_response( + build_app( + sessions_path=tmp_path / "s.json", + agent=_RealishAgent(), + arc=_FakeARC(), + ) + ) assert resp.status_code == 200 body = resp.json() rows = {r["name"]: r for r in body["integrations"]} @@ -83,11 +85,13 @@ def test_fake_agent_flagged_degraded(tmp_path: Path) -> None: def test_broken_arc_reports_unavailable(tmp_path: Path) -> None: - resp = _health_response(build_app( - sessions_path=tmp_path / "s.json", - agent=_RealishAgent(), - arc=_BrokenARC(), - )) + resp = _health_response( + build_app( + sessions_path=tmp_path / "s.json", + agent=_RealishAgent(), + arc=_BrokenARC(), + ) + ) assert resp.status_code == 503 body = resp.json() rows = {r["name"]: r for r in body["integrations"]} @@ -96,11 +100,13 @@ def test_broken_arc_reports_unavailable(tmp_path: Path) -> None: def test_ready_health_returns_200(tmp_path: Path) -> None: - resp = _health_response(build_app( - sessions_path=tmp_path / "s.json", - agent=_ProductionLikeAgent(), - arc=_FakeARC(), - )) + resp = _health_response( + build_app( + sessions_path=tmp_path / "s.json", + agent=_ProductionLikeAgent(), + arc=_FakeARC(), + ) + ) assert resp.status_code == 200 body = resp.json() assert body["overall_status"] == "ready" diff --git a/tests/test_gact/test_fork_and_search.py b/tests/test_gact/test_fork_and_search.py index bb8282b..5f7c939 100644 --- a/tests/test_gact/test_fork_and_search.py +++ b/tests/test_gact/test_fork_and_search.py @@ -30,13 +30,12 @@ def forward(self, question: str, session_id: str): def _client(tmp_path: Path) -> TestClient: - return TestClient( - build_app(sessions_path=tmp_path / "s.json", agent=_Agent()) - ) + return TestClient(build_app(sessions_path=tmp_path / "s.json", agent=_Agent())) def _turn(client: TestClient, sid: str, text: str) -> dict: from .conftest import complete_turn + return complete_turn(client, sid, text) @@ -80,9 +79,7 @@ def test_search_returns_ranked_snippets(tmp_path: Path) -> None: _turn(client, sid, "load /tmp/alpha.parquet") _turn(client, sid, "compare /tmp/alpha.parquet to /tmp/beta.parquet") - body = client.get( - f"/v1/sessions/{sid}/messages/search?q=alpha.parquet" - ).json() + body = client.get(f"/v1/sessions/{sid}/messages/search?q=alpha.parquet").json() matches = body["matches"] assert len(matches) >= 2 for m in matches: diff --git a/tests/test_gact/test_hooks.py b/tests/test_gact/test_hooks.py index f089329..f9f1862 100644 --- a/tests/test_gact/test_hooks.py +++ b/tests/test_gact/test_hooks.py @@ -42,11 +42,14 @@ def _hook_dir(tmp_path: Path, **events: str) -> Path: def test_pre_tool_hook_can_block(tmp_path: Path) -> None: """A pre_tool hook that raises PermissionError vetoes the call.""" - d = _hook_dir(tmp_path, pre_tool=""" + d = _hook_dir( + tmp_path, + pre_tool=""" def pre_tool(name, args): if name.startswith("hdf5_"): raise PermissionError("no hdf5 today") -""") +""", + ) reg = HookRegistry(hooks_dir=d) with pytest.raises(PermissionError, match="no hdf5"): reg.fire("pre_tool", "hdf5_list_datasets", {"path": "/tmp/x"}) @@ -58,10 +61,13 @@ def test_post_tool_hook_swallows_exceptions(tmp_path: Path) -> None: """post_* hooks must NEVER crash a turn; exceptions are swallowed + logged.""" - d = _hook_dir(tmp_path, post_tool=""" + d = _hook_dir( + tmp_path, + post_tool=""" def post_tool(name, args, result=None, error=None): raise RuntimeError("boom") -""") +""", + ) reg = HookRegistry(hooks_dir=d) # Should not raise. reg.fire("post_tool", "fs_read_file", {"x": 1}, result="ok") @@ -72,16 +78,17 @@ def test_pre_message_hook_blocks_via_app(tmp_path: Path) -> None: end-to-end through GACT — the assistant message comes back with error_info.error == permission_error.""" - d = _hook_dir(tmp_path, pre_message=""" + d = _hook_dir( + tmp_path, + pre_message=""" def pre_message(session_id, text): if "secret" in text.lower(): raise PermissionError("blocked by policy") -""") +""", + ) install_global_registry(HookRegistry(hooks_dir=d)) try: - app = build_app( - sessions_path=tmp_path / "s.json", agent=_Agent() - ) + app = build_app(sessions_path=tmp_path / "s.json", agent=_Agent()) with TestClient(app) as c: sid = c.post("/v1/sessions", json={"title": "t"}).json()["id"] ack = c.post( @@ -91,6 +98,7 @@ def pre_message(session_id, text): assert ack.status_code == 200 # Wait for the background turn to settle into error. import time as _t + for _ in range(30): sess = c.get(f"/v1/sessions/{sid}").json() if sess["status"] == "error": @@ -106,15 +114,16 @@ def test_post_message_hook_runs_after_settle(tmp_path: Path) -> None: (write a marker file here).""" marker = tmp_path / "post_message_fired.txt" - d = _hook_dir(tmp_path, post_message=f""" + d = _hook_dir( + tmp_path, + post_message=f""" def post_message(session_id, assistant): open({str(marker)!r}, "w").write(assistant['id']) -""") +""", + ) install_global_registry(HookRegistry(hooks_dir=d)) try: - app = build_app( - sessions_path=tmp_path / "s.json", agent=_Agent() - ) + app = build_app(sessions_path=tmp_path / "s.json", agent=_Agent()) with TestClient(app) as c: from .conftest import complete_turn diff --git a/tests/test_gact/test_memory_stats.py b/tests/test_gact/test_memory_stats.py index 5f1204a..3a3a948 100644 --- a/tests/test_gact/test_memory_stats.py +++ b/tests/test_gact/test_memory_stats.py @@ -51,9 +51,7 @@ def get_cache_stats(self) -> dict[str, Any]: @pytest.fixture() def client_with_arc(tmp_path: Path) -> TestClient: arc = FakeARC(hits=80, misses=20, conv_index_size=12, inv_index_size=42) - return TestClient( - build_app(sessions_path=tmp_path / "s.json", arc=arc) - ) + return TestClient(build_app(sessions_path=tmp_path / "s.json", arc=arc)) def test_memory_stats_reports_cache_counters(client_with_arc: TestClient) -> None: diff --git a/tests/test_gact/test_metrics.py b/tests/test_gact/test_metrics.py index c6c4830..a022a0f 100644 --- a/tests/test_gact/test_metrics.py +++ b/tests/test_gact/test_metrics.py @@ -32,9 +32,7 @@ def forward(self, question: str, session_id: str): @pytest.fixture() def client(tmp_path: Path) -> TestClient: - return TestClient( - build_app(sessions_path=tmp_path / "sessions.json", agent=_FakeAgent()) - ) + return TestClient(build_app(sessions_path=tmp_path / "sessions.json", agent=_FakeAgent())) def test_metrics_empty_state(client: TestClient) -> None: diff --git a/tests/test_gact/test_nanoagent_runtime.py b/tests/test_gact/test_nanoagent_runtime.py index fe2a841..c37060a 100644 --- a/tests/test_gact/test_nanoagent_runtime.py +++ b/tests/test_gact/test_nanoagent_runtime.py @@ -70,10 +70,9 @@ def test_spawn_many_empty_returns_empty() -> None: def test_render_input_uses_question_field_when_present() -> None: from clio_agent.runtime.nanoagent import _render_input + assert _render_input({"question": "hello"}) == "hello" - assert ( - _render_input({"file": "/tmp/x", "mode": "read"}).startswith("file=") - ) + assert _render_input({"file": "/tmp/x", "mode": "read"}).startswith("file=") def test_analysis_expert_detects_parallel_items() -> None: @@ -84,12 +83,10 @@ def test_analysis_expert_detects_parallel_items() -> None: # Multi-item triggers fan-out. assert _detect_parallel_items("validate /tmp/a.parquet and /tmp/b.parquet") == [ - "/tmp/a.parquet", "/tmp/b.parquet", + "/tmp/a.parquet", + "/tmp/b.parquet", ] - assert ( - len(_detect_parallel_items("check schema, statistics, and quality")) - == 3 - ) + assert len(_detect_parallel_items("check schema, statistics, and quality")) == 3 # Single-item questions don't fan out. assert _detect_parallel_items("validate /tmp/a.parquet") == [] # No trigger word -> no fan-out. diff --git a/tests/test_gact/test_permissions.py b/tests/test_gact/test_permissions.py index d33b59e..3d8ee66 100644 --- a/tests/test_gact/test_permissions.py +++ b/tests/test_gact/test_permissions.py @@ -27,9 +27,7 @@ def forward(self, question: str, session_id: str): def _client(tmp_path: Path, perms) -> TestClient: - return TestClient( - build_app(sessions_path=tmp_path / "s.json", agent=_Agent(perms)) - ) + return TestClient(build_app(sessions_path=tmp_path / "s.json", agent=_Agent(perms))) def _turn(client: TestClient, sid: str) -> dict: @@ -40,16 +38,19 @@ def _turn(client: TestClient, sid: str) -> dict: def test_permission_requested_then_allowed(tmp_path: Path) -> None: - client = _client(tmp_path, perms=[ - { - "tool_call": { - "call_id": "c1", - "tool_name": "shell.exec", - "input": {"cmd": "rm -rf /tmp/scratch"}, - }, - "summary": "destructive shell command", - } - ]) + client = _client( + tmp_path, + perms=[ + { + "tool_call": { + "call_id": "c1", + "tool_name": "shell.exec", + "input": {"cmd": "rm -rf /tmp/scratch"}, + }, + "summary": "destructive shell command", + } + ], + ) sid = client.post("/v1/sessions", json={"title": "t"}).json()["id"] _turn(client, sid) diff --git a/tests/test_gact/test_session_export.py b/tests/test_gact/test_session_export.py index d696e71..0363205 100644 --- a/tests/test_gact/test_session_export.py +++ b/tests/test_gact/test_session_export.py @@ -23,9 +23,7 @@ def forward(self, *args, **kwargs): def _client(tmp_path: Path) -> TestClient: - return TestClient( - build_app(sessions_path=tmp_path / "s.json", agent=_Agent()) - ) + return TestClient(build_app(sessions_path=tmp_path / "s.json", agent=_Agent())) def test_export_unknown_session_404s(tmp_path: Path) -> None: @@ -59,9 +57,7 @@ def test_export_then_import_round_trip(tmp_path: Path) -> None: assert len(rows) == 4 # Original user prompts preserved. user_texts = { - p["text"] for m in rows - for p in m["parts"] - if m["role"] == "user" and p["type"] == "text" + p["text"] for m in rows for p in m["parts"] if m["role"] == "user" and p["type"] == "text" } assert {"first", "second"} == user_texts diff --git a/tests/test_gact/test_session_tasks.py b/tests/test_gact/test_session_tasks.py index 43b66d2..6d33807 100644 --- a/tests/test_gact/test_session_tasks.py +++ b/tests/test_gact/test_session_tasks.py @@ -26,9 +26,7 @@ def test_empty_task_list(client: TestClient, sid: str) -> None: def test_create_then_list(client: TestClient, sid: str) -> None: - new = client.post( - f"/v1/sessions/{sid}/tasks", json={"title": "validate schema"} - ).json() + new = client.post(f"/v1/sessions/{sid}/tasks", json={"title": "validate schema"}).json() assert new["title"] == "validate schema" assert new["status"] == "pending" assert new["id"].startswith("task_") @@ -38,21 +36,15 @@ def test_create_then_list(client: TestClient, sid: str) -> None: def test_patch_status(client: TestClient, sid: str) -> None: - new = client.post( - f"/v1/sessions/{sid}/tasks", json={"title": "x"} - ).json() - patched = client.patch( - f"/v1/tasks/{new['id']}", json={"status": "completed"} - ).json() + new = client.post(f"/v1/sessions/{sid}/tasks", json={"title": "x"}).json() + patched = client.patch(f"/v1/tasks/{new['id']}", json={"status": "completed"}).json() assert patched["status"] == "completed" rows = client.get(f"/v1/sessions/{sid}/tasks").json()["tasks"] assert rows[0]["status"] == "completed" def test_delete(client: TestClient, sid: str) -> None: - new = client.post( - f"/v1/sessions/{sid}/tasks", json={"title": "x"} - ).json() + new = client.post(f"/v1/sessions/{sid}/tasks", json={"title": "x"}).json() assert client.delete(f"/v1/tasks/{new['id']}").status_code == 204 rows = client.get(f"/v1/sessions/{sid}/tasks").json()["tasks"] assert rows == [] diff --git a/tests/test_gact/test_sse.py b/tests/test_gact/test_sse.py index c68340c..5074d33 100644 --- a/tests/test_gact/test_sse.py +++ b/tests/test_gact/test_sse.py @@ -52,9 +52,7 @@ def test_format_sse_emits_canonical_wire_shape() -> None: assert "data: " in raw assert raw.endswith("\n\n") # The data line is valid JSON matching the envelope. - data_line = next( - ln for ln in raw.splitlines() if ln.startswith("data: ") - ) + data_line = next(ln for ln in raw.splitlines() if ln.startswith("data: ")) payload = json.loads(data_line.removeprefix("data: ")) assert payload["type"] == "message.completed" assert payload["payload"] == {"a": 1} @@ -146,9 +144,7 @@ def forward(self, question: str, session_id: str) -> Any: @pytest.fixture() def client(tmp_path: Path) -> TestClient: - return TestClient( - build_app(sessions_path=tmp_path / "s.json", agent=_FakeAgent()) - ) + return TestClient(build_app(sessions_path=tmp_path / "s.json", agent=_FakeAgent())) def test_sse_endpoint_404s_for_unknown_session(client: TestClient) -> None: diff --git a/tests/test_gact/test_thinking_blocks.py b/tests/test_gact/test_thinking_blocks.py index 0589b4c..487b5a0 100644 --- a/tests/test_gact/test_thinking_blocks.py +++ b/tests/test_gact/test_thinking_blocks.py @@ -28,9 +28,7 @@ def forward(self, *args, **kwargs): def _client(tmp_path: Path, pred) -> TestClient: - return TestClient( - build_app(sessions_path=tmp_path / "s.json", agent=_Agent(pred)) - ) + return TestClient(build_app(sessions_path=tmp_path / "s.json", agent=_Agent(pred))) def test_chain_of_thought_reasoning_becomes_thinking_part( @@ -51,12 +49,14 @@ def test_chain_of_thought_reasoning_becomes_thinking_part( def test_react_trajectory_dict_becomes_thinking_part(tmp_path: Path) -> None: from .conftest import complete_turn - pred = _Pred(trajectory={ - "step_0_thought": "first probe the schema", - "step_0_tool_name": "hdf5_list_datasets", - "step_1_thought": "now read /sim/temperature", - "step_1_tool_name": "hdf5_analyze_dataset", - }) + pred = _Pred( + trajectory={ + "step_0_thought": "first probe the schema", + "step_0_tool_name": "hdf5_list_datasets", + "step_1_thought": "now read /sim/temperature", + "step_1_tool_name": "hdf5_analyze_dataset", + } + ) c = _client(tmp_path, pred) sid = c.post("/v1/sessions", json={"title": "t"}).json()["id"] a = complete_turn(c, sid, "analyze") diff --git a/tests/test_integration_contract/conftest.py b/tests/test_integration_contract/conftest.py index b07b094..c53546d 100644 --- a/tests/test_integration_contract/conftest.py +++ b/tests/test_integration_contract/conftest.py @@ -105,9 +105,7 @@ def post_user(http: httpx.Client, sid: str, text: str) -> str: return ack.json()["message_id"] -def turn( - http: httpx.Client, sid: str, text: str, *, timeout: float = 180.0 -) -> dict[str, Any]: +def turn(http: httpx.Client, sid: str, text: str, *, timeout: float = 180.0) -> dict[str, Any]: """POST + wait for assistant + return the assistant dict.""" user_id = post_user(http, sid, text) diff --git a/tests/test_tools/test_file_policy.py b/tests/test_tools/test_file_policy.py index 338f838..616a103 100644 --- a/tests/test_tools/test_file_policy.py +++ b/tests/test_tools/test_file_policy.py @@ -1,6 +1,5 @@ """Tests for file access policy validation.""" - from clio_agent.tools import file_policy from clio_agent.tools.file_policy import FileAccessPolicy, FilePolicyError diff --git a/tests/test_tools/test_fs_server.py b/tests/test_tools/test_fs_server.py index 539e90f..4ccfa61 100644 --- a/tests/test_tools/test_fs_server.py +++ b/tests/test_tools/test_fs_server.py @@ -84,18 +84,14 @@ def test_gateway_apply_edit_write_respects_late_global_permission_gate( seen: list[str] = [] try: - set_global_permission_gate( - lambda name, _args: seen.append(name) or "allow" - ) + set_global_permission_gate(lambda name, _args: seen.append(name) or "allow") executor.call_tool( "fs_apply_edit_write", {"filepath": str(target), "new_content": "allowed\n"}, ) assert target.read_text(encoding="utf-8") == "allowed\n" - set_global_permission_gate( - lambda name, _args: seen.append(name) or "deny" - ) + set_global_permission_gate(lambda name, _args: seen.append(name) or "deny") with pytest.raises(PermissionError, match="denied"): executor.call_tool( "fs_apply_edit_write", diff --git a/tests/test_tools/test_gateway.py b/tests/test_tools/test_gateway.py index 47d57fa..623ce64 100644 --- a/tests/test_tools/test_gateway.py +++ b/tests/test_tools/test_gateway.py @@ -160,7 +160,8 @@ async def test_list_capabilities_inside_running_loop_has_no_unawaited_warning(re gc.collect() leaked = [ - warning for warning in recwarn + warning + for warning in recwarn if issubclass(warning.category, RuntimeWarning) and "was never awaited" in str(warning.message) ] @@ -171,8 +172,6 @@ async def test_list_capabilities_inside_running_loop_has_no_unawaited_warning(re async def test_gateway_error_handling(): """Test that errors propagate correctly through gateway.""" async with Client(gateway) as client: - result = await client.call_tool( - "hdf5_analyze_file", {"filepath": "/nonexistent/file.h5"} - ) + result = await client.call_tool("hdf5_analyze_file", {"filepath": "/nonexistent/file.h5"}) data = _parse_result(result) assert "error" in data diff --git a/tests/test_tools/test_shell_server.py b/tests/test_tools/test_shell_server.py index 4335c8d..0fa0b89 100644 --- a/tests/test_tools/test_shell_server.py +++ b/tests/test_tools/test_shell_server.py @@ -23,7 +23,9 @@ def _parse_result(result: object) -> dict: @pytest.mark.asyncio -async def test_shell_bash_runs_simple_command(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: +async def test_shell_bash_runs_simple_command( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: """The shell tool should execute a bounded command and return output.""" monkeypatch.setenv("CLIO_ALLOWED_ROOTS", str(tmp_path))