From 5e4411727c9d7cc5a5c4c06ee82a4bd51231c46e Mon Sep 17 00:00:00 2001 From: Suntion <149924916+SunYanbox@users.noreply.github.com> Date: Mon, 4 May 2026 14:22:22 +0800 Subject: [PATCH 1/8] =?UTF-8?q?feat(#119,=20#127)=20Merge=20similar=20tool?= =?UTF-8?q?=20parameters=20and=20functions=20to=20reduce=20injection=20len?= =?UTF-8?q?gth=20and=20mitigate=20LLM=20hallucinations=20/=20=E5=90=88?= =?UTF-8?q?=E5=B9=B6=E7=9B=B8=E4=BC=BC=E5=B7=A5=E5=85=B7=E7=9A=84=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E4=B8=8E=E5=8A=9F=E8=83=BD=EF=BC=8C=E7=BC=A9=E7=9F=AD?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E6=B3=A8=E5=85=A5=E9=95=BF=E5=BA=A6=EF=BC=8C?= =?UTF-8?q?=E9=99=8D=E4=BD=8E=E6=A8=A1=E5=9E=8B=E5=B9=BB=E8=A7=89=20(#128)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(read): 合并 read_lines 至 read 工具并支持范围读取 (#119) - 重构优化: 统一文件读取逻辑,移除冗余工具类 * 删除 `src/workspace/tools/read_lines_tool.py` 中的 `ReadLinesTool` 类及其 `read_lines` 方法 * 增强 `src/workspace/tools/read_tool.py` 中的 `ReadTool` 类,将 `read` 方法参数从 `max_lines` 变更为 `start`, `end`, `context` * 在 `ReadTool` 中新增私有函数 `_resolve_index` 用于处理负数索引和边界校验 * 更新 `ReadTool.read` 方法的返回格式,动态显示实际读取的行号范围 `[行 start-end / 共 total_lines 行]` - 破坏性变更: 废弃 `read_lines` 工具接口 * 移除 `read_lines` 工具名称注册,相关调用需迁移至 `read` 工具 * 修改 `read` 工具的 API 签名:原 `max_lines` 参数已移除,现使用 `start` (起始行), `end` (结束行, 支持负数), `context` (上下文行数) 替代 * 影响模块:`src/core/tool_registry.py` 中移除了对 `ReadLinesTool` 的导入和注册 - 文档更新: 同步工具列表与测试用例 * 更新 `README.md` 和 `README_ZH.md`,将 `read_lines` 条目合并入 `read` 的描述 * 重构 `tests/core/test_tool_registry.py` 中的分类断言,移除 `read_lines` * 重命名测试类及方法:`TestReadLinesToolMtimeRecording` -> `TestReadToolRangeRecording`, `test_read_lines_records_mtime` -> `test_read_range_records_mtime` * 更新 `tests/workspace/tools/test_binary_protection.py`,将针对 `read_lines` 的二进制保护测试迁移至 `read` 工具 * refactor(tools): 重构参数描述逻辑并精简文档 (#127) - 重构优化: 将工具参数文档生成格式从 XML 转换为 Markdown 列表 * 新增 `build_param_list_item` 函数替代原有的 `build_param_doc`,支持生成 `- **name** (type, required/optional)` 格式的列表项 * 在 `BaseTool` 类中引入 `self.param_descriptions` 字典,允许子类为每个参数提供自定义中文描述 * 更新 `to_doc` 方法逻辑,调用新函数并拼接自定义描述后缀 - 新增功能: 为所有核心工具添加详细的参数中文说明 * `symbol_ref_tool.py`: 定义 `symbol_name`, `path`, `language`, `include_definitions` 等 10 个参数的描述 * `ls_tool.py`, `git_tool.py`, `read_tool.py`, `write_tool.py`, `stat_tool.py`, `glob_tool.py`, `regex_search_tool.py`, `exact_search_tool.py`, `edit_tool.py`: 分别为各工具的输入参数添加了清晰的中文解释 - 文档更新: 移除冗余的 Docstring 注释块 * 删除了多个工具函数(如 `symbol_ref`, `ls`, `git`, `read`, `write`, `stat`, `glob`, `regex_search`, `exact_search`, `edit`)中重复的 `Args` 或 `Parameters` 段落 * 清理了不再使用的 `Returns` 部分,依赖动态生成的文档结构 - 测试维护: 同步更新单元测试以适配新的实现 * 移除对已废弃函数 `build_param_doc` 的导入和测试用例 `test_build_param_doc` * 删除针对旧 XML 输出格式的 `test_to_doc_new_format` 测试,保留参数提取逻辑验证 * docs(contributing): 更新工具开发指南以适配新的参数描述机制 - 文档更新: 同步 `CONTRIBUTING_ZH.md` 和 `CONTRIBUTING.md` 中的代码示例 * 移除函数 Docstring 中冗余的 `Parameters` 或 `Args` 段落,避免与动态生成的文档重复 * 在工具类初始化部分添加 `self.param_descriptions` 字典赋值逻辑,展示如何为每个参数提供中文或英文说明 * 更新示例代码结构,明确区分静态描述(通过 `param_descriptions`)与动态生成逻辑 * refactor(tools): 精简工具参数描述与文档摘要 - 重构优化: 移除参数说明中的默认值细节,统一为通用描述 * `symbol_ref_tool.py`: 更新 `path`, `language`, `include_definitions` 等参数的描述,移除“默认为...”后缀 * `ls_tool.py`, `stat_tool.py`: 简化 `folder_path` 和 `path` 的描述,去除默认路径提示 * `regex_search_tool.py`, `exact_search_tool.py`: 清理 `path`, `context`, `file_pattern`, `limit` 等参数的冗余信息 * `edit_tool.py`: 调整 `old_string`, `max_replacements`, `context_before`, `context_after` 的描述,使其更简洁明确 - 文档更新: 压缩函数 Docstring 内容,去除冗余的换行和解释性文字 * `symbol_ref_tool.py`: 合并多行功能描述为一行,保留核心场景说明 * `git_tool.py`, `read_tool.py`, `write_tool.py`: 缩短简介,移除不必要的标点或补充说明 * `exact_search_tool.py`: 将复杂的括号说明简化为纯文本 * `edit_tool.py`: 删除关于 dry-run、diff 生成及审计流程的详细解释,仅保留核心功能定义 * refactor(tools): 统一文件路径参数名为 `path` (#127) - 重构优化: 将多个工具中的文件/目录路径参数从特定名称统一重命名为 `path` * `ls_tool.py`: 将参数 `folder_path` 重命名为 `path`,更新内部变量引用及错误提示逻辑 * `read_tool.py`: 将参数 `file_path` 重命名为 `path`,同步更新所有涉及路径验证、二进制检查、文件读取及元数据记录的局部变量名 * `write_tool.py`: 将参数 `file_path` 重命名为 `path`,调整路径解析、存在性检查、mtime 验证及 diff 生成过程中的变量使用 * `glob_tool.py`: 将参数 `folder_path` 重命名为 `path`,简化根路径验证逻辑 * `edit_tool.py`: 将参数 `file_path` 重命名为 `path`,更新源路径解析、文件存在性校验及错误返回信息中的变量引用 - 破坏性变更: 修改了工具函数的 API 签名 * 受影响的函数包括 `ls`, `read`, `write`, `glob`, `edit` * 调用方需将对应的路径参数名更新为 `path`,例如 `tool.ls(folder_path=...)` 需改为 `tool.ls(path=...)` * refactor(tools): 调整错误处理语句格式以符合代码规范 - 重构优化: 统一 `ToolErrorResponse` 的调用格式,移除不必要的换行或缩进 * `read_tool.py`: 将 `is_file()` 检查失败时的错误返回语句合并为单行,移除跨行括号导致的冗余缩进 * `write_tool.py`: 将 `exists() and is_dir()` 检查失败时的错误返回语句拆分为多行,增强可读性并匹配其他错误处理逻辑 --- CONTRIBUTING.md | 9 +-- CONTRIBUTING_ZH.md | 9 +-- README.md | 3 +- README_ZH.md | 3 +- src/core/tool_registry.py | 4 +- src/workspace/tools/base_tool.py | 28 +++---- src/workspace/tools/edit_tool.py | 30 ++++---- src/workspace/tools/exact_search_tool.py | 21 +++-- src/workspace/tools/git_tool.py | 9 +-- src/workspace/tools/glob_tool.py | 19 ++--- src/workspace/tools/ls_tool.py | 13 ++-- src/workspace/tools/read_lines_tool.py | 71 ----------------- src/workspace/tools/read_tool.py | 76 ++++++++++++++----- src/workspace/tools/regex_search_tool.py | 17 ++--- src/workspace/tools/stat_tool.py | 9 +-- src/workspace/tools/symbol_ref_tool.py | 30 +++----- src/workspace/tools/write_tool.py | 37 +++++---- tests/core/test_tool_registry.py | 2 +- tests/workspace/tools/test_base_tool.py | 61 +-------------- .../workspace/tools/test_binary_protection.py | 24 +++--- tests/workspace/tools/test_read_tool.py | 10 +-- 21 files changed, 180 insertions(+), 305 deletions(-) delete mode 100644 src/workspace/tools/read_lines_tool.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 820ebcc..acc9fae 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -231,16 +231,15 @@ the `BaseTool` base class. ) self.func = self.your_method self.params = BaseTool.extract_params(self.your_method) + self.param_descriptions = { + "param1": "Parameter description", + "param2": "Parameter description" + } @BaseTool.handle_tool_exceptions def your_method(self, param1: str, param2: int = 0) -> str: """ Tool description -- will be generated as LLM-readable documentation. - - Parameters - ---------- - param1: Parameter description - param2: Parameter description (with default value) """ # Path operations must be validated through PathValidator path = self.workspace.path_validator.validate(param1) diff --git a/CONTRIBUTING_ZH.md b/CONTRIBUTING_ZH.md index 680f0c9..053ac9b 100644 --- a/CONTRIBUTING_ZH.md +++ b/CONTRIBUTING_ZH.md @@ -215,16 +215,15 @@ docs(README): 添加贡献指南 / add contributing guide ) self.func = self.your_method self.params = BaseTool.extract_params(self.your_method) + self.param_descriptions = { + "param1": "参数说明", + "param2": "参数说明" + } @BaseTool.handle_tool_exceptions def your_method(self, param1: str, param2: int = 0) -> str: """ 工具描述 -- 会生成为 LLM 可读的文档. - - Parameters - ---------- - param1: 参数说明 - param2: 参数说明(带默认值) """ # 路径操作必须通过 PathValidator 验证 path = self.workspace.path_validator.validate(param1) diff --git a/README.md b/README.md index 396bb4c..71c73f6 100644 --- a/README.md +++ b/README.md @@ -125,8 +125,7 @@ ManualAid registers 12 tools for LLM use via XML function calls: | -------------- | ------------------------------------------------ | | `ls` | List directory contents | | `glob` | Find files by glob pattern | -| `read` | Read file contents (with optional line limit) | -| `read_lines` | Read specific line range from a file | +| `read` | Read file contents with optional line range | | `stat` | Get file/directory metadata (size, mtime, lines) | | `exact_search` | Exact string search with case/whole-word options | | `regex_search` | Regex search with context display | diff --git a/README_ZH.md b/README_ZH.md index 2a4dc7e..51b375d 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -112,8 +112,7 @@ ManualAid 注册了 12 个工具供 LLM 通过 XML 函数调用使用: | -------------- | ----------------------------------------- | | `ls` | 列出目录内容 | | `glob` | 通过 glob 模式查找文件 | -| `read` | 读取文件内容(可选行数限制) | -| `read_lines` | 读取文件中指定范围的行 | +| `read` | 读取文件内容,支持指定行范围 | | `stat` | 获取文件/目录元数据(大小、修改时间、行数) | | `exact_search` | 精确字符串搜索,支持大小写/全词匹配 | | `regex_search` | 正则表达式搜索,支持上下文显示 | diff --git a/src/core/tool_registry.py b/src/core/tool_registry.py index c8d485d..29269e9 100644 --- a/src/core/tool_registry.py +++ b/src/core/tool_registry.py @@ -98,7 +98,7 @@ def _validate_tool_info(self, name: str, doc: str) -> None: def _set_tool_category(self, tool_name: str) -> None: """根据工具名称设置分类.""" - if tool_name in {"glob", "ls", "regex_search", "exact_search", "stat", "read", "read_lines", "symbol_ref"}: + if tool_name in {"glob", "ls", "regex_search", "exact_search", "stat", "read", "symbol_ref"}: self._tool_categories[tool_name] = "query" elif tool_name in {"write", "edit", "confirm_edit"}: self._tool_categories[tool_name] = "edit" @@ -114,7 +114,6 @@ def register(self, workspace: Workspace) -> None: from src.workspace.tools.git_tool import GitTool from src.workspace.tools.glob_tool import GlobTool from src.workspace.tools.ls_tool import LsTool - from src.workspace.tools.read_lines_tool import ReadLinesTool from src.workspace.tools.read_tool import ReadTool from src.workspace.tools.regex_search_tool import RegexSearchTool from src.workspace.tools.stat_tool import StatTool @@ -127,7 +126,6 @@ def register(self, workspace: Workspace) -> None: ExactSearchTool, GlobTool, LsTool, - ReadLinesTool, ReadTool, RegexSearchTool, WriteTool, diff --git a/src/workspace/tools/base_tool.py b/src/workspace/tools/base_tool.py index 8d456a9..367df61 100644 --- a/src/workspace/tools/base_tool.py +++ b/src/workspace/tools/base_tool.py @@ -7,23 +7,21 @@ from src.workspace.workspace import Workspace -def build_param_doc(name: str, params: dict[str, Any]) -> str: - """Generate a concise XML parameter doc.""" +def build_param_list_item(name: str, params: dict[str, Any], description: str = "") -> str: + """Generate a Markdown list item describing a parameter.""" from src.constants.prompts import clean_type_annotation - result = f' Any: @@ -112,6 +110,7 @@ def __init__( self.doc: str = doc self.func: Callable[..., Any] | None = None self.params: dict[str, Any] | None = None + self.param_descriptions: dict[str, str] = {} def to_doc(self) -> str: """转换为模型可读文档格式""" @@ -119,7 +118,8 @@ def to_doc(self) -> str: if self.params and len(self.params) > 0: lines.append(" ") for name, param in self.params.items(): - lines.append(f" {build_param_doc(name, param)}") + desc = self.param_descriptions.get(name, "") + lines.append(f" {build_param_list_item(name, param, desc)}") lines.append(" ") else: lines.append(" ") diff --git a/src/workspace/tools/edit_tool.py b/src/workspace/tools/edit_tool.py index 15553ee..001826a 100644 --- a/src/workspace/tools/edit_tool.py +++ b/src/workspace/tools/edit_tool.py @@ -19,11 +19,19 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "edit", self.edit.__doc__, write_permission=True) self.func = self.edit self.params = BaseTool.extract_params(self.edit) + self.param_descriptions = { + "path": "文件路径", + "old_string": "待替换的字符串", + "new_string": "替换后的字符串", + "max_replacements": "最大替换次数(1~100)", + "context_before": "匹配前的上下文文本", + "context_after": "匹配后的上下文文本", + } @BaseTool.handle_tool_exceptions def edit( self, - file_path: str, + path: str, old_string: str, new_string: str, max_replacements: int = 10, @@ -31,19 +39,7 @@ def edit( context_after: str = "", ) -> str: """ - 在文件中进行安全的字符串替换(仅预览,不修改磁盘) - - 执行 dry-run 替换,生成 diff,记录 PENDING_AUDIT 快照. - 批准后由 AuditCommitter 执行实际写入. - - Parameters - ---------- - file_path: 文件路径 - old_string: 待替换的字符串(不能为空) - new_string: 替换后的字符串 - max_replacements: 最大替换次数(默认 10,最大 100) - context_before: 匹配前的上下文文本(可选,用于校验) - context_after: 匹配后的上下文文本(可选,用于校验) + 通过在文件中进行安全的字符串替换编辑文件 """ # 1. 参数校验 if not old_string: @@ -55,8 +51,8 @@ def edit( max_replacements = 100 # 2. 路径解析 - source_file_path = Path(file_path) - resolved_path: Path = self.workspace.path_validator.resolve_path(source_file_path) + source_path = Path(path) + resolved_path: Path = self.workspace.path_validator.resolve_path(source_path) if not resolved_path.is_file(): return ToolErrorResponse( @@ -96,7 +92,7 @@ def edit( idx += len(old_string) if count == 0: - return f"No changes made: old_string not found in file.\nFile: {file_path}\nSearching for: '{old_string}'" + return f"No changes made: old_string not found in file.\nFile: {path}\nSearching for: '{old_string}'" # 6. 执行替换(生成新内容) new_content = old_content.replace(old_string, new_string, count) diff --git a/src/workspace/tools/exact_search_tool.py b/src/workspace/tools/exact_search_tool.py index b23eaa6..6b6b76b 100644 --- a/src/workspace/tools/exact_search_tool.py +++ b/src/workspace/tools/exact_search_tool.py @@ -76,6 +76,14 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "exact_search", self.exact_search.__doc__) self.func = self.exact_search self.params = BaseTool.extract_params(self.exact_search) + self.param_descriptions = { + "pattern": "搜索字符串", + "path": "搜索文件或文件夹路径", + "case_sensitive": "是否大小写敏感", + "whole_word": "是否全词匹配", + "limit": "最大匹配数量限制", + "ignore": "忽略匹配正则的文件或文件夹列表", + } @BaseTool.handle_tool_exceptions def exact_search( @@ -88,18 +96,7 @@ def exact_search( ignore: list[str] | None = None, ) -> str: """ - 精确搜索字符串(支持大小写敏感/全词匹配) - - Args: - pattern: 搜索字符串 - path: 搜索路径,默认为当前目录 - case_sensitive: 是否大小写敏感,默认为True - whole_word: 是否全词匹配,默认为True - limit: 最大匹配数量限制,默认为256 - ignore: 忽略匹配正则的文件或文件夹列表 - - Returns: - 格式化的搜索结果字符串 + 精确搜索字符串 """ # 验证搜索路径 search_path: Path = self.workspace.path_validator.validate(path) diff --git a/src/workspace/tools/git_tool.py b/src/workspace/tools/git_tool.py index cf96864..af1e0c2 100644 --- a/src/workspace/tools/git_tool.py +++ b/src/workspace/tools/git_tool.py @@ -55,14 +55,13 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "git", self.git.__doc__, read_permission=True) self.func = self.git self.params = BaseTool.extract_params(self.git) + self.param_descriptions = { + "command_str": "Git 子命令及其参数,如 'status'、'diff --cached'、'log --oneline -5'", + } def git(self, command_str: str) -> str: """ - 执行 Git 命令(白名单限制) - - Parameters - ---------- - command_str: Git 子命令及其参数,如 "status"、"diff --cached"、"log --oneline -5" + 执行 Git 命令 """ if not command_str or not command_str.strip(): return ToolErrorResponse(self.__class__.__name__, ValueError("command_str 不能为空")).to_str() diff --git a/src/workspace/tools/glob_tool.py b/src/workspace/tools/glob_tool.py index 55ddabb..7ee0c08 100644 --- a/src/workspace/tools/glob_tool.py +++ b/src/workspace/tools/glob_tool.py @@ -10,23 +10,18 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "glob", self.glob.__doc__) self.func = self.glob self.params = BaseTool.extract_params(self.glob) + self.param_descriptions = { + "pattern": "通配符", + "path": "目录路径", + "max_ret": "最多返回多少条检索结果", + } @BaseTool.handle_tool_exceptions - def glob(self, pattern: str, folder_path: str = ".", max_ret: int = 1000) -> list[str]: + def glob(self, pattern: str, path: str = ".", max_ret: int = 1000) -> list[str]: """ 在工作区内按通配符模式匹配并列出所有路径,带[Folder]或[File]的类型标记. 失败时返回错误信息 - - Parameters - ---------- - pattern: 通配符 - folder_path: 目录路径 - max_ret: 最多返回多少条检索结果 - - Returns - ------- - 检索到的文件或文件夹的相对路径 """ - root_path = self.workspace.path_validator.validate(folder_path) + root_path = self.workspace.path_validator.validate(path) if not root_path.is_dir(): return ToolErrorResponse(self.__class__.__name__, f"{root_path}不是一个文件夹路径").to_str() diff --git a/src/workspace/tools/ls_tool.py b/src/workspace/tools/ls_tool.py index 7f92edb..8b8ee34 100644 --- a/src/workspace/tools/ls_tool.py +++ b/src/workspace/tools/ls_tool.py @@ -10,16 +10,19 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "ls", self.ls.__doc__) self.func = self.ls self.params = BaseTool.extract_params(self.ls) + self.param_descriptions = { + "path": "目录路径", + } @BaseTool.handle_tool_exceptions - def ls(self, folder_path: str = ".") -> list[str] | str: + def ls(self, path: str = ".") -> list[str] | str: """ 列出指定目录下的文件和文件夹. 返回相对路径列表, 并标记[Folder]或[File] """ - path: Path = self.workspace.path_validator.validate(folder_path) - if not path.is_dir(): - return ToolErrorResponse(self.__class__.__name__, f'参数错误: "{path}"不是一个目录').to_str() + folder_path: Path = self.workspace.path_validator.validate(path) + if not folder_path.is_dir(): + return ToolErrorResponse(self.__class__.__name__, f'参数错误: "{folder_path}"不是一个目录').to_str() return [ f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" - for item in path.iterdir() + for item in folder_path.iterdir() ] diff --git a/src/workspace/tools/read_lines_tool.py b/src/workspace/tools/read_lines_tool.py deleted file mode 100644 index a00c41f..0000000 --- a/src/workspace/tools/read_lines_tool.py +++ /dev/null @@ -1,71 +0,0 @@ -from pathlib import Path - -from src.models.tool_error_response import ToolErrorResponse -from src.utils.binary_detector import is_binary_file -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - - -class ReadLinesTool(BaseTool): - def __init__(self, workspace: Workspace): - super().__init__(workspace, "read_lines", self.read_lines.__doc__) - self.func = self.read_lines - self.params = BaseTool.extract_params(self.read_lines) - - @BaseTool.handle_tool_exceptions - def read_lines(self, file_path: str, start: int, end: int, context: int = 2, encoding: str = "utf-8") -> str: - """ - 读取文件的指定行范围(行号从1开始), 可指定上下文行数扩展返回的实际行数范围,返回带行号的格式化内容 - - Parameters - ---------- - file_path: 文件路径 - start: 开始行数 - end: 结束行数 - context: 扩展结果行数范围 行数范围最终为(start-context, end+context) - encoding: 编码 - """ - path: Path = self.workspace.path_validator.validate(file_path) - - if not path.is_file(): - return ToolErrorResponse(self.__class__.__name__, ValueError(f"读取文件{path}时未读取到完整文件")).to_str() - - if is_binary_file(path): - return ToolErrorResponse( - self.__class__.__name__, - ValueError(f"无法读取二进制文件: {path}. 请使用二进制安全工具或转换为 base64."), - ).to_str() - - with open(path, encoding=encoding) as f: - lines = f.readlines() - - total_lines = len(lines) - - context = max(0, context) - - start -= context - end += context - - # 验证行号 - if start < 1: - start = 1 - if start > total_lines: - return f"错误:起始行 {start} 超过文件总行数 ({total_lines})" - - end = total_lines if end is None else min(end, total_lines) - - if end < start: - return f"错误:结束行 {end} 小于起始行 {start}" - - result_lines = [] - for i in range(start - 1, end): - line_num = i + 1 - content = lines[i].rstrip("\n\r") - result_lines.append(f"{line_num:6d} | {content}") - - header = f"\n[文件: {path}]\n[行 {start}-{end} / 共 {total_lines} 行]\n" - separator = "-" * 80 + "\n" - - self._record_read_meta(path) - - return header + separator + "\n".join(result_lines) diff --git a/src/workspace/tools/read_tool.py b/src/workspace/tools/read_tool.py index ba9b789..5561e16 100644 --- a/src/workspace/tools/read_tool.py +++ b/src/workspace/tools/read_tool.py @@ -6,50 +6,84 @@ from src.workspace.workspace import Workspace +def _resolve_index(idx: int, total: int) -> int: + """Resolve a 1-based or negative index to a clamped 1-based line number.""" + if idx < 0: + idx = total + 1 + idx + if idx < 1: + return 1 + if idx > total: + return total + return idx + + class ReadTool(BaseTool): def __init__(self, workspace: Workspace): super().__init__(workspace, "read", self.read.__doc__) self.func = self.read self.params = BaseTool.extract_params(self.read) + self.param_descriptions = { + "path": "文件路径", + "start": "起始行号(1开始; 负数表示倒数, -1=最后一行)", + "end": "结束行号(1开始; 负数表示倒数, -1=最后一行)", + "context": "扩展结果行数范围 行数范围最终为(start-context, end+context)", + "encoding": "编码", + } @BaseTool.handle_tool_exceptions - def read(self, file_path: str, max_lines: int = 0, encoding: str = "utf-8") -> str: + def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encoding: str = "utf-8") -> str: """ - 读取文件内容,可限制最大行数,返回文件内容字符串(带行号) - - Parameters - ---------- - file_path: 文件路径 - max_lines: 最大行数(0表示不限制) - encoding: 编码 + 读取文件内容, 返回带行号的格式化内容 """ - path: Path = self.workspace.path_validator.validate(file_path) + file_path: Path = self.workspace.path_validator.validate(path) - if not path.is_file(): - return ToolErrorResponse(self.__class__.__name__, ValueError(f"读取文件{path}时未读取到完整文件")).to_str() + if not file_path.is_file(): + return ToolErrorResponse( + self.__class__.__name__, ValueError(f"读取文件{file_path}时未读取到完整文件") + ).to_str() - if is_binary_file(path): + if is_binary_file(file_path): return ToolErrorResponse( self.__class__.__name__, - ValueError(f"无法读取二进制文件: {path}. 请使用二进制安全工具或转换为 base64."), + ValueError(f"无法读取二进制文件: {file_path}. 请使用二进制安全工具或转换为 base64."), ).to_str() - with open(path, encoding=encoding) as f: + with open(file_path, encoding=encoding) as f: lines = f.readlines() total_lines = len(lines) - if max_lines > 0: - lines = lines[:max_lines] + if total_lines == 0: + header = f"\n[文件: {file_path}]\n[行 0-0 / 共 0 行]\n" + separator = "-" * 80 + "\n" + self._record_read_meta(file_path) + return header + separator + + context = max(0, context) + + actual_start = _resolve_index(start, total_lines) - context + actual_end = _resolve_index(end, total_lines) + context + + if actual_start < 1: + actual_start = 1 + if actual_end > total_lines: + actual_end = total_lines + + if actual_end < actual_start: + return ( + f"错误:解析后的结束行 {actual_end} 小于起始行 {actual_start} " + f"(原始参数: start={start}, end={end}, context={context})" + ) result_lines = [] - for i, line in enumerate(lines, 1): - content = line.rstrip("\n\r") - result_lines.append(f"{i:6d} | {content}") + for i in range(actual_start - 1, actual_end): + line_num = i + 1 + content = lines[i].rstrip("\n\r") + result_lines.append(f"{line_num:6d} | {content}") - header = f"\n[文件: {path}]\n[行 1-{len(lines)} / 共 {total_lines} 行]\n" + header = f"\n[文件: {file_path}]\n[行 {actual_start}-{actual_end} / 共 {total_lines} 行]\n" separator = "-" * 80 + "\n" - self._record_read_meta(path) + self._record_read_meta(file_path) return header + separator + "\n".join(result_lines) diff --git a/src/workspace/tools/regex_search_tool.py b/src/workspace/tools/regex_search_tool.py index 59114ad..f6d85cb 100644 --- a/src/workspace/tools/regex_search_tool.py +++ b/src/workspace/tools/regex_search_tool.py @@ -103,6 +103,14 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "regex_search", self.regex_search.__doc__) self.func = self.regex_search self.params = BaseTool.extract_params(self.regex_search) + self.param_descriptions = { + "pattern": "正则表达式模式", + "path": "搜索文件或文件夹路径", + "context": "显示匹配行的上下文行数", + "file_pattern": "文件匹配模式,支持通配符", + "limit": "最大匹配数量限制", + "ignore": "忽略匹配正则的文件或文件夹列表", + } @BaseTool.handle_tool_exceptions def regex_search( @@ -116,15 +124,6 @@ def regex_search( ) -> str: """ 使用正则表达式搜索文件内容, 支持上下文显示、文件过滤和忽略路径, 返回匹配详情; 适合代码与文档探索 - - Parameters - ---------- - pattern: 正则表达式模式 - path: 搜索路径,默认为当前目录 - context: 显示匹配行的上下文行数,默认为3 - file_pattern: 文件匹配模式,支持通配符,默认为"*" - limit: 最大匹配数量限制,默认为256 - ignore: 忽略匹配正则的文件或文件夹列表 """ # 验证搜索路径 search_path: Path = self.workspace.path_validator.validate(path) diff --git a/src/workspace/tools/stat_tool.py b/src/workspace/tools/stat_tool.py index 41981cf..bf8d28f 100644 --- a/src/workspace/tools/stat_tool.py +++ b/src/workspace/tools/stat_tool.py @@ -13,17 +13,14 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "stat", self.stat.__doc__) self.func = self.stat self.params = BaseTool.extract_params(self.stat) + self.param_descriptions = { + "path": "文件或目录路径", + } @BaseTool.handle_tool_exceptions def stat(self, path: str = ".") -> str: """ 获取工作区内文件或目录的详细信息,包括大小、行数(仅文件)、修改时间、权限等 - - Args: - path: 文件或目录路径,默认为当前目录 - - Returns: - 格式化的详细信息字符串 """ # 验证路径 target_path: Path = self.workspace.path_validator.validate(path) diff --git a/src/workspace/tools/symbol_ref_tool.py b/src/workspace/tools/symbol_ref_tool.py index 6c0c4df..f13deb8 100644 --- a/src/workspace/tools/symbol_ref_tool.py +++ b/src/workspace/tools/symbol_ref_tool.py @@ -347,6 +347,17 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "symbol_ref", self.symbol_ref.__doc__, read_permission=True, write_permission=False) self.func = self.symbol_ref self.params = BaseTool.extract_params(self.symbol_ref) + self.param_descriptions = { + "symbol_name": "要查找的符号名称(如函数名、类名、变量名)", + "path": "搜索文件或文件夹路径", + "language": "语言类型(auto/python/javascript/typescript/markdown/general)", + "include_definitions": "是否包含定义位置", + "include_references": "是否包含引用位置", + "context_lines": "显示匹配行的上下文行数", + "limit": "最大匹配数量限制", + "ignore": "忽略匹配正则的文件或文件夹列表", + "file_pattern": "文件匹配模式(如 *.py),默认根据语言自动选择", + } @BaseTool.handle_tool_exceptions def symbol_ref( @@ -362,24 +373,7 @@ def symbol_ref( file_pattern: str | None = None, ) -> str: """ - 查找符号(函数、类、变量等)的定义和引用位置 - - 支持自动识别语言类型,生成智能搜索模式来定位符号的定义和所有使用位置. - 适用于代码探索、重构影响分析、理解代码结构等场景. - - Args: - symbol_name: 要查找的符号名称(如函数名、类名、变量名) - path: 搜索路径,默认为当前目录 - language: 语言类型(auto/python/javascript/typescript/markdown/general),默认auto - include_definitions: 是否包含定义位置,默认True - include_references: 是否包含引用位置,默认True - context_lines: 显示匹配行的上下文行数,默认2 - limit: 最大匹配数量限制,默认为256 - ignore: 忽略匹配正则的文件或文件夹列表 - file_pattern: 文件匹配模式(如 *.py),默认根据语言自动选择 - - Returns: - 格式化的符号引用搜索结果 + 查找符号(函数、类、变量等)的定义和引用位置, 适用于代码探索、重构影响分析、理解代码结构等场景. """ # 验证搜索路径 search_path: Path = self.workspace.path_validator.validate(path) diff --git a/src/workspace/tools/write_tool.py b/src/workspace/tools/write_tool.py index 1a0935e..5310fcc 100644 --- a/src/workspace/tools/write_tool.py +++ b/src/workspace/tools/write_tool.py @@ -12,45 +12,42 @@ def __init__(self, workspace: Workspace): super().__init__(workspace, "write", self.write.__doc__, write_permission=True) self.func = self.write self.params = BaseTool.extract_params(self.write) + self.param_descriptions = { + "path": "文件路径", + "content": "写入内容", + } @BaseTool.handle_tool_exceptions - def write(self, file_path: str, content: str = "") -> str: + def write(self, path: str, content: str = "") -> str: """ - 写入文件内容,如文件不存在则创建(含父目录) - - Parameters - ---------- - file_path: 文件路径 - content: 写入内容 + 写入文件内容, 如文件不存在则创建(含父目录) """ - source_file_path = Path(file_path) - file_path: Path = self.workspace.path_validator.resolve_path(source_file_path) + source_path = Path(path) + path: Path = self.workspace.path_validator.resolve_path(source_path) - if file_path.exists() and file_path.is_dir(): - return ToolErrorResponse( - self.__class__.__name__, ValueError(f"路径 {file_path} 是一个目录,无法写入") - ).to_str() + if path.exists() and path.is_dir(): + return ToolErrorResponse(self.__class__.__name__, ValueError(f"路径 {path} 是一个目录,无法写入")).to_str() - if is_binary_file(file_path): + if is_binary_file(path): return ToolErrorResponse( self.__class__.__name__, - ValueError(f"禁止写入二进制文件: {file_path}"), + ValueError(f"禁止写入二进制文件: {path}"), ).to_str() - mtime_error = self._validate_mtime(file_path) + mtime_error = self._validate_mtime(path) if mtime_error: return mtime_error old_content = "" old_meta = None - if file_path.exists() and file_path.is_file(): - old_meta = FileTracker.get_file_meta(file_path) + if path.exists() and path.is_file(): + old_meta = FileTracker.get_file_meta(path) try: - old_content = file_path.read_text(encoding="utf-8") + old_content = path.read_text(encoding="utf-8") except Exception: old_content = "" - rel_path = str(file_path.relative_to(self.workspace.root_path)) + rel_path = str(path.relative_to(self.workspace.root_path)) old_hash = old_meta.get("checksum") if old_meta else None new_hash = FileTracker.compute_checksum_from_string(content) diff_content = self._generate_diff(old_content, content, rel_path) diff --git a/tests/core/test_tool_registry.py b/tests/core/test_tool_registry.py index e94f43e..e1b42a2 100644 --- a/tests/core/test_tool_registry.py +++ b/tests/core/test_tool_registry.py @@ -156,7 +156,7 @@ def setup(self, tmp_path): def test_query_tools_category(self): registry = ToolRegistry() - for name in ("glob", "ls", "regex_search", "stat", "read", "read_lines", "symbol_ref"): + for name in ("glob", "ls", "regex_search", "stat", "read", "symbol_ref"): assert registry._tool_categories.get(name) == "query", f"{name} should be query" def test_edit_tools_category(self): diff --git a/tests/workspace/tools/test_base_tool.py b/tests/workspace/tools/test_base_tool.py index d8cd286..8d43395 100644 --- a/tests/workspace/tools/test_base_tool.py +++ b/tests/workspace/tools/test_base_tool.py @@ -1,5 +1,4 @@ -from src.workspace.tools.base_tool import BaseTool, build_param_doc -from src.workspace.workspace import Workspace +from src.workspace.tools.base_tool import BaseTool def test_extract_params(): @@ -41,61 +40,3 @@ def class_method(cls, b: str) -> str: assert "cls" not in params2 assert "a" in params1 assert "b" in params2 - - -def test_build_param_doc(): - """测试参数文档生成使用简洁类型和required属性""" - - params_required = {"required": True, "annotation": ""} - result = build_param_doc("file_path", params_required) - assert 'type="string"' in result - assert 'required="true"' in result - assert "", "default": "0"} - result = build_param_doc("max_lines", params_optional) - assert 'type="integer"' in result - assert 'required="false"' in result - assert 'default="0"' in result - assert " str: - """Sample function for testing""" - return f"{a}_{b}" - - class MockTool(BaseTool): - def __init__(self, workspace: Workspace | None): - super().__init__(workspace, "mock", "测试工具") - self.func = sample_func - self.params = BaseTool.extract_params(sample_func) - - tool = MockTool(None) - doc = tool.to_doc() - - # 使用新格式标签 - assert doc.startswith('') - assert "测试工具" in doc - assert "" in doc - assert "" in doc - assert doc.endswith("") - - # 验证参数格式 - assert 'type="string"' in doc - assert 'type="integer"' in doc - assert 'required="true"' in doc - assert 'required="false"' in doc - - # 验证没有旧格式标记 - assert "" not in doc - assert "" not in doc - assert " Date: Mon, 4 May 2026 17:04:54 +0800 Subject: [PATCH 2/8] =?UTF-8?q?fix(symbol=5Fref=5Ftool):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E7=AC=A6=E5=8F=B7=E6=90=9C=E7=B4=A2=E6=80=A7=E8=83=BD?= =?UTF-8?q?=E7=81=BE=E9=9A=BE=E5=8F=8A=E6=96=87=E6=9C=AC=E5=BE=80=E8=BF=94?= =?UTF-8?q?=E5=8F=8D=E6=A8=A1=E5=BC=8F=20(#132)=20(#137)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复问题: 消除多模式搜索导致的文件系统重复遍历与 I/O 瓶颈 * 移除 `_search_single_pattern` 和 `_search_patterns_concurrent`,引入 `_search_all_patterns` 实现单次文件遍历匹配所有正则 * 调用 `Workspace.search_content_multi_pattern` API 替代旧版 `search_content`,避免 N 次 (N=模式数量) 重复打开文件 * 优化 `_build_results_with_context` 逻辑,按需读取文件行缓存而非全量加载,解决大文件内存溢出风险 - 重构优化: 修正上下文行获取逻辑与数据格式处理 * 将匹配结果从非结构化文本字符串解析重构为原生 `list[dict]` 数据结构,消除“格式化→正则解析”的反模式 * 修复上下文行计算逻辑,支持精确获取匹配行前后指定行数 (`context_lines`) 的原始内容 * 增加正则预编译与无效模式过滤机制,提升搜索执行效率 - 破坏性变更: 底层搜索接口返回结构变更 * `symbol_ref_tool` 不再依赖 `Workspace.search_content` 的格式化字符串输出,转而使用 `search_content_multi_pattern` 的结构化字典列表 * 移除了对 `[文件] path\n----\n` 文本格式的解析代码,下游需适配新的 `file`, `line_num`, `content`, `pattern_type` 字段 --- src/workspace/tools/symbol_ref_tool.py | 212 +++++++++++++------------ src/workspace/workspace.py | 118 +++++++++++++- 2 files changed, 227 insertions(+), 103 deletions(-) diff --git a/src/workspace/tools/symbol_ref_tool.py b/src/workspace/tools/symbol_ref_tool.py index f13deb8..8d45d14 100644 --- a/src/workspace/tools/symbol_ref_tool.py +++ b/src/workspace/tools/symbol_ref_tool.py @@ -1,16 +1,12 @@ """符号引用查找工具 - 查找函数、类、变量等的定义和引用""" import re -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from src.models.tool_error_response import ToolErrorResponse from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace -# 默认排除的目录(与 workspace.py 保持一致, 后续改为从项目配置加载) -DEFAULT_EXCLUDED_DIRS = {".git", "__pycache__", "node_modules", ".venv", "venv", "dist", "build", ".idea", ".vscode"} - def _get_file_pattern_by_language(language: str) -> str: """根据语言获取默认的文件匹配模式""" @@ -124,9 +120,9 @@ def _generate_patterns(symbol_name: str, language: str, include_def: bool, inclu return unique_patterns -def _search_single_pattern( +def _search_all_patterns( workspace: Workspace, - pattern_info: dict, + patterns: list[dict], search_path: str, symbol_name: str, context_lines: int, @@ -134,112 +130,124 @@ def _search_single_pattern( file_pattern: str, ignore: list[str] | None, ) -> list[dict]: - """使用 Workspace.search_content 执行单个模式的搜索(并发版本)""" - results = [] + """单次文件遍历搜索所有模式, 返回按文件分组且带上下文的结果. + + 相比旧实现的优势: + - 文件系统只遍历 1 次(旧: N 次, N=模式数量) + - 每个文件只读取 1 次(旧: N 次) + - 直接使用结构化数据, 无需 格式化→正则解析 的反模式 + - 真正实现上下文行读取(旧实现仅将匹配行本身作为上下文) + """ + # 编译所有正则模式 + compiled_patterns: list[tuple[re.Pattern, str]] = [] + for p in patterns: + try: + regex = re.compile(p["pattern"], re.IGNORECASE) + compiled_patterns.append((regex, p["type"])) + except re.error: + continue - try: - # 使用 workspace 的并发搜索能力 - search_result = workspace.search_content( - pattern=pattern_info["pattern"], - folder_path=search_path, - file_pattern=file_pattern, - max_workers=4, # 并发线程数 - case_sensitive=False, - ) + if not compiled_patterns: + return [] - # 解析 search_content 的返回结果 - if not search_result or "未找到匹配" in search_result: - return results - - # 解析结果格式: search_content 返回的是格式化的字符串 - # 格式: [文件] path\n----\n 行号 | 内容 - lines = search_result.split("\n") - current_file = None - current_matches = [] - - for line in lines: - if line.startswith("[文件] "): - # 保存上一个文件的结果 - if current_file and current_matches: - results.append({"file": current_file, "matches": current_matches, "type": pattern_info["type"]}) - current_file = line[6:] # 去掉 "[文件] " - current_matches = [] - elif line.startswith("----"): - continue - elif line.strip() and current_file: - # 解析匹配行: " 行号 | 内容" - match = re.match(r"\s+(\d+)\s*\|\s*(.*)", line) - if match: - line_num = int(match.group(1)) - content = match.group(2) - - # 收集上下文(这里简化处理,因为 search_content 不直接提供上下文) - # 为了保持兼容性,我们构建一个简化的上下文 - current_matches.append( - { - "line_num": line_num, - "content": content, - "context": [{"line_num": line_num, "content": content, "is_match": True}], - "match_type": pattern_info["type"], - "symbol_name": symbol_name, - } - ) + # 单次遍历搜索: 所有模式在一次文件遍历中完成 + matches = workspace.search_content_multi_pattern( + patterns=compiled_patterns, + folder_path=search_path, + file_pattern=file_pattern, + max_workers=4, + ignore=ignore, + ) - # 保存最后一个文件的结果 - if current_file and current_matches: - results.append({"file": current_file, "matches": current_matches, "type": pattern_info["type"]}) + if not matches: + return [] - except Exception: - # 如果搜索失败,静默跳过 - pass + # 应用 limit 截断 + matches = matches[:limit] - return results + # 按文件分组并构建带上下文的结果 + return _build_results_with_context(matches, context_lines, symbol_name, workspace.root_path) -def _search_patterns_concurrent( - workspace: Workspace, - patterns: list[dict], - search_path: str, - symbol_name: str, +def _build_results_with_context( + matches: list[dict], context_lines: int, - limit: int, - file_pattern: str, - ignore: list[str] | None, + symbol_name: str, + root_path: Path, ) -> list[dict]: - """并发执行多个模式的搜索""" - all_results = [] - - # 使用线程池并发执行多个模式的搜索 - with ThreadPoolExecutor(max_workers=min(len(patterns), 4)) as executor: - future_to_pattern = { - executor.submit( - _search_single_pattern, - workspace, - pattern_info, - search_path, - symbol_name, - context_lines, - limit - len(all_results), - file_pattern, - ignore, - ): pattern_info - for pattern_info in patterns - } + """从扁平匹配列表构建按文件分组、带上下文的结果""" + context_lines = max(0, context_lines) + + # 按文件分组, 保留首次出现的顺序 + file_matches: dict[str, list[dict]] = {} + file_order: list[str] = [] + for m in matches: + f = m["file"] + if f not in file_matches: + file_matches[f] = [] + file_order.append(f) + file_matches[f].append(m) + + results = [] + for file_rel in file_order: + file_match_list = file_matches[file_rel] + + # 收集需要读取的行号范围(匹配行 ± 上下文行) + needed_lines: set[int] = set() + for m in file_match_list: + for delta in range(-context_lines, context_lines + 1): + needed_lines.add(m["line_num"] + delta) + + # 仅读取需要的行(不加载整个文件到内存) + line_cache: dict[int, str] = {} + file_full_path = root_path / file_rel + try: + if needed_lines: + max_needed = max(needed_lines) + with open(file_full_path, encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + if line_num in needed_lines: + line_cache[line_num] = line.rstrip("\n\r") + if line_num >= max_needed: + break + except UnicodeDecodeError, PermissionError, OSError: + pass + + # 为每个匹配项构建上下文 + built_matches = [] + for m in file_match_list: + match_line_num = m["line_num"] + context = [] + for delta in range(-context_lines, context_lines + 1): + ctx_line_num = match_line_num + delta + if ctx_line_num in line_cache: + context.append( + { + "line_num": ctx_line_num, + "content": line_cache[ctx_line_num], + "is_match": delta == 0, + } + ) - for future in as_completed(future_to_pattern): - try: - results = future.result() - all_results.extend(results) - if len(all_results) >= limit: - # 达到限制,取消剩余任务 - for f in future_to_pattern: - f.cancel() - break - except Exception: - # 单个模式失败不影响其他模式 - pass + built_matches.append( + { + "line_num": match_line_num, + "content": m["content"], + "context": context, + "match_type": m["pattern_type"], + "symbol_name": symbol_name, + } + ) - return all_results[:limit] + results.append( + { + "file": file_rel, + "matches": built_matches, + "type": file_match_list[0]["pattern_type"], + } + ) + + return results def _detect_language(search_path: Path, specified_lang: str) -> str: @@ -394,7 +402,7 @@ def symbol_ref( return f"无法为符号 '{symbol_name}' 生成有效的搜索模式(语言: {detected_lang})" # 使用并发搜索执行所有模式 - all_results = _search_patterns_concurrent( + all_results = _search_all_patterns( self.workspace, patterns, path, diff --git a/src/workspace/workspace.py b/src/workspace/workspace.py index 538a280..bf588c6 100644 --- a/src/workspace/workspace.py +++ b/src/workspace/workspace.py @@ -7,6 +7,9 @@ from src.models.tool_error_response import ToolErrorResponse from src.workspace.path_validator import PathNotFoundError, PathValidator, WorkspaceBoundaryError +# 默认排除的目录 后续改为从项目配置加载 +DEFAULT_EXCLUDED_DIRS = {".git", "__pycache__", "node_modules", ".venv", "venv", "dist", "build", ".idea", ".vscode"} + def _highlight_matches(line: str, regex: re.Pattern) -> str: """ @@ -73,7 +76,7 @@ def search_content( path = self.path_validator.validate(folder_path) # 初始化排除目录集合 - exclude_set = set(exclude_dirs or [".git", "__pycache__", "node_modules", ".venv", "venv", "dist", "build"]) + exclude_set = set(exclude_dirs or DEFAULT_EXCLUDED_DIRS) # 编译正则表达式 flags = 0 if case_sensitive else re.IGNORECASE @@ -176,3 +179,116 @@ def _search_in_file(self, file_path: Path, regex: re.Pattern) -> list[tuple]: pass return results + + def search_content_multi_pattern( + self, + patterns: list[tuple[re.Pattern, str]], + folder_path: str = ".", + file_pattern: str = "*", + max_workers: int = 4, + ignore: list[str] | None = None, + ) -> list[dict]: + """单次文件遍历匹配多个正则模式, 直接返回结构化数据. + + 与 search_content 的区别: + - 接受多个已编译的正则 + 类型标签, 一次遍历全部匹配 + - 返回结构化 list[dict] 而非格式化字符串, 消除下游正则解析反模式 + - 不做高亮处理(高亮是展示层关注点, 不应混入数据层) + + Args: + patterns: [(compiled_regex, type_label), ...] 已编译正则及其类型标签 + folder_path: 搜索起始路径 + file_pattern: 文件通配符(如 "*.py") + max_workers: 并发读取文件的线程数 + ignore: 忽略路径正则列表 + + Returns: + [{"file": str, "line_num": int, "content": str, "pattern_type": str}, ...] + 按文件路径 → 行号排序, 同一行多个模式匹配则每个各一条记录 + """ + try: + path = self.path_validator.validate(folder_path) + + # 预编译 ignore 正则 + ignore_res: list[re.Pattern] = [] + if ignore: + for ign in ignore: + try: + ignore_res.append(re.compile(ign)) + except re.error: + continue + + # 收集文件(一次遍历) + files_to_search: list[Path] = [] + if path.is_file(): + files_to_search = [path] + else: + for file_path in path.rglob(file_pattern): + if file_path.is_file(): + if any(p.name in DEFAULT_EXCLUDED_DIRS for p in file_path.parents): + continue + rel = str(file_path.relative_to(self.root_path)) + if any(ir.search(rel) for ir in ignore_res): + continue + files_to_search.append(file_path) + + if not files_to_search: + return [] + + # 并发搜索所有文件, 每个文件内一次读取、一次测试所有模式 + all_matches: list[dict] = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(self._search_multi_in_file, file_path, patterns): file_path + for file_path in files_to_search + } + for future in as_completed(futures): + try: + file_results = future.result() + if file_results: + all_matches.extend(file_results) + except Exception: + pass + + # 按文件路径 → 行号排序 + all_matches.sort(key=lambda m: (m["file"], m["line_num"])) + return all_matches + + except PathNotFoundError, WorkspaceBoundaryError, PermissionError: + return [] + except Exception: + return [] + + def _search_multi_in_file(self, file_path: Path, patterns: list[tuple[re.Pattern, str]]) -> list[dict]: + """在单个文件中一次读取、逐行测试所有模式. + + Args: + file_path: 文件绝对路径 + patterns: [(compiled_regex, type_label), ...] + + Returns: + 匹配项列表, 同一行多个模式匹配时每个各返回一条 + """ + results: list[dict] = [] + relative_path = str(file_path.relative_to(self.root_path)) + + try: + with open(file_path, encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + stripped = line.rstrip("\n\r") + for regex, pattern_type in patterns: + if regex.search(stripped): + results.append( + { + "file": relative_path, + "line_num": line_num, + "content": stripped, + "pattern_type": pattern_type, + } + ) + except UnicodeDecodeError, PermissionError: + pass + except Exception: + pass + + return results From 3343e380b125abed3a053f87afb375f08e7eca36 Mon Sep 17 00:00:00 2001 From: Suntion <149924916+SunYanbox@users.noreply.github.com> Date: Mon, 4 May 2026 17:05:04 +0800 Subject: [PATCH 3/8] =?UTF-8?q?ref(core):=20=E7=A7=BB=E9=99=A4=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E8=A7=A3=E6=9E=90=E5=99=A8=E4=B8=AD=E7=9A=84=E5=86=97?= =?UTF-8?q?=E4=BD=99=E8=AD=A6=E5=91=8A=20(#138)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复问题: 清理已解决的重复警告逻辑 * 删除 `src/core/input_parser.py` 中针对 `` 标签的 `warnings.warn` 调用 * 移除不再需要的 `import warnings` 依赖 --- src/core/input_parser.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/core/input_parser.py b/src/core/input_parser.py index 61b2d32..620cb00 100644 --- a/src/core/input_parser.py +++ b/src/core/input_parser.py @@ -1,7 +1,6 @@ import html import re import shlex -import warnings from src.console.commands.command_registry import CommandRegistry from src.models.commands import CommandParseResult @@ -53,11 +52,6 @@ def parse_func_call(content: str, warns: list[str]) -> tuple[str, dict]: """ 从 标签中提取函数名和参数,使用健壮的回退机制. """ - warnings.warn( - "当参数存在`$", "", inner).strip() From 2ef5c7f33b8e5e5a6df077b015d02fd38e2e9e90 Mon Sep 17 00:00:00 2001 From: Suntion <149924916+SunYanbox@users.noreply.github.com> Date: Mon, 4 May 2026 17:45:12 +0800 Subject: [PATCH 4/8] =?UTF-8?q?feat(#135):=20=E4=BD=BF=E7=94=A8=20write=5F?= =?UTF-8?q?permission=20=E8=87=AA=E5=8A=A8=E6=8E=A8=E6=96=AD=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E5=88=86=E7=B1=BB=20/=20Auto-categorize=20tools=20via?= =?UTF-8?q?=20write=5Fpermission=20attribute=20(#139)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/tool_registry.py | 23 +++++++++++------------ tests/core/test_tool_registry.py | 4 ++-- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/core/tool_registry.py b/src/core/tool_registry.py index 29269e9..f6cf41f 100644 --- a/src/core/tool_registry.py +++ b/src/core/tool_registry.py @@ -96,16 +96,14 @@ def _validate_tool_info(self, name: str, doc: str) -> None: if len(name) > self.MAX_FUNC_NAME_LENGTH: warnings.warn(f"工具名称 '{name}' 超过 {self.MAX_FUNC_NAME_LENGTH} 字符", UserWarning, stacklevel=3) - def _set_tool_category(self, tool_name: str) -> None: - """根据工具名称设置分类.""" - if tool_name in {"glob", "ls", "regex_search", "exact_search", "stat", "read", "symbol_ref"}: - self._tool_categories[tool_name] = "query" - elif tool_name in {"write", "edit", "confirm_edit"}: - self._tool_categories[tool_name] = "edit" - elif tool_name == "git": - self._tool_categories[tool_name] = "dangerous" + def _set_tool_category(self, tool: BaseTool) -> None: + """根据工具的 write_permission 属性设置分类.""" + if tool.name == "git": + self._tool_categories[tool.name] = "dangerous" + elif tool.write_permission: + self._tool_categories[tool.name] = "write" else: - self._tool_categories[tool_name] = "query" + self._tool_categories[tool.name] = "query" def register(self, workspace: Workspace) -> None: """为工作区注册工具""" @@ -140,7 +138,7 @@ def register(self, workspace: Workspace) -> None: warnings.warn(f"工具{tool.name}没有注册功能回调和参数", stacklevel=2) continue self._tools[tool.name] = tool - self._set_tool_category(tool.name) + self._set_tool_category(tool) except ValueError: pass @@ -266,7 +264,7 @@ def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, statu command_str = kwargs.get("command_str", "") if not GitTool.is_safe_command(command_str): audit_status = "PENDING_AUDIT" - elif category == "edit": + elif category == "write": audit_status = "none" # Snapshot has its own PENDING_AUDIT workspace.db.log_tool_call( @@ -287,7 +285,8 @@ def _record_tool_call_summary(self, func_name: str, kwargs: dict, result: Any) - return # Exclude write tools - if func_name in {"write", "edit", "confirm_edit"}: + tool = self._tools.get(func_name) + if tool is not None and tool.write_permission: return try: diff --git a/tests/core/test_tool_registry.py b/tests/core/test_tool_registry.py index e1b42a2..ad0ee62 100644 --- a/tests/core/test_tool_registry.py +++ b/tests/core/test_tool_registry.py @@ -159,10 +159,10 @@ def test_query_tools_category(self): for name in ("glob", "ls", "regex_search", "stat", "read", "symbol_ref"): assert registry._tool_categories.get(name) == "query", f"{name} should be query" - def test_edit_tools_category(self): + def test_write_tools_category(self): registry = ToolRegistry() for name in ("write", "edit"): - assert registry._tool_categories.get(name) == "edit", f"{name} should be edit" + assert registry._tool_categories.get(name) == "write", f"{name} should be write" def test_git_tool_category(self): registry = ToolRegistry() From d433e40bc1ae49829408beac103761583843e76f Mon Sep 17 00:00:00 2001 From: Suntion <149924916+SunYanbox@users.noreply.github.com> Date: Mon, 4 May 2026 18:19:07 +0800 Subject: [PATCH 5/8] =?UTF-8?q?feat(#134):=20=E4=B8=BA=20exact=5Fsearch=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20file=5Fpattern=20=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E5=B9=B6=E4=BF=AE=E5=A4=8D=20limit=20=E8=AF=AD=E4=B9=89=20/=20?= =?UTF-8?q?Add=20file=5Fpattern=20parameter=20and=20fix=20limit=20semantic?= =?UTF-8?q?s=20(#140)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 exact_search 添加 file_pattern 参数,默认 "*" 保持向后兼容 - 将 limit 语义从"扫描文件数"改为"返回匹配项数",同步修复 regex_search - Replace hardcoded rglob("*") with rglob(file_pattern) - Change limit to count individual matches instead of files scanned --- src/workspace/tools/exact_search_tool.py | 8 ++++++-- src/workspace/tools/regex_search_tool.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/workspace/tools/exact_search_tool.py b/src/workspace/tools/exact_search_tool.py index 6b6b76b..c03b55b 100644 --- a/src/workspace/tools/exact_search_tool.py +++ b/src/workspace/tools/exact_search_tool.py @@ -81,6 +81,7 @@ def __init__(self, workspace: Workspace): "path": "搜索文件或文件夹路径", "case_sensitive": "是否大小写敏感", "whole_word": "是否全词匹配", + "file_pattern": "文件匹配模式,支持通配符", "limit": "最大匹配数量限制", "ignore": "忽略匹配正则的文件或文件夹列表", } @@ -92,6 +93,7 @@ def exact_search( path: str = ".", case_sensitive: bool = True, whole_word: bool = True, + file_pattern: str = "*", limit: int = 256, ignore: list[str] | None = None, ) -> str: @@ -114,16 +116,17 @@ def exact_search( # 搜索结果 results = [] file_count = 0 + total_matches = 0 # 确定要搜索的文件列表(支持单文件或目录) - files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob("*")) + files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob(file_pattern)) # 遍历所有文件 for file_path in files_to_search: if not file_path.is_file(): continue # 检查是否达到限制 - if len(results) >= limit: + if total_matches >= limit: break # 检查是否应该忽略 @@ -149,6 +152,7 @@ def exact_search( if file_matches: results.append({"file": str(file_path), "matches": file_matches}) file_count += 1 + total_matches += len(file_matches) except OSError, UnicodeDecodeError, PermissionError: continue # 跳过无法读取的文件 diff --git a/src/workspace/tools/regex_search_tool.py b/src/workspace/tools/regex_search_tool.py index f6d85cb..621e8ee 100644 --- a/src/workspace/tools/regex_search_tool.py +++ b/src/workspace/tools/regex_search_tool.py @@ -144,6 +144,7 @@ def regex_search( # 搜索结果 results = [] file_count = 0 + total_matches = 0 # 确定要搜索的文件列表(支持单文件或目录) files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob(file_pattern)) @@ -153,7 +154,7 @@ def regex_search( if not file_path.is_file(): continue # 检查是否达到限制 - if len(results) >= limit: + if total_matches >= limit: break # 检查是否应该忽略该文件或文件夹 @@ -180,6 +181,7 @@ def regex_search( if file_results: results.append({"file": str(file_path), "matches": file_results}) file_count += 1 + total_matches += len(file_results) except OSError, UnicodeDecodeError, PermissionError: continue # 跳过无法读取的文件 From 6ce747a1d7aea700984116e269da77c4efe1058b Mon Sep 17 00:00:00 2001 From: Suntion <149924916+SunYanbox@users.noreply.github.com> Date: Mon, 4 May 2026 18:23:17 +0800 Subject: [PATCH 6/8] =?UTF-8?q?fix(read=5Ftool):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=A4=A7=E5=B0=8F=E9=99=90=E5=88=B6=E9=98=B2?= =?UTF-8?q?=E6=AD=A2=E8=AF=BB=E5=8F=96=E5=A4=A7=E6=96=87=E4=BB=B6=E6=97=B6?= =?UTF-8?q?OOM=20/=20Add=20file=20size=20limit=20to=20prevent=20=E2=80=A6?= =?UTF-8?q?=20(#141)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(read_tool): 添加文件大小限制防止读取大文件时OOM / Add file size limit to prevent OOM on large file reads (#130) * Fix comment formatting for max file size Removed space before the default file size comment. --- src/workspace/tools/read_tool.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/workspace/tools/read_tool.py b/src/workspace/tools/read_tool.py index 5561e16..3bf0c07 100644 --- a/src/workspace/tools/read_tool.py +++ b/src/workspace/tools/read_tool.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from src.models.tool_error_response import ToolErrorResponse @@ -5,6 +6,9 @@ from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace +# 最大文件读取大小, 超过此大小的文件将被拒绝读取(默认 10MB) +MAX_FILE_SIZE = int(os.getenv("TOOL_READ_MAX_FILE_SIZE", str(10 * 1024 * 1024))) + def _resolve_index(idx: int, total: int) -> int: """Resolve a 1-based or negative index to a clamped 1-based line number.""" @@ -48,6 +52,16 @@ def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encod ValueError(f"无法读取二进制文件: {file_path}. 请使用二进制安全工具或转换为 base64."), ).to_str() + file_size = file_path.stat().st_size + if file_size > MAX_FILE_SIZE: + return ToolErrorResponse( + self.__class__.__name__, + ValueError( + f"文件过大 ({file_size} 字节), 超过最大限制 ({MAX_FILE_SIZE} 字节): {file_path}. " + f"请使用范围参数 (start/end) 分批读取." + ), + ).to_str() + with open(file_path, encoding=encoding) as f: lines = f.readlines() From e1c7f3aaef2d0c6c35f13dc899dbcd72d97c481a Mon Sep 17 00:00:00 2001 From: Suntion <149924916+SunYanbox@users.noreply.github.com> Date: Tue, 5 May 2026 14:54:06 +0800 Subject: [PATCH 7/8] =?UTF-8?q?Feat/unify=20tool=20call=20result=20/=20?= =?UTF-8?q?=E7=BB=9F=E4=B8=80=E5=B7=A5=E5=85=B7=E8=BF=94=E5=9B=9E=E7=BB=93?= =?UTF-8?q?=E6=9E=9C=E7=B1=BB=E5=9E=8B=20(#142)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(tools): 重构工具异常处理机制并统一返回结构 (#133) - 新增功能: 引入 `ToolResult` 结构化结果类 * 定义 `ToolResult` 类,包含 `success`, `data`, `error` 属性及 `__slots__` 优化 * 实现 `handle_tool_exceptions` 装饰器,将各类异常(如 `PathNotFoundError`, `PermissionError`)统一封装为 `ToolResult(success=False, error=...)` * 确保工具方法无论同步或异步执行均返回标准 `ToolResult` 对象 - 修复问题: 统一会话 ID 访问方式 * 移除对私有属性 `workspace._current_session_id` 的直接访问 * 在 `Workspace` 类中公开 `session_id` 的 getter/setter 属性供外部调用 * 更新 `FileTracker`, `WriteTool`, `EditTool`, `QuitCmd` 等模块使用 `self.workspace.session_id` - 重构优化: 升级 `ToolRegistry` 执行逻辑 * 修改 `call_tool` 方法,增加对 `ToolResult` 的解包与状态判断逻辑 * 更新 `set_session_id` 方法,使其自动同步 `session_id` 到关联的 `Workspace` 实例 * 修正 `_log_tool_call` 和 `_record_tool_summary` 中的会话 ID 获取方式,直接使用内部变量而非 `getattr` * refactor(tools): 统一工具返回结果为 ToolResult 对象并重构测试断言 - 重构优化: 标准化 `ToolResult` 的使用逻辑与内部结构 * 调整 `src/workspace/tools/base_tool.py` 中 `__slots__` 属性顺序为 `(data, error, success)` 以符合初始化参数顺序 * 简化 `src/core/tool_registry.py` 中的结果处理逻辑,移除冗余的 `if-else` 分支,统一使用三元表达式解包 `raw_result` * 确保所有工具方法(ReadTool, WriteTool, EditTool)返回值均严格遵循 `ToolResult` 规范,通过 `result.data` 或 `result.error` 访问内容 - 修复问题: 更新测试用例以适配新的返回结构 * 修正 `tests/workspace/tools/test_read_tool.py` 中针对文件读取结果的断言,从检查字符串变为访问 `result.data` * 更新 `tests/workspace/tools/test_write_tool.py` 和 `test_edit_tool.py`,将直接检查返回字符串改为验证 `result.data` 字段 * 调整 `tests/workspace/tools/test_binary_protection.py` 中的二进制文件保护逻辑断言,确保正确访问 `result.data` 或 `result.error` * refactor(tool): 重构工具结果模型与执行流程 - 新增功能: 引入 `ToolResult` 模型类替代原有内联实现 * 新增 `src/models/tools/tool_result.py` 文件,定义 `ToolResult` 类 * 在 `ToolResult` 中集成 `_compress_result` 方法,支持字符串、列表和字典的自动截断逻辑 * 新增 `to_xml_string` 辅助函数,用于生成标准化的 XML 格式响应 * 添加 `make_tool_result_response`, `make_success_response`, `make_failed_response` 静态工厂方法至 `BaseTool` - 修复问题: 统一异常处理与返回值类型 * 修改 `BaseTool.handle_tool_exceptions` 装饰器,确保捕获的异常 (`PathNotFoundError`, `WorkspaceBoundaryError` 等) 均返回包含 `func_name` 和 `func_kwargs` 的 `ToolResult` 对象 * 移除 `ToolErrorResponse` 依赖,直接格式化错误信息为 `ClassName: Message` 格式 * 更新 `ToolRegistry.execute` 方法签名,强制返回 `ToolResult` 类型而非原始数据 * 增强 `ToolRegistry.execute` 中的类型检查,对非 `ToolResult` 类型的返回值抛出明确的错误提示 - 重构优化: 移动配置管理与压缩逻辑 * 将 `MAX_RESULT_LENGTH`, `LIST_TRUNCATE_THRESHOLD`, `DICT_TRUNCATE_THRESHOLD` 等常量从 `ToolRegistry` 移至 `ToolResult` 类属性 * 删除 `ToolRegistry` 中的 `_validate_config` 和 `_compress_result` 方法 * 将 `ToolRegistry._log_tool_call` 和 `_record_tool_call_summary` 调用参数由 `output` 改为 `result.response` - 文档更新: 补充单元测试覆盖 * 新增 `tests/models/tools/test_tool_result.py`,覆盖 `_compress_result` 在不同数据结构下的行为 * 删除 `tests/core/test_tool_registry.py` 中原有的重复压缩测试用例 * feat(core): 重构工具执行流程与异常处理机制 - 新增功能: 统一工具调用异常捕获逻辑 * `ToolRegistry.execute` 方法增加全局 try-except 块,将未找到的工具或运行时错误统一封装为 `ToolResult` 对象返回 * 移除 `tool_handler.py` 中冗余的 `json.dumps` 序列化及手动 traceback 打印逻辑 * 修改 `tool_handler.py` 中的结果获取方式,直接通过 `response.response` 访问标准化后的响应数据 - 修复问题: 修正导入路径与测试用例迁移 * 修复 `src/models/tools/tool_result.py` 中 `truncate_params_string` 的相对导入路径错误,改为绝对导入 `from src.utils.string_snapshot import ...` * 修复 `src/core/tool_registry.py` 中 `ToolResult` 的导入路径,从 `models.tools.tool_result` 调整为 `src.models.tools.tool_result` * 将配置验证测试 `test_validate_config` 从 `tests/core/test_tool_registry.py` 迁移至 `tests/models/tools/test_tool_result.py` * 删除 `tests/core/test_tool_registry.py` 中已不再适用的 `test_execute_nonexistent_tool` 测试用例 - 重构优化: 简化代码结构与提升健壮性 * 在 `tool_registry.py` 中移除 `execute` 方法内部对 `ValueError` 的显式 `else` 分支,利用外层异常处理统一覆盖所有错误场景 * 移除 `tool_handler.py` 中对 `json` 模块的依赖,减少不必要的字符串序列化操作 * refactor(workspace/tools): 统一工具响应格式至 ToolResult 标准 (#133) - 重构优化: 将各工具模块的返回类型从字符串或列表统一改为 `ToolResult` 对象 * 移除对 `src.models.tool_error_response.ToolErrorResponse` 的依赖,全面引入 `src.models.tools.tool_result.ToolResult` * 修改 `BaseTool` 中 `extract_params` 和 `handle_tool_exceptions` 的类型注解以适配新的返回结构 * 替换所有错误处理逻辑为调用 `self.make_failed_response(kwargs=locals().copy(), error=...)` * 替换所有成功返回逻辑为调用 `self.make_success_response(kwargs=locals().copy(), data=...)` - 新增功能: 增强搜索工具的异常处理与警告机制 * `regex_search_tool.py` 和 `exact_search_tool.py` 增加 `warnings` 列表收集文件读取错误,并通过 `error` 参数返回执行警告 * 修复正则表达式捕获语法错误 `(OSError, UnicodeDecodeError, PermissionError)` 为正确的元组形式 - 破坏性变更: 工具函数返回值类型变更 * 所有工具方法(如 `ls`, `read`, `write`, `git`, `stat`, `glob`, `edit`, `symbol_ref`, `regex_search`, `exact_search`)的返回类型由 `str` 或 `list[str]` 变更为 `ToolResult` * 调用方需更新解析逻辑以提取 `ToolResult` 中的 `data` 或 `error` 字段,不再直接解析字符串内容 * fix(tools): 优化异常处理逻辑与数据判空策略 - 修复问题: 改进 Git 命令超时异常的捕获与错误信息展示 * `src/workspace/tools/git_tool.py` 中修改 `subprocess.TimeoutExpired` 异常处理,将原始异常对象 `time_out_exception` 注入错误消息,替代固定文本"Git 命令执行超时 (30 秒)" - 重构优化: 统一工具结果模型中的空值判断逻辑 * `src/models/tools/tool_result.py` 中将 `if data:` 和 `if not data` 改为显式的 `if data is not None:` 和 `if data is None:`,避免布尔类型 falsy 值导致的误判 - 文档更新: 调整编辑工具的返回消息格式为多行字符串 * `src/workspace/tools/edit_tool.py` 中移除响应元组内的逗号分隔符,将原本的单行元组转换为包含换行符的多行字符串结构 * docs(tools): 规范工具开发流程与 ToolResult 返回标准 - 新增功能: 统一工具执行结果模型与文档示例 * 引入 `ToolResult` 类 (位于 `src/models/tools/tool_result.py`) 作为统一包装器 * 定义 `ToolResult` 字段:`success`, `func_name`, `func_kwargs`, `data`, `error`, `response` * 更新工具方法签名,要求返回类型为 `ToolResult` 而非原始数据 * 新增 `make_success_response` 和 `make_failed_response` 辅助方法使用示例 - 重构优化: 完善写入操作与注册逻辑说明 * 明确写入操作需通过 `self._validate_mtime(path)` 检查并记录 `PENDING_AUDIT` 快照 * 更新 `src/core/tool_registry.py` 中 `register()` 方法的实现细节与异常处理 * 增加对工具回调和参数缺失的警告机制 (`tool.func is None or tool.params is None`) - 文档更新: 强化测试覆盖要求 * 在 `tests/workspace/tools/` 下增加正常路径与失败场景的断言示例 * 明确要求验证 `result.success` 状态及 `result.data` / `result.error` 内容 * 提供基于 `WorkspaceBoundaryError` 的边界安全测试代码片段 * test(workspace): 统一工具执行结果断言为结构化属性 - 修复问题: 将测试断言从字符串匹配升级为基于 `result.success`、`result.data` 和 `result.error` 的结构化属性检查 * `TestGitAllowedCommands`: 更新 `git` 方法调用后的断言,使用 `result.response` 替代直接字符串操作 * `TestGitBlockedCommands`: 修改阻塞命令测试,通过 `result.success is False` 和 `result.error` 验证错误信息 * `TestBinaryProtection`: 修正 `ReadTool`、`WriteTool` 和 `EditTool` 对二进制文件的处理逻辑断言 * `TestEditTool`: 调整上下文匹配失败及文件不存在等场景的断言逻辑 * `TestWriteTool`: 更新外部修改检测场景的断言,验证 `FILE_MODIFIED_EXTERNALLY` 错误码 --- CONTRIBUTING.md | 78 +++++++++- CONTRIBUTING_ZH.md | 78 +++++++++- src/console/commands/systems/quit_cmd.py | 2 +- src/console/handlers/tool_handler.py | 33 +--- src/console/main.py | 1 - src/core/tool_registry.py | 119 +++++---------- src/models/tools/tool_result.py | 143 ++++++++++++++++++ src/workspace/tools/base_tool.py | 60 ++++++-- src/workspace/tools/edit_tool.py | 50 +++--- src/workspace/tools/exact_search_tool.py | 15 +- src/workspace/tools/git_tool.py | 75 ++++----- src/workspace/tools/glob_tool.py | 17 ++- src/workspace/tools/ls_tool.py | 17 ++- src/workspace/tools/read_tool.py | 45 +++--- src/workspace/tools/regex_search_tool.py | 18 ++- src/workspace/tools/stat_tool.py | 5 +- src/workspace/tools/symbol_ref_tool.py | 14 +- src/workspace/tools/write_tool.py | 30 ++-- src/workspace/workspace.py | 9 ++ tests/core/test_tool_registry.py | 78 ---------- tests/models/__init__.py | 0 tests/models/tools/__init__.py | 0 tests/models/tools/test_tool_result.py | 63 ++++++++ .../workspace/tools/test_binary_protection.py | 35 +++-- tests/workspace/tools/test_edit_tool.py | 54 ++++--- tests/workspace/tools/test_git_tool.py | 65 +++++--- tests/workspace/tools/test_read_tool.py | 6 +- tests/workspace/tools/test_write_tool.py | 11 +- 28 files changed, 727 insertions(+), 394 deletions(-) create mode 100644 src/models/tools/tool_result.py create mode 100644 tests/models/__init__.py create mode 100644 tests/models/tools/__init__.py create mode 100644 tests/models/tools/test_tool_result.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index acc9fae..2c61195 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -208,7 +208,21 @@ docs(README): add contributing guide ## How to Add a New Tool All tools are located in the `src/workspace/tools/` directory and inherit from -the `BaseTool` base class. +the `BaseTool` base class. All tool methods must return a `ToolResult` object +(located in `src/models/tools/tool_result.py`). + +### ToolResult Model + +`ToolResult` is the unified tool execution result wrapper: + +| Field | Description | +| :------------ | :---------------------------------------------- | +| `success` | Whether execution succeeded (`bool`) | +| `func_name` | Name of the tool method (`str`) | +| `func_kwargs` | Dictionary of invocation parameters (`dict`) | +| `data` | Return data on success (`Any`) | +| `error` | Error message on failure (`str\|None`) | +| `response` | Auto-generated XML string (for LLM consumption) | ### Step Overview @@ -217,6 +231,7 @@ the `BaseTool` base class. 2. **Inherit from BaseTool**: ```python + from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -237,14 +252,25 @@ the `BaseTool` base class. } @BaseTool.handle_tool_exceptions - def your_method(self, param1: str, param2: int = 0) -> str: + def your_method(self, param1: str, param2: int = 0) -> ToolResult: """ Tool description -- will be generated as LLM-readable documentation. """ # Path operations must be validated through PathValidator path = self.workspace.path_validator.validate(param1) - # ... tool logic ... - return f'{path}x{param2}' + + # Note: Return ToolResult, not raw data + # On success, use make_success_response + return self.make_success_response( + kwargs=locals().copy(), + data=f'{path}x{param2}' + ) + + # On failure, use make_failed_response + # return self.make_failed_response( + # kwargs=locals().copy(), + # error="Specific error description" + # ) ``` 3. **Special handling for write operations**: If the tool involves writing @@ -253,17 +279,59 @@ the `BaseTool` base class. `self._validate_mtime(path)` - Generate a diff and record a `PENDING_AUDIT` snapshot instead of writing directly to disk + - Still return via + `self.make_success_response(kwargs=locals().copy(), data=...)` or + `self.make_failed_response(kwargs=locals().copy(), error=...)` - Refer to the implementations of `WriteTool` and `EditTool` 4. **Register the tool**: Import and instantiate your tool in the `register()` method of `src/core/tool_registry.py`: + ```python + def register(self, workspace: Workspace) -> None: + from src.workspace.tools.your_tool import YourTool + + self._workspace = workspace + + for cls in ( + # ... other existing tool classes ... + YourTool, + ): + try: + tool = cls(workspace) + if tool.func is None or tool.params is None: + warnings.warn(f"Tool {tool.name} has no registered function callback and parameters", stacklevel=2) + continue + self._tools[tool.name] = tool + self._set_tool_category(tool) + except ValueError: + pass + ``` + 5. **Add tests**: Create corresponding test files under `tests/workspace/tools/`. At minimum, cover: - - Normal execution paths + - Normal execution paths (assert `result.success is True`, check + `result.data`) + - Failure scenarios (assert `result.success is False`, check `result.error`) - Parameter validation (e.g., empty parameters, invalid values) - Path security (e.g., out-of-bounds access) + Test example: + + ```python + def test_your_tool_success(workspace): + tool = YourTool(workspace) + result = tool.your_method(param1="valid_path", param2=42) + assert result.success is True + assert result.data is not None + + def test_your_tool_failure(workspace): + tool = YourTool(workspace) + result = tool.your_method(param1="../outside_path", param2=42) + assert result.success is False + assert "WorkspaceBoundaryError" in result.error + ``` + --- ## Testing Requirements diff --git a/CONTRIBUTING_ZH.md b/CONTRIBUTING_ZH.md index 053ac9b..2bb7b66 100644 --- a/CONTRIBUTING_ZH.md +++ b/CONTRIBUTING_ZH.md @@ -192,7 +192,22 @@ docs(README): 添加贡献指南 / add contributing guide ## 如何添加新工具 -所有工具位于 `src/workspace/tools/` 目录,继承 `BaseTool` 基类. +所有工具位于 `src/workspace/tools/` 目录,继承 `BaseTool` +基类. 所有工具方法必须返回 `ToolResult` 对象(位于 +`src/models/tools/tool_result.py`). + +### ToolResult 模型 + +`ToolResult` 是统一的工具执行结果包装器: + +| 字段 | 说明 | +| :------------ | :--------------------------------------- | +| `success` | 执行是否成功 (`bool`) | +| `func_name` | 工具方法名 (`str`) | +| `func_kwargs` | 调用参数字典 (`dict`) | +| `data` | 成功时的返回数据 (`Any`) | +| `error` | 失败时的错误消息 (`str\|None`) | +| `response` | 自动生成的 XML 格式字符串(用于 LLM 消费) | ### 步骤概述 @@ -201,6 +216,7 @@ docs(README): 添加贡献指南 / add contributing guide 2. **继承 BaseTool**: ```python + from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -221,29 +237,81 @@ docs(README): 添加贡献指南 / add contributing guide } @BaseTool.handle_tool_exceptions - def your_method(self, param1: str, param2: int = 0) -> str: + def your_method(self, param1: str, param2: int = 0) -> ToolResult: """ 工具描述 -- 会生成为 LLM 可读的文档. """ # 路径操作必须通过 PathValidator 验证 path = self.workspace.path_validator.validate(param1) - # ... 工具逻辑 ... - return f'{path}x{param2}' + + # 注意: 返回 ToolResult 而非原始数据 + # 成功时使用 make_success_response + return self.make_success_response( + kwargs=locals().copy(), + data=f'{path}x{param2}' + ) + + # 失败时使用 make_failed_response + # return self.make_failed_response( + # kwargs=locals().copy(), + # error="具体的错误描述" + # ) ``` 3. **写入操作的特殊处理**: 如果工具涉及写入(`write_permission=True`),需要: - 通过 `self._validate_mtime(path)` 检查文件是否被外部修改 - 生成 diff 并记录 `PENDING_AUDIT` 快照,而非直接写入磁盘 + - 返回格式仍然使用 + `self.make_success_response(kwargs=locals().copy(), data=...)` 或 + `self.make_failed_response(kwargs=locals().copy(), error=...)` - 参考 `WriteTool` 和 `EditTool` 的实现 4. **注册工具**: 在 `src/core/tool_registry.py` 的 `register()` 方法中导入并实例化你的工具: + ```python + def register(self, workspace: Workspace) -> None: + from src.workspace.tools.your_tool import YourTool + + self._workspace = workspace + + for cls in ( + # ... 其他已有的工具类 ... + YourTool, + ): + try: + tool = cls(workspace) + if tool.func is None or tool.params is None: + warnings.warn(f"工具{tool.name}没有注册功能回调和参数", stacklevel=2) + continue + self._tools[tool.name] = tool + self._set_tool_category(tool) + except ValueError: + pass + ``` + 5. **补充测试**: 在 `tests/workspace/tools/` 下创建对应的测试文件. 至少覆盖: - - 正常执行路径 + - 正常执行路径(断言 `result.success is True`, 检查 `result.data`) + - 失败场景(断言 `result.success is False`, 检查 `result.error`) - 参数验证(如空参数、非法值) - 路径安全(如越界访问) + 测试示例: + + ```python + def test_your_tool_success(workspace): + tool = YourTool(workspace) + result = tool.your_method(param1="valid_path", param2=42) + assert result.success is True + assert result.data is not None + + def test_your_tool_failure(workspace): + tool = YourTool(workspace) + result = tool.your_method(param1="../outside_path", param2=42) + assert result.success is False + assert "WorkspaceBoundaryError" in result.error + ``` + --- ## 测试要求 diff --git a/src/console/commands/systems/quit_cmd.py b/src/console/commands/systems/quit_cmd.py index fe574c3..ec199ac 100644 --- a/src/console/commands/systems/quit_cmd.py +++ b/src/console/commands/systems/quit_cmd.py @@ -16,7 +16,7 @@ def __init__(self): def execute(self, context: CommandContext) -> CommandResult: # Close session explicitly before exit (atexit handler is a fallback, # but explicit is more reliable when sys.exit triggers early shutdown). - session_id = getattr(context.tool_registry, "_current_session_id", None) + session_id = getattr(context.tool_registry, "session_id", None) if session_id is not None and hasattr(context.workspace, "db"): context.workspace.db.close_session(session_id) context.console.print("[bold]Goodbye![/bold]") diff --git a/src/console/handlers/tool_handler.py b/src/console/handlers/tool_handler.py index 874a827..0faaf67 100644 --- a/src/console/handlers/tool_handler.py +++ b/src/console/handlers/tool_handler.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import time from typing import TYPE_CHECKING @@ -100,35 +99,9 @@ def handle(self, parsed_input: CommandParseResult) -> bool: start = time.perf_counter() # 执行 - try: - response = self.tool_registry.execute(func_name, **func_kwargs) - if isinstance(response, str): - response_str = response - elif isinstance(response, (dict, list, tuple)): - response_str = json.dumps(response) - else: - response_str = f"{response.__class__.__name__}({response})" - - temp_result = [ - "", - f"", - response_str, - "", - "", - ] - func_result = str.join("\n", temp_result) - except Exception as e: - import traceback - - error = ( - f"执行工具{func_name}(参数={parms})时出现错误: " - f"Error={e.__class__.__name__}({e}, {traceback.format_exc()})" - ) - self.console.print(f"[red]{error}[/red]") - func_result = "\n".join(["", "", error, "", ""]) - self.console.print(f"[red]{error}[/red]") - - collection.add(func_name, time.perf_counter() - start, kwargs=func_kwargs, result=func_result) + response = self.tool_registry.execute(func_name, **func_kwargs) + + collection.add(func_name, time.perf_counter() - start, kwargs=func_kwargs, result=response.response) result = "" diff --git a/src/console/main.py b/src/console/main.py index 927a3f2..0d1c65c 100644 --- a/src/console/main.py +++ b/src/console/main.py @@ -85,7 +85,6 @@ def init_workspace(start_path: str | None = None) -> Workspace | None: session_id = workspace.db.create_session(name=f"session_{time.strftime('%Y%m%d_%H%M%S')}") tool_registry.set_session_id(session_id) - workspace._current_session_id = session_id # Start a daemon heartbeat thread to periodically persist session duration # and guard against accidental deletion flag diff --git a/src/core/tool_registry.py b/src/core/tool_registry.py index f6cf41f..55f8650 100644 --- a/src/core/tool_registry.py +++ b/src/core/tool_registry.py @@ -12,6 +12,7 @@ import nest_asyncio from src.console.result_manager import _to_string +from src.models.tools.tool_result import ToolResult from src.utils.string_snapshot import truncate_string from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -55,36 +56,6 @@ def __init__(self): # 配置常量 - 从环境变量读取,带默认值 self.MAX_DOC_LENGTH = int(os.getenv("TOOL_MAX_DOC_LENGTH", "360")) self.MAX_FUNC_NAME_LENGTH = int(os.getenv("TOOL_MAX_FUNC_NAME_LENGTH", "80")) - self.MAX_RESULT_LENGTH = int(os.getenv("TOOL_MAX_RESULT_LENGTH", "30000")) - self.LIST_TRUNCATE_THRESHOLD = int(os.getenv("TOOL_LIST_TRUNCATE_THRESHOLD", "100")) - self.DICT_TRUNCATE_THRESHOLD = int(os.getenv("TOOL_DICT_TRUNCATE_THRESHOLD", "100")) - - # 验证配置值 - self._validate_config() - - def _validate_config(self) -> None: - """验证配置值确保在合理范围内""" - if self.MAX_RESULT_LENGTH < 10: - warnings.warn( - f"TOOL_MAX_RESULT_LENGTH 过小({self.MAX_RESULT_LENGTH}),建议至少为100", UserWarning, stacklevel=2 - ) - self.MAX_RESULT_LENGTH = 100 - - if self.LIST_TRUNCATE_THRESHOLD < 10: - warnings.warn( - f"TOOL_LIST_TRUNCATE_THRESHOLD 过小({self.LIST_TRUNCATE_THRESHOLD}),建议至少为50", - UserWarning, - stacklevel=2, - ) - self.LIST_TRUNCATE_THRESHOLD = 50 - - if self.DICT_TRUNCATE_THRESHOLD < 10: - warnings.warn( - f"TOOL_DICT_TRUNCATE_THRESHOLD 过小({self.DICT_TRUNCATE_THRESHOLD}),建议至少为50", - UserWarning, - stacklevel=2, - ) - self.DICT_TRUNCATE_THRESHOLD = 50 def _validate_tool_info(self, name: str, doc: str) -> None: """验证工具信息并发出警告""" @@ -142,29 +113,7 @@ def register(self, workspace: Workspace) -> None: except ValueError: pass - def _compress_result(self, result: Any) -> Any: - """压缩过长的结果""" - result_length = len(result) - if isinstance(result, str): - if result_length > self.MAX_RESULT_LENGTH: - return ( - result[: self.MAX_RESULT_LENGTH] - + f"... [字符串结果已截断 显示的字符数: {self.LIST_TRUNCATE_THRESHOLD} / {result_length}]" - ) - elif isinstance(result, (list, tuple)): - if result_length > self.LIST_TRUNCATE_THRESHOLD: - return [ - *list(result[: self.LIST_TRUNCATE_THRESHOLD]), - f"... [列表已截断 显示的项: {self.LIST_TRUNCATE_THRESHOLD} / {result_length}]", - ] - elif isinstance(result, dict) and result_length > self.DICT_TRUNCATE_THRESHOLD: - compressed = {k: result[k] for k in list(result.keys())[: self.DICT_TRUNCATE_THRESHOLD]} - compressed["..."] = f"[字典已截断 显示的项: {self.DICT_TRUNCATE_THRESHOLD} / {result_length}]" - return compressed - - return result - - def execute(self, func_name: str, *args: Any, **kwargs: Any) -> Any: + def execute(self, func_name: str, *args: Any, **kwargs: Any) -> ToolResult: """ 执行工具函数 @@ -176,37 +125,48 @@ def execute(self, func_name: str, *args: Any, **kwargs: Any) -> Any: Returns: 函数执行结果(自动压缩过长的结果) """ - if func_name in self._tools: - tool = self._tools[func_name] - kwargs = tool.convert_args(kwargs) + try: + if func_name in self._tools: + tool = self._tools[func_name] + kwargs = tool.convert_args(kwargs) + + start_time = time.perf_counter() - start_time = time.perf_counter() - status = "success" - try: if inspect.iscoroutinefunction(tool.func): coro = tool.func(*args, **kwargs) # 已有事件循环 try: loop = asyncio.get_running_loop() nest_asyncio.apply() - result = loop.run_until_complete(coro) + raw_result = loop.run_until_complete(coro) except RuntimeError: # pragma: no cover // pytest内置事件循环, 测不到这里 # 没有运行中的事件循环 - result = asyncio.run(coro) + raw_result = asyncio.run(coro) else: # 同步函数 - result = tool.func(*args, **kwargs) - except Exception as e: - status = "error" - result = f'' - - duration_ms = (time.perf_counter() - start_time) * 1000 - self._log_tool_call(func_name, kwargs, duration_ms, status) - self._record_tool_call_summary(func_name, kwargs, result) - - return self._compress_result(result) - else: - raise ValueError(f"未找到工具: {func_name}") + raw_result = tool.func(*args, **kwargs) + + # 统一解包 ToolResult + result = ( + raw_result + if (isinstance(raw_result, ToolResult)) + else ( + ToolResult( + success=False, + func_name=func_name, + func_kwargs=kwargs, + error=f"错误的工具返回值类型: {raw_result.__class__.__name__}", + ) + ) + ) + duration_ms = (time.perf_counter() - start_time) * 1000 + self._log_tool_call(func_name, kwargs, duration_ms, result.status) + self._record_tool_call_summary(func_name, kwargs, result.response) + return result + else: + raise ValueError(f"未找到工具: {func_name}") + except Exception as e: + return ToolResult(success=False, data=kwargs, error=str(e), func_name=func_name, func_kwargs=kwargs) def generate_markdown(self) -> str: """ @@ -234,8 +194,10 @@ def get_tool_info(self, name: str) -> BaseTool | None: return self._tools.get(name) def set_session_id(self, session_id: int) -> None: - """设置当前会话 ID""" + """设置当前会话 ID,并同步到关联的 Workspace.""" self._current_session_id = session_id + if self._workspace is not None: + self._workspace.session_id = session_id @staticmethod def _compute_kwargs_json(kwargs: dict) -> str: @@ -246,13 +208,13 @@ def _compute_kwargs_json(kwargs: dict) -> str: return json.dumps(truncated, sort_keys=True, default=str) def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, status: str) -> str | None: - session_id = getattr(self, "_current_session_id", None) - if session_id is None: + if self._current_session_id is None: return None try: kwargs_json = self._compute_kwargs_json(kwargs) workspace = getattr(self, "_workspace", None) if workspace is not None: + session_id = self._current_session_id # Determine audit_status based on tool category category = self._tool_categories.get(func_name, "query") audit_status = "none" @@ -280,8 +242,7 @@ def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, statu return f"ToolRegistry(sync_tools={len(self._tools)})" def _record_tool_call_summary(self, func_name: str, kwargs: dict, result: Any) -> None: - session_id = getattr(self, "_current_session_id", None) - if session_id is None: + if self._current_session_id is None: return # Exclude write tools @@ -294,6 +255,6 @@ def _record_tool_call_summary(self, func_name: str, kwargs: dict, result: Any) - workspace = getattr(self, "_workspace", None) if workspace is not None: result_str = _to_string(result) - workspace.db.record_tool_call_summary(session_id, func_name, kwargs_json, result_str) + workspace.db.record_tool_call_summary(self._current_session_id, func_name, kwargs_json, result_str) except Exception: pass diff --git a/src/models/tools/tool_result.py b/src/models/tools/tool_result.py new file mode 100644 index 0000000..763e911 --- /dev/null +++ b/src/models/tools/tool_result.py @@ -0,0 +1,143 @@ +import json +import os +import warnings +from typing import Any, ClassVar + +from src.utils.string_snapshot import truncate_params_string + + +def to_xml_string(func_name: str, params: dict, data: Any = None, err: str | None = None) -> str: + params = params.copy() + params.pop("self", None) + + try: + messages: list[str] = [] + if data is not None: + messages.append("") + if isinstance(data, str): + messages.append(data) + elif isinstance(data, (dict, list, tuple)): + messages.append(json.dumps(data)) + else: + messages.append(f"{data.__class__.__name__}({data})") + messages.append("") + if err: + messages.extend( + [ + "", + err, + "", + ] + ) + if data is None and not err: + messages.append("没有任何工具调用数据或错误详情, 请提示用户检查工具是否正常") + + temp_result = [ + "", + f"", + *messages, + "", + "", + ] + func_result = str.join("\n", temp_result) + except Exception as e: + import traceback + + func_result = "\n".join( + [ + "", + f"", + f"Error={e.__class__.__name__}({e}, {traceback.format_exc()})", + err if err else "", + "", + "", + ] + ) + return func_result + + +class ToolResult: + """工具执行结果的结构化包装,显式区分成功与失败. + + 所有被 @handle_tool_exceptions 装饰的工具方法均返回此类型, + 调用方可通过 success 标志可靠判断执行状态,无需依赖隐式类型约定. + """ + + __slots__ = ("data", "error", "func_kwargs", "func_name", "response", "success") + + HAD_VALIDATE: ClassVar[bool] = False + MAX_RESULT_LENGTH: ClassVar[int] = int(os.getenv("TOOL_MAX_RESULT_LENGTH", "30000")) + LIST_TRUNCATE_THRESHOLD: ClassVar[int] = int(os.getenv("TOOL_LIST_TRUNCATE_THRESHOLD", "100")) + DICT_TRUNCATE_THRESHOLD: ClassVar[int] = int(os.getenv("TOOL_DICT_TRUNCATE_THRESHOLD", "100")) + + def __init__( + self, success: bool, func_name: str, func_kwargs: dict, data: Any = None, error: str | None = None + ) -> None: + self.success: bool = success + self.func_name: str = func_name + self.func_kwargs: dict = func_kwargs + self.data: Any = data + self.error: str | None = error + self.response: str = to_xml_string(self.func_name, self.func_kwargs, self.data, self.error) + ToolResult._validate_config() + + def __repr__(self) -> str: + if self.success: + return f"ToolResult(success=True, data={self.data!r})" + return f"ToolResult(success=False, error={self.error!r})" + + @property + def status(self) -> str: + return "success" if self.success else "error" + + @classmethod + def _compress_result(cls, result: Any) -> Any: + """压缩过长的结果""" + result_length = len(result) + if isinstance(result, str): + if result_length > cls.MAX_RESULT_LENGTH: + return ( + result[: cls.MAX_RESULT_LENGTH] + + f"... [字符串结果已截断 显示的字符数: {cls.LIST_TRUNCATE_THRESHOLD} / {result_length}]" + ) + elif isinstance(result, (list, tuple)): + if result_length > cls.LIST_TRUNCATE_THRESHOLD: + return [ + *list(result[: cls.LIST_TRUNCATE_THRESHOLD]), + f"... [列表已截断 显示的项: {cls.LIST_TRUNCATE_THRESHOLD} / {result_length}]", + ] + elif isinstance(result, dict) and result_length > cls.DICT_TRUNCATE_THRESHOLD: + compressed = {k: result[k] for k in list(result.keys())[: cls.DICT_TRUNCATE_THRESHOLD]} + compressed["..."] = f"[字典已截断 显示的项: {cls.DICT_TRUNCATE_THRESHOLD} / {result_length}]" + return compressed + + return result + + @classmethod + def _validate_config(cls) -> None: + """验证配置值确保在合理范围内""" + if cls.HAD_VALIDATE: + return None + if cls.MAX_RESULT_LENGTH < 10: + warnings.warn( + f"TOOL_MAX_RESULT_LENGTH 过小({cls.MAX_RESULT_LENGTH}),建议至少为100", UserWarning, stacklevel=2 + ) + cls.MAX_RESULT_LENGTH = 100 + + if cls.LIST_TRUNCATE_THRESHOLD < 10: + warnings.warn( + f"TOOL_LIST_TRUNCATE_THRESHOLD 过小({cls.LIST_TRUNCATE_THRESHOLD}),建议至少为50", + UserWarning, + stacklevel=2, + ) + cls.LIST_TRUNCATE_THRESHOLD = 50 + + if cls.DICT_TRUNCATE_THRESHOLD < 10: + warnings.warn( + f"TOOL_DICT_TRUNCATE_THRESHOLD 过小({cls.DICT_TRUNCATE_THRESHOLD}),建议至少为50", + UserWarning, + stacklevel=2, + ) + cls.DICT_TRUNCATE_THRESHOLD = 50 + cls.HAD_VALIDATE = True + return None diff --git a/src/workspace/tools/base_tool.py b/src/workspace/tools/base_tool.py index 367df61..70e6ab4 100644 --- a/src/workspace/tools/base_tool.py +++ b/src/workspace/tools/base_tool.py @@ -4,6 +4,7 @@ from typing import Any from src.core.file_tracker import FileTracker +from src.models.tools.tool_result import ToolResult from src.workspace.workspace import Workspace @@ -108,7 +109,7 @@ def __init__( self.write_permission: bool = write_permission self.name: str = name self.doc: str = doc - self.func: Callable[..., Any] | None = None + self.func: Callable[..., ToolResult] | None = None self.params: dict[str, Any] | None = None self.param_descriptions: dict[str, str] = {} @@ -139,7 +140,7 @@ def to_func_call(self) -> str: return func_call @staticmethod - def extract_params(func: Callable[..., Any]) -> dict[str, Any]: + def extract_params(func: Callable[..., ToolResult]) -> dict[str, Any]: """提取函数参数信息""" sig = inspect.signature(func) params = {} @@ -181,7 +182,7 @@ def _record_read_meta(self, resolved_path: Path) -> None: try: meta = FileTracker.get_file_meta(resolved_path) if meta: - session_id = self.workspace._current_session_id + session_id = self.workspace.session_id if session_id is not None: rel_path = str(resolved_path.relative_to(self.workspace.root_path)) self.workspace.db.record_file_read( @@ -195,7 +196,7 @@ def _validate_mtime(self, resolved_path: Path) -> str | None: if not resolved_path.exists(): return None - session_id = self.workspace._current_session_id + session_id = self.workspace.session_id if session_id is None: return None @@ -224,25 +225,60 @@ def _generate_diff(old_content: str, new_content: str, file_path: str) -> str: diff = difflib.unified_diff(old_lines, new_lines, fromfile=f"a/{file_path}", tofile=f"b/{file_path}") return "".join(diff) + @classmethod + def make_tool_result_response( + cls, success: bool, kwargs: dict, data: Any = None, error: str | None = None + ) -> ToolResult: + return ToolResult(success=success, func_name=cls.__name__, func_kwargs=kwargs, data=data, error=error) + + @classmethod + def make_success_response(cls, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult: + return cls.make_tool_result_response(success=True, kwargs=kwargs, data=data, error=error) + + @classmethod + def make_failed_response(cls, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult: + return cls.make_tool_result_response(success=False, kwargs=kwargs, data=data, error=error) + @staticmethod - def handle_tool_exceptions(func): - """工具方法异常处理装饰器.""" + def handle_tool_exceptions(func) -> Callable[..., ToolResult]: + """工具方法异常处理装饰器 —— 将异常转换为 ToolResult 失败结果""" from functools import wraps - from src.models.tool_error_response import ToolErrorResponse from src.workspace.path_validator import PathNotFoundError, WorkspaceBoundaryError @wraps(func) def wrapper(self, *args, **kwargs): try: - return func(self, *args, **kwargs) + raw = func(self, *args, **kwargs) + # 如果工具内部已返回 ToolResult, 直接透传 + if isinstance(raw, ToolResult): + return raw + # 否则包装为成功结果 + return ToolResult(success=True, func_name=func.__name__, func_kwargs=kwargs, data=raw) except PathNotFoundError as err1: - return ToolErrorResponse(self.__class__.__name__, err1).to_str() + return ToolResult( + success=False, + func_name=func.__name__, + func_kwargs=kwargs, + error=f"{err1.__class__.__name__}: {err1}", + ) except WorkspaceBoundaryError as err2: - return ToolErrorResponse(self.__class__.__name__, err2).to_str() + return ToolResult( + success=False, + func_name=func.__name__, + func_kwargs=kwargs, + error=f"{err2.__class__.__name__}: {err2}", + ) except PermissionError as err3: - return ToolErrorResponse(self.__class__.__name__, err3).to_str() + return ToolResult( + success=False, + func_name=func.__name__, + func_kwargs=kwargs, + error=f"{err3.__class__.__name__}: {err3}", + ) except Exception as err: - return ToolErrorResponse(self.__class__.__name__, err).to_str() + return ToolResult( + success=False, func_name=func.__name__, func_kwargs=kwargs, error=f"{err.__class__.__name__}: {err}" + ) return wrapper diff --git a/src/workspace/tools/edit_tool.py b/src/workspace/tools/edit_tool.py index 001826a..b33c154 100644 --- a/src/workspace/tools/edit_tool.py +++ b/src/workspace/tools/edit_tool.py @@ -3,6 +3,7 @@ from pathlib import Path from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.utils.binary_detector import is_binary_file from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -37,16 +38,16 @@ def edit( max_replacements: int = 10, context_before: str = "", context_after: str = "", - ) -> str: + ) -> ToolResult: """ 通过在文件中进行安全的字符串替换编辑文件 """ # 1. 参数校验 if not old_string: - return ToolErrorResponse(self.__class__.__name__, ValueError("old_string 不能为空")).to_str() + return self.make_failed_response(locals().copy(), error=f"{ValueError('old_string 不能为空')}") if max_replacements < 1: - return ToolErrorResponse(self.__class__.__name__, ValueError("max_replacements 必须 >= 1")).to_str() + return self.make_failed_response(locals().copy(), error=f"{ValueError('max_replacements 必须 >= 1')}") if max_replacements > 100: max_replacements = 100 @@ -55,21 +56,19 @@ def edit( resolved_path: Path = self.workspace.path_validator.resolve_path(source_path) if not resolved_path.is_file(): - return ToolErrorResponse( - self.__class__.__name__, - FileNotFoundError(f"文件不存在: {resolved_path}"), - ).to_str() + return self.make_failed_response( + locals().copy(), error=f"{FileNotFoundError(f'文件不存在: {resolved_path}')}" + ) if is_binary_file(resolved_path): - return ToolErrorResponse( - self.__class__.__name__, - ValueError(f"禁止编辑二进制文件: {resolved_path}"), - ).to_str() + return self.make_failed_response( + locals().copy(), error=f"{ValueError(f'禁止编辑二进制文件: {resolved_path}')}" + ) # 3. mtime 校验 mtime_error = self._validate_mtime(resolved_path) if mtime_error: - return mtime_error + return self.make_failed_response(locals().copy(), error=f"无法编辑被修改过的文件:\n{mtime_error}") # 4. 读取文件内容 old_content = resolved_path.read_text(encoding="utf-8") @@ -87,12 +86,17 @@ def edit( if context_before or context_after: ctx_error = self._check_context(old_content, idx, old_string, context_before, context_after, count) if ctx_error: - return ctx_error + return self.make_failed_response( + locals().copy(), error=f"无法修改上下文不匹配的字符串:\n{ctx_error}" + ) idx += len(old_string) if count == 0: - return f"No changes made: old_string not found in file.\nFile: {path}\nSearching for: '{old_string}'" + return self.make_failed_response( + locals().copy(), + error=f"No changes made: old_string not found in file.\nFile: {path}\nSearching for: '{old_string}'", + ) # 6. 执行替换(生成新内容) new_content = old_content.replace(old_string, new_string, count) @@ -106,7 +110,7 @@ def edit( old_hash = FileTracker.compute_checksum_from_string(old_content) new_hash = FileTracker.compute_checksum_from_string(new_content) - session_id = self.workspace._current_session_id + session_id = self.workspace.session_id snapshot_id = self.workspace.db.record_file_snapshot( rel_path, old_hash, @@ -118,12 +122,16 @@ def edit( ) # 9. 返回预览 - return ( - f"[Edit Preview]\n" - f"File: {rel_path}\n" - f"Snapshot ID: {snapshot_id}\n" - f"Replacements: {count}\n" - f"Diff:\n{diff_content}" + return self.make_success_response( + locals().copy(), + ( + "修改已推送到审核系统\n" + f"[Edit Preview]\n" + f"File: {rel_path}\n" + f"Snapshot ID: {snapshot_id}\n" + f"Replacements: {count}\n" + f"Diff:\n{diff_content}" + ), ) @staticmethod diff --git a/src/workspace/tools/exact_search_tool.py b/src/workspace/tools/exact_search_tool.py index c03b55b..873df3a 100644 --- a/src/workspace/tools/exact_search_tool.py +++ b/src/workspace/tools/exact_search_tool.py @@ -2,6 +2,7 @@ import re from pathlib import Path +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -96,7 +97,7 @@ def exact_search( file_pattern: str = "*", limit: int = 256, ignore: list[str] | None = None, - ) -> str: + ) -> ToolResult: """ 精确搜索字符串 """ @@ -117,6 +118,7 @@ def exact_search( results = [] file_count = 0 total_matches = 0 + warnings = [""] # 确定要搜索的文件列表(支持单文件或目录) files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob(file_pattern)) @@ -154,8 +156,15 @@ def exact_search( file_count += 1 total_matches += len(file_matches) - except OSError, UnicodeDecodeError, PermissionError: + except (OSError, UnicodeDecodeError, PermissionError) as e: + warnings.append(f"在文件{file_path}搜索匹配行时出错: {e}") continue # 跳过无法读取的文件 + warnings.append("") + # 格式化输出 - return _format_exact_results(results, pattern, limit, file_count, case_sensitive, whole_word) + return self.make_success_response( + kwargs=locals().copy(), + data=_format_exact_results(results, pattern, limit, file_count, case_sensitive, whole_word), + error="\n".join(warnings) if len(warnings) > 2 else None, + ) diff --git a/src/workspace/tools/git_tool.py b/src/workspace/tools/git_tool.py index af1e0c2..a0cc2eb 100644 --- a/src/workspace/tools/git_tool.py +++ b/src/workspace/tools/git_tool.py @@ -4,7 +4,7 @@ import shlex import subprocess -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -59,55 +59,61 @@ def __init__(self, workspace: Workspace): "command_str": "Git 子命令及其参数,如 'status'、'diff --cached'、'log --oneline -5'", } - def git(self, command_str: str) -> str: + def git(self, command_str: str) -> ToolResult: """ 执行 Git 命令 """ if not command_str or not command_str.strip(): - return ToolErrorResponse(self.__class__.__name__, ValueError("command_str 不能为空")).to_str() + return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError("command_str 不能为空"))) try: tokens = shlex.split(command_str) except ValueError as e: - return ToolErrorResponse(self.__class__.__name__, e).to_str() + return self.make_failed_response(kwargs=locals().copy(), error=str(e)) if not tokens: - return ToolErrorResponse(self.__class__.__name__, ValueError("无法解析命令")).to_str() + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError(f"无法解析命令: `{command_str}`")) + ) base_command = tokens[0] # 1. 白名单检查 if base_command not in _ALLOWED_COMMANDS: allowed_list = ", ".join(sorted(_ALLOWED_COMMANDS)) - return ( - f"ERROR: Git command '{base_command}' is not in the allowed whitelist.\n" - f"Allowed commands: {allowed_list}" + return self.make_failed_response( + kwargs=locals().copy(), + error=( + f"ERROR: Git command '{base_command}' is not in the allowed whitelist.\n" + f"Allowed commands: {allowed_list}" + ), ) # 2. 拦截正则检查 for pattern in _BLOCKED_PATTERNS: if pattern.search(command_str): - return ( - f"ERROR: The command was blocked by security policy.\n" - f"Pattern matched: {pattern.pattern}\n" - f"Command: {command_str}" + return self.make_failed_response( + kwargs=locals().copy(), + error=( + f"ERROR: The command was blocked by security policy.\n" + f"Pattern matched: {pattern.pattern}\n" + f"Command: {command_str}" + ), ) # 3. restore 安全检查 — 必须指定文件路径 if base_command == "restore": non_flag_args = [t for t in tokens[1:] if not t.startswith("-")] if not non_flag_args: - return ToolErrorResponse( - self.__class__.__name__, - ValueError("restore 需要指定文件路径,不允许裸 restore"), - ).to_str() + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError("restore 需要指定文件路径,不允许裸 restore")) + ) for arg in non_flag_args: stripped = arg.strip() if stripped in (".", "*", "all") or stripped.startswith("*"): - return ToolErrorResponse( - self.__class__.__name__, - ValueError("restore 需要指定具体文件路径,不允许使用通配符"), - ).to_str() + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError("restore 需要指定具体文件路径,不允许使用通配符")) + ) # 4. 执行命令 try: @@ -121,22 +127,19 @@ def git(self, command_str: str) -> str: env=env, ) except FileNotFoundError: - return ToolErrorResponse( - self.__class__.__name__, - OSError("Git 未安装或不在系统 PATH 中"), - ).to_str() - except subprocess.TimeoutExpired: - return ToolErrorResponse( - self.__class__.__name__, - TimeoutError("Git 命令执行超时(30 秒)"), - ).to_str() + return self.make_failed_response(kwargs=locals().copy(), error=str(OSError("Git 未安装或不在系统 PATH 中"))) + except subprocess.TimeoutExpired as time_out_exception: + return self.make_failed_response( + kwargs=locals().copy(), error=f"TimeoutExpired(Git 命令执行超时: {time_out_exception})" + ) # 5. 处理输出 if result.returncode != 0: stderr = (result.stderr or "").strip() - if stderr: - return f"Git command failed (exit code {result.returncode}):\n{stderr}" - return f"Git command failed (exit code {result.returncode})" + return self.make_failed_response( + kwargs=locals().copy(), + error=f"Git command failed (exit code {result.returncode})" + f":\n{stderr}" if stderr else "", + ) # Combine stdout and stderr output_parts = [] @@ -165,9 +168,13 @@ def git(self, command_str: str) -> str: if err: output_parts.append(err.rstrip("\n")) if not output_parts and _result2.returncode != 0: - return f"Git command failed (exit code {_result2.returncode})" + return self.make_failed_response( + kwargs=locals().copy(), error=f"Git command failed (exit code {_result2.returncode})" + ) - return "\n".join(output_parts) if output_parts else "(no output)" + return self.make_success_response( + kwargs=locals().copy(), data="\n".join(output_parts) if output_parts else "(no output)" + ) @staticmethod def is_safe_command(command_str: str) -> bool: diff --git a/src/workspace/tools/glob_tool.py b/src/workspace/tools/glob_tool.py index 7ee0c08..df158c0 100644 --- a/src/workspace/tools/glob_tool.py +++ b/src/workspace/tools/glob_tool.py @@ -1,6 +1,6 @@ from itertools import islice -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -17,15 +17,18 @@ def __init__(self, workspace: Workspace): } @BaseTool.handle_tool_exceptions - def glob(self, pattern: str, path: str = ".", max_ret: int = 1000) -> list[str]: + def glob(self, pattern: str, path: str = ".", max_ret: int = 1000) -> ToolResult: """ 在工作区内按通配符模式匹配并列出所有路径,带[Folder]或[File]的类型标记. 失败时返回错误信息 """ root_path = self.workspace.path_validator.validate(path) if not root_path.is_dir(): - return ToolErrorResponse(self.__class__.__name__, f"{root_path}不是一个文件夹路径").to_str() + return self.make_failed_response(kwargs=locals().copy(), error=f"{root_path}不是一个文件夹路径") - return [ - f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" - for item in islice(root_path.glob(pattern), max_ret) - ] + return self.make_success_response( + kwargs=locals().copy(), + data=[ + f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" + for item in islice(root_path.glob(pattern), max_ret) + ], + ) diff --git a/src/workspace/tools/ls_tool.py b/src/workspace/tools/ls_tool.py index 8b8ee34..33f035d 100644 --- a/src/workspace/tools/ls_tool.py +++ b/src/workspace/tools/ls_tool.py @@ -1,6 +1,6 @@ from pathlib import Path -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -15,14 +15,17 @@ def __init__(self, workspace: Workspace): } @BaseTool.handle_tool_exceptions - def ls(self, path: str = ".") -> list[str] | str: + def ls(self, path: str = ".") -> ToolResult: """ 列出指定目录下的文件和文件夹. 返回相对路径列表, 并标记[Folder]或[File] """ folder_path: Path = self.workspace.path_validator.validate(path) if not folder_path.is_dir(): - return ToolErrorResponse(self.__class__.__name__, f'参数错误: "{folder_path}"不是一个目录').to_str() - return [ - f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" - for item in folder_path.iterdir() - ] + return self.make_failed_response(kwargs=locals().copy(), error=f'参数错误: "{folder_path}"不是一个目录') + return self.make_success_response( + kwargs=locals().copy(), + data=[ + f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" + for item in folder_path.iterdir() + ], + ) diff --git a/src/workspace/tools/read_tool.py b/src/workspace/tools/read_tool.py index 3bf0c07..7bac9d8 100644 --- a/src/workspace/tools/read_tool.py +++ b/src/workspace/tools/read_tool.py @@ -1,7 +1,7 @@ import os from pathlib import Path -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.utils.binary_detector import is_binary_file from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -35,32 +35,34 @@ def __init__(self, workspace: Workspace): } @BaseTool.handle_tool_exceptions - def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encoding: str = "utf-8") -> str: + def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encoding: str = "utf-8") -> ToolResult: """ 读取文件内容, 返回带行号的格式化内容 """ file_path: Path = self.workspace.path_validator.validate(path) if not file_path.is_file(): - return ToolErrorResponse( - self.__class__.__name__, ValueError(f"读取文件{file_path}时未读取到完整文件") - ).to_str() + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError(f"读取文件{file_path}时未读取到完整文件")) + ) if is_binary_file(file_path): - return ToolErrorResponse( - self.__class__.__name__, - ValueError(f"无法读取二进制文件: {file_path}. 请使用二进制安全工具或转换为 base64."), - ).to_str() + return self.make_failed_response( + kwargs=locals().copy(), + error=str(ValueError(f"无法读取二进制文件: {file_path}. 请使用二进制安全工具或转换为 base64.")), + ) file_size = file_path.stat().st_size if file_size > MAX_FILE_SIZE: - return ToolErrorResponse( - self.__class__.__name__, - ValueError( - f"文件过大 ({file_size} 字节), 超过最大限制 ({MAX_FILE_SIZE} 字节): {file_path}. " - f"请使用范围参数 (start/end) 分批读取." + return self.make_failed_response( + kwargs=locals().copy(), + error=str( + ValueError( + f"文件过大 ({file_size} 字节), 超过最大限制 ({MAX_FILE_SIZE} 字节): {file_path}. " + f"请使用范围参数 (start/end) 分批读取." + ) ), - ).to_str() + ) with open(file_path, encoding=encoding) as f: lines = f.readlines() @@ -71,7 +73,7 @@ def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encod header = f"\n[文件: {file_path}]\n[行 0-0 / 共 0 行]\n" separator = "-" * 80 + "\n" self._record_read_meta(file_path) - return header + separator + return self.make_success_response(kwargs=locals().copy(), data=header + separator) context = max(0, context) @@ -84,9 +86,12 @@ def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encod actual_end = total_lines if actual_end < actual_start: - return ( - f"错误:解析后的结束行 {actual_end} 小于起始行 {actual_start} " - f"(原始参数: start={start}, end={end}, context={context})" + return self.make_failed_response( + kwargs=locals().copy(), + error=( + f"错误:解析后的结束行 {actual_end} 小于起始行 {actual_start} " + f"(原始参数: start={start}, end={end}, context={context})" + ), ) result_lines = [] @@ -100,4 +105,4 @@ def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encod self._record_read_meta(file_path) - return header + separator + "\n".join(result_lines) + return self.make_success_response(kwargs=locals().copy(), data=header + separator + "\n".join(result_lines)) diff --git a/src/workspace/tools/regex_search_tool.py b/src/workspace/tools/regex_search_tool.py index 621e8ee..a8b6549 100644 --- a/src/workspace/tools/regex_search_tool.py +++ b/src/workspace/tools/regex_search_tool.py @@ -2,7 +2,7 @@ import re from pathlib import Path -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -121,7 +121,7 @@ def regex_search( file_pattern: str = "*", limit: int = 256, ignore: list[str] | None = None, - ) -> str: + ) -> ToolResult: """ 使用正则表达式搜索文件内容, 支持上下文显示、文件过滤和忽略路径, 返回匹配详情; 适合代码与文档探索 """ @@ -132,7 +132,7 @@ def regex_search( try: regex = re.compile(pattern) except re.error as e: - return ToolErrorResponse(self.__class__.__name__, f"无效的正则表达式: {e}").to_str() + return self.make_failed_response(kwargs=locals().copy(), error=f"无效的正则表达式: {e}") # 收集忽略模式 ignore_patterns = [] @@ -145,6 +145,7 @@ def regex_search( results = [] file_count = 0 total_matches = 0 + warnings = [""] # 确定要搜索的文件列表(支持单文件或目录) files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob(file_pattern)) @@ -183,8 +184,15 @@ def regex_search( file_count += 1 total_matches += len(file_results) - except OSError, UnicodeDecodeError, PermissionError: + except (OSError, UnicodeDecodeError, PermissionError) as e: + warnings.append(f"在文件{file_path}搜索匹配行时出错: {e}") continue # 跳过无法读取的文件 + warnings.append("") + # 格式化输出 - return _format_regex_results(results, pattern, limit, file_count) + return self.make_success_response( + kwargs=locals().copy(), + data=_format_regex_results(results, pattern, limit, file_count), + error="\n".join(warnings) if len(warnings) > 2 else None, + ) diff --git a/src/workspace/tools/stat_tool.py b/src/workspace/tools/stat_tool.py index bf8d28f..f6caeae 100644 --- a/src/workspace/tools/stat_tool.py +++ b/src/workspace/tools/stat_tool.py @@ -2,6 +2,7 @@ from datetime import datetime from pathlib import Path +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -18,7 +19,7 @@ def __init__(self, workspace: Workspace): } @BaseTool.handle_tool_exceptions - def stat(self, path: str = ".") -> str: + def stat(self, path: str = ".") -> ToolResult: """ 获取工作区内文件或目录的详细信息,包括大小、行数(仅文件)、修改时间、权限等 """ @@ -142,4 +143,4 @@ def format_timestamp(timestamp: float) -> str: except PermissionError: output.append("目录内容: 无法访问") - return "\n".join(output) + return self.make_success_response(kwargs=locals().copy(), data="\n".join(output)) diff --git a/src/workspace/tools/symbol_ref_tool.py b/src/workspace/tools/symbol_ref_tool.py index 8d45d14..4f42291 100644 --- a/src/workspace/tools/symbol_ref_tool.py +++ b/src/workspace/tools/symbol_ref_tool.py @@ -3,7 +3,7 @@ import re from pathlib import Path -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -379,14 +379,14 @@ def symbol_ref( limit: int = 256, ignore: list[str] | None = None, file_pattern: str | None = None, - ) -> str: + ) -> ToolResult: """ 查找符号(函数、类、变量等)的定义和引用位置, 适用于代码探索、重构影响分析、理解代码结构等场景. """ # 验证搜索路径 search_path: Path = self.workspace.path_validator.validate(path) if not search_path.exists(): - return ToolErrorResponse(self.__class__.__name__, f"路径不存在: {path}").to_str() + return self.make_failed_response(kwargs=locals().copy(), error=f"路径不存在: {path}") # 自动检测语言 detected_lang = _detect_language(search_path, language) @@ -399,7 +399,9 @@ def symbol_ref( patterns = _generate_patterns(symbol_name, detected_lang, include_definitions, include_references) if not patterns: - return f"无法为符号 '{symbol_name}' 生成有效的搜索模式(语言: {detected_lang})" + return self.make_failed_response( + kwargs=locals().copy(), error=f"无法为符号 '{symbol_name}' 生成有效的搜索模式(语言: {detected_lang})" + ) # 使用并发搜索执行所有模式 all_results = _search_all_patterns( @@ -414,4 +416,6 @@ def symbol_ref( ) # 格式化输出 - return _format_results(all_results, symbol_name, detected_lang, limit) + return self.make_success_response( + kwargs=locals().copy(), data=_format_results(all_results, symbol_name, detected_lang, limit) + ) diff --git a/src/workspace/tools/write_tool.py b/src/workspace/tools/write_tool.py index 5310fcc..fe1a09c 100644 --- a/src/workspace/tools/write_tool.py +++ b/src/workspace/tools/write_tool.py @@ -1,7 +1,7 @@ from pathlib import Path from src.core.file_tracker import FileTracker -from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult from src.utils.binary_detector import is_binary_file from src.workspace.tools.base_tool import BaseTool from src.workspace.workspace import Workspace @@ -18,7 +18,7 @@ def __init__(self, workspace: Workspace): } @BaseTool.handle_tool_exceptions - def write(self, path: str, content: str = "") -> str: + def write(self, path: str, content: str = "") -> ToolResult: """ 写入文件内容, 如文件不存在则创建(含父目录) """ @@ -26,17 +26,18 @@ def write(self, path: str, content: str = "") -> str: path: Path = self.workspace.path_validator.resolve_path(source_path) if path.exists() and path.is_dir(): - return ToolErrorResponse(self.__class__.__name__, ValueError(f"路径 {path} 是一个目录,无法写入")).to_str() + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError(f"路径 {path} 是一个目录,无法写入")) + ) if is_binary_file(path): - return ToolErrorResponse( - self.__class__.__name__, - ValueError(f"禁止写入二进制文件: {path}"), - ).to_str() + return self.make_failed_response( + kwargs=locals().copy(), error=str(ValueError(f"禁止写入二进制文件: {path}")) + ) mtime_error = self._validate_mtime(path) if mtime_error: - return mtime_error + return self.make_failed_response(locals().copy(), error=f"无法编辑被修改过的文件:\n{mtime_error}") old_content = "" old_meta = None @@ -52,7 +53,7 @@ def write(self, path: str, content: str = "") -> str: new_hash = FileTracker.compute_checksum_from_string(content) diff_content = self._generate_diff(old_content, content, rel_path) - session_id = self.workspace._current_session_id + session_id = self.workspace.session_id snapshot_id = self.workspace.db.record_file_snapshot( rel_path, old_hash, @@ -63,4 +64,13 @@ def write(self, path: str, content: str = "") -> str: pending_content=content, ) - return f"[Write Preview]\nFile: {rel_path}\nSnapshot ID: {snapshot_id}\nDiff:\n{diff_content}" + return self.make_success_response( + kwargs=locals().copy(), + data=( + f"修改已推送到审核系统\n" + f"[Write Preview]\n" + f"File: {rel_path}\n" + f"Snapshot ID: {snapshot_id}\n" + f"Diff:\n{diff_content}" + ), + ) diff --git a/src/workspace/workspace.py b/src/workspace/workspace.py index bf588c6..56699f6 100644 --- a/src/workspace/workspace.py +++ b/src/workspace/workspace.py @@ -62,6 +62,15 @@ def db(self): self._db = DatabaseManager(str(self.root_path)) return self._db + @property + def session_id(self) -> int | None: + """当前会话 ID 的公开 getter —— 工具类通过此接口访问, 避免直接访问私有属性.""" + return self._current_session_id + + @session_id.setter + def session_id(self, value: int | None) -> None: + self._current_session_id = value + def search_content( self, pattern: str, diff --git a/tests/core/test_tool_registry.py b/tests/core/test_tool_registry.py index ad0ee62..ab1e2e2 100644 --- a/tests/core/test_tool_registry.py +++ b/tests/core/test_tool_registry.py @@ -33,31 +33,6 @@ def isolate_tool_registry(): DatabaseManager.reset_instances() -def test_validate_config(): - """测试配置验证 - 一次性测试所有阈值""" - config = ToolRegistry() - - # 设置所有值为过小 - config.MAX_RESULT_LENGTH = 5 - config.LIST_TRUNCATE_THRESHOLD = 3 - config.DICT_TRUNCATE_THRESHOLD = 2 - - # 验证触发3个警告 - with pytest.warns(UserWarning) as record: - config._validate_config() - - # 验证警告数量和内容 - assert len(record) == 3 - assert "TOOL_MAX_RESULT_LENGTH" in str(record[0].message) - assert "TOOL_LIST_TRUNCATE_THRESHOLD" in str(record[1].message) - assert "TOOL_DICT_TRUNCATE_THRESHOLD" in str(record[2].message) - - # 验证所有值都被修正 - assert config.MAX_RESULT_LENGTH == 100 - assert config.LIST_TRUNCATE_THRESHOLD == 50 - assert config.DICT_TRUNCATE_THRESHOLD == 50 - - def test_tool_registry_singleton(): """测试单例模式""" registry1 = ToolRegistry() @@ -67,14 +42,6 @@ def test_tool_registry_singleton(): assert id(registry1) == id(registry2) -def test_execute_nonexistent_tool(): - """测试执行不存在的工具""" - registry = ToolRegistry() - - with pytest.raises(ValueError, match="未找到工具: nonexistent"): - registry.execute("nonexistent") - - def test_validate_tool_info(): """测试工具信息验证""" registry = ToolRegistry() @@ -95,51 +62,6 @@ def test_validate_tool_info(): assert any(f"超过 {MAX_DOC_LENGTH} 字符" in str(warning.message) for warning in w) -def test_compress_result_string(): - """测试字符串结果压缩""" - registry = ToolRegistry() - - long_string = "x" * (MAX_RESULT_LENGTH + 10000) - - compressed = registry._compress_result(long_string) - assert "结果已截断" in compressed - - -def test_compress_result_list(): - """测试列表结果压缩""" - registry = ToolRegistry() - - long_list = list(range(150)) - - compressed = registry._compress_result(long_list) - assert len(compressed) == 101 - assert "列表已截断" in compressed[-1] - - -def test_compress_result_dict(): - """测试字典结果压缩""" - registry = ToolRegistry() - - long_dict = {f"key_{i}": f"value_{i}" for i in range(150)} - - compressed = registry._compress_result(long_dict) - assert len(compressed) == 101 - assert "字典已截断" in compressed["..."] - - -def test_no_compress_short_results(): - """测试不对短结果进行压缩""" - registry = ToolRegistry() - - short_string = "short" - short_list = [1, 2, 3] - short_dict = {"a": 1, "b": 2} - - assert registry._compress_result(short_string) == short_string - assert registry._compress_result(short_list) == short_list - assert registry._compress_result(short_dict) == short_dict - - class TestToolCategorization: """测试工具分类逻辑(需要 workspace 注册工具).""" diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/tools/__init__.py b/tests/models/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/tools/test_tool_result.py b/tests/models/tools/test_tool_result.py new file mode 100644 index 0000000..660e098 --- /dev/null +++ b/tests/models/tools/test_tool_result.py @@ -0,0 +1,63 @@ +import pytest + +from src.models.tools.tool_result import ToolResult + + +def test_compress_result_list(): + """测试列表结果压缩""" + long_list = list(range(150)) + + compressed = ToolResult._compress_result(long_list) + assert len(compressed) == 101 + assert "列表已截断" in compressed[-1] + + +def test_compress_result_dict(): + """测试字典结果压缩""" + long_dict = {f"key_{i}": f"value_{i}" for i in range(150)} + + compressed = ToolResult._compress_result(long_dict) + assert len(compressed) == 101 + assert "字典已截断" in compressed["..."] + + +def test_no_compress_short_results(): + """测试不对短结果进行压缩""" + short_string = "short" + short_list = [1, 2, 3] + short_dict = {"a": 1, "b": 2} + + assert ToolResult._compress_result(short_string) == short_string + assert ToolResult._compress_result(short_list) == short_list + assert ToolResult._compress_result(short_dict) == short_dict + + +def test_validate_config(): + """测试配置验证 - 一次性测试所有阈值""" + # 设置所有值为过小 + ToolResult.MAX_RESULT_LENGTH = 5 + ToolResult.LIST_TRUNCATE_THRESHOLD = 3 + ToolResult.DICT_TRUNCATE_THRESHOLD = 2 + + # 验证触发3个警告 + with pytest.warns(UserWarning) as record: + ToolResult._validate_config() + + # 验证警告数量和内容 + assert len(record) == 3 + assert "TOOL_MAX_RESULT_LENGTH" in str(record[0].message) + assert "TOOL_LIST_TRUNCATE_THRESHOLD" in str(record[1].message) + assert "TOOL_DICT_TRUNCATE_THRESHOLD" in str(record[2].message) + + # 验证所有值都被修正 + assert ToolResult.MAX_RESULT_LENGTH == 100 + assert ToolResult.LIST_TRUNCATE_THRESHOLD == 50 + assert ToolResult.DICT_TRUNCATE_THRESHOLD == 50 + + +def test_compress_result_string(): + """测试字符串结果压缩""" + long_string = "x" * (ToolResult.MAX_RESULT_LENGTH + 10000) + + compressed = ToolResult._compress_result(long_string) + assert "结果已截断" in compressed diff --git a/tests/workspace/tools/test_binary_protection.py b/tests/workspace/tools/test_binary_protection.py index cf1df3f..6845464 100644 --- a/tests/workspace/tools/test_binary_protection.py +++ b/tests/workspace/tools/test_binary_protection.py @@ -54,7 +54,8 @@ def test_read_binary_by_extension(self, workspace: Workspace, binary_ext_file: P tool = ReadTool(workspace) result = tool.read("image.png") - assert "二进制文件" in result + assert result.success is False + assert "二进制文件" in result.error def test_read_text_file_still_works(self, workspace: Workspace, text_file: Path): """文本文件读取应不受影响.""" @@ -63,8 +64,8 @@ def test_read_text_file_still_works(self, workspace: Workspace, text_file: Path) tool = ReadTool(workspace) result = tool.read("readme.txt") - assert "hello world" in result - assert "二进制文件" not in result + assert "hello world" in result.data + assert "二进制文件" not in result.data class TestReadRangeBinaryProtection: @@ -77,7 +78,8 @@ def test_read_range_binary_by_extension(self, workspace: Workspace, binary_ext_f tool = ReadTool(workspace) result = tool.read("image.png", start=1, end=10) - assert "二进制文件" in result + assert result.success is False + assert "二进制文件" in result.error def test_read_range_text_file_still_works(self, workspace: Workspace, text_file: Path): """文本文件读取应不受影响.""" @@ -86,8 +88,8 @@ def test_read_range_text_file_still_works(self, workspace: Workspace, text_file: tool = ReadTool(workspace) result = tool.read("readme.txt", start=1, end=2) - assert "hello world" in result - assert "二进制文件" not in result + assert "hello world" in result.data + assert "二进制文件" not in result.data class TestWriteBinaryProtection: @@ -100,7 +102,8 @@ def test_write_binary_by_extension_blocked(self, workspace: Workspace, binary_ex tool = WriteTool(workspace) result = tool.write("image.png", "malicious content") - assert "二进制文件" in result + assert result.success is False + assert "二进制文件" in result.error # 不应创建快照 rows = workspace.db.fetchall("SELECT * FROM file_snapshots") assert len(rows) == 0 @@ -112,8 +115,8 @@ def test_write_text_file_still_works(self, workspace: Workspace, text_file: Path tool = WriteTool(workspace) result = tool.write("readme.txt", "new content") - assert "Write Preview" in result - assert "二进制文件" not in result + assert "Write Preview" in result.data + assert "二进制文件" not in result.data def test_write_new_binary_ext_blocked(self, workspace: Workspace): """写入新的二进制扩展名文件(不存在)也应被阻止.""" @@ -122,7 +125,8 @@ def test_write_new_binary_ext_blocked(self, workspace: Workspace): tool = WriteTool(workspace) result = tool.write("new_app.exe", "fake exe content") - assert "二进制文件" in result + assert result.success is False + assert "二进制文件" in result.error def test_write_new_text_ext_allowed(self, workspace: Workspace): """写入新的文本扩展名文件应正常通过.""" @@ -131,8 +135,8 @@ def test_write_new_text_ext_allowed(self, workspace: Workspace): tool = WriteTool(workspace) result = tool.write("new_file.py", "print('hello')") - assert "Write Preview" in result - assert "二进制文件" not in result + assert "Write Preview" in result.data + assert "二进制文件" not in result.data class TestEditBinaryProtection: @@ -145,7 +149,8 @@ def test_edit_binary_by_extension_blocked(self, workspace: Workspace, binary_ext tool = EditTool(workspace) result = tool.edit("image.png", "fake", "replaced") - assert "二进制文件" in result + assert result.success is False + assert "二进制文件" in result.error def test_edit_text_file_still_works(self, workspace: Workspace, text_file: Path): """编辑文本文件应不受影响.""" @@ -154,5 +159,5 @@ def test_edit_text_file_still_works(self, workspace: Workspace, text_file: Path) tool = EditTool(workspace) result = tool.edit("readme.txt", "hello", "hi") - assert "Edit Preview" in result - assert "二进制文件" not in result + assert "Edit Preview" in result.data + assert "二进制文件" not in result.data diff --git a/tests/workspace/tools/test_edit_tool.py b/tests/workspace/tools/test_edit_tool.py index 270982c..9f17d51 100644 --- a/tests/workspace/tools/test_edit_tool.py +++ b/tests/workspace/tools/test_edit_tool.py @@ -52,8 +52,8 @@ def test_simple_replacement(self, edit_tool, workspace): _create_file(workspace, "test.txt", "hello world") result = edit_tool.edit("test.txt", "world", "there") - assert "Edit Preview" in result - assert "Snapshot ID:" in result + assert "Edit Preview" in result.data + assert "Snapshot ID:" in result.data def test_does_not_write_to_disk(self, edit_tool, workspace): file = _create_file(workspace, "test.txt", "hello world") @@ -74,14 +74,14 @@ def test_diff_in_preview(self, edit_tool, workspace): _create_file(workspace, "test.txt", "line1\nline2\nline3") result = edit_tool.edit("test.txt", "line2", "modified") - assert "-line2" in result - assert "+modified" in result + assert "-line2" in result.data + assert "+modified" in result.data def test_multiple_replacements(self, edit_tool, workspace): _create_file(workspace, "test.txt", "a a a a a") result = edit_tool.edit("test.txt", "a", "b", max_replacements=3) - assert "Replacements: 3" in result + assert "Replacements: 3" in result.data # Verify pending content has exactly 3 replacements snap = workspace.db.fetchone("SELECT pending_content FROM file_snapshots") assert snap is not None @@ -91,8 +91,9 @@ def test_no_match_found(self, edit_tool, workspace): _create_file(workspace, "test.txt", "hello world") result = edit_tool.edit("test.txt", "nonexistent", "replacement") - assert "No changes made" in result - assert "old_string not found" in result + assert result.success is False + assert "No changes made" in result.error + assert "old_string not found" in result.error class TestEditMaxReplacements: @@ -100,19 +101,19 @@ def test_exceeds_max_replacements(self, edit_tool, workspace): _create_file(workspace, "test.txt", "a a a a a a a a a a a a") # 12 a's result = edit_tool.edit("test.txt", "a", "b", max_replacements=5) - assert "Replacements: 5" in result + assert "Replacements: 5" in result.data def test_max_replacements_default_10(self, edit_tool, workspace): _create_file(workspace, "test.txt", " ".join(["a"] * 20)) result = edit_tool.edit("test.txt", "a", "b") - assert "Replacements: 10" in result + assert "Replacements: 10" in result.data def test_max_replacements_capped_at_100(self, edit_tool, workspace): _create_file(workspace, "test.txt", "a " * 150) result = edit_tool.edit("test.txt", "a", "b", max_replacements=200) - assert "Replacements: 100" in result + assert "Replacements: 100" in result.data class TestEditContextValidation: @@ -120,33 +121,35 @@ def test_context_before_matches(self, edit_tool, workspace): _create_file(workspace, "test.txt", "prefix target suffix") result = edit_tool.edit("test.txt", "target", "replaced", context_before="prefix ") - assert "Edit Preview" in result + assert "Edit Preview" in result.data def test_context_before_mismatch(self, edit_tool, workspace): _create_file(workspace, "test.txt", "prefix target suffix") result = edit_tool.edit("test.txt", "target", "replaced", context_before="wrong ") - assert "context_before" in result - assert "mismatch" in result.lower() + assert result.success is False + assert "context_before" in result.error + assert "mismatch" in result.error.lower() def test_context_after_matches(self, edit_tool, workspace): _create_file(workspace, "test.txt", "prefix target suffix") result = edit_tool.edit("test.txt", "target", "replaced", context_after=" suffix") - assert "Edit Preview" in result + assert "Edit Preview" in result.data def test_context_after_mismatch(self, edit_tool, workspace): _create_file(workspace, "test.txt", "prefix target suffix") result = edit_tool.edit("test.txt", "target", "replaced", context_after=" wrong") - assert "context_after" in result - assert "mismatch" in result.lower() + assert result.success is False + assert "context_after" in result.error + assert "mismatch" in result.error.lower() def test_both_contexts_match(self, edit_tool, workspace): _create_file(workspace, "test.txt", "before target after") result = edit_tool.edit("test.txt", "target", "replaced", context_before="before ", context_after=" after") - assert "Edit Preview" in result + assert "Edit Preview" in result.data def test_context_with_multiple_matches(self, edit_tool, workspace): _create_file(workspace, "test.txt", "before X after\nignore\nbefore X after") @@ -154,8 +157,8 @@ def test_context_with_multiple_matches(self, edit_tool, workspace): "test.txt", "X", "Y", max_replacements=2, context_before="before ", context_after=" after" ) - assert "Edit Preview" in result - assert "Replacements: 2" in result + assert "Edit Preview" in result.data + assert "Replacements: 2" in result.data class TestEditMtimeValidation: @@ -168,13 +171,14 @@ def test_edit_modified_externally_fails(self, edit_tool, read_tool, workspace): file.write_text("modified externally", encoding="utf-8") result = edit_tool.edit("test.txt", "original", "replaced") - assert "FILE_MODIFIED_EXTERNALLY" in result + assert result.success is False + assert "FILE_MODIFIED_EXTERNALLY" in result.error def test_edit_no_prior_read_succeeds(self, edit_tool, workspace): _create_file(workspace, "test.txt", "original content") result = edit_tool.edit("test.txt", "original", "updated") - assert "Edit Preview" in result + assert "Edit Preview" in result.data class TestEditEdgeCases: @@ -182,15 +186,17 @@ def test_empty_old_string(self, edit_tool, workspace): _create_file(workspace, "test.txt", "content") result = edit_tool.edit("test.txt", "", "replacement") - assert "不能为空" in result or "empty" in result.lower() + assert result.success is False + assert "不能为空" in result.error or "empty" in result.error.lower() def test_nonexistent_file(self, edit_tool, workspace): result = edit_tool.edit("nonexistent.txt", "old", "new") - assert "不存在" in result or "not found" in result.lower() or "exists" in result.lower() + assert result.success is False + assert "不存在" in result.error or "not found" in result.error.lower() or "exists" in result.error.lower() def test_file_outside_workspace(self, edit_tool, workspace): result = edit_tool.edit("../outside.txt", "old", "new") - assert "越界" in result or "boundary" in result.lower() or "outside" in result.lower() + assert "越界" in result.error or "boundary" in result.error.lower() or "outside" in result.error.lower() def test_edit_snapshot_has_old_hash(self, edit_tool, workspace): _create_file(workspace, "test.txt", "original") diff --git a/tests/workspace/tools/test_git_tool.py b/tests/workspace/tools/test_git_tool.py index 44a61ad..19450dd 100644 --- a/tests/workspace/tools/test_git_tool.py +++ b/tests/workspace/tools/test_git_tool.py @@ -50,90 +50,105 @@ def git_tool(workspace: Workspace): class TestGitAllowedCommands: def test_status(self, git_tool): result = git_tool.git("status") - assert "nothing to commit" in result.lower() or "working tree clean" in result.lower() + assert "nothing to commit" in result.response.lower() or "working tree clean" in result.lower() def test_diff(self, git_tool): result = git_tool.git("diff") - assert "(no output)" in result or result == "" or result == "(no output)" + assert "(no output)" in result.response or result.response == "" or result.response == "(no output)" def test_log(self, git_tool): result = git_tool.git("log --oneline -1") - assert "initial" in result.lower() + assert "initial" in result.response.lower() def test_show(self, git_tool): result = git_tool.git("show --stat") - assert "README" in result or "initial" in result.lower() + assert result.success is True + assert "README" in result.data or "initial" in result.data.lower() def test_add_and_commit(self, git_tool, git_repo: Path): (git_repo / "new_file.txt").write_text("content", encoding="utf-8") add_result = git_tool.git("add new_file.txt") - assert "failed" not in add_result.lower() + assert add_result.success is True commit_result = git_tool.git('commit -m "test commit"') - assert "commi" in commit_result.lower() or "file changed" in commit_result.lower() + assert commit_result.success is True + assert "commi" in commit_result.data.lower() or "file changed" in commit_result.data.lower() def test_branch(self, git_tool): result = git_tool.git("branch") - assert "*" in result or "main" in result or "master" in result + assert result.success is True + assert "*" in result.data or "main" in result.data or "master" in result.data def test_restore_specific_file(self, git_tool, git_repo: Path): (git_repo / "README.md").write_text("modified\n", encoding="utf-8") result = git_tool.git("restore README.md") - assert "failed" not in result.lower() + assert "failed" not in result.data.lower() class TestGitBlockedCommands: def test_push_blocked(self, git_tool): result = git_tool.git("push") - assert "blocked" in result.lower() or "ERROR" in result + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error def test_remote_blocked(self, git_tool): result = git_tool.git("remote add origin https://example.com/repo.git") - assert "blocked" in result.lower() or "ERROR" in result + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error def test_reset_hard_blocked(self, git_tool): result = git_tool.git("reset --hard HEAD") - assert "blocked" in result.lower() or "ERROR" in result + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error def test_branch_d_blocked(self, git_tool): result = git_tool.git("branch -D test") - assert "blocked" in result.lower() or "ERROR" in result + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error def test_merge_blocked(self, git_tool): result = git_tool.git("merge test-branch") - assert "blocked" in result.lower() or "ERROR" in result + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error def test_rebase_blocked(self, git_tool): result = git_tool.git("rebase main") - assert "blocked" in result.lower() or "ERROR" in result + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error def test_clean_blocked(self, git_tool): result = git_tool.git("clean -fd") - assert "blocked" in result.lower() or "ERROR" in result + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error def test_fetch_blocked(self, git_tool): result = git_tool.git("fetch origin") - assert "blocked" in result.lower() or "ERROR" in result + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error def test_pull_blocked(self, git_tool): result = git_tool.git("pull origin main") - assert "blocked" in result.lower() or "ERROR" in result + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error class TestGitRestoreSafety: def test_bare_restore_rejected(self, git_tool): result = git_tool.git("restore") - assert "需要指定文件路径" in result or "ERROR" in result or "restore" in result.lower() + assert result.success is False + assert "需要指定文件路径" in result.error or "ERROR" in result.error or "restore" in result.error.lower() def test_restore_dot_rejected(self, git_tool): result = git_tool.git("restore .") - assert "通配符" in result or "ERROR" in result or "restore" in result.lower() + assert result.success is False + assert "通配符" in result.error or "ERROR" in result.error or "restore" in result.error.lower() class TestGitUnknownCommand: def test_unknown_command_rejected(self, git_tool): result = git_tool.git("unknown-command") - assert "not in the allowed whitelist" in result.lower() or "ERROR" in result + assert result.success is False + assert "not in the allowed whitelist" in result.error.lower() or "ERROR" in result.error class TestGitIsSafeCommand: @@ -168,8 +183,14 @@ def test_empty_string(self): class TestGitInjection: def test_command_injection_via_semicolon(self, git_tool): result = git_tool.git("status; echo pwned") - assert "blocked" in result.lower() or "ERROR" in result or "not in the allowed whitelist" in result.lower() + assert result.success is False + assert ( + "blocked" in result.error.lower() + or "ERROR" in result.error + or "not in the allowed whitelist" in result.error.lower() + ) def test_invalid_shell_syntax(self, git_tool): result = git_tool.git("status $(whoami)") - assert "failed" not in result.lower() or "error" not in result.lower() + assert result.success is True + assert "working tree clean" in result.data.lower() or "nothing to commit" in result.data.lower() diff --git a/tests/workspace/tools/test_read_tool.py b/tests/workspace/tools/test_read_tool.py index 8ca926a..6cd710d 100644 --- a/tests/workspace/tools/test_read_tool.py +++ b/tests/workspace/tools/test_read_tool.py @@ -33,7 +33,7 @@ def test_read_records_mtime(self, workspace: Workspace, tmp_path: Path): tool = ReadTool(workspace) result = tool.read("test.txt") - assert "hello" in result + assert "hello" in result.data record = workspace.db.get_file_read_record(workspace._current_session_id, "test.txt") assert record is not None assert record[3] == pytest.approx(file.stat().st_mtime, abs=0.01) @@ -59,7 +59,7 @@ def test_read_nonexistent_file_no_db_record(self, workspace: Workspace): tool = ReadTool(workspace) result = tool.read("nonexistent.txt") - assert "error" in result.lower() or "Error" in result + assert "error" in result.error.lower() or "Error" in result.error record = workspace.db.get_file_read_record(workspace._current_session_id, "nonexistent.txt") assert record is None @@ -88,6 +88,6 @@ def test_read_range_records_mtime(self, workspace: Workspace, tmp_path: Path): tool = ReadTool(workspace) result = tool.read("test.txt", start=1, end=2) - assert "line1" in result + assert "line1" in result.data record = workspace.db.get_file_read_record(workspace._current_session_id, "test.txt") assert record is not None diff --git a/tests/workspace/tools/test_write_tool.py b/tests/workspace/tools/test_write_tool.py index 7bcc937..49f0d44 100644 --- a/tests/workspace/tools/test_write_tool.py +++ b/tests/workspace/tools/test_write_tool.py @@ -40,8 +40,8 @@ def read_tool(workspace: Workspace): class TestWritePreview: def test_write_returns_preview(self, write_tool, tmp_path: Path): result = write_tool.write("new_file.txt", "hello") - assert "Write Preview" in result - assert "Snapshot ID:" in result + assert "Write Preview" in result.data + assert "Snapshot ID:" in result.data def test_write_does_not_write_to_disk(self, write_tool, tmp_path: Path): write_tool.write("new_file.txt", "hello") @@ -85,7 +85,7 @@ def test_write_after_read_shows_diff(self, write_tool, read_tool, tmp_path: Path read_tool.read("test.txt") result = write_tool.write("test.txt", "line1\nmodified\nline3") - assert "Write Preview" in result + assert "Write Preview" in result.data rows = write_tool.workspace.db.fetchall("SELECT diff_content FROM file_snapshots") assert len(rows) == 1 assert "-line2" in rows[0][0] @@ -111,14 +111,15 @@ def test_write_modified_externally_fails(self, write_tool, read_tool, tmp_path: time.sleep(0.1) result = write_tool.write("test.txt", "should fail") - assert "FILE_MODIFIED_EXTERNALLY" in result + assert result.success is False + assert "FILE_MODIFIED_EXTERNALLY" in result.error def test_write_no_prior_read_succeeds(self, write_tool, tmp_path: Path): file = tmp_path / "test.txt" file.write_text("existing content", encoding="utf-8") result = write_tool.write("test.txt", "new content") - assert "Write Preview" in result + assert "Write Preview" in result.data class TestWriteSnapshotPendingContent: From 0f20fa920cef879025e3134829ea33a9c5d58187 Mon Sep 17 00:00:00 2001 From: Suntion <149924916+SunYanbox@users.noreply.github.com> Date: Tue, 5 May 2026 15:39:35 +0800 Subject: [PATCH 8/8] Releases/v0.5.0 (#143) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore(project): 更新项目版本号至 0.5.0 - 文档更新: 同步 README 文件中的版本信息 * 更新 `README_ZH.md` 中显示的版本号为 0.5.0 * 更新 `README.md` 中显示的版本号为 0.5.0 - 配置更新: 修改项目元数据与常量定义 * 在 `pyproject.toml` 中将 `[project]` 的 `version` 字段从 "0.4.1" 更新为 "0.5.0" * 在 `src/constants/__init__.py` 中将 `__version__` 变量从 "0.4.1" 更新为 "0.5.0" * docs(changelog): 记录 v0.5.0 版本工具重构与搜索优化变更 - 新增功能: 扩展工具能力与参数机制 * 引入 `ToolResult` 数据类统一所有工具的返回类型,包含 `success`、`data`、`error`、`response` 属性 * `exact_search` 新增 `file_pattern` 参数支持文件模式过滤 * 基于 `write_permission` 实现工具自动分类,无需手动注册 * `read` 工具支持通过 `start`、`end`、`context` 参数进行精确行范围读取 * `BaseTool` 引入 `param_descriptions` 字典,参数文档格式转为 Markdown 列表 * `read` 工具增加 `MAX_READ_FILE_SIZE` 配置项防止大文件内存溢出 - 修复问题: 修正搜索逻辑与清理冗余代码 * 修正 `exact_search` 和 `regex_search` 中 `limit` 参数语义,改为统计匹配结果数 * 移除输入解析器中过时的 `warnings.warn` 调用及未使用的依赖 - 重构优化: 提升性能并简化系统提示注入 * 使用 `search_content_multi_pattern` API 替代逐模式遍历,消除 I/O 开销并返回结构化 `list[dict]` * 统一所有工具路径参数名为 `path`,替换原有的 `file_path` 和 `folder_path` * 精简工具 Docstring 描述并移除冗余段落,降低 LLM 幻觉风险 * `handle_tool_exceptions` 装饰器统一封装异常为 `ToolResult` 对象,移除 `ToolErrorResponse` 依赖 - 破坏性变更: 移除旧版工具并更改参数命名 * 删除 `read_lines` 工具,其功能已合并至 `read` 工具,需迁移至使用 `read` 的 `start`/`end`/`context` 参数 * 所有工具的 `file_path` 和 `folder_path` 参数重命名为 `path`,上游调用方需同步更新参数名称 --- README.md | 2 +- README_ZH.md | 2 +- docs/CHANGELOG.md | 76 +++++++++++++++++++++++++++++++++++++++ docs/CHANGELOG_ZH.md | 55 ++++++++++++++++++++++++++++ pyproject.toml | 2 +- src/constants/__init__.py | 2 +- 6 files changed, 135 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 71c73f6..2675310 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ LLM chat interfaces. Paste LLM-generated tool calls (in XML format), review and audit dangerous operations, and manage sessions with full history tracking -- all running locally on your machine. -> **Version**: 0.4.1 | **Python**: >=3.14 +> **Version**: 0.5.0 | **Python**: >=3.14 --- diff --git a/README_ZH.md b/README_ZH.md index 51b375d..3479c05 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -8,7 +8,7 @@ ManualAid 提供了一个基于 Textual 的 TUI 控制台,在剪贴板和 LLM 聊天界面之间架起桥梁. 粘贴 LLM 生成的工具调用(XML 格式),审查和审计危险操作,并通过完整的历史追踪管理会话 -- 一切都在本地运行. -> **版本**: 0.4.1 | **Python**: >=3.14 +> **版本**: 0.5.0 | **Python**: >=3.14 --- diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index b9b7b89..70e2416 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -8,6 +8,81 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.0] - 2026-05-05 + +### Added + +- **Structured Tool Result**: Introduced `ToolResult` data class as the unified + return type for all tools, replacing inconsistent string and list responses. + The class includes `success`, `data`, `error`, and `response` attributes with + built-in result compression and standardized XML formatting. All tools now + return `ToolResult` objects, enabling consistent upstream error handling + ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)). +- **File Pattern Filtering for `exact_search`**: Added `file_pattern` parameter + to `exact_search` (default `"*"`), matching the existing `regex_search` + behavior. Allows filtering search scope by file extension or glob pattern + ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)). +- **Auto-Categorization of Tools**: Tools are now automatically classified as + read-only or write based on the `write_permission` attribute, eliminating + manual category registration and reducing maintenance overhead + ([#135, #139](https://github.com/SunYanbox/ManualAid/issues/135)). +- **Range Reading in `read` Tool**: The `read` tool now supports precise line + range reading via `start`, `end` (supports negative indexing), and `context` + parameters, replacing the coarse `max_lines` approach. Display header now + shows the actual line range read (`[Lines start-end / total_lines]`) + ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)). +- **Parameter Descriptions**: Introduced `param_descriptions` dictionary in + `BaseTool` allowing each tool to provide human-readable parameter + descriptions. Parameter documentation format changed from inline XML to + Markdown list items (`- **name** (type, required/optional): description`) + ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **File Size Limit**: Added configurable max file size limit + (`MAX_READ_FILE_SIZE`, default 10MB) to the `read` tool to prevent + out-of-memory errors when reading large files + ([#130, #141](https://github.com/SunYanbox/ManualAid/issues/130)). + +### Changed + +- **Tool Path Parameter Unification**: Renamed path parameters across all tools + to a consistent `path` name — `file_path` (read, write, edit) and + `folder_path` (ls, glob) are now uniformly `path`. This reduces LLM confusion + and injection token length + ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **Tool Injection Optimization**: Removed redundant docstring `Parameters` + sections from tool functions, shortened tool descriptions, and streamlined + parameter documentation format. Combined with parameter unification, these + changes significantly reduce system prompt injection length, lowering LLM + hallucination risk + ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **Symbol Search Performance**: Replaced per-pattern file traversal with a + single-pass multi-pattern search via `search_content_multi_pattern` API, + eliminating N× I/O overhead. Results are now parsed as structured `list[dict]` + instead of regex-parsing formatted text, fixing the "format-then-parse" + anti-pattern + ([#132, #137](https://github.com/SunYanbox/ManualAid/issues/132)). +- **Exception Handling Consolidation**: The `handle_tool_exceptions` decorator + now uniformly wraps all exceptions into `ToolResult(success=False, error=...)` + objects. Removed `ToolErrorResponse` dependency; error messages are now + formatted as `ClassName: Message` + ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)). + +### Fixed + +- **`limit` Semantics in Search Tools**: Corrected the `limit` parameter in both + `exact_search` and `regex_search` to count individual match results rather + than files scanned, aligning behavior with user expectations + ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)). +- **Redundant Warnings in Input Parser**: Removed stale `warnings.warn` calls + and the unused `import warnings` dependency from the input parser + ([#138](https://github.com/SunYanbox/ManualAid/issues/138)). + +### Removed + +- **`read_lines` Tool**: Merged into the enhanced `read` tool with range-reading + support. All `read_lines` functionality is now accessible via `read` with + `start`/`end`/`context` parameters + ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)). + ## [0.4.1] - 2026-05-04 ### Added @@ -176,6 +251,7 @@ and this project adheres to _Initial release features and history._ +[0.5.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.5.0 [0.4.1]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.1 [0.4.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.0 [0.3.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.3.0 diff --git a/docs/CHANGELOG_ZH.md b/docs/CHANGELOG_ZH.md index d8fd008..1540562 100644 --- a/docs/CHANGELOG_ZH.md +++ b/docs/CHANGELOG_ZH.md @@ -8,6 +8,60 @@ [Keep a Changelog](https://keepachangelog.com/zh-CN/1.0.0/). 并采用 [语义化版本](https://semver.org/lang/Chinese/). +## [0.5.0] - 2026-05-05 + +### 新增 + +- **结构化工具返回结果**: 引入 `ToolResult` + 数据类作为所有工具的统一返回类型,替代以往不一致的字符串和列表响应. 该类包含 + `success`、`data`、`error`、`response` + 属性,内置结果压缩与标准化 XML 格式化功能. 所有工具方法现均返回 `ToolResult` + 对象,使上游调用方能进行一致的错误处理 ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)). +- **`exact_search` 文件模式过滤**: 为 `exact_search` 新增 `file_pattern` + 参数(默认 `"*"`),与已有的 `regex_search` + 行为对齐. 支持按文件扩展名或通配符模式过滤搜索范围 ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)). +- **工具自动分类**: 工具现根据 `write_permission` + 属性自动归类为只读或可写,无需手动注册分类,减少维护成本 ([#135, #139](https://github.com/SunYanbox/ManualAid/issues/135)). +- **`read` 工具范围读取**: `read` 工具现支持通过 `start`、`end`(支持负数索引) 和 + `context` 参数进行精确的行范围读取,替代原有的粗粒度 `max_lines` + 方式. 显示头部现展示实际读取的行范围 (`[行 start-end / 共 total_lines 行]`) + ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)). +- **参数描述机制**: 在 `BaseTool` 中引入 `param_descriptions` + 字典,允许每个工具为参数提供可读描述. 参数文档格式从内联 XML 转为 Markdown 列表项 (`- **名称** (类型, 必需/可选): 描述`) + ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **文件大小限制**: 为 `read` + 工具添加了可配置的最大文件大小限制 (`MAX_READ_FILE_SIZE`,默认 10MB),防止读取大文件时内存溢出 ([#130, #141](https://github.com/SunYanbox/ManualAid/issues/130)). + +### 更改 + +- **工具路径参数统一**: 将所有工具中的路径参数统一重命名为 `path`——原 + `file_path`(read、write、edit)和 `folder_path`(ls、glob)现统一使用 + `path`. 此举减少 LLM 混淆并缩短注入 Token 长度 ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **工具注入长度优化**: 移除工具函数中冗余的 Docstring `Parameters` + 段落、缩短工具描述、精简参数文档格式. 配合参数统一,这些变更显著缩短了系统提示注入长度,降低 LLM 幻觉风险 ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)). +- **符号搜索性能重构**: 用单次遍历多模式搜索(`search_content_multi_pattern` + API)替代原有的逐模式文件遍历, 消除了 N 倍 I/O 开销. 搜索结果现解析为结构化的 + `list[dict]` + 而非对格式化文本做正则解析,修复了 "先格式化再解析"的反模式 ([#132, #137](https://github.com/SunYanbox/ManualAid/issues/132)). +- **异常处理整合**: `handle_tool_exceptions` 装饰器现统一将所有异常封装为 + `ToolResult(success=False, error=...)` 对象. 移除了 `ToolErrorResponse` + 依赖,错误消息格式化为 `ClassName: Message` + ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)). + +### 修复 + +- **搜索工具 `limit` 语义**: 修正了 `exact_search` 和 `regex_search` 中 `limit` + 参数的计数逻辑,从统计"扫描文件数"改为统计"匹配结果数",使行为符合用户预期 ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)). +- **输入解析器冗余警告**: 清理了输入解析器中已过时的 `warnings.warn` + 调用及未使用的 `import warnings` + 依赖 ([#138](https://github.com/SunYanbox/ManualAid/issues/138)). + +### 移除 + +- **`read_lines` 工具**: 已合并至增强后的 `read` 工具. 所有 `read_lines` + 功能现可通过 `read` 的 `start`/`end`/`context` + 参数访问 ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)). + ## [0.4.1] - 2026-05-04 ### 新增 @@ -144,6 +198,7 @@ _初始发布的功能和历史记录._ +[0.5.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.5.0 [0.4.1]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.1 [0.4.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.0 [0.3.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.3.0 diff --git a/pyproject.toml b/pyproject.toml index 5d32ecb..ed36fd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ManualAid" -version = "0.4.1" +version = "0.5.0" description = "" requires-python = ">=3.14" dependencies = [ diff --git a/src/constants/__init__.py b/src/constants/__init__.py index 3d26edf..3d18726 100644 --- a/src/constants/__init__.py +++ b/src/constants/__init__.py @@ -1 +1 @@ -__version__ = "0.4.1" +__version__ = "0.5.0"