diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 820ebcc..2c61195 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -208,7 +208,21 @@ docs(README): add contributing guide
## How to Add a New Tool
All tools are located in the `src/workspace/tools/` directory and inherit from
-the `BaseTool` base class.
+the `BaseTool` base class. All tool methods must return a `ToolResult` object
+(located in `src/models/tools/tool_result.py`).
+
+### ToolResult Model
+
+`ToolResult` is the unified tool execution result wrapper:
+
+| Field | Description |
+| :------------ | :---------------------------------------------- |
+| `success` | Whether execution succeeded (`bool`) |
+| `func_name` | Name of the tool method (`str`) |
+| `func_kwargs` | Dictionary of invocation parameters (`dict`) |
+| `data` | Return data on success (`Any`) |
+| `error` | Error message on failure (`str\|None`) |
+| `response` | Auto-generated XML string (for LLM consumption) |
### Step Overview
@@ -217,6 +231,7 @@ the `BaseTool` base class.
2. **Inherit from BaseTool**:
```python
+ from src.models.tools.tool_result import ToolResult
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -231,21 +246,31 @@ the `BaseTool` base class.
)
self.func = self.your_method
self.params = BaseTool.extract_params(self.your_method)
+ self.param_descriptions = {
+ "param1": "Parameter description",
+ "param2": "Parameter description"
+ }
@BaseTool.handle_tool_exceptions
- def your_method(self, param1: str, param2: int = 0) -> str:
+ def your_method(self, param1: str, param2: int = 0) -> ToolResult:
"""
Tool description -- will be generated as LLM-readable documentation.
-
- Parameters
- ----------
- param1: Parameter description
- param2: Parameter description (with default value)
"""
# Path operations must be validated through PathValidator
path = self.workspace.path_validator.validate(param1)
- # ... tool logic ...
- return f'{path}x{param2}'
+
+ # Note: Return ToolResult, not raw data
+ # On success, use make_success_response
+ return self.make_success_response(
+ kwargs=locals().copy(),
+ data=f'{path}x{param2}'
+ )
+
+ # On failure, use make_failed_response
+ # return self.make_failed_response(
+ # kwargs=locals().copy(),
+ # error="Specific error description"
+ # )
```
3. **Special handling for write operations**: If the tool involves writing
@@ -254,17 +279,59 @@ the `BaseTool` base class.
`self._validate_mtime(path)`
- Generate a diff and record a `PENDING_AUDIT` snapshot instead of writing
directly to disk
+ - Still return via
+ `self.make_success_response(kwargs=locals().copy(), data=...)` or
+ `self.make_failed_response(kwargs=locals().copy(), error=...)`
- Refer to the implementations of `WriteTool` and `EditTool`
4. **Register the tool**: Import and instantiate your tool in the `register()`
method of `src/core/tool_registry.py`:
+ ```python
+ def register(self, workspace: Workspace) -> None:
+ from src.workspace.tools.your_tool import YourTool
+
+ self._workspace = workspace
+
+ for cls in (
+ # ... other existing tool classes ...
+ YourTool,
+ ):
+ try:
+ tool = cls(workspace)
+ if tool.func is None or tool.params is None:
+ warnings.warn(f"Tool {tool.name} has no registered function callback and parameters", stacklevel=2)
+ continue
+ self._tools[tool.name] = tool
+ self._set_tool_category(tool)
+ except ValueError:
+ pass
+ ```
+
5. **Add tests**: Create corresponding test files under
`tests/workspace/tools/`. At minimum, cover:
- - Normal execution paths
+ - Normal execution paths (assert `result.success is True`, check
+ `result.data`)
+ - Failure scenarios (assert `result.success is False`, check `result.error`)
- Parameter validation (e.g., empty parameters, invalid values)
- Path security (e.g., out-of-bounds access)
+ Test example:
+
+ ```python
+ def test_your_tool_success(workspace):
+ tool = YourTool(workspace)
+ result = tool.your_method(param1="valid_path", param2=42)
+ assert result.success is True
+ assert result.data is not None
+
+ def test_your_tool_failure(workspace):
+ tool = YourTool(workspace)
+ result = tool.your_method(param1="../outside_path", param2=42)
+ assert result.success is False
+ assert "WorkspaceBoundaryError" in result.error
+ ```
+
---
## Testing Requirements
diff --git a/CONTRIBUTING_ZH.md b/CONTRIBUTING_ZH.md
index 680f0c9..2bb7b66 100644
--- a/CONTRIBUTING_ZH.md
+++ b/CONTRIBUTING_ZH.md
@@ -192,7 +192,22 @@ docs(README): 添加贡献指南 / add contributing guide
## 如何添加新工具
-所有工具位于 `src/workspace/tools/` 目录,继承 `BaseTool` 基类.
+所有工具位于 `src/workspace/tools/` 目录,继承 `BaseTool`
+基类. 所有工具方法必须返回 `ToolResult` 对象(位于
+`src/models/tools/tool_result.py`).
+
+### ToolResult 模型
+
+`ToolResult` 是统一的工具执行结果包装器:
+
+| 字段 | 说明 |
+| :------------ | :--------------------------------------- |
+| `success` | 执行是否成功 (`bool`) |
+| `func_name` | 工具方法名 (`str`) |
+| `func_kwargs` | 调用参数字典 (`dict`) |
+| `data` | 成功时的返回数据 (`Any`) |
+| `error` | 失败时的错误消息 (`str\|None`) |
+| `response` | 自动生成的 XML 格式字符串(用于 LLM 消费) |
### 步骤概述
@@ -201,6 +216,7 @@ docs(README): 添加贡献指南 / add contributing guide
2. **继承 BaseTool**:
```python
+ from src.models.tools.tool_result import ToolResult
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -215,36 +231,87 @@ docs(README): 添加贡献指南 / add contributing guide
)
self.func = self.your_method
self.params = BaseTool.extract_params(self.your_method)
+ self.param_descriptions = {
+ "param1": "参数说明",
+ "param2": "参数说明"
+ }
@BaseTool.handle_tool_exceptions
- def your_method(self, param1: str, param2: int = 0) -> str:
+ def your_method(self, param1: str, param2: int = 0) -> ToolResult:
"""
工具描述 -- 会生成为 LLM 可读的文档.
-
- Parameters
- ----------
- param1: 参数说明
- param2: 参数说明(带默认值)
"""
# 路径操作必须通过 PathValidator 验证
path = self.workspace.path_validator.validate(param1)
- # ... 工具逻辑 ...
- return f'{path}x{param2}'
+
+ # 注意: 返回 ToolResult 而非原始数据
+ # 成功时使用 make_success_response
+ return self.make_success_response(
+ kwargs=locals().copy(),
+ data=f'{path}x{param2}'
+ )
+
+ # 失败时使用 make_failed_response
+ # return self.make_failed_response(
+ # kwargs=locals().copy(),
+ # error="具体的错误描述"
+ # )
```
3. **写入操作的特殊处理**: 如果工具涉及写入(`write_permission=True`),需要:
- 通过 `self._validate_mtime(path)` 检查文件是否被外部修改
- 生成 diff 并记录 `PENDING_AUDIT` 快照,而非直接写入磁盘
+ - 返回格式仍然使用
+ `self.make_success_response(kwargs=locals().copy(), data=...)` 或
+ `self.make_failed_response(kwargs=locals().copy(), error=...)`
- 参考 `WriteTool` 和 `EditTool` 的实现
4. **注册工具**: 在 `src/core/tool_registry.py` 的 `register()`
方法中导入并实例化你的工具:
+ ```python
+ def register(self, workspace: Workspace) -> None:
+ from src.workspace.tools.your_tool import YourTool
+
+ self._workspace = workspace
+
+ for cls in (
+ # ... 其他已有的工具类 ...
+ YourTool,
+ ):
+ try:
+ tool = cls(workspace)
+ if tool.func is None or tool.params is None:
+ warnings.warn(f"工具{tool.name}没有注册功能回调和参数", stacklevel=2)
+ continue
+ self._tools[tool.name] = tool
+ self._set_tool_category(tool)
+ except ValueError:
+ pass
+ ```
+
5. **补充测试**: 在 `tests/workspace/tools/` 下创建对应的测试文件. 至少覆盖:
- - 正常执行路径
+ - 正常执行路径(断言 `result.success is True`, 检查 `result.data`)
+ - 失败场景(断言 `result.success is False`, 检查 `result.error`)
- 参数验证(如空参数、非法值)
- 路径安全(如越界访问)
+ 测试示例:
+
+ ```python
+ def test_your_tool_success(workspace):
+ tool = YourTool(workspace)
+ result = tool.your_method(param1="valid_path", param2=42)
+ assert result.success is True
+ assert result.data is not None
+
+ def test_your_tool_failure(workspace):
+ tool = YourTool(workspace)
+ result = tool.your_method(param1="../outside_path", param2=42)
+ assert result.success is False
+ assert "WorkspaceBoundaryError" in result.error
+ ```
+
---
## 测试要求
diff --git a/README.md b/README.md
index 396bb4c..2675310 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@ LLM chat interfaces. Paste LLM-generated tool calls (in XML format), review and
audit dangerous operations, and manage sessions with full history tracking --
all running locally on your machine.
-> **Version**: 0.4.1 | **Python**: >=3.14
+> **Version**: 0.5.0 | **Python**: >=3.14
---
@@ -125,8 +125,7 @@ ManualAid registers 12 tools for LLM use via XML function calls:
| -------------- | ------------------------------------------------ |
| `ls` | List directory contents |
| `glob` | Find files by glob pattern |
-| `read` | Read file contents (with optional line limit) |
-| `read_lines` | Read specific line range from a file |
+| `read` | Read file contents with optional line range |
| `stat` | Get file/directory metadata (size, mtime, lines) |
| `exact_search` | Exact string search with case/whole-word options |
| `regex_search` | Regex search with context display |
diff --git a/README_ZH.md b/README_ZH.md
index 2a4dc7e..3479c05 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -8,7 +8,7 @@
ManualAid 提供了一个基于 Textual 的 TUI 控制台,在剪贴板和 LLM 聊天界面之间架起桥梁. 粘贴 LLM 生成的工具调用(XML 格式),审查和审计危险操作,并通过完整的历史追踪管理会话 -- 一切都在本地运行.
-> **版本**: 0.4.1 | **Python**: >=3.14
+> **版本**: 0.5.0 | **Python**: >=3.14
---
@@ -112,8 +112,7 @@ ManualAid 注册了 12 个工具供 LLM 通过 XML 函数调用使用:
| -------------- | ----------------------------------------- |
| `ls` | 列出目录内容 |
| `glob` | 通过 glob 模式查找文件 |
-| `read` | 读取文件内容(可选行数限制) |
-| `read_lines` | 读取文件中指定范围的行 |
+| `read` | 读取文件内容,支持指定行范围 |
| `stat` | 获取文件/目录元数据(大小、修改时间、行数) |
| `exact_search` | 精确字符串搜索,支持大小写/全词匹配 |
| `regex_search` | 正则表达式搜索,支持上下文显示 |
diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md
index b9b7b89..70e2416 100644
--- a/docs/CHANGELOG.md
+++ b/docs/CHANGELOG.md
@@ -8,6 +8,81 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [0.5.0] - 2026-05-05
+
+### Added
+
+- **Structured Tool Result**: Introduced `ToolResult` data class as the unified
+ return type for all tools, replacing inconsistent string and list responses.
+ The class includes `success`, `data`, `error`, and `response` attributes with
+ built-in result compression and standardized XML formatting. All tools now
+ return `ToolResult` objects, enabling consistent upstream error handling
+ ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)).
+- **File Pattern Filtering for `exact_search`**: Added `file_pattern` parameter
+ to `exact_search` (default `"*"`), matching the existing `regex_search`
+ behavior. Allows filtering search scope by file extension or glob pattern
+ ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)).
+- **Auto-Categorization of Tools**: Tools are now automatically classified as
+ read-only or write based on the `write_permission` attribute, eliminating
+ manual category registration and reducing maintenance overhead
+ ([#135, #139](https://github.com/SunYanbox/ManualAid/issues/135)).
+- **Range Reading in `read` Tool**: The `read` tool now supports precise line
+ range reading via `start`, `end` (supports negative indexing), and `context`
+ parameters, replacing the coarse `max_lines` approach. Display header now
+ shows the actual line range read (`[Lines start-end / total_lines]`)
+ ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)).
+- **Parameter Descriptions**: Introduced `param_descriptions` dictionary in
+ `BaseTool` allowing each tool to provide human-readable parameter
+ descriptions. Parameter documentation format changed from inline XML to
+ Markdown list items (`- **name** (type, required/optional): description`)
+ ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)).
+- **File Size Limit**: Added configurable max file size limit
+ (`MAX_READ_FILE_SIZE`, default 10MB) to the `read` tool to prevent
+ out-of-memory errors when reading large files
+ ([#130, #141](https://github.com/SunYanbox/ManualAid/issues/130)).
+
+### Changed
+
+- **Tool Path Parameter Unification**: Renamed path parameters across all tools
+ to a consistent `path` name — `file_path` (read, write, edit) and
+ `folder_path` (ls, glob) are now uniformly `path`. This reduces LLM confusion
+ and injection token length
+ ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)).
+- **Tool Injection Optimization**: Removed redundant docstring `Parameters`
+ sections from tool functions, shortened tool descriptions, and streamlined
+ parameter documentation format. Combined with parameter unification, these
+ changes significantly reduce system prompt injection length, lowering LLM
+ hallucination risk
+ ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)).
+- **Symbol Search Performance**: Replaced per-pattern file traversal with a
+ single-pass multi-pattern search via `search_content_multi_pattern` API,
+ eliminating N× I/O overhead. Results are now parsed as structured `list[dict]`
+ instead of regex-parsing formatted text, fixing the "format-then-parse"
+ anti-pattern
+ ([#132, #137](https://github.com/SunYanbox/ManualAid/issues/132)).
+- **Exception Handling Consolidation**: The `handle_tool_exceptions` decorator
+ now uniformly wraps all exceptions into `ToolResult(success=False, error=...)`
+ objects. Removed `ToolErrorResponse` dependency; error messages are now
+ formatted as `ClassName: Message`
+ ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)).
+
+### Fixed
+
+- **`limit` Semantics in Search Tools**: Corrected the `limit` parameter in both
+ `exact_search` and `regex_search` to count individual match results rather
+ than files scanned, aligning behavior with user expectations
+ ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)).
+- **Redundant Warnings in Input Parser**: Removed stale `warnings.warn` calls
+ and the unused `import warnings` dependency from the input parser
+ ([#138](https://github.com/SunYanbox/ManualAid/issues/138)).
+
+### Removed
+
+- **`read_lines` Tool**: Merged into the enhanced `read` tool with range-reading
+ support. All `read_lines` functionality is now accessible via `read` with
+ `start`/`end`/`context` parameters
+ ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)).
+
## [0.4.1] - 2026-05-04
### Added
@@ -176,6 +251,7 @@ and this project adheres to
_Initial release features and history._
+[0.5.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.5.0
[0.4.1]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.1
[0.4.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.0
[0.3.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.3.0
diff --git a/docs/CHANGELOG_ZH.md b/docs/CHANGELOG_ZH.md
index d8fd008..1540562 100644
--- a/docs/CHANGELOG_ZH.md
+++ b/docs/CHANGELOG_ZH.md
@@ -8,6 +8,60 @@
[Keep a Changelog](https://keepachangelog.com/zh-CN/1.0.0/). 并采用
[语义化版本](https://semver.org/lang/Chinese/).
+## [0.5.0] - 2026-05-05
+
+### 新增
+
+- **结构化工具返回结果**: 引入 `ToolResult`
+ 数据类作为所有工具的统一返回类型,替代以往不一致的字符串和列表响应. 该类包含
+ `success`、`data`、`error`、`response`
+ 属性,内置结果压缩与标准化 XML 格式化功能. 所有工具方法现均返回 `ToolResult`
+ 对象,使上游调用方能进行一致的错误处理 ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)).
+- **`exact_search` 文件模式过滤**: 为 `exact_search` 新增 `file_pattern`
+ 参数(默认 `"*"`),与已有的 `regex_search`
+ 行为对齐. 支持按文件扩展名或通配符模式过滤搜索范围 ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)).
+- **工具自动分类**: 工具现根据 `write_permission`
+ 属性自动归类为只读或可写,无需手动注册分类,减少维护成本 ([#135, #139](https://github.com/SunYanbox/ManualAid/issues/135)).
+- **`read` 工具范围读取**: `read` 工具现支持通过 `start`、`end`(支持负数索引) 和
+ `context` 参数进行精确的行范围读取,替代原有的粗粒度 `max_lines`
+ 方式. 显示头部现展示实际读取的行范围 (`[行 start-end / 共 total_lines 行]`)
+ ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)).
+- **参数描述机制**: 在 `BaseTool` 中引入 `param_descriptions`
+ 字典,允许每个工具为参数提供可读描述. 参数文档格式从内联 XML 转为 Markdown 列表项 (`- **名称** (类型, 必需/可选): 描述`)
+ ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)).
+- **文件大小限制**: 为 `read`
+ 工具添加了可配置的最大文件大小限制 (`MAX_READ_FILE_SIZE`,默认 10MB),防止读取大文件时内存溢出 ([#130, #141](https://github.com/SunYanbox/ManualAid/issues/130)).
+
+### 更改
+
+- **工具路径参数统一**: 将所有工具中的路径参数统一重命名为 `path`——原
+ `file_path`(read、write、edit)和 `folder_path`(ls、glob)现统一使用
+ `path`. 此举减少 LLM 混淆并缩短注入 Token 长度 ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)).
+- **工具注入长度优化**: 移除工具函数中冗余的 Docstring `Parameters`
+ 段落、缩短工具描述、精简参数文档格式. 配合参数统一,这些变更显著缩短了系统提示注入长度,降低 LLM 幻觉风险 ([#127, #128](https://github.com/SunYanbox/ManualAid/issues/127)).
+- **符号搜索性能重构**: 用单次遍历多模式搜索(`search_content_multi_pattern`
+ API)替代原有的逐模式文件遍历, 消除了 N 倍 I/O 开销. 搜索结果现解析为结构化的
+ `list[dict]`
+ 而非对格式化文本做正则解析,修复了 "先格式化再解析"的反模式 ([#132, #137](https://github.com/SunYanbox/ManualAid/issues/132)).
+- **异常处理整合**: `handle_tool_exceptions` 装饰器现统一将所有异常封装为
+ `ToolResult(success=False, error=...)` 对象. 移除了 `ToolErrorResponse`
+ 依赖,错误消息格式化为 `ClassName: Message`
+ ([#133, #142](https://github.com/SunYanbox/ManualAid/issues/133)).
+
+### 修复
+
+- **搜索工具 `limit` 语义**: 修正了 `exact_search` 和 `regex_search` 中 `limit`
+ 参数的计数逻辑,从统计"扫描文件数"改为统计"匹配结果数",使行为符合用户预期 ([#134, #140](https://github.com/SunYanbox/ManualAid/issues/134)).
+- **输入解析器冗余警告**: 清理了输入解析器中已过时的 `warnings.warn`
+ 调用及未使用的 `import warnings`
+ 依赖 ([#138](https://github.com/SunYanbox/ManualAid/issues/138)).
+
+### 移除
+
+- **`read_lines` 工具**: 已合并至增强后的 `read` 工具. 所有 `read_lines`
+ 功能现可通过 `read` 的 `start`/`end`/`context`
+ 参数访问 ([#119, #128](https://github.com/SunYanbox/ManualAid/issues/119)).
+
## [0.4.1] - 2026-05-04
### 新增
@@ -144,6 +198,7 @@
_初始发布的功能和历史记录._
+[0.5.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.5.0
[0.4.1]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.1
[0.4.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.4.0
[0.3.0]: https://github.com/SunYanbox/ManualAid/releases/tag/v0.3.0
diff --git a/pyproject.toml b/pyproject.toml
index 5d32ecb..ed36fd5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ManualAid"
-version = "0.4.1"
+version = "0.5.0"
description = ""
requires-python = ">=3.14"
dependencies = [
diff --git a/src/console/commands/systems/quit_cmd.py b/src/console/commands/systems/quit_cmd.py
index fe574c3..ec199ac 100644
--- a/src/console/commands/systems/quit_cmd.py
+++ b/src/console/commands/systems/quit_cmd.py
@@ -16,7 +16,7 @@ def __init__(self):
def execute(self, context: CommandContext) -> CommandResult:
# Close session explicitly before exit (atexit handler is a fallback,
# but explicit is more reliable when sys.exit triggers early shutdown).
- session_id = getattr(context.tool_registry, "_current_session_id", None)
+ session_id = getattr(context.tool_registry, "session_id", None)
if session_id is not None and hasattr(context.workspace, "db"):
context.workspace.db.close_session(session_id)
context.console.print("[bold]Goodbye![/bold]")
diff --git a/src/console/handlers/tool_handler.py b/src/console/handlers/tool_handler.py
index 874a827..0faaf67 100644
--- a/src/console/handlers/tool_handler.py
+++ b/src/console/handlers/tool_handler.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import json
import time
from typing import TYPE_CHECKING
@@ -100,35 +99,9 @@ def handle(self, parsed_input: CommandParseResult) -> bool:
start = time.perf_counter()
# 执行
- try:
- response = self.tool_registry.execute(func_name, **func_kwargs)
- if isinstance(response, str):
- response_str = response
- elif isinstance(response, (dict, list, tuple)):
- response_str = json.dumps(response)
- else:
- response_str = f"{response.__class__.__name__}({response})"
-
- temp_result = [
- "",
- f"",
- response_str,
- "",
- "",
- ]
- func_result = str.join("\n", temp_result)
- except Exception as e:
- import traceback
-
- error = (
- f"执行工具{func_name}(参数={parms})时出现错误: "
- f"Error={e.__class__.__name__}({e}, {traceback.format_exc()})"
- )
- self.console.print(f"[red]{error}[/red]")
- func_result = "\n".join(["", "", error, "", ""])
- self.console.print(f"[red]{error}[/red]")
-
- collection.add(func_name, time.perf_counter() - start, kwargs=func_kwargs, result=func_result)
+ response = self.tool_registry.execute(func_name, **func_kwargs)
+
+ collection.add(func_name, time.perf_counter() - start, kwargs=func_kwargs, result=response.response)
result = ""
diff --git a/src/console/main.py b/src/console/main.py
index 927a3f2..0d1c65c 100644
--- a/src/console/main.py
+++ b/src/console/main.py
@@ -85,7 +85,6 @@ def init_workspace(start_path: str | None = None) -> Workspace | None:
session_id = workspace.db.create_session(name=f"session_{time.strftime('%Y%m%d_%H%M%S')}")
tool_registry.set_session_id(session_id)
- workspace._current_session_id = session_id
# Start a daemon heartbeat thread to periodically persist session duration
# and guard against accidental deletion flag
diff --git a/src/constants/__init__.py b/src/constants/__init__.py
index 3d26edf..3d18726 100644
--- a/src/constants/__init__.py
+++ b/src/constants/__init__.py
@@ -1 +1 @@
-__version__ = "0.4.1"
+__version__ = "0.5.0"
diff --git a/src/core/input_parser.py b/src/core/input_parser.py
index 61b2d32..620cb00 100644
--- a/src/core/input_parser.py
+++ b/src/core/input_parser.py
@@ -1,7 +1,6 @@
import html
import re
import shlex
-import warnings
from src.console.commands.command_registry import CommandRegistry
from src.models.commands import CommandParseResult
@@ -53,11 +52,6 @@ def parse_func_call(content: str, warns: list[str]) -> tuple[str, dict]:
"""
从 标签中提取函数名和参数,使用健壮的回退机制.
"""
- warnings.warn(
- "当参数存在`$", "", inner).strip()
diff --git a/src/core/tool_registry.py b/src/core/tool_registry.py
index c8d485d..55f8650 100644
--- a/src/core/tool_registry.py
+++ b/src/core/tool_registry.py
@@ -12,6 +12,7 @@
import nest_asyncio
from src.console.result_manager import _to_string
+from src.models.tools.tool_result import ToolResult
from src.utils.string_snapshot import truncate_string
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -55,36 +56,6 @@ def __init__(self):
# 配置常量 - 从环境变量读取,带默认值
self.MAX_DOC_LENGTH = int(os.getenv("TOOL_MAX_DOC_LENGTH", "360"))
self.MAX_FUNC_NAME_LENGTH = int(os.getenv("TOOL_MAX_FUNC_NAME_LENGTH", "80"))
- self.MAX_RESULT_LENGTH = int(os.getenv("TOOL_MAX_RESULT_LENGTH", "30000"))
- self.LIST_TRUNCATE_THRESHOLD = int(os.getenv("TOOL_LIST_TRUNCATE_THRESHOLD", "100"))
- self.DICT_TRUNCATE_THRESHOLD = int(os.getenv("TOOL_DICT_TRUNCATE_THRESHOLD", "100"))
-
- # 验证配置值
- self._validate_config()
-
- def _validate_config(self) -> None:
- """验证配置值确保在合理范围内"""
- if self.MAX_RESULT_LENGTH < 10:
- warnings.warn(
- f"TOOL_MAX_RESULT_LENGTH 过小({self.MAX_RESULT_LENGTH}),建议至少为100", UserWarning, stacklevel=2
- )
- self.MAX_RESULT_LENGTH = 100
-
- if self.LIST_TRUNCATE_THRESHOLD < 10:
- warnings.warn(
- f"TOOL_LIST_TRUNCATE_THRESHOLD 过小({self.LIST_TRUNCATE_THRESHOLD}),建议至少为50",
- UserWarning,
- stacklevel=2,
- )
- self.LIST_TRUNCATE_THRESHOLD = 50
-
- if self.DICT_TRUNCATE_THRESHOLD < 10:
- warnings.warn(
- f"TOOL_DICT_TRUNCATE_THRESHOLD 过小({self.DICT_TRUNCATE_THRESHOLD}),建议至少为50",
- UserWarning,
- stacklevel=2,
- )
- self.DICT_TRUNCATE_THRESHOLD = 50
def _validate_tool_info(self, name: str, doc: str) -> None:
"""验证工具信息并发出警告"""
@@ -96,16 +67,14 @@ def _validate_tool_info(self, name: str, doc: str) -> None:
if len(name) > self.MAX_FUNC_NAME_LENGTH:
warnings.warn(f"工具名称 '{name}' 超过 {self.MAX_FUNC_NAME_LENGTH} 字符", UserWarning, stacklevel=3)
- def _set_tool_category(self, tool_name: str) -> None:
- """根据工具名称设置分类."""
- if tool_name in {"glob", "ls", "regex_search", "exact_search", "stat", "read", "read_lines", "symbol_ref"}:
- self._tool_categories[tool_name] = "query"
- elif tool_name in {"write", "edit", "confirm_edit"}:
- self._tool_categories[tool_name] = "edit"
- elif tool_name == "git":
- self._tool_categories[tool_name] = "dangerous"
+ def _set_tool_category(self, tool: BaseTool) -> None:
+ """根据工具的 write_permission 属性设置分类."""
+ if tool.name == "git":
+ self._tool_categories[tool.name] = "dangerous"
+ elif tool.write_permission:
+ self._tool_categories[tool.name] = "write"
else:
- self._tool_categories[tool_name] = "query"
+ self._tool_categories[tool.name] = "query"
def register(self, workspace: Workspace) -> None:
"""为工作区注册工具"""
@@ -114,7 +83,6 @@ def register(self, workspace: Workspace) -> None:
from src.workspace.tools.git_tool import GitTool
from src.workspace.tools.glob_tool import GlobTool
from src.workspace.tools.ls_tool import LsTool
- from src.workspace.tools.read_lines_tool import ReadLinesTool
from src.workspace.tools.read_tool import ReadTool
from src.workspace.tools.regex_search_tool import RegexSearchTool
from src.workspace.tools.stat_tool import StatTool
@@ -127,7 +95,6 @@ def register(self, workspace: Workspace) -> None:
ExactSearchTool,
GlobTool,
LsTool,
- ReadLinesTool,
ReadTool,
RegexSearchTool,
WriteTool,
@@ -142,33 +109,11 @@ def register(self, workspace: Workspace) -> None:
warnings.warn(f"工具{tool.name}没有注册功能回调和参数", stacklevel=2)
continue
self._tools[tool.name] = tool
- self._set_tool_category(tool.name)
+ self._set_tool_category(tool)
except ValueError:
pass
- def _compress_result(self, result: Any) -> Any:
- """压缩过长的结果"""
- result_length = len(result)
- if isinstance(result, str):
- if result_length > self.MAX_RESULT_LENGTH:
- return (
- result[: self.MAX_RESULT_LENGTH]
- + f"... [字符串结果已截断 显示的字符数: {self.LIST_TRUNCATE_THRESHOLD} / {result_length}]"
- )
- elif isinstance(result, (list, tuple)):
- if result_length > self.LIST_TRUNCATE_THRESHOLD:
- return [
- *list(result[: self.LIST_TRUNCATE_THRESHOLD]),
- f"... [列表已截断 显示的项: {self.LIST_TRUNCATE_THRESHOLD} / {result_length}]",
- ]
- elif isinstance(result, dict) and result_length > self.DICT_TRUNCATE_THRESHOLD:
- compressed = {k: result[k] for k in list(result.keys())[: self.DICT_TRUNCATE_THRESHOLD]}
- compressed["..."] = f"[字典已截断 显示的项: {self.DICT_TRUNCATE_THRESHOLD} / {result_length}]"
- return compressed
-
- return result
-
- def execute(self, func_name: str, *args: Any, **kwargs: Any) -> Any:
+ def execute(self, func_name: str, *args: Any, **kwargs: Any) -> ToolResult:
"""
执行工具函数
@@ -180,37 +125,48 @@ def execute(self, func_name: str, *args: Any, **kwargs: Any) -> Any:
Returns:
函数执行结果(自动压缩过长的结果)
"""
- if func_name in self._tools:
- tool = self._tools[func_name]
- kwargs = tool.convert_args(kwargs)
+ try:
+ if func_name in self._tools:
+ tool = self._tools[func_name]
+ kwargs = tool.convert_args(kwargs)
+
+ start_time = time.perf_counter()
- start_time = time.perf_counter()
- status = "success"
- try:
if inspect.iscoroutinefunction(tool.func):
coro = tool.func(*args, **kwargs)
# 已有事件循环
try:
loop = asyncio.get_running_loop()
nest_asyncio.apply()
- result = loop.run_until_complete(coro)
+ raw_result = loop.run_until_complete(coro)
except RuntimeError: # pragma: no cover // pytest内置事件循环, 测不到这里
# 没有运行中的事件循环
- result = asyncio.run(coro)
+ raw_result = asyncio.run(coro)
else:
# 同步函数
- result = tool.func(*args, **kwargs)
- except Exception as e:
- status = "error"
- result = f''
-
- duration_ms = (time.perf_counter() - start_time) * 1000
- self._log_tool_call(func_name, kwargs, duration_ms, status)
- self._record_tool_call_summary(func_name, kwargs, result)
-
- return self._compress_result(result)
- else:
- raise ValueError(f"未找到工具: {func_name}")
+ raw_result = tool.func(*args, **kwargs)
+
+ # 统一解包 ToolResult
+ result = (
+ raw_result
+ if (isinstance(raw_result, ToolResult))
+ else (
+ ToolResult(
+ success=False,
+ func_name=func_name,
+ func_kwargs=kwargs,
+ error=f"错误的工具返回值类型: {raw_result.__class__.__name__}",
+ )
+ )
+ )
+ duration_ms = (time.perf_counter() - start_time) * 1000
+ self._log_tool_call(func_name, kwargs, duration_ms, result.status)
+ self._record_tool_call_summary(func_name, kwargs, result.response)
+ return result
+ else:
+ raise ValueError(f"未找到工具: {func_name}")
+ except Exception as e:
+ return ToolResult(success=False, data=kwargs, error=str(e), func_name=func_name, func_kwargs=kwargs)
def generate_markdown(self) -> str:
"""
@@ -238,8 +194,10 @@ def get_tool_info(self, name: str) -> BaseTool | None:
return self._tools.get(name)
def set_session_id(self, session_id: int) -> None:
- """设置当前会话 ID"""
+ """设置当前会话 ID,并同步到关联的 Workspace."""
self._current_session_id = session_id
+ if self._workspace is not None:
+ self._workspace.session_id = session_id
@staticmethod
def _compute_kwargs_json(kwargs: dict) -> str:
@@ -250,13 +208,13 @@ def _compute_kwargs_json(kwargs: dict) -> str:
return json.dumps(truncated, sort_keys=True, default=str)
def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, status: str) -> str | None:
- session_id = getattr(self, "_current_session_id", None)
- if session_id is None:
+ if self._current_session_id is None:
return None
try:
kwargs_json = self._compute_kwargs_json(kwargs)
workspace = getattr(self, "_workspace", None)
if workspace is not None:
+ session_id = self._current_session_id
# Determine audit_status based on tool category
category = self._tool_categories.get(func_name, "query")
audit_status = "none"
@@ -268,7 +226,7 @@ def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, statu
command_str = kwargs.get("command_str", "")
if not GitTool.is_safe_command(command_str):
audit_status = "PENDING_AUDIT"
- elif category == "edit":
+ elif category == "write":
audit_status = "none" # Snapshot has its own PENDING_AUDIT
workspace.db.log_tool_call(
@@ -284,12 +242,12 @@ def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, statu
return f"ToolRegistry(sync_tools={len(self._tools)})"
def _record_tool_call_summary(self, func_name: str, kwargs: dict, result: Any) -> None:
- session_id = getattr(self, "_current_session_id", None)
- if session_id is None:
+ if self._current_session_id is None:
return
# Exclude write tools
- if func_name in {"write", "edit", "confirm_edit"}:
+ tool = self._tools.get(func_name)
+ if tool is not None and tool.write_permission:
return
try:
@@ -297,6 +255,6 @@ def _record_tool_call_summary(self, func_name: str, kwargs: dict, result: Any) -
workspace = getattr(self, "_workspace", None)
if workspace is not None:
result_str = _to_string(result)
- workspace.db.record_tool_call_summary(session_id, func_name, kwargs_json, result_str)
+ workspace.db.record_tool_call_summary(self._current_session_id, func_name, kwargs_json, result_str)
except Exception:
pass
diff --git a/src/models/tools/tool_result.py b/src/models/tools/tool_result.py
new file mode 100644
index 0000000..763e911
--- /dev/null
+++ b/src/models/tools/tool_result.py
@@ -0,0 +1,143 @@
+import json
+import os
+import warnings
+from typing import Any, ClassVar
+
+from src.utils.string_snapshot import truncate_params_string
+
+
+def to_xml_string(func_name: str, params: dict, data: Any = None, err: str | None = None) -> str:
+ params = params.copy()
+ params.pop("self", None)
+
+ try:
+ messages: list[str] = []
+ if data is not None:
+ messages.append("")
+ if isinstance(data, str):
+ messages.append(data)
+ elif isinstance(data, (dict, list, tuple)):
+ messages.append(json.dumps(data))
+ else:
+ messages.append(f"{data.__class__.__name__}({data})")
+ messages.append("")
+ if err:
+ messages.extend(
+ [
+ "",
+ err,
+ "",
+ ]
+ )
+ if data is None and not err:
+ messages.append("没有任何工具调用数据或错误详情, 请提示用户检查工具是否正常")
+
+ temp_result = [
+ "",
+ f"",
+ *messages,
+ "",
+ "",
+ ]
+ func_result = str.join("\n", temp_result)
+ except Exception as e:
+ import traceback
+
+ func_result = "\n".join(
+ [
+ "",
+ f"",
+ f"Error={e.__class__.__name__}({e}, {traceback.format_exc()})",
+ err if err else "",
+ "",
+ "",
+ ]
+ )
+ return func_result
+
+
+class ToolResult:
+ """工具执行结果的结构化包装,显式区分成功与失败.
+
+ 所有被 @handle_tool_exceptions 装饰的工具方法均返回此类型,
+ 调用方可通过 success 标志可靠判断执行状态,无需依赖隐式类型约定.
+ """
+
+ __slots__ = ("data", "error", "func_kwargs", "func_name", "response", "success")
+
+ HAD_VALIDATE: ClassVar[bool] = False
+ MAX_RESULT_LENGTH: ClassVar[int] = int(os.getenv("TOOL_MAX_RESULT_LENGTH", "30000"))
+ LIST_TRUNCATE_THRESHOLD: ClassVar[int] = int(os.getenv("TOOL_LIST_TRUNCATE_THRESHOLD", "100"))
+ DICT_TRUNCATE_THRESHOLD: ClassVar[int] = int(os.getenv("TOOL_DICT_TRUNCATE_THRESHOLD", "100"))
+
+ def __init__(
+ self, success: bool, func_name: str, func_kwargs: dict, data: Any = None, error: str | None = None
+ ) -> None:
+ self.success: bool = success
+ self.func_name: str = func_name
+ self.func_kwargs: dict = func_kwargs
+ self.data: Any = data
+ self.error: str | None = error
+ self.response: str = to_xml_string(self.func_name, self.func_kwargs, self.data, self.error)
+ ToolResult._validate_config()
+
+ def __repr__(self) -> str:
+ if self.success:
+ return f"ToolResult(success=True, data={self.data!r})"
+ return f"ToolResult(success=False, error={self.error!r})"
+
+ @property
+ def status(self) -> str:
+ return "success" if self.success else "error"
+
+ @classmethod
+ def _compress_result(cls, result: Any) -> Any:
+ """压缩过长的结果"""
+ result_length = len(result)
+ if isinstance(result, str):
+ if result_length > cls.MAX_RESULT_LENGTH:
+ return (
+ result[: cls.MAX_RESULT_LENGTH]
+ + f"... [字符串结果已截断 显示的字符数: {cls.LIST_TRUNCATE_THRESHOLD} / {result_length}]"
+ )
+ elif isinstance(result, (list, tuple)):
+ if result_length > cls.LIST_TRUNCATE_THRESHOLD:
+ return [
+ *list(result[: cls.LIST_TRUNCATE_THRESHOLD]),
+ f"... [列表已截断 显示的项: {cls.LIST_TRUNCATE_THRESHOLD} / {result_length}]",
+ ]
+ elif isinstance(result, dict) and result_length > cls.DICT_TRUNCATE_THRESHOLD:
+ compressed = {k: result[k] for k in list(result.keys())[: cls.DICT_TRUNCATE_THRESHOLD]}
+ compressed["..."] = f"[字典已截断 显示的项: {cls.DICT_TRUNCATE_THRESHOLD} / {result_length}]"
+ return compressed
+
+ return result
+
+ @classmethod
+ def _validate_config(cls) -> None:
+ """验证配置值确保在合理范围内"""
+ if cls.HAD_VALIDATE:
+ return None
+ if cls.MAX_RESULT_LENGTH < 10:
+ warnings.warn(
+ f"TOOL_MAX_RESULT_LENGTH 过小({cls.MAX_RESULT_LENGTH}),建议至少为100", UserWarning, stacklevel=2
+ )
+ cls.MAX_RESULT_LENGTH = 100
+
+ if cls.LIST_TRUNCATE_THRESHOLD < 10:
+ warnings.warn(
+ f"TOOL_LIST_TRUNCATE_THRESHOLD 过小({cls.LIST_TRUNCATE_THRESHOLD}),建议至少为50",
+ UserWarning,
+ stacklevel=2,
+ )
+ cls.LIST_TRUNCATE_THRESHOLD = 50
+
+ if cls.DICT_TRUNCATE_THRESHOLD < 10:
+ warnings.warn(
+ f"TOOL_DICT_TRUNCATE_THRESHOLD 过小({cls.DICT_TRUNCATE_THRESHOLD}),建议至少为50",
+ UserWarning,
+ stacklevel=2,
+ )
+ cls.DICT_TRUNCATE_THRESHOLD = 50
+ cls.HAD_VALIDATE = True
+ return None
diff --git a/src/workspace/tools/base_tool.py b/src/workspace/tools/base_tool.py
index 8d456a9..70e6ab4 100644
--- a/src/workspace/tools/base_tool.py
+++ b/src/workspace/tools/base_tool.py
@@ -4,26 +4,25 @@
from typing import Any
from src.core.file_tracker import FileTracker
+from src.models.tools.tool_result import ToolResult
from src.workspace.workspace import Workspace
-def build_param_doc(name: str, params: dict[str, Any]) -> str:
- """Generate a concise XML parameter doc."""
+def build_param_list_item(name: str, params: dict[str, Any], description: str = "") -> str:
+ """Generate a Markdown list item describing a parameter."""
from src.constants.prompts import clean_type_annotation
- result = f'"
- return result
+ default_val = params.get("default", "")
+ req_str = f"optional, default=`{default_val}`"
+
+ desc_suffix = f": {description}" if description else ""
+
+ return f"- **{name}** ({type_str}, {req_str}){desc_suffix}"
def convert_param_type(value: str, expected_type: str) -> Any:
@@ -110,8 +109,9 @@ def __init__(
self.write_permission: bool = write_permission
self.name: str = name
self.doc: str = doc
- self.func: Callable[..., Any] | None = None
+ self.func: Callable[..., ToolResult] | None = None
self.params: dict[str, Any] | None = None
+ self.param_descriptions: dict[str, str] = {}
def to_doc(self) -> str:
"""转换为模型可读文档格式"""
@@ -119,7 +119,8 @@ def to_doc(self) -> str:
if self.params and len(self.params) > 0:
lines.append(" ")
for name, param in self.params.items():
- lines.append(f" {build_param_doc(name, param)}")
+ desc = self.param_descriptions.get(name, "")
+ lines.append(f" {build_param_list_item(name, param, desc)}")
lines.append(" ")
else:
lines.append(" ")
@@ -139,7 +140,7 @@ def to_func_call(self) -> str:
return func_call
@staticmethod
- def extract_params(func: Callable[..., Any]) -> dict[str, Any]:
+ def extract_params(func: Callable[..., ToolResult]) -> dict[str, Any]:
"""提取函数参数信息"""
sig = inspect.signature(func)
params = {}
@@ -181,7 +182,7 @@ def _record_read_meta(self, resolved_path: Path) -> None:
try:
meta = FileTracker.get_file_meta(resolved_path)
if meta:
- session_id = self.workspace._current_session_id
+ session_id = self.workspace.session_id
if session_id is not None:
rel_path = str(resolved_path.relative_to(self.workspace.root_path))
self.workspace.db.record_file_read(
@@ -195,7 +196,7 @@ def _validate_mtime(self, resolved_path: Path) -> str | None:
if not resolved_path.exists():
return None
- session_id = self.workspace._current_session_id
+ session_id = self.workspace.session_id
if session_id is None:
return None
@@ -224,25 +225,60 @@ def _generate_diff(old_content: str, new_content: str, file_path: str) -> str:
diff = difflib.unified_diff(old_lines, new_lines, fromfile=f"a/{file_path}", tofile=f"b/{file_path}")
return "".join(diff)
+ @classmethod
+ def make_tool_result_response(
+ cls, success: bool, kwargs: dict, data: Any = None, error: str | None = None
+ ) -> ToolResult:
+ return ToolResult(success=success, func_name=cls.__name__, func_kwargs=kwargs, data=data, error=error)
+
+ @classmethod
+ def make_success_response(cls, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult:
+ return cls.make_tool_result_response(success=True, kwargs=kwargs, data=data, error=error)
+
+ @classmethod
+ def make_failed_response(cls, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult:
+ return cls.make_tool_result_response(success=False, kwargs=kwargs, data=data, error=error)
+
@staticmethod
- def handle_tool_exceptions(func):
- """工具方法异常处理装饰器."""
+ def handle_tool_exceptions(func) -> Callable[..., ToolResult]:
+ """工具方法异常处理装饰器 —— 将异常转换为 ToolResult 失败结果"""
from functools import wraps
- from src.models.tool_error_response import ToolErrorResponse
from src.workspace.path_validator import PathNotFoundError, WorkspaceBoundaryError
@wraps(func)
def wrapper(self, *args, **kwargs):
try:
- return func(self, *args, **kwargs)
+ raw = func(self, *args, **kwargs)
+ # 如果工具内部已返回 ToolResult, 直接透传
+ if isinstance(raw, ToolResult):
+ return raw
+ # 否则包装为成功结果
+ return ToolResult(success=True, func_name=func.__name__, func_kwargs=kwargs, data=raw)
except PathNotFoundError as err1:
- return ToolErrorResponse(self.__class__.__name__, err1).to_str()
+ return ToolResult(
+ success=False,
+ func_name=func.__name__,
+ func_kwargs=kwargs,
+ error=f"{err1.__class__.__name__}: {err1}",
+ )
except WorkspaceBoundaryError as err2:
- return ToolErrorResponse(self.__class__.__name__, err2).to_str()
+ return ToolResult(
+ success=False,
+ func_name=func.__name__,
+ func_kwargs=kwargs,
+ error=f"{err2.__class__.__name__}: {err2}",
+ )
except PermissionError as err3:
- return ToolErrorResponse(self.__class__.__name__, err3).to_str()
+ return ToolResult(
+ success=False,
+ func_name=func.__name__,
+ func_kwargs=kwargs,
+ error=f"{err3.__class__.__name__}: {err3}",
+ )
except Exception as err:
- return ToolErrorResponse(self.__class__.__name__, err).to_str()
+ return ToolResult(
+ success=False, func_name=func.__name__, func_kwargs=kwargs, error=f"{err.__class__.__name__}: {err}"
+ )
return wrapper
diff --git a/src/workspace/tools/edit_tool.py b/src/workspace/tools/edit_tool.py
index 15553ee..b33c154 100644
--- a/src/workspace/tools/edit_tool.py
+++ b/src/workspace/tools/edit_tool.py
@@ -3,6 +3,7 @@
from pathlib import Path
from src.models.tool_error_response import ToolErrorResponse
+from src.models.tools.tool_result import ToolResult
from src.utils.binary_detector import is_binary_file
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -19,61 +20,55 @@ def __init__(self, workspace: Workspace):
super().__init__(workspace, "edit", self.edit.__doc__, write_permission=True)
self.func = self.edit
self.params = BaseTool.extract_params(self.edit)
+ self.param_descriptions = {
+ "path": "文件路径",
+ "old_string": "待替换的字符串",
+ "new_string": "替换后的字符串",
+ "max_replacements": "最大替换次数(1~100)",
+ "context_before": "匹配前的上下文文本",
+ "context_after": "匹配后的上下文文本",
+ }
@BaseTool.handle_tool_exceptions
def edit(
self,
- file_path: str,
+ path: str,
old_string: str,
new_string: str,
max_replacements: int = 10,
context_before: str = "",
context_after: str = "",
- ) -> str:
+ ) -> ToolResult:
"""
- 在文件中进行安全的字符串替换(仅预览,不修改磁盘)
-
- 执行 dry-run 替换,生成 diff,记录 PENDING_AUDIT 快照.
- 批准后由 AuditCommitter 执行实际写入.
-
- Parameters
- ----------
- file_path: 文件路径
- old_string: 待替换的字符串(不能为空)
- new_string: 替换后的字符串
- max_replacements: 最大替换次数(默认 10,最大 100)
- context_before: 匹配前的上下文文本(可选,用于校验)
- context_after: 匹配后的上下文文本(可选,用于校验)
+ 通过在文件中进行安全的字符串替换编辑文件
"""
# 1. 参数校验
if not old_string:
- return ToolErrorResponse(self.__class__.__name__, ValueError("old_string 不能为空")).to_str()
+ return self.make_failed_response(locals().copy(), error=f"{ValueError('old_string 不能为空')}")
if max_replacements < 1:
- return ToolErrorResponse(self.__class__.__name__, ValueError("max_replacements 必须 >= 1")).to_str()
+ return self.make_failed_response(locals().copy(), error=f"{ValueError('max_replacements 必须 >= 1')}")
if max_replacements > 100:
max_replacements = 100
# 2. 路径解析
- source_file_path = Path(file_path)
- resolved_path: Path = self.workspace.path_validator.resolve_path(source_file_path)
+ source_path = Path(path)
+ resolved_path: Path = self.workspace.path_validator.resolve_path(source_path)
if not resolved_path.is_file():
- return ToolErrorResponse(
- self.__class__.__name__,
- FileNotFoundError(f"文件不存在: {resolved_path}"),
- ).to_str()
+ return self.make_failed_response(
+ locals().copy(), error=f"{FileNotFoundError(f'文件不存在: {resolved_path}')}"
+ )
if is_binary_file(resolved_path):
- return ToolErrorResponse(
- self.__class__.__name__,
- ValueError(f"禁止编辑二进制文件: {resolved_path}"),
- ).to_str()
+ return self.make_failed_response(
+ locals().copy(), error=f"{ValueError(f'禁止编辑二进制文件: {resolved_path}')}"
+ )
# 3. mtime 校验
mtime_error = self._validate_mtime(resolved_path)
if mtime_error:
- return mtime_error
+ return self.make_failed_response(locals().copy(), error=f"无法编辑被修改过的文件:\n{mtime_error}")
# 4. 读取文件内容
old_content = resolved_path.read_text(encoding="utf-8")
@@ -91,12 +86,17 @@ def edit(
if context_before or context_after:
ctx_error = self._check_context(old_content, idx, old_string, context_before, context_after, count)
if ctx_error:
- return ctx_error
+ return self.make_failed_response(
+ locals().copy(), error=f"无法修改上下文不匹配的字符串:\n{ctx_error}"
+ )
idx += len(old_string)
if count == 0:
- return f"No changes made: old_string not found in file.\nFile: {file_path}\nSearching for: '{old_string}'"
+ return self.make_failed_response(
+ locals().copy(),
+ error=f"No changes made: old_string not found in file.\nFile: {path}\nSearching for: '{old_string}'",
+ )
# 6. 执行替换(生成新内容)
new_content = old_content.replace(old_string, new_string, count)
@@ -110,7 +110,7 @@ def edit(
old_hash = FileTracker.compute_checksum_from_string(old_content)
new_hash = FileTracker.compute_checksum_from_string(new_content)
- session_id = self.workspace._current_session_id
+ session_id = self.workspace.session_id
snapshot_id = self.workspace.db.record_file_snapshot(
rel_path,
old_hash,
@@ -122,12 +122,16 @@ def edit(
)
# 9. 返回预览
- return (
- f"[Edit Preview]\n"
- f"File: {rel_path}\n"
- f"Snapshot ID: {snapshot_id}\n"
- f"Replacements: {count}\n"
- f"Diff:\n{diff_content}"
+ return self.make_success_response(
+ locals().copy(),
+ (
+ "修改已推送到审核系统\n"
+ f"[Edit Preview]\n"
+ f"File: {rel_path}\n"
+ f"Snapshot ID: {snapshot_id}\n"
+ f"Replacements: {count}\n"
+ f"Diff:\n{diff_content}"
+ ),
)
@staticmethod
diff --git a/src/workspace/tools/exact_search_tool.py b/src/workspace/tools/exact_search_tool.py
index b23eaa6..873df3a 100644
--- a/src/workspace/tools/exact_search_tool.py
+++ b/src/workspace/tools/exact_search_tool.py
@@ -2,6 +2,7 @@
import re
from pathlib import Path
+from src.models.tools.tool_result import ToolResult
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -76,6 +77,15 @@ def __init__(self, workspace: Workspace):
super().__init__(workspace, "exact_search", self.exact_search.__doc__)
self.func = self.exact_search
self.params = BaseTool.extract_params(self.exact_search)
+ self.param_descriptions = {
+ "pattern": "搜索字符串",
+ "path": "搜索文件或文件夹路径",
+ "case_sensitive": "是否大小写敏感",
+ "whole_word": "是否全词匹配",
+ "file_pattern": "文件匹配模式,支持通配符",
+ "limit": "最大匹配数量限制",
+ "ignore": "忽略匹配正则的文件或文件夹列表",
+ }
@BaseTool.handle_tool_exceptions
def exact_search(
@@ -84,22 +94,12 @@ def exact_search(
path: str = ".",
case_sensitive: bool = True,
whole_word: bool = True,
+ file_pattern: str = "*",
limit: int = 256,
ignore: list[str] | None = None,
- ) -> str:
+ ) -> ToolResult:
"""
- 精确搜索字符串(支持大小写敏感/全词匹配)
-
- Args:
- pattern: 搜索字符串
- path: 搜索路径,默认为当前目录
- case_sensitive: 是否大小写敏感,默认为True
- whole_word: 是否全词匹配,默认为True
- limit: 最大匹配数量限制,默认为256
- ignore: 忽略匹配正则的文件或文件夹列表
-
- Returns:
- 格式化的搜索结果字符串
+ 精确搜索字符串
"""
# 验证搜索路径
search_path: Path = self.workspace.path_validator.validate(path)
@@ -117,16 +117,18 @@ def exact_search(
# 搜索结果
results = []
file_count = 0
+ total_matches = 0
+ warnings = [""]
# 确定要搜索的文件列表(支持单文件或目录)
- files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob("*"))
+ files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob(file_pattern))
# 遍历所有文件
for file_path in files_to_search:
if not file_path.is_file():
continue
# 检查是否达到限制
- if len(results) >= limit:
+ if total_matches >= limit:
break
# 检查是否应该忽略
@@ -152,9 +154,17 @@ def exact_search(
if file_matches:
results.append({"file": str(file_path), "matches": file_matches})
file_count += 1
+ total_matches += len(file_matches)
- except OSError, UnicodeDecodeError, PermissionError:
+ except (OSError, UnicodeDecodeError, PermissionError) as e:
+ warnings.append(f"在文件{file_path}搜索匹配行时出错: {e}")
continue # 跳过无法读取的文件
+ warnings.append("")
+
# 格式化输出
- return _format_exact_results(results, pattern, limit, file_count, case_sensitive, whole_word)
+ return self.make_success_response(
+ kwargs=locals().copy(),
+ data=_format_exact_results(results, pattern, limit, file_count, case_sensitive, whole_word),
+ error="\n".join(warnings) if len(warnings) > 2 else None,
+ )
diff --git a/src/workspace/tools/git_tool.py b/src/workspace/tools/git_tool.py
index cf96864..a0cc2eb 100644
--- a/src/workspace/tools/git_tool.py
+++ b/src/workspace/tools/git_tool.py
@@ -4,7 +4,7 @@
import shlex
import subprocess
-from src.models.tool_error_response import ToolErrorResponse
+from src.models.tools.tool_result import ToolResult
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -55,60 +55,65 @@ def __init__(self, workspace: Workspace):
super().__init__(workspace, "git", self.git.__doc__, read_permission=True)
self.func = self.git
self.params = BaseTool.extract_params(self.git)
+ self.param_descriptions = {
+ "command_str": "Git 子命令及其参数,如 'status'、'diff --cached'、'log --oneline -5'",
+ }
- def git(self, command_str: str) -> str:
+ def git(self, command_str: str) -> ToolResult:
"""
- 执行 Git 命令(白名单限制)
-
- Parameters
- ----------
- command_str: Git 子命令及其参数,如 "status"、"diff --cached"、"log --oneline -5"
+ 执行 Git 命令
"""
if not command_str or not command_str.strip():
- return ToolErrorResponse(self.__class__.__name__, ValueError("command_str 不能为空")).to_str()
+ return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError("command_str 不能为空")))
try:
tokens = shlex.split(command_str)
except ValueError as e:
- return ToolErrorResponse(self.__class__.__name__, e).to_str()
+ return self.make_failed_response(kwargs=locals().copy(), error=str(e))
if not tokens:
- return ToolErrorResponse(self.__class__.__name__, ValueError("无法解析命令")).to_str()
+ return self.make_failed_response(
+ kwargs=locals().copy(), error=str(ValueError(f"无法解析命令: `{command_str}`"))
+ )
base_command = tokens[0]
# 1. 白名单检查
if base_command not in _ALLOWED_COMMANDS:
allowed_list = ", ".join(sorted(_ALLOWED_COMMANDS))
- return (
- f"ERROR: Git command '{base_command}' is not in the allowed whitelist.\n"
- f"Allowed commands: {allowed_list}"
+ return self.make_failed_response(
+ kwargs=locals().copy(),
+ error=(
+ f"ERROR: Git command '{base_command}' is not in the allowed whitelist.\n"
+ f"Allowed commands: {allowed_list}"
+ ),
)
# 2. 拦截正则检查
for pattern in _BLOCKED_PATTERNS:
if pattern.search(command_str):
- return (
- f"ERROR: The command was blocked by security policy.\n"
- f"Pattern matched: {pattern.pattern}\n"
- f"Command: {command_str}"
+ return self.make_failed_response(
+ kwargs=locals().copy(),
+ error=(
+ f"ERROR: The command was blocked by security policy.\n"
+ f"Pattern matched: {pattern.pattern}\n"
+ f"Command: {command_str}"
+ ),
)
# 3. restore 安全检查 — 必须指定文件路径
if base_command == "restore":
non_flag_args = [t for t in tokens[1:] if not t.startswith("-")]
if not non_flag_args:
- return ToolErrorResponse(
- self.__class__.__name__,
- ValueError("restore 需要指定文件路径,不允许裸 restore"),
- ).to_str()
+ return self.make_failed_response(
+ kwargs=locals().copy(), error=str(ValueError("restore 需要指定文件路径,不允许裸 restore"))
+ )
for arg in non_flag_args:
stripped = arg.strip()
if stripped in (".", "*", "all") or stripped.startswith("*"):
- return ToolErrorResponse(
- self.__class__.__name__,
- ValueError("restore 需要指定具体文件路径,不允许使用通配符"),
- ).to_str()
+ return self.make_failed_response(
+ kwargs=locals().copy(), error=str(ValueError("restore 需要指定具体文件路径,不允许使用通配符"))
+ )
# 4. 执行命令
try:
@@ -122,22 +127,19 @@ def git(self, command_str: str) -> str:
env=env,
)
except FileNotFoundError:
- return ToolErrorResponse(
- self.__class__.__name__,
- OSError("Git 未安装或不在系统 PATH 中"),
- ).to_str()
- except subprocess.TimeoutExpired:
- return ToolErrorResponse(
- self.__class__.__name__,
- TimeoutError("Git 命令执行超时(30 秒)"),
- ).to_str()
+ return self.make_failed_response(kwargs=locals().copy(), error=str(OSError("Git 未安装或不在系统 PATH 中")))
+ except subprocess.TimeoutExpired as time_out_exception:
+ return self.make_failed_response(
+ kwargs=locals().copy(), error=f"TimeoutExpired(Git 命令执行超时: {time_out_exception})"
+ )
# 5. 处理输出
if result.returncode != 0:
stderr = (result.stderr or "").strip()
- if stderr:
- return f"Git command failed (exit code {result.returncode}):\n{stderr}"
- return f"Git command failed (exit code {result.returncode})"
+ return self.make_failed_response(
+ kwargs=locals().copy(),
+ error=f"Git command failed (exit code {result.returncode})" + f":\n{stderr}" if stderr else "",
+ )
# Combine stdout and stderr
output_parts = []
@@ -166,9 +168,13 @@ def git(self, command_str: str) -> str:
if err:
output_parts.append(err.rstrip("\n"))
if not output_parts and _result2.returncode != 0:
- return f"Git command failed (exit code {_result2.returncode})"
+ return self.make_failed_response(
+ kwargs=locals().copy(), error=f"Git command failed (exit code {_result2.returncode})"
+ )
- return "\n".join(output_parts) if output_parts else "(no output)"
+ return self.make_success_response(
+ kwargs=locals().copy(), data="\n".join(output_parts) if output_parts else "(no output)"
+ )
@staticmethod
def is_safe_command(command_str: str) -> bool:
diff --git a/src/workspace/tools/glob_tool.py b/src/workspace/tools/glob_tool.py
index 55ddabb..df158c0 100644
--- a/src/workspace/tools/glob_tool.py
+++ b/src/workspace/tools/glob_tool.py
@@ -1,6 +1,6 @@
from itertools import islice
-from src.models.tool_error_response import ToolErrorResponse
+from src.models.tools.tool_result import ToolResult
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -10,27 +10,25 @@ def __init__(self, workspace: Workspace):
super().__init__(workspace, "glob", self.glob.__doc__)
self.func = self.glob
self.params = BaseTool.extract_params(self.glob)
+ self.param_descriptions = {
+ "pattern": "通配符",
+ "path": "目录路径",
+ "max_ret": "最多返回多少条检索结果",
+ }
@BaseTool.handle_tool_exceptions
- def glob(self, pattern: str, folder_path: str = ".", max_ret: int = 1000) -> list[str]:
+ def glob(self, pattern: str, path: str = ".", max_ret: int = 1000) -> ToolResult:
"""
在工作区内按通配符模式匹配并列出所有路径,带[Folder]或[File]的类型标记. 失败时返回错误信息
-
- Parameters
- ----------
- pattern: 通配符
- folder_path: 目录路径
- max_ret: 最多返回多少条检索结果
-
- Returns
- -------
- 检索到的文件或文件夹的相对路径
"""
- root_path = self.workspace.path_validator.validate(folder_path)
+ root_path = self.workspace.path_validator.validate(path)
if not root_path.is_dir():
- return ToolErrorResponse(self.__class__.__name__, f"{root_path}不是一个文件夹路径").to_str()
+ return self.make_failed_response(kwargs=locals().copy(), error=f"{root_path}不是一个文件夹路径")
- return [
- f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}"
- for item in islice(root_path.glob(pattern), max_ret)
- ]
+ return self.make_success_response(
+ kwargs=locals().copy(),
+ data=[
+ f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}"
+ for item in islice(root_path.glob(pattern), max_ret)
+ ],
+ )
diff --git a/src/workspace/tools/ls_tool.py b/src/workspace/tools/ls_tool.py
index 7f92edb..33f035d 100644
--- a/src/workspace/tools/ls_tool.py
+++ b/src/workspace/tools/ls_tool.py
@@ -1,6 +1,6 @@
from pathlib import Path
-from src.models.tool_error_response import ToolErrorResponse
+from src.models.tools.tool_result import ToolResult
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -10,16 +10,22 @@ def __init__(self, workspace: Workspace):
super().__init__(workspace, "ls", self.ls.__doc__)
self.func = self.ls
self.params = BaseTool.extract_params(self.ls)
+ self.param_descriptions = {
+ "path": "目录路径",
+ }
@BaseTool.handle_tool_exceptions
- def ls(self, folder_path: str = ".") -> list[str] | str:
+ def ls(self, path: str = ".") -> ToolResult:
"""
列出指定目录下的文件和文件夹. 返回相对路径列表, 并标记[Folder]或[File]
"""
- path: Path = self.workspace.path_validator.validate(folder_path)
- if not path.is_dir():
- return ToolErrorResponse(self.__class__.__name__, f'参数错误: "{path}"不是一个目录').to_str()
- return [
- f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}"
- for item in path.iterdir()
- ]
+ folder_path: Path = self.workspace.path_validator.validate(path)
+ if not folder_path.is_dir():
+ return self.make_failed_response(kwargs=locals().copy(), error=f'参数错误: "{folder_path}"不是一个目录')
+ return self.make_success_response(
+ kwargs=locals().copy(),
+ data=[
+ f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}"
+ for item in folder_path.iterdir()
+ ],
+ )
diff --git a/src/workspace/tools/read_lines_tool.py b/src/workspace/tools/read_lines_tool.py
deleted file mode 100644
index a00c41f..0000000
--- a/src/workspace/tools/read_lines_tool.py
+++ /dev/null
@@ -1,71 +0,0 @@
-from pathlib import Path
-
-from src.models.tool_error_response import ToolErrorResponse
-from src.utils.binary_detector import is_binary_file
-from src.workspace.tools.base_tool import BaseTool
-from src.workspace.workspace import Workspace
-
-
-class ReadLinesTool(BaseTool):
- def __init__(self, workspace: Workspace):
- super().__init__(workspace, "read_lines", self.read_lines.__doc__)
- self.func = self.read_lines
- self.params = BaseTool.extract_params(self.read_lines)
-
- @BaseTool.handle_tool_exceptions
- def read_lines(self, file_path: str, start: int, end: int, context: int = 2, encoding: str = "utf-8") -> str:
- """
- 读取文件的指定行范围(行号从1开始), 可指定上下文行数扩展返回的实际行数范围,返回带行号的格式化内容
-
- Parameters
- ----------
- file_path: 文件路径
- start: 开始行数
- end: 结束行数
- context: 扩展结果行数范围 行数范围最终为(start-context, end+context)
- encoding: 编码
- """
- path: Path = self.workspace.path_validator.validate(file_path)
-
- if not path.is_file():
- return ToolErrorResponse(self.__class__.__name__, ValueError(f"读取文件{path}时未读取到完整文件")).to_str()
-
- if is_binary_file(path):
- return ToolErrorResponse(
- self.__class__.__name__,
- ValueError(f"无法读取二进制文件: {path}. 请使用二进制安全工具或转换为 base64."),
- ).to_str()
-
- with open(path, encoding=encoding) as f:
- lines = f.readlines()
-
- total_lines = len(lines)
-
- context = max(0, context)
-
- start -= context
- end += context
-
- # 验证行号
- if start < 1:
- start = 1
- if start > total_lines:
- return f"错误:起始行 {start} 超过文件总行数 ({total_lines})"
-
- end = total_lines if end is None else min(end, total_lines)
-
- if end < start:
- return f"错误:结束行 {end} 小于起始行 {start}"
-
- result_lines = []
- for i in range(start - 1, end):
- line_num = i + 1
- content = lines[i].rstrip("\n\r")
- result_lines.append(f"{line_num:6d} | {content}")
-
- header = f"\n[文件: {path}]\n[行 {start}-{end} / 共 {total_lines} 行]\n"
- separator = "-" * 80 + "\n"
-
- self._record_read_meta(path)
-
- return header + separator + "\n".join(result_lines)
diff --git a/src/workspace/tools/read_tool.py b/src/workspace/tools/read_tool.py
index ba9b789..7bac9d8 100644
--- a/src/workspace/tools/read_tool.py
+++ b/src/workspace/tools/read_tool.py
@@ -1,55 +1,108 @@
+import os
from pathlib import Path
-from src.models.tool_error_response import ToolErrorResponse
+from src.models.tools.tool_result import ToolResult
from src.utils.binary_detector import is_binary_file
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
+# 最大文件读取大小, 超过此大小的文件将被拒绝读取(默认 10MB)
+MAX_FILE_SIZE = int(os.getenv("TOOL_READ_MAX_FILE_SIZE", str(10 * 1024 * 1024)))
+
+
+def _resolve_index(idx: int, total: int) -> int:
+ """Resolve a 1-based or negative index to a clamped 1-based line number."""
+ if idx < 0:
+ idx = total + 1 + idx
+ if idx < 1:
+ return 1
+ if idx > total:
+ return total
+ return idx
+
class ReadTool(BaseTool):
def __init__(self, workspace: Workspace):
super().__init__(workspace, "read", self.read.__doc__)
self.func = self.read
self.params = BaseTool.extract_params(self.read)
+ self.param_descriptions = {
+ "path": "文件路径",
+ "start": "起始行号(1开始; 负数表示倒数, -1=最后一行)",
+ "end": "结束行号(1开始; 负数表示倒数, -1=最后一行)",
+ "context": "扩展结果行数范围 行数范围最终为(start-context, end+context)",
+ "encoding": "编码",
+ }
@BaseTool.handle_tool_exceptions
- def read(self, file_path: str, max_lines: int = 0, encoding: str = "utf-8") -> str:
+ def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encoding: str = "utf-8") -> ToolResult:
"""
- 读取文件内容,可限制最大行数,返回文件内容字符串(带行号)
-
- Parameters
- ----------
- file_path: 文件路径
- max_lines: 最大行数(0表示不限制)
- encoding: 编码
+ 读取文件内容, 返回带行号的格式化内容
"""
- path: Path = self.workspace.path_validator.validate(file_path)
+ file_path: Path = self.workspace.path_validator.validate(path)
+
+ if not file_path.is_file():
+ return self.make_failed_response(
+ kwargs=locals().copy(), error=str(ValueError(f"读取文件{file_path}时未读取到完整文件"))
+ )
- if not path.is_file():
- return ToolErrorResponse(self.__class__.__name__, ValueError(f"读取文件{path}时未读取到完整文件")).to_str()
+ if is_binary_file(file_path):
+ return self.make_failed_response(
+ kwargs=locals().copy(),
+ error=str(ValueError(f"无法读取二进制文件: {file_path}. 请使用二进制安全工具或转换为 base64.")),
+ )
- if is_binary_file(path):
- return ToolErrorResponse(
- self.__class__.__name__,
- ValueError(f"无法读取二进制文件: {path}. 请使用二进制安全工具或转换为 base64."),
- ).to_str()
+ file_size = file_path.stat().st_size
+ if file_size > MAX_FILE_SIZE:
+ return self.make_failed_response(
+ kwargs=locals().copy(),
+ error=str(
+ ValueError(
+ f"文件过大 ({file_size} 字节), 超过最大限制 ({MAX_FILE_SIZE} 字节): {file_path}. "
+ f"请使用范围参数 (start/end) 分批读取."
+ )
+ ),
+ )
- with open(path, encoding=encoding) as f:
+ with open(file_path, encoding=encoding) as f:
lines = f.readlines()
total_lines = len(lines)
- if max_lines > 0:
- lines = lines[:max_lines]
+ if total_lines == 0:
+ header = f"\n[文件: {file_path}]\n[行 0-0 / 共 0 行]\n"
+ separator = "-" * 80 + "\n"
+ self._record_read_meta(file_path)
+ return self.make_success_response(kwargs=locals().copy(), data=header + separator)
+
+ context = max(0, context)
+
+ actual_start = _resolve_index(start, total_lines) - context
+ actual_end = _resolve_index(end, total_lines) + context
+
+ if actual_start < 1:
+ actual_start = 1
+ if actual_end > total_lines:
+ actual_end = total_lines
+
+ if actual_end < actual_start:
+ return self.make_failed_response(
+ kwargs=locals().copy(),
+ error=(
+ f"错误:解析后的结束行 {actual_end} 小于起始行 {actual_start} "
+ f"(原始参数: start={start}, end={end}, context={context})"
+ ),
+ )
result_lines = []
- for i, line in enumerate(lines, 1):
- content = line.rstrip("\n\r")
- result_lines.append(f"{i:6d} | {content}")
+ for i in range(actual_start - 1, actual_end):
+ line_num = i + 1
+ content = lines[i].rstrip("\n\r")
+ result_lines.append(f"{line_num:6d} | {content}")
- header = f"\n[文件: {path}]\n[行 1-{len(lines)} / 共 {total_lines} 行]\n"
+ header = f"\n[文件: {file_path}]\n[行 {actual_start}-{actual_end} / 共 {total_lines} 行]\n"
separator = "-" * 80 + "\n"
- self._record_read_meta(path)
+ self._record_read_meta(file_path)
- return header + separator + "\n".join(result_lines)
+ return self.make_success_response(kwargs=locals().copy(), data=header + separator + "\n".join(result_lines))
diff --git a/src/workspace/tools/regex_search_tool.py b/src/workspace/tools/regex_search_tool.py
index 59114ad..a8b6549 100644
--- a/src/workspace/tools/regex_search_tool.py
+++ b/src/workspace/tools/regex_search_tool.py
@@ -2,7 +2,7 @@
import re
from pathlib import Path
-from src.models.tool_error_response import ToolErrorResponse
+from src.models.tools.tool_result import ToolResult
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -103,6 +103,14 @@ def __init__(self, workspace: Workspace):
super().__init__(workspace, "regex_search", self.regex_search.__doc__)
self.func = self.regex_search
self.params = BaseTool.extract_params(self.regex_search)
+ self.param_descriptions = {
+ "pattern": "正则表达式模式",
+ "path": "搜索文件或文件夹路径",
+ "context": "显示匹配行的上下文行数",
+ "file_pattern": "文件匹配模式,支持通配符",
+ "limit": "最大匹配数量限制",
+ "ignore": "忽略匹配正则的文件或文件夹列表",
+ }
@BaseTool.handle_tool_exceptions
def regex_search(
@@ -113,18 +121,9 @@ def regex_search(
file_pattern: str = "*",
limit: int = 256,
ignore: list[str] | None = None,
- ) -> str:
+ ) -> ToolResult:
"""
使用正则表达式搜索文件内容, 支持上下文显示、文件过滤和忽略路径, 返回匹配详情; 适合代码与文档探索
-
- Parameters
- ----------
- pattern: 正则表达式模式
- path: 搜索路径,默认为当前目录
- context: 显示匹配行的上下文行数,默认为3
- file_pattern: 文件匹配模式,支持通配符,默认为"*"
- limit: 最大匹配数量限制,默认为256
- ignore: 忽略匹配正则的文件或文件夹列表
"""
# 验证搜索路径
search_path: Path = self.workspace.path_validator.validate(path)
@@ -133,7 +132,7 @@ def regex_search(
try:
regex = re.compile(pattern)
except re.error as e:
- return ToolErrorResponse(self.__class__.__name__, f"无效的正则表达式: {e}").to_str()
+ return self.make_failed_response(kwargs=locals().copy(), error=f"无效的正则表达式: {e}")
# 收集忽略模式
ignore_patterns = []
@@ -145,6 +144,8 @@ def regex_search(
# 搜索结果
results = []
file_count = 0
+ total_matches = 0
+ warnings = [""]
# 确定要搜索的文件列表(支持单文件或目录)
files_to_search = [search_path] if search_path.is_file() else list(search_path.rglob(file_pattern))
@@ -154,7 +155,7 @@ def regex_search(
if not file_path.is_file():
continue
# 检查是否达到限制
- if len(results) >= limit:
+ if total_matches >= limit:
break
# 检查是否应该忽略该文件或文件夹
@@ -181,9 +182,17 @@ def regex_search(
if file_results:
results.append({"file": str(file_path), "matches": file_results})
file_count += 1
+ total_matches += len(file_results)
- except OSError, UnicodeDecodeError, PermissionError:
+ except (OSError, UnicodeDecodeError, PermissionError) as e:
+ warnings.append(f"在文件{file_path}搜索匹配行时出错: {e}")
continue # 跳过无法读取的文件
+ warnings.append("")
+
# 格式化输出
- return _format_regex_results(results, pattern, limit, file_count)
+ return self.make_success_response(
+ kwargs=locals().copy(),
+ data=_format_regex_results(results, pattern, limit, file_count),
+ error="\n".join(warnings) if len(warnings) > 2 else None,
+ )
diff --git a/src/workspace/tools/stat_tool.py b/src/workspace/tools/stat_tool.py
index 41981cf..f6caeae 100644
--- a/src/workspace/tools/stat_tool.py
+++ b/src/workspace/tools/stat_tool.py
@@ -2,6 +2,7 @@
from datetime import datetime
from pathlib import Path
+from src.models.tools.tool_result import ToolResult
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -13,17 +14,14 @@ def __init__(self, workspace: Workspace):
super().__init__(workspace, "stat", self.stat.__doc__)
self.func = self.stat
self.params = BaseTool.extract_params(self.stat)
+ self.param_descriptions = {
+ "path": "文件或目录路径",
+ }
@BaseTool.handle_tool_exceptions
- def stat(self, path: str = ".") -> str:
+ def stat(self, path: str = ".") -> ToolResult:
"""
获取工作区内文件或目录的详细信息,包括大小、行数(仅文件)、修改时间、权限等
-
- Args:
- path: 文件或目录路径,默认为当前目录
-
- Returns:
- 格式化的详细信息字符串
"""
# 验证路径
target_path: Path = self.workspace.path_validator.validate(path)
@@ -145,4 +143,4 @@ def format_timestamp(timestamp: float) -> str:
except PermissionError:
output.append("目录内容: 无法访问")
- return "\n".join(output)
+ return self.make_success_response(kwargs=locals().copy(), data="\n".join(output))
diff --git a/src/workspace/tools/symbol_ref_tool.py b/src/workspace/tools/symbol_ref_tool.py
index 6c0c4df..4f42291 100644
--- a/src/workspace/tools/symbol_ref_tool.py
+++ b/src/workspace/tools/symbol_ref_tool.py
@@ -1,16 +1,12 @@
"""符号引用查找工具 - 查找函数、类、变量等的定义和引用"""
import re
-from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
-from src.models.tool_error_response import ToolErrorResponse
+from src.models.tools.tool_result import ToolResult
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
-# 默认排除的目录(与 workspace.py 保持一致, 后续改为从项目配置加载)
-DEFAULT_EXCLUDED_DIRS = {".git", "__pycache__", "node_modules", ".venv", "venv", "dist", "build", ".idea", ".vscode"}
-
def _get_file_pattern_by_language(language: str) -> str:
"""根据语言获取默认的文件匹配模式"""
@@ -124,9 +120,9 @@ def _generate_patterns(symbol_name: str, language: str, include_def: bool, inclu
return unique_patterns
-def _search_single_pattern(
+def _search_all_patterns(
workspace: Workspace,
- pattern_info: dict,
+ patterns: list[dict],
search_path: str,
symbol_name: str,
context_lines: int,
@@ -134,112 +130,124 @@ def _search_single_pattern(
file_pattern: str,
ignore: list[str] | None,
) -> list[dict]:
- """使用 Workspace.search_content 执行单个模式的搜索(并发版本)"""
- results = []
+ """单次文件遍历搜索所有模式, 返回按文件分组且带上下文的结果.
+
+ 相比旧实现的优势:
+ - 文件系统只遍历 1 次(旧: N 次, N=模式数量)
+ - 每个文件只读取 1 次(旧: N 次)
+ - 直接使用结构化数据, 无需 格式化→正则解析 的反模式
+ - 真正实现上下文行读取(旧实现仅将匹配行本身作为上下文)
+ """
+ # 编译所有正则模式
+ compiled_patterns: list[tuple[re.Pattern, str]] = []
+ for p in patterns:
+ try:
+ regex = re.compile(p["pattern"], re.IGNORECASE)
+ compiled_patterns.append((regex, p["type"]))
+ except re.error:
+ continue
- try:
- # 使用 workspace 的并发搜索能力
- search_result = workspace.search_content(
- pattern=pattern_info["pattern"],
- folder_path=search_path,
- file_pattern=file_pattern,
- max_workers=4, # 并发线程数
- case_sensitive=False,
- )
+ if not compiled_patterns:
+ return []
- # 解析 search_content 的返回结果
- if not search_result or "未找到匹配" in search_result:
- return results
-
- # 解析结果格式: search_content 返回的是格式化的字符串
- # 格式: [文件] path\n----\n 行号 | 内容
- lines = search_result.split("\n")
- current_file = None
- current_matches = []
-
- for line in lines:
- if line.startswith("[文件] "):
- # 保存上一个文件的结果
- if current_file and current_matches:
- results.append({"file": current_file, "matches": current_matches, "type": pattern_info["type"]})
- current_file = line[6:] # 去掉 "[文件] "
- current_matches = []
- elif line.startswith("----"):
- continue
- elif line.strip() and current_file:
- # 解析匹配行: " 行号 | 内容"
- match = re.match(r"\s+(\d+)\s*\|\s*(.*)", line)
- if match:
- line_num = int(match.group(1))
- content = match.group(2)
-
- # 收集上下文(这里简化处理,因为 search_content 不直接提供上下文)
- # 为了保持兼容性,我们构建一个简化的上下文
- current_matches.append(
- {
- "line_num": line_num,
- "content": content,
- "context": [{"line_num": line_num, "content": content, "is_match": True}],
- "match_type": pattern_info["type"],
- "symbol_name": symbol_name,
- }
- )
+ # 单次遍历搜索: 所有模式在一次文件遍历中完成
+ matches = workspace.search_content_multi_pattern(
+ patterns=compiled_patterns,
+ folder_path=search_path,
+ file_pattern=file_pattern,
+ max_workers=4,
+ ignore=ignore,
+ )
- # 保存最后一个文件的结果
- if current_file and current_matches:
- results.append({"file": current_file, "matches": current_matches, "type": pattern_info["type"]})
+ if not matches:
+ return []
- except Exception:
- # 如果搜索失败,静默跳过
- pass
+ # 应用 limit 截断
+ matches = matches[:limit]
- return results
+ # 按文件分组并构建带上下文的结果
+ return _build_results_with_context(matches, context_lines, symbol_name, workspace.root_path)
-def _search_patterns_concurrent(
- workspace: Workspace,
- patterns: list[dict],
- search_path: str,
- symbol_name: str,
+def _build_results_with_context(
+ matches: list[dict],
context_lines: int,
- limit: int,
- file_pattern: str,
- ignore: list[str] | None,
+ symbol_name: str,
+ root_path: Path,
) -> list[dict]:
- """并发执行多个模式的搜索"""
- all_results = []
-
- # 使用线程池并发执行多个模式的搜索
- with ThreadPoolExecutor(max_workers=min(len(patterns), 4)) as executor:
- future_to_pattern = {
- executor.submit(
- _search_single_pattern,
- workspace,
- pattern_info,
- search_path,
- symbol_name,
- context_lines,
- limit - len(all_results),
- file_pattern,
- ignore,
- ): pattern_info
- for pattern_info in patterns
- }
+ """从扁平匹配列表构建按文件分组、带上下文的结果"""
+ context_lines = max(0, context_lines)
+
+ # 按文件分组, 保留首次出现的顺序
+ file_matches: dict[str, list[dict]] = {}
+ file_order: list[str] = []
+ for m in matches:
+ f = m["file"]
+ if f not in file_matches:
+ file_matches[f] = []
+ file_order.append(f)
+ file_matches[f].append(m)
+
+ results = []
+ for file_rel in file_order:
+ file_match_list = file_matches[file_rel]
+
+ # 收集需要读取的行号范围(匹配行 ± 上下文行)
+ needed_lines: set[int] = set()
+ for m in file_match_list:
+ for delta in range(-context_lines, context_lines + 1):
+ needed_lines.add(m["line_num"] + delta)
+
+ # 仅读取需要的行(不加载整个文件到内存)
+ line_cache: dict[int, str] = {}
+ file_full_path = root_path / file_rel
+ try:
+ if needed_lines:
+ max_needed = max(needed_lines)
+ with open(file_full_path, encoding="utf-8") as f:
+ for line_num, line in enumerate(f, 1):
+ if line_num in needed_lines:
+ line_cache[line_num] = line.rstrip("\n\r")
+ if line_num >= max_needed:
+ break
+ except UnicodeDecodeError, PermissionError, OSError:
+ pass
+
+ # 为每个匹配项构建上下文
+ built_matches = []
+ for m in file_match_list:
+ match_line_num = m["line_num"]
+ context = []
+ for delta in range(-context_lines, context_lines + 1):
+ ctx_line_num = match_line_num + delta
+ if ctx_line_num in line_cache:
+ context.append(
+ {
+ "line_num": ctx_line_num,
+ "content": line_cache[ctx_line_num],
+ "is_match": delta == 0,
+ }
+ )
+
+ built_matches.append(
+ {
+ "line_num": match_line_num,
+ "content": m["content"],
+ "context": context,
+ "match_type": m["pattern_type"],
+ "symbol_name": symbol_name,
+ }
+ )
- for future in as_completed(future_to_pattern):
- try:
- results = future.result()
- all_results.extend(results)
- if len(all_results) >= limit:
- # 达到限制,取消剩余任务
- for f in future_to_pattern:
- f.cancel()
- break
- except Exception:
- # 单个模式失败不影响其他模式
- pass
+ results.append(
+ {
+ "file": file_rel,
+ "matches": built_matches,
+ "type": file_match_list[0]["pattern_type"],
+ }
+ )
- return all_results[:limit]
+ return results
def _detect_language(search_path: Path, specified_lang: str) -> str:
@@ -347,6 +355,17 @@ def __init__(self, workspace: Workspace):
super().__init__(workspace, "symbol_ref", self.symbol_ref.__doc__, read_permission=True, write_permission=False)
self.func = self.symbol_ref
self.params = BaseTool.extract_params(self.symbol_ref)
+ self.param_descriptions = {
+ "symbol_name": "要查找的符号名称(如函数名、类名、变量名)",
+ "path": "搜索文件或文件夹路径",
+ "language": "语言类型(auto/python/javascript/typescript/markdown/general)",
+ "include_definitions": "是否包含定义位置",
+ "include_references": "是否包含引用位置",
+ "context_lines": "显示匹配行的上下文行数",
+ "limit": "最大匹配数量限制",
+ "ignore": "忽略匹配正则的文件或文件夹列表",
+ "file_pattern": "文件匹配模式(如 *.py),默认根据语言自动选择",
+ }
@BaseTool.handle_tool_exceptions
def symbol_ref(
@@ -360,31 +379,14 @@ def symbol_ref(
limit: int = 256,
ignore: list[str] | None = None,
file_pattern: str | None = None,
- ) -> str:
+ ) -> ToolResult:
"""
- 查找符号(函数、类、变量等)的定义和引用位置
-
- 支持自动识别语言类型,生成智能搜索模式来定位符号的定义和所有使用位置.
- 适用于代码探索、重构影响分析、理解代码结构等场景.
-
- Args:
- symbol_name: 要查找的符号名称(如函数名、类名、变量名)
- path: 搜索路径,默认为当前目录
- language: 语言类型(auto/python/javascript/typescript/markdown/general),默认auto
- include_definitions: 是否包含定义位置,默认True
- include_references: 是否包含引用位置,默认True
- context_lines: 显示匹配行的上下文行数,默认2
- limit: 最大匹配数量限制,默认为256
- ignore: 忽略匹配正则的文件或文件夹列表
- file_pattern: 文件匹配模式(如 *.py),默认根据语言自动选择
-
- Returns:
- 格式化的符号引用搜索结果
+ 查找符号(函数、类、变量等)的定义和引用位置, 适用于代码探索、重构影响分析、理解代码结构等场景.
"""
# 验证搜索路径
search_path: Path = self.workspace.path_validator.validate(path)
if not search_path.exists():
- return ToolErrorResponse(self.__class__.__name__, f"路径不存在: {path}").to_str()
+ return self.make_failed_response(kwargs=locals().copy(), error=f"路径不存在: {path}")
# 自动检测语言
detected_lang = _detect_language(search_path, language)
@@ -397,10 +399,12 @@ def symbol_ref(
patterns = _generate_patterns(symbol_name, detected_lang, include_definitions, include_references)
if not patterns:
- return f"无法为符号 '{symbol_name}' 生成有效的搜索模式(语言: {detected_lang})"
+ return self.make_failed_response(
+ kwargs=locals().copy(), error=f"无法为符号 '{symbol_name}' 生成有效的搜索模式(语言: {detected_lang})"
+ )
# 使用并发搜索执行所有模式
- all_results = _search_patterns_concurrent(
+ all_results = _search_all_patterns(
self.workspace,
patterns,
path,
@@ -412,4 +416,6 @@ def symbol_ref(
)
# 格式化输出
- return _format_results(all_results, symbol_name, detected_lang, limit)
+ return self.make_success_response(
+ kwargs=locals().copy(), data=_format_results(all_results, symbol_name, detected_lang, limit)
+ )
diff --git a/src/workspace/tools/write_tool.py b/src/workspace/tools/write_tool.py
index 1a0935e..fe1a09c 100644
--- a/src/workspace/tools/write_tool.py
+++ b/src/workspace/tools/write_tool.py
@@ -1,7 +1,7 @@
from pathlib import Path
from src.core.file_tracker import FileTracker
-from src.models.tool_error_response import ToolErrorResponse
+from src.models.tools.tool_result import ToolResult
from src.utils.binary_detector import is_binary_file
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
@@ -12,50 +12,48 @@ def __init__(self, workspace: Workspace):
super().__init__(workspace, "write", self.write.__doc__, write_permission=True)
self.func = self.write
self.params = BaseTool.extract_params(self.write)
+ self.param_descriptions = {
+ "path": "文件路径",
+ "content": "写入内容",
+ }
@BaseTool.handle_tool_exceptions
- def write(self, file_path: str, content: str = "") -> str:
+ def write(self, path: str, content: str = "") -> ToolResult:
"""
- 写入文件内容,如文件不存在则创建(含父目录)
-
- Parameters
- ----------
- file_path: 文件路径
- content: 写入内容
+ 写入文件内容, 如文件不存在则创建(含父目录)
"""
- source_file_path = Path(file_path)
- file_path: Path = self.workspace.path_validator.resolve_path(source_file_path)
+ source_path = Path(path)
+ path: Path = self.workspace.path_validator.resolve_path(source_path)
- if file_path.exists() and file_path.is_dir():
- return ToolErrorResponse(
- self.__class__.__name__, ValueError(f"路径 {file_path} 是一个目录,无法写入")
- ).to_str()
+ if path.exists() and path.is_dir():
+ return self.make_failed_response(
+ kwargs=locals().copy(), error=str(ValueError(f"路径 {path} 是一个目录,无法写入"))
+ )
- if is_binary_file(file_path):
- return ToolErrorResponse(
- self.__class__.__name__,
- ValueError(f"禁止写入二进制文件: {file_path}"),
- ).to_str()
+ if is_binary_file(path):
+ return self.make_failed_response(
+ kwargs=locals().copy(), error=str(ValueError(f"禁止写入二进制文件: {path}"))
+ )
- mtime_error = self._validate_mtime(file_path)
+ mtime_error = self._validate_mtime(path)
if mtime_error:
- return mtime_error
+ return self.make_failed_response(locals().copy(), error=f"无法编辑被修改过的文件:\n{mtime_error}")
old_content = ""
old_meta = None
- if file_path.exists() and file_path.is_file():
- old_meta = FileTracker.get_file_meta(file_path)
+ if path.exists() and path.is_file():
+ old_meta = FileTracker.get_file_meta(path)
try:
- old_content = file_path.read_text(encoding="utf-8")
+ old_content = path.read_text(encoding="utf-8")
except Exception:
old_content = ""
- rel_path = str(file_path.relative_to(self.workspace.root_path))
+ rel_path = str(path.relative_to(self.workspace.root_path))
old_hash = old_meta.get("checksum") if old_meta else None
new_hash = FileTracker.compute_checksum_from_string(content)
diff_content = self._generate_diff(old_content, content, rel_path)
- session_id = self.workspace._current_session_id
+ session_id = self.workspace.session_id
snapshot_id = self.workspace.db.record_file_snapshot(
rel_path,
old_hash,
@@ -66,4 +64,13 @@ def write(self, file_path: str, content: str = "") -> str:
pending_content=content,
)
- return f"[Write Preview]\nFile: {rel_path}\nSnapshot ID: {snapshot_id}\nDiff:\n{diff_content}"
+ return self.make_success_response(
+ kwargs=locals().copy(),
+ data=(
+ f"修改已推送到审核系统\n"
+ f"[Write Preview]\n"
+ f"File: {rel_path}\n"
+ f"Snapshot ID: {snapshot_id}\n"
+ f"Diff:\n{diff_content}"
+ ),
+ )
diff --git a/src/workspace/workspace.py b/src/workspace/workspace.py
index 538a280..56699f6 100644
--- a/src/workspace/workspace.py
+++ b/src/workspace/workspace.py
@@ -7,6 +7,9 @@
from src.models.tool_error_response import ToolErrorResponse
from src.workspace.path_validator import PathNotFoundError, PathValidator, WorkspaceBoundaryError
+# 默认排除的目录 后续改为从项目配置加载
+DEFAULT_EXCLUDED_DIRS = {".git", "__pycache__", "node_modules", ".venv", "venv", "dist", "build", ".idea", ".vscode"}
+
def _highlight_matches(line: str, regex: re.Pattern) -> str:
"""
@@ -59,6 +62,15 @@ def db(self):
self._db = DatabaseManager(str(self.root_path))
return self._db
+ @property
+ def session_id(self) -> int | None:
+ """当前会话 ID 的公开 getter —— 工具类通过此接口访问, 避免直接访问私有属性."""
+ return self._current_session_id
+
+ @session_id.setter
+ def session_id(self, value: int | None) -> None:
+ self._current_session_id = value
+
def search_content(
self,
pattern: str,
@@ -73,7 +85,7 @@ def search_content(
path = self.path_validator.validate(folder_path)
# 初始化排除目录集合
- exclude_set = set(exclude_dirs or [".git", "__pycache__", "node_modules", ".venv", "venv", "dist", "build"])
+ exclude_set = set(exclude_dirs or DEFAULT_EXCLUDED_DIRS)
# 编译正则表达式
flags = 0 if case_sensitive else re.IGNORECASE
@@ -176,3 +188,116 @@ def _search_in_file(self, file_path: Path, regex: re.Pattern) -> list[tuple]:
pass
return results
+
+ def search_content_multi_pattern(
+ self,
+ patterns: list[tuple[re.Pattern, str]],
+ folder_path: str = ".",
+ file_pattern: str = "*",
+ max_workers: int = 4,
+ ignore: list[str] | None = None,
+ ) -> list[dict]:
+ """单次文件遍历匹配多个正则模式, 直接返回结构化数据.
+
+ 与 search_content 的区别:
+ - 接受多个已编译的正则 + 类型标签, 一次遍历全部匹配
+ - 返回结构化 list[dict] 而非格式化字符串, 消除下游正则解析反模式
+ - 不做高亮处理(高亮是展示层关注点, 不应混入数据层)
+
+ Args:
+ patterns: [(compiled_regex, type_label), ...] 已编译正则及其类型标签
+ folder_path: 搜索起始路径
+ file_pattern: 文件通配符(如 "*.py")
+ max_workers: 并发读取文件的线程数
+ ignore: 忽略路径正则列表
+
+ Returns:
+ [{"file": str, "line_num": int, "content": str, "pattern_type": str}, ...]
+ 按文件路径 → 行号排序, 同一行多个模式匹配则每个各一条记录
+ """
+ try:
+ path = self.path_validator.validate(folder_path)
+
+ # 预编译 ignore 正则
+ ignore_res: list[re.Pattern] = []
+ if ignore:
+ for ign in ignore:
+ try:
+ ignore_res.append(re.compile(ign))
+ except re.error:
+ continue
+
+ # 收集文件(一次遍历)
+ files_to_search: list[Path] = []
+ if path.is_file():
+ files_to_search = [path]
+ else:
+ for file_path in path.rglob(file_pattern):
+ if file_path.is_file():
+ if any(p.name in DEFAULT_EXCLUDED_DIRS for p in file_path.parents):
+ continue
+ rel = str(file_path.relative_to(self.root_path))
+ if any(ir.search(rel) for ir in ignore_res):
+ continue
+ files_to_search.append(file_path)
+
+ if not files_to_search:
+ return []
+
+ # 并发搜索所有文件, 每个文件内一次读取、一次测试所有模式
+ all_matches: list[dict] = []
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures = {
+ executor.submit(self._search_multi_in_file, file_path, patterns): file_path
+ for file_path in files_to_search
+ }
+ for future in as_completed(futures):
+ try:
+ file_results = future.result()
+ if file_results:
+ all_matches.extend(file_results)
+ except Exception:
+ pass
+
+ # 按文件路径 → 行号排序
+ all_matches.sort(key=lambda m: (m["file"], m["line_num"]))
+ return all_matches
+
+ except PathNotFoundError, WorkspaceBoundaryError, PermissionError:
+ return []
+ except Exception:
+ return []
+
+ def _search_multi_in_file(self, file_path: Path, patterns: list[tuple[re.Pattern, str]]) -> list[dict]:
+ """在单个文件中一次读取、逐行测试所有模式.
+
+ Args:
+ file_path: 文件绝对路径
+ patterns: [(compiled_regex, type_label), ...]
+
+ Returns:
+ 匹配项列表, 同一行多个模式匹配时每个各返回一条
+ """
+ results: list[dict] = []
+ relative_path = str(file_path.relative_to(self.root_path))
+
+ try:
+ with open(file_path, encoding="utf-8") as f:
+ for line_num, line in enumerate(f, 1):
+ stripped = line.rstrip("\n\r")
+ for regex, pattern_type in patterns:
+ if regex.search(stripped):
+ results.append(
+ {
+ "file": relative_path,
+ "line_num": line_num,
+ "content": stripped,
+ "pattern_type": pattern_type,
+ }
+ )
+ except UnicodeDecodeError, PermissionError:
+ pass
+ except Exception:
+ pass
+
+ return results
diff --git a/tests/core/test_tool_registry.py b/tests/core/test_tool_registry.py
index e94f43e..ab1e2e2 100644
--- a/tests/core/test_tool_registry.py
+++ b/tests/core/test_tool_registry.py
@@ -33,31 +33,6 @@ def isolate_tool_registry():
DatabaseManager.reset_instances()
-def test_validate_config():
- """测试配置验证 - 一次性测试所有阈值"""
- config = ToolRegistry()
-
- # 设置所有值为过小
- config.MAX_RESULT_LENGTH = 5
- config.LIST_TRUNCATE_THRESHOLD = 3
- config.DICT_TRUNCATE_THRESHOLD = 2
-
- # 验证触发3个警告
- with pytest.warns(UserWarning) as record:
- config._validate_config()
-
- # 验证警告数量和内容
- assert len(record) == 3
- assert "TOOL_MAX_RESULT_LENGTH" in str(record[0].message)
- assert "TOOL_LIST_TRUNCATE_THRESHOLD" in str(record[1].message)
- assert "TOOL_DICT_TRUNCATE_THRESHOLD" in str(record[2].message)
-
- # 验证所有值都被修正
- assert config.MAX_RESULT_LENGTH == 100
- assert config.LIST_TRUNCATE_THRESHOLD == 50
- assert config.DICT_TRUNCATE_THRESHOLD == 50
-
-
def test_tool_registry_singleton():
"""测试单例模式"""
registry1 = ToolRegistry()
@@ -67,14 +42,6 @@ def test_tool_registry_singleton():
assert id(registry1) == id(registry2)
-def test_execute_nonexistent_tool():
- """测试执行不存在的工具"""
- registry = ToolRegistry()
-
- with pytest.raises(ValueError, match="未找到工具: nonexistent"):
- registry.execute("nonexistent")
-
-
def test_validate_tool_info():
"""测试工具信息验证"""
registry = ToolRegistry()
@@ -95,51 +62,6 @@ def test_validate_tool_info():
assert any(f"超过 {MAX_DOC_LENGTH} 字符" in str(warning.message) for warning in w)
-def test_compress_result_string():
- """测试字符串结果压缩"""
- registry = ToolRegistry()
-
- long_string = "x" * (MAX_RESULT_LENGTH + 10000)
-
- compressed = registry._compress_result(long_string)
- assert "结果已截断" in compressed
-
-
-def test_compress_result_list():
- """测试列表结果压缩"""
- registry = ToolRegistry()
-
- long_list = list(range(150))
-
- compressed = registry._compress_result(long_list)
- assert len(compressed) == 101
- assert "列表已截断" in compressed[-1]
-
-
-def test_compress_result_dict():
- """测试字典结果压缩"""
- registry = ToolRegistry()
-
- long_dict = {f"key_{i}": f"value_{i}" for i in range(150)}
-
- compressed = registry._compress_result(long_dict)
- assert len(compressed) == 101
- assert "字典已截断" in compressed["..."]
-
-
-def test_no_compress_short_results():
- """测试不对短结果进行压缩"""
- registry = ToolRegistry()
-
- short_string = "short"
- short_list = [1, 2, 3]
- short_dict = {"a": 1, "b": 2}
-
- assert registry._compress_result(short_string) == short_string
- assert registry._compress_result(short_list) == short_list
- assert registry._compress_result(short_dict) == short_dict
-
-
class TestToolCategorization:
"""测试工具分类逻辑(需要 workspace 注册工具)."""
@@ -156,13 +78,13 @@ def setup(self, tmp_path):
def test_query_tools_category(self):
registry = ToolRegistry()
- for name in ("glob", "ls", "regex_search", "stat", "read", "read_lines", "symbol_ref"):
+ for name in ("glob", "ls", "regex_search", "stat", "read", "symbol_ref"):
assert registry._tool_categories.get(name) == "query", f"{name} should be query"
- def test_edit_tools_category(self):
+ def test_write_tools_category(self):
registry = ToolRegistry()
for name in ("write", "edit"):
- assert registry._tool_categories.get(name) == "edit", f"{name} should be edit"
+ assert registry._tool_categories.get(name) == "write", f"{name} should be write"
def test_git_tool_category(self):
registry = ToolRegistry()
diff --git a/tests/models/__init__.py b/tests/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/models/tools/__init__.py b/tests/models/tools/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/models/tools/test_tool_result.py b/tests/models/tools/test_tool_result.py
new file mode 100644
index 0000000..660e098
--- /dev/null
+++ b/tests/models/tools/test_tool_result.py
@@ -0,0 +1,63 @@
+import pytest
+
+from src.models.tools.tool_result import ToolResult
+
+
+def test_compress_result_list():
+ """测试列表结果压缩"""
+ long_list = list(range(150))
+
+ compressed = ToolResult._compress_result(long_list)
+ assert len(compressed) == 101
+ assert "列表已截断" in compressed[-1]
+
+
+def test_compress_result_dict():
+ """测试字典结果压缩"""
+ long_dict = {f"key_{i}": f"value_{i}" for i in range(150)}
+
+ compressed = ToolResult._compress_result(long_dict)
+ assert len(compressed) == 101
+ assert "字典已截断" in compressed["..."]
+
+
+def test_no_compress_short_results():
+ """测试不对短结果进行压缩"""
+ short_string = "short"
+ short_list = [1, 2, 3]
+ short_dict = {"a": 1, "b": 2}
+
+ assert ToolResult._compress_result(short_string) == short_string
+ assert ToolResult._compress_result(short_list) == short_list
+ assert ToolResult._compress_result(short_dict) == short_dict
+
+
+def test_validate_config():
+ """测试配置验证 - 一次性测试所有阈值"""
+ # 设置所有值为过小
+ ToolResult.MAX_RESULT_LENGTH = 5
+ ToolResult.LIST_TRUNCATE_THRESHOLD = 3
+ ToolResult.DICT_TRUNCATE_THRESHOLD = 2
+
+ # 验证触发3个警告
+ with pytest.warns(UserWarning) as record:
+ ToolResult._validate_config()
+
+ # 验证警告数量和内容
+ assert len(record) == 3
+ assert "TOOL_MAX_RESULT_LENGTH" in str(record[0].message)
+ assert "TOOL_LIST_TRUNCATE_THRESHOLD" in str(record[1].message)
+ assert "TOOL_DICT_TRUNCATE_THRESHOLD" in str(record[2].message)
+
+ # 验证所有值都被修正
+ assert ToolResult.MAX_RESULT_LENGTH == 100
+ assert ToolResult.LIST_TRUNCATE_THRESHOLD == 50
+ assert ToolResult.DICT_TRUNCATE_THRESHOLD == 50
+
+
+def test_compress_result_string():
+ """测试字符串结果压缩"""
+ long_string = "x" * (ToolResult.MAX_RESULT_LENGTH + 10000)
+
+ compressed = ToolResult._compress_result(long_string)
+ assert "结果已截断" in compressed
diff --git a/tests/workspace/tools/test_base_tool.py b/tests/workspace/tools/test_base_tool.py
index d8cd286..8d43395 100644
--- a/tests/workspace/tools/test_base_tool.py
+++ b/tests/workspace/tools/test_base_tool.py
@@ -1,5 +1,4 @@
-from src.workspace.tools.base_tool import BaseTool, build_param_doc
-from src.workspace.workspace import Workspace
+from src.workspace.tools.base_tool import BaseTool
def test_extract_params():
@@ -41,61 +40,3 @@ def class_method(cls, b: str) -> str:
assert "cls" not in params2
assert "a" in params1
assert "b" in params2
-
-
-def test_build_param_doc():
- """测试参数文档生成使用简洁类型和required属性"""
-
- params_required = {"required": True, "annotation": ""}
- result = build_param_doc("file_path", params_required)
- assert 'type="string"' in result
- assert 'required="true"' in result
- assert "", "default": "0"}
- result = build_param_doc("max_lines", params_optional)
- assert 'type="integer"' in result
- assert 'required="false"' in result
- assert 'default="0"' in result
- assert " str:
- """Sample function for testing"""
- return f"{a}_{b}"
-
- class MockTool(BaseTool):
- def __init__(self, workspace: Workspace | None):
- super().__init__(workspace, "mock", "测试工具")
- self.func = sample_func
- self.params = BaseTool.extract_params(sample_func)
-
- tool = MockTool(None)
- doc = tool.to_doc()
-
- # 使用新格式标签
- assert doc.startswith('')
- assert "测试工具" in doc
- assert "" in doc
- assert "" in doc
- assert doc.endswith("")
-
- # 验证参数格式
- assert 'type="string"' in doc
- assert 'type="integer"' in doc
- assert 'required="true"' in doc
- assert 'required="false"' in doc
-
- # 验证没有旧格式标记
- assert "" not in doc
- assert "" not in doc
- assert "