diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 820ebcc..2c61195 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -208,7 +208,21 @@ docs(README): add contributing guide ## How to Add a New Tool All tools are located in the `src/workspace/tools/` directory and inherit from -the `BaseTool` base class. +the `BaseTool` base class. All tool methods must return a `ToolResult` object +(located in `src/models/tools/tool_result.py`). + +### ToolResult Model + +`ToolResult` is the unified tool execution result wrapper: + +| Field | Description | +| :------------ | :---------------------------------------------- | +| `success` | Whether execution succeeded (`bool`) | +| `func_name` | Name of the tool method (`str`) | +| `func_kwargs` | Dictionary of invocation parameters (`dict`) | +| `data` | Return data on success (`Any`) | +| `error` | Error message on failure (`str\|None`) | +| `response` | Auto-generated XML string (for LLM consumption) | ### Step Overview @@ -217,6 +231,7 @@ the `BaseTool` base class. 2. **Inherit from BaseTool**: ```python + from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -231,21 +246,31 @@ the `BaseTool` base class. ) self.func = self.your_method self.params = BaseTool.extract_params(self.your_method) + self.param_descriptions = { + "param1": "Parameter description", + "param2": "Parameter description" + } @BaseTool.handle_tool_exceptions - def your_method(self, param1: str, param2: int = 0) -> str: + def your_method(self, param1: str, param2: int = 0) -> ToolResult: """ Tool description -- will be generated as LLM-readable documentation. - - Parameters - ---------- - param1: Parameter description - param2: Parameter description (with default value) """ # Path operations must be validated through PathValidator path = self.workspace.path_validator.validate(param1) - # ... tool logic ... - return f'{path}x{param2}' + + # Note: Return ToolResult, not raw data + # On success, use make_success_response + return self.make_success_response( + kwargs=locals().copy(), + data=f'{path}x{param2}' + ) + + # On failure, use make_failed_response + # return self.make_failed_response( + # kwargs=locals().copy(), + # error="Specific error description" + # ) ``` 3. **Special handling for write operations**: If the tool involves writing @@ -254,17 +279,59 @@ the `BaseTool` base class. `self._validate_mtime(path)` - Generate a diff and record a `PENDING_AUDIT` snapshot instead of writing directly to disk + - Still return via + `self.make_success_response(kwargs=locals().copy(), data=...)` or + `self.make_failed_response(kwargs=locals().copy(), error=...)` - Refer to the implementations of `WriteTool` and `EditTool` 4. **Register the tool**: Import and instantiate your tool in the `register()` method of `src/core/tool_registry.py`: + ```python + def register(self, workspace: Workspace) -> None: + from src.workspace.tools.your_tool import YourTool + + self._workspace = workspace + + for cls in ( + # ... other existing tool classes ... + YourTool, + ): + try: + tool = cls(workspace) + if tool.func is None or tool.params is None: + warnings.warn(f"Tool {tool.name} has no registered function callback and parameters", stacklevel=2) + continue + self._tools[tool.name] = tool + self._set_tool_category(tool) + except ValueError: + pass + ``` + 5. **Add tests**: Create corresponding test files under `tests/workspace/tools/`. At minimum, cover: - - Normal execution paths + - Normal execution paths (assert `result.success is True`, check + `result.data`) + - Failure scenarios (assert `result.success is False`, check `result.error`) - Parameter validation (e.g., empty parameters, invalid values) - Path security (e.g., out-of-bounds access) + Test example: + + ```python + def test_your_tool_success(workspace): + tool = YourTool(workspace) + result = tool.your_method(param1="valid_path", param2=42) + assert result.success is True + assert result.data is not None + + def test_your_tool_failure(workspace): + tool = YourTool(workspace) + result = tool.your_method(param1="../outside_path", param2=42) + assert result.success is False + assert "WorkspaceBoundaryError" in result.error + ``` + --- ## Testing Requirements diff --git a/CONTRIBUTING_ZH.md b/CONTRIBUTING_ZH.md index 680f0c9..2bb7b66 100644 --- a/CONTRIBUTING_ZH.md +++ b/CONTRIBUTING_ZH.md @@ -192,7 +192,22 @@ docs(README): 添加贡献指南 / add contributing guide ## 如何添加新工具 -所有工具位于 `src/workspace/tools/` 目录,继承 `BaseTool` 基类. +所有工具位于 `src/workspace/tools/` 目录,继承 `BaseTool` +基类. 所有工具方法必须返回 `ToolResult` 对象(位于 +`src/models/tools/tool_result.py`). + +### ToolResult 模型 + +`ToolResult` 是统一的工具执行结果包装器: + +| 字段 | 说明 | +| :------------ | :--------------------------------------- | +| `success` | 执行是否成功 (`bool`) | +| `func_name` | 工具方法名 (`str`) | +| `func_kwargs` | 调用参数字典 (`dict`) | +| `data` | 成功时的返回数据 (`Any`) | +| `error` | 失败时的错误消息 (`str\|None`) | +| `response` | 自动生成的 XML 格式字符串(用于 LLM 消费) | ### 步骤概述 @@ -201,6 +216,7 @@ docs(README): 添加贡献指南 / add contributing guide 2. **继承 BaseTool**: ```python + from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -215,36 +231,87 @@ docs(README): 添加贡献指南 / add contributing guide ) self.func = self.your_method self.params = BaseTool.extract_params(self.your_method) + self.param_descriptions = { + "param1": "参数说明", + "param2": "参数说明" + } @BaseTool.handle_tool_exceptions - def your_method(self, param1: str, param2: int = 0) -> str: + def your_method(self, param1: str, param2: int = 0) -> ToolResult: """ 工具描述 -- 会生成为 LLM 可读的文档. - - Parameters - ---------- - param1: 参数说明 - param2: 参数说明(带默认值) """ # 路径操作必须通过 PathValidator 验证 path = self.workspace.path_validator.validate(param1) - # ... 工具逻辑 ... - return f'{path}x{param2}' + + # 注意: 返回 ToolResult 而非原始数据 + # 成功时使用 make_success_response + return self.make_success_response( + kwargs=locals().copy(), + data=f'{path}x{param2}' + ) + + # 失败时使用 make_failed_response + # return self.make_failed_response( + # kwargs=locals().copy(), + # error="具体的错误描述" + # ) ``` 3. **写入操作的特殊处理**: 如果工具涉及写入(`write_permission=True`),需要: - 通过 `self._validate_mtime(path)` 检查文件是否被外部修改 - 生成 diff 并记录 `PENDING_AUDIT` 快照,而非直接写入磁盘 + - 返回格式仍然使用 + `self.make_success_response(kwargs=locals().copy(), data=...)` 或 + `self.make_failed_response(kwargs=locals().copy(), error=...)` - 参考 `WriteTool` 和 `EditTool` 的实现 4. **注册工具**: 在 `src/core/tool_registry.py` 的 `register()` 方法中导入并实例化你的工具: + ```python + def register(self, workspace: Workspace) -> None: + from src.workspace.tools.your_tool import YourTool + + self._workspace = workspace + + for cls in ( + # ... 其他已有的工具类 ... + YourTool, + ): + try: + tool = cls(workspace) + if tool.func is None or tool.params is None: + warnings.warn(f"工具{tool.name}没有注册功能回调和参数", stacklevel=2) + continue + self._tools[tool.name] = tool + self._set_tool_category(tool) + except ValueError: + pass + ``` + 5. **补充测试**: 在 `tests/workspace/tools/` 下创建对应的测试文件. 至少覆盖: - - 正常执行路径 + - 正常执行路径(断言 `result.success is True`, 检查 `result.data`) + - 失败场景(断言 `result.success is False`, 检查 `result.error`) - 参数验证(如空参数、非法值) - 路径安全(如越界访问) + 测试示例: + + ```python + def test_your_tool_success(workspace): + tool = YourTool(workspace) + result = tool.your_method(param1="valid_path", param2=42) + assert result.success is True + assert result.data is not None + + def test_your_tool_failure(workspace): + tool = YourTool(workspace) + result = tool.your_method(param1="../outside_path", param2=42) + assert result.success is False + assert "WorkspaceBoundaryError" in result.error + ``` + --- ## 测试要求 diff --git a/README.md b/README.md index 396bb4c..2675310 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ LLM chat interfaces. Paste LLM-generated tool calls (in XML format), review and audit dangerous operations, and manage sessions with full history tracking -- all running locally on your machine. -> **Version**: 0.4.1 | **Python**: >=3.14 +> **Version**: 0.5.0 | **Python**: >=3.14 --- @@ -125,8 +125,7 @@ ManualAid registers 12 tools for LLM use via XML function calls: | -------------- | ------------------------------------------------ | | `ls` | List directory contents | | `glob` | Find files by glob pattern | -| `read` | Read file contents (with optional line limit) | -| `read_lines` | Read specific line range from a file | +| `read` | Read file contents with optional line range | | `stat` | Get file/directory metadata (size, mtime, lines) | | `exact_search` | Exact string search with case/whole-word options | | `regex_search` | Regex search with context display | diff --git a/README_ZH.md b/README_ZH.md index 2a4dc7e..3479c05 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -8,7 +8,7 @@ ManualAid 提供了一个基于 Textual 的 TUI 控制台,在剪贴板和 LLM 聊天界面之间架起桥梁. 粘贴 LLM 生成的工具调用(XML 格式),审查和审计危险操作,并通过完整的历史追踪管理会话 -- 一切都在本地运行. -> **版本**: 0.4.1 | **Python**: >=3.14 +> **版本**: 0.5.0 | **Python**: >=3.14 --- @@ -112,8 +112,7 @@ ManualAid 注册了 12 个工具供 LLM 通过 XML 函数调用使用: | -------------- | ----------------------------------------- | | `ls` | 列出目录内容 | | `glob` | 通过 glob 模式查找文件 | -| `read` | 读取文件内容(可选行数限制) | -| `read_lines` | 读取文件中指定范围的行 | +| `read` | 读取文件内容,支持指定行范围 | | `stat` | 获取文件/目录元数据(大小、修改时间、行数) | | `exact_search` | 精确字符串搜索,支持大小写/全词匹配 | | `regex_search` | 正则表达式搜索,支持上下文显示 | diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index b9b7b89..70e2416 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -8,6 +8,81 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.0] - 2026-05-05 + +### Added + +- **Structured Tool Result**: Introduced `ToolResult` data class as the unified + return type for all tools, replacing inconsistent string and list responses. + The class includes `success`, `data`, `error`, and `response` attributes with + built-in result compression and standardized XML formatting. All tools now + return `ToolResult` objects, enabling consistent upstream error handling + ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)). +- **File Pattern Filtering for `exact_search`**: Added `file_pattern` parameter + to `exact_search` (default `"*"`), matching the existing `regex_search` + behavior. Allows filtering search scope by file extension or glob pattern + ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)). +- **Auto-Categorization of Tools**: Tools are now automatically classified as + read-only or write based on the `write_permission` attribute, eliminating + manual category registration and reducing maintenance overhead + ([#135, #139](https://github.com/SunYanbox/ManualAid/issues/135)). +- **Range Reading in `read` Tool**: The `read` tool now supports precise line + range reading via `start`, `end` (supports negative indexing), and `context` + parameters, replacing the coarse `max_lines` approach. Display header now + shows the actual line range read (`[Lines start-end / total_lines]`) + ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)). +- **Parameter Descriptions**: Introduced `param_descriptions` dictionary in + `BaseTool` allowing each tool to provide human-readable parameter + descriptions. Parameter documentation format changed from inline XML to + Markdown list items (`- **name** (type, required/optional): description`) + ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **File Size Limit**: Added configurable max file size limit + (`MAX_READ_FILE_SIZE`, default 10MB) to the `read` tool to prevent + out-of-memory errors when reading large files + ([#130, #141](https://github.com/SunYanbox/ManualAid/issues/130)). + +### Changed + +- **Tool Path Parameter Unification**: Renamed path parameters across all tools + to a consistent `path` name — `file_path` (read, write, edit) and + `folder_path` (ls, glob) are now uniformly `path`. This reduces LLM confusion + and injection token length + ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **Tool Injection Optimization**: Removed redundant docstring `Parameters` + sections from tool functions, shortened tool descriptions, and streamlined + parameter documentation format. Combined with parameter unification, these + changes significantly reduce system prompt injection length, lowering LLM + hallucination risk + ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **Symbol Search Performance**: Replaced per-pattern file traversal with a + single-pass multi-pattern search via `search_content_multi_pattern` API, + eliminating N× I/O overhead. Results are now parsed as structured `list[dict]` + instead of regex-parsing formatted text, fixing the "format-then-parse" + anti-pattern + ([#132, #137](https://github.com/SunYanbox/ManualAid/issues/132)). +- **Exception Handling Consolidation**: The `handle_tool_exceptions` decorator + now uniformly wraps all exceptions into `ToolResult(success=False, error=...)` + objects. Removed `ToolErrorResponse` dependency; error messages are now + formatted as `ClassName: Message` + ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)). + +### Fixed + +- **`limit` Semantics in Search Tools**: Corrected the `limit` parameter in both + `exact_search` and `regex_search` to count individual match results rather + than files scanned, aligning behavior with user expectations + ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)). +- **Redundant Warnings in Input Parser**: Removed stale `warnings.warn` calls + and the unused `import warnings` dependency from the input parser + ([#138](https://github.com/SunYanbox/ManualAid/issues/138)). + +### Removed + +- **`read_lines` Tool**: Merged into the enhanced `read` tool with range-reading + support. All `read_lines` functionality is now accessible via `read` with + `start`/`end`/`context` parameters + ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)). + ## [0.4.1] - 2026-05-04 ### Added @@ -176,6 +251,7 @@ and this project adheres to _Initial release features and history._ +[0.5.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.5.0 [0.4.1]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.1 [0.4.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.0 [0.3.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.3.0 diff --git a/docs/CHANGELOG_ZH.md b/docs/CHANGELOG_ZH.md index d8fd008..1540562 100644 --- a/docs/CHANGELOG_ZH.md +++ b/docs/CHANGELOG_ZH.md @@ -8,6 +8,60 @@ [Keep a Changelog](https://keepachangelog.com/zh-CN/1.0.0/). 并采用 [语义化版本](https://semver.org/lang/Chinese/). +## [0.5.0] - 2026-05-05 + +### 新增 + +- **结构化工具返回结果**: 引入 `ToolResult` + 数据类作为所有工具的统一返回类型,替代以往不一致的字符串和列表响应. 该类包含 + `success`、`data`、`error`、`response` + 属性,内置结果压缩与标准化 XML 格式化功能. 所有工具方法现均返回 `ToolResult` + 对象,使上游调用方能进行一致的错误处理 ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)). +- **`exact_search` 文件模式过滤**: 为 `exact_search` 新增 `file_pattern` + 参数(默认 `"*"`),与已有的 `regex_search` + 行为对齐. 支持按文件扩展名或通配符模式过滤搜索范围 ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)). +- **工具自动分类**: 工具现根据 `write_permission` + 属性自动归类为只读或可写,无需手动注册分类,减少维护成本 ([#135, #139](https://github.com/SunYanbox/ManualAid/issues/135)). +- **`read` 工具范围读取**: `read` 工具现支持通过 `start`、`end`(支持负数索引) 和 + `context` 参数进行精确的行范围读取,替代原有的粗粒度 `max_lines` + 方式. 显示头部现展示实际读取的行范围 (`[行 start-end / 共 total_lines 行]`) + ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)). +- **参数描述机制**: 在 `BaseTool` 中引入 `param_descriptions` + 字典,允许每个工具为参数提供可读描述. 参数文档格式从内联 XML 转为 Markdown 列表项 (`- **名称** (类型, 必需/可选): 描述`) + ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **文件大小限制**: 为 `read` + 工具添加了可配置的最大文件大小限制 (`MAX_READ_FILE_SIZE`,默认 10MB),防止读取大文件时内存溢出 ([#130, #141](https://github.com/SunYanbox/ManualAid/issues/130)). + +### 更改 + +- **工具路径参数统一**: 将所有工具中的路径参数统一重命名为 `path`——原 + `file_path`(read、write、edit)和 `folder_path`(ls、glob)现统一使用 + `path`. 此举减少 LLM 混淆并缩短注入 Token 长度 ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **工具注入长度优化**: 移除工具函数中冗余的 Docstring `Parameters` + 段落、缩短工具描述、精简参数文档格式. 配合参数统一,这些变更显著缩短了系统提示注入长度,降低 LLM 幻觉风险 ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **符号搜索性能重构**: 用单次遍历多模式搜索(`search_content_multi_pattern` + API)替代原有的逐模式文件遍历, 消除了 N 倍 I/O 开销. 搜索结果现解析为结构化的 + `list[dict]` + 而非对格式化文本做正则解析,修复了 "先格式化再解析"的反模式 ([#132, #137](https://github.com/SunYanbox/ManualAid/issues/132)). +- **异常处理整合**: `handle_tool_exceptions` 装饰器现统一将所有异常封装为 + `ToolResult(success=False, error=...)` 对象. 移除了 `ToolErrorResponse` + 依赖,错误消息格式化为 `ClassName: Message` + ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)). + +### 修复 + +- **搜索工具 `limit` 语义**: 修正了 `exact_search` 和 `regex_search` 中 `limit` + 参数的计数逻辑,从统计"扫描文件数"改为统计"匹配结果数",使行为符合用户预期 ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)). +- **输入解析器冗余警告**: 清理了输入解析器中已过时的 `warnings.warn` + 调用及未使用的 `import warnings` + 依赖 ([#138](https://github.com/SunYanbox/ManualAid/issues/138)). + +### 移除 + +- **`read_lines` 工具**: 已合并至增强后的 `read` 工具. 所有 `read_lines` + 功能现可通过 `read` 的 `start`/`end`/`context` + 参数访问 ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)). + ## [0.4.1] - 2026-05-04 ### 新增 @@ -144,6 +198,7 @@ _初始发布的功能和历史记录._ +[0.5.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.5.0 [0.4.1]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.1 [0.4.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.0 [0.3.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.3.0 diff --git a/pyproject.toml b/pyproject.toml index 5d32ecb..ed36fd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ManualAid" -version = "0.4.1" +version = "0.5.0" description = "" requires-python = ">=3.14" dependencies = [ diff --git a/src/console/commands/systems/quit_cmd.py b/src/console/commands/systems/quit_cmd.py index fe574c3..ec199ac 100644 --- a/src/console/commands/systems/quit_cmd.py +++ b/src/console/commands/systems/quit_cmd.py @@ -16,7 +16,7 @@ def __init__(self): def execute(self, context: CommandContext) -> CommandResult: # Close session explicitly before exit (atexit handler is a fallback, # but explicit is more reliable when sys.exit triggers early shutdown). - session_id = getattr(context.tool_registry, "_current_session_id", None) + session_id = getattr(context.tool_registry, "session_id", None) if session_id is not None and hasattr(context.workspace, "db"): context.workspace.db.close_session(session_id) context.console.print("[bold]Goodbye![/bold]") diff --git a/src/console/handlers/tool_handler.py b/src/console/handlers/tool_handler.py index 874a827..0faaf67 100644 --- a/src/console/handlers/tool_handler.py +++ b/src/console/handlers/tool_handler.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import time from typing import TYPE_CHECKING @@ -100,35 +99,9 @@ def handle(self, parsed_input: CommandParseResult) -> bool: start = time.perf_counter() # 执行 - try: - response = self.tool_registry.execute(func_name, **func_kwargs) - if isinstance(response, str): - response_str = response - elif isinstance(response, (dict, list, tuple)): - response_str = json.dumps(response) - else: - response_str = f"{response.__class__.__name__}({response})" - - temp_result = [ - "", - f"", - response_str, - "", - "", - ] - func_result = str.join("\n", temp_result) - except Exception as e: - import traceback - - error = ( - f"执行工具{func_name}(参数={parms})时出现错误: " - f"Error={e.__class__.__name__}({e}, {traceback.format_exc()})" - ) - self.console.print(f"[red]{error}[/red]") - func_result = "\n".join(["", "", error, "", ""]) - self.console.print(f"[red]{error}[/red]") - - collection.add(func_name, time.perf_counter() - start, kwargs=func_kwargs, result=func_result) + response = self.tool_registry.execute(func_name, **func_kwargs) + + collection.add(func_name, time.perf_counter() - start, kwargs=func_kwargs, result=response.response) result = "" diff --git a/src/console/main.py b/src/console/main.py index 927a3f2..0d1c65c 100644 --- a/src/console/main.py +++ b/src/console/main.py @@ -85,7 +85,6 @@ def init_workspace(start_path: str | None = None) -> Workspace | None: session_id = workspace.db.create_session(name=f"session_{time.strftime('%Y%m%d_%H%M%S')}") tool_registry.set_session_id(session_id) - workspace._current_session_id = session_id # Start a daemon heartbeat thread to periodically persist session duration # and guard against accidental deletion flag diff --git a/src/constants/__init__.py b/src/constants/__init__.py index 3d26edf..3d18726 100644 --- a/src/constants/__init__.py +++ b/src/constants/__init__.py @@ -1 +1 @@ -__version__ = "0.4.1" +__version__ = "0.5.0" diff --git a/src/core/input_parser.py b/src/core/input_parser.py index 61b2d32..620cb00 100644 --- a/src/core/input_parser.py +++ b/src/core/input_parser.py @@ -1,7 +1,6 @@ import html import re import shlex -import warnings from src.console.commands.command_registry import CommandRegistry from src.models.commands import CommandParseResult @@ -53,11 +52,6 @@ def parse_func_call(content: str, warns: list[str]) -> tuple[str, dict]: """ 从 标签中提取函数名和参数,使用健壮的回退机制. """ - warnings.warn( - "当参数存在`$", "", inner).strip() diff --git a/src/core/tool_registry.py b/src/core/tool_registry.py index c8d485d..55f8650 100644 --- a/src/core/tool_registry.py +++ b/src/core/tool_registry.py @@ -12,6 +12,7 @@ import nest_asyncio from src.console.result_manager import _to_string +from src.models.tools.tool_result import ToolResult from src.utils.string_snapshot import truncate_string from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -55,36 +56,6 @@ def __init__(self): # 配置常量 - 从环境变量读取,带默认值 self.MAX_DOC_LENGTH = int(os.getenv("TOOL_MAX_DOC_LENGTH", "360")) self.MAX_FUNC_NAME_LENGTH = int(os.getenv("TOOL_MAX_FUNC_NAME_LENGTH", "80")) - self.MAX_RESULT_LENGTH = int(os.getenv("TOOL_MAX_RESULT_LENGTH", "30000")) - self.LIST_TRUNCATE_THRESHOLD = int(os.getenv("TOOL_LIST_TRUNCATE_THRESHOLD", "100")) - self.DICT_TRUNCATE_THRESHOLD = int(os.getenv("TOOL_DICT_TRUNCATE_THRESHOLD", "100")) - - # 验证配置值 - self._validate_config() - - def _validate_config(self) -> None: - """验证配置值确保在合理范围内""" - if self.MAX_RESULT_LENGTH < 10: - warnings.warn( - f"TOOL_MAX_RESULT_LENGTH 过小({self.MAX_RESULT_LENGTH}),建议至少为100", UserWarning, stacklevel=2 - ) - self.MAX_RESULT_LENGTH = 100 - - if self.LIST_TRUNCATE_THRESHOLD < 10: - warnings.warn( - f"TOOL_LIST_TRUNCATE_THRESHOLD 过小({self.LIST_TRUNCATE_THRESHOLD}),建议至少为50", - UserWarning, - stacklevel=2, - ) - self.LIST_TRUNCATE_THRESHOLD = 50 - - if self.DICT_TRUNCATE_THRESHOLD < 10: - warnings.warn( - f"TOOL_DICT_TRUNCATE_THRESHOLD 过小({self.DICT_TRUNCATE_THRESHOLD}),建议至少为50", - UserWarning, - stacklevel=2, - ) - self.DICT_TRUNCATE_THRESHOLD = 50 def _validate_tool_info(self, name: str, doc: str) -> None: """验证工具信息并发出警告""" @@ -96,16 +67,14 @@ def _validate_tool_info(self, name: str, doc: str) -> None: if len(name) > self.MAX_FUNC_NAME_LENGTH: warnings.warn(f"工具名称 '{name}' 超过 {self.MAX_FUNC_NAME_LENGTH} 字符", UserWarning, stacklevel=3) - def _set_tool_category(self, tool_name: str) -> None: - """根据工具名称设置分类.""" - if tool_name in {"glob", "ls", "regex_search", "exact_search", "stat", "read", "read_lines", "symbol_ref"}: - self._tool_categories[tool_name] = "query" - elif tool_name in {"write", "edit", "confirm_edit"}: - self._tool_categories[tool_name] = "edit" - elif tool_name == "git": - self._tool_categories[tool_name] = "dangerous" + def _set_tool_category(self, tool: BaseTool) -> None: + """根据工具的 write_permission 属性设置分类.""" + if tool.name == "git": + self._tool_categories[tool.name] = "dangerous" + elif tool.write_permission: + self._tool_categories[tool.name] = "write" else: - self._tool_categories[tool_name] = "query" + self._tool_categories[tool.name] = "query" def register(self, workspace: Workspace) -> None: """为工作区注册工具""" @@ -114,7 +83,6 @@ def register(self, workspace: Workspace) -> None: from src.workspace.tools.git_tool import GitTool from src.workspace.tools.glob_tool import GlobTool from src.workspace.tools.ls_tool import LsTool - from src.workspace.tools.read_lines_tool import ReadLinesTool from src.workspace.tools.read_tool import ReadTool from src.workspace.tools.regex_search_tool import RegexSearchTool from src.workspace.tools.stat_tool import StatTool @@ -127,7 +95,6 @@ def register(self, workspace: Workspace) -> None: ExactSearchTool, GlobTool, LsTool, - ReadLinesTool, ReadTool, RegexSearchTool, WriteTool, @@ -142,33 +109,11 @@ def register(self, workspace: Workspace) -> None: warnings.warn(f"工具{tool.name}没有注册功能回调和参数", stacklevel=2) continue self._tools[tool.name] = tool - self._set_tool_category(tool.name) + self._set_tool_category(tool) except ValueError: pass - def _compress_result(self, result: Any) -> Any: - """压缩过长的结果""" - result_length = len(result) - if isinstance(result, str): - if result_length > self.MAX_RESULT_LENGTH: - return ( - result[: self.MAX_RESULT_LENGTH] - + f"... [字符串结果已截断 显示的字符数: {self.LIST_TRUNCATE_THRESHOLD} / {result_length}]" - ) - elif isinstance(result, (list, tuple)): - if result_length > self.LIST_TRUNCATE_THRESHOLD: - return [ - *list(result[: self.LIST_TRUNCATE_THRESHOLD]), - f"... [列表已截断 显示的项: {self.LIST_TRUNCATE_THRESHOLD} / {result_length}]", - ] - elif isinstance(result, dict) and result_length > self.DICT_TRUNCATE_THRESHOLD: - compressed = {k: result[k] for k in list(result.keys())[: self.DICT_TRUNCATE_THRESHOLD]} - compressed["..."] = f"[字典已截断 显示的项: {self.DICT_TRUNCATE_THRESHOLD} / {result_length}]" - return compressed - - return result - - def execute(self, func_name: str, *args: Any, **kwargs: Any) -> Any: + def execute(self, func_name: str, *args: Any, **kwargs: Any) -> ToolResult: """ 执行工具函数 @@ -180,37 +125,48 @@ def execute(self, func_name: str, *args: Any, **kwargs: Any) -> Any: Returns: 函数执行结果(自动压缩过长的结果) """ - if func_name in self._tools: - tool = self._tools[func_name] - kwargs = tool.convert_args(kwargs) + try: + if func_name in self._tools: + tool = self._tools[func_name] + kwargs = tool.convert_args(kwargs) + + start_time = time.perf_counter() - start_time = time.perf_counter() - status = "success" - try: if inspect.iscoroutinefunction(tool.func): coro = tool.func(*args, **kwargs) # 已有事件循环 try: loop = asyncio.get_running_loop() nest_asyncio.apply() - result = loop.run_until_complete(coro) + raw_result = loop.run_until_complete(coro) except RuntimeError: # pragma: no cover // pytest内置事件循环, 测不到这里 # 没有运行中的事件循环 - result = asyncio.run(coro) + raw_result = asyncio.run(coro) else: # 同步函数 - result = tool.func(*args, **kwargs) - except Exception as e: - status = "error" - result = f'' - - duration_ms = (time.perf_counter() - start_time) * 1000 - self._log_tool_call(func_name, kwargs, duration_ms, status) - self._record_tool_call_summary(func_name, kwargs, result) - - return self._compress_result(result) - else: - raise ValueError(f"未找到工具: {func_name}") + raw_result = tool.func(*args, **kwargs) + + # 统一解包 ToolResult + result = ( + raw_result + if (isinstance(raw_result, ToolResult)) + else ( + ToolResult( + success=False, + func_name=func_name, + func_kwargs=kwargs, + error=f"错误的工具返回值类型: {raw_result.__class__.__name__}", + ) + ) + ) + duration_ms = (time.perf_counter() - start_time) * 1000 + self._log_tool_call(func_name, kwargs, duration_ms, result.status) + self._record_tool_call_summary(func_name, kwargs, result.response) + return result + else: + raise ValueError(f"未找到工具: {func_name}") + except Exception as e: + return ToolResult(success=False, data=kwargs, error=str(e), func_name=func_name, func_kwargs=kwargs) def generate_markdown(self) -> str: """ @@ -238,8 +194,10 @@ def get_tool_info(self, name: str) -> BaseTool | None: return self._tools.get(name) def set_session_id(self, session_id: int) -> None: - """设置当前会话 ID""" + """设置当前会话 ID,并同步到关联的 Workspace.""" self._current_session_id = session_id + if self._workspace is not None: + self._workspace.session_id = session_id @staticmethod def _compute_kwargs_json(kwargs: dict) -> str: @@ -250,13 +208,13 @@ def _compute_kwargs_json(kwargs: dict) -> str: return json.dumps(truncated, sort_keys=True, default=str) def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, status: str) -> str | None: - session_id = getattr(self, "_current_session_id", None) - if session_id is None: + if self._current_session_id is None: return None try: kwargs_json = self._compute_kwargs_json(kwargs) workspace = getattr(self, "_workspace", None) if workspace is not None: + session_id = self._current_session_id # Determine audit_status based on tool category category = self._tool_categories.get(func_name, "query") audit_status = "none" @@ -268,7 +226,7 @@ def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, statu command_str = kwargs.get("command_str", "") if not GitTool.is_safe_command(command_str): audit_status = "PENDING_AUDIT" - elif category == "edit": + elif category == "write": audit_status = "none" # Snapshot has its own PENDING_AUDIT workspace.db.log_tool_call( @@ -284,12 +242,12 @@ def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, statu return f"ToolRegistry(sync_tools={len(self._tools)})" def _record_tool_call_summary(self, func_name: str, kwargs: dict, result: Any) -> None: - session_id = getattr(self, "_current_session_id", None) - if session_id is None: + if self._current_session_id is None: return # Exclude write tools - if func_name in {"write", "edit", "confirm_edit"}: + tool = self._tools.get(func_name) + if tool is not None and tool.write_permission: return try: @@ -297,6 +255,6 @@ def _record_tool_call_summary(self, func_name: str, kwargs: dict, result: Any) - workspace = getattr(self, "_workspace", None) if workspace is not None: result_str = _to_string(result) - workspace.db.record_tool_call_summary(session_id, func_name, kwargs_json, result_str) + workspace.db.record_tool_call_summary(self._current_session_id, func_name, kwargs_json, result_str) except Exception: pass diff --git a/src/models/tools/tool_result.py b/src/models/tools/tool_result.py new file mode 100644 index 0000000..763e911 --- /dev/null +++ b/src/models/tools/tool_result.py @@ -0,0 +1,143 @@ +import json +import os +import warnings +from typing import Any, ClassVar + +from src.utils.string_snapshot import truncate_params_string + + +def to_xml_string(func_name: str, params: dict, data: Any = None, err: str | None = None) -> str: + params = params.copy() + params.pop("self", None) + + try: + messages: list[str] = [] + if data is not None: + messages.append("") + if isinstance(data, str): + messages.append(data) + elif isinstance(data, (dict, list, tuple)): + messages.append(json.dumps(data)) + else: + messages.append(f"{data.__class__.__name__}({data})") + messages.append("") + if err: + messages.extend( + [ + "", + err, + "", + ] + ) + if data is None and not err: + messages.append("没有任何工具调用数据或错误详情, 请提示用户检查工具是否正常") + + temp_result = [ + "", + f"", + *messages, + "", + "", + ] + func_result = str.join("\n", temp_result) + except Exception as e: + import traceback + + func_result = "\n".join( + [ + "", + f"", + f"Error={e.__class__.__name__}({e}, {traceback.format_exc()})", + err if err else "", + "", + "", + ] + ) + return func_result + + +class ToolResult: + """工具执行结果的结构化包装,显式区分成功与失败. + + 所有被 @handle_tool_exceptions 装饰的工具方法均返回此类型, + 调用方可通过 success 标志可靠判断执行状态,无需依赖隐式类型约定. + """ + + __slots__ = ("data", "error", "func_kwargs", "func_name", "response", "success") + + HAD_VALIDATE: ClassVar[bool] = False + MAX_RESULT_LENGTH: ClassVar[int] = int(os.getenv("TOOL_MAX_RESULT_LENGTH", "30000")) + LIST_TRUNCATE_THRESHOLD: ClassVar[int] = int(os.getenv("TOOL_LIST_TRUNCATE_THRESHOLD", "100")) + DICT_TRUNCATE_THRESHOLD: ClassVar[int] = int(os.getenv("TOOL_DICT_TRUNCATE_THRESHOLD", "100")) + + def __init__( + self, success: bool, func_name: str, func_kwargs: dict, data: Any = None, error: str | None = None + ) -> None: + self.success: bool = success + self.func_name: str = func_name + self.func_kwargs: dict = func_kwargs + self.data: Any = data + self.error: str | None = error + self.response: str = to_xml_string(self.func_name, self.func_kwargs, self.data, self.error) + ToolResult._validate_config() + + def __repr__(self) -> str: + if self.success: + return f"ToolResult(success=True, data={self.data!r})" + return f"ToolResult(success=False, error={self.error!r})" + + @property + def status(self) -> str: + return "success" if self.success else "error" + + @classmethod + def _compress_result(cls, result: Any) -> Any: + """压缩过长的结果""" + result_length = len(result) + if isinstance(result, str): + if result_length > cls.MAX_RESULT_LENGTH: + return ( + result[: cls.MAX_RESULT_LENGTH] + + f"... [字符串结果已截断 显示的字符数: {cls.LIST_TRUNCATE_THRESHOLD} / {result_length}]" + ) + elif isinstance(result, (list, tuple)): + if result_length > cls.LIST_TRUNCATE_THRESHOLD: + return [ + *list(result[: cls.LIST_TRUNCATE_THRESHOLD]), + f"... [列表已截断 显示的项: {cls.LIST_TRUNCATE_THRESHOLD} / {result_length}]", + ] + elif isinstance(result, dict) and result_length > cls.DICT_TRUNCATE_THRESHOLD: + compressed = {k: result[k] for k in list(result.keys())[: cls.DICT_TRUNCATE_THRESHOLD]} + compressed["..."] = f"[字典已截断 显示的项: {cls.DICT_TRUNCATE_THRESHOLD} / {result_length}]" + return compressed + + return result + + @classmethod + def _validate_config(cls) -> None: + """验证配置值确保在合理范围内""" + if cls.HAD_VALIDATE: + return None + if cls.MAX_RESULT_LENGTH < 10: + warnings.warn( + f"TOOL_MAX_RESULT_LENGTH 过小({cls.MAX_RESULT_LENGTH}),建议至少为100", UserWarning, stacklevel=2 + ) + cls.MAX_RESULT_LENGTH = 100 + + if cls.LIST_TRUNCATE_THRESHOLD < 10: + warnings.warn( + f"TOOL_LIST_TRUNCATE_THRESHOLD 过小({cls.LIST_TRUNCATE_THRESHOLD}),建议至少为50", + UserWarning, + stacklevel=2, + ) + cls.LIST_TRUNCATE_THRESHOLD = 50 + + if cls.DICT_TRUNCATE_THRESHOLD < 10: + warnings.warn( + f"TOOL_DICT_TRUNCATE_THRESHOLD 过小({cls.DICT_TRUNCATE_THRESHOLD}),建议至少为50", + UserWarning, + stacklevel=2, + ) + cls.DICT_TRUNCATE_THRESHOLD = 50 + cls.HAD_VALIDATE = True + return None diff --git a/src/workspace/tools/base_tool.py b/src/workspace/tools/base_tool.py index 8d456a9..70e6ab4 100644 --- a/src/workspace/tools/base_tool.py +++ b/src/workspace/tools/base_tool.py @@ -4,26 +4,25 @@ from typing import Any from src.core.file_tracker import FileTracker +from src.models.tools.tool_result import ToolResult from src.workspace.workspace import Workspace -def build_param_doc(name: str, params: dict[str, Any]) -> str: - """Generate a concise XML parameter doc.""" +def build_param_list_item(name: str, params: dict[str, Any], description: str = "") -> str: + """Generate a Markdown list item describing a parameter.""" from src.constants.prompts import clean_type_annotation - result = f' Any: @@ -110,8 +109,9 @@ def __init__( self.write_permission: bool = write_permission self.name: str = name self.doc: str = doc - self.func: Callable[..., Any] | None = None + self.func: Callable[..., ToolResult] | None = None self.params: dict[str, Any] | None = None + self.param_descriptions: dict[str, str] = {} def to_doc(self) -> str: """转换为模型可读文档格式""" @@ -119,7 +119,8 @@ def to_doc(self) -> str: if self.params and len(self.params) > 0: lines.append(" ") for name, param in self.params.items(): - lines.append(f" {build_param_doc(name, param)}") + desc = self.param_descriptions.get(name, "") + lines.append(f" {build_param_list_item(name, param, desc)}") lines.append(" ") else: lines.append(" ") @@ -139,7 +140,7 @@ def to_func_call(self) -> str: return func_call @staticmethod - def extract_params(func: Callable[..., Any]) -> dict[str, Any]: + def extract_params(func: Callable[..., ToolResult]) -> dict[str, Any]: """提取函数参数信息""" sig = inspect.signature(func) params = {} @@ -181,7 +182,7 @@ def _record_read_meta(self, resolved_path: Path) -> None: try: meta = FileTracker.get_file_meta(resolved_path) if meta: - session_id = self.workspace._current_session_id + session_id = self.workspace.session_id if session_id is not None: rel_path = str(resolved_path.relative_to(self.workspace.root_path)) self.workspace.db.record_file_read( @@ -195,7 +196,7 @@ def _validate_mtime(self, resolved_path: Path) -> str | None: if not resolved_path.exists(): return None - session_id = self.workspace._current_session_id + session_id = self.workspace.session_id if session_id is None: return None @@ -224,25 +225,60 @@ def _generate_diff(old_content: str, new_content: str, file_path: str) -> str: diff = difflib.unified_diff(old_lines, new_lines, fromfile=f"a/{file_path}", tofile=f"b/{file_path}") return "".join(diff) + @classmethod + def make_tool_result_response( + cls, success: bool, kwargs: dict, data: Any = None, error: str | None = None + ) -> ToolResult: + return ToolResult(success=success, func_name=cls.__name__, func_kwargs=kwargs, data=data, error=error) + + @classmethod + def make_success_response(cls, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult: + return cls.make_tool_result_response(success=True, kwargs=kwargs, data=data, error=error) + + @classmethod + def make_failed_response(cls, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult: + return cls.make_tool_result_response(success=False, kwargs=kwargs, data=data, error=error) + @staticmethod - def handle_tool_exceptions(func): - """工具方法异常处理装饰器.""" + def handle_tool_exceptions(func) -> Callable[..., ToolResult]: + """工具方法异常处理装饰器 —— 将异常转换为 ToolResult 失败结果""" from functools import wraps - from src.models.tool_error_response import ToolErrorResponse from src.workspace.path_validator import PathNotFoundError, WorkspaceBoundaryError @wraps(func) def wrapper(self, *args, **kwargs): try: - return func(self, *args, **kwargs) + raw = func(self, *args, **kwargs) + # 如果工具内部已返回 ToolResult, 直接透传 + if isinstance(raw, ToolResult): + return raw + # 否则包装为成功结果 + return ToolResult(success=True, func_name=func.__name__, func_kwargs=kwargs, data=raw) except PathNotFoundError as err1: - return ToolErrorResponse(self.__class__.__name__, err1).to_str() + return ToolResult( + success=False, + func_name=func.__name__, + func_kwargs=kwargs, + error=f"{err1.__class__.__name__}: {err1}", + ) except WorkspaceBoundaryError as err2: - return ToolErrorResponse(self.__class__.__name__, err2).to_str() + return ToolResult( + success=False, + func_name=func.__name__, + func_kwargs=kwargs, + error=f"{err2.__class__.__name__}: {err2}", + ) except PermissionError as err3: - return ToolErrorResponse(self.__class__.__name__, err3).to_str() + return ToolResult( + success=False, + func_name=func.__name__, + func_kwargs=kwargs, + error=f"{err3.__class__.__name__}: {err3}", + ) except Exception as err: - return ToolErrorResponse(self.__class__.__name__, err).to_str() + return ToolResult( + success=False, func_name=func.__name__, func_kwargs=kwargs, error=f"{err.__class__.__name__}: {err}" + ) return wrapper diff --git a/src/workspace/tools/edit_tool.py b/src/workspace/tools/edit_tool.py index 15553ee..b33c154 100644 --- a/src/workspace/tools/edit_tool.py +++ b/src/workspace/tools/edit_tool.py @@ -3,6 +3,7 @@ from pathlib import Path from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.utils.binary_detector import is_binary_file from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -19,61 +20,55 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "edit", self.edit.__doc__, write_permission=True) self.func = self.edit self.params = BaseTool.extract_params(self.edit) + self.param_descriptions = { + "path": "文件路径", + "old_string": "待替换的字符串", + "new_string": "替换后的字符串", + "max_replacements": "最大替换次数(1~100)", + "context_before": "匹配前的上下文文本", + "context_after": "匹配后的上下文文本", + } @BaseTool.handle_tool_exceptions def edit( self, - file_path: str, + path: str, old_string: str, new_string: str, max_replacements: int = 10, context_before: str = "", context_after: str = "", - ) -> str: + ) -> ToolResult: """ - 在文件中进行安全的字符串替换(仅预览,不修改磁盘) - - 执行 dry-run 替换,生成 diff,记录 PENDING_AUDIT 快照. - 批准后由 AuditCommitter 执行实际写入. - - Parameters - ---------- - file_path: 文件路径 - old_string: 待替换的字符串(不能为空) - new_string: 替换后的字符串 - max_replacements: 最大替换次数(默认 10,最大 100) - context_before: 匹配前的上下文文本(可选,用于校验) - context_after: 匹配后的上下文文本(可选,用于校验) + 通过在文件中进行安全的字符串替换编辑文件 """ # 1. 参数校验 if not old_string: - return ToolErrorResponse(self.__class__.__name__, ValueError("old_string 不能为空")).to_str() + return self.make_failed_response(locals().copy(), error=f"{ValueError('old_string 不能为空')}") if max_replacements < 1: - return ToolErrorResponse(self.__class__.__name__, ValueError("max_replacements 必须 >= 1")).to_str() + return self.make_failed_response(locals().copy(), error=f"{ValueError('max_replacements 必须 >= 1')}") if max_replacements > 100: max_replacements = 100 # 2. 路径解析 - source_file_path = Path(file_path) - resolved_path: Path = self.workspace.path_validator.resolve_path(source_file_path) + source_path = Path(path) + resolved_path: Path = self.workspace.path_validator.resolve_path(source_path) if not resolved_path.is_file(): - return ToolErrorResponse( - self.__class__.__name__, - FileNotFoundError(f"文件不存在: {resolved_path}"), - ).to_str() + return self.make_failed_response( + locals().copy(), error=f"{FileNotFoundError(f'文件不存在: {resolved_path}')}" + ) if is_binary_file(resolved_path): - return ToolErrorResponse( - self.__class__.__name__, - ValueError(f"禁止编辑二进制文件: {resolved_path}"), - ).to_str() + return self.make_failed_response( + locals().copy(), error=f"{ValueError(f'禁止编辑二进制文件: {resolved_path}')}" + ) # 3. mtime 校验 mtime_error = self._validate_mtime(resolved_path) if mtime_error: - return mtime_error + return self.make_failed_response(locals().copy(), error=f"无法编辑被修改过的文件:\n{mtime_error}") # 4. 读取文件内容 old_content = resolved_path.read_text(encoding="utf-8") @@ -91,12 +86,17 @@ def edit( if context_before or context_after: ctx_error = self._check_context(old_content, idx, old_string, context_before, context_after, count) if ctx_error: - return ctx_error + return self.make_failed_response( + locals().copy(), error=f"无法修改上下文不匹配的字符串:\n{ctx_error}" + ) idx += len(old_string) if count == 0: - return f"No changes made: old_string not found in file.\nFile: {file_path}\nSearching for: '{old_string}'" + return self.make_failed_response( + locals().copy(), + error=f"No changes made: old_string not found in file.\nFile: {path}\nSearching for: '{old_string}'", + ) # 6. 执行替换(生成新内容) new_content = old_content.replace(old_string, new_string, count) @@ -110,7 +110,7 @@ def edit( old_hash = FileTracker.compute_checksum_from_string(old_content) new_hash = FileTracker.compute_checksum_from_string(new_content) - session_id = self.workspace._current_session_id + session_id = self.workspace.session_id snapshot_id = self.workspace.db.record_file_snapshot( rel_path, old_hash, @@ -122,12 +122,16 @@ def edit( ) # 9. 返回预览 - return ( - f"[Edit Preview]\n" - f"File: {rel_path}\n" - f"Snapshot ID: {snapshot_id}\n" - f"Replacements: {count}\n" - f"Diff:\n{diff_content}" + return self.make_success_response( + locals().copy(), + ( + "修改已推送到审核系统\n" + f"[Edit Preview]\n" + f"File: {rel_path}\n" + f"Snapshot ID: {snapshot_id}\n" + f"Replacements: {count}\n" + f"Diff:\n{diff_content}" + ), ) @staticmethod diff --git a/src/workspace/tools/exact_search_tool.py b/src/workspace/tools/exact_search_tool.py index b23eaa6..873df3a 100644 --- a/src/workspace/tools/exact_search_tool.py +++ b/src/workspace/tools/exact_search_tool.py @@ -2,6 +2,7 @@ import re from pathlib import Path +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -76,6 +77,15 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "exact_search", self.exact_search.__doc__) self.func = self.exact_search self.params = BaseTool.extract_params(self.exact_search) + self.param_descriptions = { + "pattern": "搜索字符串", + "path": "搜索文件或文件夹路径", + "case_sensitive": "是否大小写敏感", + "whole_word": "是否全词匹配", + "file_pattern": "文件匹配模式,支持通配符", + "limit": "最大匹配数量限制", + "ignore": "忽略匹配正则的文件或文件夹列表", + } @BaseTool.handle_tool_exceptions def exact_search( @@ -84,22 +94,12 @@ def exact_search( path: str = ".", case_sensitive: bool = True, whole_word: bool = True, + file_pattern: str = "*", limit: int = 256, ignore: list[str] | None = None, - ) -> str: + ) -> ToolResult: """ - 精确搜索字符串(支持大小写敏感/全词匹配) - - Args: - pattern: 搜索字符串 - path: 搜索路径,默认为当前目录 - case_sensitive: 是否大小写敏感,默认为True - whole_word: 是否全词匹配,默认为True - limit: 最大匹配数量限制,默认为256 - ignore: 忽略匹配正则的文件或文件夹列表 - - Returns: - 格式化的搜索结果字符串 + 精确搜索字符串 """ # 验证搜索路径 search_path: Path = self.workspace.path_validator.validate(path) @@ -117,16 +117,18 @@ def exact_search( # 搜索结果 results = [] file_count = 0 + total_matches = 0 + warnings = [""] # 确定要搜索的文件列表(支持单文件或目录) - files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob("*")) + files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob(file_pattern)) # 遍历所有文件 for file_path in files_to_search: if not file_path.is_file(): continue # 检查是否达到限制 - if len(results) >= limit: + if total_matches >= limit: break # 检查是否应该忽略 @@ -152,9 +154,17 @@ def exact_search( if file_matches: results.append({"file": str(file_path), "matches": file_matches}) file_count += 1 + total_matches += len(file_matches) - except OSError, UnicodeDecodeError, PermissionError: + except (OSError, UnicodeDecodeError, PermissionError) as e: + warnings.append(f"在文件{file_path}搜索匹配行时出错: {e}") continue # 跳过无法读取的文件 + warnings.append("") + # 格式化输出 - return _format_exact_results(results, pattern, limit, file_count, case_sensitive, whole_word) + return self.make_success_response( + kwargs=locals().copy(), + data=_format_exact_results(results, pattern, limit, file_count, case_sensitive, whole_word), + error="\n".join(warnings) if len(warnings) > 2 else None, + ) diff --git a/src/workspace/tools/git_tool.py b/src/workspace/tools/git_tool.py index cf96864..a0cc2eb 100644 --- a/src/workspace/tools/git_tool.py +++ b/src/workspace/tools/git_tool.py @@ -4,7 +4,7 @@ import shlex import subprocess -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -55,60 +55,65 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "git", self.git.__doc__, read_permission=True) self.func = self.git self.params = BaseTool.extract_params(self.git) + self.param_descriptions = { + "command_str": "Git 子命令及其参数,如 'status'、'diff --cached'、'log --oneline -5'", + } - def git(self, command_str: str) -> str: + def git(self, command_str: str) -> ToolResult: """ - 执行 Git 命令(白名单限制) - - Parameters - ---------- - command_str: Git 子命令及其参数,如 "status"、"diff --cached"、"log --oneline -5" + 执行 Git 命令 """ if not command_str or not command_str.strip(): - return ToolErrorResponse(self.__class__.__name__, ValueError("command_str 不能为空")).to_str() + return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError("command_str 不能为空"))) try: tokens = shlex.split(command_str) except ValueError as e: - return ToolErrorResponse(self.__class__.__name__, e).to_str() + return self.make_failed_response(kwargs=locals().copy(), error=str(e)) if not tokens: - return ToolErrorResponse(self.__class__.__name__, ValueError("无法解析命令")).to_str() + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError(f"无法解析命令: `{command_str}`")) + ) base_command = tokens[0] # 1. 白名单检查 if base_command not in _ALLOWED_COMMANDS: allowed_list = ", ".join(sorted(_ALLOWED_COMMANDS)) - return ( - f"ERROR: Git command '{base_command}' is not in the allowed whitelist.\n" - f"Allowed commands: {allowed_list}" + return self.make_failed_response( + kwargs=locals().copy(), + error=( + f"ERROR: Git command '{base_command}' is not in the allowed whitelist.\n" + f"Allowed commands: {allowed_list}" + ), ) # 2. 拦截正则检查 for pattern in _BLOCKED_PATTERNS: if pattern.search(command_str): - return ( - f"ERROR: The command was blocked by security policy.\n" - f"Pattern matched: {pattern.pattern}\n" - f"Command: {command_str}" + return self.make_failed_response( + kwargs=locals().copy(), + error=( + f"ERROR: The command was blocked by security policy.\n" + f"Pattern matched: {pattern.pattern}\n" + f"Command: {command_str}" + ), ) # 3. restore 安全检查 — 必须指定文件路径 if base_command == "restore": non_flag_args = [t for t in tokens[1:] if not t.startswith("-")] if not non_flag_args: - return ToolErrorResponse( - self.__class__.__name__, - ValueError("restore 需要指定文件路径,不允许裸 restore"), - ).to_str() + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError("restore 需要指定文件路径,不允许裸 restore")) + ) for arg in non_flag_args: stripped = arg.strip() if stripped in (".", "*", "all") or stripped.startswith("*"): - return ToolErrorResponse( - self.__class__.__name__, - ValueError("restore 需要指定具体文件路径,不允许使用通配符"), - ).to_str() + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError("restore 需要指定具体文件路径,不允许使用通配符")) + ) # 4. 执行命令 try: @@ -122,22 +127,19 @@ def git(self, command_str: str) -> str: env=env, ) except FileNotFoundError: - return ToolErrorResponse( - self.__class__.__name__, - OSError("Git 未安装或不在系统 PATH 中"), - ).to_str() - except subprocess.TimeoutExpired: - return ToolErrorResponse( - self.__class__.__name__, - TimeoutError("Git 命令执行超时(30 秒)"), - ).to_str() + return self.make_failed_response(kwargs=locals().copy(), error=str(OSError("Git 未安装或不在系统 PATH 中"))) + except subprocess.TimeoutExpired as time_out_exception: + return self.make_failed_response( + kwargs=locals().copy(), error=f"TimeoutExpired(Git 命令执行超时: {time_out_exception})" + ) # 5. 处理输出 if result.returncode != 0: stderr = (result.stderr or "").strip() - if stderr: - return f"Git command failed (exit code {result.returncode}):\n{stderr}" - return f"Git command failed (exit code {result.returncode})" + return self.make_failed_response( + kwargs=locals().copy(), + error=f"Git command failed (exit code {result.returncode})" + f":\n{stderr}" if stderr else "", + ) # Combine stdout and stderr output_parts = [] @@ -166,9 +168,13 @@ def git(self, command_str: str) -> str: if err: output_parts.append(err.rstrip("\n")) if not output_parts and _result2.returncode != 0: - return f"Git command failed (exit code {_result2.returncode})" + return self.make_failed_response( + kwargs=locals().copy(), error=f"Git command failed (exit code {_result2.returncode})" + ) - return "\n".join(output_parts) if output_parts else "(no output)" + return self.make_success_response( + kwargs=locals().copy(), data="\n".join(output_parts) if output_parts else "(no output)" + ) @staticmethod def is_safe_command(command_str: str) -> bool: diff --git a/src/workspace/tools/glob_tool.py b/src/workspace/tools/glob_tool.py index 55ddabb..df158c0 100644 --- a/src/workspace/tools/glob_tool.py +++ b/src/workspace/tools/glob_tool.py @@ -1,6 +1,6 @@ from itertools import islice -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -10,27 +10,25 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "glob", self.glob.__doc__) self.func = self.glob self.params = BaseTool.extract_params(self.glob) + self.param_descriptions = { + "pattern": "通配符", + "path": "目录路径", + "max_ret": "最多返回多少条检索结果", + } @BaseTool.handle_tool_exceptions - def glob(self, pattern: str, folder_path: str = ".", max_ret: int = 1000) -> list[str]: + def glob(self, pattern: str, path: str = ".", max_ret: int = 1000) -> ToolResult: """ 在工作区内按通配符模式匹配并列出所有路径,带[Folder]或[File]的类型标记. 失败时返回错误信息 - - Parameters - ---------- - pattern: 通配符 - folder_path: 目录路径 - max_ret: 最多返回多少条检索结果 - - Returns - ------- - 检索到的文件或文件夹的相对路径 """ - root_path = self.workspace.path_validator.validate(folder_path) + root_path = self.workspace.path_validator.validate(path) if not root_path.is_dir(): - return ToolErrorResponse(self.__class__.__name__, f"{root_path}不是一个文件夹路径").to_str() + return self.make_failed_response(kwargs=locals().copy(), error=f"{root_path}不是一个文件夹路径") - return [ - f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" - for item in islice(root_path.glob(pattern), max_ret) - ] + return self.make_success_response( + kwargs=locals().copy(), + data=[ + f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" + for item in islice(root_path.glob(pattern), max_ret) + ], + ) diff --git a/src/workspace/tools/ls_tool.py b/src/workspace/tools/ls_tool.py index 7f92edb..33f035d 100644 --- a/src/workspace/tools/ls_tool.py +++ b/src/workspace/tools/ls_tool.py @@ -1,6 +1,6 @@ from pathlib import Path -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -10,16 +10,22 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "ls", self.ls.__doc__) self.func = self.ls self.params = BaseTool.extract_params(self.ls) + self.param_descriptions = { + "path": "目录路径", + } @BaseTool.handle_tool_exceptions - def ls(self, folder_path: str = ".") -> list[str] | str: + def ls(self, path: str = ".") -> ToolResult: """ 列出指定目录下的文件和文件夹. 返回相对路径列表, 并标记[Folder]或[File] """ - path: Path = self.workspace.path_validator.validate(folder_path) - if not path.is_dir(): - return ToolErrorResponse(self.__class__.__name__, f'参数错误: "{path}"不是一个目录').to_str() - return [ - f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" - for item in path.iterdir() - ] + folder_path: Path = self.workspace.path_validator.validate(path) + if not folder_path.is_dir(): + return self.make_failed_response(kwargs=locals().copy(), error=f'参数错误: "{folder_path}"不是一个目录') + return self.make_success_response( + kwargs=locals().copy(), + data=[ + f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" + for item in folder_path.iterdir() + ], + ) diff --git a/src/workspace/tools/read_lines_tool.py b/src/workspace/tools/read_lines_tool.py deleted file mode 100644 index a00c41f..0000000 --- a/src/workspace/tools/read_lines_tool.py +++ /dev/null @@ -1,71 +0,0 @@ -from pathlib import Path - -from src.models.tool_error_response import ToolErrorResponse -from src.utils.binary_detector import is_binary_file -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - - -class ReadLinesTool(BaseTool): - def __init__(self, workspace: Workspace): - super().__init__(workspace, "read_lines", self.read_lines.__doc__) - self.func = self.read_lines - self.params = BaseTool.extract_params(self.read_lines) - - @BaseTool.handle_tool_exceptions - def read_lines(self, file_path: str, start: int, end: int, context: int = 2, encoding: str = "utf-8") -> str: - """ - 读取文件的指定行范围(行号从1开始), 可指定上下文行数扩展返回的实际行数范围,返回带行号的格式化内容 - - Parameters - ---------- - file_path: 文件路径 - start: 开始行数 - end: 结束行数 - context: 扩展结果行数范围 行数范围最终为(start-context, end+context) - encoding: 编码 - """ - path: Path = self.workspace.path_validator.validate(file_path) - - if not path.is_file(): - return ToolErrorResponse(self.__class__.__name__, ValueError(f"读取文件{path}时未读取到完整文件")).to_str() - - if is_binary_file(path): - return ToolErrorResponse( - self.__class__.__name__, - ValueError(f"无法读取二进制文件: {path}. 请使用二进制安全工具或转换为 base64."), - ).to_str() - - with open(path, encoding=encoding) as f: - lines = f.readlines() - - total_lines = len(lines) - - context = max(0, context) - - start -= context - end += context - - # 验证行号 - if start < 1: - start = 1 - if start > total_lines: - return f"错误:起始行 {start} 超过文件总行数 ({total_lines})" - - end = total_lines if end is None else min(end, total_lines) - - if end < start: - return f"错误:结束行 {end} 小于起始行 {start}" - - result_lines = [] - for i in range(start - 1, end): - line_num = i + 1 - content = lines[i].rstrip("\n\r") - result_lines.append(f"{line_num:6d} | {content}") - - header = f"\n[文件: {path}]\n[行 {start}-{end} / 共 {total_lines} 行]\n" - separator = "-" * 80 + "\n" - - self._record_read_meta(path) - - return header + separator + "\n".join(result_lines) diff --git a/src/workspace/tools/read_tool.py b/src/workspace/tools/read_tool.py index ba9b789..7bac9d8 100644 --- a/src/workspace/tools/read_tool.py +++ b/src/workspace/tools/read_tool.py @@ -1,55 +1,108 @@ +import os from pathlib import Path -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.utils.binary_detector import is_binary_file from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace +# 最大文件读取大小, 超过此大小的文件将被拒绝读取(默认 10MB) +MAX_FILE_SIZE = int(os.getenv("TOOL_READ_MAX_FILE_SIZE", str(10 * 1024 * 1024))) + + +def _resolve_index(idx: int, total: int) -> int: + """Resolve a 1-based or negative index to a clamped 1-based line number.""" + if idx < 0: + idx = total + 1 + idx + if idx < 1: + return 1 + if idx > total: + return total + return idx + class ReadTool(BaseTool): def __init__(self, workspace: Workspace): super().__init__(workspace, "read", self.read.__doc__) self.func = self.read self.params = BaseTool.extract_params(self.read) + self.param_descriptions = { + "path": "文件路径", + "start": "起始行号(1开始; 负数表示倒数, -1=最后一行)", + "end": "结束行号(1开始; 负数表示倒数, -1=最后一行)", + "context": "扩展结果行数范围 行数范围最终为(start-context, end+context)", + "encoding": "编码", + } @BaseTool.handle_tool_exceptions - def read(self, file_path: str, max_lines: int = 0, encoding: str = "utf-8") -> str: + def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encoding: str = "utf-8") -> ToolResult: """ - 读取文件内容,可限制最大行数,返回文件内容字符串(带行号) - - Parameters - ---------- - file_path: 文件路径 - max_lines: 最大行数(0表示不限制) - encoding: 编码 + 读取文件内容, 返回带行号的格式化内容 """ - path: Path = self.workspace.path_validator.validate(file_path) + file_path: Path = self.workspace.path_validator.validate(path) + + if not file_path.is_file(): + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError(f"读取文件{file_path}时未读取到完整文件")) + ) - if not path.is_file(): - return ToolErrorResponse(self.__class__.__name__, ValueError(f"读取文件{path}时未读取到完整文件")).to_str() + if is_binary_file(file_path): + return self.make_failed_response( + kwargs=locals().copy(), + error=str(ValueError(f"无法读取二进制文件: {file_path}. 请使用二进制安全工具或转换为 base64.")), + ) - if is_binary_file(path): - return ToolErrorResponse( - self.__class__.__name__, - ValueError(f"无法读取二进制文件: {path}. 请使用二进制安全工具或转换为 base64."), - ).to_str() + file_size = file_path.stat().st_size + if file_size > MAX_FILE_SIZE: + return self.make_failed_response( + kwargs=locals().copy(), + error=str( + ValueError( + f"文件过大 ({file_size} 字节), 超过最大限制 ({MAX_FILE_SIZE} 字节): {file_path}. " + f"请使用范围参数 (start/end) 分批读取." + ) + ), + ) - with open(path, encoding=encoding) as f: + with open(file_path, encoding=encoding) as f: lines = f.readlines() total_lines = len(lines) - if max_lines > 0: - lines = lines[:max_lines] + if total_lines == 0: + header = f"\n[文件: {file_path}]\n[行 0-0 / 共 0 行]\n" + separator = "-" * 80 + "\n" + self._record_read_meta(file_path) + return self.make_success_response(kwargs=locals().copy(), data=header + separator) + + context = max(0, context) + + actual_start = _resolve_index(start, total_lines) - context + actual_end = _resolve_index(end, total_lines) + context + + if actual_start < 1: + actual_start = 1 + if actual_end > total_lines: + actual_end = total_lines + + if actual_end < actual_start: + return self.make_failed_response( + kwargs=locals().copy(), + error=( + f"错误:解析后的结束行 {actual_end} 小于起始行 {actual_start} " + f"(原始参数: start={start}, end={end}, context={context})" + ), + ) result_lines = [] - for i, line in enumerate(lines, 1): - content = line.rstrip("\n\r") - result_lines.append(f"{i:6d} | {content}") + for i in range(actual_start - 1, actual_end): + line_num = i + 1 + content = lines[i].rstrip("\n\r") + result_lines.append(f"{line_num:6d} | {content}") - header = f"\n[文件: {path}]\n[行 1-{len(lines)} / 共 {total_lines} 行]\n" + header = f"\n[文件: {file_path}]\n[行 {actual_start}-{actual_end} / 共 {total_lines} 行]\n" separator = "-" * 80 + "\n" - self._record_read_meta(path) + self._record_read_meta(file_path) - return header + separator + "\n".join(result_lines) + return self.make_success_response(kwargs=locals().copy(), data=header + separator + "\n".join(result_lines)) diff --git a/src/workspace/tools/regex_search_tool.py b/src/workspace/tools/regex_search_tool.py index 59114ad..a8b6549 100644 --- a/src/workspace/tools/regex_search_tool.py +++ b/src/workspace/tools/regex_search_tool.py @@ -2,7 +2,7 @@ import re from pathlib import Path -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -103,6 +103,14 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "regex_search", self.regex_search.__doc__) self.func = self.regex_search self.params = BaseTool.extract_params(self.regex_search) + self.param_descriptions = { + "pattern": "正则表达式模式", + "path": "搜索文件或文件夹路径", + "context": "显示匹配行的上下文行数", + "file_pattern": "文件匹配模式,支持通配符", + "limit": "最大匹配数量限制", + "ignore": "忽略匹配正则的文件或文件夹列表", + } @BaseTool.handle_tool_exceptions def regex_search( @@ -113,18 +121,9 @@ def regex_search( file_pattern: str = "*", limit: int = 256, ignore: list[str] | None = None, - ) -> str: + ) -> ToolResult: """ 使用正则表达式搜索文件内容, 支持上下文显示、文件过滤和忽略路径, 返回匹配详情; 适合代码与文档探索 - - Parameters - ---------- - pattern: 正则表达式模式 - path: 搜索路径,默认为当前目录 - context: 显示匹配行的上下文行数,默认为3 - file_pattern: 文件匹配模式,支持通配符,默认为"*" - limit: 最大匹配数量限制,默认为256 - ignore: 忽略匹配正则的文件或文件夹列表 """ # 验证搜索路径 search_path: Path = self.workspace.path_validator.validate(path) @@ -133,7 +132,7 @@ def regex_search( try: regex = re.compile(pattern) except re.error as e: - return ToolErrorResponse(self.__class__.__name__, f"无效的正则表达式: {e}").to_str() + return self.make_failed_response(kwargs=locals().copy(), error=f"无效的正则表达式: {e}") # 收集忽略模式 ignore_patterns = [] @@ -145,6 +144,8 @@ def regex_search( # 搜索结果 results = [] file_count = 0 + total_matches = 0 + warnings = [""] # 确定要搜索的文件列表(支持单文件或目录) files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob(file_pattern)) @@ -154,7 +155,7 @@ def regex_search( if not file_path.is_file(): continue # 检查是否达到限制 - if len(results) >= limit: + if total_matches >= limit: break # 检查是否应该忽略该文件或文件夹 @@ -181,9 +182,17 @@ def regex_search( if file_results: results.append({"file": str(file_path), "matches": file_results}) file_count += 1 + total_matches += len(file_results) - except OSError, UnicodeDecodeError, PermissionError: + except (OSError, UnicodeDecodeError, PermissionError) as e: + warnings.append(f"在文件{file_path}搜索匹配行时出错: {e}") continue # 跳过无法读取的文件 + warnings.append("") + # 格式化输出 - return _format_regex_results(results, pattern, limit, file_count) + return self.make_success_response( + kwargs=locals().copy(), + data=_format_regex_results(results, pattern, limit, file_count), + error="\n".join(warnings) if len(warnings) > 2 else None, + ) diff --git a/src/workspace/tools/stat_tool.py b/src/workspace/tools/stat_tool.py index 41981cf..f6caeae 100644 --- a/src/workspace/tools/stat_tool.py +++ b/src/workspace/tools/stat_tool.py @@ -2,6 +2,7 @@ from datetime import datetime from pathlib import Path +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -13,17 +14,14 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "stat", self.stat.__doc__) self.func = self.stat self.params = BaseTool.extract_params(self.stat) + self.param_descriptions = { + "path": "文件或目录路径", + } @BaseTool.handle_tool_exceptions - def stat(self, path: str = ".") -> str: + def stat(self, path: str = ".") -> ToolResult: """ 获取工作区内文件或目录的详细信息,包括大小、行数(仅文件)、修改时间、权限等 - - Args: - path: 文件或目录路径,默认为当前目录 - - Returns: - 格式化的详细信息字符串 """ # 验证路径 target_path: Path = self.workspace.path_validator.validate(path) @@ -145,4 +143,4 @@ def format_timestamp(timestamp: float) -> str: except PermissionError: output.append("目录内容: 无法访问") - return "\n".join(output) + return self.make_success_response(kwargs=locals().copy(), data="\n".join(output)) diff --git a/src/workspace/tools/symbol_ref_tool.py b/src/workspace/tools/symbol_ref_tool.py index 6c0c4df..4f42291 100644 --- a/src/workspace/tools/symbol_ref_tool.py +++ b/src/workspace/tools/symbol_ref_tool.py @@ -1,16 +1,12 @@ """符号引用查找工具 - 查找函数、类、变量等的定义和引用""" import re -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace -# 默认排除的目录(与 workspace.py 保持一致, 后续改为从项目配置加载) -DEFAULT_EXCLUDED_DIRS = {".git", "__pycache__", "node_modules", ".venv", "venv", "dist", "build", ".idea", ".vscode"} - def _get_file_pattern_by_language(language: str) -> str: """根据语言获取默认的文件匹配模式""" @@ -124,9 +120,9 @@ def _generate_patterns(symbol_name: str, language: str, include_def: bool, inclu return unique_patterns -def _search_single_pattern( +def _search_all_patterns( workspace: Workspace, - pattern_info: dict, + patterns: list[dict], search_path: str, symbol_name: str, context_lines: int, @@ -134,112 +130,124 @@ def _search_single_pattern( file_pattern: str, ignore: list[str] | None, ) -> list[dict]: - """使用 Workspace.search_content 执行单个模式的搜索(并发版本)""" - results = [] + """单次文件遍历搜索所有模式, 返回按文件分组且带上下文的结果. + + 相比旧实现的优势: + - 文件系统只遍历 1 次(旧: N 次, N=模式数量) + - 每个文件只读取 1 次(旧: N 次) + - 直接使用结构化数据, 无需 格式化→正则解析 的反模式 + - 真正实现上下文行读取(旧实现仅将匹配行本身作为上下文) + """ + # 编译所有正则模式 + compiled_patterns: list[tuple[re.Pattern, str]] = [] + for p in patterns: + try: + regex = re.compile(p["pattern"], re.IGNORECASE) + compiled_patterns.append((regex, p["type"])) + except re.error: + continue - try: - # 使用 workspace 的并发搜索能力 - search_result = workspace.search_content( - pattern=pattern_info["pattern"], - folder_path=search_path, - file_pattern=file_pattern, - max_workers=4, # 并发线程数 - case_sensitive=False, - ) + if not compiled_patterns: + return [] - # 解析 search_content 的返回结果 - if not search_result or "未找到匹配" in search_result: - return results - - # 解析结果格式: search_content 返回的是格式化的字符串 - # 格式: [文件] path\n----\n 行号 | 内容 - lines = search_result.split("\n") - current_file = None - current_matches = [] - - for line in lines: - if line.startswith("[文件] "): - # 保存上一个文件的结果 - if current_file and current_matches: - results.append({"file": current_file, "matches": current_matches, "type": pattern_info["type"]}) - current_file = line[6:] # 去掉 "[文件] " - current_matches = [] - elif line.startswith("----"): - continue - elif line.strip() and current_file: - # 解析匹配行: " 行号 | 内容" - match = re.match(r"\s+(\d+)\s*\|\s*(.*)", line) - if match: - line_num = int(match.group(1)) - content = match.group(2) - - # 收集上下文(这里简化处理,因为 search_content 不直接提供上下文) - # 为了保持兼容性,我们构建一个简化的上下文 - current_matches.append( - { - "line_num": line_num, - "content": content, - "context": [{"line_num": line_num, "content": content, "is_match": True}], - "match_type": pattern_info["type"], - "symbol_name": symbol_name, - } - ) + # 单次遍历搜索: 所有模式在一次文件遍历中完成 + matches = workspace.search_content_multi_pattern( + patterns=compiled_patterns, + folder_path=search_path, + file_pattern=file_pattern, + max_workers=4, + ignore=ignore, + ) - # 保存最后一个文件的结果 - if current_file and current_matches: - results.append({"file": current_file, "matches": current_matches, "type": pattern_info["type"]}) + if not matches: + return [] - except Exception: - # 如果搜索失败,静默跳过 - pass + # 应用 limit 截断 + matches = matches[:limit] - return results + # 按文件分组并构建带上下文的结果 + return _build_results_with_context(matches, context_lines, symbol_name, workspace.root_path) -def _search_patterns_concurrent( - workspace: Workspace, - patterns: list[dict], - search_path: str, - symbol_name: str, +def _build_results_with_context( + matches: list[dict], context_lines: int, - limit: int, - file_pattern: str, - ignore: list[str] | None, + symbol_name: str, + root_path: Path, ) -> list[dict]: - """并发执行多个模式的搜索""" - all_results = [] - - # 使用线程池并发执行多个模式的搜索 - with ThreadPoolExecutor(max_workers=min(len(patterns), 4)) as executor: - future_to_pattern = { - executor.submit( - _search_single_pattern, - workspace, - pattern_info, - search_path, - symbol_name, - context_lines, - limit - len(all_results), - file_pattern, - ignore, - ): pattern_info - for pattern_info in patterns - } + """从扁平匹配列表构建按文件分组、带上下文的结果""" + context_lines = max(0, context_lines) + + # 按文件分组, 保留首次出现的顺序 + file_matches: dict[str, list[dict]] = {} + file_order: list[str] = [] + for m in matches: + f = m["file"] + if f not in file_matches: + file_matches[f] = [] + file_order.append(f) + file_matches[f].append(m) + + results = [] + for file_rel in file_order: + file_match_list = file_matches[file_rel] + + # 收集需要读取的行号范围(匹配行 ± 上下文行) + needed_lines: set[int] = set() + for m in file_match_list: + for delta in range(-context_lines, context_lines + 1): + needed_lines.add(m["line_num"] + delta) + + # 仅读取需要的行(不加载整个文件到内存) + line_cache: dict[int, str] = {} + file_full_path = root_path / file_rel + try: + if needed_lines: + max_needed = max(needed_lines) + with open(file_full_path, encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + if line_num in needed_lines: + line_cache[line_num] = line.rstrip("\n\r") + if line_num >= max_needed: + break + except UnicodeDecodeError, PermissionError, OSError: + pass + + # 为每个匹配项构建上下文 + built_matches = [] + for m in file_match_list: + match_line_num = m["line_num"] + context = [] + for delta in range(-context_lines, context_lines + 1): + ctx_line_num = match_line_num + delta + if ctx_line_num in line_cache: + context.append( + { + "line_num": ctx_line_num, + "content": line_cache[ctx_line_num], + "is_match": delta == 0, + } + ) + + built_matches.append( + { + "line_num": match_line_num, + "content": m["content"], + "context": context, + "match_type": m["pattern_type"], + "symbol_name": symbol_name, + } + ) - for future in as_completed(future_to_pattern): - try: - results = future.result() - all_results.extend(results) - if len(all_results) >= limit: - # 达到限制,取消剩余任务 - for f in future_to_pattern: - f.cancel() - break - except Exception: - # 单个模式失败不影响其他模式 - pass + results.append( + { + "file": file_rel, + "matches": built_matches, + "type": file_match_list[0]["pattern_type"], + } + ) - return all_results[:limit] + return results def _detect_language(search_path: Path, specified_lang: str) -> str: @@ -347,6 +355,17 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "symbol_ref", self.symbol_ref.__doc__, read_permission=True, write_permission=False) self.func = self.symbol_ref self.params = BaseTool.extract_params(self.symbol_ref) + self.param_descriptions = { + "symbol_name": "要查找的符号名称(如函数名、类名、变量名)", + "path": "搜索文件或文件夹路径", + "language": "语言类型(auto/python/javascript/typescript/markdown/general)", + "include_definitions": "是否包含定义位置", + "include_references": "是否包含引用位置", + "context_lines": "显示匹配行的上下文行数", + "limit": "最大匹配数量限制", + "ignore": "忽略匹配正则的文件或文件夹列表", + "file_pattern": "文件匹配模式(如 *.py),默认根据语言自动选择", + } @BaseTool.handle_tool_exceptions def symbol_ref( @@ -360,31 +379,14 @@ def symbol_ref( limit: int = 256, ignore: list[str] | None = None, file_pattern: str | None = None, - ) -> str: + ) -> ToolResult: """ - 查找符号(函数、类、变量等)的定义和引用位置 - - 支持自动识别语言类型,生成智能搜索模式来定位符号的定义和所有使用位置. - 适用于代码探索、重构影响分析、理解代码结构等场景. - - Args: - symbol_name: 要查找的符号名称(如函数名、类名、变量名) - path: 搜索路径,默认为当前目录 - language: 语言类型(auto/python/javascript/typescript/markdown/general),默认auto - include_definitions: 是否包含定义位置,默认True - include_references: 是否包含引用位置,默认True - context_lines: 显示匹配行的上下文行数,默认2 - limit: 最大匹配数量限制,默认为256 - ignore: 忽略匹配正则的文件或文件夹列表 - file_pattern: 文件匹配模式(如 *.py),默认根据语言自动选择 - - Returns: - 格式化的符号引用搜索结果 + 查找符号(函数、类、变量等)的定义和引用位置, 适用于代码探索、重构影响分析、理解代码结构等场景. """ # 验证搜索路径 search_path: Path = self.workspace.path_validator.validate(path) if not search_path.exists(): - return ToolErrorResponse(self.__class__.__name__, f"路径不存在: {path}").to_str() + return self.make_failed_response(kwargs=locals().copy(), error=f"路径不存在: {path}") # 自动检测语言 detected_lang = _detect_language(search_path, language) @@ -397,10 +399,12 @@ def symbol_ref( patterns = _generate_patterns(symbol_name, detected_lang, include_definitions, include_references) if not patterns: - return f"无法为符号 '{symbol_name}' 生成有效的搜索模式(语言: {detected_lang})" + return self.make_failed_response( + kwargs=locals().copy(), error=f"无法为符号 '{symbol_name}' 生成有效的搜索模式(语言: {detected_lang})" + ) # 使用并发搜索执行所有模式 - all_results = _search_patterns_concurrent( + all_results = _search_all_patterns( self.workspace, patterns, path, @@ -412,4 +416,6 @@ def symbol_ref( ) # 格式化输出 - return _format_results(all_results, symbol_name, detected_lang, limit) + return self.make_success_response( + kwargs=locals().copy(), data=_format_results(all_results, symbol_name, detected_lang, limit) + ) diff --git a/src/workspace/tools/write_tool.py b/src/workspace/tools/write_tool.py index 1a0935e..fe1a09c 100644 --- a/src/workspace/tools/write_tool.py +++ b/src/workspace/tools/write_tool.py @@ -1,7 +1,7 @@ from pathlib import Path from src.core.file_tracker import FileTracker -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.utils.binary_detector import is_binary_file from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -12,50 +12,48 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "write", self.write.__doc__, write_permission=True) self.func = self.write self.params = BaseTool.extract_params(self.write) + self.param_descriptions = { + "path": "文件路径", + "content": "写入内容", + } @BaseTool.handle_tool_exceptions - def write(self, file_path: str, content: str = "") -> str: + def write(self, path: str, content: str = "") -> ToolResult: """ - 写入文件内容,如文件不存在则创建(含父目录) - - Parameters - ---------- - file_path: 文件路径 - content: 写入内容 + 写入文件内容, 如文件不存在则创建(含父目录) """ - source_file_path = Path(file_path) - file_path: Path = self.workspace.path_validator.resolve_path(source_file_path) + source_path = Path(path) + path: Path = self.workspace.path_validator.resolve_path(source_path) - if file_path.exists() and file_path.is_dir(): - return ToolErrorResponse( - self.__class__.__name__, ValueError(f"路径 {file_path} 是一个目录,无法写入") - ).to_str() + if path.exists() and path.is_dir(): + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError(f"路径 {path} 是一个目录,无法写入")) + ) - if is_binary_file(file_path): - return ToolErrorResponse( - self.__class__.__name__, - ValueError(f"禁止写入二进制文件: {file_path}"), - ).to_str() + if is_binary_file(path): + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError(f"禁止写入二进制文件: {path}")) + ) - mtime_error = self._validate_mtime(file_path) + mtime_error = self._validate_mtime(path) if mtime_error: - return mtime_error + return self.make_failed_response(locals().copy(), error=f"无法编辑被修改过的文件:\n{mtime_error}") old_content = "" old_meta = None - if file_path.exists() and file_path.is_file(): - old_meta = FileTracker.get_file_meta(file_path) + if path.exists() and path.is_file(): + old_meta = FileTracker.get_file_meta(path) try: - old_content = file_path.read_text(encoding="utf-8") + old_content = path.read_text(encoding="utf-8") except Exception: old_content = "" - rel_path = str(file_path.relative_to(self.workspace.root_path)) + rel_path = str(path.relative_to(self.workspace.root_path)) old_hash = old_meta.get("checksum") if old_meta else None new_hash = FileTracker.compute_checksum_from_string(content) diff_content = self._generate_diff(old_content, content, rel_path) - session_id = self.workspace._current_session_id + session_id = self.workspace.session_id snapshot_id = self.workspace.db.record_file_snapshot( rel_path, old_hash, @@ -66,4 +64,13 @@ def write(self, file_path: str, content: str = "") -> str: pending_content=content, ) - return f"[Write Preview]\nFile: {rel_path}\nSnapshot ID: {snapshot_id}\nDiff:\n{diff_content}" + return self.make_success_response( + kwargs=locals().copy(), + data=( + f"修改已推送到审核系统\n" + f"[Write Preview]\n" + f"File: {rel_path}\n" + f"Snapshot ID: {snapshot_id}\n" + f"Diff:\n{diff_content}" + ), + ) diff --git a/src/workspace/workspace.py b/src/workspace/workspace.py index 538a280..56699f6 100644 --- a/src/workspace/workspace.py +++ b/src/workspace/workspace.py @@ -7,6 +7,9 @@ from src.models.tool_error_response import ToolErrorResponse from src.workspace.path_validator import PathNotFoundError, PathValidator, WorkspaceBoundaryError +# 默认排除的目录 后续改为从项目配置加载 +DEFAULT_EXCLUDED_DIRS = {".git", "__pycache__", "node_modules", ".venv", "venv", "dist", "build", ".idea", ".vscode"} + def _highlight_matches(line: str, regex: re.Pattern) -> str: """ @@ -59,6 +62,15 @@ def db(self): self._db = DatabaseManager(str(self.root_path)) return self._db + @property + def session_id(self) -> int | None: + """当前会话 ID 的公开 getter —— 工具类通过此接口访问, 避免直接访问私有属性.""" + return self._current_session_id + + @session_id.setter + def session_id(self, value: int | None) -> None: + self._current_session_id = value + def search_content( self, pattern: str, @@ -73,7 +85,7 @@ def search_content( path = self.path_validator.validate(folder_path) # 初始化排除目录集合 - exclude_set = set(exclude_dirs or [".git", "__pycache__", "node_modules", ".venv", "venv", "dist", "build"]) + exclude_set = set(exclude_dirs or DEFAULT_EXCLUDED_DIRS) # 编译正则表达式 flags = 0 if case_sensitive else re.IGNORECASE @@ -176,3 +188,116 @@ def _search_in_file(self, file_path: Path, regex: re.Pattern) -> list[tuple]: pass return results + + def search_content_multi_pattern( + self, + patterns: list[tuple[re.Pattern, str]], + folder_path: str = ".", + file_pattern: str = "*", + max_workers: int = 4, + ignore: list[str] | None = None, + ) -> list[dict]: + """单次文件遍历匹配多个正则模式, 直接返回结构化数据. + + 与 search_content 的区别: + - 接受多个已编译的正则 + 类型标签, 一次遍历全部匹配 + - 返回结构化 list[dict] 而非格式化字符串, 消除下游正则解析反模式 + - 不做高亮处理(高亮是展示层关注点, 不应混入数据层) + + Args: + patterns: [(compiled_regex, type_label), ...] 已编译正则及其类型标签 + folder_path: 搜索起始路径 + file_pattern: 文件通配符(如 "*.py") + max_workers: 并发读取文件的线程数 + ignore: 忽略路径正则列表 + + Returns: + [{"file": str, "line_num": int, "content": str, "pattern_type": str}, ...] + 按文件路径 → 行号排序, 同一行多个模式匹配则每个各一条记录 + """ + try: + path = self.path_validator.validate(folder_path) + + # 预编译 ignore 正则 + ignore_res: list[re.Pattern] = [] + if ignore: + for ign in ignore: + try: + ignore_res.append(re.compile(ign)) + except re.error: + continue + + # 收集文件(一次遍历) + files_to_search: list[Path] = [] + if path.is_file(): + files_to_search = [path] + else: + for file_path in path.rglob(file_pattern): + if file_path.is_file(): + if any(p.name in DEFAULT_EXCLUDED_DIRS for p in file_path.parents): + continue + rel = str(file_path.relative_to(self.root_path)) + if any(ir.search(rel) for ir in ignore_res): + continue + files_to_search.append(file_path) + + if not files_to_search: + return [] + + # 并发搜索所有文件, 每个文件内一次读取、一次测试所有模式 + all_matches: list[dict] = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(self._search_multi_in_file, file_path, patterns): file_path + for file_path in files_to_search + } + for future in as_completed(futures): + try: + file_results = future.result() + if file_results: + all_matches.extend(file_results) + except Exception: + pass + + # 按文件路径 → 行号排序 + all_matches.sort(key=lambda m: (m["file"], m["line_num"])) + return all_matches + + except PathNotFoundError, WorkspaceBoundaryError, PermissionError: + return [] + except Exception: + return [] + + def _search_multi_in_file(self, file_path: Path, patterns: list[tuple[re.Pattern, str]]) -> list[dict]: + """在单个文件中一次读取、逐行测试所有模式. + + Args: + file_path: 文件绝对路径 + patterns: [(compiled_regex, type_label), ...] + + Returns: + 匹配项列表, 同一行多个模式匹配时每个各返回一条 + """ + results: list[dict] = [] + relative_path = str(file_path.relative_to(self.root_path)) + + try: + with open(file_path, encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + stripped = line.rstrip("\n\r") + for regex, pattern_type in patterns: + if regex.search(stripped): + results.append( + { + "file": relative_path, + "line_num": line_num, + "content": stripped, + "pattern_type": pattern_type, + } + ) + except UnicodeDecodeError, PermissionError: + pass + except Exception: + pass + + return results diff --git a/tests/core/test_tool_registry.py b/tests/core/test_tool_registry.py index e94f43e..ab1e2e2 100644 --- a/tests/core/test_tool_registry.py +++ b/tests/core/test_tool_registry.py @@ -33,31 +33,6 @@ def isolate_tool_registry(): DatabaseManager.reset_instances() -def test_validate_config(): - """测试配置验证 - 一次性测试所有阈值""" - config = ToolRegistry() - - # 设置所有值为过小 - config.MAX_RESULT_LENGTH = 5 - config.LIST_TRUNCATE_THRESHOLD = 3 - config.DICT_TRUNCATE_THRESHOLD = 2 - - # 验证触发3个警告 - with pytest.warns(UserWarning) as record: - config._validate_config() - - # 验证警告数量和内容 - assert len(record) == 3 - assert "TOOL_MAX_RESULT_LENGTH" in str(record[0].message) - assert "TOOL_LIST_TRUNCATE_THRESHOLD" in str(record[1].message) - assert "TOOL_DICT_TRUNCATE_THRESHOLD" in str(record[2].message) - - # 验证所有值都被修正 - assert config.MAX_RESULT_LENGTH == 100 - assert config.LIST_TRUNCATE_THRESHOLD == 50 - assert config.DICT_TRUNCATE_THRESHOLD == 50 - - def test_tool_registry_singleton(): """测试单例模式""" registry1 = ToolRegistry() @@ -67,14 +42,6 @@ def test_tool_registry_singleton(): assert id(registry1) == id(registry2) -def test_execute_nonexistent_tool(): - """测试执行不存在的工具""" - registry = ToolRegistry() - - with pytest.raises(ValueError, match="未找到工具: nonexistent"): - registry.execute("nonexistent") - - def test_validate_tool_info(): """测试工具信息验证""" registry = ToolRegistry() @@ -95,51 +62,6 @@ def test_validate_tool_info(): assert any(f"超过 {MAX_DOC_LENGTH} 字符" in str(warning.message) for warning in w) -def test_compress_result_string(): - """测试字符串结果压缩""" - registry = ToolRegistry() - - long_string = "x" * (MAX_RESULT_LENGTH + 10000) - - compressed = registry._compress_result(long_string) - assert "结果已截断" in compressed - - -def test_compress_result_list(): - """测试列表结果压缩""" - registry = ToolRegistry() - - long_list = list(range(150)) - - compressed = registry._compress_result(long_list) - assert len(compressed) == 101 - assert "列表已截断" in compressed[-1] - - -def test_compress_result_dict(): - """测试字典结果压缩""" - registry = ToolRegistry() - - long_dict = {f"key_{i}": f"value_{i}" for i in range(150)} - - compressed = registry._compress_result(long_dict) - assert len(compressed) == 101 - assert "字典已截断" in compressed["..."] - - -def test_no_compress_short_results(): - """测试不对短结果进行压缩""" - registry = ToolRegistry() - - short_string = "short" - short_list = [1, 2, 3] - short_dict = {"a": 1, "b": 2} - - assert registry._compress_result(short_string) == short_string - assert registry._compress_result(short_list) == short_list - assert registry._compress_result(short_dict) == short_dict - - class TestToolCategorization: """测试工具分类逻辑(需要 workspace 注册工具).""" @@ -156,13 +78,13 @@ def setup(self, tmp_path): def test_query_tools_category(self): registry = ToolRegistry() - for name in ("glob", "ls", "regex_search", "stat", "read", "read_lines", "symbol_ref"): + for name in ("glob", "ls", "regex_search", "stat", "read", "symbol_ref"): assert registry._tool_categories.get(name) == "query", f"{name} should be query" - def test_edit_tools_category(self): + def test_write_tools_category(self): registry = ToolRegistry() for name in ("write", "edit"): - assert registry._tool_categories.get(name) == "edit", f"{name} should be edit" + assert registry._tool_categories.get(name) == "write", f"{name} should be write" def test_git_tool_category(self): registry = ToolRegistry() diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/tools/__init__.py b/tests/models/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/tools/test_tool_result.py b/tests/models/tools/test_tool_result.py new file mode 100644 index 0000000..660e098 --- /dev/null +++ b/tests/models/tools/test_tool_result.py @@ -0,0 +1,63 @@ +import pytest + +from src.models.tools.tool_result import ToolResult + + +def test_compress_result_list(): + """测试列表结果压缩""" + long_list = list(range(150)) + + compressed = ToolResult._compress_result(long_list) + assert len(compressed) == 101 + assert "列表已截断" in compressed[-1] + + +def test_compress_result_dict(): + """测试字典结果压缩""" + long_dict = {f"key_{i}": f"value_{i}" for i in range(150)} + + compressed = ToolResult._compress_result(long_dict) + assert len(compressed) == 101 + assert "字典已截断" in compressed["..."] + + +def test_no_compress_short_results(): + """测试不对短结果进行压缩""" + short_string = "short" + short_list = [1, 2, 3] + short_dict = {"a": 1, "b": 2} + + assert ToolResult._compress_result(short_string) == short_string + assert ToolResult._compress_result(short_list) == short_list + assert ToolResult._compress_result(short_dict) == short_dict + + +def test_validate_config(): + """测试配置验证 - 一次性测试所有阈值""" + # 设置所有值为过小 + ToolResult.MAX_RESULT_LENGTH = 5 + ToolResult.LIST_TRUNCATE_THRESHOLD = 3 + ToolResult.DICT_TRUNCATE_THRESHOLD = 2 + + # 验证触发3个警告 + with pytest.warns(UserWarning) as record: + ToolResult._validate_config() + + # 验证警告数量和内容 + assert len(record) == 3 + assert "TOOL_MAX_RESULT_LENGTH" in str(record[0].message) + assert "TOOL_LIST_TRUNCATE_THRESHOLD" in str(record[1].message) + assert "TOOL_DICT_TRUNCATE_THRESHOLD" in str(record[2].message) + + # 验证所有值都被修正 + assert ToolResult.MAX_RESULT_LENGTH == 100 + assert ToolResult.LIST_TRUNCATE_THRESHOLD == 50 + assert ToolResult.DICT_TRUNCATE_THRESHOLD == 50 + + +def test_compress_result_string(): + """测试字符串结果压缩""" + long_string = "x" * (ToolResult.MAX_RESULT_LENGTH + 10000) + + compressed = ToolResult._compress_result(long_string) + assert "结果已截断" in compressed diff --git a/tests/workspace/tools/test_base_tool.py b/tests/workspace/tools/test_base_tool.py index d8cd286..8d43395 100644 --- a/tests/workspace/tools/test_base_tool.py +++ b/tests/workspace/tools/test_base_tool.py @@ -1,5 +1,4 @@ -from src.workspace.tools.base_tool import BaseTool, build_param_doc -from src.workspace.workspace import Workspace +from src.workspace.tools.base_tool import BaseTool def test_extract_params(): @@ -41,61 +40,3 @@ def class_method(cls, b: str) -> str: assert "cls" not in params2 assert "a" in params1 assert "b" in params2 - - -def test_build_param_doc(): - """测试参数文档生成使用简洁类型和required属性""" - - params_required = {"required": True, "annotation": ""} - result = build_param_doc("file_path", params_required) - assert 'type="string"' in result - assert 'required="true"' in result - assert "", "default": "0"} - result = build_param_doc("max_lines", params_optional) - assert 'type="integer"' in result - assert 'required="false"' in result - assert 'default="0"' in result - assert " str: - """Sample function for testing""" - return f"{a}_{b}" - - class MockTool(BaseTool): - def __init__(self, workspace: Workspace | None): - super().__init__(workspace, "mock", "测试工具") - self.func = sample_func - self.params = BaseTool.extract_params(sample_func) - - tool = MockTool(None) - doc = tool.to_doc() - - # 使用新格式标签 - assert doc.startswith('') - assert "测试工具" in doc - assert "" in doc - assert "" in doc - assert doc.endswith("") - - # 验证参数格式 - assert 'type="string"' in doc - assert 'type="integer"' in doc - assert 'required="true"' in doc - assert 'required="false"' in doc - - # 验证没有旧格式标记 - assert "" not in doc - assert "" not in doc - assert "