From bea4abd96ef791b025be5a4469083f076eb546e6 Mon Sep 17 00:00:00 2001 From: Parman Mohammadalizadeh Date: Thu, 18 Jun 2026 14:37:07 +0330 Subject: [PATCH] chore: apply ruff formatting Signed-off-by: Parman Mohammadalizadeh --- src/agent/claude_agent/cli.py | 4 +- src/agent/claude_agent/runner.py | 34 +- src/agent/claude_agent/tests/test_runner.py | 21 +- src/agent/deep_agent/cli.py | 4 +- src/agent/deep_agent/runner.py | 8 +- src/agent/deep_agent/tests/test_runner.py | 47 ++- src/agent/direct_llm_agent/__init__.py | 2 +- src/agent/direct_llm_agent/cli.py | 2 +- src/agent/direct_llm_agent/runner.py | 2 +- src/agent/openai_agent/cli.py | 4 +- src/agent/openai_agent/runner.py | 23 +- src/agent/plan_execute/executor.py | 10 +- src/agent/plan_execute/runner.py | 5 +- src/agent/stirrup_agent/__init__.py | 2 +- src/agent/stirrup_agent/cli.py | 8 +- src/agent/stirrup_agent/runner.py | 15 +- src/agent/stirrup_agent/tests/__init__.py | 30 +- src/agent/stirrup_agent/tests/test_runner.py | 30 +- src/agent/stirrup_agent/trajectory.py | 2 +- src/agent/tests/test_planner.py | 9 +- src/agent/tests/test_runner.py | 167 +++++--- src/couchdb/init_data.py | 91 +++-- src/couchdb/loader.py | 61 ++- src/couchdb/transforms.py | 26 +- src/evaluation/cli.py | 3 +- src/evaluation/evaluator.py | 4 +- src/evaluation/loader.py | 24 +- src/evaluation/metrics.py | 10 +- src/evaluation/scorers/__init__.py | 6 +- src/evaluation/scorers/llm_judge.py | 4 +- src/evaluation/scorers/static_json.py | 12 +- src/evaluation/tests/test_loader.py | 14 +- src/evaluation/tests/test_metrics.py | 47 ++- src/evaluation/tests/test_models.py | 4 +- src/evaluation/tests/test_report.py | 13 +- src/evaluation/tests/test_runner.py | 12 +- .../tests/test_static_json_scorer.py | 4 +- src/llm/base.py | 4 +- src/llm/litellm.py | 4 +- src/llm/openai_compat.py | 8 +- src/llm/routers.py | 6 +- src/llm/tests/test_backends.py | 10 +- src/observability/persistence.py | 4 +- src/observability/runspan.py | 2 - src/observability/tests/test_persistence.py | 8 +- src/observability/tests/test_tracing.py | 4 +- src/observability/tracing.py | 4 +- src/servers/fmsr/main.py | 25 +- src/servers/fmsr/tests/conftest.py | 6 +- src/servers/fmsr/tests/test_tools.py | 18 +- src/servers/iot/main.py | 5 +- src/servers/iot/tests/conftest.py | 1 + src/servers/iot/tests/test_couchdb.py | 4 +- src/servers/tsfm/main.py | 10 +- src/servers/tsfm/tests/conftest.py | 2 + src/servers/tsfm/tests/test_tools.py | 85 ++-- src/servers/utilities/main.py | 13 +- src/servers/vibration/data_store.py | 4 +- src/servers/vibration/dsp/__init__.py | 14 +- src/servers/vibration/dsp/bearing_freqs.py | 6 +- src/servers/vibration/dsp/envelope.py | 4 +- src/servers/vibration/dsp/fault_detection.py | 8 +- src/servers/vibration/main.py | 5 +- .../generate_synthetic_vibration.py | 52 +-- src/servers/vibration/tests/test_dsp.py | 5 +- src/servers/vibration/tests/test_mcp_e2e.py | 63 ++- src/servers/vibration/tests/test_tools.py | 169 +++++--- src/servers/wo/couch.py | 41 +- src/servers/wo/envelope.py | 12 +- src/servers/wo/main.py | 273 +++++++++---- src/servers/wo/models.py | 2 +- src/servers/wo/tests/test_models_boundary.py | 44 ++- src/servers/wo/tests/test_workorders.py | 86 ++-- src/servers/wo/workorders.py | 367 +++++++++++++----- 74 files changed, 1500 insertions(+), 642 deletions(-) diff --git a/src/agent/claude_agent/cli.py b/src/agent/claude_agent/cli.py index 3d477a039..5bda40bf0 100644 --- a/src/agent/claude_agent/cli.py +++ b/src/agent/claude_agent/cli.py @@ -52,7 +52,9 @@ async def _run(args: argparse.Namespace) -> None: runner = ClaudeAgentRunner(model=args.model_id, max_turns=args.max_turns) result = await runner.run(args.question) - print_result(result, show_trajectory=args.show_trajectory, output_json=args.output_json) + print_result( + result, show_trajectory=args.show_trajectory, output_json=args.output_json + ) def main() -> None: diff --git a/src/agent/claude_agent/runner.py b/src/agent/claude_agent/runner.py index 499432893..4f7f3e9a4 100644 --- a/src/agent/claude_agent/runner.py +++ b/src/agent/claude_agent/runner.py @@ -20,7 +20,13 @@ import time from pathlib import Path -from claude_agent_sdk import AssistantMessage, ClaudeAgentOptions, HookMatcher, ResultMessage, query +from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + HookMatcher, + ResultMessage, + query, +) from claude_agent_sdk import TextBlock, ToolUseBlock from observability import agent_run_span, persist_trajectory @@ -132,8 +138,14 @@ async def run(self, question: str) -> AgentResult: last_turn_start = run_started tool_outputs: dict[str, object] = {} - async def _capture_tool_output(input_data, tool_use_id: str, context) -> dict: - resp = input_data.get("tool_response") if isinstance(input_data, dict) else input_data + async def _capture_tool_output( + input_data, tool_use_id: str, context + ) -> dict: + resp = ( + input_data.get("tool_response") + if isinstance(input_data, dict) + else input_data + ) if isinstance(resp, dict): tool_outputs[tool_use_id] = resp.get("content", resp) else: @@ -145,7 +157,9 @@ async def _capture_tool_output(input_data, tool_use_id: str, context) -> dict: # per-tool duration for claude-agent is therefore not captured # (matches openai-agent / deep-agent). options.hooks = { - "PostToolUse": [HookMatcher(matcher=".*", hooks=[_capture_tool_output])], + "PostToolUse": [ + HookMatcher(matcher=".*", hooks=[_capture_tool_output]) + ], } def _flush_tool_outputs() -> None: @@ -169,7 +183,9 @@ def _flush_tool_outputs() -> None: text += block.text elif isinstance(block, ToolUseBlock): tool_calls.append( - ToolCall(name=block.name, input=block.input, id=block.id) + ToolCall( + name=block.name, input=block.input, id=block.id + ) ) usage = message.usage or {} trajectory.turns.append( @@ -197,8 +213,12 @@ def _flush_tool_outputs() -> None: duration_ms = (time.perf_counter() - run_started) * 1000 span.set_attribute("agent.answer.length", len(answer)) - span.set_attribute("gen_ai.usage.input_tokens", trajectory.total_input_tokens) - span.set_attribute("gen_ai.usage.output_tokens", trajectory.total_output_tokens) + span.set_attribute( + "gen_ai.usage.input_tokens", trajectory.total_input_tokens + ) + span.set_attribute( + "gen_ai.usage.output_tokens", trajectory.total_output_tokens + ) span.set_attribute("agent.turns", len(trajectory.turns)) span.set_attribute("agent.tool_calls", len(trajectory.all_tool_calls)) span.set_attribute("agent.duration_ms", duration_ms) diff --git a/src/agent/claude_agent/tests/test_runner.py b/src/agent/claude_agent/tests/test_runner.py index 3c4ddbc59..c3f439e7e 100644 --- a/src/agent/claude_agent/tests/test_runner.py +++ b/src/agent/claude_agent/tests/test_runner.py @@ -113,7 +113,12 @@ async def fake_query(prompt, options): @pytest.mark.anyio async def test_run_collects_trajectory(): - from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock, ToolUseBlock + from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + TextBlock, + ToolUseBlock, + ) mock_tool = MagicMock(spec=ToolUseBlock) mock_tool.name = "sensors" @@ -157,7 +162,12 @@ async def fake_query(prompt, options): @pytest.mark.anyio async def test_run_tool_output_captured(): """PostToolUse hook output is attached to the matching ToolCall.""" - from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock, ToolUseBlock + from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + TextBlock, + ToolUseBlock, + ) mock_tool = MagicMock(spec=ToolUseBlock) mock_tool.name = "sensors" @@ -206,7 +216,12 @@ async def fake_query(prompt, options): @pytest.mark.anyio async def test_run_tool_output_string_response(): """PostToolUse hook handles string tool_response (no .get).""" - from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock, ToolUseBlock + from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + TextBlock, + ToolUseBlock, + ) mock_tool = MagicMock(spec=ToolUseBlock) mock_tool.name = "sites" diff --git a/src/agent/deep_agent/cli.py b/src/agent/deep_agent/cli.py index f823307f7..8efcabd8e 100644 --- a/src/agent/deep_agent/cli.py +++ b/src/agent/deep_agent/cli.py @@ -58,7 +58,9 @@ async def _run(args: argparse.Namespace) -> None: recursion_limit=args.recursion_limit, ) result = await runner.run(args.question) - print_result(result, show_trajectory=args.show_trajectory, output_json=args.output_json) + print_result( + result, show_trajectory=args.show_trajectory, output_json=args.output_json + ) def main() -> None: diff --git a/src/agent/deep_agent/runner.py b/src/agent/deep_agent/runner.py index 5ff3f4dbf..e06a975a5 100644 --- a/src/agent/deep_agent/runner.py +++ b/src/agent/deep_agent/runner.py @@ -239,8 +239,12 @@ async def run(self, question: str) -> AgentResult: ) span.set_attribute("agent.answer.length", len(answer)) - span.set_attribute("gen_ai.usage.input_tokens", trajectory.total_input_tokens) - span.set_attribute("gen_ai.usage.output_tokens", trajectory.total_output_tokens) + span.set_attribute( + "gen_ai.usage.input_tokens", trajectory.total_input_tokens + ) + span.set_attribute( + "gen_ai.usage.output_tokens", trajectory.total_output_tokens + ) span.set_attribute("agent.turns", len(trajectory.turns)) span.set_attribute("agent.tool_calls", len(trajectory.all_tool_calls)) span.set_attribute( diff --git a/src/agent/deep_agent/tests/test_runner.py b/src/agent/deep_agent/tests/test_runner.py index 2014612d8..834ada5ea 100644 --- a/src/agent/deep_agent/tests/test_runner.py +++ b/src/agent/deep_agent/tests/test_runner.py @@ -122,12 +122,20 @@ def test_build_trajectory_tool_calls_and_outputs(): AIMessage( content="", tool_calls=[{"name": "sensors", "args": {"asset_id": "CH-6"}, "id": "c1"}], - usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120}, + usage_metadata={ + "input_tokens": 100, + "output_tokens": 20, + "total_tokens": 120, + }, ), ToolMessage(content="5 sensors found", tool_call_id="c1"), AIMessage( content="Chiller 6 has 5 sensors.", - usage_metadata={"input_tokens": 150, "output_tokens": 30, "total_tokens": 180}, + usage_metadata={ + "input_tokens": 150, + "output_tokens": 30, + "total_tokens": 180, + }, ), ] traj = _build_trajectory(messages) @@ -149,7 +157,12 @@ def test_build_trajectory_tool_calls_and_outputs(): def test_build_trajectory_list_content(): messages = [ - AIMessage(content=[{"type": "text", "text": "part one "}, {"type": "text", "text": "part two"}]) + AIMessage( + content=[ + {"type": "text", "text": "part one "}, + {"type": "text", "text": "part two"}, + ] + ) ] traj = _build_trajectory(messages) assert traj.turns[0].text == "part one part two" @@ -172,13 +185,21 @@ def test_build_trajectory_multiple_tool_calls_one_turn(): {"name": "sites", "args": {}, "id": "c1"}, {"name": "assets", "args": {"site_id": "MAIN"}, "id": "c2"}, ], - usage_metadata={"input_tokens": 50, "output_tokens": 10, "total_tokens": 60}, + usage_metadata={ + "input_tokens": 50, + "output_tokens": 10, + "total_tokens": 60, + }, ), ToolMessage(content=["MAIN"], tool_call_id="c1"), ToolMessage(content=["Chiller 6"], tool_call_id="c2"), AIMessage( content="Found Chiller 6 at site MAIN.", - usage_metadata={"input_tokens": 80, "output_tokens": 15, "total_tokens": 95}, + usage_metadata={ + "input_tokens": 80, + "output_tokens": 15, + "total_tokens": 95, + }, ), ] traj = _build_trajectory(messages) @@ -242,13 +263,23 @@ async def test_run_collects_trajectory(): HumanMessage(content="What sensors are on Chiller 6?"), AIMessage( content="", - tool_calls=[{"name": "sensors", "args": {"asset_id": "CH-6"}, "id": "c1"}], - usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120}, + tool_calls=[ + {"name": "sensors", "args": {"asset_id": "CH-6"}, "id": "c1"} + ], + usage_metadata={ + "input_tokens": 100, + "output_tokens": 20, + "total_tokens": 120, + }, ), ToolMessage(content="sensor data", tool_call_id="c1"), AIMessage( content="Chiller 6 has 5 sensors.", - usage_metadata={"input_tokens": 150, "output_tokens": 30, "total_tokens": 180}, + usage_metadata={ + "input_tokens": 150, + "output_tokens": 30, + "total_tokens": 180, + }, ), ] } diff --git a/src/agent/direct_llm_agent/__init__.py b/src/agent/direct_llm_agent/__init__.py index 102c9726c..6f317b52b 100644 --- a/src/agent/direct_llm_agent/__init__.py +++ b/src/agent/direct_llm_agent/__init__.py @@ -2,4 +2,4 @@ from .runner import DirectLLMAgentRunner -__all__ = ["DirectLLMAgentRunner"] \ No newline at end of file +__all__ = ["DirectLLMAgentRunner"] diff --git a/src/agent/direct_llm_agent/cli.py b/src/agent/direct_llm_agent/cli.py index b17328a37..be8a7c661 100644 --- a/src/agent/direct_llm_agent/cli.py +++ b/src/agent/direct_llm_agent/cli.py @@ -75,4 +75,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/agent/direct_llm_agent/runner.py b/src/agent/direct_llm_agent/runner.py index ef6c4b305..6b4a35c44 100644 --- a/src/agent/direct_llm_agent/runner.py +++ b/src/agent/direct_llm_agent/runner.py @@ -100,4 +100,4 @@ async def run(self, question: str) -> AgentResult: question=question, answer=answer, trajectory=trajectory, - ) \ No newline at end of file + ) diff --git a/src/agent/openai_agent/cli.py b/src/agent/openai_agent/cli.py index d6e73fe8c..c26d831ca 100644 --- a/src/agent/openai_agent/cli.py +++ b/src/agent/openai_agent/cli.py @@ -54,7 +54,9 @@ async def _run(args: argparse.Namespace) -> None: runner = OpenAIAgentRunner(model=args.model_id, max_turns=args.max_turns) result = await runner.run(args.question) - print_result(result, show_trajectory=args.show_trajectory, output_json=args.output_json) + print_result( + result, show_trajectory=args.show_trajectory, output_json=args.output_json + ) def main() -> None: diff --git a/src/agent/openai_agent/runner.py b/src/agent/openai_agent/runner.py index 1d6809878..f301b904b 100644 --- a/src/agent/openai_agent/runner.py +++ b/src/agent/openai_agent/runner.py @@ -24,7 +24,14 @@ from openai import AsyncOpenAI -from agents import Agent, ModelProvider, OpenAIChatCompletionsModel, RunConfig, Runner, set_tracing_disabled +from agents import ( + Agent, + ModelProvider, + OpenAIChatCompletionsModel, + RunConfig, + Runner, + set_tracing_disabled, +) from agents.mcp import MCPServerStdio from observability import agent_run_span, persist_trajectory @@ -136,7 +143,9 @@ def _flush() -> None: tc_id = getattr(raw, "call_id", "") or getattr(raw, "id", "") or "" tc_args = getattr(raw, "arguments", "{}") or "{}" try: - tc_input = json.loads(tc_args) if isinstance(tc_args, str) else tc_args + tc_input = ( + json.loads(tc_args) if isinstance(tc_args, str) else tc_args + ) except (json.JSONDecodeError, TypeError): tc_input = {"raw": tc_args} tool_calls.append(ToolCall(name=tc_name, input=tc_input, id=tc_id)) @@ -251,8 +260,12 @@ async def run(self, question: str) -> AgentResult: ) span.set_attribute("agent.answer.length", len(answer)) - span.set_attribute("gen_ai.usage.input_tokens", trajectory.total_input_tokens) - span.set_attribute("gen_ai.usage.output_tokens", trajectory.total_output_tokens) + span.set_attribute( + "gen_ai.usage.input_tokens", trajectory.total_input_tokens + ) + span.set_attribute( + "gen_ai.usage.output_tokens", trajectory.total_output_tokens + ) span.set_attribute("agent.turns", len(trajectory.turns)) span.set_attribute("agent.tool_calls", len(trajectory.all_tool_calls)) span.set_attribute( @@ -270,5 +283,3 @@ async def run(self, question: str) -> AgentResult: answer=answer, trajectory=trajectory, ) - - diff --git a/src/agent/plan_execute/executor.py b/src/agent/plan_execute/executor.py index 2e27b5144..d785f97ea 100644 --- a/src/agent/plan_execute/executor.py +++ b/src/agent/plan_execute/executor.py @@ -112,7 +112,9 @@ async def execute_plan(self, plan: Plan, question: str) -> list[StepResult]: ) schema = tool_schemas.get(step.server, {}).get(step.tool, "") step_started = time.perf_counter() - result = await self.execute_step(step, context, question, tool_schema=schema) + result = await self.execute_step( + step, context, question, tool_schema=schema + ) result.duration_ms = (time.perf_counter() - step_started) * 1000 if result.success: _log.info("Step %d OK.", step.step_number) @@ -202,8 +204,7 @@ async def _resolve_args_with_llm( f"Step {n}: {r.response}" for n, r in sorted(context.items()) ) prompt = ( - _ARG_RESOLUTION_PROMPT - .replace("{question}", question) + _ARG_RESOLUTION_PROMPT.replace("{question}", question) .replace("{task}", task) .replace("{tool}", tool) .replace("{tool_schema}", tool_schema or "(unknown)") @@ -214,7 +215,8 @@ async def _resolve_args_with_llm( if resolved is None: _log.warning( "Tool '%s': arg resolution returned no parseable JSON (response: %r…)", - tool, raw[:120], + tool, + raw[:120], ) return {} return resolved diff --git a/src/agent/plan_execute/runner.py b/src/agent/plan_execute/runner.py index ea684a650..445f47b09 100644 --- a/src/agent/plan_execute/runner.py +++ b/src/agent/plan_execute/runner.py @@ -50,9 +50,7 @@ def generate(self, prompt: str, temperature: float = 0.0) -> str: self.output_tokens += result.output_tokens return result.text - def generate_with_usage( - self, prompt: str, temperature: float = 0.0 - ) -> LLMResult: + def generate_with_usage(self, prompt: str, temperature: float = 0.0) -> LLMResult: result = self._inner.generate_with_usage(prompt, temperature) self.input_tokens += result.input_tokens self.output_tokens += result.output_tokens @@ -62,6 +60,7 @@ def generate_with_usage( def model_id(self) -> str: return self._inner.model_id + _log = logging.getLogger(__name__) _SUMMARIZE_PROMPT = """\ diff --git a/src/agent/stirrup_agent/__init__.py b/src/agent/stirrup_agent/__init__.py index afa573e86..f32d12e07 100644 --- a/src/agent/stirrup_agent/__init__.py +++ b/src/agent/stirrup_agent/__init__.py @@ -6,4 +6,4 @@ from .runner import StirrupAgentRunner -__all__ = ["StirrupAgentRunner"] \ No newline at end of file +__all__ = ["StirrupAgentRunner"] diff --git a/src/agent/stirrup_agent/cli.py b/src/agent/stirrup_agent/cli.py index 7d66ec064..efee63720 100644 --- a/src/agent/stirrup_agent/cli.py +++ b/src/agent/stirrup_agent/cli.py @@ -81,7 +81,7 @@ def _build_parser() -> argparse.ArgumentParser: default=16_384, metavar="N", help="Max output tokens per model call; must stay under the provider " - "limit (watsonx caps new tokens at 100k). Default: 16384.", + "limit (watsonx caps new tokens at 100k). Default: 16384.", ) return parser @@ -97,7 +97,9 @@ async def _run(args: argparse.Namespace) -> None: max_tokens=args.max_tokens, ) result = await runner.run(args.question) - print_result(result, show_trajectory=args.show_trajectory, output_json=args.output_json) + print_result( + result, show_trajectory=args.show_trajectory, output_json=args.output_json + ) def main() -> None: @@ -105,4 +107,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/agent/stirrup_agent/runner.py b/src/agent/stirrup_agent/runner.py index 351c665d0..eea8e916d 100644 --- a/src/agent/stirrup_agent/runner.py +++ b/src/agent/stirrup_agent/runner.py @@ -160,7 +160,9 @@ async def run(self, question: str) -> AgentResult: _log.info( "StirrupAgentRunner: starting (model=%s, code=%s, backend=%s)", - self._model_id, self._code_enabled, self._code_backend, + self._model_id, + self._code_enabled, + self._code_backend, ) async with agent.session() as session: @@ -180,7 +182,9 @@ async def run(self, question: str) -> AgentResult: ) return AgentResult(question=question, answer=answer, trajectory=trajectory) - def _annotate_span(self, span, trajectory: Trajectory, answer: str, started: float) -> None: + def _annotate_span( + self, span, trajectory: Trajectory, answer: str, started: float + ) -> None: domain_servers = set(self._server_paths) counts = {"domain": 0, "code": 0, "other": 0} for tc in trajectory.all_tool_calls: @@ -201,5 +205,8 @@ def _annotate_span(self, span, trajectory: Trajectory, answer: str, started: flo _log.info( "StirrupAgentRunner: done (turns=%d, domain=%d, code=%d, bypass=%s)", - len(trajectory.turns), counts["domain"], counts["code"], bypass, - ) \ No newline at end of file + len(trajectory.turns), + counts["domain"], + counts["code"], + bypass, + ) diff --git a/src/agent/stirrup_agent/tests/__init__.py b/src/agent/stirrup_agent/tests/__init__.py index 23718fecf..baf177c5b 100644 --- a/src/agent/stirrup_agent/tests/__init__.py +++ b/src/agent/stirrup_agent/tests/__init__.py @@ -80,8 +80,13 @@ def test_build_trajectory_maps_turns_calls_and_outputs(): request_start_time=1.0, request_end_time=2.5, ), - _Tool(content="[{'wo': 7}]", tool_call_id="t1", name="wo__get_work_order", - tool_start_time=2.5, tool_end_time=3.0), + _Tool( + content="[{'wo': 7}]", + tool_call_id="t1", + name="wo__get_work_order", + tool_start_time=2.5, + tool_end_time=3.0, + ), ], [ _Assistant( @@ -98,20 +103,29 @@ def test_build_trajectory_maps_turns_calls_and_outputs(): call = traj.all_tool_calls[0] assert call.name == "wo__get_work_order" - assert call.input == {"asset": "CWC04013"} # JSON string parsed + assert call.input == {"asset": "CWC04013"} # JSON string parsed assert call.output == "[{'wo': 7}]" - assert call.duration_ms == 500.0 # (3.0 - 2.5) * 1000 - assert traj.turns[0].duration_ms == 1500.0 # (2.5 - 1.0) * 1000 + assert call.duration_ms == 500.0 # (3.0 - 2.5) * 1000 + assert traj.turns[0].duration_ms == 1500.0 # (2.5 - 1.0) * 1000 assert final_answer(history, _Finish("done")) == "there are 7 open work orders" def test_final_answer_falls_back_to_finish_reason(): history = [[_Assistant(content="")]] - assert final_answer(history, _Finish("computed RUL = 142 days")) == "computed RUL = 142 days" + assert ( + final_answer(history, _Finish("computed RUL = 142 days")) + == "computed RUL = 142 days" + ) def test_arguments_parsed_when_already_dict(): - history = [[_Assistant(content="x", tool_calls=[_TC("iot__get_sensors", {"asset": "CH6"}, "a")])]] + history = [ + [ + _Assistant( + content="x", tool_calls=[_TC("iot__get_sensors", {"asset": "CH6"}, "a")] + ) + ] + ] traj = build_trajectory(history) - assert traj.all_tool_calls[0].input == {"asset": "CH6"} \ No newline at end of file + assert traj.all_tool_calls[0].input == {"asset": "CH6"} diff --git a/src/agent/stirrup_agent/tests/test_runner.py b/src/agent/stirrup_agent/tests/test_runner.py index 23718fecf..baf177c5b 100644 --- a/src/agent/stirrup_agent/tests/test_runner.py +++ b/src/agent/stirrup_agent/tests/test_runner.py @@ -80,8 +80,13 @@ def test_build_trajectory_maps_turns_calls_and_outputs(): request_start_time=1.0, request_end_time=2.5, ), - _Tool(content="[{'wo': 7}]", tool_call_id="t1", name="wo__get_work_order", - tool_start_time=2.5, tool_end_time=3.0), + _Tool( + content="[{'wo': 7}]", + tool_call_id="t1", + name="wo__get_work_order", + tool_start_time=2.5, + tool_end_time=3.0, + ), ], [ _Assistant( @@ -98,20 +103,29 @@ def test_build_trajectory_maps_turns_calls_and_outputs(): call = traj.all_tool_calls[0] assert call.name == "wo__get_work_order" - assert call.input == {"asset": "CWC04013"} # JSON string parsed + assert call.input == {"asset": "CWC04013"} # JSON string parsed assert call.output == "[{'wo': 7}]" - assert call.duration_ms == 500.0 # (3.0 - 2.5) * 1000 - assert traj.turns[0].duration_ms == 1500.0 # (2.5 - 1.0) * 1000 + assert call.duration_ms == 500.0 # (3.0 - 2.5) * 1000 + assert traj.turns[0].duration_ms == 1500.0 # (2.5 - 1.0) * 1000 assert final_answer(history, _Finish("done")) == "there are 7 open work orders" def test_final_answer_falls_back_to_finish_reason(): history = [[_Assistant(content="")]] - assert final_answer(history, _Finish("computed RUL = 142 days")) == "computed RUL = 142 days" + assert ( + final_answer(history, _Finish("computed RUL = 142 days")) + == "computed RUL = 142 days" + ) def test_arguments_parsed_when_already_dict(): - history = [[_Assistant(content="x", tool_calls=[_TC("iot__get_sensors", {"asset": "CH6"}, "a")])]] + history = [ + [ + _Assistant( + content="x", tool_calls=[_TC("iot__get_sensors", {"asset": "CH6"}, "a")] + ) + ] + ] traj = build_trajectory(history) - assert traj.all_tool_calls[0].input == {"asset": "CH6"} \ No newline at end of file + assert traj.all_tool_calls[0].input == {"asset": "CH6"} diff --git a/src/agent/stirrup_agent/trajectory.py b/src/agent/stirrup_agent/trajectory.py index 228da70a7..d6e3e0df8 100644 --- a/src/agent/stirrup_agent/trajectory.py +++ b/src/agent/stirrup_agent/trajectory.py @@ -157,4 +157,4 @@ def final_answer(history: Iterable[Any], finish_params: Any) -> str: if text: return text reason = getattr(finish_params, "reason", None) - return reason if isinstance(reason, str) else "" \ No newline at end of file + return reason if isinstance(reason, str) else "" diff --git a/src/agent/tests/test_planner.py b/src/agent/tests/test_planner.py index 77bc4497d..6ce6e78de 100644 --- a/src/agent/tests/test_planner.py +++ b/src/agent/tests/test_planner.py @@ -144,7 +144,9 @@ def test_generate_plan_uses_llm_output(self, mock_llm): planner = Planner(llm) plan = planner.generate_plan( "List all assets", - {"iot": " - sites(): List sites\n - assets(site_name: string): List assets"}, + { + "iot": " - sites(): List sites\n - assets(site_name: string): List assets" + }, ) assert len(plan.steps) == 2 assert plan.steps[0].server == "iot" @@ -170,7 +172,10 @@ def test_generate_plan_prompt_contains_agent_names(self, mock_llm, monkeypatch): Planner(llm).generate_plan( "Q", - {"iot": " - sites(): List sites", "utilities": " - current_date_time(): Get time"}, + { + "iot": " - sites(): List sites", + "utilities": " - current_date_time(): Get time", + }, ) assert "iot" in captured[0] assert "utilities" in captured[0] diff --git a/src/agent/tests/test_runner.py b/src/agent/tests/test_runner.py index 6bdbac98b..062d78656 100644 --- a/src/agent/tests/test_runner.py +++ b/src/agent/tests/test_runner.py @@ -37,7 +37,11 @@ _MOCK_TOOLS = [ {"name": "sites", "description": "List IoT sites", "parameters": []}, - {"name": "current_date_time", "description": "Get current datetime", "parameters": []}, + { + "name": "current_date_time", + "description": "Get current datetime", + "parameters": [], + }, ] _TOOL_RESPONSE = json.dumps({"sites": ["MAIN"]}) @@ -51,9 +55,13 @@ def _patch_mcp(tool_response: str = _TOOL_RESPONSE): return ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), patch( - "agent.plan_execute.executor._call_tool", new=AsyncMock(return_value=tool_response) + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), + patch( + "agent.plan_execute.executor._call_tool", + new=AsyncMock(return_value=tool_response), ), ) @@ -93,12 +101,14 @@ def generate(self, prompt: str, **_kw) -> str: @pytest.mark.anyio async def test_orchestrator_run_returns_result(sequential_llm): - llm = sequential_llm([ - _TWO_STEP_PLAN, # planner call - _STEP1_ARGS, # arg resolution for step 1 - _STEP2_ARGS, # arg resolution for step 2 - _FINAL_ANSWER, # summarisation - ]) + llm = sequential_llm( + [ + _TWO_STEP_PLAN, # planner call + _STEP1_ARGS, # arg resolution for step 1 + _STEP2_ARGS, # arg resolution for step 2 + _FINAL_ANSWER, # summarisation + ] + ) with _patch_mcp()[0], _patch_mcp()[1]: result = await PlanExecuteRunner(llm).run("What are the IoT sites?") @@ -145,9 +155,7 @@ def __init__(self, items: list[tuple[str, int, int]]) -> None: def generate(self, prompt: str, temperature: float = 0.0) -> str: return self.generate_with_usage(prompt, temperature).text - def generate_with_usage( - self, prompt: str, temperature: float = 0.0 - ) -> LLMResult: + def generate_with_usage(self, prompt: str, temperature: float = 0.0) -> LLMResult: text, in_tok, out_tok = next(self._items, ("", 0, 0)) return LLMResult(text=text, input_tokens=in_tok, output_tokens=out_tok) @@ -155,12 +163,14 @@ def generate_with_usage( @pytest.mark.anyio async def test_orchestrator_accumulates_token_usage_across_llm_calls(): """Plan + 2 arg-resolution + summarise → summed input/output tokens.""" - llm = _UsageReportingLLM([ - (_TWO_STEP_PLAN, 100, 50), # planner - (_STEP1_ARGS, 20, 5), # step 1 arg resolution - (_STEP2_ARGS, 30, 5), # step 2 arg resolution - (_FINAL_ANSWER, 200, 40), # summarise - ]) + llm = _UsageReportingLLM( + [ + (_TWO_STEP_PLAN, 100, 50), # planner + (_STEP1_ARGS, 20, 5), # step 1 arg resolution + (_STEP2_ARGS, 30, 5), # step 2 arg resolution + (_FINAL_ANSWER, 200, 40), # summarise + ] + ) runner = PlanExecuteRunner(llm) with _patch_mcp()[0], _patch_mcp()[1]: await runner.run("Q") @@ -245,8 +255,13 @@ async def test_executor_step_result_carries_resolved_args(sequential_llm): step = _make_step(1, tool="assets") with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), - patch("agent.plan_execute.executor._call_tool", new=AsyncMock(return_value="{}")), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), + patch( + "agent.plan_execute.executor._call_tool", new=AsyncMock(return_value="{}") + ), ): result = await executor.execute_step(step, {}, "List assets at MAIN") @@ -258,13 +273,19 @@ async def test_executor_tool_call_exception_recorded_as_error(sequential_llm): """If _call_tool raises, the error is captured in StepResult (no crash).""" from pathlib import Path - llm = sequential_llm(['{}']) + llm = sequential_llm(["{}"]) executor = Executor(llm, server_paths={"iot": Path("/fake/server.py")}) step = _make_step(1, tool="sites") with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), - patch("agent.plan_execute.executor._call_tool", new=AsyncMock(side_effect=RuntimeError("timeout"))), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), + patch( + "agent.plan_execute.executor._call_tool", + new=AsyncMock(side_effect=RuntimeError("timeout")), + ), ): result = await executor.execute_step(step, {}, "Q") @@ -277,10 +298,12 @@ async def test_executor_calls_llm_to_generate_args(sequential_llm): """Each tool step triggers exactly one LLM call for arg generation.""" from pathlib import Path - llm = sequential_llm([ - '{}', # step 1: sites (no args) - '{"site_name": "MAIN", "asset_id": "CH-1"}', # step 2: sensors - ]) + llm = sequential_llm( + [ + "{}", # step 1: sites (no args) + '{"site_name": "MAIN", "asset_id": "CH-1"}', # step 2: sensors + ] + ) executor = Executor(llm, server_paths={"iot": Path("/fake/server.py")}) plan = Plan( @@ -290,12 +313,17 @@ async def test_executor_calls_llm_to_generate_args(sequential_llm): ], raw="", ) - call_mock = AsyncMock(side_effect=[ - json.dumps({"sites": ["MAIN"]}), - json.dumps({"sensors": ["temp"]}), - ]) + call_mock = AsyncMock( + side_effect=[ + json.dumps({"sites": ["MAIN"]}), + json.dumps({"sensors": ["temp"]}), + ] + ) with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), patch("agent.plan_execute.executor._call_tool", new=call_mock), ): results = await executor.execute_plan(plan, "Q") @@ -324,7 +352,10 @@ async def test_executor_prior_step_results_in_llm_prompt(): site_resp = json.dumps({"sites": ["MAIN"]}) call_mock = AsyncMock(side_effect=[site_resp, '{"sensors": []}']) with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), patch("agent.plan_execute.executor._call_tool", new=call_mock), ): await executor.execute_plan(plan, "List sensors for CH-1") @@ -338,13 +369,18 @@ async def test_executor_no_prior_context_shows_none_in_prompt(): """When no prior steps exist the prompt contains the literal '(none)'.""" from pathlib import Path - llm = _CapturingLLM('{}') + llm = _CapturingLLM("{}") executor = Executor(llm, server_paths={"iot": Path("/fake/server.py")}) # type: ignore[arg-type] step = _make_step(1, tool="sites") with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), - patch("agent.plan_execute.executor._call_tool", new=AsyncMock(return_value="{}")), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), + patch( + "agent.plan_execute.executor._call_tool", new=AsyncMock(return_value="{}") + ), ): await executor.execute_step(step, {}, "Q") @@ -356,7 +392,7 @@ async def test_executor_context_accumulates_across_steps(): """Step 3's LLM prompt contains results from both steps 1 and 2.""" from pathlib import Path - llm = _CapturingLLM('{}') + llm = _CapturingLLM("{}") executor = Executor(llm, server_paths={"iot": Path("/fake/server.py")}) # type: ignore[arg-type] plan = Plan( @@ -370,7 +406,10 @@ async def test_executor_context_accumulates_across_steps(): resp1, resp2, resp3 = '{"sites":["MAIN"]}', '{"assets":["CH-1"]}', '{"sensors":[]}' call_mock = AsyncMock(side_effect=[resp1, resp2, resp3]) with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), patch("agent.plan_execute.executor._call_tool", new=call_mock), ): await executor.execute_plan(plan, "Q") @@ -395,16 +434,21 @@ async def test_pipeline_uses_llm_args_for_each_step(sequential_llm): "#Dependency2: #S1\n" "#ExpectedOutput2: List of assets" ) - llm = sequential_llm([ - planner_output, # planner call - '{}', # arg resolution for step 1 (sites needs no args) - '{"site_name": "MAIN"}', # arg resolution for step 2 (uses step 1 result) - "Final answer.", # summarisation - ]) + llm = sequential_llm( + [ + planner_output, # planner call + "{}", # arg resolution for step 1 (sites needs no args) + '{"site_name": "MAIN"}', # arg resolution for step 2 (uses step 1 result) + "Final answer.", # summarisation + ] + ) call_mock = AsyncMock(side_effect=['{"sites": ["MAIN"]}', '{"assets": ["CH-1"]}']) with ( - patch("agent.plan_execute.executor._list_tools", new=AsyncMock(return_value=_MOCK_TOOLS)), + patch( + "agent.plan_execute.executor._list_tools", + new=AsyncMock(return_value=_MOCK_TOOLS), + ), patch("agent.plan_execute.executor._call_tool", new=call_mock), ): result = await PlanExecuteRunner(llm).run("List all assets at site MAIN") @@ -420,10 +464,18 @@ async def test_pipeline_uses_llm_args_for_each_step(sequential_llm): @pytest.mark.anyio async def test_resolve_args_with_llm_uses_context(mock_llm): llm = mock_llm('{"asset_id": "CH-1"}') - ctx = {1: StepResult(step_number=1, task="t", server="a", - response='{"assets": ["CH-1", "CH-2"]}')} + ctx = { + 1: StepResult( + step_number=1, task="t", server="a", response='{"assets": ["CH-1", "CH-2"]}' + ) + } result = await _resolve_args_with_llm( - "What sensors does CH-1 have?", "get sensors", "sensors", "", ctx, llm, + "What sensors does CH-1 have?", + "get sensors", + "sensors", + "", + ctx, + llm, ) assert result["asset_id"] == "CH-1" @@ -440,14 +492,19 @@ async def test_resolve_args_with_llm_fallback_on_bad_json(mock_llm): async def test_resolve_args_with_llm_question_in_prompt(): llm = _CapturingLLM('{"site_name": "MAIN"}') await _resolve_args_with_llm( - "What sites exist?", "List sites", "sites", "", {}, llm # type: ignore[arg-type] + "What sites exist?", + "List sites", + "sites", + "", + {}, + llm, # type: ignore[arg-type] ) assert "What sites exist?" in llm.prompts[0] @pytest.mark.anyio async def test_resolve_args_with_llm_tool_in_prompt(): - llm = _CapturingLLM('{}') + llm = _CapturingLLM("{}") await _resolve_args_with_llm("Q", "List IoT sites", "sites", "", {}, llm) # type: ignore[arg-type] assert "sites" in llm.prompts[0] @@ -465,7 +522,7 @@ async def test_resolve_args_with_llm_schema_in_prompt(): @pytest.mark.anyio async def test_resolve_args_with_llm_unknown_schema_shows_sentinel(): """Empty schema renders as '(unknown)' in the prompt.""" - llm = _CapturingLLM('{}') + llm = _CapturingLLM("{}") await _resolve_args_with_llm("Q", "task", "tool", "", {}, llm) # type: ignore[arg-type] assert "(unknown)" in llm.prompts[0] @@ -473,15 +530,17 @@ async def test_resolve_args_with_llm_unknown_schema_shows_sentinel(): @pytest.mark.anyio async def test_resolve_args_with_llm_context_in_prompt(): """Prior step results appear verbatim in the generated prompt.""" - llm = _CapturingLLM('{}') - ctx = {1: StepResult(step_number=1, task="t", server="a", response="step-one-result")} + llm = _CapturingLLM("{}") + ctx = { + 1: StepResult(step_number=1, task="t", server="a", response="step-one-result") + } await _resolve_args_with_llm("Q", "task", "tool", "", ctx, llm) # type: ignore[arg-type] assert "step-one-result" in llm.prompts[0] @pytest.mark.anyio async def test_resolve_args_with_llm_empty_context_shows_none(): - llm = _CapturingLLM('{}') + llm = _CapturingLLM("{}") await _resolve_args_with_llm("Q", "task", "tool", "", {}, llm) # type: ignore[arg-type] assert "(none)" in llm.prompts[0] diff --git a/src/couchdb/init_data.py b/src/couchdb/init_data.py index 86a73bc9c..dee9f74d3 100644 --- a/src/couchdb/init_data.py +++ b/src/couchdb/init_data.py @@ -24,9 +24,9 @@ from dotenv import load_dotenv -try: # works as a package (python -m couchdb.init_data / imports) +try: # works as a package (python -m couchdb.init_data / imports) from . import loader -except ImportError: # works as a script (python3 /couchdb/init_data.py) +except ImportError: # works as a script (python3 /couchdb/init_data.py) import loader load_dotenv() @@ -35,9 +35,12 @@ _HERE = os.path.dirname(os.path.abspath(__file__)) -SCENARIOS_DATA_DIR = os.environ.get("SCENARIOS_DATA_DIR", os.path.join(_HERE, "scenarios_data")) +SCENARIOS_DATA_DIR = os.environ.get( + "SCENARIOS_DATA_DIR", os.path.join(_HERE, "scenarios_data") +) DEFAULT_MANIFEST_FILE = os.environ.get( - "DEFAULT_MANIFEST", os.path.join(SCENARIOS_DATA_DIR, "default", "manifest.json")) + "DEFAULT_MANIFEST", os.path.join(SCENARIOS_DATA_DIR, "default", "manifest.json") +) # --------------------------------------------------------------------------- # @@ -48,14 +51,21 @@ def _load_default_manifest() -> tuple: if not os.path.isfile(DEFAULT_MANIFEST_FILE): raise FileNotFoundError( f"default manifest not found: {DEFAULT_MANIFEST_FILE}. " - "Create scenarios_data/default/manifest.json (or set DEFAULT_MANIFEST).") + "Create scenarios_data/default/manifest.json (or set DEFAULT_MANIFEST)." + ) with open(DEFAULT_MANIFEST_FILE) as f: return json.load(f), os.path.dirname(DEFAULT_MANIFEST_FILE) def manifest_path(scenario_id) -> str: - folder = os.path.join(SCENARIOS_DATA_DIR, f"scenario_{scenario_id}", "manifest.json") - return folder if os.path.isfile(folder) else os.path.join(SCENARIOS_DATA_DIR, f"scenario_{scenario_id}.json") + folder = os.path.join( + SCENARIOS_DATA_DIR, f"scenario_{scenario_id}", "manifest.json" + ) + return ( + folder + if os.path.isfile(folder) + else os.path.join(SCENARIOS_DATA_DIR, f"scenario_{scenario_id}.json") + ) def _resolve_manifest(scenario_id) -> tuple: @@ -77,8 +87,11 @@ def _resolve_manifest(scenario_id) -> tuple: raise FileNotFoundError( f"no manifest for scenario {scenario_id}: expected " f"{os.path.join(SCENARIOS_DATA_DIR, f'scenario_{scenario_id}', 'manifest.json')} " - f"or {os.path.join(SCENARIOS_DATA_DIR, f'scenario_{scenario_id}.json')}") - base_dir = os.path.dirname(path) if os.path.basename(path) == "manifest.json" else None + f"or {os.path.join(SCENARIOS_DATA_DIR, f'scenario_{scenario_id}.json')}" + ) + base_dir = ( + os.path.dirname(path) if os.path.basename(path) == "manifest.json" else None + ) with open(path) as f: return json.load(f), base_dir @@ -113,34 +126,61 @@ def reset(managed_only: bool = False) -> list: # --------------------------------------------------------------------------- # # Load # --------------------------------------------------------------------------- # -def init_data(scenario_id=None, force: bool = True, reset_first: bool = False, - managed_only: bool = False) -> dict: +def init_data( + scenario_id=None, + force: bool = True, + reset_first: bool = False, + managed_only: bool = False, +) -> dict: """Load a scenario's data (or the default) into CouchDB. Returns {collection: (db, n)}. Resolves the manifest first, so an unknown ``scenario_id`` raises FileNotFoundError before anything is dropped. ``reset_first=True`` then drops databases so collections absent from the manifest are left empty rather than carrying over. """ - manifest, base_dir = _resolve_manifest(scenario_id) # validate first (raises on unknown id) + manifest, base_dir = _resolve_manifest( + scenario_id + ) # validate first (raises on unknown id) if reset_first: reset(managed_only=managed_only) results = {} for key, spec in manifest.items(): - results[key] = loader.load_collection(key, spec, drop=force, base_dir=base_dir) # database name = key - logger.info("Scenario %s: '%s' → %s (%d docs).", scenario_id, key, *results[key]) + results[key] = loader.load_collection( + key, spec, drop=force, base_dir=base_dir + ) # database name = key + logger.info( + "Scenario %s: '%s' → %s (%d docs).", scenario_id, key, *results[key] + ) return results def main() -> None: logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") - p = argparse.ArgumentParser(description="Load CouchDB data for a scenario (default if omitted).") - p.add_argument("scenario", nargs="?", default=None, - help="Scenario id → scenarios_data/scenario_.json (omit for default).") - p.add_argument("--reuse", action="store_true", help="Reuse instead of reloading from scratch.") - p.add_argument("--reset", action="store_true", help="Drop databases first, then load (clean start).") - p.add_argument("--reset-only", action="store_true", help="Drop databases and exit (no load).") - p.add_argument("--managed-only", action="store_true", - help="With --reset/--reset-only: drop only the default-manifest collections.") + p = argparse.ArgumentParser( + description="Load CouchDB data for a scenario (default if omitted)." + ) + p.add_argument( + "scenario", + nargs="?", + default=None, + help="Scenario id → scenarios_data/scenario_.json (omit for default).", + ) + p.add_argument( + "--reuse", action="store_true", help="Reuse instead of reloading from scratch." + ) + p.add_argument( + "--reset", + action="store_true", + help="Drop databases first, then load (clean start).", + ) + p.add_argument( + "--reset-only", action="store_true", help="Drop databases and exit (no load)." + ) + p.add_argument( + "--managed-only", + action="store_true", + help="With --reset/--reset-only: drop only the default-manifest collections.", + ) a = p.parse_args() if a.reset_only: @@ -148,10 +188,11 @@ def main() -> None: print(f"dropped\t{db}") return - for key, (db, n) in init_data(a.scenario, force=not a.reuse, - reset_first=a.reset, managed_only=a.managed_only).items(): + for key, (db, n) in init_data( + a.scenario, force=not a.reuse, reset_first=a.reset, managed_only=a.managed_only + ).items(): print(f"{key}\t{db}\t{n}") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/couchdb/loader.py b/src/couchdb/loader.py index 22fa1e799..a95489300 100644 --- a/src/couchdb/loader.py +++ b/src/couchdb/loader.py @@ -38,8 +38,13 @@ _HERE = os.path.dirname(os.path.abspath(__file__)) COUCHDB_URL = os.environ.get("COUCHDB_URL", "http://localhost:5984") -_AUTH = (os.environ.get("COUCHDB_USERNAME", "admin"), os.environ.get("COUCHDB_PASSWORD", "password")) -COLLECTIONS_FILE = os.environ.get("COLLECTIONS_CONFIG", os.path.join(_HERE, "collections.json")) +_AUTH = ( + os.environ.get("COUCHDB_USERNAME", "admin"), + os.environ.get("COUCHDB_PASSWORD", "password"), +) +COLLECTIONS_FILE = os.environ.get( + "COLLECTIONS_CONFIG", os.path.join(_HERE, "collections.json") +) SAMPLE_DATA_DIR = os.path.join(_HERE, "sample_data") @@ -87,7 +92,11 @@ def parse_csv(path, cfg) -> list: for row in df.to_dict(orient="records"): doc = {} for col, val in row.items(): - if val is None or (isinstance(val, float) and pd.isna(val)) or str(val).strip() == "": + if ( + val is None + or (isinstance(val, float) and pd.isna(val)) + or str(val).strip() == "" + ): continue v = _coerce(col, val, int_f, float_f, json_f) if "." in col: @@ -134,9 +143,9 @@ def files_from(s): return [p] docs = [] - for item in (source if isinstance(source, list) else [source]): + for item in source if isinstance(source, list) else [source]: if isinstance(item, dict): - docs.append(item) # inline document + docs.append(item) # inline document elif isinstance(item, str): for fp in files_from(item): if not os.path.isfile(fp): @@ -153,9 +162,9 @@ def _transform_for(key): """Optional per-collection transform: a function named in transforms.py.""" try: try: - from . import transforms # package context + from . import transforms # package context except ImportError: - import transforms # script context + import transforms # script context return getattr(transforms, key, None) except Exception: return None @@ -227,26 +236,41 @@ def _install_design(db, design_doc): if existing.status_code == 200: design["_rev"] = existing.json()["_rev"] resp = requests.put(url, json=design, auth=_AUTH, timeout=10) - if not resp.ok: # surface CouchDB's actual reason (e.g. compilation_error) - raise RuntimeError(f"design doc install failed for '{db}' ({resp.status_code}): {resp.text}") + if not resp.ok: # surface CouchDB's actual reason (e.g. compilation_error) + raise RuntimeError( + f"design doc install failed for '{db}' ({resp.status_code}): {resp.text}" + ) def _create_indexes(db, indexes): for fields in indexes or []: - requests.post(_db_url(db, "_index"), json={"index": {"fields": fields}, "type": "json"}, - auth=_AUTH, timeout=10).raise_for_status() + requests.post( + _db_url(db, "_index"), + json={"index": {"fields": fields}, "type": "json"}, + auth=_AUTH, + timeout=10, + ).raise_for_status() def _bulk_insert(db, docs, batch_size=500): total = len(docs) for i in range(0, total, batch_size): - batch = docs[i:i + batch_size] - r = requests.post(_db_url(db, "_bulk_docs"), json={"docs": batch}, auth=_AUTH, timeout=60) + batch = docs[i : i + batch_size] + r = requests.post( + _db_url(db, "_bulk_docs"), json={"docs": batch}, auth=_AUTH, timeout=60 + ) r.raise_for_status() errors = [x for x in r.json() if x.get("error")] if errors: - logger.warning("%d bulk-insert errors in batch %d", len(errors), i // batch_size) - logger.info("Inserted batch %d/%d (%d docs)", i // batch_size + 1, math.ceil(total / batch_size), len(batch)) + logger.warning( + "%d bulk-insert errors in batch %d", len(errors), i // batch_size + ) + logger.info( + "Inserted batch %d/%d (%d docs)", + i // batch_size + 1, + math.ceil(total / batch_size), + len(batch), + ) # --------------------------------------------------------------------------- # @@ -256,7 +280,10 @@ def load_collection(key, source, drop=True, base_dir=None) -> tuple: """Load one collection's data into a database named after the key. Returns (db, n).""" cfg = collection_config(key) transform = _transform_for(key) - docs = [_normalise(d, key, cfg, transform) for d in _collect_docs(key, source, cfg, base_dir)] + docs = [ + _normalise(d, key, cfg, transform) + for d in _collect_docs(key, source, cfg, base_dir) + ] db = key if docs: _ensure_db(db, drop=drop) @@ -264,4 +291,4 @@ def load_collection(key, source, drop=True, base_dir=None) -> tuple: _install_design(db, cfg["design_doc"]) _bulk_insert(db, docs) _create_indexes(db, cfg.get("indexes")) - return db, len(docs) \ No newline at end of file + return db, len(docs) diff --git a/src/couchdb/transforms.py b/src/couchdb/transforms.py index ca3513413..229b8cb00 100644 --- a/src/couchdb/transforms.py +++ b/src/couchdb/transforms.py @@ -14,11 +14,25 @@ # Work-order CSV columns that must be typed (CSV gives everything as strings). _WO_INT = ("wopriority", "taskid") _WO_FLOAT = ( - "estlabhrs", "actlabhrs", "estlabcost", "actlabcost", "estmatcost", "actmatcost", - "estservcost", "actservcost", "esttoolcost", "acttoolcost", "estatapprtotalcost", - "esttotalcost", "acttotalcost", + "estlabhrs", + "actlabhrs", + "estlabcost", + "actlabcost", + "estmatcost", + "actmatcost", + "estservcost", + "actservcost", + "esttoolcost", + "acttoolcost", + "estatapprtotalcost", + "esttotalcost", + "acttotalcost", ) -_WO_EVIDENCE_FLOAT = ("anomaly_score", "threshold", "observed_value") # nested under aob_source.evidence +_WO_EVIDENCE_FLOAT = ( + "anomaly_score", + "threshold", + "observed_value", +) # nested under aob_source.evidence def workorder(doc): @@ -32,7 +46,7 @@ def workorder(doc): if isinstance(doc.get(f), str): doc[f] = float(doc[f]) - if isinstance(doc.get("wplabor"), str): # JSON-string column → list + if isinstance(doc.get("wplabor"), str): # JSON-string column → list doc["wplabor"] = json.loads(doc["wplabor"]) evidence = doc.get("aob_source", {}).get("evidence") @@ -41,4 +55,4 @@ def workorder(doc): if isinstance(evidence.get(f), str): evidence[f] = float(evidence[f]) - return doc \ No newline at end of file + return doc diff --git a/src/evaluation/cli.py b/src/evaluation/cli.py index 66ee508d3..226617252 100644 --- a/src/evaluation/cli.py +++ b/src/evaluation/cli.py @@ -47,8 +47,7 @@ def _build_parser() -> argparse.ArgumentParser: "--scorer-default", dest="scorer_default", default="llm_judge", - help="Scorer name when scenario.scoring_method is unset. " - "Default: llm_judge.", + help="Scorer name when scenario.scoring_method is unset. Default: llm_judge.", ) p.add_argument( "--judge-model", diff --git a/src/evaluation/evaluator.py b/src/evaluation/evaluator.py index 27845f088..a9a1bed96 100644 --- a/src/evaluation/evaluator.py +++ b/src/evaluation/evaluator.py @@ -82,7 +82,9 @@ def _score_one( def _resolve(name: str) -> Scorer: return scorer_registry.get(name) - def _validate_judge_model(self, scorer_name: str, traj: PersistedTrajectory) -> None: + def _validate_judge_model( + self, scorer_name: str, traj: PersistedTrajectory + ) -> None: if scorer_name != "llm_judge" or not self.judge_model: return diff --git a/src/evaluation/loader.py b/src/evaluation/loader.py index 1d5c0c9db..e1481f0fd 100644 --- a/src/evaluation/loader.py +++ b/src/evaluation/loader.py @@ -45,20 +45,20 @@ def _load_one_trajectory(path: Path) -> PersistedTrajectory: def load_scenarios(paths: Iterable[Path] | Path) -> list[Scenario]: """Load scenarios from one or more files or directories. - Supported inputs: + Supported inputs: - 1. Existing JSON / JSONL scenario files. - 2. A directory containing scenario subdirectories, each with - ``groundtruth.txt``. For example: + 1. Existing JSON / JSONL scenario files. + 2. A directory containing scenario subdirectories, each with + ``groundtruth.txt``. For example: - scenarios_data/ - scenario_11/ - groundtruth.txt - scenario_12/ - groundtruth.txt + scenarios_data/ + scenario_11/ + groundtruth.txt + scenario_12/ + groundtruth.txt - For folder-based scenarios, the folder name becomes the scenario id and - ``groundtruth.txt`` becomes ``expected_answer``. + For folder-based scenarios, the folder name becomes the scenario id and + ``groundtruth.txt`` becomes ``expected_answer``. """ if isinstance(paths, (str, Path)): paths = [Path(paths)] @@ -154,4 +154,4 @@ def join_records( continue scenario = by_id.get(traj.scenario_id) if scenario is not None: - yield scenario, traj \ No newline at end of file + yield scenario, traj diff --git a/src/evaluation/metrics.py b/src/evaluation/metrics.py index 325074a7e..0255263c3 100644 --- a/src/evaluation/metrics.py +++ b/src/evaluation/metrics.py @@ -40,7 +40,9 @@ def _from_sdk_trajectory(traj: dict, model: str) -> OpsMetrics: tokens_in = sum(int(t.get("input_tokens") or 0) for t in turns) tokens_out = sum(int(t.get("output_tokens") or 0) for t in turns) - durations_ms = [t.get("duration_ms") for t in turns if t.get("duration_ms") is not None] + durations_ms = [ + t.get("duration_ms") for t in turns if t.get("duration_ms") is not None + ] duration_ms = sum(durations_ms) if durations_ms else None tool_names: list[str] = [] @@ -65,11 +67,7 @@ def _from_plan_execute(steps: list[Any], model: str) -> OpsMetrics: # plan-execute persists ``list[StepResult]``; the dataclass exposes # ``server`` / ``tool`` / ``response`` fields but no per-step token # counts, so we surface what is available and leave the rest at zero. - tool_names = [ - s.get("tool") - for s in steps - if isinstance(s, dict) and s.get("tool") - ] + tool_names = [s.get("tool") for s in steps if isinstance(s, dict) and s.get("tool")] return OpsMetrics( turn_count=len(steps), tool_call_count=len(tool_names), diff --git a/src/evaluation/scorers/__init__.py b/src/evaluation/scorers/__init__.py index f681844a8..37973fae6 100644 --- a/src/evaluation/scorers/__init__.py +++ b/src/evaluation/scorers/__init__.py @@ -30,9 +30,7 @@ def register(name: str, scorer: Scorer) -> None: def get(name: str) -> Scorer: if name not in _REGISTRY: - raise KeyError( - f"unknown scorer {name!r}; registered: {sorted(_REGISTRY)}" - ) + raise KeyError(f"unknown scorer {name!r}; registered: {sorted(_REGISTRY)}") return _REGISTRY[name] @@ -48,4 +46,4 @@ def names() -> list[str]: from . import semantic # noqa: E402,F401 from .static_json import install as _install_static_json # noqa: E402 -_install_static_json() \ No newline at end of file +_install_static_json() diff --git a/src/evaluation/scorers/llm_judge.py b/src/evaluation/scorers/llm_judge.py index e37ecc219..139744ddb 100644 --- a/src/evaluation/scorers/llm_judge.py +++ b/src/evaluation/scorers/llm_judge.py @@ -140,9 +140,7 @@ def __call__( if review.get("hallucinations") is True: score = max(0.0, score - 0.2) - rationale = str( - review.get("suggestions") or review.get("reason") or "" - )[:500] + rationale = str(review.get("suggestions") or review.get("reason") or "")[:500] return ScorerResult( scorer=self.name, passed=passed, diff --git a/src/evaluation/scorers/static_json.py b/src/evaluation/scorers/static_json.py index a26c53db2..20fe72621 100644 --- a/src/evaluation/scorers/static_json.py +++ b/src/evaluation/scorers/static_json.py @@ -114,9 +114,7 @@ def _extract_balanced_structure(content: str) -> str: (content.find("("), "(", ")"), ] candidates = [ - (idx, open_ch, close_ch) - for idx, open_ch, close_ch in candidates - if idx != -1 + (idx, open_ch, close_ch) for idx, open_ch, close_ch in candidates if idx != -1 ] if not candidates: @@ -367,9 +365,7 @@ def evaluate_static_json( precision = exact_matches / total_model_keys if total_model_keys else 0.0 recall = exact_matches / total_gold_keys if total_gold_keys else 0.0 f1 = ( - 2 * precision * recall / (precision + recall) - if precision + recall > 0 - else 0.0 + 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0 ) partial_exact = exact_matches / total_gold_keys if total_gold_keys else 0.0 @@ -392,6 +388,7 @@ def evaluate_static_json( details=details, ) + def evaluate_static_json_batch( pairs: list[tuple[Any, Any]], *, @@ -439,6 +436,7 @@ def evaluate_static_json_batch( "examples": [score.to_dict() for score in scores], } + class StaticJsonScorer: """Evaluation scorer wrapper for the trajectory-based pipeline.""" @@ -482,4 +480,4 @@ def __call__( def install(name: str = "static_json") -> None: """Register the static JSON scorer.""" - register(name, StaticJsonScorer(name=name)) \ No newline at end of file + register(name, StaticJsonScorer(name=name)) diff --git a/src/evaluation/tests/test_loader.py b/src/evaluation/tests/test_loader.py index 27d5c9a92..72b3b3e5b 100644 --- a/src/evaluation/tests/test_loader.py +++ b/src/evaluation/tests/test_loader.py @@ -21,7 +21,9 @@ def test_load_trajectories_from_dir(trajectory_dir: Path): def test_load_trajectories_skips_unparseable(tmp_path: Path, make_persisted_record): - (tmp_path / "good.json").write_text(json.dumps(make_persisted_record()), encoding="utf-8") + (tmp_path / "good.json").write_text( + json.dumps(make_persisted_record()), encoding="utf-8" + ) (tmp_path / "bad.json").write_text("{not json", encoding="utf-8") records = load_trajectories(tmp_path) assert len(records) == 1 @@ -30,9 +32,7 @@ def test_load_trajectories_skips_unparseable(tmp_path: Path, make_persisted_reco def test_load_scenarios_json_list(tmp_path: Path): p = tmp_path / "s.json" p.write_text( - json.dumps( - [{"id": 1, "text": "Q1"}, {"id": "2", "text": "Q2"}] - ), + json.dumps([{"id": 1, "text": "Q1"}, {"id": "2", "text": "Q2"}]), encoding="utf-8", ) out = load_scenarios(p) @@ -65,7 +65,9 @@ def test_join_drops_orphans(make_persisted_record): ] trajs = [ PersistedTrajectory.from_raw(make_persisted_record(scenario_id=1)), - PersistedTrajectory.from_raw(make_persisted_record(run_id="r2", scenario_id=99)), + PersistedTrajectory.from_raw( + make_persisted_record(run_id="r2", scenario_id=99) + ), ] pairs = list(join_records(scenarios, trajs)) assert len(pairs) == 1 @@ -108,4 +110,4 @@ def test_load_scenarios_from_groundtruth_folders(tmp_path): assert len(scenarios) == 1 assert scenarios[0].id == "11" assert scenarios[0].expected_answer == "{'energy': 14, 'material': 48}" - assert scenarios[0].scoring_method == "static_json" \ No newline at end of file + assert scenarios[0].scoring_method == "static_json" diff --git a/src/evaluation/tests/test_metrics.py b/src/evaluation/tests/test_metrics.py index 21f097b1c..df096d032 100644 --- a/src/evaluation/tests/test_metrics.py +++ b/src/evaluation/tests/test_metrics.py @@ -47,9 +47,27 @@ def test_plan_execute_list_trajectory(self, make_persisted_record): rec = PersistedTrajectory.from_raw( make_persisted_record( trajectory=[ - {"step_number": 1, "task": "t", "server": "iot", "tool": "sites", "response": "ok"}, - {"step_number": 2, "task": "t2", "server": "iot", "tool": "assets", "response": "ok"}, - {"step_number": 3, "task": "t3", "server": "iot", "tool": "sites", "response": "ok"}, + { + "step_number": 1, + "task": "t", + "server": "iot", + "tool": "sites", + "response": "ok", + }, + { + "step_number": 2, + "task": "t2", + "server": "iot", + "tool": "assets", + "response": "ok", + }, + { + "step_number": 3, + "task": "t3", + "server": "iot", + "tool": "sites", + "response": "ok", + }, ] ) ) @@ -67,9 +85,21 @@ def test_empty(self): def test_sums_and_percentiles(self): results = [ - _result(ops=OpsMetrics(tokens_in=10, tokens_out=5, duration_ms=100.0, tool_call_count=1)), - _result(ops=OpsMetrics(tokens_in=20, tokens_out=10, duration_ms=300.0, tool_call_count=2)), - _result(ops=OpsMetrics(tokens_in=30, tokens_out=15, duration_ms=500.0, tool_call_count=3)), + _result( + ops=OpsMetrics( + tokens_in=10, tokens_out=5, duration_ms=100.0, tool_call_count=1 + ) + ), + _result( + ops=OpsMetrics( + tokens_in=20, tokens_out=10, duration_ms=300.0, tool_call_count=2 + ) + ), + _result( + ops=OpsMetrics( + tokens_in=30, tokens_out=15, duration_ms=500.0, tool_call_count=3 + ) + ), ] agg = aggregate_ops(results) assert agg.tokens_in_total == 60 @@ -90,7 +120,10 @@ def test_cost_only_when_some_present(self): class TestNormalizeModel: def test_strips_provider_prefix(self): - assert _normalize_model("litellm_proxy/anthropic/claude-opus-4-5") == "claude-opus-4-5" + assert ( + _normalize_model("litellm_proxy/anthropic/claude-opus-4-5") + == "claude-opus-4-5" + ) assert _normalize_model("watsonx/ibm/granite-13b") == "granite-13b" def test_strips_long_numeric_suffix(self): diff --git a/src/evaluation/tests/test_models.py b/src/evaluation/tests/test_models.py index 4aca4d551..621107a02 100644 --- a/src/evaluation/tests/test_models.py +++ b/src/evaluation/tests/test_models.py @@ -10,7 +10,9 @@ def test_scenario_from_raw_coerces_int_id_to_str(): def test_scenario_preserves_extra_fields(): - s = Scenario.from_raw({"id": "1", "text": "Q", "characteristic_form": "X", "tolerance": 0.01}) + s = Scenario.from_raw( + {"id": "1", "text": "Q", "characteristic_form": "X", "tolerance": 0.01} + ) extra = s.model_extra or {} assert extra.get("tolerance") == 0.01 diff --git a/src/evaluation/tests/test_report.py b/src/evaluation/tests/test_report.py index 7c71788dc..aabb5042c 100644 --- a/src/evaluation/tests/test_report.py +++ b/src/evaluation/tests/test_report.py @@ -27,7 +27,9 @@ def _result(stype: str, passed: bool, run_id: str = "", **ops_kwargs) -> Scenari model="watsonx/ibm/granite", question="q", answer="a", - score=ScorerResult(scorer="llm_judge", passed=passed, score=1.0 if passed else 0.0), + score=ScorerResult( + scorer="llm_judge", passed=passed, score=1.0 if passed else 0.0 + ), ops=OpsMetrics(**ops_kwargs), ) @@ -98,7 +100,14 @@ def test_write_reports_dir_falls_back_to_scenario_id(tmp_path: Path): def test_render_summary_includes_headlines(): results = [ - _result("iot", True, tokens_in=10, tokens_out=5, duration_ms=100.0, tool_call_count=1), + _result( + "iot", + True, + tokens_in=10, + tokens_out=5, + duration_ms=100.0, + tool_call_count=1, + ), _result("iot", False, tokens_in=8, tokens_out=4, duration_ms=200.0), ] text = render_summary(build_report(results)) diff --git a/src/evaluation/tests/test_runner.py b/src/evaluation/tests/test_runner.py index f8a936db0..b82123f74 100644 --- a/src/evaluation/tests/test_runner.py +++ b/src/evaluation/tests/test_runner.py @@ -10,7 +10,9 @@ from evaluation import scorers as registry -def _always_pass_scorer(scenario: Scenario, answer: str, trajectory_text: str) -> ScorerResult: +def _always_pass_scorer( + scenario: Scenario, answer: str, trajectory_text: str +) -> ScorerResult: return ScorerResult(scorer="stub", passed=True, score=1.0) @@ -46,11 +48,15 @@ def test_evaluate_end_to_end(tmp_path: Path, make_persisted_record): assert report.ops.tokens_in_total > 0 -def _always_fail_scorer(scenario: Scenario, answer: str, trajectory_text: str) -> ScorerResult: +def _always_fail_scorer( + scenario: Scenario, answer: str, trajectory_text: str +) -> ScorerResult: return ScorerResult(scorer="stub-fail", passed=False, score=0.0) -def test_evaluate_uses_per_scenario_scoring_method(tmp_path: Path, make_persisted_record): +def test_evaluate_uses_per_scenario_scoring_method( + tmp_path: Path, make_persisted_record +): rec = make_persisted_record(run_id="run-x", scenario_id=1, answer="A.") (tmp_path / "run-x.json").write_text(json.dumps(rec), encoding="utf-8") diff --git a/src/evaluation/tests/test_static_json_scorer.py b/src/evaluation/tests/test_static_json_scorer.py index 175a320c6..97ce72399 100644 --- a/src/evaluation/tests/test_static_json_scorer.py +++ b/src/evaluation/tests/test_static_json_scorer.py @@ -5,6 +5,7 @@ parse_structured_answer, ) + def test_parse_json_object_from_noisy_markdown_answer(): raw = 'Answer:\n```json\n{"energy": 3, "material": 12}\n```' @@ -122,7 +123,6 @@ def test_batch_evaluation(): assert result["strict_exact_match_accuracy"] == 0.5 - from evaluation.models import Scenario from evaluation.scorers.static_json import StaticJsonScorer @@ -147,4 +147,4 @@ def test_static_json_scorer_wrapper_exact_match(): assert result.scorer == "static_json" assert result.passed is True assert result.score == 1.0 - assert result.details["strict_exact_match_accuracy"] == 1.0 \ No newline at end of file + assert result.details["strict_exact_match_accuracy"] == 1.0 diff --git a/src/llm/base.py b/src/llm/base.py index a6b085141..6df322ab5 100644 --- a/src/llm/base.py +++ b/src/llm/base.py @@ -27,9 +27,7 @@ def generate(self, prompt: str, temperature: float = 0.0) -> str: """Generate text given a prompt.""" ... - def generate_with_usage( - self, prompt: str, temperature: float = 0.0 - ) -> LLMResult: + def generate_with_usage(self, prompt: str, temperature: float = 0.0) -> LLMResult: """Generate text and report token usage. Default impl delegates to :meth:`generate` and reports zero usage — diff --git a/src/llm/litellm.py b/src/llm/litellm.py index 85067c7c1..3a1edd1b7 100644 --- a/src/llm/litellm.py +++ b/src/llm/litellm.py @@ -36,9 +36,7 @@ def __init__(self, model_id: str) -> None: def generate(self, prompt: str, temperature: float = 0.0) -> str: return self.generate_with_usage(prompt, temperature).text - def generate_with_usage( - self, prompt: str, temperature: float = 0.0 - ) -> LLMResult: + def generate_with_usage(self, prompt: str, temperature: float = 0.0) -> LLMResult: import litellm kwargs: dict = { diff --git a/src/llm/openai_compat.py b/src/llm/openai_compat.py index 428ed72f6..6b0531ba5 100644 --- a/src/llm/openai_compat.py +++ b/src/llm/openai_compat.py @@ -29,18 +29,14 @@ class OpenAICompatBackend(LLMBackend): def __init__(self, model_id: str) -> None: if not is_openai_compat(model_id): - raise ValueError( - f"unsupported OpenAI-compatible model id: {model_id!r}" - ) + raise ValueError(f"unsupported OpenAI-compatible model id: {model_id!r}") self._model_id = model_id self._model_name = resolve_model(model_id) def generate(self, prompt: str, temperature: float = 0.0) -> str: return self.generate_with_usage(prompt, temperature).text - def generate_with_usage( - self, prompt: str, temperature: float = 0.0 - ) -> LLMResult: + def generate_with_usage(self, prompt: str, temperature: float = 0.0) -> LLMResult: from openai import OpenAI creds = resolve_router_creds(self._model_id) # strict: clear error if unset diff --git a/src/llm/routers.py b/src/llm/routers.py index d89cf7fb8..5dd0121e4 100644 --- a/src/llm/routers.py +++ b/src/llm/routers.py @@ -57,7 +57,7 @@ def resolve_model(model_id: str) -> str: ``"anthropic/claude-sonnet-4-6"`` -> unchanged. """ prefix = router_prefix(model_id) - return model_id[len(prefix):] if prefix else model_id + return model_id[len(prefix) :] if prefix else model_id def is_openai_compat(model_id: str) -> bool: @@ -65,9 +65,7 @@ def is_openai_compat(model_id: str) -> bool: return model_id.startswith(OPENAI_COMPAT_PREFIXES) -def resolve_router_creds( - model_id: str, *, strict: bool = True -) -> RouterCreds | None: +def resolve_router_creds(model_id: str, *, strict: bool = True) -> RouterCreds | None: """Resolve endpoint + key for *model_id*, or ``None`` if not proxied. Args: diff --git a/src/llm/tests/test_backends.py b/src/llm/tests/test_backends.py index 632b137d0..38046ce9e 100644 --- a/src/llm/tests/test_backends.py +++ b/src/llm/tests/test_backends.py @@ -17,9 +17,7 @@ def create(**kwargs): captured.update(kwargs) return types.SimpleNamespace( choices=[ - types.SimpleNamespace( - message=types.SimpleNamespace(content="hi") - ) + types.SimpleNamespace(message=types.SimpleNamespace(content="hi")) ], usage=types.SimpleNamespace(prompt_tokens=3, completion_tokens=2), ) @@ -70,5 +68,7 @@ def test_tokenrouter_strips_prefix_and_routes(monkeypatch): def test_model_id_property_keeps_full_string(): - assert OpenAICompatBackend("tokenrouter/MiniMax-M3").model_id == "tokenrouter/MiniMax-M3" - + assert ( + OpenAICompatBackend("tokenrouter/MiniMax-M3").model_id + == "tokenrouter/MiniMax-M3" + ) diff --git a/src/observability/persistence.py b/src/observability/persistence.py index 692e13503..49f7443e9 100644 --- a/src/observability/persistence.py +++ b/src/observability/persistence.py @@ -79,9 +79,7 @@ def persist_trajectory( } try: - out_path.write_text( - json.dumps(record, indent=2, default=str), encoding="utf-8" - ) + out_path.write_text(json.dumps(record, indent=2, default=str), encoding="utf-8") except OSError: _log.exception("persist_trajectory: write failed at %s", out_path) return None diff --git a/src/observability/runspan.py b/src/observability/runspan.py index c22435806..6fe53688e 100644 --- a/src/observability/runspan.py +++ b/src/observability/runspan.py @@ -76,5 +76,3 @@ def agent_run_span( span.record_exception(exc) span.set_status(Status(StatusCode.ERROR, str(exc))) raise - - diff --git a/src/observability/tests/test_persistence.py b/src/observability/tests/test_persistence.py index d881adf5d..555ab465e 100644 --- a/src/observability/tests/test_persistence.py +++ b/src/observability/tests/test_persistence.py @@ -64,7 +64,9 @@ def test_persist_writes_file(monkeypatch, tmp_path: Path): _FakeTurn( index=0, text="hello", - tool_calls=[_FakeToolCall(name="sensors", input={"id": "CH-6"}, output="ok")], + tool_calls=[ + _FakeToolCall(name="sensors", input={"id": "CH-6"}, output="ok") + ], input_tokens=100, output_tokens=20, ), @@ -111,7 +113,9 @@ class _FakeStep: ) record = json.loads(out.read_text()) - assert record["trajectory"] == [{"step_number": 1, "task": "do thing", "success": True}] + assert record["trajectory"] == [ + {"step_number": 1, "task": "do thing", "success": True} + ] def test_persist_skips_when_no_run_id(monkeypatch, tmp_path: Path, caplog): diff --git a/src/observability/tests/test_tracing.py b/src/observability/tests/test_tracing.py index cea33a38a..8be10235f 100644 --- a/src/observability/tests/test_tracing.py +++ b/src/observability/tests/test_tracing.py @@ -88,7 +88,9 @@ def test_agent_run_span_emits_attributes(memory_exporter): assert s.attributes["agent.runner"] == "plan-execute" assert s.attributes["gen_ai.system"] == "anthropic" assert s.attributes["gen_ai.request.model"] == "litellm_proxy/aws/claude-opus-4-6" - assert s.attributes["agent.question.length"] == len("What sensors are on Chiller 6?") + assert s.attributes["agent.question.length"] == len( + "What sensors are on Chiller 6?" + ) assert s.attributes["custom.flag"] is True diff --git a/src/observability/tracing.py b/src/observability/tracing.py index d4b12ce44..7c0eb9afd 100644 --- a/src/observability/tracing.py +++ b/src/observability/tracing.py @@ -73,7 +73,9 @@ def init_tracing(service_name: str) -> None: if _initialized: return - provider = TracerProvider(resource=Resource.create({"service.name": service_name})) + provider = TracerProvider( + resource=Resource.create({"service.name": service_name}) + ) if (path := _traces_file_path()) is not None: from .file_exporter import OTLPJsonFileExporter diff --git a/src/servers/fmsr/main.py b/src/servers/fmsr/main.py index 8091dd148..d858f1587 100644 --- a/src/servers/fmsr/main.py +++ b/src/servers/fmsr/main.py @@ -30,7 +30,9 @@ load_dotenv() -_log_level = getattr(logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING) +_log_level = getattr( + logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING +) logging.basicConfig(level=_log_level) logger = logging.getLogger("fmsr-mcp-server") @@ -62,6 +64,7 @@ # ── Output parsers ──────────────────────────────────────────────────────────── + def _parse_numbered_list(text: str) -> list[str]: """Parse a numbered list response into a plain list of strings.""" items = [] @@ -97,15 +100,23 @@ def _build_llm(): model_id = os.environ.get("FMSR_MODEL_ID", _DEFAULT_MODEL_ID) if model_id.startswith("watsonx/"): - missing = [v for v in ("WATSONX_APIKEY", "WATSONX_PROJECT_ID") if not os.environ.get(v)] + missing = [ + v for v in ("WATSONX_APIKEY", "WATSONX_PROJECT_ID") if not os.environ.get(v) + ] if missing: raise RuntimeError(f"Missing env vars for WatsonX: {missing}") elif model_id.startswith("tokenrouter/"): - missing = [v for v in ("TOKENROUTER_API_KEY", "TOKENROUTER_BASE_URL") if not os.environ.get(v)] + missing = [ + v + for v in ("TOKENROUTER_API_KEY", "TOKENROUTER_BASE_URL") + if not os.environ.get(v) + ] if missing: raise RuntimeError(f"Missing env vars for TokenRouter: {missing}") else: - missing = [v for v in ("LITELLM_API_KEY", "LITELLM_BASE_URL") if not os.environ.get(v)] + missing = [ + v for v in ("LITELLM_API_KEY", "LITELLM_BASE_URL") if not os.environ.get(v) + ] if missing: raise RuntimeError(f"Missing env vars for LiteLLM: {missing}") return make_backend(model_id) @@ -159,6 +170,7 @@ def _call_relevancy(asset_name: str, failure_mode: str, sensor: str) -> dict: # ── Result models ───────────────────────────────────────────────────────────── + class ErrorResult(BaseModel): error: str @@ -192,7 +204,10 @@ class FailureModeSensorMappingResult(BaseModel): # ── FastMCP server ──────────────────────────────────────────────────────────── -mcp = FastMCP("fmsr", instructions="Failure mode and sensor reasoning: get failure modes for assets and determine which sensors can detect each failure.") +mcp = FastMCP( + "fmsr", + instructions="Failure mode and sensor reasoning: get failure modes for assets and determine which sensors can detect each failure.", +) @mcp.tool(title="Get Failure Modes") diff --git a/src/servers/fmsr/tests/conftest.py b/src/servers/fmsr/tests/conftest.py index 4a1959c69..b2b10e0db 100644 --- a/src/servers/fmsr/tests/conftest.py +++ b/src/servers/fmsr/tests/conftest.py @@ -27,7 +27,11 @@ def no_llm(): def mock_relevancy_chain(): """Patch _call_relevancy so it always returns 'Yes' without calling the LLM.""" mock = MagicMock( - return_value={"answer": "Yes", "reason": "Relevant sensor", "temporal_behavior": "Increases"} + return_value={ + "answer": "Yes", + "reason": "Relevant sensor", + "temporal_behavior": "Increases", + } ) with patch("servers.fmsr.main._call_relevancy", mock): with patch("servers.fmsr.main._llm_available", True): diff --git a/src/servers/fmsr/tests/test_tools.py b/src/servers/fmsr/tests/test_tools.py index 3bbc3129d..cece28473 100644 --- a/src/servers/fmsr/tests/test_tools.py +++ b/src/servers/fmsr/tests/test_tools.py @@ -75,7 +75,11 @@ async def test_returns_expected_keys(self, mock_relevancy_chain): data = await call_tool( mcp, "get_failure_mode_sensor_mapping", - {"asset_name": "Chiller 6", "failure_modes": _FAILURE_MODES, "sensors": _SENSORS}, + { + "asset_name": "Chiller 6", + "failure_modes": _FAILURE_MODES, + "sensors": _SENSORS, + }, ) assert "fm2sensor" in data assert "sensor2fm" in data @@ -88,7 +92,11 @@ async def test_full_relevancy_count(self, mock_relevancy_chain): data = await call_tool( mcp, "get_failure_mode_sensor_mapping", - {"asset_name": "Chiller 6", "failure_modes": _FAILURE_MODES, "sensors": _SENSORS}, + { + "asset_name": "Chiller 6", + "failure_modes": _FAILURE_MODES, + "sensors": _SENSORS, + }, ) assert len(data["full_relevancy"]) == 4 @@ -115,7 +123,11 @@ async def test_llm_unavailable_returns_error(self, no_llm): data = await call_tool( mcp, "get_failure_mode_sensor_mapping", - {"asset_name": "Chiller 6", "failure_modes": _FAILURE_MODES, "sensors": _SENSORS}, + { + "asset_name": "Chiller 6", + "failure_modes": _FAILURE_MODES, + "sensors": _SENSORS, + }, ) assert "error" in data diff --git a/src/servers/iot/main.py b/src/servers/iot/main.py index 9e4732087..c69503796 100644 --- a/src/servers/iot/main.py +++ b/src/servers/iot/main.py @@ -37,7 +37,10 @@ logger.error(f"Failed to connect to CouchDB: {e}") db = None -mcp = FastMCP("iot", instructions="IoT sensor data: browse sites, assets, sensors, and query historical readings from CouchDB.") +mcp = FastMCP( + "iot", + instructions="IoT sensor data: browse sites, assets, sensors, and query historical readings from CouchDB.", +) # Static site as per original requirement SITES = ["MAIN"] diff --git a/src/servers/iot/tests/conftest.py b/src/servers/iot/tests/conftest.py index 83a9ef3df..b99bcd0c2 100644 --- a/src/servers/iot/tests/conftest.py +++ b/src/servers/iot/tests/conftest.py @@ -16,6 +16,7 @@ def _couchdb_reachable() -> bool: return False try: import requests + requests.get(url, timeout=2) return True except Exception: diff --git a/src/servers/iot/tests/test_couchdb.py b/src/servers/iot/tests/test_couchdb.py index 36fea30cc..fc27ae6ce 100644 --- a/src/servers/iot/tests/test_couchdb.py +++ b/src/servers/iot/tests/test_couchdb.py @@ -28,7 +28,9 @@ def couchdb_client(): @requires_couchdb class TestCouchDBInfrastructure: def test_connection(self): - resp = requests.get(f"http://{COUCHDB_HOST}", auth=(COUCHDB_USERNAME, COUCHDB_PASSWORD)) + resp = requests.get( + f"http://{COUCHDB_HOST}", auth=(COUCHDB_USERNAME, COUCHDB_PASSWORD) + ) assert resp.status_code == 200 client = couchdb3.Server(FULL_URL) diff --git a/src/servers/tsfm/main.py b/src/servers/tsfm/main.py index 288a388b2..81b9f2ce9 100644 --- a/src/servers/tsfm/main.py +++ b/src/servers/tsfm/main.py @@ -72,7 +72,6 @@ logger = logging.getLogger("tsfm-mcp-server") - # ── Internal helpers ────────────────────────────────────────────────────────── @@ -115,7 +114,10 @@ def _tsad_output_to_df(output: dict) -> pd.DataFrame: # ── FastMCP server ──────────────────────────────────────────────────────────── -mcp = FastMCP("tsfm", instructions="Time-series foundation models: forecasting, finetuning, and anomaly detection using IBM Granite TinyTimeMixer.") +mcp = FastMCP( + "tsfm", + instructions="Time-series foundation models: forecasting, finetuning, and anomaly detection using IBM Granite TinyTimeMixer.", +) # ── Static tools ────────────────────────────────────────────────────────────── @@ -576,7 +578,9 @@ def run_integrated_tsad( frequency_sampling, autoregressive_modeling, ) - full_data_df = _read_ts_data(dataset_path, dataset_config_dictionary=full_config) + full_data_df = _read_ts_data( + dataset_path, dataset_config_dictionary=full_config + ) for col in target_columns: col_config = _build_dataset_config( diff --git a/src/servers/tsfm/tests/conftest.py b/src/servers/tsfm/tests/conftest.py index 169484aad..6c4cc510b 100644 --- a/src/servers/tsfm/tests/conftest.py +++ b/src/servers/tsfm/tests/conftest.py @@ -7,10 +7,12 @@ import pytest + # Skip marker for tests that require tsfm_public + its ML dependencies. def _tsfm_available() -> bool: try: import tsfm_public # noqa: F401 + return True except ImportError: return False diff --git a/src/servers/tsfm/tests/test_tools.py b/src/servers/tsfm/tests/test_tools.py index 744b4abc9..78790f34f 100644 --- a/src/servers/tsfm/tests/test_tools.py +++ b/src/servers/tsfm/tests/test_tools.py @@ -15,6 +15,7 @@ # ── get_ai_tasks ────────────────────────────────────────────────────────────── + class TestGetAITasks: @pytest.mark.anyio async def test_returns_tasks_list(self): @@ -40,6 +41,7 @@ async def test_each_task_has_description(self): # ── get_tsfm_models ─────────────────────────────────────────────────────────── + class TestGetTSFMModels: @pytest.mark.anyio async def test_returns_models_list(self): @@ -65,11 +67,13 @@ async def test_each_model_has_checkpoint_and_description(self): # ── run_tsfm_forecasting — input validation ─────────────────────────────────── + class TestRunTSFMForecastingValidation: @pytest.mark.anyio async def test_empty_dataset_path_returns_error(self): data = await call_tool( - mcp, "run_tsfm_forecasting", + mcp, + "run_tsfm_forecasting", {"dataset_path": "", "timestamp_column": "ts", "target_columns": ["val"]}, ) assert "error" in data @@ -78,8 +82,13 @@ async def test_empty_dataset_path_returns_error(self): @pytest.mark.anyio async def test_empty_target_columns_returns_error(self): data = await call_tool( - mcp, "run_tsfm_forecasting", - {"dataset_path": "/tmp/data.csv", "timestamp_column": "ts", "target_columns": []}, + mcp, + "run_tsfm_forecasting", + { + "dataset_path": "/tmp/data.csv", + "timestamp_column": "ts", + "target_columns": [], + }, ) assert "error" in data assert "target_columns" in data["error"] @@ -89,7 +98,8 @@ async def test_missing_deps_returns_error(self): # tsfm_public is not expected to be installed in the CI/MCP environment. # If it IS installed this test is a no-op (the import succeeds). data = await call_tool( - mcp, "run_tsfm_forecasting", + mcp, + "run_tsfm_forecasting", { "dataset_path": "/nonexistent/data.csv", "timestamp_column": "Timestamp", @@ -103,11 +113,13 @@ async def test_missing_deps_returns_error(self): # ── run_tsfm_finetuning — input validation ──────────────────────────────────── + class TestRunTSFMFinetuningValidation: @pytest.mark.anyio async def test_empty_dataset_path_returns_error(self): data = await call_tool( - mcp, "run_tsfm_finetuning", + mcp, + "run_tsfm_finetuning", {"dataset_path": "", "timestamp_column": "ts", "target_columns": ["val"]}, ) assert "error" in data @@ -116,8 +128,13 @@ async def test_empty_dataset_path_returns_error(self): @pytest.mark.anyio async def test_empty_target_columns_returns_error(self): data = await call_tool( - mcp, "run_tsfm_finetuning", - {"dataset_path": "/tmp/data.csv", "timestamp_column": "ts", "target_columns": []}, + mcp, + "run_tsfm_finetuning", + { + "dataset_path": "/tmp/data.csv", + "timestamp_column": "ts", + "target_columns": [], + }, ) assert "error" in data assert "target_columns" in data["error"] @@ -125,11 +142,13 @@ async def test_empty_target_columns_returns_error(self): # ── run_tsad — input validation ─────────────────────────────────────────────── + class TestRunTSADValidation: @pytest.mark.anyio async def test_empty_dataset_path_returns_error(self): data = await call_tool( - mcp, "run_tsad", + mcp, + "run_tsad", { "dataset_path": "", "tsfm_output_json": "/tmp/pred.json", @@ -143,7 +162,8 @@ async def test_empty_dataset_path_returns_error(self): @pytest.mark.anyio async def test_empty_tsfm_output_json_returns_error(self): data = await call_tool( - mcp, "run_tsad", + mcp, + "run_tsad", { "dataset_path": "/tmp/data.csv", "tsfm_output_json": "", @@ -157,7 +177,8 @@ async def test_empty_tsfm_output_json_returns_error(self): @pytest.mark.anyio async def test_invalid_task_returns_error(self): data = await call_tool( - mcp, "run_tsad", + mcp, + "run_tsad", { "dataset_path": "/tmp/data.csv", "tsfm_output_json": "/tmp/pred.json", @@ -172,7 +193,8 @@ async def test_invalid_task_returns_error(self): @pytest.mark.anyio async def test_empty_target_columns_returns_error(self): data = await call_tool( - mcp, "run_tsad", + mcp, + "run_tsad", { "dataset_path": "/tmp/data.csv", "tsfm_output_json": "/tmp/pred.json", @@ -186,11 +208,13 @@ async def test_empty_target_columns_returns_error(self): # ── run_integrated_tsad — input validation ──────────────────────────────────── + class TestRunIntegratedTSADValidation: @pytest.mark.anyio async def test_empty_dataset_path_returns_error(self): data = await call_tool( - mcp, "run_integrated_tsad", + mcp, + "run_integrated_tsad", {"dataset_path": "", "timestamp_column": "ts", "target_columns": ["val"]}, ) assert "error" in data @@ -199,8 +223,13 @@ async def test_empty_dataset_path_returns_error(self): @pytest.mark.anyio async def test_empty_target_columns_returns_error(self): data = await call_tool( - mcp, "run_integrated_tsad", - {"dataset_path": "/tmp/data.csv", "timestamp_column": "ts", "target_columns": []}, + mcp, + "run_integrated_tsad", + { + "dataset_path": "/tmp/data.csv", + "timestamp_column": "ts", + "target_columns": [], + }, ) assert "error" in data assert "target_columns" in data["error"] @@ -208,6 +237,7 @@ async def test_empty_target_columns_returns_error(self): # ── Integration tests (requires tsfm_public) ───────────────────────────────── + @requires_tsfm class TestTSFMForecastingIntegration: @pytest.mark.anyio @@ -218,15 +248,18 @@ async def test_forecasting_returns_results_file(self, tmp_path): # Create a small synthetic sine-wave CSV n = 200 - df = pd.DataFrame({ - "Timestamp": pd.date_range("2024-01-01", periods=n, freq="15min"), - "sensor_1": np.sin(np.linspace(0, 4 * np.pi, n)), - }) + df = pd.DataFrame( + { + "Timestamp": pd.date_range("2024-01-01", periods=n, freq="15min"), + "sensor_1": np.sin(np.linspace(0, 4 * np.pi, n)), + } + ) csv_path = str(tmp_path / "synthetic.csv") df.to_csv(csv_path, index=False) data = await call_tool( - mcp, "run_tsfm_forecasting", + mcp, + "run_tsfm_forecasting", { "dataset_path": csv_path, "timestamp_column": "Timestamp", @@ -249,15 +282,19 @@ async def test_integrated_tsad_returns_csv(self, tmp_path): import numpy as np n = 300 - df = pd.DataFrame({ - "Timestamp": pd.date_range("2024-01-01", periods=n, freq="15min"), - "sensor_1": np.sin(np.linspace(0, 6 * np.pi, n)) + np.random.randn(n) * 0.05, - }) + df = pd.DataFrame( + { + "Timestamp": pd.date_range("2024-01-01", periods=n, freq="15min"), + "sensor_1": np.sin(np.linspace(0, 6 * np.pi, n)) + + np.random.randn(n) * 0.05, + } + ) csv_path = str(tmp_path / "synthetic_ad.csv") df.to_csv(csv_path, index=False) data = await call_tool( - mcp, "run_integrated_tsad", + mcp, + "run_integrated_tsad", { "dataset_path": csv_path, "timestamp_column": "Timestamp", diff --git a/src/servers/utilities/main.py b/src/servers/utilities/main.py index 48e1858b8..42858783c 100644 --- a/src/servers/utilities/main.py +++ b/src/servers/utilities/main.py @@ -13,11 +13,16 @@ # Setup logging — default WARNING so stderr stays quiet when used as MCP server; # set LOG_LEVEL=INFO (or DEBUG) in the environment to see verbose output. -_log_level = getattr(logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING) +_log_level = getattr( + logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING +) logging.basicConfig(level=_log_level) logger = logging.getLogger("utilities-mcp-server") -mcp = FastMCP("utilities", instructions="General utilities: read JSON files and get current date/time.") +mcp = FastMCP( + "utilities", + instructions="General utilities: read JSON files and get current date/time.", +) class DateTimeResult(BaseModel): @@ -75,7 +80,9 @@ def current_date_time() -> DateTimeResult: description = f"Today's date is {date_part} and time is {time_part}." - return DateTimeResult(currentDateTime=now_iso, currentDateTimeDescription=description) + return DateTimeResult( + currentDateTime=now_iso, currentDateTimeDescription=description + ) @mcp.tool(title="Get Current Time in English") diff --git a/src/servers/vibration/data_store.py b/src/servers/vibration/data_store.py index e546899aa..b5088691b 100644 --- a/src/servers/vibration/data_store.py +++ b/src/servers/vibration/data_store.py @@ -86,9 +86,7 @@ def summary(self) -> dict: "sample_rate_hz": self.sample_rate, "duration_s": round(self.duration_s, 4), "channel_stats": channel_stats, - "metadata": { - k: v for k, v in self.metadata.items() if k != "axis_labels" - }, + "metadata": {k: v for k, v in self.metadata.items() if k != "axis_labels"}, } diff --git a/src/servers/vibration/dsp/__init__.py b/src/servers/vibration/dsp/__init__.py index 4f7dc3977..3522139c4 100644 --- a/src/servers/vibration/dsp/__init__.py +++ b/src/servers/vibration/dsp/__init__.py @@ -1,9 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/LGDiMaggio/claude-stwinbox-diagnostics/tree/main/mcp-servers/vibration-analysis-mcp -from .fft_analysis import compute_fft, compute_psd, compute_spectrogram, find_peaks_in_spectrum +from .fft_analysis import ( + compute_fft, + compute_psd, + compute_spectrogram, + find_peaks_in_spectrum, +) from .envelope import envelope_spectrum, check_bearing_peaks -from .bearing_freqs import compute_bearing_frequencies, get_bearing, list_bearings, COMMON_BEARINGS +from .bearing_freqs import ( + compute_bearing_frequencies, + get_bearing, + list_bearings, + COMMON_BEARINGS, +) from .fault_detection import ( assess_iso10816, extract_shaft_features, diff --git a/src/servers/vibration/dsp/bearing_freqs.py b/src/servers/vibration/dsp/bearing_freqs.py index 149fb9975..9a22da7de 100644 --- a/src/servers/vibration/dsp/bearing_freqs.py +++ b/src/servers/vibration/dsp/bearing_freqs.py @@ -91,8 +91,10 @@ def compute_bearing_frequencies( ftf = f_shaft * 0.5 * (1.0 - ratio * math.cos(alpha_rad)) bpfo = f_shaft * (n_balls / 2.0) * (1.0 - ratio * math.cos(alpha_rad)) bpfi = f_shaft * (n_balls / 2.0) * (1.0 + ratio * math.cos(alpha_rad)) - bsf = f_shaft * (pitch_dia / (2.0 * ball_dia)) * ( - 1.0 - (ratio * math.cos(alpha_rad)) ** 2 + bsf = ( + f_shaft + * (pitch_dia / (2.0 * ball_dia)) + * (1.0 - (ratio * math.cos(alpha_rad)) ** 2) ) return BearingFrequencies( diff --git a/src/servers/vibration/dsp/envelope.py b/src/servers/vibration/dsp/envelope.py index f60b23ec3..90888f33a 100644 --- a/src/servers/vibration/dsp/envelope.py +++ b/src/servers/vibration/dsp/envelope.py @@ -189,7 +189,9 @@ def check_bearing_peaks( "harmonics_checked": n_harmonics, "harmonics_detected": detected_count, "confidence": ( - "high" if detected_count >= 2 else ("medium" if detected_count == 1 else "none") + "high" + if detected_count >= 2 + else ("medium" if detected_count == 1 else "none") ), "details": results, } diff --git a/src/servers/vibration/dsp/fault_detection.py b/src/servers/vibration/dsp/fault_detection.py index d613a8147..a588a7c36 100644 --- a/src/servers/vibration/dsp/fault_detection.py +++ b/src/servers/vibration/dsp/fault_detection.py @@ -51,9 +51,7 @@ def assess_iso10816( Returns: dict with zone (A/B/C/D), description, and thresholds used. """ - thresholds = ISO_10816_THRESHOLDS.get( - machine_group, ISO_10816_THRESHOLDS["group2"] - ) + thresholds = ISO_10816_THRESHOLDS.get(machine_group, ISO_10816_THRESHOLDS["group2"]) if rms_velocity_mm_s <= thresholds["A_good"]: zone, desc = "A", "Good - newly commissioned machines" @@ -246,9 +244,7 @@ def classify_faults( # --- Mechanical looseness: many harmonics + sub-harmonics --- n_significant = sum( - 1 - for a in [features.amp_1x, features.amp_2x, features.amp_3x] - if a / rms > 1.5 + 1 for a in [features.amp_1x, features.amp_2x, features.amp_3x] if a / rms > 1.5 ) if n_significant >= 3 or (features.amp_half_x / rms > 1.5): evidence = [f"Harmonics above threshold: {n_significant}/3"] diff --git a/src/servers/vibration/main.py b/src/servers/vibration/main.py index 398da8bd6..bcc8985b0 100644 --- a/src/servers/vibration/main.py +++ b/src/servers/vibration/main.py @@ -43,7 +43,10 @@ logging.basicConfig(level=_log_level) logger = logging.getLogger("vibration-mcp-server") -mcp = FastMCP("vibration", instructions="Vibration signal analysis: FFT, envelope spectrum, bearing fault detection, and ISO 10816 severity assessment.") +mcp = FastMCP( + "vibration", + instructions="Vibration signal analysis: FFT, envelope spectrum, bearing fault detection, and ISO 10816 severity assessment.", +) # --------------------------------------------------------------------------- diff --git a/src/servers/vibration/sample_data/generate_synthetic_vibration.py b/src/servers/vibration/sample_data/generate_synthetic_vibration.py index b6c4d4882..137678a15 100644 --- a/src/servers/vibration/sample_data/generate_synthetic_vibration.py +++ b/src/servers/vibration/sample_data/generate_synthetic_vibration.py @@ -32,6 +32,7 @@ python generate_synthetic_vibration.py # writes JSON to cwd python generate_synthetic_vibration.py --check # writes JSON + prints stats """ + from __future__ import annotations import argparse @@ -44,30 +45,30 @@ # --------------------------------------------------------------------------- # Machine / bearing parameters # --------------------------------------------------------------------------- -FS = 4096 # sampling rate [Hz] -DURATION = 1.0 # seconds -RPM = 1800 # shaft speed +FS = 4096 # sampling rate [Hz] +DURATION = 1.0 # seconds +RPM = 1800 # shaft speed F_SHAFT = RPM / 60 # shaft frequency [Hz] # SKF 6205-2RS (common small motor bearing) N_BALLS = 9 -BD = 7.94 # ball diameter [mm] -PD = 39.04 # pitch diameter [mm] -ALPHA = 0.0 # contact angle [rad] +BD = 7.94 # ball diameter [mm] +PD = 39.04 # pitch diameter [mm] +ALPHA = 0.0 # contact angle [rad] # Derived characteristic frequencies BPFO = N_BALLS / 2 * F_SHAFT * (1 - BD / PD * np.cos(ALPHA)) # ~107.5 Hz # Resonance and damping -F_RESONANCE = 3200.0 # structural resonance [Hz] -DAMPING = 5000.0 # exponential decay rate [1/s] (fast → sharp impulses) -IMPULSE_AMP = 2.0 # peak impulse amplitude [g] -LOAD_MOD = 0.5 # load-zone modulation depth (0 = none, 1 = full) +F_RESONANCE = 3200.0 # structural resonance [Hz] +DAMPING = 5000.0 # exponential decay rate [1/s] (fast → sharp impulses) +IMPULSE_AMP = 2.0 # peak impulse amplitude [g] +LOAD_MOD = 0.5 # load-zone modulation depth (0 = none, 1 = full) # Background -SHAFT_1X = 0.10 # 1× shaft amplitude [g] -SHAFT_2X = 0.04 # 2× shaft amplitude [g] -NOISE_STD = 0.02 # broadband noise σ [g] +SHAFT_1X = 0.10 # 1× shaft amplitude [g] +SHAFT_2X = 0.04 # 2× shaft amplitude [g] +NOISE_STD = 0.02 # broadband noise σ [g] # Time origin (arbitrary) T0 = datetime(2024, 1, 15, 0, 0, 0) @@ -82,8 +83,9 @@ def generate() -> tuple[np.ndarray, np.ndarray]: t = np.arange(n_samples) / FS # Shaft harmonics (healthy background) - shaft = SHAFT_1X * np.sin(2 * np.pi * F_SHAFT * t) + \ - SHAFT_2X * np.sin(2 * np.pi * 2 * F_SHAFT * t) + shaft = SHAFT_1X * np.sin(2 * np.pi * F_SHAFT * t) + SHAFT_2X * np.sin( + 2 * np.pi * 2 * F_SHAFT * t + ) # Bearing fault impulses at BPFO impulse_times = np.arange(0, DURATION, 1.0 / BPFO) @@ -93,8 +95,12 @@ def generate() -> tuple[np.ndarray, np.ndarray]: mask = dt >= 0 # Load-zone amplitude modulation amp = 1.0 + LOAD_MOD * np.cos(2 * np.pi * F_SHAFT * t_imp) - ring = amp * IMPULSE_AMP * np.exp(-DAMPING * dt[mask]) * \ - np.sin(2 * np.pi * F_RESONANCE * dt[mask]) + ring = ( + amp + * IMPULSE_AMP + * np.exp(-DAMPING * dt[mask]) + * np.sin(2 * np.pi * F_RESONANCE * dt[mask]) + ) bearing[mask] += ring noise = NOISE_STD * rng.standard_normal(n_samples) @@ -115,8 +121,9 @@ def to_couchdb_docs(t: np.ndarray, signal: np.ndarray) -> list[dict]: def main() -> None: parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--check", action="store_true", - help="Print signal statistics after generation") + parser.add_argument( + "--check", action="store_true", help="Print signal statistics after generation" + ) args = parser.parse_args() t, signal = generate() @@ -129,11 +136,12 @@ def main() -> None: print(f"Wrote {len(docs)} documents to {out}") if args.check: - rms = float(np.sqrt(np.mean(signal ** 2))) + rms = float(np.sqrt(np.mean(signal**2))) peak = float(np.max(np.abs(signal))) # Excess kurtosis with sample std (ddof=1), consistent with main.py - kurt = float(np.mean((signal - signal.mean()) ** 4) / - np.std(signal, ddof=1) ** 4 - 3) + kurt = float( + np.mean((signal - signal.mean()) ** 4) / np.std(signal, ddof=1) ** 4 - 3 + ) print(f" BPFO: {BPFO:.2f} Hz") print(f" f_shaft: {F_SHAFT:.1f} Hz") print(f" f_resonance: {F_RESONANCE:.1f} Hz") diff --git a/src/servers/vibration/tests/test_dsp.py b/src/servers/vibration/tests/test_dsp.py index 181d93d04..1d029345c 100644 --- a/src/servers/vibration/tests/test_dsp.py +++ b/src/servers/vibration/tests/test_dsp.py @@ -219,8 +219,9 @@ def test_basic(self): freqs = np.array(fft["frequencies"]) mags = np.array(fft["magnitude"]) shaft_freq = 50.0 # as if rpm=3000 - features = extract_shaft_features(freqs, mags, shaft_freq, - time_signal=COMPOSITE) + features = extract_shaft_features( + freqs, mags, shaft_freq, time_signal=COMPOSITE + ) assert features.f_shaft == 50.0 assert features.amp_1x > 0 diff --git a/src/servers/vibration/tests/test_mcp_e2e.py b/src/servers/vibration/tests/test_mcp_e2e.py index 8892da325..e2ef5079a 100644 --- a/src/servers/vibration/tests/test_mcp_e2e.py +++ b/src/servers/vibration/tests/test_mcp_e2e.py @@ -38,7 +38,11 @@ import anyio import pytest from mcp.client.session import ClientSession -from mcp.client.stdio import StdioServerParameters, get_default_environment, stdio_client +from mcp.client.stdio import ( + StdioServerParameters, + get_default_environment, + stdio_client, +) # --------------------------------------------------------------------------- # Constants @@ -64,12 +68,21 @@ def _find_repo_root(start: Path) -> Path: # LLM credentials that must not reach the test subprocess. # Prevents accidental billable API calls if server-side logic is ever changed. -_SENSITIVE_KEYS: frozenset[str] = frozenset({ - "WATSONX_APIKEY", "WATSONX_PROJECT_ID", "WATSONX_URL", - "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "LITELLM_API_KEY", - "LITELLM_BASE_URL", "COHERE_API_KEY", "AZURE_API_KEY", - "AZURE_API_BASE", "HUGGINGFACE_API_KEY", -}) +_SENSITIVE_KEYS: frozenset[str] = frozenset( + { + "WATSONX_APIKEY", + "WATSONX_PROJECT_ID", + "WATSONX_URL", + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + "LITELLM_API_KEY", + "LITELLM_BASE_URL", + "COHERE_API_KEY", + "AZURE_API_KEY", + "AZURE_API_BASE", + "HUGGINGFACE_API_KEY", + } +) # --------------------------------------------------------------------------- # Helpers @@ -173,7 +186,9 @@ class TestVibrationMCPProtocol: @pytest.mark.anyio async def test_sc01_tool_listing(self, vibration_session: ClientSession) -> None: """SC-01: Server starts and exposes expected tools over stdio.""" - tools = await asyncio.wait_for(vibration_session.list_tools(), timeout=_DEADLINE) + tools = await asyncio.wait_for( + vibration_session.list_tools(), timeout=_DEADLINE + ) names = {t.name for t in tools.tools} expected = { "get_vibration_data", @@ -188,7 +203,9 @@ async def test_sc01_tool_listing(self, vibration_session: ClientSession) -> None assert expected <= names, f"Missing tools: {expected - names}" @pytest.mark.anyio - async def test_sc02_static_tool_happy_path(self, vibration_session: ClientSession) -> None: + async def test_sc02_static_tool_happy_path( + self, vibration_session: ClientSession + ) -> None: """SC-02: list_known_bearings returns static database without CouchDB.""" result = await asyncio.wait_for( vibration_session.call_tool("list_known_bearings", {}), @@ -201,16 +218,26 @@ async def test_sc02_static_tool_happy_path(self, vibration_session: ClientSessio assert any("6205" in n for n in names), f"6205 not found in {names}" @pytest.mark.anyio - async def test_sc03_iso_severity_zone_classification(self, vibration_session: ClientSession) -> None: + async def test_sc03_iso_severity_zone_classification( + self, vibration_session: ClientSession + ) -> None: """SC-03: assess_vibration_severity classifies ISO 10816 zones correctly.""" - zone_d = _parse_result(await asyncio.wait_for( - vibration_session.call_tool("assess_vibration_severity", {"rms_velocity_mm_s": 50.0}), - timeout=_DEADLINE, - )) - zone_a = _parse_result(await asyncio.wait_for( - vibration_session.call_tool("assess_vibration_severity", {"rms_velocity_mm_s": 0.5}), - timeout=_DEADLINE, - )) + zone_d = _parse_result( + await asyncio.wait_for( + vibration_session.call_tool( + "assess_vibration_severity", {"rms_velocity_mm_s": 50.0} + ), + timeout=_DEADLINE, + ) + ) + zone_a = _parse_result( + await asyncio.wait_for( + vibration_session.call_tool( + "assess_vibration_severity", {"rms_velocity_mm_s": 0.5} + ), + timeout=_DEADLINE, + ) + ) assert zone_d.get("iso_zone") == "D", f"Expected D, got: {zone_d}" assert zone_a.get("iso_zone") == "A", f"Expected A, got: {zone_a}" diff --git a/src/servers/vibration/tests/test_tools.py b/src/servers/vibration/tests/test_tools.py index 8e081df76..fc3298ad6 100644 --- a/src/servers/vibration/tests/test_tools.py +++ b/src/servers/vibration/tests/test_tools.py @@ -16,8 +16,13 @@ # Helpers # --------------------------------------------------------------------------- -def _make_sine(freq_hz: float = 50.0, sr: float = 2048.0, - duration: float = 1.0, amplitude: float = 1.0) -> tuple: + +def _make_sine( + freq_hz: float = 50.0, + sr: float = 2048.0, + duration: float = 1.0, + amplitude: float = 1.0, +) -> tuple: """Generate a pure sine wave and store it; return (data_id, signal, sr).""" t = np.arange(0, duration, 1.0 / sr) sig = amplitude * np.sin(2 * np.pi * freq_hz * t) @@ -26,8 +31,9 @@ def _make_sine(freq_hz: float = 50.0, sr: float = 2048.0, return data_id, sig, sr -def _make_composite(freqs: list[float], sr: float = 4096.0, - duration: float = 2.0) -> str: +def _make_composite( + freqs: list[float], sr: float = 4096.0, duration: float = 2.0 +) -> str: """Composite signal with multiple sine components; returns data_id.""" t = np.arange(0, duration, 1.0 / sr) sig = np.zeros_like(t) @@ -56,16 +62,18 @@ async def test_basic_50hz(self): @pytest.mark.anyio async def test_missing_data_id(self): - result = await call_tool(mcp, "compute_fft_spectrum", - {"data_id": "nonexistent"}) + result = await call_tool( + mcp, "compute_fft_spectrum", {"data_id": "nonexistent"} + ) assert "error" in result @pytest.mark.anyio async def test_window_types(self): data_id, _, _ = _make_sine(100.0) for win in ("hann", "hamming", "blackman", "rectangular"): - result = await call_tool(mcp, "compute_fft_spectrum", - {"data_id": data_id, "window": win}) + result = await call_tool( + mcp, "compute_fft_spectrum", {"data_id": data_id, "window": win} + ) assert "error" not in result assert result["window"] == win @@ -79,16 +87,14 @@ class TestComputeEnvelopeSpectrum: @pytest.mark.anyio async def test_basic_run(self): data_id, _, _ = _make_sine(120.0, sr=4096.0) - result = await call_tool(mcp, "compute_envelope_spectrum", - {"data_id": data_id}) + result = await call_tool(mcp, "compute_envelope_spectrum", {"data_id": data_id}) assert "error" not in result assert "filter_band_hz" in result assert result["sample_rate_hz"] == 4096.0 @pytest.mark.anyio async def test_missing_data_id(self): - result = await call_tool(mcp, "compute_envelope_spectrum", - {"data_id": "nope"}) + result = await call_tool(mcp, "compute_envelope_spectrum", {"data_id": "nope"}) assert "error" in result @@ -100,22 +106,26 @@ async def test_missing_data_id(self): class TestAssessVibrationSeverity: @pytest.mark.anyio async def test_zone_a(self): - result = await call_tool(mcp, "assess_vibration_severity", - {"rms_velocity_mm_s": 0.5}) + result = await call_tool( + mcp, "assess_vibration_severity", {"rms_velocity_mm_s": 0.5} + ) assert result["iso_zone"] == "A" @pytest.mark.anyio async def test_zone_d(self): - result = await call_tool(mcp, "assess_vibration_severity", - {"rms_velocity_mm_s": 50.0}) + result = await call_tool( + mcp, "assess_vibration_severity", {"rms_velocity_mm_s": 50.0} + ) assert result["iso_zone"] == "D" @pytest.mark.anyio async def test_group_param(self): for grp in ("group1", "group2", "group3", "group4"): - result = await call_tool(mcp, "assess_vibration_severity", - {"rms_velocity_mm_s": 4.5, - "machine_group": grp}) + result = await call_tool( + mcp, + "assess_vibration_severity", + {"rms_velocity_mm_s": 4.5, "machine_group": grp}, + ) assert result["iso_zone"] in ("A", "B", "C", "D") @@ -142,13 +152,17 @@ async def test_returns_bearings(self): class TestCalculateBearingFrequencies: @pytest.mark.anyio async def test_basic(self): - result = await call_tool(mcp, "calculate_bearing_frequencies", { - "rpm": 1800, - "n_balls": 9, - "ball_diameter_mm": 7.94, - "pitch_diameter_mm": 39.04, - "contact_angle_deg": 0.0, - }) + result = await call_tool( + mcp, + "calculate_bearing_frequencies", + { + "rpm": 1800, + "n_balls": 9, + "ball_diameter_mm": 7.94, + "pitch_diameter_mm": 39.04, + "contact_angle_deg": 0.0, + }, + ) assert "bpfo_hz" in result assert "bpfi_hz" in result assert "bsf_hz" in result @@ -157,13 +171,17 @@ async def test_basic(self): @pytest.mark.anyio async def test_with_name(self): - result = await call_tool(mcp, "calculate_bearing_frequencies", { - "rpm": 3600, - "n_balls": 8, - "ball_diameter_mm": 10.0, - "pitch_diameter_mm": 46.0, - "bearing_name": "test-bearing", - }) + result = await call_tool( + mcp, + "calculate_bearing_frequencies", + { + "rpm": 3600, + "n_balls": 8, + "ball_diameter_mm": 10.0, + "pitch_diameter_mm": 46.0, + "bearing_name": "test-bearing", + }, + ) assert "bearing" in result assert result["bearing"] == "test-bearing" @@ -178,9 +196,13 @@ class TestDiagnoseVibration: async def test_no_rpm(self): """Without RPM we expect a partial result with a warning.""" data_id, _, _ = _make_sine(120.0, sr=4096.0, duration=2.0) - result = await call_tool(mcp, "diagnose_vibration", { - "data_id": data_id, - }) + result = await call_tool( + mcp, + "diagnose_vibration", + { + "data_id": data_id, + }, + ) assert "error" not in result assert "warning" in result assert result["shaft_features"] is None @@ -188,10 +210,14 @@ async def test_no_rpm(self): @pytest.mark.anyio async def test_with_rpm(self): data_id = _make_composite([30, 60, 90], sr=4096.0, duration=2.0) - result = await call_tool(mcp, "diagnose_vibration", { - "data_id": data_id, - "rpm": 1800.0, - }) + result = await call_tool( + mcp, + "diagnose_vibration", + { + "data_id": data_id, + "rpm": 1800.0, + }, + ) assert "error" not in result assert result["shaft_features"] is not None assert result["iso_10816"] is not None @@ -200,11 +226,15 @@ async def test_with_rpm(self): @pytest.mark.anyio async def test_with_bearing_designation(self): data_id = _make_composite([30, 60, 120], sr=4096.0, duration=2.0) - result = await call_tool(mcp, "diagnose_vibration", { - "data_id": data_id, - "rpm": 1800.0, - "bearing_designation": "6205", - }) + result = await call_tool( + mcp, + "diagnose_vibration", + { + "data_id": data_id, + "rpm": 1800.0, + "bearing_designation": "6205", + }, + ) assert "error" not in result assert result["bearing_info_source"] is not None assert "database" in result["bearing_info_source"] @@ -212,20 +242,23 @@ async def test_with_bearing_designation(self): @pytest.mark.anyio async def test_with_custom_bearing_geometry(self): data_id = _make_composite([30, 60], sr=4096.0, duration=2.0) - result = await call_tool(mcp, "diagnose_vibration", { - "data_id": data_id, - "rpm": 1800.0, - "bearing_n_balls": 9, - "bearing_ball_dia_mm": 7.94, - "bearing_pitch_dia_mm": 39.04, - }) + result = await call_tool( + mcp, + "diagnose_vibration", + { + "data_id": data_id, + "rpm": 1800.0, + "bearing_n_balls": 9, + "bearing_ball_dia_mm": 7.94, + "bearing_pitch_dia_mm": 39.04, + }, + ) assert "error" not in result assert result["bearing_info_source"] == "custom geometry" @pytest.mark.anyio async def test_missing_data_id(self): - result = await call_tool(mcp, "diagnose_vibration", - {"data_id": "ghost"}) + result = await call_tool(mcp, "diagnose_vibration", {"data_id": "ghost"}) assert "error" in result @@ -238,12 +271,16 @@ class TestGetVibrationData: @requires_couchdb @pytest.mark.anyio async def test_fetch_integration(self): - result = await call_tool(mcp, "get_vibration_data", { - "site_name": "MAIN", - "asset_id": "Motor_01", - "sensor_name": "Vibration_X", - "start": "2024-01-15T00:00:00", - }) + result = await call_tool( + mcp, + "get_vibration_data", + { + "site_name": "MAIN", + "asset_id": "Motor_01", + "sensor_name": "Vibration_X", + "start": "2024-01-15T00:00:00", + }, + ) assert "error" not in result assert "data_id" in result @@ -257,8 +294,12 @@ class TestListVibrationSensors: @requires_couchdb @pytest.mark.anyio async def test_list_integration(self): - result = await call_tool(mcp, "list_vibration_sensors", { - "site_name": "MAIN", - "asset_id": "Chiller 6", - }) + result = await call_tool( + mcp, + "list_vibration_sensors", + { + "site_name": "MAIN", + "asset_id": "Chiller 6", + }, + ) assert "sensors" in result or "error" in result diff --git a/src/servers/wo/couch.py b/src/servers/wo/couch.py index 139ede740..cd694cdff 100644 --- a/src/servers/wo/couch.py +++ b/src/servers/wo/couch.py @@ -7,6 +7,7 @@ typed), so they can be unit-tested against an in-memory fake (see test_workorders.py) without a running CouchDB. """ + from __future__ import annotations from typing import Any, Dict, List, Optional @@ -22,13 +23,24 @@ class CouchError(Exception): class CouchClient: - def __init__(self, base_url: str, db: str, *, username: Optional[str] = None, - password: Optional[str] = None, timeout: float = 10.0): + def __init__( + self, + base_url: str, + db: str, + *, + username: Optional[str] = None, + password: Optional[str] = None, + timeout: float = 10.0, + ): if httpx is None: - raise CouchError("httpx is required for the real CouchClient (pip install httpx)") + raise CouchError( + "httpx is required for the real CouchClient (pip install httpx)" + ) self.db = db auth = (username, password) if username else None - self._c = httpx.AsyncClient(base_url=base_url.rstrip("/"), auth=auth, timeout=timeout) + self._c = httpx.AsyncClient( + base_url=base_url.rstrip("/"), auth=auth, timeout=timeout + ) async def aclose(self) -> None: await self._c.aclose() @@ -56,9 +68,15 @@ async def delete(self, doc_id: str, rev: str) -> Dict[str, Any]: return r.json() # ---- queries ---- - async def find(self, selector: Dict[str, Any], *, fields: Optional[List[str]] = None, - sort: Optional[List[Dict[str, str]]] = None, limit: int = 200, - skip: int = 0) -> List[Dict[str, Any]]: + async def find( + self, + selector: Dict[str, Any], + *, + fields: Optional[List[str]] = None, + sort: Optional[List[Dict[str, str]]] = None, + limit: int = 200, + skip: int = 0, + ) -> List[Dict[str, Any]]: body: Dict[str, Any] = {"selector": selector, "limit": limit, "skip": skip} if fields: body["fields"] = fields @@ -71,8 +89,11 @@ async def find(self, selector: Dict[str, Any], *, fields: Optional[List[str]] = async def view(self, ddoc: str, view: str, **params: Any) -> Dict[str, Any]: # CouchDB expects JSON-encoded key/startkey/endkey params. import json as _json - q = {k: (_json.dumps(v) if k in ("key", "startkey", "endkey") else v) - for k, v in params.items()} + + q = { + k: (_json.dumps(v) if k in ("key", "startkey", "endkey") else v) + for k, v in params.items() + } r = await self._c.get(f"/{self.db}/_design/{ddoc}/_view/{view}", params=q) r.raise_for_status() return r.json() @@ -94,4 +115,4 @@ async def next_wonum(self, site_id: str) -> str: return str(doc["value"]) except CouchError: continue - raise CouchError("could not allocate wonum (counter contention)") \ No newline at end of file + raise CouchError("could not allocate wonum (counter contention)") diff --git a/src/servers/wo/envelope.py b/src/servers/wo/envelope.py index 195600fa8..43d6882c6 100644 --- a/src/servers/wo/envelope.py +++ b/src/servers/wo/envelope.py @@ -4,14 +4,20 @@ error_code}`) so AssetOpsBench agents written against the real Maximo server work unchanged against this benchmark server. """ + from __future__ import annotations import time from typing import Any, Dict, Optional -def envelope(data: Any, *, cached: bool = False, duration_ms: int = 0, - record_count: Optional[int] = None) -> Dict[str, Any]: +def envelope( + data: Any, + *, + cached: bool = False, + duration_ms: int = 0, + record_count: Optional[int] = None, +) -> Dict[str, Any]: meta: Dict[str, Any] = {"cached": cached, "duration_ms": duration_ms} if record_count is not None: meta["record_count"] = record_count @@ -34,4 +40,4 @@ def __exit__(self, *exc) -> None: @property def ms(self) -> int: - return int((time.monotonic() - self._start) * 1000) \ No newline at end of file + return int((time.monotonic() - self._start) * 1000) diff --git a/src/servers/wo/main.py b/src/servers/wo/main.py index 2b92a73b7..add5dda14 100644 --- a/src/servers/wo/main.py +++ b/src/servers/wo/main.py @@ -21,21 +21,30 @@ from . import workorders as wo from .couch import CouchClient from .models import ( - ActualsVsPlannedResult, CostsResult, ErrorResult, KpiResult, ScheduleResult, - TasksResult, WorkOrderItem, WorkOrderMutationResult, WorkOrderResult, + ActualsVsPlannedResult, + CostsResult, + ErrorResult, + KpiResult, + ScheduleResult, + TasksResult, + WorkOrderItem, + WorkOrderMutationResult, + WorkOrderResult, WorkOrdersResult, ) load_dotenv() -_log_level = getattr(logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING) +_log_level = getattr( + logging, os.environ.get("LOG_LEVEL", "WARNING").upper(), logging.WARNING +) logging.basicConfig(level=_log_level) mcp = FastMCP( "wo", instructions="Work order lifecycle for industrial assets, backed by CouchDB. Query, " - "create, approve, assign, close, and cancel work orders; compute KPIs, " - "costs, and schedules. Documents use IBM Maximo mxwo field names.", + "create, approve, assign, close, and cancel work orders; compute KPIs, " + "costs, and schedules. Documents use IBM Maximo mxwo field names.", ) _db: Optional[CouchClient] = None @@ -65,45 +74,76 @@ def _failed(res: Dict[str, Any]) -> Optional[ErrorResult]: return None -def _mutation(res: Dict[str, Any], verb: str) -> Union[WorkOrderMutationResult, ErrorResult]: +def _mutation( + res: Dict[str, Any], verb: str +) -> Union[WorkOrderMutationResult, ErrorResult]: err = _failed(res) if err: return err doc = res["data"] return WorkOrderMutationResult( - wonum=doc.get("wonum"), siteid=doc.get("siteid"), status=doc.get("status"), + wonum=doc.get("wonum"), + siteid=doc.get("siteid"), + status=doc.get("status"), work_order=WorkOrderItem.model_validate(doc), - message=f"Work order {doc.get('wonum')} {verb}.") - - -async def list_workorders(site_id: Optional[str] = None, status: Optional[str] = None, - asset_num: Optional[str] = None, priority: Optional[int] = None, - date_from: Optional[str] = None, date_to: Optional[str] = None, - page_size: int = 50, page_num: int = 1) -> Union[WorkOrdersResult, ErrorResult]: + message=f"Work order {doc.get('wonum')} {verb}.", + ) + + +async def list_workorders( + site_id: Optional[str] = None, + status: Optional[str] = None, + asset_num: Optional[str] = None, + priority: Optional[int] = None, + date_from: Optional[str] = None, + date_to: Optional[str] = None, + page_size: int = 50, + page_num: int = 1, +) -> Union[WorkOrdersResult, ErrorResult]: """List work orders with optional filters (site, status, asset, priority, dates). status accepts OPEN / APPROVED_PENDING; page_size=0 returns all matches in one call.""" - res = await wo.list_workorders(db(), site_id, status, asset_num, priority, - date_from, date_to, page_size, page_num) + res = await wo.list_workorders( + db(), + site_id, + status, + asset_num, + priority, + date_from, + date_to, + page_size, + page_num, + ) err = _failed(res) if err: return err d = res["data"] items = [WorkOrderItem.model_validate(x) for x in d["workorders"]] - return WorkOrdersResult(site_id=site_id, status=status, total=d["totalCount"], - work_orders=items, message=f"Found {d['totalCount']} work order(s).") - - -async def get_workorder(wonum: str, site_id: str) -> Union[WorkOrderResult, ErrorResult]: + return WorkOrdersResult( + site_id=site_id, + status=status, + total=d["totalCount"], + work_orders=items, + message=f"Found {d['totalCount']} work order(s).", + ) + + +async def get_workorder( + wonum: str, site_id: str +) -> Union[WorkOrderResult, ErrorResult]: """Get a single work order by number and site.""" res = await wo.get_workorder(db(), wonum, site_id) err = _failed(res) if err: return err - return WorkOrderResult(work_order=WorkOrderItem.model_validate(res["data"]), - message=f"Work order {wonum} at {site_id}.") + return WorkOrderResult( + work_order=WorkOrderItem.model_validate(res["data"]), + message=f"Work order {wonum} at {site_id}.", + ) -async def get_workorder_tasks(wonum: str, site_id: str) -> Union[TasksResult, ErrorResult]: +async def get_workorder_tasks( + wonum: str, site_id: str +) -> Union[TasksResult, ErrorResult]: """List the child tasks of a parent work order.""" res = await wo.get_workorder_tasks(db(), wonum, site_id) err = _failed(res) @@ -111,50 +151,74 @@ async def get_workorder_tasks(wonum: str, site_id: str) -> Union[TasksResult, Er return err d = res["data"] tasks = [WorkOrderItem.model_validate(t) for t in d["tasks"]] - return TasksResult(parent_wonum=d["parent_wonum"], site_id=d["site_id"], total=len(tasks), - tasks=tasks, message=f"{len(tasks)} task(s) under {wonum}.") - - -async def get_workorder_costs(wonum: str, site_id: str) -> Union[CostsResult, ErrorResult]: + return TasksResult( + parent_wonum=d["parent_wonum"], + site_id=d["site_id"], + total=len(tasks), + tasks=tasks, + message=f"{len(tasks)} task(s) under {wonum}.", + ) + + +async def get_workorder_costs( + wonum: str, site_id: str +) -> Union[CostsResult, ErrorResult]: """Actual labor/material/service/tool cost breakdown for a work order.""" res = await wo.get_workorder_costs(db(), wonum, site_id) err = _failed(res) if err: return err - return CostsResult.model_validate({**res["data"], "message": f"Cost breakdown for {wonum}."}) + return CostsResult.model_validate( + {**res["data"], "message": f"Cost breakdown for {wonum}."} + ) -async def get_workorder_actuals_vs_planned(wonum: str, site_id: str) -> Union[ActualsVsPlannedResult, ErrorResult]: +async def get_workorder_actuals_vs_planned( + wonum: str, site_id: str +) -> Union[ActualsVsPlannedResult, ErrorResult]: """Estimated vs actual hours and cost variance for a work order.""" res = await wo.get_workorder_actuals_vs_planned(db(), wonum, site_id) err = _failed(res) if err: return err - return ActualsVsPlannedResult.model_validate({**res["data"], "message": f"Actuals vs planned for {wonum}."}) + return ActualsVsPlannedResult.model_validate( + {**res["data"], "message": f"Actuals vs planned for {wonum}."} + ) -async def get_workorder_kpis(site_id: str, period_months: int = 3) -> Union[KpiResult, ErrorResult]: +async def get_workorder_kpis( + site_id: str, period_months: int = 3 +) -> Union[KpiResult, ErrorResult]: """Site KPIs: totals, backlog, overdue, avg completion, priority/asset breakdowns.""" res = await wo.get_workorder_kpis(db(), site_id, period_months) err = _failed(res) if err: return err - return KpiResult.model_validate({**res["data"], "message": f"KPIs for {site_id} over {period_months} month(s)."}) + return KpiResult.model_validate( + {**res["data"], "message": f"KPIs for {site_id} over {period_months} month(s)."} + ) -async def get_schedule_calendar(site_id: str, date_from: Optional[str] = None, - date_to: Optional[str] = None, group_by: str = "date") -> Union[ScheduleResult, ErrorResult]: +async def get_schedule_calendar( + site_id: str, + date_from: Optional[str] = None, + date_to: Optional[str] = None, + group_by: str = "date", +) -> Union[ScheduleResult, ErrorResult]: """Scheduled (non-terminal) work orders in a date window, bucketed by day.""" res = await wo.get_schedule_calendar(db(), site_id, date_from, date_to, group_by) err = _failed(res) if err: return err d = res["data"] - return ScheduleResult.model_validate({**d, "message": f"{d['total_scheduled']} scheduled at {site_id}."}) + return ScheduleResult.model_validate( + {**d, "message": f"{d['total_scheduled']} scheduled at {site_id}."} + ) -async def get_my_assigned_workorders(labor_code: str, site_id: Optional[str] = None, - open_only: bool = True) -> Union[WorkOrdersResult, ErrorResult]: +async def get_my_assigned_workorders( + labor_code: str, site_id: Optional[str] = None, open_only: bool = True +) -> Union[WorkOrdersResult, ErrorResult]: """Work orders assigned to a given technician (labor code).""" res = await wo.get_my_assigned_workorders(db(), labor_code, site_id, open_only) err = _failed(res) @@ -162,59 +226,114 @@ async def get_my_assigned_workorders(labor_code: str, site_id: Optional[str] = N return err d = res["data"] items = [WorkOrderItem.model_validate(x) for x in d["workorders"]] - return WorkOrdersResult(site_id=site_id, labor_code=labor_code, total=d["totalCount"], - work_orders=items, message=f"{d['totalCount']} work order(s) for {labor_code}.") - - -async def generate_work_order(description: str, asset_num: str, site_id: str, - priority: int = 3, work_type: str = "CM", - reported_by: Optional[str] = None, location: Optional[str] = None, - notes: Optional[str] = None, wonum: Optional[str] = None, - aob_source: Optional[Dict[str, Any]] = None) -> Union[WorkOrderMutationResult, ErrorResult]: + return WorkOrdersResult( + site_id=site_id, + labor_code=labor_code, + total=d["totalCount"], + work_orders=items, + message=f"{d['totalCount']} work order(s) for {labor_code}.", + ) + + +async def generate_work_order( + description: str, + asset_num: str, + site_id: str, + priority: int = 3, + work_type: str = "CM", + reported_by: Optional[str] = None, + location: Optional[str] = None, + notes: Optional[str] = None, + wonum: Optional[str] = None, + aob_source: Optional[Dict[str, Any]] = None, +) -> Union[WorkOrderMutationResult, ErrorResult]: """Create a work order (status WAPPR). Attach aob_source provenance (agent/trigger/evidence).""" - res = await wo.create_workorder(db(), description=description, asset_num=asset_num, - site_id=site_id, priority=priority, work_type=work_type, - reported_by=reported_by, location=location, notes=notes, - wonum=wonum, aob_source=aob_source) + res = await wo.create_workorder( + db(), + description=description, + asset_num=asset_num, + site_id=site_id, + priority=priority, + work_type=work_type, + reported_by=reported_by, + location=location, + notes=notes, + wonum=wonum, + aob_source=aob_source, + ) return _mutation(res, "created (WAPPR)") -async def update_workorder(wonum: str, site_id: str, description: Optional[str] = None, - priority: Optional[int] = None, location: Optional[str] = None, - asset_num: Optional[str] = None, notes: Optional[str] = None, - failure_code: Optional[str] = None - ) -> Union[WorkOrderMutationResult, ErrorResult]: +async def update_workorder( + wonum: str, + site_id: str, + description: Optional[str] = None, + priority: Optional[int] = None, + location: Optional[str] = None, + asset_num: Optional[str] = None, + notes: Optional[str] = None, + failure_code: Optional[str] = None, +) -> Union[WorkOrderMutationResult, ErrorResult]: """Update mutable fields on a work order.""" - res = await wo.update_workorder(db(), wonum, site_id, description, priority, location, - asset_num, notes, failure_code) + res = await wo.update_workorder( + db(), + wonum, + site_id, + description, + priority, + location, + asset_num, + notes, + failure_code, + ) return _mutation(res, "updated") -async def approve_workorder(wonum: str, site_id: str) -> Union[WorkOrderMutationResult, ErrorResult]: +async def approve_workorder( + wonum: str, site_id: str +) -> Union[WorkOrderMutationResult, ErrorResult]: """Approve a work order (-> APPR).""" - return _mutation(await wo.approve_workorder(db(), wonum, site_id), "approved (APPR)") - - -async def assign_technician(wonum: str, site_id: str, labor_code: str, craft: Optional[str] = None, - start_date: Optional[str] = None, hours_planned: float = 8.0 - ) -> Union[WorkOrderMutationResult, ErrorResult]: + return _mutation( + await wo.approve_workorder(db(), wonum, site_id), "approved (APPR)" + ) + + +async def assign_technician( + wonum: str, + site_id: str, + labor_code: str, + craft: Optional[str] = None, + start_date: Optional[str] = None, + hours_planned: float = 8.0, +) -> Union[WorkOrderMutationResult, ErrorResult]: """Assign a technician (adds a wplabor line).""" - res = await wo.assign_technician(db(), wonum, site_id, labor_code, craft, start_date, hours_planned) + res = await wo.assign_technician( + db(), wonum, site_id, labor_code, craft, start_date, hours_planned + ) return _mutation(res, f"assigned to {labor_code}") -async def close_workorder(wonum: str, site_id: str, actual_hours: float = 0.0, - failure_code: Optional[str] = None, resolution_notes: Optional[str] = None - ) -> Union[WorkOrderMutationResult, ErrorResult]: +async def close_workorder( + wonum: str, + site_id: str, + actual_hours: float = 0.0, + failure_code: Optional[str] = None, + resolution_notes: Optional[str] = None, +) -> Union[WorkOrderMutationResult, ErrorResult]: """Close a work order (-> COMP) with actuals and resolution.""" - res = await wo.close_workorder(db(), wonum, site_id, actual_hours, failure_code, resolution_notes) + res = await wo.close_workorder( + db(), wonum, site_id, actual_hours, failure_code, resolution_notes + ) return _mutation(res, "closed (COMP)") -async def cancel_workorder(wonum: str, site_id: str, reason: Optional[str] = None - ) -> Union[WorkOrderMutationResult, ErrorResult]: +async def cancel_workorder( + wonum: str, site_id: str, reason: Optional[str] = None +) -> Union[WorkOrderMutationResult, ErrorResult]: """Cancel a work order (-> CAN).""" - return _mutation(await wo.cancel_workorder(db(), wonum, site_id, reason), "cancelled (CAN)") + return _mutation( + await wo.cancel_workorder(db(), wonum, site_id, reason), "cancelled (CAN)" + ) # --------------------------------------------------------------------------- # @@ -239,7 +358,9 @@ async def cancel_workorder(wonum: str, site_id: str, reason: Optional[str] = Non (cancel_workorder, "Cancel Work Order"), ] -_TOOLS = _READ_TOOLS if os.environ.get("AOB_READONLY") == "1" else _READ_TOOLS + _WRITE_TOOLS +_TOOLS = ( + _READ_TOOLS if os.environ.get("AOB_READONLY") == "1" else _READ_TOOLS + _WRITE_TOOLS +) for _fn, _title in _TOOLS: mcp.tool(title=_title)(_fn) @@ -249,4 +370,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/servers/wo/models.py b/src/servers/wo/models.py index 51da4a416..d1fbf144e 100644 --- a/src/servers/wo/models.py +++ b/src/servers/wo/models.py @@ -186,4 +186,4 @@ class ScheduleResult(_Lenient): total_scheduled: int by_date: Optional[List[ScheduleDay]] = None workorders: Optional[List[WorkOrderItem]] = None - message: str \ No newline at end of file + message: str diff --git a/src/servers/wo/tests/test_models_boundary.py b/src/servers/wo/tests/test_models_boundary.py index 48f57fbc4..7f970389c 100644 --- a/src/servers/wo/tests/test_models_boundary.py +++ b/src/servers/wo/tests/test_models_boundary.py @@ -3,13 +3,18 @@ Injects an in-memory fake CouchDB into main._db, so no server/CouchDB is needed. Run: PYTHONPATH=.. python tests/test_models_boundary.py """ + import asyncio from datetime import datetime, timezone from servers.wo.tests.test_workorders import FakeCouch # reuse the in-memory fake from servers.wo import main from servers.wo.models import ( - ErrorResult, WorkOrderItem, WorkOrderMutationResult, WorkOrderResult, WorkOrdersResult, + ErrorResult, + WorkOrderItem, + WorkOrderMutationResult, + WorkOrderResult, + WorkOrdersResult, ) T0 = datetime(2020, 4, 28, 9, 0, 0, tzinfo=timezone.utc) @@ -20,8 +25,14 @@ async def scenario(): main._db = db # inject fake # 1) Output is a Pydantic model, not a dict - created = await main.generate_work_order(description="Chiller 6 anomaly", asset_num="CHILLER6", - site_id="MAIN", priority=2, work_type="PdM", wonum="1000045") + created = await main.generate_work_order( + description="Chiller 6 anomaly", + asset_num="CHILLER6", + site_id="MAIN", + priority=2, + work_type="PdM", + wonum="1000045", + ) assert isinstance(created, WorkOrderMutationResult), type(created) assert created.status == "WAPPR" and created.work_order.assetnum == "CHILLER6" @@ -37,9 +48,16 @@ async def scenario(): assert isinstance(missing, ErrorResult) and "not found" in missing.error.lower() # 3) THE KEY CASE: a work order missing most columns still converts cleanly - partial = {"_id": "wo:MAIN:2000", "type": "workorder", "wonum": "2000", - "siteid": "MAIN", "status": "APPR", "worktype": "CM", - "description": "partial doc", "reportdate": "2020-01-01T00:00:00+00:00"} + partial = { + "_id": "wo:MAIN:2000", + "type": "workorder", + "wonum": "2000", + "siteid": "MAIN", + "status": "APPR", + "worktype": "CM", + "description": "partial doc", + "reportdate": "2020-01-01T00:00:00+00:00", + } # store it directly and read back through the typed boundary await db.put(partial) res = await main.get_workorder("2000", "MAIN") @@ -51,9 +69,17 @@ async def scenario(): assert wo.wplabor is None and wo.failurecode is None # 4) extra/unmodeled fields are preserved (extra='allow') - extra_doc = {"_id": "wo:MAIN:2001", "type": "workorder", "wonum": "2001", "siteid": "MAIN", - "status": "APPR", "worktype": "CM", "description": "x", - "reportdate": "2020-01-01T00:00:00+00:00", "custom_field": "kept"} + extra_doc = { + "_id": "wo:MAIN:2001", + "type": "workorder", + "wonum": "2001", + "siteid": "MAIN", + "status": "APPR", + "worktype": "CM", + "description": "x", + "reportdate": "2020-01-01T00:00:00+00:00", + "custom_field": "kept", + } await db.put(extra_doc) r2 = await main.get_workorder("2001", "MAIN") assert r2.work_order.model_dump().get("custom_field") == "kept" diff --git a/src/servers/wo/tests/test_workorders.py b/src/servers/wo/tests/test_workorders.py index 25a4b4b6d..26eab2f02 100644 --- a/src/servers/wo/tests/test_workorders.py +++ b/src/servers/wo/tests/test_workorders.py @@ -2,6 +2,7 @@ Run: python -m pytest tests/ -q (or just `python tests/test_workorders.py`) """ + from __future__ import annotations import asyncio @@ -59,7 +60,9 @@ def _match(self, doc, selector): return False if op == "$elemMatch": if not isinstance(val, list) or not any( - all(it.get(ik) == iv for ik, iv in arg.items()) for it in val): + all(it.get(ik) == iv for ik, iv in arg.items()) + for it in val + ): return False else: if val != cond: @@ -72,7 +75,7 @@ async def find(self, selector, fields=None, sort=None, limit=200, skip=0): key = list(sort[0].keys())[0] rev = sort[0][key] == "desc" rows.sort(key=lambda d: d.get(key) or "", reverse=rev) - return rows[skip:skip + limit] + return rows[skip : skip + limit] T0 = datetime(2020, 4, 28, 9, 0, 0, tzinfo=timezone.utc) @@ -83,16 +86,28 @@ async def scenario(): # create with provenance + pinned wonum for reproducibility r = await wo.create_workorder( - db, description="Investigate Chiller 6 anomaly", asset_num="CHILLER6", - site_id="MAIN", priority=2, work_type="PdM", reported_by="AGENT.TSFM", - wonum="1000045", now=T0, - aob_source={"agent": "tsfm", "trigger_type": "anomaly_detection", - "scenario_id": "WO-CHILLER6-ANOMALY-001"}) + db, + description="Investigate Chiller 6 anomaly", + asset_num="CHILLER6", + site_id="MAIN", + priority=2, + work_type="PdM", + reported_by="AGENT.TSFM", + wonum="1000045", + now=T0, + aob_source={ + "agent": "tsfm", + "trigger_type": "anomaly_detection", + "scenario_id": "WO-CHILLER6-ANOMALY-001", + }, + ) assert r["success"], r assert r["data"]["_id"] == "wo:MAIN:1000045" assert r["data"]["status"] == "WAPPR" assert "_rev" not in r["data"], "internal _rev must not leak" - assert r["data"]["reportdate"] == "2020-04-28T09:00:00+00:00", "clock must be injectable" + assert r["data"]["reportdate"] == "2020-04-28T09:00:00+00:00", ( + "clock must be injectable" + ) # get g = await wo.get_workorder(db, "1000045", "MAIN") @@ -105,23 +120,45 @@ async def scenario(): # partial update: failure_code set independently, other fields untouched u = await wo.update_workorder(db, "1000045", "MAIN", failure_code="BRG-WEAR") assert u["data"]["failurecode"] == "BRG-WEAR" - assert u["data"]["wopriority"] == 2, "untouched fields must survive a partial update" + assert u["data"]["wopriority"] == 2, ( + "untouched fields must survive a partial update" + ) # approve -> assign -> close - assert (await wo.approve_workorder(db, "1000045", "MAIN", now=T0))["data"]["status"] == "APPR" - a = await wo.assign_technician(db, "1000045", "MAIN", "HVACTECH1", craft="HVAC", - hours_planned=4, now=T0) + assert (await wo.approve_workorder(db, "1000045", "MAIN", now=T0))["data"][ + "status" + ] == "APPR" + a = await wo.assign_technician( + db, "1000045", "MAIN", "HVACTECH1", craft="HVAC", hours_planned=4, now=T0 + ) assert a["data"]["wplabor"][0]["laborcode"] == "HVACTECH1" - c = await wo.close_workorder(db, "1000045", "MAIN", actual_hours=3.5, - failure_code="SENSOR-DRIFT", now=T0) + c = await wo.close_workorder( + db, "1000045", "MAIN", actual_hours=3.5, failure_code="SENSOR-DRIFT", now=T0 + ) assert c["data"]["status"] == "COMP" and c["data"]["actlabhrs"] == 3.5 assert c["data"]["actfinish"] == "2020-04-28T09:00:00+00:00" # auto wonum allocation is sequential/deterministic after reset - w1 = (await wo.create_workorder(db, description="PM", asset_num="AHU2", site_id="MAIN", - work_type="PM", now=T0))["data"]["wonum"] - w2 = (await wo.create_workorder(db, description="PM", asset_num="AHU3", site_id="MAIN", - work_type="PM", now=T0))["data"]["wonum"] + w1 = ( + await wo.create_workorder( + db, + description="PM", + asset_num="AHU2", + site_id="MAIN", + work_type="PM", + now=T0, + ) + )["data"]["wonum"] + w2 = ( + await wo.create_workorder( + db, + description="PM", + asset_num="AHU3", + site_id="MAIN", + work_type="PM", + now=T0, + ) + )["data"]["wonum"] assert int(w2) == int(w1) + 1, (w1, w2) # list + filters @@ -131,16 +168,21 @@ async def scenario(): assert w1 in open_nums and w2 in open_nums # validation errors - bad = await wo.create_workorder(db, description="x", asset_num="A", site_id="S", priority=9) + bad = await wo.create_workorder( + db, description="x", asset_num="A", site_id="S", priority=9 + ) assert not bad["success"] and bad["error_code"] == "VALIDATION_ERROR" # my assigned (open_only excludes the closed one) - mine = await wo.get_my_assigned_workorders(db, "HVACTECH1", site_id="MAIN", open_only=True) + mine = await wo.get_my_assigned_workorders( + db, "HVACTECH1", site_id="MAIN", open_only=True + ) assert mine["data"]["totalCount"] == 0 # kpis - k = await wo.get_workorder_kpis(db, "MAIN", period_months=3, - now=datetime(2020, 5, 1, tzinfo=timezone.utc)) + k = await wo.get_workorder_kpis( + db, "MAIN", period_months=3, now=datetime(2020, 5, 1, tzinfo=timezone.utc) + ) assert k["data"]["completed"] == 1 and k["data"]["total_workorders"] >= 3 print("ALL ASSERTIONS PASSED") diff --git a/src/servers/wo/workorders.py b/src/servers/wo/workorders.py index 784ad74f6..39bf35fcd 100644 --- a/src/servers/wo/workorders.py +++ b/src/servers/wo/workorders.py @@ -15,6 +15,7 @@ binds a real client; tests bind an in-memory fake. Each returns the same `{success, data, metadata}` / `{success, error, error_code}` envelope as Maximo MCP. """ + from __future__ import annotations from datetime import datetime, timezone, timedelta @@ -45,10 +46,17 @@ def _public(doc: Dict[str, Any]) -> Dict[str, Any]: # --------------------------------------------------------------------------- # # Read tools # --------------------------------------------------------------------------- # -async def list_workorders(db, site_id: Optional[str] = None, status: Optional[str] = None, - asset_num: Optional[str] = None, priority: Optional[int] = None, - date_from: Optional[str] = None, date_to: Optional[str] = None, - page_size: int = 50, page_num: int = 1) -> Dict[str, Any]: +async def list_workorders( + db, + site_id: Optional[str] = None, + status: Optional[str] = None, + asset_num: Optional[str] = None, + priority: Optional[int] = None, + date_from: Optional[str] = None, + date_to: Optional[str] = None, + page_size: int = 50, + page_num: int = 1, +) -> Dict[str, Any]: """List work orders with optional filters (site, status, asset, priority, date window). `status` accepts a single value, or the pseudo-values OPEN / APPROVED_PENDING. @@ -82,15 +90,18 @@ async def list_workorders(db, site_id: Optional[str] = None, status: Optional[st # No Mango `sort` — that requires a matching index and 400s without one. # Sort client-side instead (robust to missing reportdate / indexes). docs = await db.find(sel, limit=1000000) - docs.sort(key=lambda d: (d.get("reportdate") or ""), reverse=True) + docs.sort(key=lambda d: d.get("reportdate") or "", reverse=True) total = len(docs) if not page_size: # page_size=0 (or None) → return everything page = [_public(d) for d in docs] else: start = (page_num - 1) * page_size - page = [_public(d) for d in docs[start:start + page_size]] - return envelope({"workorders": page, "totalCount": total}, - duration_ms=t_ms(t), record_count=len(page)) + page = [_public(d) for d in docs[start : start + page_size]] + return envelope( + {"workorders": page, "totalCount": total}, + duration_ms=t_ms(t), + record_count=len(page), + ) async def get_workorder(db, wonum: str, site_id: str) -> Dict[str, Any]: @@ -98,19 +109,31 @@ async def get_workorder(db, wonum: str, site_id: str) -> Dict[str, Any]: with Timer() as t: doc = await db.get(_doc_id(site_id, wonum)) if not doc: - return error(f"Work order '{wonum}' not found in site '{site_id}'", "NOT_FOUND") + return error( + f"Work order '{wonum}' not found in site '{site_id}'", "NOT_FOUND" + ) return envelope(_public(doc), duration_ms=t_ms(t)) async def get_workorder_tasks(db, wonum: str, site_id: str) -> Dict[str, Any]: """List child task rows whose `parent` references this work order.""" with Timer() as t: - docs = await db.find({"type": "workorder", "parent": wonum, "siteid": site_id.upper()}, - limit=1000) - docs.sort(key=lambda d: (d.get("taskid") or 0)) # sort client-side (no index needed) - return envelope({"parent_wonum": wonum, "site_id": site_id, - "tasks": [_public(d) for d in docs]}, - duration_ms=t_ms(t), record_count=len(docs)) + docs = await db.find( + {"type": "workorder", "parent": wonum, "siteid": site_id.upper()}, + limit=1000, + ) + docs.sort( + key=lambda d: d.get("taskid") or 0 + ) # sort client-side (no index needed) + return envelope( + { + "parent_wonum": wonum, + "site_id": site_id, + "tasks": [_public(d) for d in docs], + }, + duration_ms=t_ms(t), + record_count=len(docs), + ) async def get_workorder_costs(db, wonum: str, site_id: str) -> Dict[str, Any]: @@ -118,58 +141,106 @@ async def get_workorder_costs(db, wonum: str, site_id: str) -> Dict[str, Any]: with Timer() as t: wo = await db.get(_doc_id(site_id, wonum)) if not wo: - return error(f"Work order '{wonum}' not found in site '{site_id}'.", "NOT_FOUND") + return error( + f"Work order '{wonum}' not found in site '{site_id}'.", "NOT_FOUND" + ) f = lambda n: float(wo.get(n) or 0) - labor, material, service, tool = f("actlabcost"), f("actmatcost"), f("actservcost"), f("acttoolcost") + labor, material, service, tool = ( + f("actlabcost"), + f("actmatcost"), + f("actservcost"), + f("acttoolcost"), + ) total = f("acttotalcost") or (labor + material + service + tool) - breakdown = [{"category": c, "amount": round(a, 2), - "share_pct": round((a / total) * 100, 1) if total else 0} - for c, a in (("labor", labor), ("material", material), - ("service", service), ("tool", tool))] - return envelope({"wonum": wonum, "site_id": site_id, "status": wo.get("status"), - "assetnum": wo.get("assetnum"), "location": wo.get("location"), - "actual_hours": f("actlabhrs"), "total_cost": round(total, 2), - "breakdown": breakdown}, duration_ms=t_ms(t)) - - -async def get_workorder_actuals_vs_planned(db, wonum: str, site_id: str) -> Dict[str, Any]: + breakdown = [ + { + "category": c, + "amount": round(a, 2), + "share_pct": round((a / total) * 100, 1) if total else 0, + } + for c, a in ( + ("labor", labor), + ("material", material), + ("service", service), + ("tool", tool), + ) + ] + return envelope( + { + "wonum": wonum, + "site_id": site_id, + "status": wo.get("status"), + "assetnum": wo.get("assetnum"), + "location": wo.get("location"), + "actual_hours": f("actlabhrs"), + "total_cost": round(total, 2), + "breakdown": breakdown, + }, + duration_ms=t_ms(t), + ) + + +async def get_workorder_actuals_vs_planned( + db, wonum: str, site_id: str +) -> Dict[str, Any]: """Estimated vs actual hours and cost variance for one work order.""" with Timer() as t: wo = await db.get(_doc_id(site_id, wonum)) if not wo: - return error(f"Work order '{wonum}' not found in site '{site_id}'.", "NOT_FOUND") + return error( + f"Work order '{wonum}' not found in site '{site_id}'.", "NOT_FOUND" + ) f = lambda n: float(wo.get(n) or 0) def var(est, act): - return {"estimated": round(est, 2), "actual": round(act, 2), - "variance_abs": round(act - est, 2), - "variance_pct": round(((act - est) / est) * 100, 1) if est else None, - "over_budget": act > est} - - est_total = f("esttotalcost") or (f("estlabcost") + f("estmatcost") + f("estservcost") + f("esttoolcost")) - act_total = f("acttotalcost") or (f("actlabcost") + f("actmatcost") + f("actservcost") + f("acttoolcost")) - return envelope({"wonum": wonum, "site_id": site_id, "status": wo.get("status"), - "worktype": wo.get("worktype"), - "labor_hours": var(f("estlabhrs"), f("actlabhrs")), - "labor_cost": var(f("estlabcost"), f("actlabcost")), - "material_cost": var(f("estmatcost"), f("actmatcost")), - "service_cost": var(f("estservcost"), f("actservcost")), - "tool_cost": var(f("esttoolcost"), f("acttoolcost")), - "total_cost": var(est_total, act_total)}, duration_ms=t_ms(t)) - - -async def get_workorder_kpis(db, site_id: str, period_months: int = 3, - now: Optional[datetime] = None) -> Dict[str, Any]: + return { + "estimated": round(est, 2), + "actual": round(act, 2), + "variance_abs": round(act - est, 2), + "variance_pct": round(((act - est) / est) * 100, 1) if est else None, + "over_budget": act > est, + } + + est_total = f("esttotalcost") or ( + f("estlabcost") + f("estmatcost") + f("estservcost") + f("esttoolcost") + ) + act_total = f("acttotalcost") or ( + f("actlabcost") + f("actmatcost") + f("actservcost") + f("acttoolcost") + ) + return envelope( + { + "wonum": wonum, + "site_id": site_id, + "status": wo.get("status"), + "worktype": wo.get("worktype"), + "labor_hours": var(f("estlabhrs"), f("actlabhrs")), + "labor_cost": var(f("estlabcost"), f("actlabcost")), + "material_cost": var(f("estmatcost"), f("actmatcost")), + "service_cost": var(f("estservcost"), f("actservcost")), + "tool_cost": var(f("esttoolcost"), f("acttoolcost")), + "total_cost": var(est_total, act_total), + }, + duration_ms=t_ms(t), + ) + + +async def get_workorder_kpis( + db, site_id: str, period_months: int = 3, now: Optional[datetime] = None +) -> Dict[str, Any]: """Site KPIs over a period: totals, backlog, overdue, avg completion, priority + asset breakdowns.""" with Timer() as t: now = now or datetime.now(timezone.utc) cutoff = _iso(now - timedelta(days=period_months * 30)) now_str = _iso(now) - docs = await db.find({"type": "workorder", "siteid": site_id.upper()}, limit=10000) + docs = await db.find( + {"type": "workorder", "siteid": site_id.upper()}, limit=10000 + ) wos = [w for w in docs if (w.get("reportdate") or "") >= cutoff] completed = [w for w in wos if w.get("status") == "COMP"] backlog = [w for w in wos if w.get("status") not in TERMINAL] - overdue = [w for w in backlog if w.get("targcompdate") and w["targcompdate"] < now_str] + overdue = [ + w for w in backlog if w.get("targcompdate") and w["targcompdate"] < now_str + ] times = [] for w in completed: @@ -184,27 +255,44 @@ async def get_workorder_kpis(db, site_id: str, period_months: int = 3, prio: Dict[str, int] = {} assets: Dict[str, int] = {} for w in wos: - prio[str(w.get("wopriority", "Unknown"))] = prio.get(str(w.get("wopriority", "Unknown")), 0) + 1 + prio[str(w.get("wopriority", "Unknown"))] = ( + prio.get(str(w.get("wopriority", "Unknown")), 0) + 1 + ) a = w.get("assetnum", "UNKNOWN") assets[a] = assets.get(a, 0) + 1 top = sorted(assets.items(), key=lambda x: x[1], reverse=True)[:5] - return envelope({"site_id": site_id, "period_months": period_months, - "total_workorders": len(wos), "completed": len(completed), - "backlog": len(backlog), "overdue": len(overdue), - "avg_completion_hrs": avg_hrs, "priority_breakdown": prio, - "top_assets_by_wo_count": [{"asset": a, "count": c} for a, c in top]}, - duration_ms=t_ms(t)) - - -async def get_schedule_calendar(db, site_id: str, date_from: Optional[str] = None, - date_to: Optional[str] = None, group_by: str = "date", - now: Optional[datetime] = None) -> Dict[str, Any]: + return envelope( + { + "site_id": site_id, + "period_months": period_months, + "total_workorders": len(wos), + "completed": len(completed), + "backlog": len(backlog), + "overdue": len(overdue), + "avg_completion_hrs": avg_hrs, + "priority_breakdown": prio, + "top_assets_by_wo_count": [{"asset": a, "count": c} for a, c in top], + }, + duration_ms=t_ms(t), + ) + + +async def get_schedule_calendar( + db, + site_id: str, + date_from: Optional[str] = None, + date_to: Optional[str] = None, + group_by: str = "date", + now: Optional[datetime] = None, +) -> Dict[str, Any]: """Scheduled (non-terminal) work orders in a date window, optionally bucketed by day.""" with Timer() as t: now = now or datetime.now(timezone.utc) date_from = date_from or now.strftime("%Y-%m-%d") date_to = date_to or (now + timedelta(days=14)).strftime("%Y-%m-%d") - docs = await db.find({"type": "workorder", "siteid": site_id.upper()}, limit=10000) + docs = await db.find( + {"type": "workorder", "siteid": site_id.upper()}, limit=10000 + ) in_win = [] for w in docs: if (w.get("status") or "") in TERMINAL: @@ -221,22 +309,36 @@ async def get_schedule_calendar(db, site_id: str, date_from: Optional[str] = Non for w in in_win: day = (w.get("schedstart") or w.get("targstartdate"))[:10] buckets.setdefault(day, []).append(_public(w)) - payload = {"site_id": site_id, "date_from": date_from, "date_to": date_to, - "total_scheduled": len(in_win), - "by_date": [{"date": d, "count": len(r), "workorders": r} - for d, r in sorted(buckets.items())]} + payload = { + "site_id": site_id, + "date_from": date_from, + "date_to": date_to, + "total_scheduled": len(in_win), + "by_date": [ + {"date": d, "count": len(r), "workorders": r} + for d, r in sorted(buckets.items()) + ], + } else: - payload = {"site_id": site_id, "date_from": date_from, "date_to": date_to, - "total_scheduled": len(in_win), "workorders": [_public(w) for w in in_win]} + payload = { + "site_id": site_id, + "date_from": date_from, + "date_to": date_to, + "total_scheduled": len(in_win), + "workorders": [_public(w) for w in in_win], + } return envelope(payload, duration_ms=t_ms(t), record_count=len(in_win)) -async def get_my_assigned_workorders(db, labor_code: str, site_id: Optional[str] = None, - open_only: bool = True) -> Dict[str, Any]: +async def get_my_assigned_workorders( + db, labor_code: str, site_id: Optional[str] = None, open_only: bool = True +) -> Dict[str, Any]: """Work orders with a `wplabor` line for the given labor (technician).""" with Timer() as t: - docs = await db.find({"type": "workorder", "wplabor": {"$elemMatch": {"laborcode": labor_code}}}, - limit=10000) + docs = await db.find( + {"type": "workorder", "wplabor": {"$elemMatch": {"laborcode": labor_code}}}, + limit=10000, + ) out = [] for w in docs: if site_id and (w.get("siteid") or "").upper() != site_id.upper(): @@ -244,23 +346,36 @@ async def get_my_assigned_workorders(db, labor_code: str, site_id: Optional[str] if open_only and (w.get("status") or "").upper() in TERMINAL: continue out.append(_public(w)) - return envelope({"labor_code": labor_code, "workorders": out, "totalCount": len(out)}, - duration_ms=t_ms(t), record_count=len(out)) + return envelope( + {"labor_code": labor_code, "workorders": out, "totalCount": len(out)}, + duration_ms=t_ms(t), + record_count=len(out), + ) # --------------------------------------------------------------------------- # # Write tools # --------------------------------------------------------------------------- # -async def create_workorder(db, description: str, asset_num: str, site_id: str, - priority: int = 3, work_type: str = "CM", - reported_by: Optional[str] = None, location: Optional[str] = None, - notes: Optional[str] = None, wonum: Optional[str] = None, - aob_source: Optional[Dict[str, Any]] = None, - now: Optional[datetime] = None) -> Dict[str, Any]: +async def create_workorder( + db, + description: str, + asset_num: str, + site_id: str, + priority: int = 3, + work_type: str = "CM", + reported_by: Optional[str] = None, + location: Optional[str] = None, + notes: Optional[str] = None, + wonum: Optional[str] = None, + aob_source: Optional[Dict[str, Any]] = None, + now: Optional[datetime] = None, +) -> Dict[str, Any]: """Create a new work order (status WAPPR). Optionally pin `wonum` for reproducible ids and attach `aob_source` provenance (the agent/trigger that generated it).""" if not description or not asset_num or not site_id: - return error("description, asset_num, and site_id are required", "VALIDATION_ERROR") + return error( + "description, asset_num, and site_id are required", "VALIDATION_ERROR" + ) if not 1 <= priority <= 5: return error("priority must be between 1 and 5", "VALIDATION_ERROR") if work_type not in WORKTYPES: @@ -270,10 +385,17 @@ async def create_workorder(db, description: str, asset_num: str, site_id: str, now = now or datetime.now(timezone.utc) won = wonum or await db.next_wonum(site_id) doc: Dict[str, Any] = { - "_id": _doc_id(site_id, won), "type": "workorder", "schema_version": "1.0.0", - "wonum": won, "siteid": site_id.upper(), "description": description[:100], - "assetnum": asset_num, "wopriority": priority, "worktype": work_type, - "status": "WAPPR", "reportdate": _iso(now), + "_id": _doc_id(site_id, won), + "type": "workorder", + "schema_version": "1.0.0", + "wonum": won, + "siteid": site_id.upper(), + "description": description[:100], + "assetnum": asset_num, + "wopriority": priority, + "worktype": work_type, + "status": "WAPPR", + "reportdate": _iso(now), } if reported_by: doc["reportedby"] = reported_by @@ -293,10 +415,17 @@ async def generate_work_order(db, **kwargs) -> Dict[str, Any]: return await create_workorder(db, **kwargs) -async def update_workorder(db, wonum: str, site_id: str, description: Optional[str] = None, - priority: Optional[int] = None, location: Optional[str] = None, - asset_num: Optional[str] = None, notes: Optional[str] = None, - failure_code: Optional[str] = None) -> Dict[str, Any]: +async def update_workorder( + db, + wonum: str, + site_id: str, + description: Optional[str] = None, + priority: Optional[int] = None, + location: Optional[str] = None, + asset_num: Optional[str] = None, + notes: Optional[str] = None, + failure_code: Optional[str] = None, +) -> Dict[str, Any]: """Update mutable fields on an existing work order.""" with Timer() as t: doc = await db.get(_doc_id(site_id, wonum)) @@ -319,9 +448,14 @@ async def update_workorder(db, wonum: str, site_id: str, description: Optional[s return envelope(_public(doc), duration_ms=t_ms(t)) -async def _change_status(db, wonum: str, site_id: str, new_status: str, - extra: Optional[Dict[str, Any]] = None, - now: Optional[datetime] = None) -> Dict[str, Any]: +async def _change_status( + db, + wonum: str, + site_id: str, + new_status: str, + extra: Optional[Dict[str, Any]] = None, + now: Optional[datetime] = None, +) -> Dict[str, Any]: with Timer() as t: doc = await db.get(_doc_id(site_id, wonum)) if not doc: @@ -334,21 +468,35 @@ async def _change_status(db, wonum: str, site_id: str, new_status: str, return envelope(_public(doc), duration_ms=t_ms(t)) -async def approve_workorder(db, wonum: str, site_id: str, now: Optional[datetime] = None) -> Dict[str, Any]: +async def approve_workorder( + db, wonum: str, site_id: str, now: Optional[datetime] = None +) -> Dict[str, Any]: """Approve a work order (status → APPR).""" return await _change_status(db, wonum, site_id, "APPR", now=now) -async def cancel_workorder(db, wonum: str, site_id: str, reason: Optional[str] = None, - now: Optional[datetime] = None) -> Dict[str, Any]: +async def cancel_workorder( + db, + wonum: str, + site_id: str, + reason: Optional[str] = None, + now: Optional[datetime] = None, +) -> Dict[str, Any]: """Cancel a work order (status → CAN).""" extra = {"description_longdescription": reason} if reason else None return await _change_status(db, wonum, site_id, "CAN", extra=extra, now=now) -async def assign_technician(db, wonum: str, site_id: str, labor_code: str, - craft: Optional[str] = None, start_date: Optional[str] = None, - hours_planned: float = 8.0, now: Optional[datetime] = None) -> Dict[str, Any]: +async def assign_technician( + db, + wonum: str, + site_id: str, + labor_code: str, + craft: Optional[str] = None, + start_date: Optional[str] = None, + hours_planned: float = 8.0, + now: Optional[datetime] = None, +) -> Dict[str, Any]: """Append a planned-labor (`wplabor`) line assigning a technician to the work order.""" if not all([wonum, site_id, labor_code]): return error("wonum, site_id, and labor_code are required", "VALIDATION_ERROR") @@ -356,8 +504,11 @@ async def assign_technician(db, wonum: str, site_id: str, labor_code: str, doc = await db.get(_doc_id(site_id, wonum)) if not doc: return error(f"Work order '{wonum}' not found", "NOT_FOUND") - line: Dict[str, Any] = {"laborcode": labor_code, "laborhrs": hours_planned, - "startdate": start_date or _iso(now or datetime.now(timezone.utc))} + line: Dict[str, Any] = { + "laborcode": labor_code, + "laborhrs": hours_planned, + "startdate": start_date or _iso(now or datetime.now(timezone.utc)), + } if craft: line["craft"] = craft doc.setdefault("wplabor", []).append(line) @@ -365,9 +516,15 @@ async def assign_technician(db, wonum: str, site_id: str, labor_code: str, return envelope(_public(doc), duration_ms=t_ms(t)) -async def close_workorder(db, wonum: str, site_id: str, actual_hours: float = 0.0, - failure_code: Optional[str] = None, resolution_notes: Optional[str] = None, - now: Optional[datetime] = None) -> Dict[str, Any]: +async def close_workorder( + db, + wonum: str, + site_id: str, + actual_hours: float = 0.0, + failure_code: Optional[str] = None, + resolution_notes: Optional[str] = None, + now: Optional[datetime] = None, +) -> Dict[str, Any]: """Close a work order (status → COMP), recording actual hours, failure code, resolution, and stamping `actfinish`.""" now = now or datetime.now(timezone.utc) @@ -381,4 +538,4 @@ async def close_workorder(db, wonum: str, site_id: str, actual_hours: float = 0. def t_ms(timer: Timer) -> int: # Timer.ms is only set on __exit__; inside the block fall back to 0. - return getattr(timer, "ms", 0) \ No newline at end of file + return getattr(timer, "ms", 0)